Skip to main content

baracuda_runtime/
driver_entry.rs

1//! Runtime-to-Driver entry-point bridge — `cudaGetDriverEntryPoint`.
2//!
3//! Most code using this crate goes through the typed Driver loader in
4//! `baracuda-driver`. The runtime's `cudaGetDriverEntryPoint` is useful
5//! for one narrow case: asking the installed runtime which driver
6//! symbol name / fptr it would resolve for a given API, without
7//! touching `libcuda` directly. Handy for diagnostic tools and for
8//! picking up versioned symbol variants (`_ptsz`, `_v2`, …).
9
10use core::ffi::{c_int, c_void};
11use std::ffi::CString;
12
13use baracuda_cuda_sys::runtime::runtime;
14
15use crate::error::{check, Error, Result};
16
17/// Typed outcome of [`driver_entry_point`]. `status` mirrors the
18/// `cudaDriverEntryPointQueryResult` enum reported by the runtime:
19/// 0 = Success, 1 = SymbolNotFound, 2 = VersionNotSufficient.
20#[derive(Copy, Clone, Debug)]
21pub struct DriverEntryPoint {
22    pub fn_ptr: *mut c_void,
23    pub status: i32,
24}
25
26impl DriverEntryPoint {
27    #[inline]
28    pub fn is_success(&self) -> bool {
29        self.status == 0 && !self.fn_ptr.is_null()
30    }
31}
32
33/// Resolve a Driver-API symbol by name through the Runtime API
34/// (`cudaGetDriverEntryPoint`). `flags = 0` = default; bit 0 = legacy
35/// stream, bit 1 = per-thread stream (mirrors `cuGetProcAddress`).
36pub fn driver_entry_point(symbol: &str, flags: u64) -> Result<DriverEntryPoint> {
37    let c_sym = CString::new(symbol).map_err(|_| {
38        Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
39            library: "cuda-runtime",
40            symbol: "cudaGetDriverEntryPoint(symbol contained a NUL byte)",
41        })
42    })?;
43    let r = runtime()?;
44    let cu = r.cuda_get_driver_entry_point()?;
45    let mut fn_ptr: *mut c_void = core::ptr::null_mut();
46    let mut driver_status: c_int = 0;
47    check(unsafe { cu(c_sym.as_ptr(), &mut fn_ptr, flags, &mut driver_status) })?;
48    Ok(DriverEntryPoint {
49        fn_ptr,
50        status: driver_status,
51    })
52}