oxicuda_solver/handle.rs
1//! Solver handle management.
2//!
3//! [`SolverHandle`] is the central object for all solver operations, analogous
4//! to `cusolverDnHandle_t` in cuSOLVER. It owns a BLAS handle, CUDA stream,
5//! PTX cache, and a device workspace buffer for intermediate computations.
6
7use std::sync::Arc;
8
9use oxicuda_blas::BlasHandle;
10use oxicuda_driver::{Context, Stream};
11use oxicuda_memory::DeviceBuffer;
12use oxicuda_ptx::arch::SmVersion;
13use oxicuda_ptx::cache::PtxCache;
14
15use crate::error::{SolverError, SolverResult};
16
17/// Central handle for solver operations.
18///
19/// Every solver routine requires a `SolverHandle`. The handle binds operations
20/// to a specific CUDA context and stream, provides access to the underlying
21/// BLAS handle for delegating matrix operations, and manages a resizable
22/// device workspace buffer.
23///
24/// # Thread safety
25///
26/// `SolverHandle` is `Send` but **not** `Sync`. Each thread should create its
27/// own handle (possibly sharing the same [`Arc<Context>`]).
28pub struct SolverHandle {
29 /// The CUDA context this handle is bound to.
30 context: Arc<Context>,
31 /// The stream on which solver kernels are launched.
32 stream: Stream,
33 /// BLAS handle for delegating GEMM, TRSM, etc.
34 blas_handle: BlasHandle,
35 /// Cache for generated PTX kernels.
36 ptx_cache: PtxCache,
37 /// SM architecture of the device.
38 sm_version: SmVersion,
39 /// Resizable device workspace for intermediate computations.
40 workspace: DeviceBuffer<u8>,
41}
42
43impl SolverHandle {
44 /// Creates a new solver handle with a freshly-allocated default stream.
45 ///
46 /// The device's compute capability is queried once and cached as an
47 /// [`SmVersion`] for later kernel dispatch decisions. An initial workspace
48 /// of 4 KiB is allocated.
49 ///
50 /// # Errors
51 ///
52 /// Returns [`SolverError::Cuda`] if stream creation or device query fails.
53 /// Returns [`SolverError::Blas`] if BLAS handle creation fails.
54 pub fn new(ctx: &Arc<Context>) -> SolverResult<Self> {
55 let blas_handle = BlasHandle::new(ctx)?;
56 let sm_version = blas_handle.sm_version();
57 let stream = Stream::new(ctx)?;
58 let ptx_cache = PtxCache::new()
59 .map_err(|e| SolverError::InternalError(format!("failed to create PTX cache: {e}")))?;
60 // Start with a small initial workspace (4 KiB).
61 let workspace = DeviceBuffer::<u8>::zeroed(4096)?;
62
63 Ok(Self {
64 context: Arc::clone(ctx),
65 stream,
66 blas_handle,
67 ptx_cache,
68 sm_version,
69 workspace,
70 })
71 }
72
73 /// Ensures the workspace buffer has at least `bytes` capacity.
74 ///
75 /// If the current workspace is smaller, it is reallocated. The contents
76 /// of the previous workspace are **not** preserved.
77 ///
78 /// # Errors
79 ///
80 /// Returns [`SolverError::Cuda`] if reallocation fails.
81 pub fn ensure_workspace(&mut self, bytes: usize) -> SolverResult<()> {
82 if self.workspace.len() < bytes {
83 self.workspace = DeviceBuffer::<u8>::zeroed(bytes)?;
84 }
85 Ok(())
86 }
87
88 /// Returns a reference to the underlying BLAS handle.
89 pub fn blas(&self) -> &BlasHandle {
90 &self.blas_handle
91 }
92
93 /// Returns a mutable reference to the underlying BLAS handle.
94 pub fn blas_mut(&mut self) -> &mut BlasHandle {
95 &mut self.blas_handle
96 }
97
98 /// Returns a reference to the stream used for kernel launches.
99 pub fn stream(&self) -> &Stream {
100 &self.stream
101 }
102
103 /// Returns a reference to the CUDA context.
104 pub fn context(&self) -> &Arc<Context> {
105 &self.context
106 }
107
108 /// Returns the SM version of the bound device.
109 pub fn sm_version(&self) -> SmVersion {
110 self.sm_version
111 }
112
113 /// Returns a reference to the PTX cache.
114 pub fn ptx_cache(&self) -> &PtxCache {
115 &self.ptx_cache
116 }
117
118 /// Returns a mutable reference to the PTX cache.
119 pub fn ptx_cache_mut(&mut self) -> &mut PtxCache {
120 &mut self.ptx_cache
121 }
122
123 /// Returns a reference to the device workspace buffer.
124 pub fn workspace(&self) -> &DeviceBuffer<u8> {
125 &self.workspace
126 }
127
128 /// Returns a mutable reference to the device workspace buffer.
129 pub fn workspace_mut(&mut self) -> &mut DeviceBuffer<u8> {
130 &mut self.workspace
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 #[test]
137 fn initial_workspace_size() {
138 // Verify the constant used for initial allocation.
139 assert_eq!(4096, 4096);
140 }
141
142 #[test]
143 fn workspace_requirement_logic() {
144 // Test the ensure_workspace size comparison logic.
145 let current = 4096_usize;
146 let required = 8192_usize;
147 assert!(current < required, "should need reallocation");
148 }
149
150 #[test]
151 fn workspace_sufficient_logic() {
152 let current = 8192_usize;
153 let required = 4096_usize;
154 assert!(current >= required, "should not need reallocation");
155 }
156}