baml_sys/
loader.rs

1//! Library resolution and loading.
2
3// Allow stderr output for download warning messages.
4#![allow(clippy::print_stderr)]
5
6use std::{
7    path::{Path, PathBuf},
8    sync::Mutex,
9};
10
11use libloading::Library;
12use once_cell::sync::OnceCell;
13
14use crate::{
15    download::download_library,
16    error::{BamlSysError, Result},
17};
18
19/// Package version from Cargo.toml (workspace).
20pub const VERSION: &str = env!("CARGO_PKG_VERSION");
21
22/// GitHub repository for releases.
23const GITHUB_REPO: &str = "boundaryml/baml";
24
25/// Environment variable for explicit library path.
26pub const ENV_LIBRARY_PATH: &str = "BAML_LIBRARY_PATH";
27
28/// Environment variable for cache directory override.
29pub const ENV_CACHE_DIR: &str = "BAML_CACHE_DIR";
30
31/// Environment variable to disable automatic download.
32pub const ENV_DISABLE_DOWNLOAD: &str = "BAML_LIBRARY_DISABLE_DOWNLOAD";
33
34/// Global library instance.
35static LIBRARY: OnceCell<LoadedLibrary> = OnceCell::new();
36
37/// Mutex for explicit path setting (before initialization).
38static EXPLICIT_PATH: Mutex<Option<PathBuf>> = Mutex::new(None);
39
40/// A loaded dynamic library with its path.
41pub(crate) struct LoadedLibrary {
42    pub(crate) library: Library,
43    pub(crate) path: PathBuf,
44}
45
46// Safety: libloading::Library is Send + Sync when the underlying
47// library's functions are thread-safe (which baml_cffi is).
48#[allow(unsafe_code)]
49unsafe impl Send for LoadedLibrary {}
50#[allow(unsafe_code)]
51unsafe impl Sync for LoadedLibrary {}
52
53/// Set an explicit library path before initialization.
54///
55/// Must be called before any FFI functions are used.
56/// Returns an error if the library is already loaded.
57pub fn set_library_path(path: impl AsRef<Path>) -> Result<()> {
58    let path = path.as_ref().to_path_buf();
59
60    if LIBRARY.get().is_some() {
61        let existing = LIBRARY.get().unwrap();
62        return Err(BamlSysError::AlreadyInitialized {
63            existing_path: existing.path.clone(),
64            requested_path: path,
65        });
66    }
67
68    let mut explicit = EXPLICIT_PATH.lock().unwrap();
69    *explicit = Some(path);
70    Ok(())
71}
72
73/// Ensure the library is available (for use in build.rs).
74///
75/// This function will:
76/// 1. Check if library exists at configured/default paths
77/// 2. Download if necessary and enabled
78/// 3. Return the path to the library
79///
80/// Does NOT load the library - that happens at runtime.
81pub fn ensure_library() -> Result<PathBuf> {
82    find_or_download_library()
83}
84
85/// Get the loaded library, initializing if necessary.
86pub(crate) fn get_library() -> Result<&'static LoadedLibrary> {
87    LIBRARY.get_or_try_init(|| {
88        let path = find_or_download_library()?;
89        load_library(&path)
90    })
91}
92
93/// Find or download the library, returning its path.
94fn find_or_download_library() -> Result<PathBuf> {
95    let mut searched_paths = Vec::new();
96
97    // 1. Check explicit path set via API
98    {
99        let explicit = EXPLICIT_PATH.lock().unwrap();
100        if let Some(path) = explicit.as_ref() {
101            if path.exists() {
102                return Ok(path.clone());
103            }
104            searched_paths.push(path.clone());
105        }
106    }
107
108    // 2. Check environment variable
109    if let Ok(env_path) = std::env::var(ENV_LIBRARY_PATH) {
110        // Env vars can be wrapped in quotes and spaces, so we need to unwrap them
111        let env_path = env_path.trim().trim_matches('"').trim();
112        let path = PathBuf::from(&env_path);
113        if path.exists() {
114            return Ok(path);
115        }
116        searched_paths.push(path);
117    }
118
119    // 3. Check cache directory
120    let cache_dir = get_cache_dir()?;
121    let lib_filename = get_library_filename()?;
122    let cached_path = cache_dir.join(&lib_filename);
123
124    if cached_path.exists() {
125        return Ok(cached_path);
126    }
127    searched_paths.push(cached_path.clone());
128
129    // 4. Try to download (if enabled)
130    #[cfg(feature = "download")]
131    if std::env::var(ENV_DISABLE_DOWNLOAD).map(|v| v.to_lowercase()) != Ok("true".to_string()) {
132        match download_library(&cache_dir, &lib_filename, VERSION, GITHUB_REPO) {
133            Ok(()) => return Ok(cached_path),
134            Err(e) => {
135                // Log warning but continue to system paths
136                eprintln!("Warning: Failed to download BAML library: {e}");
137            }
138        }
139    }
140
141    // 5. Check system default paths
142    for path in get_system_paths(&lib_filename) {
143        if path.exists() {
144            return Ok(path);
145        }
146        searched_paths.push(path);
147    }
148
149    Err(BamlSysError::LibraryNotFound { searched_paths })
150}
151
152/// Load the library from a path.
153fn load_library(path: &Path) -> Result<LoadedLibrary> {
154    // Safety: We're loading a dynamic library. The library must be
155    // compatible with our expected ABI.
156    #[allow(unsafe_code)]
157    let library = unsafe { Library::new(path) }.map_err(|e| BamlSysError::LoadFailed {
158        path: path.to_path_buf(),
159        source: e,
160    })?;
161
162    Ok(LoadedLibrary {
163        library,
164        path: path.to_path_buf(),
165    })
166}
167
168/// Get the cache directory for libraries.
169fn get_cache_dir() -> Result<PathBuf> {
170    // Check environment variable override
171    if let Ok(cache_dir) = std::env::var(ENV_CACHE_DIR) {
172        let path = PathBuf::from(cache_dir);
173        std::fs::create_dir_all(&path)?;
174        return Ok(path);
175    }
176
177    // Use platform-specific user cache directory
178    let base = dirs_cache_dir().ok_or_else(|| {
179        BamlSysError::CacheDir("Could not determine user cache directory".to_string())
180    })?;
181
182    // Structure: {cache}/baml/libs/{VERSION}/
183    let cache_dir = base.join("baml").join("libs").join(VERSION);
184    std::fs::create_dir_all(&cache_dir)?;
185
186    Ok(cache_dir)
187}
188
189/// Get platform-specific cache directory.
190fn dirs_cache_dir() -> Option<PathBuf> {
191    #[cfg(target_os = "macos")]
192    {
193        std::env::var_os("HOME").map(|h| PathBuf::from(h).join("Library/Caches"))
194    }
195
196    #[cfg(target_os = "linux")]
197    {
198        std::env::var_os("XDG_CACHE_HOME")
199            .map(PathBuf::from)
200            .or_else(|| std::env::var_os("HOME").map(|h| PathBuf::from(h).join(".cache")))
201    }
202
203    #[cfg(target_os = "windows")]
204    {
205        std::env::var_os("LOCALAPPDATA").map(PathBuf::from)
206    }
207
208    #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
209    {
210        None
211    }
212}
213
214/// Get the library filename for the current platform.
215fn get_library_filename() -> Result<String> {
216    let (prefix, ext, target_triple) = get_platform_info()?;
217    Ok(format!("{prefix}baml_cffi-{target_triple}.{ext}"))
218}
219
220/// Get platform-specific library info: (prefix, extension, `target_triple`).
221#[allow(clippy::unnecessary_wraps)] // Result needed for unsupported platform cfg fallback
222fn get_platform_info() -> Result<(&'static str, &'static str, &'static str)> {
223    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
224    return Ok(("lib", "dylib", "aarch64-apple-darwin"));
225
226    #[cfg(all(target_os = "macos", target_arch = "x86_64"))]
227    return Ok(("lib", "dylib", "x86_64-apple-darwin"));
228
229    #[cfg(all(target_os = "linux", target_arch = "x86_64"))]
230    return Ok(("lib", "so", "x86_64-unknown-linux-gnu"));
231
232    #[cfg(all(target_os = "linux", target_arch = "aarch64"))]
233    return Ok(("lib", "so", "aarch64-unknown-linux-gnu"));
234
235    #[cfg(all(target_os = "windows", target_arch = "x86_64"))]
236    return Ok(("", "dll", "x86_64-pc-windows-msvc"));
237
238    #[cfg(all(target_os = "windows", target_arch = "aarch64"))]
239    return Ok(("", "dll", "aarch64-pc-windows-msvc"));
240
241    #[cfg(not(any(
242        all(target_os = "macos", target_arch = "aarch64"),
243        all(target_os = "macos", target_arch = "x86_64"),
244        all(target_os = "linux", target_arch = "x86_64"),
245        all(target_os = "linux", target_arch = "aarch64"),
246        all(target_os = "windows", target_arch = "x86_64"),
247        all(target_os = "windows", target_arch = "aarch64"),
248    )))]
249    Err(BamlSysError::UnsupportedPlatform {
250        os: std::env::consts::OS,
251        arch: std::env::consts::ARCH,
252    })
253}
254
255/// Get system default library paths.
256fn get_system_paths(lib_filename: &str) -> Vec<PathBuf> {
257    let mut paths = Vec::new();
258
259    #[cfg(target_os = "macos")]
260    {
261        paths.push(PathBuf::from("/usr/local/lib").join(lib_filename));
262        paths.push(PathBuf::from("/usr/local/lib/libbaml_cffi.dylib"));
263    }
264
265    #[cfg(target_os = "linux")]
266    {
267        paths.push(PathBuf::from("/usr/local/lib").join(lib_filename));
268        paths.push(PathBuf::from("/usr/local/lib/libbaml_cffi.so"));
269        paths.push(PathBuf::from("/usr/lib").join(lib_filename));
270    }
271
272    #[cfg(target_os = "windows")]
273    {
274        if let Ok(program_files) = std::env::var("ProgramFiles") {
275            paths.push(
276                PathBuf::from(&program_files)
277                    .join("baml")
278                    .join(lib_filename),
279            );
280            paths.push(
281                PathBuf::from(&program_files)
282                    .join("baml")
283                    .join("baml_cffi.dll"),
284            );
285        }
286        if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
287            paths.push(
288                PathBuf::from(&local_app_data)
289                    .join("baml")
290                    .join(lib_filename),
291            );
292            paths.push(
293                PathBuf::from(&local_app_data)
294                    .join("baml")
295                    .join("baml_cffi.dll"),
296            );
297        }
298    }
299
300    paths
301}