Skip to main content

vulkan_rust/
loader.rs

1use std::ffi::{CStr, c_void};
2
3use crate::error::LoadError;
4
5/// Abstraction over loading symbols from the Vulkan shared library.
6///
7/// # Safety
8///
9/// Implementations must return valid function pointers for the requested
10/// symbol name, or null if the symbol is not found. Returning a pointer
11/// to the wrong function causes undefined behavior.
12///
13/// # Examples
14///
15/// ```
16/// use std::ffi::{CStr, c_void};
17/// use vulkan_rust::Loader;
18///
19/// struct NullLoader;
20///
21/// unsafe impl Loader for NullLoader {
22///     unsafe fn load(&self, _name: &CStr) -> *const c_void {
23///         std::ptr::null()
24///     }
25/// }
26///
27/// let loader = NullLoader;
28/// let ptr = unsafe { loader.load(c"vkCreateInstance") };
29/// assert!(ptr.is_null());
30/// ```
31pub unsafe trait Loader: Send + Sync {
32    /// Load a function by name from the Vulkan library.
33    ///
34    /// Returns a raw function pointer, or null if the symbol is not found.
35    ///
36    /// # Safety
37    ///
38    /// The caller must only transmute the returned pointer to a function
39    /// type matching the Vulkan command identified by `name`.
40    unsafe fn load(&self, name: &CStr) -> *const c_void;
41}
42
43/// Default [`Loader`] implementation backed by `libloading`.
44///
45/// Loads the platform-appropriate Vulkan shared library at construction
46/// time and resolves symbols from it on demand.
47///
48/// # Examples
49///
50/// ```no_run
51/// use vulkan_rust::LibloadingLoader;
52///
53/// let loader = unsafe { LibloadingLoader::new() }
54///     .expect("Vulkan library not found");
55/// ```
56pub struct LibloadingLoader {
57    lib: libloading::Library,
58}
59
60impl LibloadingLoader {
61    /// Load the platform's Vulkan shared library.
62    pub fn new() -> Result<Self, LoadError> {
63        // SAFETY: loading the platform's Vulkan shared library is standard initialization.
64        let lib = unsafe { load_vulkan_library()? };
65        Ok(Self { lib })
66    }
67}
68
69unsafe impl Loader for LibloadingLoader {
70    unsafe fn load(&self, name: &CStr) -> *const c_void {
71        // SAFETY: name is a valid CStr; libloading resolves it from the loaded library.
72        unsafe {
73            self.lib
74                .get::<*const c_void>(name.to_bytes_with_nul())
75                .map(|sym| *sym)
76                .unwrap_or(std::ptr::null())
77        }
78    }
79}
80
81/// Load the platform-specific Vulkan shared library.
82///
83/// # Safety
84///
85/// Loading a shared library can execute arbitrary initialization code.
86unsafe fn load_vulkan_library() -> Result<libloading::Library, LoadError> {
87    #[cfg(target_os = "windows")]
88    const LIB_NAMES: &[&str] = &["vulkan-1.dll"];
89
90    #[cfg(target_os = "linux")]
91    const LIB_NAMES: &[&str] = &["libvulkan.so.1", "libvulkan.so"];
92
93    #[cfg(target_os = "android")]
94    const LIB_NAMES: &[&str] = &["libvulkan.so.1", "libvulkan.so"];
95
96    #[cfg(target_os = "macos")]
97    const LIB_NAMES: &[&str] = &["libvulkan.1.dylib", "libMoltenVK.dylib"];
98
99    let mut last_err = None;
100    for name in LIB_NAMES {
101        // SAFETY: loading a shared library by platform-known name.
102        match unsafe { libloading::Library::new(name) } {
103            Ok(lib) => return Ok(lib),
104            Err(e) => last_err = Some(e),
105        }
106    }
107    Err(LoadError::Library(
108        last_err.expect("LIB_NAMES is non-empty"),
109    ))
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn null_loader_returns_null() {
118        struct TestNullLoader;
119        unsafe impl Loader for TestNullLoader {
120            unsafe fn load(&self, _name: &CStr) -> *const c_void {
121                std::ptr::null()
122            }
123        }
124        let loader = TestNullLoader;
125        let ptr = unsafe { loader.load(c"vkGetInstanceProcAddr") };
126        assert!(ptr.is_null());
127    }
128
129    #[test]
130    fn load_vulkan_library_returns_error_on_missing_platform() {
131        // Verify the error path is reachable.
132        let result = LoadError::MissingEntryPoint;
133        assert_eq!(
134            result.to_string(),
135            "vkGetInstanceProcAddr not found in Vulkan library"
136        );
137    }
138
139    #[test]
140    #[cfg(not(miri))] // libloading calls FFI that Miri cannot interpret
141    fn libloading_loader_new_returns_error_message_on_missing_lib() {
142        // We can't easily force a missing library, but we can verify
143        // the error path by trying to load a nonsense library name.
144        let err = unsafe { libloading::Library::new("__nonexistent_vulkan_lib__") };
145        assert!(err.is_err());
146        let load_err = LoadError::Library(err.unwrap_err());
147        let msg = load_err.to_string();
148        assert!(msg.contains("failed to load Vulkan library"));
149    }
150
151    #[test]
152    fn custom_loader_returns_non_null() {
153        struct FixedLoader;
154        unsafe impl Loader for FixedLoader {
155            unsafe fn load(&self, _name: &CStr) -> *const c_void {
156                0xDEAD as *const c_void
157            }
158        }
159        let loader = FixedLoader;
160        let ptr = unsafe { loader.load(c"vkGetInstanceProcAddr") };
161        assert!(!ptr.is_null());
162        assert_eq!(ptr as usize, 0xDEAD);
163    }
164
165    /// Compile-time check that Loader requires Send + Sync.
166    fn _assert_loader_is_send_sync<T: Loader>() {}
167    #[test]
168    fn loader_trait_requires_send_sync() {
169        struct TestLoader;
170        unsafe impl Loader for TestLoader {
171            unsafe fn load(&self, _name: &CStr) -> *const c_void {
172                std::ptr::null()
173            }
174        }
175        _assert_loader_is_send_sync::<TestLoader>();
176    }
177
178    #[test]
179    #[cfg(not(miri))] // libloading calls FFI that Miri cannot interpret
180    fn libloading_loader_new_error_is_load_error_library() {
181        // On systems without Vulkan, new() should return LoadError::Library.
182        // On systems WITH Vulkan, this test is still valid because it just
183        // verifies the error type from a manually constructed error.
184        let lib_err = unsafe { libloading::Library::new("__no_such_lib__") }.unwrap_err();
185        let err = LoadError::Library(lib_err);
186        match &err {
187            LoadError::Library(_) => {}
188            LoadError::MissingEntryPoint => panic!("expected Library variant"),
189        }
190    }
191
192    #[test]
193    #[cfg(not(miri))] // libloading calls FFI that Miri cannot interpret
194    fn libloading_loader_new_exercises_load_path() {
195        // Exercises LibloadingLoader::new() and load_vulkan_library() without
196        // requiring a working ICD. On CI (libvulkan-dev installed) this succeeds;
197        // on machines without Vulkan it exercises the error path. Either way
198        // the production code is covered.
199        match LibloadingLoader::new() {
200            Ok(loader) => {
201                // Library loaded, verify the Loader impl works.
202                let ptr = unsafe { loader.load(c"vkGetInstanceProcAddr") };
203                // May be null if the library has no ICD, but the call itself succeeds.
204                let _ = ptr;
205
206                // Unknown symbol should return null.
207                let unknown = unsafe { loader.load(c"vkNotARealFunction_XYZ") };
208                assert!(unknown.is_null(), "unknown symbol should return null");
209            }
210            Err(e) => {
211                // No Vulkan library on this system, verify error is Library variant.
212                assert!(
213                    matches!(e, LoadError::Library(_)),
214                    "expected LoadError::Library, got {e}"
215                );
216                assert!(e.to_string().contains("failed to load Vulkan library"));
217            }
218        }
219    }
220
221    #[test]
222    fn loader_is_object_safe() {
223        // Verify Loader can be used as a trait object, which is critical
224        // for Entry's Arc<dyn Loader> storage.
225        struct TestLoader;
226        unsafe impl Loader for TestLoader {
227            unsafe fn load(&self, _name: &CStr) -> *const c_void {
228                0xABCD as *const c_void
229            }
230        }
231        let loader: Box<dyn Loader> = Box::new(TestLoader);
232        let ptr = unsafe { loader.load(c"vkGetInstanceProcAddr") };
233        assert_eq!(ptr as usize, 0xABCD);
234    }
235
236    #[test]
237    fn loader_behind_arc_works() {
238        use std::sync::Arc;
239        struct TestLoader;
240        unsafe impl Loader for TestLoader {
241            unsafe fn load(&self, _name: &CStr) -> *const c_void {
242                0x1234 as *const c_void
243            }
244        }
245        let loader: Arc<dyn Loader> = Arc::new(TestLoader);
246        let ptr = unsafe { loader.load(c"vkGetInstanceProcAddr") };
247        assert_eq!(ptr as usize, 0x1234);
248        assert_eq!(Arc::strong_count(&loader), 1);
249    }
250
251    #[test]
252    fn loader_resolves_different_names_independently() {
253        use std::sync::atomic::{AtomicUsize, Ordering};
254        static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
255
256        struct CountingLoader;
257        unsafe impl Loader for CountingLoader {
258            unsafe fn load(&self, name: &CStr) -> *const c_void {
259                CALL_COUNT.fetch_add(1, Ordering::SeqCst);
260                match name.to_bytes() {
261                    b"vkGetInstanceProcAddr" => 0x1000 as *const c_void,
262                    b"vkGetDeviceProcAddr" => 0x2000 as *const c_void,
263                    _ => std::ptr::null(),
264                }
265            }
266        }
267
268        CALL_COUNT.store(0, Ordering::SeqCst);
269        let loader = CountingLoader;
270        let gipa = unsafe { loader.load(c"vkGetInstanceProcAddr") };
271        let gdpa = unsafe { loader.load(c"vkGetDeviceProcAddr") };
272        let unknown = unsafe { loader.load(c"vkUnknown") };
273
274        assert_eq!(gipa as usize, 0x1000);
275        assert_eq!(gdpa as usize, 0x2000);
276        assert!(unknown.is_null());
277        assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 3);
278    }
279
280    #[test]
281    #[ignore] // requires Vulkan runtime
282    fn libloading_loader_new_succeeds() {
283        let loader = LibloadingLoader::new().expect("failed to load Vulkan library");
284        let ptr = unsafe { loader.load(c"vkGetInstanceProcAddr") };
285        assert!(!ptr.is_null(), "vkGetInstanceProcAddr should be non-null");
286    }
287
288    #[test]
289    #[ignore] // requires Vulkan runtime
290    fn libloading_loader_resolves_device_proc_addr() {
291        let loader = LibloadingLoader::new().expect("failed to load Vulkan library");
292        let ptr = unsafe { loader.load(c"vkGetDeviceProcAddr") };
293        assert!(!ptr.is_null(), "vkGetDeviceProcAddr should be non-null");
294    }
295
296    #[test]
297    #[ignore] // requires Vulkan runtime
298    fn libloading_loader_returns_null_for_unknown_symbol() {
299        let loader = LibloadingLoader::new().expect("failed to load Vulkan library");
300        let ptr = unsafe { loader.load(c"vkNotARealFunction_XYZ") };
301        assert!(ptr.is_null(), "unknown symbol should return null");
302    }
303
304    #[test]
305    #[ignore] // requires Vulkan runtime
306    fn libloading_loader_resolves_create_instance() {
307        let loader = LibloadingLoader::new().expect("failed to load Vulkan library");
308        let ptr = unsafe { loader.load(c"vkCreateInstance") };
309        assert!(!ptr.is_null(), "vkCreateInstance should be non-null");
310    }
311
312    #[test]
313    #[ignore] // requires Vulkan runtime
314    fn libloading_loader_distinct_pointers_for_different_symbols() {
315        let loader = LibloadingLoader::new().expect("failed to load Vulkan library");
316        let gipa = unsafe { loader.load(c"vkGetInstanceProcAddr") };
317        let gdpa = unsafe { loader.load(c"vkGetDeviceProcAddr") };
318        assert!(!gipa.is_null());
319        assert!(!gdpa.is_null());
320        assert_ne!(
321            gipa, gdpa,
322            "different symbols should return different pointers"
323        );
324    }
325
326    #[test]
327    #[ignore] // requires Vulkan runtime
328    fn libloading_loader_same_symbol_returns_same_pointer() {
329        let loader = LibloadingLoader::new().expect("failed to load Vulkan library");
330        let ptr1 = unsafe { loader.load(c"vkGetInstanceProcAddr") };
331        let ptr2 = unsafe { loader.load(c"vkGetInstanceProcAddr") };
332        assert_eq!(ptr1, ptr2, "same symbol should return the same pointer");
333    }
334}