baracuda_core/loader.rs
1//! Thin dynamic-loader wrapper around `libloading`.
2//!
3//! This crate does not know about any particular NVIDIA library; each `-sys`
4//! crate instantiates a [`Library`] with its own candidate filenames and
5//! symbol-resolution strategy. The Driver API in particular layers
6//! `cuGetProcAddress`-based symbol resolution on top of [`Library::symbol`];
7//! everything else (cudart, cublas, ...) can call `symbol` directly.
8
9use std::ffi::CStr;
10use std::path::{Path, PathBuf};
11
12use crate::error::LoaderError;
13use crate::platform;
14
15/// Dynamically-loaded NVIDIA library (wraps [`libloading::Library`]).
16pub struct Library {
17 name: &'static str,
18 lib: libloading::Library,
19 /// Records the path the library actually resolved from, for diagnostics.
20 resolved_from: Option<PathBuf>,
21}
22
23impl std::fmt::Debug for Library {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("Library")
26 .field("name", &self.name)
27 .field("resolved_from", &self.resolved_from)
28 .finish_non_exhaustive()
29 }
30}
31
32impl Library {
33 /// Open `candidates[0]`, `candidates[1]`, ... in order, falling back to
34 /// each path returned by [`platform::library_search_paths`]. Returns
35 /// the first success or [`LoaderError::LibraryNotFound`] / platform
36 /// error.
37 pub fn open(name: &'static str, candidates: &[&'static str]) -> Result<Self, LoaderError> {
38 if matches!(platform::os_family(), platform::OsFamily::Unsupported) {
39 return Err(LoaderError::UnsupportedPlatform {
40 platform: std::env::consts::OS,
41 });
42 }
43 if candidates.is_empty() {
44 return Err(LoaderError::library_not_found(name, candidates));
45 }
46
47 // Phase 1: try each candidate name bare (OS handles the search).
48 for candidate in candidates {
49 if let Ok(lib) = unsafe { libloading::Library::new(candidate) } {
50 return Ok(Self {
51 name,
52 lib,
53 resolved_from: Some(PathBuf::from(candidate)),
54 });
55 }
56 }
57
58 // Phase 2: try each candidate inside each explicit search directory.
59 let search_paths = platform::library_search_paths();
60 for dir in &search_paths {
61 for candidate in candidates {
62 let full = dir.join(candidate);
63 if let Ok(lib) = unsafe { libloading::Library::new(&full) } {
64 return Ok(Self {
65 name,
66 lib,
67 resolved_from: Some(full),
68 });
69 }
70 }
71 }
72
73 Err(LoaderError::library_not_found_with_search(
74 name,
75 candidates,
76 search_paths.len(),
77 ))
78 }
79
80 /// Open a library at the specific path `path` (no search). Mostly used
81 /// in tests to inject a known library location.
82 pub fn open_at(name: &'static str, path: &Path) -> Result<Self, LoaderError> {
83 let lib = unsafe { libloading::Library::new(path) }?;
84 Ok(Self {
85 name,
86 lib,
87 resolved_from: Some(path.to_path_buf()),
88 })
89 }
90
91 /// The logical library name baracuda knows it by (e.g. `"cuda-driver"`,
92 /// `"cublas"`).
93 #[inline]
94 pub fn name(&self) -> &'static str {
95 self.name
96 }
97
98 /// The absolute path the library actually resolved from, if known.
99 #[inline]
100 pub fn resolved_from(&self) -> Option<&Path> {
101 self.resolved_from.as_deref()
102 }
103
104 /// Resolve `symbol`. The caller is responsible for the type `T` matching
105 /// the C signature of the symbol; consequently, this function is `unsafe`.
106 ///
107 /// # Errors
108 ///
109 /// [`LoaderError::SymbolNotFound`] if `dlsym`/`GetProcAddress` returns
110 /// a null pointer; [`LoaderError::Libloading`] for other `libloading`
111 /// failures.
112 ///
113 /// # Safety
114 ///
115 /// `T` must be a function-pointer type (`unsafe extern "C" fn(...) -> ...`)
116 /// matching the C signature of `symbol`. Calling the returned symbol
117 /// with the wrong signature is undefined behavior.
118 pub unsafe fn symbol<T>(
119 &self,
120 symbol: &'static str,
121 ) -> Result<libloading::Symbol<'_, T>, LoaderError> {
122 let bytes_with_nul: Vec<u8> = symbol.bytes().chain(std::iter::once(0)).collect();
123 let cstr = CStr::from_bytes_with_nul(&bytes_with_nul).map_err(|_| {
124 LoaderError::SymbolNotFound {
125 library: self.name,
126 symbol,
127 }
128 })?;
129 match self.lib.get::<T>(cstr.to_bytes_with_nul()) {
130 Ok(s) => Ok(s),
131 Err(_) => Err(LoaderError::SymbolNotFound {
132 library: self.name,
133 symbol,
134 }),
135 }
136 }
137
138 /// Return a raw pointer to the symbol without wrapping in `libloading::Symbol`.
139 /// Useful for stashing function pointers in `OnceLock`s that outlive the
140 /// borrow checker's view of the library.
141 ///
142 /// # Safety
143 ///
144 /// Same as [`Self::symbol`]. Additionally, the caller must ensure the
145 /// [`Library`] outlives any use of the returned pointer — in practice this
146 /// means storing the [`Library`] in a `static OnceLock<Library>` or
147 /// equivalent.
148 pub unsafe fn raw_symbol(&self, symbol: &'static str) -> Result<*mut (), LoaderError> {
149 let sym: libloading::Symbol<'_, *mut ()> = self.symbol(symbol)?;
150 Ok(*sym)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn missing_library_reports_candidates() {
160 let err = Library::open(
161 "unobtanium",
162 &["libunobtanium.so.42", "unobtanium64_42.dll"],
163 );
164 match err {
165 Err(LoaderError::LibraryNotFound {
166 library,
167 candidates,
168 ..
169 }) => {
170 assert_eq!(library, "unobtanium");
171 assert_eq!(candidates.len(), 2);
172 }
173 Err(LoaderError::UnsupportedPlatform { .. }) => {
174 // Acceptable on non-Linux/Windows CI runners.
175 }
176 other => panic!("expected LibraryNotFound, got {other:?}"),
177 }
178 }
179
180 #[test]
181 fn empty_candidates_returns_library_not_found() {
182 let err = Library::open("nothing", &[]);
183 match err {
184 Err(LoaderError::LibraryNotFound { library, .. }) => {
185 assert_eq!(library, "nothing");
186 }
187 Err(LoaderError::UnsupportedPlatform { .. }) => {}
188 other => panic!("expected LibraryNotFound, got {other:?}"),
189 }
190 }
191}