next_plaid_cli/
onnx_runtime.rs1use anyhow::{Context, Result};
6use std::env;
7use std::fs;
8use std::path::{Path, PathBuf};
9
10const ORT_VERSION: &str = "1.23.0";
11
12#[cfg(target_os = "macos")]
13const ORT_LIB_NAME: &str = "libonnxruntime.dylib";
14
15#[cfg(target_os = "linux")]
16const ORT_LIB_NAME: &str = "libonnxruntime.so";
17
18#[cfg(target_os = "windows")]
19const ORT_LIB_NAME: &str = "onnxruntime.dll";
20
21pub fn ensure_onnx_runtime() -> Result<PathBuf> {
24 if let Ok(path) = env::var("ORT_DYLIB_PATH") {
26 let path = PathBuf::from(&path);
27 if path.exists() {
28 return Ok(path);
29 }
30 }
31
32 if let Some(path) = find_onnx_runtime() {
34 env::set_var("ORT_DYLIB_PATH", &path);
35 return Ok(path);
36 }
37
38 let path = download_onnx_runtime()?;
40 env::set_var("ORT_DYLIB_PATH", &path);
41 Ok(path)
42}
43
44fn find_onnx_runtime() -> Option<PathBuf> {
46 let search_paths = get_search_paths();
47
48 for base_path in search_paths {
49 let lib_path = base_path.join(ORT_LIB_NAME);
51 if lib_path.exists() {
52 return Some(lib_path);
53 }
54
55 if let Ok(entries) = fs::read_dir(&base_path) {
57 for entry in entries.flatten() {
58 let name = entry.file_name();
59 let name_str = name.to_string_lossy();
60 if name_str.starts_with("libonnxruntime")
61 && (name_str.ends_with(".dylib") || name_str.ends_with(".so"))
62 {
63 return Some(entry.path());
64 }
65 }
66 }
67
68 let lib_subdir = base_path.join("lib").join(ORT_LIB_NAME);
70 if lib_subdir.exists() {
71 return Some(lib_subdir);
72 }
73 }
74
75 None
76}
77
78fn get_search_paths() -> Vec<PathBuf> {
80 let mut paths = Vec::new();
81
82 if let Some(home) = dirs::home_dir() {
84 paths.push(home.join(".cache").join("onnxruntime").join(ORT_VERSION));
86
87 if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
89 let conda_path = PathBuf::from(&conda_prefix);
90 paths.push(conda_path.join("lib"));
91
92 for entry in [
94 "lib/python3.12",
95 "lib/python3.11",
96 "lib/python3.10",
97 "lib/python3.9",
98 ] {
99 paths.push(
100 conda_path
101 .join(entry)
102 .join("site-packages/onnxruntime/capi"),
103 );
104 }
105 }
106
107 for venv_name in [".venv", "venv", ".env", "env"] {
109 let venv_path = std::env::current_dir()
110 .map(|cwd| cwd.join(venv_name))
111 .unwrap_or_default();
112
113 #[cfg(target_os = "windows")]
114 paths.push(venv_path.join("Lib/site-packages/onnxruntime/capi"));
115
116 #[cfg(not(target_os = "windows"))]
117 for py in ["python3.12", "python3.11", "python3.10", "python3.9"] {
118 paths.push(
119 venv_path
120 .join("lib")
121 .join(py)
122 .join("site-packages/onnxruntime/capi"),
123 );
124 }
125 }
126
127 paths.push(home.join(".cache/uv"));
129
130 #[cfg(target_os = "macos")]
132 {
133 paths.push(PathBuf::from("/opt/homebrew/lib"));
134 paths.push(PathBuf::from("/usr/local/lib"));
135 }
136
137 #[cfg(target_os = "linux")]
139 {
140 paths.push(PathBuf::from("/usr/lib"));
141 paths.push(PathBuf::from("/usr/local/lib"));
142 paths.push(PathBuf::from("/usr/lib/x86_64-linux-gnu"));
143 }
144 }
145
146 paths
147}
148
149fn download_onnx_runtime() -> Result<PathBuf> {
151 let cache_dir = dirs::home_dir()
152 .context("Could not find home directory")?
153 .join(".cache")
154 .join("onnxruntime")
155 .join(ORT_VERSION);
156
157 let lib_path = cache_dir.join(ORT_LIB_NAME);
158
159 if lib_path.exists() {
161 return Ok(lib_path);
162 }
163
164 fs::create_dir_all(&cache_dir)?;
165
166 let (url, archive_lib_path) = get_download_url()?;
167
168 eprintln!("⚙️ Runtime: ONNX {}", ORT_VERSION);
169
170 let response = ureq::get(&url)
172 .call()
173 .context("Failed to download ONNX Runtime")?;
174
175 let mut archive_data = Vec::new();
176 response.into_reader().read_to_end(&mut archive_data)?;
177
178 extract_library(&archive_data, &archive_lib_path, &lib_path)?;
180 Ok(lib_path)
181}
182
183fn get_download_url() -> Result<(String, String)> {
185 let base = format!(
186 "https://github.com/microsoft/onnxruntime/releases/download/v{}",
187 ORT_VERSION
188 );
189
190 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
191 let (archive, lib_path) = (
192 format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION),
193 format!(
194 "onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib",
195 ORT_VERSION, ORT_VERSION
196 ),
197 );
198
199 #[cfg(all(target_os = "macos", target_arch = "x86_64"))]
200 let (archive, lib_path) = (
201 format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION),
202 format!(
203 "onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib",
204 ORT_VERSION, ORT_VERSION
205 ),
206 );
207
208 #[cfg(all(target_os = "linux", target_arch = "x86_64"))]
209 let (archive, lib_path) = (
210 format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION),
211 format!(
212 "onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}",
213 ORT_VERSION, ORT_VERSION
214 ),
215 );
216
217 #[cfg(all(target_os = "linux", target_arch = "aarch64"))]
218 let (archive, lib_path) = (
219 format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION),
220 format!(
221 "onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}",
222 ORT_VERSION, ORT_VERSION
223 ),
224 );
225
226 #[cfg(all(target_os = "windows", target_arch = "x86_64"))]
227 let (archive, lib_path) = (
228 format!("onnxruntime-win-x64-{}.zip", ORT_VERSION),
229 format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION),
230 );
231
232 #[cfg(not(any(
233 all(target_os = "macos", target_arch = "aarch64"),
234 all(target_os = "macos", target_arch = "x86_64"),
235 all(target_os = "linux", target_arch = "x86_64"),
236 all(target_os = "linux", target_arch = "aarch64"),
237 all(target_os = "windows", target_arch = "x86_64"),
238 )))]
239 return Err(anyhow::anyhow!(
240 "Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH."
241 ));
242
243 Ok((format!("{}/{}", base, archive), lib_path))
244}
245
246#[cfg(not(target_os = "windows"))]
248fn extract_library(archive_data: &[u8], lib_path_in_archive: &str, dest: &Path) -> Result<()> {
249 use flate2::read::GzDecoder;
250 use std::io::Read;
251
252 let decoder = GzDecoder::new(archive_data);
253 let mut archive = tar::Archive::new(decoder);
254
255 for entry in archive.entries()? {
256 let mut entry = entry?;
257 let path = entry.path()?;
258 let path_str = path.to_string_lossy();
259
260 let normalized_path = path_str.strip_prefix("./").unwrap_or(&path_str);
262
263 if normalized_path == lib_path_in_archive {
264 let mut lib_data = Vec::new();
265 entry.read_to_end(&mut lib_data)?;
266 fs::write(dest, lib_data)?;
267
268 #[cfg(unix)]
270 {
271 use std::os::unix::fs::PermissionsExt;
272 fs::set_permissions(dest, fs::Permissions::from_mode(0o755))?;
273 }
274
275 return Ok(());
276 }
277 }
278
279 Err(anyhow::anyhow!(
280 "Library not found in archive: {}",
281 lib_path_in_archive
282 ))
283}
284
285#[cfg(target_os = "windows")]
287fn extract_library(archive_data: &[u8], lib_path_in_archive: &str, dest: &Path) -> Result<()> {
288 use std::io::{Cursor, Read};
289
290 let cursor = Cursor::new(archive_data);
291 let mut archive = zip::ZipArchive::new(cursor)?;
292
293 for i in 0..archive.len() {
294 let mut file = archive.by_index(i)?;
295 let path = file.name();
296
297 let normalized_path = path.strip_prefix("./").unwrap_or(path);
299
300 if normalized_path == lib_path_in_archive {
301 let mut lib_data = Vec::new();
302 file.read_to_end(&mut lib_data)?;
303 fs::write(dest, lib_data)?;
304 return Ok(());
305 }
306 }
307
308 Err(anyhow::anyhow!(
309 "Library not found in archive: {}",
310 lib_path_in_archive
311 ))
312}