onnx_embedding/
lib.rs

1//! Downloads the ONNX runtime lib and embeds it into Rust code.
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::Literal;
6use syn::{parse_macro_input, LitStr};
7use quote::quote;
8use std::{
9    env,
10    env::consts,
11    fs::{self, File},
12    io::{self, Write},
13    path::{Path, PathBuf},
14};
15use flate2::read::GzDecoder;
16use tar::Archive;
17
18/// Extracts a `.tgz` file to the specified directory.
19/// 
20/// # Arguments
21/// - tgz_path: path to .tgz to be extracted from
22/// - extract_to: path that the extracted .tgx is unloaded to
23fn extract_tgz<P: AsRef<Path>>(tgz_path: P, extract_to: P) -> io::Result<()> {
24    let file = File::open(&tgz_path)?;
25    let decompressor = GzDecoder::new(file);
26    let mut archive = Archive::new(decompressor);
27    archive.unpack(&extract_to)?;
28    Ok(())
29}
30
31
32/// Constructs the URL and details for downloading ONNX Runtime based on platform.
33/// 
34/// # Arguments
35/// - onnx_version: the version of ONNX to download
36/// 
37/// # Returns
38/// (url, package_name, ext, dylib_name)
39fn get_onnxruntime_url(onnx_version: &str) -> (String, String, String, String) {
40    let base_url = format!(
41        "https://github.com/microsoft/onnxruntime/releases/download/v{}/",
42        onnx_version
43    );
44
45    match (consts::OS, consts::ARCH) {
46        ("linux", "x86_64") => (
47            format!("{}onnxruntime-linux-x64-{}.tgz", base_url, onnx_version),
48            format!("onnxruntime-linux-x64-{}", onnx_version),
49            "tgz".to_string(),
50            "libonnxruntime.so".to_string(),
51        ),
52        ("linux", "aarch64") => (
53            format!("{}onnxruntime-linux-aarch64-{}.tgz", base_url, onnx_version),
54            format!("onnxruntime-linux-aarch64-{}", onnx_version),
55            "tgz".to_string(),
56            "libonnxruntime.so".to_string(),
57        ),
58        ("macos", "x86_64") => (
59            format!("{}onnxruntime-osx-x86_64-{}.tgz", base_url, onnx_version),
60            format!("onnxruntime-osx-x86_64-{}", onnx_version),
61            "tgz".to_string(),
62            "libonnxruntime.dylib".to_string(),
63        ),
64        ("macos", "aarch64") => (
65            format!("{}onnxruntime-osx-arm64-{}.tgz", base_url, onnx_version),
66            format!("onnxruntime-osx-arm64-{}", onnx_version),
67            "tgz".to_string(),
68            "libonnxruntime.dylib".to_string(),
69        ),
70        ("windows", "x86_64") => (
71            format!("{}onnxruntime-win-x64-{}.zip", base_url, onnx_version),
72            format!("onnxruntime-win-x64-{}", onnx_version),
73            "zip".to_string(),
74            "onnxruntime.dll".to_string(),
75        ),
76        ("windows", "aarch64") => (
77            format!("{}onnxruntime-win-arm64-{}.zip", base_url, onnx_version),
78            format!("onnxruntime-win-arm64-{}", onnx_version),
79            "zip".to_string(),
80            "onnxruntime.dll".to_string(),
81        ),
82        _ => panic!(
83            "Unsupported platform or architecture: {} {}",
84            consts::OS,
85            consts::ARCH
86        ),
87    }
88}
89
90#[proc_macro]
91pub fn embed_onnx(attr: TokenStream) -> TokenStream {
92
93    // get the onnx version
94    let input = parse_macro_input!(attr as LitStr);
95    let supported_versions = ["1.20.0"];
96    let onnx_version = match input.value().as_str() {
97        "1.20.0" => "1.20.0",
98        _ => panic!(
99            "{} passed in as version, only the following versions are supported: {:?}", 
100            input.value(), supported_versions
101        )
102    };
103
104    let (url, package_name, ext, dylib_name) = get_onnxruntime_url(onnx_version);
105
106    // Persistent cache under target directory
107    let target_root = std::env::var("CARGO_TARGET_DIR")
108    .map(PathBuf::from)
109    .unwrap_or_else(|_| {
110        Path::new(&env!("CARGO_MANIFEST_DIR"))
111            .parent()
112            .expect("crate not in a workspace?")
113            .join("target")
114    });
115
116    let cache = target_root.join("onnxruntime_cache").join(onnx_version);
117    fs::create_dir_all(&cache).expect("failed to create cache dir");
118
119    let filename = format!("{}.{}", package_name, ext);
120    let download_path = cache.join(&filename);
121    let extract_target = cache.join(&package_name);
122    let dylib_path = extract_target.join("lib").join(&dylib_name);
123
124    if !download_path.exists() {
125        println!("Downloading ONNX Runtime from {}", url);
126        let response = reqwest::blocking::get(&url)
127            .expect("Failed to download ONNX Runtime")
128            .bytes()
129            .expect("Failed to read ONNX Runtime response");
130
131        let mut file = File::create(&download_path).expect("Failed to create ONNX file");
132        file.write_all(&response).expect("Failed to write ONNX file");
133        println!("Saved to {}", download_path.display());
134    }
135
136    if !dylib_path.exists() {
137        extract_tgz(&download_path, &cache)
138            .expect("Failed to extract ONNX archive");
139    }
140
141    let bytes: Vec<u8> = fs::read(&dylib_path).expect("Failed to read extracted library");
142
143    let byte_string = Literal::byte_string(&bytes);
144
145    let tokens = quote! {
146        #byte_string
147    };
148
149    TokenStream::from(tokens)
150}
151
152
153// #[proc_macro]
154// pub fn embed_onnx(attr: TokenStream) -> TokenStream {
155
156//     // get the onnx version
157//     let input = parse_macro_input!(attr as LitStr);
158//     let supported_versions = ["1.20.0"];
159//     let onnx_version = match input.value().as_str() {
160//         "1.20.0" => "1.20.0",
161//         _ => panic!(
162//             "{} passed in as version, only the following versions are supported: {:?}", 
163//             input.value(), supported_versions
164//         )
165//     };
166
167//     let (url, package_name, ext, dylib_name) = get_onnxruntime_url(onnx_version);
168
169//     // Create a temporary directory
170//     let temp_dir = tempfile::Builder::new()
171//         .prefix("onnxruntime_embed_")
172//         .tempdir()
173//         .expect("Failed to create temporary directory");
174
175//     let temp_path = temp_dir.path().to_path_buf();
176//     let filename = format!("{}.{}", package_name, ext);
177//     let download_path = temp_path.join(&filename);
178//     let extract_target = temp_path.join(&package_name);
179//     let tgz_path_str = download_path.to_str().expect("cannot convert download path to string").to_string();
180//     let dylib_path = extract_target.join("lib").join(dylib_name);
181
182//     if !download_path.exists() {
183//         println!("Downloading ONNX Runtime from {}", url);
184//         let response = reqwest::blocking::get(&url)
185//             .expect("Failed to download ONNX Runtime")
186//             .bytes()
187//             .expect("Failed to read ONNX Runtime response");
188
189//         let mut file = File::create(&download_path).expect("Failed to create ONNX file");
190//         file.write_all(&response).expect("Failed to write ONNX file");
191//         println!("Saved to {}", download_path.display());
192//     }
193
194//     if !dylib_path.exists() {
195//         extract_tgz(&tgz_path_str, &temp_path.to_str().expect("cannot convert temp path to string").to_owned()).expect("Failed to extract ONNX archive");
196//     }
197
198//     let bytes: Vec<u8> = fs::read(&dylib_path).expect("Failed to read extracted library");
199
200//     // Explicitly clean up the temporary directory
201//     temp_dir.close().expect("Failed to remove temporary directory");
202
203//     let byte_string = Literal::byte_string(&bytes);
204
205//     let tokens = quote! {
206//         #byte_string
207//     };
208
209//     TokenStream::from(tokens)
210// }