1#![allow(dead_code)]
17
18use oxicuda_blas::GpuFloat;
19
20use crate::error::{SolverError, SolverResult};
21use crate::handle::SolverHandle;
22
23fn to_f64<T: GpuFloat>(val: T) -> f64 {
28 if T::SIZE == 4 {
29 f32::from_bits(val.to_bits_u64() as u32) as f64
30 } else {
31 f64::from_bits(val.to_bits_u64())
32 }
33}
34
35fn from_f64<T: GpuFloat>(val: f64) -> T {
36 if T::SIZE == 4 {
37 T::from_bits_u64(u64::from((val as f32).to_bits()))
38 } else {
39 T::from_bits_u64(val.to_bits())
40 }
41}
42
43const DEFAULT_RESTART: u32 = 30;
45
46#[derive(Debug, Clone)]
52pub struct GmresConfig {
53 pub max_iter: u32,
55 pub tol: f64,
57 pub restart: u32,
59}
60
61impl Default for GmresConfig {
62 fn default() -> Self {
63 Self {
64 max_iter: 1000,
65 tol: 1e-6,
66 restart: DEFAULT_RESTART,
67 }
68 }
69}
70
71pub fn gmres_solve<T, F>(
97 _handle: &SolverHandle,
98 spmv: F,
99 b: &[T],
100 x: &mut [T],
101 n: u32,
102 config: &GmresConfig,
103) -> SolverResult<u32>
104where
105 T: GpuFloat,
106 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
107{
108 let n_usize = n as usize;
109
110 if b.len() < n_usize {
112 return Err(SolverError::DimensionMismatch(format!(
113 "gmres_solve: b length ({}) < n ({n})",
114 b.len()
115 )));
116 }
117 if x.len() < n_usize {
118 return Err(SolverError::DimensionMismatch(format!(
119 "gmres_solve: x length ({}) < n ({n})",
120 x.len()
121 )));
122 }
123 if n == 0 {
124 return Ok(0);
125 }
126
127 let b_norm = vec_norm(b, n_usize);
128 let abs_tol = if b_norm > 0.0 {
129 config.tol * b_norm
130 } else {
131 for xi in x.iter_mut().take(n_usize) {
132 *xi = T::gpu_zero();
133 }
134 return Ok(0);
135 };
136
137 let m = config.restart.min(n) as usize;
138 let mut total_iters = 0_u32;
139
140 while total_iters < config.max_iter {
142 let iters = gmres_cycle(
143 &spmv,
144 b,
145 x,
146 n_usize,
147 m,
148 abs_tol,
149 config.max_iter - total_iters,
150 )?;
151 total_iters += iters;
152
153 let mut r = vec![T::gpu_zero(); n_usize];
155 let mut ax = vec![T::gpu_zero(); n_usize];
156 spmv(x, &mut ax)?;
157 for i in 0..n_usize {
158 r[i] = sub_t(b[i], ax[i]);
159 }
160 total_iters += 1; let r_norm = vec_norm(&r, n_usize);
163 if r_norm < abs_tol {
164 return Ok(total_iters);
165 }
166
167 if iters == 0 {
168 break; }
170 }
171
172 let mut r = vec![T::gpu_zero(); n_usize];
174 let mut ax = vec![T::gpu_zero(); n_usize];
175 spmv(x, &mut ax)?;
176 for i in 0..n_usize {
177 r[i] = sub_t(b[i], ax[i]);
178 }
179 let r_norm = vec_norm(&r, n_usize);
180
181 if r_norm < abs_tol {
182 Ok(total_iters)
183 } else {
184 Err(SolverError::ConvergenceFailure {
185 iterations: total_iters,
186 residual: r_norm,
187 })
188 }
189}
190
191fn gmres_cycle<T, F>(
200 spmv: &F,
201 b: &[T],
202 x: &mut [T],
203 n: usize,
204 m: usize,
205 abs_tol: f64,
206 max_iters: u32,
207) -> SolverResult<u32>
208where
209 T: GpuFloat,
210 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
211{
212 let mut r = vec![T::gpu_zero(); n];
214 let mut ax = vec![T::gpu_zero(); n];
215 spmv(x, &mut ax)?;
216 for i in 0..n {
217 r[i] = sub_t(b[i], ax[i]);
218 }
219 let beta = vec_norm(&r, n);
220
221 if beta < abs_tol {
222 return Ok(0);
223 }
224
225 let mut v_basis: Vec<Vec<T>> = Vec::with_capacity(m + 1);
227
228 let inv_beta = from_f64(1.0 / beta);
230 let v0: Vec<T> = r.iter().map(|&ri| mul_t(ri, inv_beta)).collect();
231 v_basis.push(v0);
232
233 let mut h = vec![vec![0.0_f64; m + 1]; m];
235
236 let mut cs = vec![0.0_f64; m];
238 let mut sn = vec![0.0_f64; m];
239
240 let mut g = vec![0.0_f64; m + 1];
242 g[0] = beta;
243
244 let mut j = 0;
245 let max_j = m.min(max_iters as usize);
246
247 while j < max_j {
248 let mut w = vec![T::gpu_zero(); n];
250 spmv(&v_basis[j], &mut w)?;
251
252 for i in 0..=j {
254 h[j][i] = dot_product(&v_basis[i], &w, n);
255 let h_ij_t = from_f64(h[j][i]);
256 for k in 0..n {
257 w[k] = sub_t(w[k], mul_t(h_ij_t, v_basis[i][k]));
258 }
259 }
260
261 let w_norm = vec_norm(&w, n);
262 h[j][j + 1] = w_norm;
263
264 if w_norm > 1e-300 {
266 let inv_w = from_f64(1.0 / w_norm);
267 let vj1: Vec<T> = w.iter().map(|&wi| mul_t(wi, inv_w)).collect();
268 v_basis.push(vj1);
269 } else {
270 let vj1 = vec![T::gpu_zero(); n];
272 v_basis.push(vj1);
273 }
274
275 for i in 0..j {
277 let tmp = cs[i] * h[j][i] + sn[i] * h[j][i + 1];
278 h[j][i + 1] = -sn[i] * h[j][i] + cs[i] * h[j][i + 1];
279 h[j][i] = tmp;
280 }
281
282 let (c, s) = givens_rotation(h[j][j], h[j][j + 1]);
284 cs[j] = c;
285 sn[j] = s;
286
287 h[j][j] = c * h[j][j] + s * h[j][j + 1];
289 h[j][j + 1] = 0.0;
290
291 let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
293 g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
294 g[j] = tmp;
295
296 j += 1;
297
298 if g[j].abs() < abs_tol {
300 break;
301 }
302 }
303
304 let mut y = vec![0.0_f64; j];
306 for i in (0..j).rev() {
307 y[i] = g[i];
308 for k in (i + 1)..j {
309 y[i] -= h[k][i] * y[k];
310 }
311 if h[i][i].abs() > 1e-300 {
312 y[i] /= h[i][i];
313 }
314 }
315
316 for i in 0..j {
318 let yi_t = from_f64(y[i]);
319 for k in 0..n {
320 x[k] = add_t(x[k], mul_t(yi_t, v_basis[i][k]));
321 }
322 }
323
324 Ok(j as u32)
325}
326
327fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
332 if b.abs() < 1e-300 {
333 return (1.0, 0.0);
334 }
335 if a.abs() < 1e-300 {
336 return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
337 }
338 let r = (a * a + b * b).sqrt();
339 (a / r, b / r)
340}
341
342fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
343 let mut sum = 0.0_f64;
344 for i in 0..n {
345 sum += to_f64(a[i]) * to_f64(b[i]);
346 }
347 sum
348}
349
350fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
351 dot_product(v, v, n).sqrt()
352}
353
354fn add_t<T: GpuFloat>(a: T, b: T) -> T {
355 from_f64(to_f64(a) + to_f64(b))
356}
357
358fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
359 from_f64(to_f64(a) - to_f64(b))
360}
361
362fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
363 from_f64(to_f64(a) * to_f64(b))
364}
365
366#[cfg(test)]
371mod tests {
372 use super::*;
373
374 fn gmres_solve_cpu<T, F>(
379 spmv: F,
380 b: &[T],
381 x: &mut [T],
382 n: u32,
383 config: &GmresConfig,
384 ) -> SolverResult<u32>
385 where
386 T: GpuFloat,
387 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
388 {
389 let n_usize = n as usize;
390
391 if b.len() < n_usize {
392 return Err(SolverError::DimensionMismatch(format!(
393 "gmres_solve_cpu: b length ({}) < n ({n})",
394 b.len()
395 )));
396 }
397 if x.len() < n_usize {
398 return Err(SolverError::DimensionMismatch(format!(
399 "gmres_solve_cpu: x length ({}) < n ({n})",
400 x.len()
401 )));
402 }
403 if n == 0 {
404 return Ok(0);
405 }
406
407 let b_norm = vec_norm(b, n_usize);
408 let abs_tol = if b_norm > 0.0 {
409 config.tol * b_norm
410 } else {
411 for xi in x.iter_mut().take(n_usize) {
412 *xi = T::gpu_zero();
413 }
414 return Ok(0);
415 };
416
417 let m = config.restart.min(n) as usize;
418 let mut total_iters = 0_u32;
419
420 while total_iters < config.max_iter {
421 let iters = gmres_cycle(
422 &spmv,
423 b,
424 x,
425 n_usize,
426 m,
427 abs_tol,
428 config.max_iter - total_iters,
429 )?;
430 total_iters += iters;
431
432 let mut r = vec![T::gpu_zero(); n_usize];
433 let mut ax = vec![T::gpu_zero(); n_usize];
434 spmv(x, &mut ax)?;
435 for i in 0..n_usize {
436 r[i] = sub_t(b[i], ax[i]);
437 }
438 total_iters += 1;
439
440 let r_norm = vec_norm(&r, n_usize);
441 if r_norm < abs_tol {
442 return Ok(total_iters);
443 }
444
445 if iters == 0 {
446 break;
447 }
448 }
449
450 let mut r = vec![T::gpu_zero(); n_usize];
451 let mut ax = vec![T::gpu_zero(); n_usize];
452 spmv(x, &mut ax)?;
453 for i in 0..n_usize {
454 r[i] = sub_t(b[i], ax[i]);
455 }
456 let r_norm = vec_norm(&r, n_usize);
457
458 if r_norm < abs_tol {
459 Ok(total_iters)
460 } else {
461 Err(SolverError::ConvergenceFailure {
462 iterations: total_iters,
463 residual: r_norm,
464 })
465 }
466 }
467
468 #[test]
469 fn gmres_config_default() {
470 let cfg = GmresConfig::default();
471 assert_eq!(cfg.max_iter, 1000);
472 assert!((cfg.tol - 1e-6).abs() < 1e-15);
473 assert_eq!(cfg.restart, DEFAULT_RESTART);
474 }
475
476 #[test]
477 fn gmres_config_custom() {
478 let cfg = GmresConfig {
479 max_iter: 500,
480 tol: 1e-10,
481 restart: 50,
482 };
483 assert_eq!(cfg.restart, 50);
484 }
485
486 #[test]
487 fn givens_rotation_basic() {
488 let (cs, sn) = givens_rotation(3.0, 4.0);
489 let r = cs * 3.0 + sn * 4.0;
490 assert!((r - 5.0).abs() < 1e-10);
491 }
492
493 #[test]
494 fn givens_rotation_zero_b() {
495 let (cs, sn) = givens_rotation(5.0, 0.0);
496 assert!((cs - 1.0).abs() < 1e-15);
497 assert!(sn.abs() < 1e-15);
498 }
499
500 #[test]
501 fn dot_product_basic() {
502 let a = [1.0_f64, 2.0, 3.0];
503 let b = [4.0_f64, 5.0, 6.0];
504 assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
505 }
506
507 #[test]
508 fn vec_norm_unit() {
509 let v = [1.0_f64, 0.0, 0.0];
510 assert!((vec_norm(&v, 3) - 1.0).abs() < 1e-15);
511 }
512
513 #[test]
519 fn gmres_converges_identity_3x3() {
520 let b = vec![3.0_f64, 7.0, -2.0];
521 let mut x = vec![0.0_f64; 3];
522 let config = GmresConfig {
523 max_iter: 50,
524 tol: 1e-10,
525 restart: 10,
526 };
527
528 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
530 out.copy_from_slice(v);
531 Ok(())
532 };
533
534 let _iters = gmres_solve_cpu(spmv, &b, &mut x, 3, &config)
535 .expect("GMRES should converge on identity system");
536
537 assert!((x[0] - 3.0).abs() < 1e-8, "x[0] = {} expected 3.0", x[0]);
538 assert!((x[1] - 7.0).abs() < 1e-8, "x[1] = {} expected 7.0", x[1]);
539 assert!(
540 (x[2] - (-2.0)).abs() < 1e-8,
541 "x[2] = {} expected -2.0",
542 x[2]
543 );
544 }
545
546 #[test]
550 fn gmres_converges_tridiagonal_4x4() {
551 let b = vec![1.0_f64, 1.0, 1.0, 1.0];
552 let mut x = vec![0.0_f64; 4];
553 let config = GmresConfig {
554 max_iter: 200,
555 tol: 1e-10,
556 restart: 10,
557 };
558
559 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
561 out[0] = 2.0 * v[0] - v[1];
562 out[1] = -v[0] + 2.0 * v[1] - v[2];
563 out[2] = -v[1] + 2.0 * v[2] - v[3];
564 out[3] = -v[2] + 2.0 * v[3];
565 Ok(())
566 };
567
568 let _iters = gmres_solve_cpu(spmv, &b, &mut x, 4, &config)
569 .expect("GMRES should converge on tridiagonal system");
570
571 assert!((x[0] - 2.0).abs() < 1e-7, "x[0] = {} expected 2.0", x[0]);
572 assert!((x[1] - 3.0).abs() < 1e-7, "x[1] = {} expected 3.0", x[1]);
573 assert!((x[2] - 3.0).abs() < 1e-7, "x[2] = {} expected 3.0", x[2]);
574 assert!((x[3] - 2.0).abs() < 1e-7, "x[3] = {} expected 2.0", x[3]);
575 }
576
577 #[test]
579 fn gmres_zero_rhs_returns_zero() {
580 let b = vec![0.0_f64; 3];
581 let mut x = vec![1.0_f64; 3]; let config = GmresConfig::default();
583
584 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
585 out.copy_from_slice(v);
586 Ok(())
587 };
588
589 let iters = gmres_solve_cpu(spmv, &b, &mut x, 3, &config).expect("zero RHS should succeed");
590 assert_eq!(iters, 0);
591 for &xi in &x {
592 assert!(xi.abs() < 1e-15, "x should be zeroed for zero RHS");
593 }
594 }
595
596 #[test]
598 fn gmres_dimension_mismatch() {
599 let b = vec![1.0_f64]; let mut x = vec![0.0_f64; 3];
601 let config = GmresConfig::default();
602 let spmv = |_: &[f64], _: &mut [f64]| -> SolverResult<()> { Ok(()) };
603 let result = gmres_solve_cpu(spmv, &b, &mut x, 3, &config);
604 assert!(matches!(result, Err(SolverError::DimensionMismatch(_))));
605 }
606
607 #[test]
611 fn gmres_converges_diagonal_spd() {
612 let b = vec![1.0_f64, 4.0, 9.0];
613 let mut x = vec![0.0_f64; 3];
614 let config = GmresConfig {
615 max_iter: 100,
616 tol: 1e-10,
617 restart: 10,
618 };
619
620 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
621 out[0] = 1.0 * v[0];
622 out[1] = 4.0 * v[1];
623 out[2] = 9.0 * v[2];
624 Ok(())
625 };
626
627 let _iters = gmres_solve_cpu(spmv, &b, &mut x, 3, &config)
628 .expect("GMRES should converge on diagonal SPD");
629
630 assert!((x[0] - 1.0).abs() < 1e-8, "x[0] = {} expected 1.0", x[0]);
631 assert!((x[1] - 1.0).abs() < 1e-8, "x[1] = {} expected 1.0", x[1]);
632 assert!((x[2] - 1.0).abs() < 1e-8, "x[2] = {} expected 1.0", x[2]);
633 }
634}