Skip to main content

baracuda_core/
loader.rs

1//! Thin dynamic-loader wrapper around `libloading`.
2//!
3//! This crate does not know about any particular NVIDIA library; each `-sys`
4//! crate instantiates a [`Library`] with its own candidate filenames and
5//! symbol-resolution strategy. The Driver API in particular layers
6//! `cuGetProcAddress`-based symbol resolution on top of [`Library::symbol`];
7//! everything else (cudart, cublas, ...) can call `symbol` directly.
8
9use std::ffi::CStr;
10use std::path::{Path, PathBuf};
11
12use crate::error::LoaderError;
13use crate::platform;
14
15/// Dynamically-loaded NVIDIA library (wraps [`libloading::Library`]).
16pub struct Library {
17    name: &'static str,
18    lib: libloading::Library,
19    /// Records the path the library actually resolved from, for diagnostics.
20    resolved_from: Option<PathBuf>,
21}
22
23impl std::fmt::Debug for Library {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("Library")
26            .field("name", &self.name)
27            .field("resolved_from", &self.resolved_from)
28            .finish_non_exhaustive()
29    }
30}
31
32impl Library {
33    /// Open `candidates[0]`, `candidates[1]`, ... in order, falling back to
34    /// each path returned by [`platform::library_search_paths`]. Returns
35    /// the first success or [`LoaderError::LibraryNotFound`] / platform
36    /// error.
37    pub fn open(name: &'static str, candidates: &[&'static str]) -> Result<Self, LoaderError> {
38        if matches!(platform::os_family(), platform::OsFamily::Unsupported) {
39            return Err(LoaderError::UnsupportedPlatform {
40                platform: std::env::consts::OS,
41            });
42        }
43        if candidates.is_empty() {
44            return Err(LoaderError::library_not_found(name, candidates));
45        }
46
47        // Phase 1: try each candidate name bare (OS handles the search).
48        for candidate in candidates {
49            if let Ok(lib) = unsafe { libloading::Library::new(candidate) } {
50                return Ok(Self {
51                    name,
52                    lib,
53                    resolved_from: Some(PathBuf::from(candidate)),
54                });
55            }
56        }
57
58        // Phase 2: try each candidate inside each explicit search directory.
59        let search_paths = platform::library_search_paths();
60        for dir in &search_paths {
61            for candidate in candidates {
62                let full = dir.join(candidate);
63                if let Ok(lib) = unsafe { libloading::Library::new(&full) } {
64                    return Ok(Self {
65                        name,
66                        lib,
67                        resolved_from: Some(full),
68                    });
69                }
70            }
71        }
72
73        Err(LoaderError::library_not_found_with_search(
74            name,
75            candidates,
76            search_paths.len(),
77        ))
78    }
79
80    /// Open a library at the specific path `path` (no search). Mostly used
81    /// in tests to inject a known library location.
82    pub fn open_at(name: &'static str, path: &Path) -> Result<Self, LoaderError> {
83        let lib = unsafe { libloading::Library::new(path) }?;
84        Ok(Self {
85            name,
86            lib,
87            resolved_from: Some(path.to_path_buf()),
88        })
89    }
90
91    /// The logical library name baracuda knows it by (e.g. `"cuda-driver"`,
92    /// `"cublas"`).
93    #[inline]
94    pub fn name(&self) -> &'static str {
95        self.name
96    }
97
98    /// The absolute path the library actually resolved from, if known.
99    #[inline]
100    pub fn resolved_from(&self) -> Option<&Path> {
101        self.resolved_from.as_deref()
102    }
103
104    /// Resolve `symbol`. The caller is responsible for the type `T` matching
105    /// the C signature of the symbol; consequently, this function is `unsafe`.
106    ///
107    /// # Errors
108    ///
109    /// [`LoaderError::SymbolNotFound`] if `dlsym`/`GetProcAddress` returns
110    /// a null pointer; [`LoaderError::Libloading`] for other `libloading`
111    /// failures.
112    ///
113    /// # Safety
114    ///
115    /// `T` must be a function-pointer type (`unsafe extern "C" fn(...) -> ...`)
116    /// matching the C signature of `symbol`. Calling the returned symbol
117    /// with the wrong signature is undefined behavior.
118    pub unsafe fn symbol<T>(
119        &self,
120        symbol: &'static str,
121    ) -> Result<libloading::Symbol<'_, T>, LoaderError> {
122        let bytes_with_nul: Vec<u8> = symbol.bytes().chain(std::iter::once(0)).collect();
123        let cstr = CStr::from_bytes_with_nul(&bytes_with_nul).map_err(|_| {
124            LoaderError::SymbolNotFound {
125                library: self.name,
126                symbol,
127            }
128        })?;
129        match self.lib.get::<T>(cstr.to_bytes_with_nul()) {
130            Ok(s) => Ok(s),
131            Err(_) => Err(LoaderError::SymbolNotFound {
132                library: self.name,
133                symbol,
134            }),
135        }
136    }
137
138    /// Return a raw pointer to the symbol without wrapping in `libloading::Symbol`.
139    /// Useful for stashing function pointers in `OnceLock`s that outlive the
140    /// borrow checker's view of the library.
141    ///
142    /// # Safety
143    ///
144    /// Same as [`Self::symbol`]. Additionally, the caller must ensure the
145    /// [`Library`] outlives any use of the returned pointer — in practice this
146    /// means storing the [`Library`] in a `static OnceLock<Library>` or
147    /// equivalent.
148    pub unsafe fn raw_symbol(&self, symbol: &'static str) -> Result<*mut (), LoaderError> {
149        let sym: libloading::Symbol<'_, *mut ()> = self.symbol(symbol)?;
150        Ok(*sym)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn missing_library_reports_candidates() {
160        let err = Library::open(
161            "unobtanium",
162            &["libunobtanium.so.42", "unobtanium64_42.dll"],
163        );
164        match err {
165            Err(LoaderError::LibraryNotFound {
166                library,
167                candidates,
168                ..
169            }) => {
170                assert_eq!(library, "unobtanium");
171                assert_eq!(candidates.len(), 2);
172            }
173            Err(LoaderError::UnsupportedPlatform { .. }) => {
174                // Acceptable on non-Linux/Windows CI runners.
175            }
176            other => panic!("expected LibraryNotFound, got {other:?}"),
177        }
178    }
179
180    #[test]
181    fn empty_candidates_returns_library_not_found() {
182        let err = Library::open("nothing", &[]);
183        match err {
184            Err(LoaderError::LibraryNotFound { library, .. }) => {
185                assert_eq!(library, "nothing");
186            }
187            Err(LoaderError::UnsupportedPlatform { .. }) => {}
188            other => panic!("expected LibraryNotFound, got {other:?}"),
189        }
190    }
191}