1use core::cell::Cell;
28use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_sys::{
34 cusolverDnCreate, cusolverDnDestroy, cusolverDnDgetrf, cusolverDnDgetrf_bufferSize,
35 cusolverDnDgetrs, cusolverDnHandle_t, cusolverDnSetStream, cusolverDnSgetrf,
36 cusolverDnSgetrf_bufferSize, cusolverDnSgetrs, CUBLAS_OP_N,
37};
38use baracuda_kernels_types::{
39 ArchSku, BackendKind, Element, ElementKind, KernelSku, LinalgKind, MathPrecision, OpCategory,
40 PlanPreference, PrecisionGuarantee, TensorMut, Workspace,
41};
42
43use super::cholesky::unpack_workspace;
44
45#[derive(Copy, Clone, Debug)]
47pub struct SolveDescriptor {
48 pub m: i32,
50 pub nrhs: i32,
52 pub element: ElementKind,
54}
55
56pub struct SolveArgs<'a, T: Element> {
64 pub a: TensorMut<'a, T, 2>,
67 pub b: TensorMut<'a, T, 2>,
70 pub pivot: TensorMut<'a, i32, 1>,
72 pub info: TensorMut<'a, i32, 1>,
74}
75
76pub struct SolvePlan<T: Element> {
99 desc: SolveDescriptor,
100 sku: KernelSku,
101 handle: Cell<cusolverDnHandle_t>,
102 workspace_bytes: Cell<usize>,
103 _marker: PhantomData<T>,
104}
105
106impl<T: Element> SolvePlan<T> {
107 pub fn select(
109 _stream: &Stream,
110 desc: &SolveDescriptor,
111 _pref: PlanPreference,
112 ) -> Result<Self> {
113 if desc.element != T::KIND {
114 return Err(Error::Unsupported(
115 "baracuda-kernels::SolvePlan: descriptor.element != T::KIND",
116 ));
117 }
118 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
119 return Err(Error::Unsupported(
120 "baracuda-kernels::SolvePlan: cuSOLVER dense solve supports f32 + f64 only",
121 ));
122 }
123 if desc.m <= 0 {
124 return Err(Error::InvalidProblem(
125 "baracuda-kernels::SolvePlan: m must be > 0",
126 ));
127 }
128 if desc.nrhs <= 0 {
129 return Err(Error::InvalidProblem(
130 "baracuda-kernels::SolvePlan: nrhs must be > 0",
131 ));
132 }
133
134 let math_precision = match T::KIND {
135 ElementKind::F64 => MathPrecision::F64,
136 _ => MathPrecision::F32,
137 };
138 let precision_guarantee = PrecisionGuarantee {
139 math_precision,
140 accumulator: T::KIND,
141 bit_stable_on_same_hardware: false,
142 deterministic: true,
143 };
144 let sku = KernelSku {
145 category: OpCategory::Linalg,
146 op: LinalgKind::Solve as u16,
147 element: T::KIND,
148 aux_element: Some(ElementKind::I32),
149 layout: None,
150 epilogue: None,
151 arch: ArchSku::Sm80,
152 backend: BackendKind::Cusolver,
153 precision_guarantee,
154 };
155 Ok(Self {
156 desc: *desc,
157 sku,
158 handle: Cell::new(core::ptr::null_mut()),
159 workspace_bytes: Cell::new(0),
160 _marker: PhantomData,
161 })
162 }
163
164 #[inline]
166 pub fn sku(&self) -> KernelSku {
167 self.sku
168 }
169
170 #[inline]
172 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
173 self.sku.precision_guarantee
174 }
175
176 #[inline]
179 pub fn workspace_size(&self) -> usize {
180 self.workspace_bytes.get()
181 }
182
183 pub fn query_workspace_size(&self, _stream: &Stream) -> Result<usize> {
185 let h = self.ensure_handle()?;
186 let mut lwork: i32 = 0;
187 let status = match T::KIND {
188 ElementKind::F32 => unsafe {
189 cusolverDnSgetrf_bufferSize(
190 h,
191 self.desc.m,
192 self.desc.m,
193 core::ptr::null_mut(),
194 self.desc.m,
195 &mut lwork as *mut _,
196 )
197 },
198 ElementKind::F64 => unsafe {
199 cusolverDnDgetrf_bufferSize(
200 h,
201 self.desc.m,
202 self.desc.m,
203 core::ptr::null_mut(),
204 self.desc.m,
205 &mut lwork as *mut _,
206 )
207 },
208 _ => unreachable!("select() gates on F32 / F64"),
209 };
210 if status != 0 {
211 return Err(Error::CutlassInternal(-status));
212 }
213 let bytes = (lwork as usize) * core::mem::size_of::<T>();
214 self.workspace_bytes.set(bytes);
215 Ok(bytes)
216 }
217
218 fn ensure_handle(&self) -> Result<cusolverDnHandle_t> {
219 let h = self.handle.get();
220 if !h.is_null() {
221 return Ok(h);
222 }
223 let mut handle: cusolverDnHandle_t = core::ptr::null_mut();
224 let status = unsafe { cusolverDnCreate(&mut handle as *mut _) };
225 if status != 0 {
226 return Err(Error::CutlassInternal(-status));
227 }
228 self.handle.set(handle);
229 Ok(handle)
230 }
231
232 fn bind_stream(&self, h: cusolverDnHandle_t, stream: &Stream) -> Result<()> {
233 let status = unsafe { cusolverDnSetStream(h, stream.as_raw() as *mut c_void) };
234 if status != 0 {
235 return Err(Error::CutlassInternal(-status));
236 }
237 Ok(())
238 }
239
240 fn check_args(&self, args: &SolveArgs<'_, T>) -> Result<()> {
241 let m = self.desc.m;
242 let nrhs = self.desc.nrhs;
243 if args.a.shape != [m, m] {
244 return Err(Error::InvalidProblem(
245 "baracuda-kernels::SolvePlan: A shape != [M, M]",
246 ));
247 }
248 if args.b.shape != [m, nrhs] {
249 return Err(Error::InvalidProblem(
250 "baracuda-kernels::SolvePlan: B shape != [M, NRHS]",
251 ));
252 }
253 if args.pivot.shape != [m] {
254 return Err(Error::InvalidProblem(
255 "baracuda-kernels::SolvePlan: pivot shape != [M]",
256 ));
257 }
258 if args.info.shape != [1] {
259 return Err(Error::InvalidProblem(
260 "baracuda-kernels::SolvePlan: info shape != [1]",
261 ));
262 }
263 Ok(())
264 }
265}
266
267macro_rules! impl_solve_run {
269 ($T:ty, $getrf:ident, $getrs:ident) => {
270 impl SolvePlan<$T> {
271 pub fn run(
273 &self,
274 stream: &Stream,
275 workspace: Workspace<'_>,
276 args: SolveArgs<'_, $T>,
277 ) -> Result<()> {
278 self.check_args(&args)?;
279 let h = self.ensure_handle()?;
280 self.bind_stream(h, stream)?;
281 let m = self.desc.m;
282 let nrhs = self.desc.nrhs;
283
284 let needed = if self.workspace_bytes.get() == 0 {
285 self.query_workspace_size(stream)?
286 } else {
287 self.workspace_bytes.get()
288 };
289 let (ws_ptr, _ws_bytes) = unpack_workspace(workspace, needed)?;
290
291 let a_ptr = args.a.data.as_raw().0 as *mut $T;
292 let b_ptr = args.b.data.as_raw().0 as *mut $T;
293 let pivot_ptr = args.pivot.data.as_raw().0 as *mut i32;
294 let info_ptr = args.info.data.as_raw().0 as *mut i32;
295
296 let status = unsafe {
298 $getrf(h, m, m, a_ptr, m, ws_ptr as *mut $T, pivot_ptr, info_ptr)
299 };
300 if status != 0 {
301 return Err(Error::CutlassInternal(-status));
302 }
303
304 let status = unsafe {
307 $getrs(
308 h,
309 CUBLAS_OP_N,
310 m,
311 nrhs,
312 a_ptr as *const $T,
313 m,
314 pivot_ptr as *const i32,
315 b_ptr,
316 m,
317 info_ptr,
318 )
319 };
320 if status != 0 {
321 return Err(Error::CutlassInternal(-status));
322 }
323 Ok(())
324 }
325 }
326 };
327}
328
329impl_solve_run!(f32, cusolverDnSgetrf, cusolverDnSgetrs);
330impl_solve_run!(f64, cusolverDnDgetrf, cusolverDnDgetrs);
331
332impl<T: Element> Drop for SolvePlan<T> {
333 fn drop(&mut self) {
334 let h = self.handle.get();
335 if !h.is_null() {
336 unsafe {
337 let _ = cusolverDnDestroy(h);
338 }
339 self.handle.set(core::ptr::null_mut());
340 }
341 }
342}