onnx_embedding/
lib.rs

1//! Downloads the ONNX runtime lib and embeds it into Rust code.
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::Literal;
6use syn::{parse_macro_input, LitStr};
7use quote::quote;
8use std::{
9    env,
10    env::consts,
11    fs::{self, File},
12    io::{Write, Cursor},
13    path::{Path, PathBuf},
14};
15mod file_extraction;
16use file_extraction::{
17    FileType,
18    extract_tgz,
19    extract_zip,
20    zip_dir,
21    DylibName
22};
23
24
25/// Constructs the URL and details for downloading ONNX Runtime based on platform.
26/// 
27/// # Arguments
28/// - onnx_version: the version of ONNX to download
29/// 
30/// # Returns
31/// (url, package_name, ext, dylib_name)
32fn get_onnxruntime_url(onnx_version: &str) -> (String, String, String, DylibName, FileType) {
33    let base_url = format!(
34        "https://github.com/microsoft/onnxruntime/releases/download/v{}/",
35        onnx_version
36    );
37
38    match (consts::OS, consts::ARCH) {
39        ("linux", "x86_64") => (
40            format!("{}onnxruntime-linux-x64-{}.tgz", base_url, onnx_version),
41            format!("onnxruntime-linux-x64-{}", onnx_version),
42            "tgz".to_string(),
43            DylibName::So,
44            FileType::Tgz
45        ),
46        ("linux", "aarch64") => (
47            format!("{}onnxruntime-linux-aarch64-{}.tgz", base_url, onnx_version),
48            format!("onnxruntime-linux-aarch64-{}", onnx_version),
49            "tgz".to_string(),
50            DylibName::So,
51            FileType::Tgz
52        ),
53        ("macos", "x86_64") => (
54            format!("{}onnxruntime-osx-x86_64-{}.tgz", base_url, onnx_version),
55            format!("onnxruntime-osx-x86_64-{}", onnx_version),
56            "tgz".to_string(),
57            DylibName::Dylib,
58            FileType::Tgz
59        ),
60        ("macos", "aarch64") => (
61            format!("{}onnxruntime-osx-arm64-{}.tgz", base_url, onnx_version),
62            format!("onnxruntime-osx-arm64-{}", onnx_version),
63            "tgz".to_string(),
64            DylibName::Dylib,
65            FileType::Tgz
66        ),
67        ("windows", "x86_64") => (
68            format!("{}onnxruntime-win-x64-{}.zip", base_url, onnx_version),
69            format!("onnxruntime-win-x64-{}", onnx_version),
70            "zip".to_string(),
71            DylibName::Dll,
72            FileType::Zip
73        ),
74        ("windows", "aarch64") => (
75            format!("{}onnxruntime-win-arm64-{}.zip", base_url, onnx_version),
76            format!("onnxruntime-win-arm64-{}", onnx_version),
77            "zip".to_string(),
78            DylibName::Dll,
79            FileType::Zip
80        ),
81        _ => panic!(
82            "Unsupported platform or architecture: {} {}",
83            consts::OS,
84            consts::ARCH
85        ),
86    }
87}
88
89#[proc_macro]
90pub fn embed_onnx(attr: TokenStream) -> TokenStream {
91
92    // get the onnx version
93    let input = parse_macro_input!(attr as LitStr);
94    let supported_versions = ["1.20.0"];
95    let onnx_version = match input.value().as_str() {
96        "1.20.0" => "1.20.0",
97        _ => panic!(
98            "{} passed in as version, only the following versions are supported: {:?}", 
99            input.value(), supported_versions
100        )
101    };
102
103    let (url, package_name, ext, dylib_name, file_type) = get_onnxruntime_url(
104        onnx_version
105    );
106
107    // Persistent cache under target directory
108    let target_root = std::env::var("CARGO_TARGET_DIR")
109    .map(PathBuf::from)
110    .unwrap_or_else(|_| {
111        Path::new(&env!("CARGO_MANIFEST_DIR"))
112            .parent()
113            .expect("crate not in a workspace?")
114            .join("target")
115    });
116
117    let cache = target_root.join("onnxruntime_cache").join(onnx_version);
118    fs::create_dir_all(&cache).expect("failed to create cache dir");
119
120    let filename = format!("{}.{}", package_name, ext);
121    let download_path = cache.join(&filename);
122    let extract_target = cache.join(&package_name);
123    let lib_path = extract_target.join("lib");
124    let dylib_name_str: &str = dylib_name.clone().into();
125    let dylib_path = lib_path.join(dylib_name_str);
126
127    // obtain the lock for multiple downloads at the same time
128    let lock_path = cache.join("onnx_download.lock");
129    let mut lock = fslock::LockFile::open(&lock_path).expect("Failed to open lock file");
130    lock.lock().expect("Failed to acquire download lock");
131
132    if !download_path.exists() {
133        println!("Downloading ONNX Runtime from {}", url);
134        let response = reqwest::blocking::get(&url)
135            .expect("Failed to download ONNX Runtime")
136            .bytes()
137            .expect("Failed to read ONNX Runtime response");
138
139        let mut file = File::create(&download_path).expect("Failed to create ONNX file");
140        file.write_all(&response).expect("Failed to write ONNX file");
141        println!("Saved to {}", download_path.display());
142    }
143
144    if !dylib_path.exists() {
145        match file_type {
146            FileType::Tgz => extract_tgz(&download_path, &cache).expect("Failed to extract ONNX archive for tgz"),
147            FileType::Zip => extract_zip(&download_path, &cache).expect("Failed to extract ONNX archive for zip")
148        };
149    }
150
151    // zip the contents of the lib dir into bytes
152    let mut buffer = Cursor::new(Vec::new());
153    zip_dir(&lib_path, &mut buffer).expect("Failed to zip directory");
154
155    // attach a flag onto the start of the bytes to denote the name of the dylib
156    let raw_bytes = buffer.into_inner();
157
158    // release the lock for other processes
159    lock.unlock().expect("Failed to release download lock");
160
161    let byte_string = Literal::byte_string(&raw_bytes);
162
163    let tokens = quote! {
164        #byte_string
165    };
166
167    TokenStream::from(tokens)
168}