Skip to main content

baracuda_kernels/linalg/
solve.rs

1//! Linear solve `A · X = B` via `getrf` + `getrs`.
2//!
3//! Wraps cuSOLVER's `cusolverDnSgetrf` / `Dgetrf` (LU factorization
4//! with partial pivoting) followed by `cusolverDnSgetrs` / `Dgetrs`
5//! (triangular substitutions over the packed `LU` + pivot). The plan
6//! owns no scratch state across calls — pivot + info are caller-
7//! provided, the workspace bytes are reported through
8//! [`SolvePlan::workspace_size`] and supplied as `Workspace::Borrowed`.
9//!
10//! **2-D only** — single `A`, single `B`. No batching today.
11//!
12//! **In-place semantics**: `A` is overwritten with the packed `LU`
13//! factors (cuSOLVER `getrf` convention — `L` in the strict lower
14//! triangle with implicit unit diagonal, `U` in the upper triangle).
15//! `B` is overwritten with the solution `X`.
16//!
17//! **Storage convention**: like the rest of the linalg family, the
18//! trailblazer passes through cuSOLVER's column-major view of the
19//! caller's byte storage. The LU plan documents the same convention —
20//! callers that want row-major end-to-end semantics must transpose on
21//! either side (a future shape-layout op can fuse the transpose).
22//!
23//! **Workspace**: cuSOLVER's `getrs` is workspace-free; the entire
24//! workspace requirement is the one queried from
25//! `cusolverDnSgetrf_bufferSize` / `Dgetrf_bufferSize`.
26
27use core::cell::Cell;
28use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_sys::{
34    cusolverDnCreate, cusolverDnDestroy, cusolverDnDgetrf, cusolverDnDgetrf_bufferSize,
35    cusolverDnDgetrs, cusolverDnHandle_t, cusolverDnSetStream, cusolverDnSgetrf,
36    cusolverDnSgetrf_bufferSize, cusolverDnSgetrs, CUBLAS_OP_N,
37};
38use baracuda_kernels_types::{
39    ArchSku, BackendKind, Element, ElementKind, KernelSku, LinalgKind, MathPrecision, OpCategory,
40    PlanPreference, PrecisionGuarantee, TensorMut, Workspace,
41};
42
43use super::cholesky::unpack_workspace;
44
45/// Descriptor for a linear-solve.
46#[derive(Copy, Clone, Debug)]
47pub struct SolveDescriptor {
48    /// Order `M` of the (square) coefficient matrix `A`.
49    pub m: i32,
50    /// Number of right-hand sides — column count of `B` / `X`.
51    pub nrhs: i32,
52    /// Element type. Must be `F32` or `F64`.
53    pub element: ElementKind,
54}
55
56/// Args bundle for a linear-solve launch.
57///
58/// `a` is overwritten in place with the packed `LU` factors produced by
59/// `getrf`. `b` is overwritten in place with the solution `X`. `pivot`
60/// receives cuSOLVER's 1-based pivot indices (length `M`). `info`
61/// receives the single factorization-status word (`0` on success,
62/// `k > 0` if `U[k, k] == 0` at step `k`).
63pub struct SolveArgs<'a, T: Element> {
64    /// Coefficient matrix `[M, M]` (column-major). Overwritten with
65    /// packed `LU` in place.
66    pub a: TensorMut<'a, T, 2>,
67    /// Right-hand side `[M, NRHS]` (column-major) on input; solution
68    /// `X` on output.
69    pub b: TensorMut<'a, T, 2>,
70    /// Pivot vector `[M]` (1-based per LAPACK convention).
71    pub pivot: TensorMut<'a, i32, 1>,
72    /// Single-cell info: `0` on success.
73    pub info: TensorMut<'a, i32, 1>,
74}
75
76/// Linear-solve plan — `A · X = B` via `getrf` + `getrs`.
77///
78/// Two-step pipeline per `run`: `getrf` factors `A` in place to packed
79/// `LU` + pivots, then `getrs` solves over the packed factorization.
80/// `B` is overwritten with `X`.
81///
82/// **When to use**: square solve over a general `A`. Use
83/// [`super::CholeskyPlan`] + a `trsm` chain when `A` is SPD;
84/// [`super::LstSqPlan`] for least-squares.
85///
86/// **Dtypes**: `f32`, `f64`.
87///
88/// **Shape**: `[M, M]` × `[M, NRHS]`. 2-D only.
89///
90/// **Storage**: column-major end-to-end.
91///
92/// **Workspace**: cuSOLVER `_bufferSize` for `getrf` (queried lazily on
93/// first `run`).
94///
95/// **Precision guarantee**: deterministic; not bit-stable across runs.
96///
97/// Owns a lazy cuSOLVER handle (`!Sync` / `!Send`); destroyed on `Drop`.
98pub struct SolvePlan<T: Element> {
99    desc: SolveDescriptor,
100    sku: KernelSku,
101    handle: Cell<cusolverDnHandle_t>,
102    workspace_bytes: Cell<usize>,
103    _marker: PhantomData<T>,
104}
105
106impl<T: Element> SolvePlan<T> {
107    /// Pick a kernel + validate the descriptor.
108    pub fn select(
109        _stream: &Stream,
110        desc: &SolveDescriptor,
111        _pref: PlanPreference,
112    ) -> Result<Self> {
113        if desc.element != T::KIND {
114            return Err(Error::Unsupported(
115                "baracuda-kernels::SolvePlan: descriptor.element != T::KIND",
116            ));
117        }
118        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
119            return Err(Error::Unsupported(
120                "baracuda-kernels::SolvePlan: cuSOLVER dense solve supports f32 + f64 only",
121            ));
122        }
123        if desc.m <= 0 {
124            return Err(Error::InvalidProblem(
125                "baracuda-kernels::SolvePlan: m must be > 0",
126            ));
127        }
128        if desc.nrhs <= 0 {
129            return Err(Error::InvalidProblem(
130                "baracuda-kernels::SolvePlan: nrhs must be > 0",
131            ));
132        }
133
134        let math_precision = match T::KIND {
135            ElementKind::F64 => MathPrecision::F64,
136            _ => MathPrecision::F32,
137        };
138        let precision_guarantee = PrecisionGuarantee {
139            math_precision,
140            accumulator: T::KIND,
141            bit_stable_on_same_hardware: false,
142            deterministic: true,
143        };
144        let sku = KernelSku {
145            category: OpCategory::Linalg,
146            op: LinalgKind::Solve as u16,
147            element: T::KIND,
148            aux_element: Some(ElementKind::I32),
149            layout: None,
150            epilogue: None,
151            arch: ArchSku::Sm80,
152            backend: BackendKind::Cusolver,
153            precision_guarantee,
154        };
155        Ok(Self {
156            desc: *desc,
157            sku,
158            handle: Cell::new(core::ptr::null_mut()),
159            workspace_bytes: Cell::new(0),
160            _marker: PhantomData,
161        })
162    }
163
164    /// Kernel SKU identity.
165    #[inline]
166    pub fn sku(&self) -> KernelSku {
167        self.sku
168    }
169
170    /// Numerical guarantees.
171    #[inline]
172    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
173        self.sku.precision_guarantee
174    }
175
176    /// Workspace size in bytes (the `getrf` requirement; `getrs` is
177    /// workspace-free). Lazily populated on first `run`.
178    #[inline]
179    pub fn workspace_size(&self) -> usize {
180        self.workspace_bytes.get()
181    }
182
183    /// Materialize the handle and query workspace size.
184    pub fn query_workspace_size(&self, _stream: &Stream) -> Result<usize> {
185        let h = self.ensure_handle()?;
186        let mut lwork: i32 = 0;
187        let status = match T::KIND {
188            ElementKind::F32 => unsafe {
189                cusolverDnSgetrf_bufferSize(
190                    h,
191                    self.desc.m,
192                    self.desc.m,
193                    core::ptr::null_mut(),
194                    self.desc.m,
195                    &mut lwork as *mut _,
196                )
197            },
198            ElementKind::F64 => unsafe {
199                cusolverDnDgetrf_bufferSize(
200                    h,
201                    self.desc.m,
202                    self.desc.m,
203                    core::ptr::null_mut(),
204                    self.desc.m,
205                    &mut lwork as *mut _,
206                )
207            },
208            _ => unreachable!("select() gates on F32 / F64"),
209        };
210        if status != 0 {
211            return Err(Error::CutlassInternal(-status));
212        }
213        let bytes = (lwork as usize) * core::mem::size_of::<T>();
214        self.workspace_bytes.set(bytes);
215        Ok(bytes)
216    }
217
218    fn ensure_handle(&self) -> Result<cusolverDnHandle_t> {
219        let h = self.handle.get();
220        if !h.is_null() {
221            return Ok(h);
222        }
223        let mut handle: cusolverDnHandle_t = core::ptr::null_mut();
224        let status = unsafe { cusolverDnCreate(&mut handle as *mut _) };
225        if status != 0 {
226            return Err(Error::CutlassInternal(-status));
227        }
228        self.handle.set(handle);
229        Ok(handle)
230    }
231
232    fn bind_stream(&self, h: cusolverDnHandle_t, stream: &Stream) -> Result<()> {
233        let status = unsafe { cusolverDnSetStream(h, stream.as_raw() as *mut c_void) };
234        if status != 0 {
235            return Err(Error::CutlassInternal(-status));
236        }
237        Ok(())
238    }
239
240    fn check_args(&self, args: &SolveArgs<'_, T>) -> Result<()> {
241        let m = self.desc.m;
242        let nrhs = self.desc.nrhs;
243        if args.a.shape != [m, m] {
244            return Err(Error::InvalidProblem(
245                "baracuda-kernels::SolvePlan: A shape != [M, M]",
246            ));
247        }
248        if args.b.shape != [m, nrhs] {
249            return Err(Error::InvalidProblem(
250                "baracuda-kernels::SolvePlan: B shape != [M, NRHS]",
251            ));
252        }
253        if args.pivot.shape != [m] {
254            return Err(Error::InvalidProblem(
255                "baracuda-kernels::SolvePlan: pivot shape != [M]",
256            ));
257        }
258        if args.info.shape != [1] {
259            return Err(Error::InvalidProblem(
260                "baracuda-kernels::SolvePlan: info shape != [1]",
261            ));
262        }
263        Ok(())
264    }
265}
266
267// Macro to instantiate run() for f32 / f64.
268macro_rules! impl_solve_run {
269    ($T:ty, $getrf:ident, $getrs:ident) => {
270        impl SolvePlan<$T> {
271            /// Run the linear solve.
272            pub fn run(
273                &self,
274                stream: &Stream,
275                workspace: Workspace<'_>,
276                args: SolveArgs<'_, $T>,
277            ) -> Result<()> {
278                self.check_args(&args)?;
279                let h = self.ensure_handle()?;
280                self.bind_stream(h, stream)?;
281                let m = self.desc.m;
282                let nrhs = self.desc.nrhs;
283
284                let needed = if self.workspace_bytes.get() == 0 {
285                    self.query_workspace_size(stream)?
286                } else {
287                    self.workspace_bytes.get()
288                };
289                let (ws_ptr, _ws_bytes) = unpack_workspace(workspace, needed)?;
290
291                let a_ptr = args.a.data.as_raw().0 as *mut $T;
292                let b_ptr = args.b.data.as_raw().0 as *mut $T;
293                let pivot_ptr = args.pivot.data.as_raw().0 as *mut i32;
294                let info_ptr = args.info.data.as_raw().0 as *mut i32;
295
296                // 1. getrf — factors A in place, writes pivot + info.
297                let status = unsafe {
298                    $getrf(h, m, m, a_ptr, m, ws_ptr as *mut $T, pivot_ptr, info_ptr)
299                };
300                if status != 0 {
301                    return Err(Error::CutlassInternal(-status));
302                }
303
304                // 2. getrs — solves A · X = B in place over B. trans
305                //    == N because storage is end-to-end column-major.
306                let status = unsafe {
307                    $getrs(
308                        h,
309                        CUBLAS_OP_N,
310                        m,
311                        nrhs,
312                        a_ptr as *const $T,
313                        m,
314                        pivot_ptr as *const i32,
315                        b_ptr,
316                        m,
317                        info_ptr,
318                    )
319                };
320                if status != 0 {
321                    return Err(Error::CutlassInternal(-status));
322                }
323                Ok(())
324            }
325        }
326    };
327}
328
329impl_solve_run!(f32, cusolverDnSgetrf, cusolverDnSgetrs);
330impl_solve_run!(f64, cusolverDnDgetrf, cusolverDnDgetrs);
331
332impl<T: Element> Drop for SolvePlan<T> {
333    fn drop(&mut self) {
334        let h = self.handle.get();
335        if !h.is_null() {
336            unsafe {
337                let _ = cusolverDnDestroy(h);
338            }
339            self.handle.set(core::ptr::null_mut());
340        }
341    }
342}