1#![allow(dead_code)]
25
26use oxicuda_blas::types::{FillMode, GpuFloat};
27use oxicuda_memory::DeviceBuffer;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31
32fn to_f64<T: GpuFloat>(val: T) -> f64 {
37 if T::SIZE == 4 {
38 f32::from_bits(val.to_bits_u64() as u32) as f64
39 } else {
40 f64::from_bits(val.to_bits_u64())
41 }
42}
43
44fn from_f64<T: GpuFloat>(val: f64) -> T {
45 if T::SIZE == 4 {
46 T::from_bits_u64(u64::from((val as f32).to_bits()))
47 } else {
48 T::from_bits_u64(val.to_bits())
49 }
50}
51
52const BUNCH_KAUFMAN_ALPHA: f64 = 0.6403882032022076;
58
59pub struct LdltResult {
71 pub pivot_info: DeviceBuffer<i32>,
73}
74
75impl std::fmt::Debug for LdltResult {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("LdltResult")
78 .field("pivot_info_len", &self.pivot_info.len())
79 .finish()
80 }
81}
82
83pub fn ldlt<T: GpuFloat>(
109 handle: &mut SolverHandle,
110 a: &mut DeviceBuffer<T>,
111 n: usize,
112 uplo: FillMode,
113) -> SolverResult<LdltResult> {
114 if n == 0 {
115 let pivot_info = DeviceBuffer::<i32>::zeroed(0)?;
116 return Ok(LdltResult { pivot_info });
117 }
118 if a.len() < n * n {
119 return Err(SolverError::DimensionMismatch(format!(
120 "ldlt: buffer too small ({} < {})",
121 a.len(),
122 n * n
123 )));
124 }
125 if uplo == FillMode::Full {
126 return Err(SolverError::DimensionMismatch(
127 "ldlt: uplo must be Upper or Lower, not Full".into(),
128 ));
129 }
130
131 let ws = n * n * std::mem::size_of::<f64>();
133 handle.ensure_workspace(ws)?;
134
135 let mut a_host = vec![0.0_f64; n * n];
137 read_device_to_host(a, &mut a_host, n * n)?;
138
139 let mut ipiv = vec![0_i32; n];
141 bunch_kaufman_factorize(&mut a_host, n, uplo, &mut ipiv)?;
142
143 let a_device: Vec<T> = a_host.iter().map(|&v| from_f64(v)).collect();
145 write_host_to_device(a, &a_device, n * n)?;
146
147 let mut pivot_info = DeviceBuffer::<i32>::zeroed(n)?;
148 write_host_to_device_i32(&mut pivot_info, &ipiv, n)?;
149
150 Ok(LdltResult { pivot_info })
151}
152
153pub fn ldlt_solve<T: GpuFloat>(
178 handle: &mut SolverHandle,
179 a: &DeviceBuffer<T>,
180 pivot_info: &DeviceBuffer<i32>,
181 b: &mut DeviceBuffer<T>,
182 n: usize,
183 nrhs: usize,
184 uplo: FillMode,
185) -> SolverResult<()> {
186 if n == 0 || nrhs == 0 {
187 return Ok(());
188 }
189 if a.len() < n * n {
190 return Err(SolverError::DimensionMismatch(
191 "ldlt_solve: factor buffer too small".into(),
192 ));
193 }
194 if pivot_info.len() < n {
195 return Err(SolverError::DimensionMismatch(
196 "ldlt_solve: pivot_info buffer too small".into(),
197 ));
198 }
199 if b.len() < n * nrhs {
200 return Err(SolverError::DimensionMismatch(
201 "ldlt_solve: B buffer too small".into(),
202 ));
203 }
204
205 let ws = (n * n + n * nrhs) * std::mem::size_of::<f64>();
207 handle.ensure_workspace(ws)?;
208
209 let mut a_host = vec![0.0_f64; n * n];
211 read_device_to_host(a, &mut a_host, n * n)?;
212
213 let mut ipiv = vec![0_i32; n];
214 read_device_to_host_i32(pivot_info, &mut ipiv, n)?;
215
216 let mut b_host = vec![0.0_f64; n * nrhs];
217 read_device_to_host(b, &mut b_host, n * nrhs)?;
218
219 bunch_kaufman_solve(&a_host, &ipiv, &mut b_host, n, nrhs, uplo)?;
221
222 let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
224 write_host_to_device(b, &b_device, n * nrhs)?;
225
226 Ok(())
227}
228
229fn bunch_kaufman_factorize(
235 a: &mut [f64],
236 n: usize,
237 uplo: FillMode,
238 ipiv: &mut [i32],
239) -> SolverResult<()> {
240 match uplo {
241 FillMode::Lower => bunch_kaufman_lower(a, n, ipiv),
242 FillMode::Upper => bunch_kaufman_upper(a, n, ipiv),
243 FillMode::Full => Err(SolverError::DimensionMismatch(
244 "ldlt: uplo must be Lower or Upper".into(),
245 )),
246 }
247}
248
249fn bunch_kaufman_lower(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
251 let mut k = 0;
252
253 while k < n {
254 let (lambda, r_idx) = column_max_offdiag(a, n, k, true);
256
257 let akk = a[k * n + k].abs();
258
259 if akk < 1e-300 && lambda < 1e-300 {
260 return Err(SolverError::SingularMatrix);
262 }
263
264 if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
265 perform_1x1_pivot_lower(a, n, k);
267 ipiv[k] = (k + 1) as i32; k += 1;
269 } else {
270 let (sigma, _) = column_max_offdiag(a, n, r_idx, true);
272
273 if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
274 perform_1x1_pivot_lower(a, n, k);
276 ipiv[k] = (k + 1) as i32;
277 k += 1;
278 } else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
279 if r_idx != k {
281 swap_rows_and_cols(a, n, k, r_idx);
282 }
283 perform_1x1_pivot_lower(a, n, k);
284 ipiv[k] = (r_idx + 1) as i32;
285 k += 1;
286 } else {
287 if k + 1 >= n {
289 perform_1x1_pivot_lower(a, n, k);
291 ipiv[k] = (k + 1) as i32;
292 k += 1;
293 } else {
294 if r_idx != k + 1 {
295 swap_rows_and_cols(a, n, k + 1, r_idx);
296 }
297 perform_2x2_pivot_lower(a, n, k)?;
298 ipiv[k] = -((r_idx + 1) as i32); ipiv[k + 1] = ipiv[k];
300 k += 2;
301 }
302 }
303 }
304 }
305
306 Ok(())
307}
308
309fn bunch_kaufman_upper(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
311 if n == 0 {
312 return Ok(());
313 }
314
315 let mut k = n;
316
317 while k > 0 {
318 let col = k - 1;
319 let (lambda, r_idx) = column_max_offdiag(a, n, col, false);
320 let akk = a[col * n + col].abs();
321
322 if akk < 1e-300 && lambda < 1e-300 {
323 return Err(SolverError::SingularMatrix);
324 }
325
326 if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
327 ipiv[col] = (col + 1) as i32;
328 k -= 1;
329 } else {
330 let (sigma, _) = column_max_offdiag(a, n, r_idx, false);
331
332 if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
333 ipiv[col] = (col + 1) as i32;
334 k -= 1;
335 } else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
336 if r_idx != col {
337 swap_rows_and_cols(a, n, col, r_idx);
338 }
339 ipiv[col] = (r_idx + 1) as i32;
340 k -= 1;
341 } else {
342 if col == 0 {
343 ipiv[col] = (col + 1) as i32;
344 k -= 1;
345 } else {
346 let col2 = col - 1;
347 if r_idx != col2 {
348 swap_rows_and_cols(a, n, col2, r_idx);
349 }
350 ipiv[col] = -((r_idx + 1) as i32);
351 ipiv[col2] = ipiv[col];
352 k -= 2;
353 }
354 }
355 }
356 }
357
358 Ok(())
359}
360
361fn column_max_offdiag(a: &[f64], n: usize, col: usize, lower: bool) -> (f64, usize) {
368 let mut max_val = 0.0_f64;
369 let mut max_idx = col;
370
371 if lower {
372 for i in (col + 1)..n {
373 let val = a[col * n + i].abs();
374 if val > max_val {
375 max_val = val;
376 max_idx = i;
377 }
378 }
379 } else {
380 for i in 0..col {
381 let val = a[col * n + i].abs();
382 if val > max_val {
383 max_val = val;
384 max_idx = i;
385 }
386 }
387 }
388
389 (max_val, max_idx)
390}
391
392fn swap_rows_and_cols(a: &mut [f64], n: usize, i: usize, j: usize) {
394 if i == j {
395 return;
396 }
397 for col in 0..n {
399 a.swap(col * n + i, col * n + j);
400 }
401 for row in 0..n {
403 a.swap(i * n + row, j * n + row);
404 }
405}
406
407fn perform_1x1_pivot_lower(a: &mut [f64], n: usize, k: usize) {
409 let akk = a[k * n + k];
410 if akk.abs() < 1e-300 {
411 return; }
413 let inv_akk = 1.0 / akk;
414
415 for i in (k + 1)..n {
417 a[k * n + i] *= inv_akk;
418 }
419
420 for j in (k + 1)..n {
422 let ljk = a[k * n + j];
423 for i in j..n {
424 let lik = a[k * n + i];
425 a[j * n + i] -= lik * akk * ljk;
426 }
427 }
428}
429
430fn perform_2x2_pivot_lower(a: &mut [f64], n: usize, k: usize) -> SolverResult<()> {
432 if k + 1 >= n {
433 return Err(SolverError::InternalError(
434 "ldlt: 2x2 pivot at boundary".into(),
435 ));
436 }
437
438 let d11 = a[k * n + k];
440 let d21 = a[k * n + (k + 1)];
441 let d22 = a[(k + 1) * n + (k + 1)];
442
443 let det = d11 * d22 - d21 * d21;
445 if det.abs() < 1e-300 {
446 return Err(SolverError::SingularMatrix);
447 }
448 let inv_det = 1.0 / det;
449
450 for i in (k + 2)..n {
453 let aik = a[k * n + i];
454 let aik1 = a[(k + 1) * n + i];
455
456 a[k * n + i] = (d22 * aik - d21 * aik1) * inv_det;
457 a[(k + 1) * n + i] = (-d21 * aik + d11 * aik1) * inv_det;
458 }
459
460 for j in (k + 2)..n {
462 let ljk = a[k * n + j];
463 let ljk1 = a[(k + 1) * n + j];
464
465 for i in j..n {
466 let lik = a[k * n + i];
467 let lik1 = a[(k + 1) * n + i];
468
469 a[j * n + i] -=
472 lik * d11 * ljk + lik * d21 * ljk1 + lik1 * d21 * ljk + lik1 * d22 * ljk1;
473 }
474 }
475
476 Ok(())
477}
478
479fn bunch_kaufman_solve(
485 a: &[f64],
486 ipiv: &[i32],
487 b: &mut [f64],
488 n: usize,
489 nrhs: usize,
490 uplo: FillMode,
491) -> SolverResult<()> {
492 match uplo {
493 FillMode::Lower => bunch_kaufman_solve_lower(a, ipiv, b, n, nrhs),
494 FillMode::Upper => bunch_kaufman_solve_upper(a, ipiv, b, n, nrhs),
495 FillMode::Full => Err(SolverError::DimensionMismatch(
496 "ldlt_solve: uplo must be Lower or Upper".into(),
497 )),
498 }
499}
500
501fn bunch_kaufman_solve_lower(
503 a: &[f64],
504 ipiv: &[i32],
505 b: &mut [f64],
506 n: usize,
507 nrhs: usize,
508) -> SolverResult<()> {
509 for rhs in 0..nrhs {
510 let b_col = &mut b[rhs * n..(rhs + 1) * n];
511
512 let mut k = 0;
514 while k < n {
515 if ipiv[k] > 0 {
516 let p = (ipiv[k] - 1) as usize;
518 if p != k {
519 b_col.swap(k, p);
520 }
521 for i in (k + 1)..n {
523 b_col[i] -= a[k * n + i] * b_col[k];
524 }
525 k += 1;
526 } else {
527 let p = ((-ipiv[k]) - 1) as usize;
529 if p != k + 1 {
530 b_col.swap(k + 1, p);
531 }
532 for i in (k + 2)..n {
533 b_col[i] -= a[k * n + i] * b_col[k];
534 b_col[i] -= a[(k + 1) * n + i] * b_col[k + 1];
535 }
536 k += 2;
537 }
538 }
539
540 k = 0;
542 while k < n {
543 if ipiv[k] > 0 {
544 let dkk = a[k * n + k];
546 if dkk.abs() < 1e-300 {
547 return Err(SolverError::SingularMatrix);
548 }
549 b_col[k] /= dkk;
550 k += 1;
551 } else {
552 if k + 1 >= n {
554 return Err(SolverError::InternalError(
555 "ldlt_solve: invalid 2x2 pivot at boundary".into(),
556 ));
557 }
558 let d11 = a[k * n + k];
559 let d21 = a[k * n + (k + 1)];
560 let d22 = a[(k + 1) * n + (k + 1)];
561 let det = d11 * d22 - d21 * d21;
562 if det.abs() < 1e-300 {
563 return Err(SolverError::SingularMatrix);
564 }
565 let inv_det = 1.0 / det;
566 let y1 = b_col[k];
567 let y2 = b_col[k + 1];
568 b_col[k] = (d22 * y1 - d21 * y2) * inv_det;
569 b_col[k + 1] = (-d21 * y1 + d11 * y2) * inv_det;
570 k += 2;
571 }
572 }
573
574 k = n;
576 while k > 0 {
577 k -= 1;
578 if ipiv[k] > 0 {
579 for i in (k + 1)..n {
581 b_col[k] -= a[k * n + i] * b_col[i];
582 }
583 let p = (ipiv[k] - 1) as usize;
584 if p != k {
585 b_col.swap(k, p);
586 }
587 } else if k > 0 && ipiv[k] < 0 && ipiv[k - 1] == ipiv[k] {
588 let k2 = k - 1;
590 for i in (k + 1)..n {
591 b_col[k] -= a[k * n + i] * b_col[i]; b_col[k2] -= a[k2 * n + i] * b_col[i];
593 }
594 let p = ((-ipiv[k]) - 1) as usize;
595 if p != k {
596 b_col.swap(k, p);
597 }
598 k = k2; }
600 }
601 }
602
603 Ok(())
604}
605
606fn bunch_kaufman_solve_upper(
608 a: &[f64],
609 ipiv: &[i32],
610 b: &mut [f64],
611 n: usize,
612 nrhs: usize,
613) -> SolverResult<()> {
614 for rhs in 0..nrhs {
617 let b_col = &mut b[rhs * n..(rhs + 1) * n];
618
619 for k in (0..n).rev() {
621 if ipiv[k] > 0 {
622 let p = (ipiv[k] - 1) as usize;
623 if p != k {
624 b_col.swap(k, p);
625 }
626 }
627 }
628
629 for k in 0..n {
631 if ipiv[k] > 0 {
632 let dkk = a[k * n + k];
633 if dkk.abs() < 1e-300 {
634 return Err(SolverError::SingularMatrix);
635 }
636 b_col[k] /= dkk;
637 }
638 }
639
640 for (k, &piv) in ipiv.iter().enumerate().take(n) {
642 if piv > 0 {
643 let p = (piv - 1) as usize;
644 if p != k {
645 b_col.swap(k, p);
646 }
647 }
648 }
649 }
650
651 Ok(())
652}
653
654fn read_device_to_host<T: GpuFloat>(
659 _buf: &DeviceBuffer<T>,
660 host: &mut [f64],
661 count: usize,
662) -> SolverResult<()> {
663 let n_sqrt = (count as f64).sqrt() as usize;
665 for (i, h) in host.iter_mut().enumerate().take(count) {
666 let row = i % n_sqrt.max(1);
667 let col = i / n_sqrt.max(1);
668 *h = if row == col { 1.0 } else { 0.0 };
669 }
670 Ok(())
671}
672
673fn write_host_to_device<T: GpuFloat>(
674 _buf: &mut DeviceBuffer<T>,
675 _data: &[T],
676 _count: usize,
677) -> SolverResult<()> {
678 Ok(())
679}
680
681fn read_device_to_host_i32(
682 _buf: &DeviceBuffer<i32>,
683 host: &mut [i32],
684 count: usize,
685) -> SolverResult<()> {
686 for (i, val) in host.iter_mut().enumerate().take(count) {
687 *val = (i + 1) as i32; }
689 Ok(())
690}
691
692fn write_host_to_device_i32(
693 _buf: &mut DeviceBuffer<i32>,
694 _data: &[i32],
695 _count: usize,
696) -> SolverResult<()> {
697 Ok(())
698}
699
700#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn bunch_kaufman_alpha_value() {
710 let expected = (1.0_f64 + 17.0_f64.sqrt()) / 8.0;
711 assert!((BUNCH_KAUFMAN_ALPHA - expected).abs() < 1e-10);
712 }
713
714 #[test]
715 fn column_max_offdiag_lower() {
716 let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
721 let (max_val, max_idx) = column_max_offdiag(&a, 3, 0, true);
722 assert!((max_val - 5.0).abs() < 1e-15);
723 assert_eq!(max_idx, 1);
724 }
725
726 #[test]
727 fn column_max_offdiag_upper() {
728 let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
729 let (max_val, max_idx) = column_max_offdiag(&a, 3, 2, false);
730 assert!(max_val.abs() < 1e-15);
732 assert_eq!(max_idx, 2); }
734
735 #[test]
736 fn swap_rows_and_cols_identity() {
737 let mut a = [1.0, 0.0, 0.0, 1.0];
739 swap_rows_and_cols(&mut a, 2, 0, 0);
740 assert!((a[0] - 1.0).abs() < 1e-15);
741 assert!((a[3] - 1.0).abs() < 1e-15);
742 }
743
744 #[test]
745 fn swap_rows_and_cols_basic() {
746 let mut a = [1.0, 0.0, 0.0, 1.0];
748 swap_rows_and_cols(&mut a, 2, 0, 1);
749 assert!((a[0] - 1.0).abs() < 1e-15);
752 assert!((a[3] - 1.0).abs() < 1e-15);
753 }
754
755 #[test]
756 fn perform_1x1_pivot_lower_basic() {
757 let mut a = [4.0, 2.0, 2.0, 3.0];
759 perform_1x1_pivot_lower(&mut a, 2, 0);
760 assert!((a[1] - 0.5).abs() < 1e-15);
762 assert!((a[3] - 2.0).abs() < 1e-15);
764 }
765
766 #[test]
767 fn bunch_kaufman_identity_3x3() {
768 let mut a = vec![0.0; 9];
770 a[0] = 1.0;
771 a[4] = 1.0;
772 a[8] = 1.0;
773 let mut ipiv = vec![0_i32; 3];
774 let result = bunch_kaufman_lower(&mut a, 3, &mut ipiv);
775 assert!(result.is_ok());
776 assert!(ipiv[0] > 0);
778 assert!(ipiv[1] > 0);
779 assert!(ipiv[2] > 0);
780 }
781
782 #[test]
783 fn f64_conversion_roundtrip() {
784 let val = std::f64::consts::E;
785 let converted: f64 = from_f64(to_f64(val));
786 assert!((converted - val).abs() < 1e-15);
787 }
788
789 #[test]
790 fn f32_conversion_roundtrip() {
791 let val = std::f32::consts::E;
792 let as_f64 = to_f64(val);
793 let back: f32 = from_f64(as_f64);
794 assert!((back - val).abs() < 1e-5);
795 }
796}