1#![allow(dead_code)]
26
27use oxicuda_blas::GpuFloat;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31use crate::sparse::preconditioned::{IterativeSolverResult, Preconditioner};
32
33fn to_f64<T: GpuFloat>(val: T) -> f64 {
38 if T::SIZE == 4 {
39 f32::from_bits(val.to_bits_u64() as u32) as f64
40 } else {
41 f64::from_bits(val.to_bits_u64())
42 }
43}
44
45fn from_f64<T: GpuFloat>(val: f64) -> T {
46 if T::SIZE == 4 {
47 T::from_bits_u64(u64::from((val as f32).to_bits()))
48 } else {
49 T::from_bits_u64(val.to_bits())
50 }
51}
52
53#[derive(Debug, Clone)]
59pub struct FgmresConfig {
60 pub restart: usize,
62 pub max_iter: usize,
64 pub tol: f64,
66}
67
68impl Default for FgmresConfig {
69 fn default() -> Self {
70 Self {
71 restart: 30,
72 max_iter: 1000,
73 tol: 1e-6,
74 }
75 }
76}
77
78pub fn fgmres<T, P, F>(
110 _handle: &SolverHandle,
111 spmv: F,
112 precond: &P,
113 b: &[T],
114 x: &mut [T],
115 config: &FgmresConfig,
116) -> SolverResult<IterativeSolverResult<T>>
117where
118 T: GpuFloat,
119 P: Preconditioner<T>,
120 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
121{
122 let n = b.len();
123 if x.len() < n {
124 return Err(SolverError::DimensionMismatch(format!(
125 "fgmres: x length ({}) < b length ({n})",
126 x.len()
127 )));
128 }
129 if n == 0 {
130 return Ok(IterativeSolverResult {
131 iterations: 0,
132 residual: T::gpu_zero(),
133 converged: true,
134 });
135 }
136
137 let b_norm = vec_norm(b, n);
138 let abs_tol = if b_norm > 0.0 {
139 config.tol * b_norm
140 } else {
141 for xi in x.iter_mut().take(n) {
142 *xi = T::gpu_zero();
143 }
144 return Ok(IterativeSolverResult {
145 iterations: 0,
146 residual: T::gpu_zero(),
147 converged: true,
148 });
149 };
150
151 let m = config.restart.min(n);
152 let mut total_iters = 0_u32;
153
154 while (total_iters as usize) < config.max_iter {
156 let remaining = config.max_iter.saturating_sub(total_iters as usize);
157 let (iters, converged, res_norm) =
158 fgmres_cycle(&spmv, precond, b, x, n, m, abs_tol, remaining)?;
159 total_iters += iters;
160
161 if converged {
162 return Ok(IterativeSolverResult {
163 iterations: total_iters,
164 residual: from_f64(res_norm),
165 converged: true,
166 });
167 }
168
169 if iters == 0 {
170 break; }
172 }
173
174 let mut r = vec![T::gpu_zero(); n];
176 let mut ax = vec![T::gpu_zero(); n];
177 spmv(x, &mut ax)?;
178 for i in 0..n {
179 r[i] = sub_t(b[i], ax[i]);
180 }
181 let r_norm = vec_norm(&r, n);
182
183 Ok(IterativeSolverResult {
184 iterations: total_iters,
185 residual: from_f64(r_norm),
186 converged: r_norm < abs_tol,
187 })
188}
189
190#[allow(clippy::too_many_arguments)]
199fn fgmres_cycle<T, P, F>(
200 spmv: &F,
201 precond: &P,
202 b: &[T],
203 x: &mut [T],
204 n: usize,
205 m: usize,
206 abs_tol: f64,
207 max_iters: usize,
208) -> SolverResult<(u32, bool, f64)>
209where
210 T: GpuFloat,
211 P: Preconditioner<T>,
212 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
213{
214 let mut r = vec![T::gpu_zero(); n];
216 let mut ax = vec![T::gpu_zero(); n];
217 spmv(x, &mut ax)?;
218 for i in 0..n {
219 r[i] = sub_t(b[i], ax[i]);
220 }
221 let beta = vec_norm(&r, n);
222
223 if beta < abs_tol {
224 return Ok((0, true, beta));
225 }
226
227 let mut v_basis: Vec<Vec<T>> = Vec::with_capacity(m + 1);
229 let mut z_basis: Vec<Vec<T>> = Vec::with_capacity(m);
231
232 let inv_beta = from_f64(1.0 / beta);
234 let v0: Vec<T> = r.iter().map(|&ri| mul_t(ri, inv_beta)).collect();
235 v_basis.push(v0);
236
237 let mut h = vec![vec![0.0_f64; m + 1]; m];
239
240 let mut cs = vec![0.0_f64; m];
242 let mut sn = vec![0.0_f64; m];
243
244 let mut g = vec![0.0_f64; m + 1];
246 g[0] = beta;
247
248 let mut j = 0;
249 let max_j = m.min(max_iters);
250 let mut converged = false;
251
252 while j < max_j {
253 let mut z_j = vec![T::gpu_zero(); n];
255 precond.apply(&v_basis[j], &mut z_j)?;
256 z_basis.push(z_j);
257
258 let mut w = vec![T::gpu_zero(); n];
260 spmv(&z_basis[j], &mut w)?;
261
262 for i in 0..=j {
264 h[j][i] = dot_product(&v_basis[i], &w, n);
265 let h_ij_t = from_f64(h[j][i]);
266 for k in 0..n {
267 w[k] = sub_t(w[k], mul_t(h_ij_t, v_basis[i][k]));
268 }
269 }
270
271 let w_norm = vec_norm(&w, n);
273 h[j][j + 1] = w_norm;
274
275 if w_norm > 1e-300 {
276 let inv_w = from_f64(1.0 / w_norm);
277 let vj1: Vec<T> = w.iter().map(|&wi| mul_t(wi, inv_w)).collect();
278 v_basis.push(vj1);
279 } else {
280 let vj1 = vec![T::gpu_zero(); n];
282 v_basis.push(vj1);
283 }
284
285 for i in 0..j {
287 let tmp = cs[i] * h[j][i] + sn[i] * h[j][i + 1];
288 h[j][i + 1] = -sn[i] * h[j][i] + cs[i] * h[j][i + 1];
289 h[j][i] = tmp;
290 }
291
292 let (c, s) = givens_rotation(h[j][j], h[j][j + 1]);
294 cs[j] = c;
295 sn[j] = s;
296
297 h[j][j] = c * h[j][j] + s * h[j][j + 1];
299 h[j][j + 1] = 0.0;
300
301 let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
303 g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
304 g[j] = tmp;
305
306 j += 1;
307
308 if g[j].abs() < abs_tol {
310 converged = true;
311 break;
312 }
313 }
314
315 let mut y = vec![0.0_f64; j];
317 for i in (0..j).rev() {
318 y[i] = g[i];
319 for k in (i + 1)..j {
320 y[i] -= h[k][i] * y[k];
321 }
322 if h[i][i].abs() > 1e-300 {
323 y[i] /= h[i][i];
324 }
325 }
326
327 for i in 0..j {
329 let yi_t = from_f64(y[i]);
330 for k in 0..n {
331 x[k] = add_t(x[k], mul_t(yi_t, z_basis[i][k]));
332 }
333 }
334
335 let mut r_final = vec![T::gpu_zero(); n];
337 let mut ax_final = vec![T::gpu_zero(); n];
338 spmv(x, &mut ax_final)?;
339 for i in 0..n {
340 r_final[i] = sub_t(b[i], ax_final[i]);
341 }
342 let r_norm = vec_norm(&r_final, n);
343
344 Ok((j as u32, converged || r_norm < abs_tol, r_norm))
345}
346
347fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
352 if b.abs() < 1e-300 {
353 return (1.0, 0.0);
354 }
355 if a.abs() < 1e-300 {
356 return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
357 }
358 let r = (a * a + b * b).sqrt();
359 (a / r, b / r)
360}
361
362fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
363 let mut sum = 0.0_f64;
364 for i in 0..n {
365 sum += to_f64(a[i]) * to_f64(b[i]);
366 }
367 sum
368}
369
370fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
371 dot_product(v, v, n).sqrt()
372}
373
374fn add_t<T: GpuFloat>(a: T, b: T) -> T {
375 from_f64(to_f64(a) + to_f64(b))
376}
377
378fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
379 from_f64(to_f64(a) - to_f64(b))
380}
381
382fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
383 from_f64(to_f64(a) * to_f64(b))
384}
385
386#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::sparse::preconditioned::IdentityPreconditioner;
394
395 #[test]
396 fn fgmres_config_default() {
397 let cfg = FgmresConfig::default();
398 assert_eq!(cfg.restart, 30);
399 assert_eq!(cfg.max_iter, 1000);
400 assert!((cfg.tol - 1e-6).abs() < 1e-15);
401 }
402
403 #[test]
404 fn fgmres_config_custom() {
405 let cfg = FgmresConfig {
406 restart: 50,
407 max_iter: 2000,
408 tol: 1e-10,
409 };
410 assert_eq!(cfg.restart, 50);
411 assert_eq!(cfg.max_iter, 2000);
412 }
413
414 #[test]
415 fn givens_rotation_basic() {
416 let (cs, sn) = givens_rotation(3.0, 4.0);
417 let r = cs * 3.0 + sn * 4.0;
418 assert!((r - 5.0).abs() < 1e-10);
419 }
420
421 #[test]
422 fn givens_rotation_zero_b() {
423 let (cs, sn) = givens_rotation(5.0, 0.0);
424 assert!((cs - 1.0).abs() < 1e-15);
425 assert!(sn.abs() < 1e-15);
426 }
427
428 #[test]
429 fn givens_rotation_zero_a() {
430 let (cs, sn) = givens_rotation(0.0, 3.0);
431 assert!(cs.abs() < 1e-15);
432 assert!((sn - 1.0).abs() < 1e-15);
433 }
434
435 #[test]
436 fn dot_product_basic() {
437 let a = [1.0_f64, 2.0, 3.0];
438 let b = [4.0_f64, 5.0, 6.0];
439 assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
440 }
441
442 #[test]
443 fn vec_norm_unit() {
444 let v = [1.0_f64, 0.0, 0.0];
445 assert!((vec_norm(&v, 3) - 1.0).abs() < 1e-15);
446 }
447
448 #[test]
449 fn vec_norm_345() {
450 let v = [3.0_f64, 4.0];
451 assert!((vec_norm(&v, 2) - 5.0).abs() < 1e-10);
452 }
453
454 #[test]
455 fn add_sub_mul_helpers() {
456 let a = 3.0_f64;
457 let b = 4.0_f64;
458 assert!((to_f64(add_t(a, b)) - 7.0).abs() < 1e-15);
459 assert!((to_f64(sub_t(a, b)) - (-1.0)).abs() < 1e-15);
460 assert!((to_f64(mul_t(a, b)) - 12.0).abs() < 1e-15);
461 }
462
463 #[test]
464 fn identity_preconditioner_with_fgmres() {
465 let _precond = IdentityPreconditioner;
466 let r = [1.0_f64, 2.0, 3.0];
468 let mut z = [0.0_f64; 3];
469 let result = _precond.apply(&r, &mut z);
470 assert!(result.is_ok());
471 assert!((z[0] - 1.0).abs() < 1e-15);
472 }
473
474 #[test]
475 fn f64_conversion_roundtrip() {
476 let val = std::f64::consts::PI;
477 let as_f64 = to_f64(val);
478 let back: f64 = from_f64(as_f64);
479 assert!((back - val).abs() < 1e-15);
480 }
481
482 #[test]
483 fn f32_conversion_roundtrip() {
484 let val = std::f32::consts::PI;
485 let as_f64 = to_f64(val);
486 let back: f32 = from_f64(as_f64);
487 assert!((back - val).abs() < 1e-5);
488 }
489}