onnx_embedding/
lib.rs

1//! Downloads the ONNX runtime lib and embeds it into Rust code.
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::Literal;
6use syn::{parse_macro_input, LitStr};
7use quote::quote;
8use std::{
9    env::consts,
10    fs::{self, File},
11    io::{self, Write},
12    path::Path,
13};
14use flate2::read::GzDecoder;
15use tar::Archive;
16
17/// Extracts a `.tgz` file to the specified directory.
18/// 
19/// # Arguments
20/// - tgz_path: path to .tgz to be extracted from
21/// - extract_to: path that the extracted .tgx is unloaded to
22fn extract_tgz<P: AsRef<Path>>(tgz_path: P, extract_to: P) -> io::Result<()> {
23    let file = File::open(&tgz_path)?;
24    let decompressor = GzDecoder::new(file);
25    let mut archive = Archive::new(decompressor);
26    archive.unpack(&extract_to)?;
27    Ok(())
28}
29
30
31/// Constructs the URL and details for downloading ONNX Runtime based on platform.
32/// 
33/// # Arguments
34/// - onnx_version: the version of ONNX to download
35/// 
36/// # Returns
37/// (url, package_name, ext, dylib_name)
38fn get_onnxruntime_url(onnx_version: &str) -> (String, String, String, String) {
39    let base_url = format!(
40        "https://github.com/microsoft/onnxruntime/releases/download/v{}/",
41        onnx_version
42    );
43
44    match (consts::OS, consts::ARCH) {
45        ("linux", "x86_64") => (
46            format!("{}onnxruntime-linux-x64-{}.tgz", base_url, onnx_version),
47            format!("onnxruntime-linux-x64-{}", onnx_version),
48            "tgz".to_string(),
49            "libonnxruntime.so".to_string(),
50        ),
51        ("linux", "aarch64") => (
52            format!("{}onnxruntime-linux-aarch64-{}.tgz", base_url, onnx_version),
53            format!("onnxruntime-linux-aarch64-{}", onnx_version),
54            "tgz".to_string(),
55            "libonnxruntime.so".to_string(),
56        ),
57        ("macos", "x86_64") => (
58            format!("{}onnxruntime-osx-x86_64-{}.tgz", base_url, onnx_version),
59            format!("onnxruntime-osx-x86_64-{}", onnx_version),
60            "tgz".to_string(),
61            "libonnxruntime.dylib".to_string(),
62        ),
63        ("macos", "aarch64") => (
64            format!("{}onnxruntime-osx-arm64-{}.tgz", base_url, onnx_version),
65            format!("onnxruntime-osx-arm64-{}", onnx_version),
66            "tgz".to_string(),
67            "libonnxruntime.dylib".to_string(),
68        ),
69        ("windows", "x86_64") => (
70            format!("{}onnxruntime-win-x64-{}.zip", base_url, onnx_version),
71            format!("onnxruntime-win-x64-{}", onnx_version),
72            "zip".to_string(),
73            "onnxruntime.dll".to_string(),
74        ),
75        ("windows", "aarch64") => (
76            format!("{}onnxruntime-win-arm64-{}.zip", base_url, onnx_version),
77            format!("onnxruntime-win-arm64-{}", onnx_version),
78            "zip".to_string(),
79            "onnxruntime.dll".to_string(),
80        ),
81        _ => panic!(
82            "Unsupported platform or architecture: {} {}",
83            consts::OS,
84            consts::ARCH
85        ),
86    }
87}
88
89#[proc_macro]
90pub fn embed_onnx(attr: TokenStream) -> TokenStream {
91
92    // get the onnx version
93    let input = parse_macro_input!(attr as LitStr);
94    let supported_versions = ["1.20.0"];
95    let onnx_version = match input.value().as_str() {
96        "1.20.0" => "1.20.0",
97        _ => panic!(
98            "{} passed in as version, only the following versions are supported: {:?}", 
99            input.value(), supported_versions
100        )
101    };
102
103    let (url, package_name, ext, dylib_name) = get_onnxruntime_url(onnx_version);
104
105    // Create a temporary directory
106    let temp_dir = tempfile::Builder::new()
107        .prefix("onnxruntime_embed_")
108        .tempdir()
109        .expect("Failed to create temporary directory");
110
111    let temp_path = temp_dir.path().to_path_buf();
112    let filename = format!("{}.{}", package_name, ext);
113    let download_path = temp_path.join(&filename);
114    let extract_target = temp_path.join(&package_name);
115    let tgz_path_str = download_path.to_str().expect("cannot convert download path to string").to_string();
116    let dylib_path = extract_target.join("lib").join(dylib_name);
117
118    if !download_path.exists() {
119        println!("Downloading ONNX Runtime from {}", url);
120        let response = reqwest::blocking::get(&url)
121            .expect("Failed to download ONNX Runtime")
122            .bytes()
123            .expect("Failed to read ONNX Runtime response");
124
125        let mut file = File::create(&download_path).expect("Failed to create ONNX file");
126        file.write_all(&response).expect("Failed to write ONNX file");
127        println!("Saved to {}", download_path.display());
128    }
129
130    if !dylib_path.exists() {
131        extract_tgz(&tgz_path_str, &temp_path.to_str().expect("cannot convert temp path to string").to_owned()).expect("Failed to extract ONNX archive");
132    }
133
134    let bytes: Vec<u8> = fs::read(&dylib_path).expect("Failed to read extracted library");
135
136    // Explicitly clean up the temporary directory
137    temp_dir.close().expect("Failed to remove temporary directory");
138
139    let byte_string = Literal::byte_string(&bytes);
140
141    let tokens = quote! {
142        #byte_string
143    };
144
145    TokenStream::from(tokens)
146}