1#![allow(dead_code)]
29
30use oxicuda_blas::GpuFloat;
31
32use crate::error::{SolverError, SolverResult};
33use crate::handle::SolverHandle;
34
35fn to_f64<T: GpuFloat>(val: T) -> f64 {
40 if T::SIZE == 4 {
41 f32::from_bits(val.to_bits_u64() as u32) as f64
42 } else {
43 f64::from_bits(val.to_bits_u64())
44 }
45}
46
47fn from_f64<T: GpuFloat>(val: f64) -> T {
48 if T::SIZE == 4 {
49 T::from_bits_u64(u64::from((val as f32).to_bits()))
50 } else {
51 T::from_bits_u64(val.to_bits())
52 }
53}
54
55pub trait Preconditioner<T: GpuFloat> {
65 fn apply(&self, r: &[T], z: &mut [T]) -> SolverResult<()>;
73}
74
75pub struct IdentityPreconditioner;
79
80impl<T: GpuFloat> Preconditioner<T> for IdentityPreconditioner {
81 fn apply(&self, r: &[T], z: &mut [T]) -> SolverResult<()> {
82 let n = r.len().min(z.len());
83 z[..n].copy_from_slice(&r[..n]);
84 Ok(())
85 }
86}
87
88pub struct JacobiPreconditioner<T: GpuFloat> {
93 inv_diag: Vec<T>,
95}
96
97impl<T: GpuFloat> JacobiPreconditioner<T> {
98 pub fn new(diagonal: &[T]) -> SolverResult<Self> {
108 let mut inv_diag = Vec::with_capacity(diagonal.len());
109 for (i, &d) in diagonal.iter().enumerate() {
110 let dv = to_f64(d);
111 if dv.abs() < 1e-300 {
112 return Err(SolverError::InternalError(format!(
113 "JacobiPreconditioner: zero diagonal at index {i}"
114 )));
115 }
116 inv_diag.push(from_f64(1.0 / dv));
117 }
118 Ok(Self { inv_diag })
119 }
120}
121
122impl<T: GpuFloat> Preconditioner<T> for JacobiPreconditioner<T> {
123 fn apply(&self, r: &[T], z: &mut [T]) -> SolverResult<()> {
124 let n = r.len().min(z.len()).min(self.inv_diag.len());
125 for i in 0..n {
126 z[i] = mul_t(r[i], self.inv_diag[i]);
127 }
128 Ok(())
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct IterativeSolverResult<T> {
139 pub iterations: u32,
141 pub residual: T,
143 pub converged: bool,
145}
146
147#[derive(Debug, Clone)]
153pub struct PcgConfig {
154 pub max_iter: u32,
156 pub tol: f64,
158}
159
160impl Default for PcgConfig {
161 fn default() -> Self {
162 Self {
163 max_iter: 1000,
164 tol: 1e-6,
165 }
166 }
167}
168
169pub fn preconditioned_cg<T, P, F>(
198 _handle: &SolverHandle,
199 spmv: F,
200 precond: &P,
201 b: &[T],
202 x: &mut [T],
203 n: u32,
204 config: &PcgConfig,
205) -> SolverResult<IterativeSolverResult<T>>
206where
207 T: GpuFloat,
208 P: Preconditioner<T>,
209 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
210{
211 let n_usize = n as usize;
212
213 if b.len() < n_usize {
215 return Err(SolverError::DimensionMismatch(format!(
216 "preconditioned_cg: b length ({}) < n ({n})",
217 b.len()
218 )));
219 }
220 if x.len() < n_usize {
221 return Err(SolverError::DimensionMismatch(format!(
222 "preconditioned_cg: x length ({}) < n ({n})",
223 x.len()
224 )));
225 }
226 if n == 0 {
227 return Ok(IterativeSolverResult {
228 iterations: 0,
229 residual: T::gpu_zero(),
230 converged: true,
231 });
232 }
233
234 let b_norm = vec_norm(b, n_usize);
236 let abs_tol = if b_norm > 0.0 {
237 config.tol * b_norm
238 } else {
239 for xi in x.iter_mut().take(n_usize) {
240 *xi = T::gpu_zero();
241 }
242 return Ok(IterativeSolverResult {
243 iterations: 0,
244 residual: T::gpu_zero(),
245 converged: true,
246 });
247 };
248
249 let mut r = vec![T::gpu_zero(); n_usize];
251 let mut ap = vec![T::gpu_zero(); n_usize];
252 spmv(x, &mut ap)?;
253 for i in 0..n_usize {
254 r[i] = sub_t(b[i], ap[i]);
255 }
256
257 let mut z = vec![T::gpu_zero(); n_usize];
259 precond.apply(&r, &mut z)?;
260
261 let mut p = z.clone();
263
264 let mut rz_old = dot_product(&r, &z, n_usize);
266
267 let r_norm = vec_norm(&r, n_usize);
268 if r_norm < abs_tol {
269 return Ok(IterativeSolverResult {
270 iterations: 0,
271 residual: from_f64(r_norm),
272 converged: true,
273 });
274 }
275
276 for iter in 0..config.max_iter {
277 spmv(&p, &mut ap)?;
279
280 let pap = dot_product(&p, &ap, n_usize);
282 if pap.abs() < 1e-300 {
283 return Err(SolverError::InternalError(
284 "preconditioned_cg: p^T * A * p is near zero (A may not be SPD)".into(),
285 ));
286 }
287 let alpha = rz_old / pap;
288 let alpha_t = from_f64(alpha);
289
290 for i in 0..n_usize {
292 x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
293 }
294
295 for i in 0..n_usize {
297 r[i] = sub_t(r[i], mul_t(alpha_t, ap[i]));
298 }
299
300 let r_norm_now = vec_norm(&r, n_usize);
302 if r_norm_now < abs_tol {
303 return Ok(IterativeSolverResult {
304 iterations: iter + 1,
305 residual: from_f64(r_norm_now),
306 converged: true,
307 });
308 }
309
310 precond.apply(&r, &mut z)?;
312
313 let rz_new = dot_product(&r, &z, n_usize);
315
316 if rz_old.abs() < 1e-300 {
318 return Err(SolverError::InternalError(
319 "preconditioned_cg: rz_old is near zero".into(),
320 ));
321 }
322 let beta = rz_new / rz_old;
323 let beta_t = from_f64(beta);
324
325 for i in 0..n_usize {
327 p[i] = add_t(z[i], mul_t(beta_t, p[i]));
328 }
329
330 rz_old = rz_new;
331 }
332
333 let final_norm = vec_norm(&r, n_usize);
334 Ok(IterativeSolverResult {
335 iterations: config.max_iter,
336 residual: from_f64(final_norm),
337 converged: false,
338 })
339}
340
341#[derive(Debug, Clone)]
347pub struct PgmresConfig {
348 pub max_iter: u32,
350 pub tol: f64,
352 pub restart: u32,
354}
355
356impl Default for PgmresConfig {
357 fn default() -> Self {
358 Self {
359 max_iter: 1000,
360 tol: 1e-6,
361 restart: 30,
362 }
363 }
364}
365
366pub fn preconditioned_gmres<T, P, F>(
392 _handle: &SolverHandle,
393 spmv: F,
394 precond: &P,
395 b: &[T],
396 x: &mut [T],
397 n: u32,
398 config: &PgmresConfig,
399) -> SolverResult<IterativeSolverResult<T>>
400where
401 T: GpuFloat,
402 P: Preconditioner<T>,
403 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
404{
405 let n_usize = n as usize;
406
407 if b.len() < n_usize {
409 return Err(SolverError::DimensionMismatch(format!(
410 "preconditioned_gmres: b length ({}) < n ({n})",
411 b.len()
412 )));
413 }
414 if x.len() < n_usize {
415 return Err(SolverError::DimensionMismatch(format!(
416 "preconditioned_gmres: x length ({}) < n ({n})",
417 x.len()
418 )));
419 }
420 if n == 0 {
421 return Ok(IterativeSolverResult {
422 iterations: 0,
423 residual: T::gpu_zero(),
424 converged: true,
425 });
426 }
427
428 let b_norm = vec_norm(b, n_usize);
429 let abs_tol = if b_norm > 0.0 {
430 config.tol * b_norm
431 } else {
432 for xi in x.iter_mut().take(n_usize) {
433 *xi = T::gpu_zero();
434 }
435 return Ok(IterativeSolverResult {
436 iterations: 0,
437 residual: T::gpu_zero(),
438 converged: true,
439 });
440 };
441
442 let m = config.restart.min(n) as usize;
443 let mut total_iters = 0_u32;
444
445 while total_iters < config.max_iter {
447 let (iters, converged, res_norm) = pgmres_cycle(
448 &spmv,
449 precond,
450 b,
451 x,
452 n_usize,
453 m,
454 abs_tol,
455 config.max_iter.saturating_sub(total_iters),
456 )?;
457 total_iters += iters;
458
459 if converged {
460 return Ok(IterativeSolverResult {
461 iterations: total_iters,
462 residual: from_f64(res_norm),
463 converged: true,
464 });
465 }
466
467 if iters == 0 {
468 break; }
470 }
471
472 let mut r = vec![T::gpu_zero(); n_usize];
474 let mut ax = vec![T::gpu_zero(); n_usize];
475 spmv(x, &mut ax)?;
476 for i in 0..n_usize {
477 r[i] = sub_t(b[i], ax[i]);
478 }
479 let r_norm = vec_norm(&r, n_usize);
480
481 Ok(IterativeSolverResult {
482 iterations: total_iters,
483 residual: from_f64(r_norm),
484 converged: r_norm < abs_tol,
485 })
486}
487
488#[allow(clippy::too_many_arguments)]
492fn pgmres_cycle<T, P, F>(
493 spmv: &F,
494 precond: &P,
495 b: &[T],
496 x: &mut [T],
497 n: usize,
498 m: usize,
499 abs_tol: f64,
500 max_iters: u32,
501) -> SolverResult<(u32, bool, f64)>
502where
503 T: GpuFloat,
504 P: Preconditioner<T>,
505 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
506{
507 let mut r = vec![T::gpu_zero(); n];
509 let mut ax = vec![T::gpu_zero(); n];
510 spmv(x, &mut ax)?;
511 for i in 0..n {
512 r[i] = sub_t(b[i], ax[i]);
513 }
514
515 let mut z = vec![T::gpu_zero(); n];
517 precond.apply(&r, &mut z)?;
518
519 let beta = vec_norm(&z, n);
520
521 if beta < abs_tol {
522 return Ok((0, true, vec_norm(&r, n)));
523 }
524
525 let mut v_basis: Vec<Vec<T>> = Vec::with_capacity(m + 1);
527
528 let inv_beta = from_f64(1.0 / beta);
530 let v0: Vec<T> = z.iter().map(|&zi| mul_t(zi, inv_beta)).collect();
531 v_basis.push(v0);
532
533 let mut h = vec![vec![0.0_f64; m + 1]; m];
535
536 let mut cs = vec![0.0_f64; m];
538 let mut sn = vec![0.0_f64; m];
539
540 let mut g = vec![0.0_f64; m + 1];
542 g[0] = beta;
543
544 let mut j = 0;
545 let max_j = m.min(max_iters as usize);
546 let mut converged = false;
547
548 while j < max_j {
549 let mut av = vec![T::gpu_zero(); n];
551 spmv(&v_basis[j], &mut av)?;
552 let mut w = vec![T::gpu_zero(); n];
553 precond.apply(&av, &mut w)?;
554
555 for i in 0..=j {
557 h[j][i] = dot_product(&v_basis[i], &w, n);
558 let h_ij_t = from_f64(h[j][i]);
559 for k in 0..n {
560 w[k] = sub_t(w[k], mul_t(h_ij_t, v_basis[i][k]));
561 }
562 }
563
564 let w_norm = vec_norm(&w, n);
565 h[j][j + 1] = w_norm;
566
567 if w_norm > 1e-300 {
569 let inv_w = from_f64(1.0 / w_norm);
570 let vj1: Vec<T> = w.iter().map(|&wi| mul_t(wi, inv_w)).collect();
571 v_basis.push(vj1);
572 } else {
573 let vj1 = vec![T::gpu_zero(); n];
575 v_basis.push(vj1);
576 }
577
578 for i in 0..j {
580 let tmp = cs[i] * h[j][i] + sn[i] * h[j][i + 1];
581 h[j][i + 1] = -sn[i] * h[j][i] + cs[i] * h[j][i + 1];
582 h[j][i] = tmp;
583 }
584
585 let (c, s) = givens_rotation(h[j][j], h[j][j + 1]);
587 cs[j] = c;
588 sn[j] = s;
589
590 h[j][j] = c * h[j][j] + s * h[j][j + 1];
592 h[j][j + 1] = 0.0;
593
594 let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
596 g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
597 g[j] = tmp;
598
599 j += 1;
600
601 if g[j].abs() < abs_tol {
603 converged = true;
604 break;
605 }
606 }
607
608 let mut y = vec![0.0_f64; j];
610 for i in (0..j).rev() {
611 y[i] = g[i];
612 for k in (i + 1)..j {
613 y[i] -= h[k][i] * y[k];
614 }
615 if h[i][i].abs() > 1e-300 {
616 y[i] /= h[i][i];
617 }
618 }
619
620 for i in 0..j {
622 let yi_t = from_f64(y[i]);
623 for k in 0..n {
624 x[k] = add_t(x[k], mul_t(yi_t, v_basis[i][k]));
625 }
626 }
627
628 let mut r_final = vec![T::gpu_zero(); n];
630 let mut ax_final = vec![T::gpu_zero(); n];
631 spmv(x, &mut ax_final)?;
632 for i in 0..n {
633 r_final[i] = sub_t(b[i], ax_final[i]);
634 }
635 let r_norm = vec_norm(&r_final, n);
636
637 Ok((j as u32, converged || r_norm < abs_tol, r_norm))
638}
639
640fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
645 if b.abs() < 1e-300 {
646 return (1.0, 0.0);
647 }
648 if a.abs() < 1e-300 {
649 return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
650 }
651 let r = (a * a + b * b).sqrt();
652 (a / r, b / r)
653}
654
655fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
656 let mut sum = 0.0_f64;
657 for i in 0..n {
658 sum += to_f64(a[i]) * to_f64(b[i]);
659 }
660 sum
661}
662
663fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
664 dot_product(v, v, n).sqrt()
665}
666
667fn add_t<T: GpuFloat>(a: T, b: T) -> T {
668 from_f64(to_f64(a) + to_f64(b))
669}
670
671fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
672 from_f64(to_f64(a) - to_f64(b))
673}
674
675fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
676 from_f64(to_f64(a) * to_f64(b))
677}
678
679#[cfg(test)]
684mod tests {
685 use super::*;
686
687 #[test]
690 fn pcg_config_default() {
691 let cfg = PcgConfig::default();
692 assert_eq!(cfg.max_iter, 1000);
693 assert!((cfg.tol - 1e-6).abs() < 1e-15);
694 }
695
696 #[test]
697 fn pgmres_config_default() {
698 let cfg = PgmresConfig::default();
699 assert_eq!(cfg.max_iter, 1000);
700 assert!((cfg.tol - 1e-6).abs() < 1e-15);
701 assert_eq!(cfg.restart, 30);
702 }
703
704 #[test]
705 fn pcg_config_custom() {
706 let cfg = PcgConfig {
707 max_iter: 500,
708 tol: 1e-10,
709 };
710 assert_eq!(cfg.max_iter, 500);
711 assert!((cfg.tol - 1e-10).abs() < 1e-20);
712 }
713
714 #[test]
715 fn pgmres_config_custom() {
716 let cfg = PgmresConfig {
717 max_iter: 2000,
718 tol: 1e-8,
719 restart: 50,
720 };
721 assert_eq!(cfg.max_iter, 2000);
722 assert_eq!(cfg.restart, 50);
723 }
724
725 #[test]
728 fn identity_preconditioner_copies() {
729 let r = [1.0_f64, 2.0, 3.0];
730 let mut z = [0.0_f64; 3];
731 let p = IdentityPreconditioner;
732 let result = p.apply(&r, &mut z);
733 assert!(result.is_ok());
734 for i in 0..3 {
735 assert!((z[i] - r[i]).abs() < 1e-15);
736 }
737 }
738
739 #[test]
742 fn jacobi_preconditioner_basic() {
743 let diag = [2.0_f64, 4.0, 8.0];
744 let jp = JacobiPreconditioner::new(&diag);
745 assert!(jp.is_ok());
746 let jp = match jp {
747 Ok(v) => v,
748 Err(e) => panic!("test: {e}"),
749 };
750
751 let r = [2.0_f64, 8.0, 16.0];
752 let mut z = [0.0_f64; 3];
753 let result = jp.apply(&r, &mut z);
754 assert!(result.is_ok());
755 assert!((z[0] - 1.0).abs() < 1e-12);
756 assert!((z[1] - 2.0).abs() < 1e-12);
757 assert!((z[2] - 2.0).abs() < 1e-12);
758 }
759
760 #[test]
761 fn jacobi_preconditioner_zero_diagonal() {
762 let diag = [1.0_f64, 0.0, 3.0];
763 let jp = JacobiPreconditioner::new(&diag);
764 assert!(jp.is_err());
765 }
766
767 #[test]
770 fn iterative_result_fields() {
771 let result: IterativeSolverResult<f64> = IterativeSolverResult {
772 iterations: 42,
773 residual: 1e-8,
774 converged: true,
775 };
776 assert_eq!(result.iterations, 42);
777 assert!(result.converged);
778 assert!((result.residual - 1e-8).abs() < 1e-15);
779 }
780
781 #[test]
784 fn givens_rotation_basic() {
785 let (cs, sn) = givens_rotation(3.0, 4.0);
786 let r = cs * 3.0 + sn * 4.0;
787 assert!((r - 5.0).abs() < 1e-10);
788 }
789
790 #[test]
791 fn givens_rotation_zero_b() {
792 let (cs, sn) = givens_rotation(5.0, 0.0);
793 assert!((cs - 1.0).abs() < 1e-15);
794 assert!(sn.abs() < 1e-15);
795 }
796
797 #[test]
800 fn dot_product_basic() {
801 let a = [1.0_f64, 2.0, 3.0];
802 let b = [4.0_f64, 5.0, 6.0];
803 assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
804 }
805
806 #[test]
807 fn vec_norm_basic() {
808 let v = [3.0_f64, 4.0];
809 assert!((vec_norm(&v, 2) - 5.0).abs() < 1e-10);
810 }
811
812 #[test]
815 fn add_sub_mul_helpers() {
816 let a = 3.0_f64;
817 let b = 4.0_f64;
818 assert!((to_f64(add_t(a, b)) - 7.0).abs() < 1e-15);
819 assert!((to_f64(sub_t(a, b)) - (-1.0)).abs() < 1e-15);
820 assert!((to_f64(mul_t(a, b)) - 12.0).abs() < 1e-15);
821 }
822
823 #[test]
824 fn f32_conversion_roundtrip() {
825 let val = 3.15_f32;
826 let as_f64 = to_f64(val);
827 let back: f32 = from_f64(as_f64);
828 assert!((back - val).abs() < 1e-6);
829 }
830
831 #[test]
832 fn f64_conversion_roundtrip() {
833 let val = std::f64::consts::PI;
834 let as_f64 = to_f64(val);
835 let back: f64 = from_f64(as_f64);
836 assert!((back - val).abs() < 1e-15);
837 }
838
839 #[test]
840 fn jacobi_preconditioner_f32() {
841 let diag = [2.0_f32, 4.0, 8.0];
842 let jp = match JacobiPreconditioner::new(&diag) {
843 Ok(v) => v,
844 Err(e) => panic!("test: {e}"),
845 };
846
847 let r = [2.0_f32, 8.0, 16.0];
848 let mut z = [0.0_f32; 3];
849 let result = jp.apply(&r, &mut z);
850 assert!(result.is_ok());
851 assert!((z[0] - 1.0_f32).abs() < 1e-5);
852 assert!((z[1] - 2.0_f32).abs() < 1e-5);
853 assert!((z[2] - 2.0_f32).abs() < 1e-5);
854 }
855
856 #[test]
857 fn identity_preconditioner_f32() {
858 let r = [1.0_f32, 2.0, 3.0];
859 let mut z = [0.0_f32; 3];
860 let p = IdentityPreconditioner;
861 let result = p.apply(&r, &mut z);
862 assert!(result.is_ok());
863 for i in 0..3 {
864 assert!((z[i] - r[i]).abs() < 1e-6);
865 }
866 }
867
868 #[test]
869 fn iterative_result_not_converged() {
870 let result: IterativeSolverResult<f64> = IterativeSolverResult {
871 iterations: 1000,
872 residual: 1e-2,
873 converged: false,
874 };
875 assert_eq!(result.iterations, 1000);
876 assert!(!result.converged);
877 }
878}