next_plaid_cli/
onnx_runtime.rs

1//! ONNX Runtime auto-setup
2//!
3//! Automatically finds or downloads ONNX Runtime library.
4
5use anyhow::{Context, Result};
6use std::env;
7use std::fs;
8use std::path::{Path, PathBuf};
9
10const ORT_VERSION: &str = "1.23.0";
11
12#[cfg(target_os = "macos")]
13const ORT_LIB_NAME: &str = "libonnxruntime.dylib";
14
15#[cfg(target_os = "linux")]
16const ORT_LIB_NAME: &str = "libonnxruntime.so";
17
18#[cfg(target_os = "windows")]
19const ORT_LIB_NAME: &str = "onnxruntime.dll";
20
21/// Ensure ONNX Runtime is available.
22/// Sets ORT_DYLIB_PATH if found or downloaded.
23pub fn ensure_onnx_runtime() -> Result<PathBuf> {
24    // 1. Check if already set
25    if let Ok(path) = env::var("ORT_DYLIB_PATH") {
26        let path = PathBuf::from(&path);
27        if path.exists() {
28            return Ok(path);
29        }
30    }
31
32    // 2. Search common locations
33    if let Some(path) = find_onnx_runtime() {
34        env::set_var("ORT_DYLIB_PATH", &path);
35        return Ok(path);
36    }
37
38    // 3. Download and cache
39    let path = download_onnx_runtime()?;
40    env::set_var("ORT_DYLIB_PATH", &path);
41    Ok(path)
42}
43
44/// Search for ONNX Runtime in common locations
45fn find_onnx_runtime() -> Option<PathBuf> {
46    let search_paths = get_search_paths();
47
48    for base_path in search_paths {
49        // Direct library file
50        let lib_path = base_path.join(ORT_LIB_NAME);
51        if lib_path.exists() {
52            return Some(lib_path);
53        }
54
55        // Versioned library (e.g., libonnxruntime.1.20.1.dylib)
56        if let Ok(entries) = fs::read_dir(&base_path) {
57            for entry in entries.flatten() {
58                let name = entry.file_name();
59                let name_str = name.to_string_lossy();
60                if name_str.starts_with("libonnxruntime")
61                    && (name_str.ends_with(".dylib") || name_str.ends_with(".so"))
62                {
63                    return Some(entry.path());
64                }
65            }
66        }
67
68        // Check lib subdirectory
69        let lib_subdir = base_path.join("lib").join(ORT_LIB_NAME);
70        if lib_subdir.exists() {
71            return Some(lib_subdir);
72        }
73    }
74
75    None
76}
77
78/// Get list of paths to search for ONNX Runtime
79fn get_search_paths() -> Vec<PathBuf> {
80    let mut paths = Vec::new();
81
82    // Home directory for cache
83    if let Some(home) = dirs::home_dir() {
84        // Our cache location
85        paths.push(home.join(".cache").join("onnxruntime").join(ORT_VERSION));
86
87        // Conda environments
88        if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
89            let conda_path = PathBuf::from(&conda_prefix);
90            paths.push(conda_path.join("lib"));
91
92            // Python site-packages in conda
93            for entry in [
94                "lib/python3.12",
95                "lib/python3.11",
96                "lib/python3.10",
97                "lib/python3.9",
98            ] {
99                paths.push(
100                    conda_path
101                        .join(entry)
102                        .join("site-packages/onnxruntime/capi"),
103                );
104            }
105        }
106
107        // Virtual environments
108        for venv_name in [".venv", "venv", ".env", "env"] {
109            let venv_path = std::env::current_dir()
110                .map(|cwd| cwd.join(venv_name))
111                .unwrap_or_default();
112
113            #[cfg(target_os = "windows")]
114            paths.push(venv_path.join("Lib/site-packages/onnxruntime/capi"));
115
116            #[cfg(not(target_os = "windows"))]
117            for py in ["python3.12", "python3.11", "python3.10", "python3.9"] {
118                paths.push(
119                    venv_path
120                        .join("lib")
121                        .join(py)
122                        .join("site-packages/onnxruntime/capi"),
123                );
124            }
125        }
126
127        // UV cache
128        paths.push(home.join(".cache/uv"));
129
130        // Homebrew (macOS)
131        #[cfg(target_os = "macos")]
132        {
133            paths.push(PathBuf::from("/opt/homebrew/lib"));
134            paths.push(PathBuf::from("/usr/local/lib"));
135        }
136
137        // System paths (Linux)
138        #[cfg(target_os = "linux")]
139        {
140            paths.push(PathBuf::from("/usr/lib"));
141            paths.push(PathBuf::from("/usr/local/lib"));
142            paths.push(PathBuf::from("/usr/lib/x86_64-linux-gnu"));
143        }
144    }
145
146    paths
147}
148
149/// Download ONNX Runtime from GitHub releases
150fn download_onnx_runtime() -> Result<PathBuf> {
151    let cache_dir = dirs::home_dir()
152        .context("Could not find home directory")?
153        .join(".cache")
154        .join("onnxruntime")
155        .join(ORT_VERSION);
156
157    let lib_path = cache_dir.join(ORT_LIB_NAME);
158
159    // Already cached
160    if lib_path.exists() {
161        return Ok(lib_path);
162    }
163
164    fs::create_dir_all(&cache_dir)?;
165
166    let (url, archive_lib_path) = get_download_url()?;
167
168    eprintln!("⚙️  Runtime: ONNX {}", ORT_VERSION);
169
170    // Download archive
171    let response = ureq::get(&url)
172        .call()
173        .context("Failed to download ONNX Runtime")?;
174
175    let mut archive_data = Vec::new();
176    response.into_reader().read_to_end(&mut archive_data)?;
177
178    // Extract library from archive
179    extract_library(&archive_data, &archive_lib_path, &lib_path)?;
180    Ok(lib_path)
181}
182
183/// Get download URL for current platform
184fn get_download_url() -> Result<(String, String)> {
185    let base = format!(
186        "https://github.com/microsoft/onnxruntime/releases/download/v{}",
187        ORT_VERSION
188    );
189
190    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
191    let (archive, lib_path) = (
192        format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION),
193        format!(
194            "onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib",
195            ORT_VERSION, ORT_VERSION
196        ),
197    );
198
199    #[cfg(all(target_os = "macos", target_arch = "x86_64"))]
200    let (archive, lib_path) = (
201        format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION),
202        format!(
203            "onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib",
204            ORT_VERSION, ORT_VERSION
205        ),
206    );
207
208    #[cfg(all(target_os = "linux", target_arch = "x86_64"))]
209    let (archive, lib_path) = (
210        format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION),
211        format!(
212            "onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}",
213            ORT_VERSION, ORT_VERSION
214        ),
215    );
216
217    #[cfg(all(target_os = "linux", target_arch = "aarch64"))]
218    let (archive, lib_path) = (
219        format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION),
220        format!(
221            "onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}",
222            ORT_VERSION, ORT_VERSION
223        ),
224    );
225
226    #[cfg(all(target_os = "windows", target_arch = "x86_64"))]
227    let (archive, lib_path) = (
228        format!("onnxruntime-win-x64-{}.zip", ORT_VERSION),
229        format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION),
230    );
231
232    #[cfg(not(any(
233        all(target_os = "macos", target_arch = "aarch64"),
234        all(target_os = "macos", target_arch = "x86_64"),
235        all(target_os = "linux", target_arch = "x86_64"),
236        all(target_os = "linux", target_arch = "aarch64"),
237        all(target_os = "windows", target_arch = "x86_64"),
238    )))]
239    return Err(anyhow::anyhow!(
240        "Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH."
241    ));
242
243    Ok((format!("{}/{}", base, archive), lib_path))
244}
245
246/// Extract library from tgz archive
247#[cfg(not(target_os = "windows"))]
248fn extract_library(archive_data: &[u8], lib_path_in_archive: &str, dest: &Path) -> Result<()> {
249    use flate2::read::GzDecoder;
250    use std::io::Read;
251
252    let decoder = GzDecoder::new(archive_data);
253    let mut archive = tar::Archive::new(decoder);
254
255    for entry in archive.entries()? {
256        let mut entry = entry?;
257        let path = entry.path()?;
258        let path_str = path.to_string_lossy();
259
260        // Handle paths with or without ./ prefix (macOS archives have ./, Linux doesn't)
261        let normalized_path = path_str.strip_prefix("./").unwrap_or(&path_str);
262
263        if normalized_path == lib_path_in_archive {
264            let mut lib_data = Vec::new();
265            entry.read_to_end(&mut lib_data)?;
266            fs::write(dest, lib_data)?;
267
268            // Make executable on Unix
269            #[cfg(unix)]
270            {
271                use std::os::unix::fs::PermissionsExt;
272                fs::set_permissions(dest, fs::Permissions::from_mode(0o755))?;
273            }
274
275            return Ok(());
276        }
277    }
278
279    Err(anyhow::anyhow!(
280        "Library not found in archive: {}",
281        lib_path_in_archive
282    ))
283}
284
285/// Extract library from zip archive (Windows)
286#[cfg(target_os = "windows")]
287fn extract_library(archive_data: &[u8], lib_path_in_archive: &str, dest: &Path) -> Result<()> {
288    use std::io::{Cursor, Read};
289
290    let cursor = Cursor::new(archive_data);
291    let mut archive = zip::ZipArchive::new(cursor)?;
292
293    for i in 0..archive.len() {
294        let mut file = archive.by_index(i)?;
295        let path = file.name();
296
297        // Handle paths with or without ./ prefix
298        let normalized_path = path.strip_prefix("./").unwrap_or(path);
299
300        if normalized_path == lib_path_in_archive {
301            let mut lib_data = Vec::new();
302            file.read_to_end(&mut lib_data)?;
303            fs::write(dest, lib_data)?;
304            return Ok(());
305        }
306    }
307
308    Err(anyhow::anyhow!(
309        "Library not found in archive: {}",
310        lib_path_in_archive
311    ))
312}