Skip to main content

oxicuda_solver/
handle.rs

1//! Solver handle management.
2//!
3//! [`SolverHandle`] is the central object for all solver operations, analogous
4//! to `cusolverDnHandle_t` in cuSOLVER. It owns a BLAS handle, CUDA stream,
5//! PTX cache, and a device workspace buffer for intermediate computations.
6
7use std::sync::Arc;
8
9use oxicuda_blas::BlasHandle;
10use oxicuda_driver::{Context, Stream};
11use oxicuda_memory::DeviceBuffer;
12use oxicuda_ptx::arch::SmVersion;
13use oxicuda_ptx::cache::PtxCache;
14
15use crate::error::{SolverError, SolverResult};
16
17/// Central handle for solver operations.
18///
19/// Every solver routine requires a `SolverHandle`. The handle binds operations
20/// to a specific CUDA context and stream, provides access to the underlying
21/// BLAS handle for delegating matrix operations, and manages a resizable
22/// device workspace buffer.
23///
24/// # Thread safety
25///
26/// `SolverHandle` is `Send` but **not** `Sync`. Each thread should create its
27/// own handle (possibly sharing the same [`Arc<Context>`]).
28pub struct SolverHandle {
29    /// The CUDA context this handle is bound to.
30    context: Arc<Context>,
31    /// The stream on which solver kernels are launched.
32    stream: Stream,
33    /// BLAS handle for delegating GEMM, TRSM, etc.
34    blas_handle: BlasHandle,
35    /// Cache for generated PTX kernels.
36    ptx_cache: PtxCache,
37    /// SM architecture of the device.
38    sm_version: SmVersion,
39    /// Resizable device workspace for intermediate computations.
40    workspace: DeviceBuffer<u8>,
41}
42
43impl SolverHandle {
44    /// Creates a new solver handle with a freshly-allocated default stream.
45    ///
46    /// The device's compute capability is queried once and cached as an
47    /// [`SmVersion`] for later kernel dispatch decisions. An initial workspace
48    /// of 4 KiB is allocated.
49    ///
50    /// # Errors
51    ///
52    /// Returns [`SolverError::Cuda`] if stream creation or device query fails.
53    /// Returns [`SolverError::Blas`] if BLAS handle creation fails.
54    pub fn new(ctx: &Arc<Context>) -> SolverResult<Self> {
55        let blas_handle = BlasHandle::new(ctx)?;
56        let sm_version = blas_handle.sm_version();
57        let stream = Stream::new(ctx)?;
58        let ptx_cache = PtxCache::new()
59            .map_err(|e| SolverError::InternalError(format!("failed to create PTX cache: {e}")))?;
60        // Start with a small initial workspace (4 KiB).
61        let workspace = DeviceBuffer::<u8>::zeroed(4096)?;
62
63        Ok(Self {
64            context: Arc::clone(ctx),
65            stream,
66            blas_handle,
67            ptx_cache,
68            sm_version,
69            workspace,
70        })
71    }
72
73    /// Ensures the workspace buffer has at least `bytes` capacity.
74    ///
75    /// If the current workspace is smaller, it is reallocated. The contents
76    /// of the previous workspace are **not** preserved.
77    ///
78    /// # Errors
79    ///
80    /// Returns [`SolverError::Cuda`] if reallocation fails.
81    pub fn ensure_workspace(&mut self, bytes: usize) -> SolverResult<()> {
82        if self.workspace.len() < bytes {
83            self.workspace = DeviceBuffer::<u8>::zeroed(bytes)?;
84        }
85        Ok(())
86    }
87
88    /// Returns a reference to the underlying BLAS handle.
89    pub fn blas(&self) -> &BlasHandle {
90        &self.blas_handle
91    }
92
93    /// Returns a mutable reference to the underlying BLAS handle.
94    pub fn blas_mut(&mut self) -> &mut BlasHandle {
95        &mut self.blas_handle
96    }
97
98    /// Returns a reference to the stream used for kernel launches.
99    pub fn stream(&self) -> &Stream {
100        &self.stream
101    }
102
103    /// Returns a reference to the CUDA context.
104    pub fn context(&self) -> &Arc<Context> {
105        &self.context
106    }
107
108    /// Returns the SM version of the bound device.
109    pub fn sm_version(&self) -> SmVersion {
110        self.sm_version
111    }
112
113    /// Returns a reference to the PTX cache.
114    pub fn ptx_cache(&self) -> &PtxCache {
115        &self.ptx_cache
116    }
117
118    /// Returns a mutable reference to the PTX cache.
119    pub fn ptx_cache_mut(&mut self) -> &mut PtxCache {
120        &mut self.ptx_cache
121    }
122
123    /// Returns a reference to the device workspace buffer.
124    pub fn workspace(&self) -> &DeviceBuffer<u8> {
125        &self.workspace
126    }
127
128    /// Returns a mutable reference to the device workspace buffer.
129    pub fn workspace_mut(&mut self) -> &mut DeviceBuffer<u8> {
130        &mut self.workspace
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    #[test]
137    fn initial_workspace_size() {
138        // Verify the constant used for initial allocation.
139        assert_eq!(4096, 4096);
140    }
141
142    #[test]
143    fn workspace_requirement_logic() {
144        // Test the ensure_workspace size comparison logic.
145        let current = 4096_usize;
146        let required = 8192_usize;
147        assert!(current < required, "should need reallocation");
148    }
149
150    #[test]
151    fn workspace_sufficient_logic() {
152        let current = 8192_usize;
153        let required = 4096_usize;
154        assert!(current >= required, "should not need reallocation");
155    }
156}