1extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::Literal;
6use syn::{parse_macro_input, LitStr};
7use quote::quote;
8use std::{
9 env,
10 env::consts,
11 fs::{self, File},
12 io::{Write, Cursor},
13 path::{Path, PathBuf},
14};
15mod file_extraction;
16use file_extraction::{
17 FileType,
18 extract_tgz,
19 extract_zip,
20 zip_dir,
21 DylibName
22};
23
24
25fn get_onnxruntime_url(onnx_version: &str) -> (String, String, String, DylibName, FileType) {
33 let base_url = format!(
34 "https://github.com/microsoft/onnxruntime/releases/download/v{}/",
35 onnx_version
36 );
37
38 match (consts::OS, consts::ARCH) {
39 ("linux", "x86_64") => (
40 format!("{}onnxruntime-linux-x64-{}.tgz", base_url, onnx_version),
41 format!("onnxruntime-linux-x64-{}", onnx_version),
42 "tgz".to_string(),
43 DylibName::So,
44 FileType::Tgz
45 ),
46 ("linux", "aarch64") => (
47 format!("{}onnxruntime-linux-aarch64-{}.tgz", base_url, onnx_version),
48 format!("onnxruntime-linux-aarch64-{}", onnx_version),
49 "tgz".to_string(),
50 DylibName::So,
51 FileType::Tgz
52 ),
53 ("macos", "x86_64") => (
54 format!("{}onnxruntime-osx-x86_64-{}.tgz", base_url, onnx_version),
55 format!("onnxruntime-osx-x86_64-{}", onnx_version),
56 "tgz".to_string(),
57 DylibName::Dylib,
58 FileType::Tgz
59 ),
60 ("macos", "aarch64") => (
61 format!("{}onnxruntime-osx-arm64-{}.tgz", base_url, onnx_version),
62 format!("onnxruntime-osx-arm64-{}", onnx_version),
63 "tgz".to_string(),
64 DylibName::Dylib,
65 FileType::Tgz
66 ),
67 ("windows", "x86_64") => (
68 format!("{}onnxruntime-win-x64-{}.zip", base_url, onnx_version),
69 format!("onnxruntime-win-x64-{}", onnx_version),
70 "zip".to_string(),
71 DylibName::Dll,
72 FileType::Zip
73 ),
74 ("windows", "aarch64") => (
75 format!("{}onnxruntime-win-arm64-{}.zip", base_url, onnx_version),
76 format!("onnxruntime-win-arm64-{}", onnx_version),
77 "zip".to_string(),
78 DylibName::Dll,
79 FileType::Zip
80 ),
81 _ => panic!(
82 "Unsupported platform or architecture: {} {}",
83 consts::OS,
84 consts::ARCH
85 ),
86 }
87}
88
89#[proc_macro]
90pub fn embed_onnx(attr: TokenStream) -> TokenStream {
91
92 let input = parse_macro_input!(attr as LitStr);
94 let supported_versions = ["1.20.0"];
95 let onnx_version = match input.value().as_str() {
96 "1.20.0" => "1.20.0",
97 _ => panic!(
98 "{} passed in as version, only the following versions are supported: {:?}",
99 input.value(), supported_versions
100 )
101 };
102
103 let (url, package_name, ext, dylib_name, file_type) = get_onnxruntime_url(
104 onnx_version
105 );
106
107 let target_root = std::env::var("CARGO_TARGET_DIR")
109 .map(PathBuf::from)
110 .unwrap_or_else(|_| {
111 Path::new(&env!("CARGO_MANIFEST_DIR"))
112 .parent()
113 .expect("crate not in a workspace?")
114 .join("target")
115 });
116
117 let cache = target_root.join("onnxruntime_cache").join(onnx_version);
118 fs::create_dir_all(&cache).expect("failed to create cache dir");
119
120 let filename = format!("{}.{}", package_name, ext);
121 let download_path = cache.join(&filename);
122 let extract_target = cache.join(&package_name);
123 let lib_path = extract_target.join("lib");
124 let dylib_name_str: &str = dylib_name.clone().into();
125 let dylib_path = lib_path.join(dylib_name_str);
126
127 let lock_path = cache.join("onnx_download.lock");
129 let mut lock = fslock::LockFile::open(&lock_path).expect("Failed to open lock file");
130 lock.lock().expect("Failed to acquire download lock");
131
132 if !download_path.exists() {
133 println!("Downloading ONNX Runtime from {}", url);
134 let response = reqwest::blocking::get(&url)
135 .expect("Failed to download ONNX Runtime")
136 .bytes()
137 .expect("Failed to read ONNX Runtime response");
138
139 let mut file = File::create(&download_path).expect("Failed to create ONNX file");
140 file.write_all(&response).expect("Failed to write ONNX file");
141 println!("Saved to {}", download_path.display());
142 }
143
144 if !dylib_path.exists() {
145 match file_type {
146 FileType::Tgz => extract_tgz(&download_path, &cache).expect("Failed to extract ONNX archive for tgz"),
147 FileType::Zip => extract_zip(&download_path, &cache).expect("Failed to extract ONNX archive for zip")
148 };
149 }
150
151 let mut buffer = Cursor::new(Vec::new());
153 zip_dir(&lib_path, &mut buffer).expect("Failed to zip directory");
154
155 let raw_bytes = buffer.into_inner();
157
158 lock.unlock().expect("Failed to release download lock");
160
161 let byte_string = Literal::byte_string(&raw_bytes);
162
163 let tokens = quote! {
164 #byte_string
165 };
166
167 TokenStream::from(tokens)
168}