1#![allow(dead_code)]
25
26use oxicuda_blas::GpuFloat;
27
28use crate::error::{SolverError, SolverResult};
29use crate::handle::SolverHandle;
30
31fn to_f64<T: GpuFloat>(val: T) -> f64 {
36 if T::SIZE == 4 {
37 f32::from_bits(val.to_bits_u64() as u32) as f64
38 } else {
39 f64::from_bits(val.to_bits_u64())
40 }
41}
42
43fn from_f64<T: GpuFloat>(val: f64) -> T {
44 if T::SIZE == 4 {
45 T::from_bits_u64(u64::from((val as f32).to_bits()))
46 } else {
47 T::from_bits_u64(val.to_bits())
48 }
49}
50
51#[derive(Debug, Clone)]
57pub struct BiCgStabConfig {
58 pub max_iter: u32,
60 pub tol: f64,
62}
63
64impl Default for BiCgStabConfig {
65 fn default() -> Self {
66 Self {
67 max_iter: 1000,
68 tol: 1e-6,
69 }
70 }
71}
72
73pub fn bicgstab_solve<T, F>(
100 _handle: &SolverHandle,
101 spmv: F,
102 b: &[T],
103 x: &mut [T],
104 n: u32,
105 config: &BiCgStabConfig,
106) -> SolverResult<u32>
107where
108 T: GpuFloat,
109 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
110{
111 let n_usize = n as usize;
112
113 if b.len() < n_usize {
115 return Err(SolverError::DimensionMismatch(format!(
116 "bicgstab_solve: b length ({}) < n ({n})",
117 b.len()
118 )));
119 }
120 if x.len() < n_usize {
121 return Err(SolverError::DimensionMismatch(format!(
122 "bicgstab_solve: x length ({}) < n ({n})",
123 x.len()
124 )));
125 }
126 if n == 0 {
127 return Ok(0);
128 }
129
130 let b_norm = vec_norm(b, n_usize);
132 let abs_tol = if b_norm > 0.0 {
133 config.tol * b_norm
134 } else {
135 for xi in x.iter_mut().take(n_usize) {
136 *xi = T::gpu_zero();
137 }
138 return Ok(0);
139 };
140
141 let mut r = vec![T::gpu_zero(); n_usize];
143 let mut tmp = vec![T::gpu_zero(); n_usize];
144 spmv(x, &mut tmp)?;
145 for i in 0..n_usize {
146 r[i] = sub_t(b[i], tmp[i]);
147 }
148
149 let r0_hat = r.clone();
151
152 let mut rho = 1.0_f64;
154 let mut alpha = 1.0_f64;
155 let mut omega = 1.0_f64;
156
157 let mut v = vec![T::gpu_zero(); n_usize];
159 let mut p = vec![T::gpu_zero(); n_usize];
160 let mut s = vec![T::gpu_zero(); n_usize];
161 let mut t = vec![T::gpu_zero(); n_usize];
162
163 for iter in 0..config.max_iter {
164 let rho_new = dot_product(&r0_hat, &r, n_usize);
166
167 if rho_new.abs() < 1e-300 {
168 return Err(SolverError::InternalError(
169 "bicgstab_solve: rho breakdown (r0_hat^T * r ~ 0)".into(),
170 ));
171 }
172
173 let beta = if rho.abs() > 1e-300 && omega.abs() > 1e-300 {
175 (rho_new / rho) * (alpha / omega)
176 } else {
177 0.0
178 };
179 let beta_t = from_f64(beta);
180 let omega_t = from_f64(omega);
181
182 for i in 0..n_usize {
184 let pv = sub_t(p[i], mul_t(omega_t, v[i]));
185 p[i] = add_t(r[i], mul_t(beta_t, pv));
186 }
187
188 spmv(&p, &mut v)?;
190
191 let r0v = dot_product(&r0_hat, &v, n_usize);
193 if r0v.abs() < 1e-300 {
194 return Err(SolverError::InternalError(
195 "bicgstab_solve: alpha breakdown (r0_hat^T * v ~ 0)".into(),
196 ));
197 }
198 alpha = rho_new / r0v;
199 let alpha_t = from_f64(alpha);
200
201 for i in 0..n_usize {
203 s[i] = sub_t(r[i], mul_t(alpha_t, v[i]));
204 }
205
206 let s_norm = vec_norm(&s, n_usize);
208 if s_norm < abs_tol {
209 for i in 0..n_usize {
211 x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
212 }
213 return Ok(iter + 1);
214 }
215
216 spmv(&s, &mut t)?;
218
219 let tt = dot_product(&t, &t, n_usize);
221 omega = if tt.abs() > 1e-300 {
222 dot_product(&t, &s, n_usize) / tt
223 } else {
224 0.0
225 };
226 let omega_new_t = from_f64(omega);
227
228 for i in 0..n_usize {
230 x[i] = add_t(x[i], add_t(mul_t(alpha_t, p[i]), mul_t(omega_new_t, s[i])));
231 }
232
233 for i in 0..n_usize {
235 r[i] = sub_t(s[i], mul_t(omega_new_t, t[i]));
236 }
237
238 let r_norm = vec_norm(&r, n_usize);
240 if r_norm < abs_tol {
241 return Ok(iter + 1);
242 }
243
244 if omega.abs() < 1e-300 {
245 return Err(SolverError::InternalError(
246 "bicgstab_solve: omega breakdown".into(),
247 ));
248 }
249
250 rho = rho_new;
251 }
252
253 Err(SolverError::ConvergenceFailure {
254 iterations: config.max_iter,
255 residual: vec_norm(&r, n_usize),
256 })
257}
258
259fn dot_product<T: GpuFloat>(a: &[T], b: &[T], n: usize) -> f64 {
264 let mut sum = 0.0_f64;
265 for i in 0..n {
266 sum += to_f64(a[i]) * to_f64(b[i]);
267 }
268 sum
269}
270
271fn vec_norm<T: GpuFloat>(v: &[T], n: usize) -> f64 {
272 dot_product(v, v, n).sqrt()
273}
274
275fn add_t<T: GpuFloat>(a: T, b: T) -> T {
276 from_f64(to_f64(a) + to_f64(b))
277}
278
279fn sub_t<T: GpuFloat>(a: T, b: T) -> T {
280 from_f64(to_f64(a) - to_f64(b))
281}
282
283fn mul_t<T: GpuFloat>(a: T, b: T) -> T {
284 from_f64(to_f64(a) * to_f64(b))
285}
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294
295 fn bicgstab_solve_cpu<T, F>(
300 spmv: F,
301 b: &[T],
302 x: &mut [T],
303 n: u32,
304 config: &BiCgStabConfig,
305 ) -> SolverResult<u32>
306 where
307 T: GpuFloat,
308 F: Fn(&[T], &mut [T]) -> SolverResult<()>,
309 {
310 let n_usize = n as usize;
311
312 if b.len() < n_usize {
313 return Err(SolverError::DimensionMismatch(format!(
314 "bicgstab_solve_cpu: b length ({}) < n ({n})",
315 b.len()
316 )));
317 }
318 if x.len() < n_usize {
319 return Err(SolverError::DimensionMismatch(format!(
320 "bicgstab_solve_cpu: x length ({}) < n ({n})",
321 x.len()
322 )));
323 }
324 if n == 0 {
325 return Ok(0);
326 }
327
328 let b_norm = vec_norm(b, n_usize);
329 let abs_tol = if b_norm > 0.0 {
330 config.tol * b_norm
331 } else {
332 for xi in x.iter_mut().take(n_usize) {
333 *xi = T::gpu_zero();
334 }
335 return Ok(0);
336 };
337
338 let mut r = vec![T::gpu_zero(); n_usize];
339 let mut tmp = vec![T::gpu_zero(); n_usize];
340 spmv(x, &mut tmp)?;
341 for i in 0..n_usize {
342 r[i] = sub_t(b[i], tmp[i]);
343 }
344
345 let r0_hat = r.clone();
346 let mut rho = 1.0_f64;
347 let mut alpha = 1.0_f64;
348 let mut omega = 1.0_f64;
349 let mut v = vec![T::gpu_zero(); n_usize];
350 let mut p = vec![T::gpu_zero(); n_usize];
351 let mut s = vec![T::gpu_zero(); n_usize];
352 let mut t = vec![T::gpu_zero(); n_usize];
353
354 for iter in 0..config.max_iter {
355 let rho_new = dot_product(&r0_hat, &r, n_usize);
356 if rho_new.abs() < 1e-300 {
357 return Err(SolverError::InternalError(
358 "bicgstab_solve_cpu: rho breakdown".into(),
359 ));
360 }
361
362 let beta = if rho.abs() > 1e-300 && omega.abs() > 1e-300 {
363 (rho_new / rho) * (alpha / omega)
364 } else {
365 0.0
366 };
367 let beta_t = from_f64(beta);
368 let omega_t = from_f64(omega);
369
370 for i in 0..n_usize {
371 let pv = sub_t(p[i], mul_t(omega_t, v[i]));
372 p[i] = add_t(r[i], mul_t(beta_t, pv));
373 }
374
375 spmv(&p, &mut v)?;
376
377 let r0v = dot_product(&r0_hat, &v, n_usize);
378 if r0v.abs() < 1e-300 {
379 return Err(SolverError::InternalError(
380 "bicgstab_solve_cpu: alpha breakdown".into(),
381 ));
382 }
383 alpha = rho_new / r0v;
384 let alpha_t = from_f64(alpha);
385
386 for i in 0..n_usize {
387 s[i] = sub_t(r[i], mul_t(alpha_t, v[i]));
388 }
389
390 let s_norm = vec_norm(&s, n_usize);
391 if s_norm < abs_tol {
392 for i in 0..n_usize {
393 x[i] = add_t(x[i], mul_t(alpha_t, p[i]));
394 }
395 return Ok(iter + 1);
396 }
397
398 spmv(&s, &mut t)?;
399
400 let tt = dot_product(&t, &t, n_usize);
401 omega = if tt.abs() > 1e-300 {
402 dot_product(&t, &s, n_usize) / tt
403 } else {
404 0.0
405 };
406 let omega_new_t = from_f64(omega);
407
408 for i in 0..n_usize {
409 x[i] = add_t(x[i], add_t(mul_t(alpha_t, p[i]), mul_t(omega_new_t, s[i])));
410 }
411
412 for i in 0..n_usize {
413 r[i] = sub_t(s[i], mul_t(omega_new_t, t[i]));
414 }
415
416 let r_norm = vec_norm(&r, n_usize);
417 if r_norm < abs_tol {
418 return Ok(iter + 1);
419 }
420
421 if omega.abs() < 1e-300 {
422 return Err(SolverError::InternalError(
423 "bicgstab_solve_cpu: omega breakdown".into(),
424 ));
425 }
426
427 rho = rho_new;
428 }
429
430 Err(SolverError::ConvergenceFailure {
431 iterations: config.max_iter,
432 residual: vec_norm(&r, n_usize),
433 })
434 }
435
436 #[test]
437 fn bicgstab_config_default() {
438 let cfg = BiCgStabConfig::default();
439 assert_eq!(cfg.max_iter, 1000);
440 assert!((cfg.tol - 1e-6).abs() < 1e-15);
441 }
442
443 #[test]
444 fn bicgstab_config_custom() {
445 let cfg = BiCgStabConfig {
446 max_iter: 2000,
447 tol: 1e-8,
448 };
449 assert_eq!(cfg.max_iter, 2000);
450 assert!((cfg.tol - 1e-8).abs() < 1e-20);
451 }
452
453 #[test]
454 fn dot_product_basic() {
455 let a = [1.0_f64, 2.0, 3.0];
456 let b = [4.0_f64, 5.0, 6.0];
457 assert!((dot_product(&a, &b, 3) - 32.0).abs() < 1e-10);
458 }
459
460 #[test]
461 fn vec_norm_basic() {
462 let v = [3.0_f64, 4.0];
463 assert!((vec_norm(&v, 2) - 5.0).abs() < 1e-10);
464 }
465
466 #[test]
471 fn bicgstab_converges_spd_3x3() {
472 let b = vec![6.0_f64, 0.0, 6.0];
473 let mut x = vec![0.0_f64; 3];
474 let config = BiCgStabConfig {
475 max_iter: 200,
476 tol: 1e-10,
477 };
478
479 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
481 out[0] = 4.0 * v[0] - v[1];
482 out[1] = -v[0] + 4.0 * v[1] - v[2];
483 out[2] = -v[1] + 4.0 * v[2];
484 Ok(())
485 };
486
487 let _iters = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config)
488 .expect("BiCGSTAB should converge on SPD system");
489
490 let x0_exact = 12.0_f64 / 7.0; let x1_exact = 6.0_f64 / 7.0; assert!(
493 (x[0] - x0_exact).abs() < 1e-7,
494 "x[0] = {} expected {x0_exact}",
495 x[0]
496 );
497 assert!(
498 (x[1] - x1_exact).abs() < 1e-7,
499 "x[1] = {} expected {x1_exact}",
500 x[1]
501 );
502 assert!(
503 (x[2] - x0_exact).abs() < 1e-7,
504 "x[2] = {} expected {x0_exact}",
505 x[2]
506 );
507 }
508
509 #[test]
511 fn bicgstab_converges_identity() {
512 let b = vec![5.0_f64, -3.0, 2.0];
513 let mut x = vec![0.0_f64; 3];
514 let config = BiCgStabConfig {
515 max_iter: 50,
516 tol: 1e-12,
517 };
518
519 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
521 out.copy_from_slice(v);
522 Ok(())
523 };
524
525 let _iters = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config)
526 .expect("BiCGSTAB should converge on identity");
527
528 assert!((x[0] - 5.0).abs() < 1e-9);
529 assert!((x[1] - (-3.0)).abs() < 1e-9);
530 assert!((x[2] - 2.0).abs() < 1e-9);
531 }
532
533 #[test]
535 fn bicgstab_zero_rhs_returns_zero() {
536 let b = vec![0.0_f64; 3];
537 let mut x = vec![1.0_f64; 3];
538 let config = BiCgStabConfig::default();
539
540 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
541 out.copy_from_slice(v);
542 Ok(())
543 };
544
545 let iters =
546 bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config).expect("zero RHS should succeed");
547 assert_eq!(iters, 0);
548 for &xi in &x {
549 assert!(xi.abs() < 1e-15);
550 }
551 }
552
553 #[test]
555 fn bicgstab_dimension_mismatch() {
556 let b = vec![1.0_f64]; let mut x = vec![0.0_f64; 3];
558 let config = BiCgStabConfig::default();
559 let spmv = |_: &[f64], _: &mut [f64]| -> SolverResult<()> { Ok(()) };
560 let result = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config);
561 assert!(matches!(result, Err(SolverError::DimensionMismatch(_))));
562 }
563
564 #[test]
568 fn bicgstab_converges_diagonal() {
569 let b = vec![2.0_f64, 9.0, 14.0];
570 let mut x = vec![0.0_f64; 3];
571 let config = BiCgStabConfig {
572 max_iter: 200,
573 tol: 1e-10,
574 };
575
576 let spmv = |v: &[f64], out: &mut [f64]| -> SolverResult<()> {
577 out[0] = 1.0 * v[0];
578 out[1] = 3.0 * v[1];
579 out[2] = 7.0 * v[2];
580 Ok(())
581 };
582
583 let _iters = bicgstab_solve_cpu(spmv, &b, &mut x, 3, &config)
584 .expect("BiCGSTAB should converge on diagonal system");
585
586 assert!((x[0] - 2.0).abs() < 1e-8, "x[0] = {} expected 2.0", x[0]);
587 assert!((x[1] - 3.0).abs() < 1e-8, "x[1] = {} expected 3.0", x[1]);
588 assert!((x[2] - 2.0).abs() < 1e-8, "x[2] = {} expected 2.0", x[2]);
589 }
590}