1extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::Literal;
6use syn::{parse_macro_input, LitStr};
7use quote::quote;
8use std::{
9 env::consts,
10 fs::{self, File},
11 io::{self, Write},
12 path::Path,
13};
14use flate2::read::GzDecoder;
15use tar::Archive;
16
17fn extract_tgz<P: AsRef<Path>>(tgz_path: P, extract_to: P) -> io::Result<()> {
23 let file = File::open(&tgz_path)?;
24 let decompressor = GzDecoder::new(file);
25 let mut archive = Archive::new(decompressor);
26 archive.unpack(&extract_to)?;
27 Ok(())
28}
29
30
31fn get_onnxruntime_url(onnx_version: &str) -> (String, String, String, String) {
39 let base_url = format!(
40 "https://github.com/microsoft/onnxruntime/releases/download/v{}/",
41 onnx_version
42 );
43
44 match (consts::OS, consts::ARCH) {
45 ("linux", "x86_64") => (
46 format!("{}onnxruntime-linux-x64-{}.tgz", base_url, onnx_version),
47 format!("onnxruntime-linux-x64-{}", onnx_version),
48 "tgz".to_string(),
49 "libonnxruntime.so".to_string(),
50 ),
51 ("linux", "aarch64") => (
52 format!("{}onnxruntime-linux-aarch64-{}.tgz", base_url, onnx_version),
53 format!("onnxruntime-linux-aarch64-{}", onnx_version),
54 "tgz".to_string(),
55 "libonnxruntime.so".to_string(),
56 ),
57 ("macos", "x86_64") => (
58 format!("{}onnxruntime-osx-x86_64-{}.tgz", base_url, onnx_version),
59 format!("onnxruntime-osx-x86_64-{}", onnx_version),
60 "tgz".to_string(),
61 "libonnxruntime.dylib".to_string(),
62 ),
63 ("macos", "aarch64") => (
64 format!("{}onnxruntime-osx-arm64-{}.tgz", base_url, onnx_version),
65 format!("onnxruntime-osx-arm64-{}", onnx_version),
66 "tgz".to_string(),
67 "libonnxruntime.dylib".to_string(),
68 ),
69 ("windows", "x86_64") => (
70 format!("{}onnxruntime-win-x64-{}.zip", base_url, onnx_version),
71 format!("onnxruntime-win-x64-{}", onnx_version),
72 "zip".to_string(),
73 "onnxruntime.dll".to_string(),
74 ),
75 ("windows", "aarch64") => (
76 format!("{}onnxruntime-win-arm64-{}.zip", base_url, onnx_version),
77 format!("onnxruntime-win-arm64-{}", onnx_version),
78 "zip".to_string(),
79 "onnxruntime.dll".to_string(),
80 ),
81 _ => panic!(
82 "Unsupported platform or architecture: {} {}",
83 consts::OS,
84 consts::ARCH
85 ),
86 }
87}
88
89#[proc_macro]
90pub fn embed_onnx(attr: TokenStream) -> TokenStream {
91
92 let input = parse_macro_input!(attr as LitStr);
94 let supported_versions = ["1.20.0"];
95 let onnx_version = match input.value().as_str() {
96 "1.20.0" => "1.20.0",
97 _ => panic!(
98 "{} passed in as version, only the following versions are supported: {:?}",
99 input.value(), supported_versions
100 )
101 };
102
103 let (url, package_name, ext, dylib_name) = get_onnxruntime_url(onnx_version);
104
105 let temp_dir = tempfile::Builder::new()
107 .prefix("onnxruntime_embed_")
108 .tempdir()
109 .expect("Failed to create temporary directory");
110
111 let temp_path = temp_dir.path().to_path_buf();
112 let filename = format!("{}.{}", package_name, ext);
113 let download_path = temp_path.join(&filename);
114 let extract_target = temp_path.join(&package_name);
115 let tgz_path_str = download_path.to_str().expect("cannot convert download path to string").to_string();
116 let dylib_path = extract_target.join("lib").join(dylib_name);
117
118 if !download_path.exists() {
119 println!("Downloading ONNX Runtime from {}", url);
120 let response = reqwest::blocking::get(&url)
121 .expect("Failed to download ONNX Runtime")
122 .bytes()
123 .expect("Failed to read ONNX Runtime response");
124
125 let mut file = File::create(&download_path).expect("Failed to create ONNX file");
126 file.write_all(&response).expect("Failed to write ONNX file");
127 println!("Saved to {}", download_path.display());
128 }
129
130 if !dylib_path.exists() {
131 extract_tgz(&tgz_path_str, &temp_path.to_str().expect("cannot convert temp path to string").to_owned()).expect("Failed to extract ONNX archive");
132 }
133
134 let bytes: Vec<u8> = fs::read(&dylib_path).expect("Failed to read extracted library");
135
136 temp_dir.close().expect("Failed to remove temporary directory");
138
139 let byte_string = Literal::byte_string(&bytes);
140
141 let tokens = quote! {
142 #byte_string
143 };
144
145 TokenStream::from(tokens)
146}