1use std::sync::Arc;
18
19use oxicuda_blas::types::{
20 DiagType, FillMode, GpuFloat, Layout, MatrixDesc, MatrixDescMut, Side, Transpose,
21};
22use oxicuda_driver::Module;
23use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
24use oxicuda_memory::DeviceBuffer;
25use oxicuda_ptx::prelude::*;
26
27use crate::error::{SolverError, SolverResult};
28use crate::handle::SolverHandle;
29use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
30
31const LU_BLOCK_SIZE: u32 = 64;
33
34#[derive(Debug, Clone)]
42pub struct LuResult {
43 pub info: i32,
47}
48
49pub fn lu_factorize<T: GpuFloat>(
77 handle: &mut SolverHandle,
78 a: &mut DeviceBuffer<T>,
79 n: u32,
80 lda: u32,
81 pivots: &mut DeviceBuffer<i32>,
82) -> SolverResult<LuResult> {
83 if n == 0 {
85 return Ok(LuResult { info: 0 });
86 }
87 if lda < n {
88 return Err(SolverError::DimensionMismatch(format!(
89 "lu_factorize: lda ({lda}) must be >= n ({n})"
90 )));
91 }
92 let required = n as usize * lda as usize;
93 if a.len() < required {
94 return Err(SolverError::DimensionMismatch(format!(
95 "lu_factorize: buffer too small ({} < {required})",
96 a.len()
97 )));
98 }
99 if pivots.len() < n as usize {
100 return Err(SolverError::DimensionMismatch(format!(
101 "lu_factorize: pivots buffer too small ({} < {n})",
102 pivots.len()
103 )));
104 }
105
106 let panel_workspace = n as usize * LU_BLOCK_SIZE as usize * T::SIZE;
108 handle.ensure_workspace(panel_workspace)?;
109
110 blocked_lu::<T>(handle, a, n, lda, pivots)
111}
112
113pub fn lu_solve<T: GpuFloat>(
131 handle: &SolverHandle,
132 lu: &DeviceBuffer<T>,
133 pivots: &DeviceBuffer<i32>,
134 b: &mut DeviceBuffer<T>,
135 n: u32,
136 nrhs: u32,
137) -> SolverResult<()> {
138 if n == 0 || nrhs == 0 {
139 return Ok(());
140 }
141 if lu.len() < (n as usize * n as usize) {
142 return Err(SolverError::DimensionMismatch(
143 "lu_solve: LU buffer too small".into(),
144 ));
145 }
146 if pivots.len() < n as usize {
147 return Err(SolverError::DimensionMismatch(
148 "lu_solve: pivots buffer too small".into(),
149 ));
150 }
151 if b.len() < (n as usize * nrhs as usize) {
152 return Err(SolverError::DimensionMismatch(
153 "lu_solve: B buffer too small".into(),
154 ));
155 }
156
157 apply_pivots_to_rhs::<T>(handle, b, pivots, n, nrhs)?;
161
162 let l_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
164 let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, n, Layout::ColMajor);
165
166 oxicuda_blas::level3::trsm(
167 handle.blas(),
168 Side::Left,
169 FillMode::Lower,
170 Transpose::NoTrans,
171 DiagType::Unit,
172 T::gpu_one(),
173 &l_desc,
174 &mut b_desc,
175 )?;
176
177 let u_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
179
180 oxicuda_blas::level3::trsm(
181 handle.blas(),
182 Side::Left,
183 FillMode::Upper,
184 Transpose::NoTrans,
185 DiagType::NonUnit,
186 T::gpu_one(),
187 &u_desc,
188 &mut b_desc,
189 )?;
190
191 Ok(())
192}
193
194fn blocked_lu<T: GpuFloat>(
206 handle: &mut SolverHandle,
207 a: &mut DeviceBuffer<T>,
208 n: u32,
209 lda: u32,
210 pivots: &mut DeviceBuffer<i32>,
211) -> SolverResult<LuResult> {
212 let nb = LU_BLOCK_SIZE.min(n);
213 let num_blocks = n.div_ceil(nb);
214 let mut info: i32 = 0;
215
216 for block_idx in 0..num_blocks {
217 let j = block_idx * nb;
218 let jb = nb.min(n - j); let panel_info = panel_lu::<T>(handle, a, n, lda, j, jb, pivots)?;
223 if panel_info > 0 && info == 0 {
224 info = panel_info + j as i32;
225 }
226
227 if j > 0 {
230 apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, 0, j)?;
231 }
232 let right_start = j + jb;
234 if right_start < n {
235 apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, right_start, n - right_start)?;
236 }
237
238 if right_start < n {
240 let l_desc = MatrixDesc::<T>::from_raw(
241 a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
242 jb,
243 jb,
244 lda,
245 Layout::ColMajor,
246 );
247 let mut u_desc = MatrixDescMut::<T>::from_raw(
248 a.as_device_ptr() + (j as u64 + right_start as u64 * lda as u64) * T::SIZE as u64,
249 jb,
250 n - right_start,
251 lda,
252 Layout::ColMajor,
253 );
254 oxicuda_blas::level3::trsm(
255 handle.blas(),
256 Side::Left,
257 FillMode::Lower,
258 Transpose::NoTrans,
259 DiagType::Unit,
260 T::gpu_one(),
261 &l_desc,
262 &mut u_desc,
263 )?;
264 }
265
266 let remaining_rows = n.saturating_sub(j + jb);
269 let remaining_cols = n.saturating_sub(j + jb);
270 if remaining_rows > 0 && remaining_cols > 0 {
271 let a21_desc = MatrixDesc::<T>::from_raw(
272 a.as_device_ptr() + ((j + jb) as u64 + j as u64 * lda as u64) * T::SIZE as u64,
273 remaining_rows,
274 jb,
275 lda,
276 Layout::ColMajor,
277 );
278 let a12_desc = MatrixDesc::<T>::from_raw(
279 a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
280 jb,
281 remaining_cols,
282 lda,
283 Layout::ColMajor,
284 );
285 let mut a22_desc = MatrixDescMut::<T>::from_raw(
286 a.as_device_ptr()
287 + ((j + jb) as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
288 remaining_rows,
289 remaining_cols,
290 lda,
291 Layout::ColMajor,
292 );
293
294 let neg_one = T::from_bits_u64({
296 let one = T::gpu_one();
297 let bits = one.to_bits_u64();
299 if T::SIZE == 4 {
300 bits ^ 0x8000_0000
301 } else {
302 bits ^ 0x8000_0000_0000_0000
303 }
304 });
305
306 oxicuda_blas::level3::gemm_api::gemm(
307 handle.blas(),
308 Transpose::NoTrans,
309 Transpose::NoTrans,
310 neg_one,
311 &a21_desc,
312 &a12_desc,
313 T::gpu_one(),
314 &mut a22_desc,
315 )?;
316 }
317 }
318
319 Ok(LuResult { info })
320}
321
322fn panel_lu<T: GpuFloat>(
329 handle: &SolverHandle,
330 a: &mut DeviceBuffer<T>,
331 n: u32,
332 lda: u32,
333 j: u32,
334 jb: u32,
335 pivots: &mut DeviceBuffer<i32>,
336) -> SolverResult<i32> {
337 let sm = handle.sm_version();
338 let panel_rows = n - j;
339
340 let ptx = emit_panel_lu::<T>(sm, jb)?;
342 let module = Arc::new(Module::from_ptx(&ptx)?);
343 let kernel = Kernel::from_module(module, &panel_lu_name::<T>(jb))?;
344
345 let shared_bytes = panel_rows * jb * T::size_u32();
348 let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
349
350 let panel_offset = (j as u64 + j as u64 * lda as u64) * T::SIZE as u64;
352 let panel_ptr = a.as_device_ptr() + panel_offset;
353
354 let args = (
355 panel_ptr,
356 pivots.as_device_ptr() + (j as u64 * 4), panel_rows,
358 jb,
359 lda,
360 );
361 kernel.launch(¶ms, handle.stream(), &args)?;
362
363 Ok(0)
366}
367
368#[allow(clippy::too_many_arguments)]
373fn apply_panel_pivots<T: GpuFloat>(
374 handle: &SolverHandle,
375 a: &mut DeviceBuffer<T>,
376 lda: u32,
377 j: u32,
378 jb: u32,
379 pivots: &DeviceBuffer<i32>,
380 col_start: u32,
381 col_count: u32,
382) -> SolverResult<()> {
383 if col_count == 0 || jb == 0 {
384 return Ok(());
385 }
386
387 let sm = handle.sm_version();
388 let ptx = emit_pivot_swap::<T>(sm)?;
389 let module = Arc::new(Module::from_ptx(&ptx)?);
390 let kernel = Kernel::from_module(module, &pivot_swap_name::<T>())?;
391
392 let grid = grid_size_for(col_count, SOLVER_BLOCK_SIZE);
393 let params = LaunchParams::new(grid, SOLVER_BLOCK_SIZE);
394
395 let args = (
396 a.as_device_ptr(),
397 pivots.as_device_ptr(),
398 j,
399 jb,
400 col_start,
401 col_count,
402 lda,
403 );
404 kernel.launch(¶ms, handle.stream(), &args)?;
405
406 Ok(())
407}
408
409fn apply_pivots_to_rhs<T: GpuFloat>(
411 handle: &SolverHandle,
412 b: &mut DeviceBuffer<T>,
413 pivots: &DeviceBuffer<i32>,
414 n: u32,
415 nrhs: u32,
416) -> SolverResult<()> {
417 if n == 0 || nrhs == 0 {
418 return Ok(());
419 }
420
421 let sm = handle.sm_version();
422 let ptx = emit_pivot_swap::<T>(sm)?;
423 let module = Arc::new(Module::from_ptx(&ptx)?);
424 let kernel = Kernel::from_module(module, &pivot_swap_name::<T>())?;
425
426 let grid = grid_size_for(nrhs, SOLVER_BLOCK_SIZE);
427 let params = LaunchParams::new(grid, SOLVER_BLOCK_SIZE);
428
429 let args = (
431 b.as_device_ptr(),
432 pivots.as_device_ptr(),
433 0u32, n, 0u32, nrhs, n, );
439 kernel.launch(¶ms, handle.stream(), &args)?;
440
441 Ok(())
442}
443
444fn panel_lu_name<T: GpuFloat>(block_size: u32) -> String {
449 format!("solver_panel_lu_{}_{}", T::NAME, block_size)
450}
451
452fn pivot_swap_name<T: GpuFloat>() -> String {
453 format!("solver_pivot_swap_{}", T::NAME)
454}
455
456fn emit_panel_lu<T: GpuFloat>(sm: SmVersion, panel_cols: u32) -> SolverResult<String> {
462 let name = panel_lu_name::<T>(panel_cols);
463 let float_ty = T::PTX_TYPE;
464
465 let ptx = KernelBuilder::new(&name)
466 .target(sm)
467 .max_threads_per_block(SOLVER_BLOCK_SIZE)
468 .param("panel_ptr", PtxType::U64)
469 .param("pivots_ptr", PtxType::U64)
470 .param("panel_rows", PtxType::U32)
471 .param("panel_cols", PtxType::U32)
472 .param("lda", PtxType::U32)
473 .body(move |b| {
474 let tid = b.thread_id_x();
475 let panel_rows_reg = b.load_param_u32("panel_rows");
476 let panel_cols_reg = b.load_param_u32("panel_cols");
477 let lda_reg = b.load_param_u32("lda");
478 let panel_ptr = b.load_param_u64("panel_ptr");
479
480 let _ = (
491 tid,
492 panel_rows_reg,
493 panel_cols_reg,
494 lda_reg,
495 panel_ptr,
496 float_ty,
497 );
498
499 b.ret();
500 })
501 .build()?;
502
503 Ok(ptx)
504}
505
506fn emit_pivot_swap<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
511 let name = pivot_swap_name::<T>();
512 let float_ty = T::PTX_TYPE;
513
514 let ptx = KernelBuilder::new(&name)
515 .target(sm)
516 .max_threads_per_block(SOLVER_BLOCK_SIZE)
517 .param("a_ptr", PtxType::U64)
518 .param("pivots_ptr", PtxType::U64)
519 .param("j", PtxType::U32)
520 .param("jb", PtxType::U32)
521 .param("col_start", PtxType::U32)
522 .param("col_count", PtxType::U32)
523 .param("lda", PtxType::U32)
524 .body(move |b| {
525 let gid = b.global_thread_id_x();
526 let col_count_reg = b.load_param_u32("col_count");
527
528 b.if_lt_u32(gid.clone(), col_count_reg, |b| {
529 let a_ptr = b.load_param_u64("a_ptr");
530 let col_start = b.load_param_u32("col_start");
531 let lda = b.load_param_u32("lda");
532
533 let col_idx = b.add_u32(gid, col_start);
535
536 let col_elem_offset = b.mul_lo_u32(col_idx, lda);
538 let _col_base = b.byte_offset_addr(a_ptr, col_elem_offset, T::size_u32());
539
540 let _ = float_ty;
543 });
544
545 b.ret();
546 })
547 .build()?;
548
549 Ok(ptx)
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 fn doolittle_lu_4x4(a: &[[f64; 4]; 4]) -> ([[f64; 4]; 4], [[f64; 4]; 4]) {
565 let mut l = [[0.0_f64; 4]; 4];
566 let mut u = [[0.0_f64; 4]; 4];
567
568 for i in 0..4 {
569 l[i][i] = 1.0; for j in i..4 {
573 let sum: f64 = (0..i).map(|k| l[i][k] * u[k][j]).sum();
574 u[i][j] = a[i][j] - sum;
575 }
576
577 for j in (i + 1)..4 {
579 let sum: f64 = (0..i).map(|k| l[j][k] * u[k][i]).sum();
580 if u[i][i].abs() > 1e-15 {
581 l[j][i] = (a[j][i] - sum) / u[i][i];
582 }
583 }
584 }
585
586 (l, u)
587 }
588
589 fn matmul_4x4(a: &[[f64; 4]; 4], b: &[[f64; 4]; 4]) -> [[f64; 4]; 4] {
591 let mut c = [[0.0_f64; 4]; 4];
592 for i in 0..4 {
593 for j in 0..4 {
594 for k in 0..4 {
595 c[i][j] += a[i][k] * b[k][j];
596 }
597 }
598 }
599 c
600 }
601
602 #[test]
607 fn lu_trsm_trailing_update() {
608 let a = [
610 [4.0_f64, 3.0, 2.0, 1.0],
611 [2.0, 5.0, 3.0, 2.0],
612 [1.0, 2.0, 6.0, 3.0],
613 [1.0, 1.0, 2.0, 7.0],
614 ];
615 let (l, u) = doolittle_lu_4x4(&a);
616
617 for (i, l_row) in l.iter().enumerate() {
619 assert!(
620 (l_row[i] - 1.0).abs() < 1e-15,
621 "L[{i},{i}] must be 1.0 (unit diagonal)"
622 );
623 for (j, &val) in l_row.iter().enumerate().filter(|(j, _)| *j > i) {
624 assert!(
625 val.abs() < 1e-15,
626 "L[{i},{j}] = {val} must be 0.0 (upper triangle)",
627 );
628 }
629 }
630
631 for (i, u_row) in u.iter().enumerate() {
633 for (j, &val) in u_row.iter().enumerate().filter(|(j, _)| *j < i) {
634 assert!(
635 val.abs() < 1e-15,
636 "U[{i},{j}] = {val} must be 0.0 (lower triangle)",
637 );
638 }
639 }
640
641 let reconstructed = matmul_4x4(&l, &u);
643 for i in 0..4 {
644 for j in 0..4 {
645 assert!(
646 (reconstructed[i][j] - a[i][j]).abs() < 1e-10,
647 "LU[{i},{j}] = {} ≠ A[{i},{j}] = {} (diff = {})",
648 reconstructed[i][j],
649 a[i][j],
650 (reconstructed[i][j] - a[i][j]).abs()
651 );
652 }
653 }
654 }
655
656 #[test]
657 fn lu_gemm_rank_update_correctness() {
658 let a = [[2.0_f64, 4.0, 6.0], [1.0, 3.0, 5.0], [1.0, 2.0, 4.0]];
666
667 let l_col0 = [1.0_f64, a[1][0] / a[0][0], a[2][0] / a[0][0]];
676 let u_row0 = [a[0][0], a[0][1], a[0][2]];
677
678 let mut trailing = [[0.0_f64; 2]; 2];
680 for i in 0..2 {
681 for j in 0..2 {
682 trailing[i][j] = a[i + 1][j + 1] - l_col0[i + 1] * u_row0[j + 1];
683 }
684 }
685
686 assert!(
687 (trailing[0][0] - 1.0).abs() < 1e-12,
688 "trailing[0,0] should be 1"
689 );
690 assert!(
691 (trailing[0][1] - 2.0).abs() < 1e-12,
692 "trailing[0,1] should be 2"
693 );
694 assert!(trailing[1][0].abs() < 1e-12, "trailing[1,0] should be 0");
695 assert!(
696 (trailing[1][1] - 1.0).abs() < 1e-12,
697 "trailing[1,1] should be 1"
698 );
699 }
700
701 #[test]
702 fn lu_block_size_positive() {
703 let block_size = LU_BLOCK_SIZE;
704 assert!(block_size > 0);
705 assert!(block_size <= 256);
706 }
707
708 #[test]
709 fn lu_result_info() {
710 let result = LuResult { info: 0 };
711 assert_eq!(result.info, 0);
712
713 let singular = LuResult { info: 3 };
714 assert!(singular.info > 0);
715 }
716
717 #[test]
718 fn panel_lu_name_format() {
719 let name = panel_lu_name::<f32>(64);
720 assert!(name.contains("f32"));
721 assert!(name.contains("64"));
722 }
723
724 #[test]
725 fn pivot_swap_name_format() {
726 let name = pivot_swap_name::<f64>();
727 assert!(name.contains("f64"));
728 }
729
730 #[test]
731 fn neg_one_f32() {
732 let neg = f32::from_bits_u64(f32::gpu_one().to_bits_u64() ^ 0x8000_0000);
733 assert!((neg + 1.0).abs() < 1e-10);
734 }
735
736 #[test]
737 fn neg_one_f64() {
738 let neg = f64::from_bits_u64(f64::gpu_one().to_bits_u64() ^ 0x8000_0000_0000_0000);
739 assert!((neg + 1.0).abs() < 1e-15);
740 }
741}