1#![allow(dead_code)]
25
26use oxicuda_blas::types::{FillMode, GpuFloat};
27use oxicuda_memory::DeviceBuffer;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31
32fn to_f64<T: GpuFloat>(val: T) -> f64 {
37 if T::SIZE == 4 {
38 f32::from_bits(val.to_bits_u64() as u32) as f64
39 } else {
40 f64::from_bits(val.to_bits_u64())
41 }
42}
43
44fn from_f64<T: GpuFloat>(val: f64) -> T {
45 if T::SIZE == 4 {
46 T::from_bits_u64(u64::from((val as f32).to_bits()))
47 } else {
48 T::from_bits_u64(val.to_bits())
49 }
50}
51
52const BUNCH_KAUFMAN_ALPHA: f64 = 0.6403882032022076;
58
59pub struct LdltResult {
71 pub pivot_info: DeviceBuffer<i32>,
73}
74
75impl std::fmt::Debug for LdltResult {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("LdltResult")
78 .field("pivot_info_len", &self.pivot_info.len())
79 .finish()
80 }
81}
82
83pub fn ldlt<T: GpuFloat>(
109 handle: &mut SolverHandle,
110 a: &mut DeviceBuffer<T>,
111 n: usize,
112 uplo: FillMode,
113) -> SolverResult<LdltResult> {
114 if n == 0 {
115 let pivot_info = DeviceBuffer::<i32>::zeroed(0)?;
116 return Ok(LdltResult { pivot_info });
117 }
118 if a.len() < n * n {
119 return Err(SolverError::DimensionMismatch(format!(
120 "ldlt: buffer too small ({} < {})",
121 a.len(),
122 n * n
123 )));
124 }
125 if uplo == FillMode::Full {
126 return Err(SolverError::DimensionMismatch(
127 "ldlt: uplo must be Upper or Lower, not Full".into(),
128 ));
129 }
130
131 let ws = n * n * std::mem::size_of::<f64>();
133 handle.ensure_workspace(ws)?;
134
135 let mut a_host = vec![0.0_f64; n * n];
137 read_device_to_host(a, &mut a_host, n * n)?;
138
139 let mut ipiv = vec![0_i32; n];
141 bunch_kaufman_factorize(&mut a_host, n, uplo, &mut ipiv)?;
142
143 let a_device: Vec<T> = a_host.iter().map(|&v| from_f64(v)).collect();
145 write_host_to_device(a, &a_device, n * n)?;
146
147 let mut pivot_info = DeviceBuffer::<i32>::zeroed(n)?;
148 write_host_to_device_i32(&mut pivot_info, &ipiv, n)?;
149
150 Ok(LdltResult { pivot_info })
151}
152
153pub fn ldlt_solve<T: GpuFloat>(
178 handle: &mut SolverHandle,
179 a: &DeviceBuffer<T>,
180 pivot_info: &DeviceBuffer<i32>,
181 b: &mut DeviceBuffer<T>,
182 n: usize,
183 nrhs: usize,
184 uplo: FillMode,
185) -> SolverResult<()> {
186 if n == 0 || nrhs == 0 {
187 return Ok(());
188 }
189 if a.len() < n * n {
190 return Err(SolverError::DimensionMismatch(
191 "ldlt_solve: factor buffer too small".into(),
192 ));
193 }
194 if pivot_info.len() < n {
195 return Err(SolverError::DimensionMismatch(
196 "ldlt_solve: pivot_info buffer too small".into(),
197 ));
198 }
199 if b.len() < n * nrhs {
200 return Err(SolverError::DimensionMismatch(
201 "ldlt_solve: B buffer too small".into(),
202 ));
203 }
204
205 let ws = (n * n + n * nrhs) * std::mem::size_of::<f64>();
207 handle.ensure_workspace(ws)?;
208
209 let mut a_host = vec![0.0_f64; n * n];
211 read_device_to_host(a, &mut a_host, n * n)?;
212
213 let mut ipiv = vec![0_i32; n];
214 read_device_to_host_i32(pivot_info, &mut ipiv, n)?;
215
216 let mut b_host = vec![0.0_f64; n * nrhs];
217 read_device_to_host(b, &mut b_host, n * nrhs)?;
218
219 bunch_kaufman_solve(&a_host, &ipiv, &mut b_host, n, nrhs, uplo)?;
221
222 let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
224 write_host_to_device(b, &b_device, n * nrhs)?;
225
226 Ok(())
227}
228
229fn bunch_kaufman_factorize(
235 a: &mut [f64],
236 n: usize,
237 uplo: FillMode,
238 ipiv: &mut [i32],
239) -> SolverResult<()> {
240 match uplo {
241 FillMode::Lower => bunch_kaufman_lower(a, n, ipiv),
242 FillMode::Upper => {
243 mirror_upper_to_lower(a, n);
246 bunch_kaufman_lower(a, n, ipiv)
247 }
248 FillMode::Full => Err(SolverError::DimensionMismatch(
249 "ldlt: uplo must be Lower or Upper".into(),
250 )),
251 }
252}
253
254fn bunch_kaufman_lower(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
256 let mut k = 0;
257
258 while k < n {
259 let (lambda, r_idx) = column_max_offdiag(a, n, k, true);
261
262 let akk = a[k * n + k].abs();
263
264 if akk < 1e-300 && lambda < 1e-300 {
265 return Err(SolverError::SingularMatrix);
267 }
268
269 if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
270 perform_1x1_pivot_lower(a, n, k);
272 ipiv[k] = (k + 1) as i32; k += 1;
274 } else {
275 let (sigma, _) = column_max_offdiag(a, n, r_idx, true);
277
278 if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
279 perform_1x1_pivot_lower(a, n, k);
281 ipiv[k] = (k + 1) as i32;
282 k += 1;
283 } else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
284 if r_idx != k {
286 swap_rows_and_cols(a, n, k, r_idx);
287 }
288 perform_1x1_pivot_lower(a, n, k);
289 ipiv[k] = (r_idx + 1) as i32;
290 k += 1;
291 } else {
292 if k + 1 >= n {
294 perform_1x1_pivot_lower(a, n, k);
296 ipiv[k] = (k + 1) as i32;
297 k += 1;
298 } else {
299 if r_idx != k + 1 {
300 swap_rows_and_cols(a, n, k + 1, r_idx);
301 }
302 perform_2x2_pivot_lower(a, n, k)?;
303 ipiv[k] = -((r_idx + 1) as i32); ipiv[k + 1] = ipiv[k];
305 k += 2;
306 }
307 }
308 }
309 }
310
311 Ok(())
312}
313
314fn bunch_kaufman_upper(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
316 if n == 0 {
317 return Ok(());
318 }
319
320 let mut k = n;
321
322 while k > 0 {
323 let col = k - 1;
324 let (lambda, r_idx) = column_max_offdiag(a, n, col, false);
325 let akk = a[col * n + col].abs();
326
327 if akk < 1e-300 && lambda < 1e-300 {
328 return Err(SolverError::SingularMatrix);
329 }
330
331 if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
332 ipiv[col] = (col + 1) as i32;
333 k -= 1;
334 } else {
335 let (sigma, _) = column_max_offdiag(a, n, r_idx, false);
336
337 if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
338 ipiv[col] = (col + 1) as i32;
339 k -= 1;
340 } else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
341 if r_idx != col {
342 swap_rows_and_cols(a, n, col, r_idx);
343 }
344 ipiv[col] = (r_idx + 1) as i32;
345 k -= 1;
346 } else {
347 if col == 0 {
348 ipiv[col] = (col + 1) as i32;
349 k -= 1;
350 } else {
351 let col2 = col - 1;
352 if r_idx != col2 {
353 swap_rows_and_cols(a, n, col2, r_idx);
354 }
355 ipiv[col] = -((r_idx + 1) as i32);
356 ipiv[col2] = ipiv[col];
357 k -= 2;
358 }
359 }
360 }
361 }
362
363 Ok(())
364}
365
366fn column_max_offdiag(a: &[f64], n: usize, col: usize, lower: bool) -> (f64, usize) {
373 let mut max_val = 0.0_f64;
374 let mut max_idx = col;
375
376 if lower {
377 for i in (col + 1)..n {
378 let val = a[col * n + i].abs();
379 if val > max_val {
380 max_val = val;
381 max_idx = i;
382 }
383 }
384 } else {
385 for i in 0..col {
386 let val = a[col * n + i].abs();
387 if val > max_val {
388 max_val = val;
389 max_idx = i;
390 }
391 }
392 }
393
394 (max_val, max_idx)
395}
396
397fn swap_rows_and_cols(a: &mut [f64], n: usize, i: usize, j: usize) {
399 if i == j {
400 return;
401 }
402 for col in 0..n {
404 a.swap(col * n + i, col * n + j);
405 }
406 for row in 0..n {
408 a.swap(i * n + row, j * n + row);
409 }
410}
411
412fn perform_1x1_pivot_lower(a: &mut [f64], n: usize, k: usize) {
414 let akk = a[k * n + k];
415 if akk.abs() < 1e-300 {
416 return; }
418 let inv_akk = 1.0 / akk;
419
420 for i in (k + 1)..n {
422 a[k * n + i] *= inv_akk;
423 }
424
425 for j in (k + 1)..n {
427 let ljk = a[k * n + j];
428 for i in j..n {
429 let lik = a[k * n + i];
430 a[j * n + i] -= lik * akk * ljk;
431 }
432 }
433}
434
435fn perform_2x2_pivot_lower(a: &mut [f64], n: usize, k: usize) -> SolverResult<()> {
437 if k + 1 >= n {
438 return Err(SolverError::InternalError(
439 "ldlt: 2x2 pivot at boundary".into(),
440 ));
441 }
442
443 let d11 = a[k * n + k];
445 let d21 = a[k * n + (k + 1)];
446 let d22 = a[(k + 1) * n + (k + 1)];
447
448 let det = d11 * d22 - d21 * d21;
450 if det.abs() < 1e-300 {
451 return Err(SolverError::SingularMatrix);
452 }
453 let inv_det = 1.0 / det;
454
455 for i in (k + 2)..n {
458 let aik = a[k * n + i];
459 let aik1 = a[(k + 1) * n + i];
460
461 a[k * n + i] = (d22 * aik - d21 * aik1) * inv_det;
462 a[(k + 1) * n + i] = (-d21 * aik + d11 * aik1) * inv_det;
463 }
464
465 for j in (k + 2)..n {
467 let ljk = a[k * n + j];
468 let ljk1 = a[(k + 1) * n + j];
469
470 for i in j..n {
471 let lik = a[k * n + i];
472 let lik1 = a[(k + 1) * n + i];
473
474 a[j * n + i] -=
477 lik * d11 * ljk + lik * d21 * ljk1 + lik1 * d21 * ljk + lik1 * d22 * ljk1;
478 }
479 }
480
481 Ok(())
482}
483
484fn bunch_kaufman_solve(
490 a: &[f64],
491 ipiv: &[i32],
492 b: &mut [f64],
493 n: usize,
494 nrhs: usize,
495 uplo: FillMode,
496) -> SolverResult<()> {
497 match uplo {
498 FillMode::Lower => bunch_kaufman_solve_lower(a, ipiv, b, n, nrhs),
499 FillMode::Upper => bunch_kaufman_solve_lower(a, ipiv, b, n, nrhs),
500 FillMode::Full => Err(SolverError::DimensionMismatch(
501 "ldlt_solve: uplo must be Lower or Upper".into(),
502 )),
503 }
504}
505
506fn bunch_kaufman_solve_lower(
508 a: &[f64],
509 ipiv: &[i32],
510 b: &mut [f64],
511 n: usize,
512 nrhs: usize,
513) -> SolverResult<()> {
514 for rhs in 0..nrhs {
515 let b_col = &mut b[rhs * n..(rhs + 1) * n];
516
517 let mut k = 0;
519 while k < n {
520 if ipiv[k] > 0 {
521 let p = (ipiv[k] - 1) as usize;
523 if p != k {
524 b_col.swap(k, p);
525 }
526 for i in (k + 1)..n {
528 b_col[i] -= a[k * n + i] * b_col[k];
529 }
530 k += 1;
531 } else {
532 let p = ((-ipiv[k]) - 1) as usize;
534 if p != k + 1 {
535 b_col.swap(k + 1, p);
536 }
537 for i in (k + 2)..n {
538 b_col[i] -= a[k * n + i] * b_col[k];
539 b_col[i] -= a[(k + 1) * n + i] * b_col[k + 1];
540 }
541 k += 2;
542 }
543 }
544
545 k = 0;
547 while k < n {
548 if ipiv[k] > 0 {
549 let dkk = a[k * n + k];
551 if dkk.abs() < 1e-300 {
552 return Err(SolverError::SingularMatrix);
553 }
554 b_col[k] /= dkk;
555 k += 1;
556 } else {
557 if k + 1 >= n {
559 return Err(SolverError::InternalError(
560 "ldlt_solve: invalid 2x2 pivot at boundary".into(),
561 ));
562 }
563 let d11 = a[k * n + k];
564 let d21 = a[k * n + (k + 1)];
565 let d22 = a[(k + 1) * n + (k + 1)];
566 let det = d11 * d22 - d21 * d21;
567 if det.abs() < 1e-300 {
568 return Err(SolverError::SingularMatrix);
569 }
570 let inv_det = 1.0 / det;
571 let y1 = b_col[k];
572 let y2 = b_col[k + 1];
573 b_col[k] = (d22 * y1 - d21 * y2) * inv_det;
574 b_col[k + 1] = (-d21 * y1 + d11 * y2) * inv_det;
575 k += 2;
576 }
577 }
578
579 k = n;
581 while k > 0 {
582 k -= 1;
583 if ipiv[k] > 0 {
584 for i in (k + 1)..n {
586 b_col[k] -= a[k * n + i] * b_col[i];
587 }
588 let p = (ipiv[k] - 1) as usize;
589 if p != k {
590 b_col.swap(k, p);
591 }
592 } else if k > 0 && ipiv[k] < 0 && ipiv[k - 1] == ipiv[k] {
593 let k2 = k - 1;
595 for i in (k + 1)..n {
596 b_col[k] -= a[k * n + i] * b_col[i]; b_col[k2] -= a[k2 * n + i] * b_col[i];
598 }
599 let p = ((-ipiv[k]) - 1) as usize;
600 if p != k {
601 b_col.swap(k, p);
602 }
603 k = k2; }
605 }
606 }
607
608 Ok(())
609}
610
611fn mirror_upper_to_lower(a: &mut [f64], n: usize) {
612 for col in 0..n {
613 for row in 0..col {
614 a[col * n + row] = a[row * n + col];
615 }
616 }
617}
618
619fn read_device_to_host<T: GpuFloat>(
624 buf: &DeviceBuffer<T>,
625 host: &mut [f64],
626 count: usize,
627) -> SolverResult<()> {
628 if host.len() < count {
629 return Err(SolverError::DimensionMismatch(format!(
630 "read_device_to_host: host buffer too small ({} < {})",
631 host.len(),
632 count
633 )));
634 }
635 let mut staged = vec![T::gpu_zero(); count];
636 buf.copy_to_host(&mut staged)?;
637 for (dst, src) in host.iter_mut().zip(staged.iter()) {
638 *dst = to_f64(*src);
639 }
640 Ok(())
641}
642
643fn write_host_to_device<T: GpuFloat>(
644 buf: &mut DeviceBuffer<T>,
645 data: &[T],
646 count: usize,
647) -> SolverResult<()> {
648 if data.len() < count {
649 return Err(SolverError::DimensionMismatch(format!(
650 "write_host_to_device: source buffer too small ({} < {})",
651 data.len(),
652 count
653 )));
654 }
655 buf.copy_from_host(&data[..count])?;
656 Ok(())
657}
658
659fn read_device_to_host_i32(
660 buf: &DeviceBuffer<i32>,
661 host: &mut [i32],
662 count: usize,
663) -> SolverResult<()> {
664 if host.len() < count {
665 return Err(SolverError::DimensionMismatch(format!(
666 "read_device_to_host_i32: host buffer too small ({} < {})",
667 host.len(),
668 count
669 )));
670 }
671 buf.copy_to_host(&mut host[..count])?;
672 Ok(())
673}
674
675fn write_host_to_device_i32(
676 buf: &mut DeviceBuffer<i32>,
677 data: &[i32],
678 count: usize,
679) -> SolverResult<()> {
680 if data.len() < count {
681 return Err(SolverError::DimensionMismatch(format!(
682 "write_host_to_device_i32: source buffer too small ({} < {})",
683 data.len(),
684 count
685 )));
686 }
687 buf.copy_from_host(&data[..count])?;
688 Ok(())
689}
690
691#[cfg(test)]
696mod tests {
697 use super::*;
698
699 #[test]
700 fn bunch_kaufman_alpha_value() {
701 let expected = (1.0_f64 + 17.0_f64.sqrt()) / 8.0;
702 assert!((BUNCH_KAUFMAN_ALPHA - expected).abs() < 1e-10);
703 }
704
705 #[test]
706 fn column_max_offdiag_lower() {
707 let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
712 let (max_val, max_idx) = column_max_offdiag(&a, 3, 0, true);
713 assert!((max_val - 5.0).abs() < 1e-15);
714 assert_eq!(max_idx, 1);
715 }
716
717 #[test]
718 fn column_max_offdiag_upper() {
719 let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
720 let (max_val, max_idx) = column_max_offdiag(&a, 3, 2, false);
721 assert!(max_val.abs() < 1e-15);
723 assert_eq!(max_idx, 2); }
725
726 #[test]
727 fn swap_rows_and_cols_identity() {
728 let mut a = [1.0, 0.0, 0.0, 1.0];
730 swap_rows_and_cols(&mut a, 2, 0, 0);
731 assert!((a[0] - 1.0).abs() < 1e-15);
732 assert!((a[3] - 1.0).abs() < 1e-15);
733 }
734
735 #[test]
736 fn swap_rows_and_cols_basic() {
737 let mut a = [1.0, 0.0, 0.0, 1.0];
739 swap_rows_and_cols(&mut a, 2, 0, 1);
740 assert!((a[0] - 1.0).abs() < 1e-15);
743 assert!((a[3] - 1.0).abs() < 1e-15);
744 }
745
746 #[test]
747 fn perform_1x1_pivot_lower_basic() {
748 let mut a = [4.0, 2.0, 2.0, 3.0];
750 perform_1x1_pivot_lower(&mut a, 2, 0);
751 assert!((a[1] - 0.5).abs() < 1e-15);
753 assert!((a[3] - 2.0).abs() < 1e-15);
755 }
756
757 #[test]
758 fn bunch_kaufman_identity_3x3() {
759 let mut a = vec![0.0; 9];
761 a[0] = 1.0;
762 a[4] = 1.0;
763 a[8] = 1.0;
764 let mut ipiv = vec![0_i32; 3];
765 let result = bunch_kaufman_lower(&mut a, 3, &mut ipiv);
766 assert!(result.is_ok());
767 assert!(ipiv[0] > 0);
769 assert!(ipiv[1] > 0);
770 assert!(ipiv[2] > 0);
771 }
772
773 #[test]
774 fn f64_conversion_roundtrip() {
775 let val = std::f64::consts::E;
776 let converted: f64 = from_f64(to_f64(val));
777 assert!((converted - val).abs() < 1e-15);
778 }
779
780 #[test]
781 fn f32_conversion_roundtrip() {
782 let val = std::f32::consts::E;
783 let as_f64 = to_f64(val);
784 let back: f32 = from_f64(as_f64);
785 assert!((back - val).abs() < 1e-5);
786 }
787}