1use oxicuda_blas::types::{
18 DiagType, FillMode, GpuFloat, Layout, MatrixDesc, MatrixDescMut, Side, Transpose,
19};
20use oxicuda_memory::DeviceBuffer;
21use oxicuda_ptx::prelude::*;
22
23use crate::error::{SolverError, SolverResult};
24use crate::handle::SolverHandle;
25use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
26
27const LU_BLOCK_SIZE: u32 = 64;
29
30#[derive(Debug, Clone)]
38pub struct LuResult {
39 pub info: i32,
43}
44
45pub fn lu_factorize<T: GpuFloat>(
73 handle: &mut SolverHandle,
74 a: &mut DeviceBuffer<T>,
75 n: u32,
76 lda: u32,
77 pivots: &mut DeviceBuffer<i32>,
78) -> SolverResult<LuResult> {
79 if n == 0 {
81 return Ok(LuResult { info: 0 });
82 }
83 if lda < n {
84 return Err(SolverError::DimensionMismatch(format!(
85 "lu_factorize: lda ({lda}) must be >= n ({n})"
86 )));
87 }
88 let required = n as usize * lda as usize;
89 if a.len() < required {
90 return Err(SolverError::DimensionMismatch(format!(
91 "lu_factorize: buffer too small ({} < {required})",
92 a.len()
93 )));
94 }
95 if pivots.len() < n as usize {
96 return Err(SolverError::DimensionMismatch(format!(
97 "lu_factorize: pivots buffer too small ({} < {n})",
98 pivots.len()
99 )));
100 }
101
102 let panel_workspace = n as usize * LU_BLOCK_SIZE as usize * T::SIZE;
104 handle.ensure_workspace(panel_workspace)?;
105
106 blocked_lu::<T>(handle, a, n, lda, pivots)
107}
108
109pub fn lu_solve<T: GpuFloat>(
127 handle: &SolverHandle,
128 lu: &DeviceBuffer<T>,
129 pivots: &DeviceBuffer<i32>,
130 b: &mut DeviceBuffer<T>,
131 n: u32,
132 nrhs: u32,
133) -> SolverResult<()> {
134 if n == 0 || nrhs == 0 {
135 return Ok(());
136 }
137 if lu.len() < (n as usize * n as usize) {
138 return Err(SolverError::DimensionMismatch(
139 "lu_solve: LU buffer too small".into(),
140 ));
141 }
142 if pivots.len() < n as usize {
143 return Err(SolverError::DimensionMismatch(
144 "lu_solve: pivots buffer too small".into(),
145 ));
146 }
147 if b.len() < (n as usize * nrhs as usize) {
148 return Err(SolverError::DimensionMismatch(
149 "lu_solve: B buffer too small".into(),
150 ));
151 }
152
153 apply_pivots_to_rhs::<T>(handle, b, pivots, n, nrhs)?;
157
158 let l_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
160 let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, n, Layout::ColMajor);
161
162 oxicuda_blas::level3::trsm(
163 handle.blas(),
164 Side::Left,
165 FillMode::Lower,
166 Transpose::NoTrans,
167 DiagType::Unit,
168 T::gpu_one(),
169 &l_desc,
170 &mut b_desc,
171 )?;
172
173 let u_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
175
176 oxicuda_blas::level3::trsm(
177 handle.blas(),
178 Side::Left,
179 FillMode::Upper,
180 Transpose::NoTrans,
181 DiagType::NonUnit,
182 T::gpu_one(),
183 &u_desc,
184 &mut b_desc,
185 )?;
186
187 Ok(())
188}
189
190fn blocked_lu<T: GpuFloat>(
202 handle: &mut SolverHandle,
203 a: &mut DeviceBuffer<T>,
204 n: u32,
205 lda: u32,
206 pivots: &mut DeviceBuffer<i32>,
207) -> SolverResult<LuResult> {
208 let nb = LU_BLOCK_SIZE.min(n);
209 let num_blocks = n.div_ceil(nb);
210 let mut info: i32 = 0;
211
212 for block_idx in 0..num_blocks {
213 let j = block_idx * nb;
214 let jb = nb.min(n - j); let panel_info = panel_lu::<T>(handle, a, n, lda, j, jb, pivots)?;
219 if panel_info > 0 && info == 0 {
220 info = panel_info + j as i32;
221 }
222
223 if j > 0 {
226 apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, 0, j)?;
227 }
228 let right_start = j + jb;
230 if right_start < n {
231 apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, right_start, n - right_start)?;
232 }
233
234 if right_start < n {
236 let l_desc = MatrixDesc::<T>::from_raw(
237 a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
238 jb,
239 jb,
240 lda,
241 Layout::ColMajor,
242 );
243 let mut u_desc = MatrixDescMut::<T>::from_raw(
244 a.as_device_ptr() + (j as u64 + right_start as u64 * lda as u64) * T::SIZE as u64,
245 jb,
246 n - right_start,
247 lda,
248 Layout::ColMajor,
249 );
250 oxicuda_blas::level3::trsm(
251 handle.blas(),
252 Side::Left,
253 FillMode::Lower,
254 Transpose::NoTrans,
255 DiagType::Unit,
256 T::gpu_one(),
257 &l_desc,
258 &mut u_desc,
259 )?;
260 }
261
262 let remaining_rows = n.saturating_sub(j + jb);
265 let remaining_cols = n.saturating_sub(j + jb);
266 if remaining_rows > 0 && remaining_cols > 0 {
267 let a21_desc = MatrixDesc::<T>::from_raw(
268 a.as_device_ptr() + ((j + jb) as u64 + j as u64 * lda as u64) * T::SIZE as u64,
269 remaining_rows,
270 jb,
271 lda,
272 Layout::ColMajor,
273 );
274 let a12_desc = MatrixDesc::<T>::from_raw(
275 a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
276 jb,
277 remaining_cols,
278 lda,
279 Layout::ColMajor,
280 );
281 let mut a22_desc = MatrixDescMut::<T>::from_raw(
282 a.as_device_ptr()
283 + ((j + jb) as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
284 remaining_rows,
285 remaining_cols,
286 lda,
287 Layout::ColMajor,
288 );
289
290 let neg_one = T::from_bits_u64({
292 let one = T::gpu_one();
293 let bits = one.to_bits_u64();
295 if T::SIZE == 4 {
296 bits ^ 0x8000_0000
297 } else {
298 bits ^ 0x8000_0000_0000_0000
299 }
300 });
301
302 oxicuda_blas::level3::gemm_api::gemm(
303 handle.blas(),
304 Transpose::NoTrans,
305 Transpose::NoTrans,
306 neg_one,
307 &a21_desc,
308 &a12_desc,
309 T::gpu_one(),
310 &mut a22_desc,
311 )?;
312 }
313 }
314
315 Ok(LuResult { info })
316}
317
318fn panel_lu<T: GpuFloat>(
325 _handle: &SolverHandle,
326 a: &mut DeviceBuffer<T>,
327 n: u32,
328 lda: u32,
329 j: u32,
330 jb: u32,
331 pivots: &mut DeviceBuffer<i32>,
332) -> SolverResult<i32> {
333 let _ = emit_panel_lu::<T>(_handle.sm_version(), jb)?;
335
336 let n_usize = n as usize;
337 let lda_usize = lda as usize;
338 let j_usize = j as usize;
339 let jb_usize = jb as usize;
340
341 let mut a_host = vec![T::gpu_zero(); a.len()];
342 a.copy_to_host(&mut a_host)?;
343
344 let mut piv_host = vec![0_i32; pivots.len()];
345 pivots.copy_to_host(&mut piv_host)?;
346
347 let mut info: i32 = 0;
348 let panel_end = (j_usize + jb_usize).min(n_usize);
349
350 for kk in 0..jb_usize {
351 let col = j_usize + kk;
352 if col >= n_usize {
353 break;
354 }
355
356 let mut pivot_row = col;
358 let mut max_abs = 0.0_f64;
359 for row in col..n_usize {
360 let bits = a_host[col * lda_usize + row].to_bits_u64();
361 let val = if T::SIZE == 8 {
362 f64::from_bits(bits)
363 } else {
364 f64::from(f32::from_bits(bits as u32))
365 };
366 let abs = val.abs();
367 if abs > max_abs {
368 max_abs = abs;
369 pivot_row = row;
370 }
371 }
372
373 piv_host[col] = pivot_row as i32;
374
375 if pivot_row != col {
377 for c in j_usize..panel_end {
378 a_host.swap(c * lda_usize + col, c * lda_usize + pivot_row);
379 }
380 }
381
382 let pivot_bits = a_host[col * lda_usize + col].to_bits_u64();
384 let pivot_val = if T::SIZE == 8 {
385 f64::from_bits(pivot_bits)
386 } else {
387 f64::from(f32::from_bits(pivot_bits as u32))
388 };
389 if info == 0 && pivot_val.abs() <= 1e-30 {
390 info = (kk + 1) as i32;
391 continue;
392 }
393
394 for row in (col + 1)..n_usize {
396 let x_bits = a_host[col * lda_usize + row].to_bits_u64();
397 let x = if T::SIZE == 8 {
398 f64::from_bits(x_bits)
399 } else {
400 f64::from(f32::from_bits(x_bits as u32))
401 };
402 let scaled = x / pivot_val;
403 a_host[col * lda_usize + row] = if T::SIZE == 8 {
404 T::from_bits_u64(scaled.to_bits())
405 } else {
406 T::from_bits_u64(u64::from((scaled as f32).to_bits()))
407 };
408 }
409
410 for c in (col + 1)..panel_end {
412 let uk_bits = a_host[c * lda_usize + col].to_bits_u64();
413 let u_kc = if T::SIZE == 8 {
414 f64::from_bits(uk_bits)
415 } else {
416 f64::from(f32::from_bits(uk_bits as u32))
417 };
418 for row in (col + 1)..n_usize {
419 let l_bits = a_host[col * lda_usize + row].to_bits_u64();
420 let l_rc = if T::SIZE == 8 {
421 f64::from_bits(l_bits)
422 } else {
423 f64::from(f32::from_bits(l_bits as u32))
424 };
425 let a_bits = a_host[c * lda_usize + row].to_bits_u64();
426 let a_rc = if T::SIZE == 8 {
427 f64::from_bits(a_bits)
428 } else {
429 f64::from(f32::from_bits(a_bits as u32))
430 };
431 let updated = a_rc - l_rc * u_kc;
432 a_host[c * lda_usize + row] = if T::SIZE == 8 {
433 T::from_bits_u64(updated.to_bits())
434 } else {
435 T::from_bits_u64(u64::from((updated as f32).to_bits()))
436 };
437 }
438 }
439 }
440
441 a.copy_from_host(&a_host)?;
442 pivots.copy_from_host(&piv_host)?;
443
444 Ok(info)
445}
446
447#[allow(clippy::too_many_arguments)]
452fn apply_panel_pivots<T: GpuFloat>(
453 _handle: &SolverHandle,
454 a: &mut DeviceBuffer<T>,
455 lda: u32,
456 j: u32,
457 jb: u32,
458 pivots: &DeviceBuffer<i32>,
459 col_start: u32,
460 col_count: u32,
461) -> SolverResult<()> {
462 if col_count == 0 || jb == 0 {
463 return Ok(());
464 }
465
466 let _ = emit_pivot_swap::<T>(_handle.sm_version())?;
468
469 let lda_usize = lda as usize;
470 let j_usize = j as usize;
471 let jb_usize = jb as usize;
472 let col_start_usize = col_start as usize;
473 let col_end = col_start_usize + col_count as usize;
474
475 let mut a_host = vec![T::gpu_zero(); a.len()];
476 a.copy_to_host(&mut a_host)?;
477 let mut piv_host = vec![0_i32; pivots.len()];
478 pivots.copy_to_host(&mut piv_host)?;
479
480 for t in 0..jb_usize {
481 let row = j_usize + t;
482 if row >= piv_host.len() {
483 break;
484 }
485 let piv = piv_host[row].max(0) as usize;
486 if piv >= lda_usize {
487 return Err(SolverError::DimensionMismatch(format!(
488 "apply_panel_pivots: pivot index out of range ({piv} >= lda {lda_usize})"
489 )));
490 }
491 if piv == row {
492 continue;
493 }
494 for col in col_start_usize..col_end {
495 a_host.swap(col * lda_usize + row, col * lda_usize + piv);
496 }
497 }
498
499 a.copy_from_host(&a_host)?;
500
501 Ok(())
502}
503
504fn apply_pivots_to_rhs<T: GpuFloat>(
506 _handle: &SolverHandle,
507 b: &mut DeviceBuffer<T>,
508 pivots: &DeviceBuffer<i32>,
509 n: u32,
510 nrhs: u32,
511) -> SolverResult<()> {
512 if n == 0 || nrhs == 0 {
513 return Ok(());
514 }
515
516 let _ = emit_pivot_swap::<T>(_handle.sm_version())?;
518
519 let n_usize = n as usize;
520 let nrhs_usize = nrhs as usize;
521
522 let mut b_host = vec![T::gpu_zero(); b.len()];
523 b.copy_to_host(&mut b_host)?;
524 let mut piv_host = vec![0_i32; pivots.len()];
525 pivots.copy_to_host(&mut piv_host)?;
526
527 for row in 0..n_usize {
529 if row >= piv_host.len() {
530 break;
531 }
532 let piv = piv_host[row].max(0) as usize;
533 if piv >= n_usize {
534 return Err(SolverError::DimensionMismatch(format!(
535 "apply_pivots_to_rhs: pivot index out of range ({piv} >= n {n_usize})"
536 )));
537 }
538 if piv == row {
539 continue;
540 }
541 for col in 0..nrhs_usize {
542 b_host.swap(col * n_usize + row, col * n_usize + piv);
543 }
544 }
545
546 b.copy_from_host(&b_host)?;
547
548 Ok(())
549}
550
551fn panel_lu_name<T: GpuFloat>(block_size: u32) -> String {
556 format!("solver_panel_lu_{}_{}", T::NAME, block_size)
557}
558
559fn pivot_swap_name<T: GpuFloat>() -> String {
560 format!("solver_pivot_swap_{}", T::NAME)
561}
562
563fn emit_panel_lu<T: GpuFloat>(sm: SmVersion, panel_cols: u32) -> SolverResult<String> {
569 let name = panel_lu_name::<T>(panel_cols);
570 let float_ty = T::PTX_TYPE;
571
572 let ptx = KernelBuilder::new(&name)
573 .target(sm)
574 .max_threads_per_block(SOLVER_BLOCK_SIZE)
575 .param("panel_ptr", PtxType::U64)
576 .param("pivots_ptr", PtxType::U64)
577 .param("panel_rows", PtxType::U32)
578 .param("panel_cols", PtxType::U32)
579 .param("lda", PtxType::U32)
580 .body(move |b| {
581 let tid = b.thread_id_x();
582 let panel_rows_reg = b.load_param_u32("panel_rows");
583 let panel_cols_reg = b.load_param_u32("panel_cols");
584 let lda_reg = b.load_param_u32("lda");
585 let panel_ptr = b.load_param_u64("panel_ptr");
586
587 let _ = (
598 tid,
599 panel_rows_reg,
600 panel_cols_reg,
601 lda_reg,
602 panel_ptr,
603 float_ty,
604 );
605
606 b.ret();
607 })
608 .build()?;
609
610 Ok(ptx)
611}
612
613fn emit_pivot_swap<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
618 let name = pivot_swap_name::<T>();
619 let float_ty = T::PTX_TYPE;
620
621 let ptx = KernelBuilder::new(&name)
622 .target(sm)
623 .max_threads_per_block(SOLVER_BLOCK_SIZE)
624 .param("a_ptr", PtxType::U64)
625 .param("pivots_ptr", PtxType::U64)
626 .param("j", PtxType::U32)
627 .param("jb", PtxType::U32)
628 .param("col_start", PtxType::U32)
629 .param("col_count", PtxType::U32)
630 .param("lda", PtxType::U32)
631 .body(move |b| {
632 let gid = b.global_thread_id_x();
633 let col_count_reg = b.load_param_u32("col_count");
634
635 b.if_lt_u32(gid.clone(), col_count_reg, |b| {
636 let a_ptr = b.load_param_u64("a_ptr");
637 let col_start = b.load_param_u32("col_start");
638 let lda = b.load_param_u32("lda");
639
640 let col_idx = b.add_u32(gid, col_start);
642
643 let col_elem_offset = b.mul_lo_u32(col_idx, lda);
645 let _col_base = b.byte_offset_addr(a_ptr, col_elem_offset, T::size_u32());
646
647 let _ = float_ty;
650 });
651
652 b.ret();
653 })
654 .build()?;
655
656 Ok(ptx)
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 fn doolittle_lu_4x4(a: &[[f64; 4]; 4]) -> ([[f64; 4]; 4], [[f64; 4]; 4]) {
672 let mut l = [[0.0_f64; 4]; 4];
673 let mut u = [[0.0_f64; 4]; 4];
674
675 for i in 0..4 {
676 l[i][i] = 1.0; for j in i..4 {
680 let sum: f64 = (0..i).map(|k| l[i][k] * u[k][j]).sum();
681 u[i][j] = a[i][j] - sum;
682 }
683
684 for j in (i + 1)..4 {
686 let sum: f64 = (0..i).map(|k| l[j][k] * u[k][i]).sum();
687 if u[i][i].abs() > 1e-15 {
688 l[j][i] = (a[j][i] - sum) / u[i][i];
689 }
690 }
691 }
692
693 (l, u)
694 }
695
696 fn matmul_4x4(a: &[[f64; 4]; 4], b: &[[f64; 4]; 4]) -> [[f64; 4]; 4] {
698 let mut c = [[0.0_f64; 4]; 4];
699 for i in 0..4 {
700 for j in 0..4 {
701 for k in 0..4 {
702 c[i][j] += a[i][k] * b[k][j];
703 }
704 }
705 }
706 c
707 }
708
709 #[test]
714 fn lu_trsm_trailing_update() {
715 let a = [
717 [4.0_f64, 3.0, 2.0, 1.0],
718 [2.0, 5.0, 3.0, 2.0],
719 [1.0, 2.0, 6.0, 3.0],
720 [1.0, 1.0, 2.0, 7.0],
721 ];
722 let (l, u) = doolittle_lu_4x4(&a);
723
724 for (i, l_row) in l.iter().enumerate() {
726 assert!(
727 (l_row[i] - 1.0).abs() < 1e-15,
728 "L[{i},{i}] must be 1.0 (unit diagonal)"
729 );
730 for (j, &val) in l_row.iter().enumerate().filter(|(j, _)| *j > i) {
731 assert!(
732 val.abs() < 1e-15,
733 "L[{i},{j}] = {val} must be 0.0 (upper triangle)",
734 );
735 }
736 }
737
738 for (i, u_row) in u.iter().enumerate() {
740 for (j, &val) in u_row.iter().enumerate().filter(|(j, _)| *j < i) {
741 assert!(
742 val.abs() < 1e-15,
743 "U[{i},{j}] = {val} must be 0.0 (lower triangle)",
744 );
745 }
746 }
747
748 let reconstructed = matmul_4x4(&l, &u);
750 for i in 0..4 {
751 for j in 0..4 {
752 assert!(
753 (reconstructed[i][j] - a[i][j]).abs() < 1e-10,
754 "LU[{i},{j}] = {} ≠ A[{i},{j}] = {} (diff = {})",
755 reconstructed[i][j],
756 a[i][j],
757 (reconstructed[i][j] - a[i][j]).abs()
758 );
759 }
760 }
761 }
762
763 #[test]
764 fn lu_gemm_rank_update_correctness() {
765 let a = [[2.0_f64, 4.0, 6.0], [1.0, 3.0, 5.0], [1.0, 2.0, 4.0]];
773
774 let l_col0 = [1.0_f64, a[1][0] / a[0][0], a[2][0] / a[0][0]];
783 let u_row0 = [a[0][0], a[0][1], a[0][2]];
784
785 let mut trailing = [[0.0_f64; 2]; 2];
787 for i in 0..2 {
788 for j in 0..2 {
789 trailing[i][j] = a[i + 1][j + 1] - l_col0[i + 1] * u_row0[j + 1];
790 }
791 }
792
793 assert!(
794 (trailing[0][0] - 1.0).abs() < 1e-12,
795 "trailing[0,0] should be 1"
796 );
797 assert!(
798 (trailing[0][1] - 2.0).abs() < 1e-12,
799 "trailing[0,1] should be 2"
800 );
801 assert!(trailing[1][0].abs() < 1e-12, "trailing[1,0] should be 0");
802 assert!(
803 (trailing[1][1] - 1.0).abs() < 1e-12,
804 "trailing[1,1] should be 1"
805 );
806 }
807
808 #[test]
809 fn lu_block_size_positive() {
810 let block_size = LU_BLOCK_SIZE;
811 assert!(block_size > 0);
812 assert!(block_size <= 256);
813 }
814
815 #[test]
816 fn lu_result_info() {
817 let result = LuResult { info: 0 };
818 assert_eq!(result.info, 0);
819
820 let singular = LuResult { info: 3 };
821 assert!(singular.info > 0);
822 }
823
824 #[test]
825 fn panel_lu_name_format() {
826 let name = panel_lu_name::<f32>(64);
827 assert!(name.contains("f32"));
828 assert!(name.contains("64"));
829 }
830
831 #[test]
832 fn pivot_swap_name_format() {
833 let name = pivot_swap_name::<f64>();
834 assert!(name.contains("f64"));
835 }
836
837 #[test]
838 fn neg_one_f32() {
839 let neg = f32::from_bits_u64(f32::gpu_one().to_bits_u64() ^ 0x8000_0000);
840 assert!((neg + 1.0).abs() < 1e-10);
841 }
842
843 #[test]
844 fn neg_one_f64() {
845 let neg = f64::from_bits_u64(f64::gpu_one().to_bits_u64() ^ 0x8000_0000_0000_0000);
846 assert!((neg + 1.0).abs() < 1e-15);
847 }
848}