1extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::Literal;
6use syn::{parse_macro_input, LitStr};
7use quote::quote;
8use std::{
9 env,
10 env::consts,
11 fs::{self, File},
12 io::{self, Write},
13 path::{Path, PathBuf},
14};
15use flate2::read::GzDecoder;
16use tar::Archive;
17
18fn extract_tgz<P: AsRef<Path>>(tgz_path: P, extract_to: P) -> io::Result<()> {
24 let file = File::open(&tgz_path)?;
25 let decompressor = GzDecoder::new(file);
26 let mut archive = Archive::new(decompressor);
27 archive.unpack(&extract_to)?;
28 Ok(())
29}
30
31
32fn get_onnxruntime_url(onnx_version: &str) -> (String, String, String, String) {
40 let base_url = format!(
41 "https://github.com/microsoft/onnxruntime/releases/download/v{}/",
42 onnx_version
43 );
44
45 match (consts::OS, consts::ARCH) {
46 ("linux", "x86_64") => (
47 format!("{}onnxruntime-linux-x64-{}.tgz", base_url, onnx_version),
48 format!("onnxruntime-linux-x64-{}", onnx_version),
49 "tgz".to_string(),
50 "libonnxruntime.so".to_string(),
51 ),
52 ("linux", "aarch64") => (
53 format!("{}onnxruntime-linux-aarch64-{}.tgz", base_url, onnx_version),
54 format!("onnxruntime-linux-aarch64-{}", onnx_version),
55 "tgz".to_string(),
56 "libonnxruntime.so".to_string(),
57 ),
58 ("macos", "x86_64") => (
59 format!("{}onnxruntime-osx-x86_64-{}.tgz", base_url, onnx_version),
60 format!("onnxruntime-osx-x86_64-{}", onnx_version),
61 "tgz".to_string(),
62 "libonnxruntime.dylib".to_string(),
63 ),
64 ("macos", "aarch64") => (
65 format!("{}onnxruntime-osx-arm64-{}.tgz", base_url, onnx_version),
66 format!("onnxruntime-osx-arm64-{}", onnx_version),
67 "tgz".to_string(),
68 "libonnxruntime.dylib".to_string(),
69 ),
70 ("windows", "x86_64") => (
71 format!("{}onnxruntime-win-x64-{}.zip", base_url, onnx_version),
72 format!("onnxruntime-win-x64-{}", onnx_version),
73 "zip".to_string(),
74 "onnxruntime.dll".to_string(),
75 ),
76 ("windows", "aarch64") => (
77 format!("{}onnxruntime-win-arm64-{}.zip", base_url, onnx_version),
78 format!("onnxruntime-win-arm64-{}", onnx_version),
79 "zip".to_string(),
80 "onnxruntime.dll".to_string(),
81 ),
82 _ => panic!(
83 "Unsupported platform or architecture: {} {}",
84 consts::OS,
85 consts::ARCH
86 ),
87 }
88}
89
90#[proc_macro]
91pub fn embed_onnx(attr: TokenStream) -> TokenStream {
92
93 let input = parse_macro_input!(attr as LitStr);
95 let supported_versions = ["1.20.0"];
96 let onnx_version = match input.value().as_str() {
97 "1.20.0" => "1.20.0",
98 _ => panic!(
99 "{} passed in as version, only the following versions are supported: {:?}",
100 input.value(), supported_versions
101 )
102 };
103
104 let (url, package_name, ext, dylib_name) = get_onnxruntime_url(onnx_version);
105
106 let target_root = std::env::var("CARGO_TARGET_DIR")
108 .map(PathBuf::from)
109 .unwrap_or_else(|_| {
110 Path::new(&env!("CARGO_MANIFEST_DIR"))
111 .parent()
112 .expect("crate not in a workspace?")
113 .join("target")
114 });
115
116 let cache = target_root.join("onnxruntime_cache").join(onnx_version);
117 fs::create_dir_all(&cache).expect("failed to create cache dir");
118
119 let filename = format!("{}.{}", package_name, ext);
120 let download_path = cache.join(&filename);
121 let extract_target = cache.join(&package_name);
122 let dylib_path = extract_target.join("lib").join(&dylib_name);
123
124 if !download_path.exists() {
125 println!("Downloading ONNX Runtime from {}", url);
126 let response = reqwest::blocking::get(&url)
127 .expect("Failed to download ONNX Runtime")
128 .bytes()
129 .expect("Failed to read ONNX Runtime response");
130
131 let mut file = File::create(&download_path).expect("Failed to create ONNX file");
132 file.write_all(&response).expect("Failed to write ONNX file");
133 println!("Saved to {}", download_path.display());
134 }
135
136 if !dylib_path.exists() {
137 extract_tgz(&download_path, &cache)
138 .expect("Failed to extract ONNX archive");
139 }
140
141 let bytes: Vec<u8> = fs::read(&dylib_path).expect("Failed to read extracted library");
142
143 let byte_string = Literal::byte_string(&bytes);
144
145 let tokens = quote! {
146 #byte_string
147 };
148
149 TokenStream::from(tokens)
150}
151
152
153