Skip to main content

numr/algorithm/iterative/impl_generic/
bicgstab.rs

1//! Generic BiCGSTAB implementation
2//!
3//! Bi-Conjugate Gradient Stabilized method for non-symmetric sparse systems.
4//! Alternative to GMRES with fixed memory footprint.
5
6use crate::algorithm::sparse_linalg::{IluDecomposition, IluOptions, SparseLinAlgAlgorithms};
7use crate::dtype::DType;
8use crate::error::{Error, Result};
9use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
10use crate::runtime::Runtime;
11use crate::sparse::{CsrData, SparseOps};
12use crate::tensor::Tensor;
13
14use super::super::helpers::{apply_ilu0_preconditioner, vector_dot, vector_norm};
15use super::super::types::{BiCgStabOptions, BiCgStabResult, PreconditionerType};
16
17/// Generic BiCGSTAB implementation
18///
19/// Implements right-preconditioned BiCGSTAB.
20/// Uses less memory than GMRES(m) but convergence can be less predictable.
21pub fn bicgstab_impl<R, C>(
22    client: &C,
23    a: &CsrData<R>,
24    b: &Tensor<R>,
25    x0: Option<&Tensor<R>>,
26    options: BiCgStabOptions,
27) -> Result<BiCgStabResult<R>>
28where
29    R: Runtime<DType = DType>,
30    R::Client: SparseOps<R>,
31    C: SparseLinAlgAlgorithms<R>
32        + SparseOps<R>
33        + BinaryOps<R>
34        + UnaryOps<R>
35        + ReduceOps<R>
36        + ScalarOps<R>,
37{
38    let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
39    let device = b.device();
40    let dtype = b.dtype();
41
42    // Validate dtype is floating point
43    if !matches!(dtype, DType::F32 | DType::F64) {
44        return Err(Error::UnsupportedDType {
45            dtype,
46            op: "bicgstab",
47        });
48    }
49
50    // Initialize solution
51    let mut x = match x0 {
52        Some(x0) => x0.clone(),
53        None => Tensor::<R>::zeros(&[n], dtype, device),
54    };
55
56    // Compute preconditioner if requested
57    let precond: Option<IluDecomposition<R>> = match options.preconditioner {
58        PreconditionerType::None => None,
59        PreconditionerType::Ilu0 => {
60            let ilu = client.ilu0(a, IluOptions::default())?;
61            Some(ilu)
62        }
63        PreconditionerType::Amg => {
64            return Err(Error::Internal(
65                "AMG preconditioner not supported for BiCGSTAB - use amg_preconditioned_cg"
66                    .to_string(),
67            ));
68        }
69        PreconditionerType::Ic0 => {
70            return Err(Error::Internal(
71                "IC0 preconditioner not yet supported for BiCGSTAB - use ILU0".to_string(),
72            ));
73        }
74    };
75
76    // Compute ||b|| for relative tolerance
77    let b_norm = vector_norm(client, b)?;
78    if b_norm < options.atol {
79        return Ok(BiCgStabResult {
80            solution: x,
81            iterations: 0,
82            residual_norm: b_norm,
83            converged: true,
84        });
85    }
86
87    // r = b - A @ x
88    let ax = a.spmv(&x)?;
89    let mut r = client.sub(b, &ax)?;
90
91    // r_hat = r (shadow residual, kept constant)
92    let r_hat = r.clone();
93
94    // Initialize vectors
95    let mut rho = 1.0;
96    let mut alpha = 1.0;
97    let mut omega = 1.0;
98
99    let mut v = Tensor::<R>::zeros(&[n], dtype, device);
100    let mut p = Tensor::<R>::zeros(&[n], dtype, device);
101
102    for iter in 0..options.max_iter {
103        // rho_new = <r_hat, r>
104        let rho_new = vector_dot(client, &r_hat, &r)?;
105
106        // Check for breakdown
107        if rho_new.abs() < 1e-40 {
108            // BiCGSTAB breakdown - shadow residual orthogonal to residual
109            let res_norm = vector_norm(client, &r)?;
110            return Ok(BiCgStabResult {
111                solution: x,
112                iterations: iter,
113                residual_norm: res_norm,
114                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
115            });
116        }
117
118        // beta = (rho_new / rho) * (alpha / omega)
119        let beta = (rho_new / rho) * (alpha / omega);
120
121        // p = r + beta * (p - omega * v)
122        // p = r + beta * p - beta * omega * v
123        let p_scaled = client.mul_scalar(&p, beta)?;
124        let v_scaled = client.mul_scalar(&v, beta * omega)?;
125        let temp = client.sub(&p_scaled, &v_scaled)?;
126        p = client.add(&r, &temp)?;
127
128        // Apply preconditioner: p_hat = M^-1 @ p
129        let p_hat = apply_ilu0_preconditioner(client, &precond, &p)?;
130
131        // v = A @ p_hat
132        v = a.spmv(&p_hat)?;
133
134        // alpha = rho_new / <r_hat, v>
135        let r_hat_v = vector_dot(client, &r_hat, &v)?;
136        if r_hat_v.abs() < 1e-40 {
137            // Another breakdown case
138            let res_norm = vector_norm(client, &r)?;
139            return Ok(BiCgStabResult {
140                solution: x,
141                iterations: iter,
142                residual_norm: res_norm,
143                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
144            });
145        }
146        alpha = rho_new / r_hat_v;
147
148        // s = r - alpha * v
149        let v_scaled = client.mul_scalar(&v, alpha)?;
150        let s = client.sub(&r, &v_scaled)?;
151
152        // Check convergence on s
153        let s_norm = vector_norm(client, &s)?;
154        if s_norm < options.atol || s_norm / b_norm < options.rtol {
155            // x = x + alpha * p_hat
156            let p_hat_scaled = client.mul_scalar(&p_hat, alpha)?;
157            x = client.add(&x, &p_hat_scaled)?;
158
159            return Ok(BiCgStabResult {
160                solution: x,
161                iterations: iter + 1,
162                residual_norm: s_norm,
163                converged: true,
164            });
165        }
166
167        // Apply preconditioner: s_hat = M^-1 @ s
168        let s_hat = apply_ilu0_preconditioner(client, &precond, &s)?;
169
170        // t = A @ s_hat
171        let t = a.spmv(&s_hat)?;
172
173        // omega = <t, s> / <t, t>
174        let t_s = vector_dot(client, &t, &s)?;
175        let t_t = vector_dot(client, &t, &t)?;
176        if t_t.abs() < 1e-40 {
177            // Breakdown
178            let res_norm = vector_norm(client, &s)?;
179            return Ok(BiCgStabResult {
180                solution: x,
181                iterations: iter + 1,
182                residual_norm: res_norm,
183                converged: false,
184            });
185        }
186        omega = t_s / t_t;
187
188        // x = x + alpha * p_hat + omega * s_hat
189        let p_hat_scaled = client.mul_scalar(&p_hat, alpha)?;
190        let s_hat_scaled = client.mul_scalar(&s_hat, omega)?;
191        x = client.add(&x, &p_hat_scaled)?;
192        x = client.add(&x, &s_hat_scaled)?;
193
194        // r = s - omega * t
195        let t_scaled = client.mul_scalar(&t, omega)?;
196        r = client.sub(&s, &t_scaled)?;
197
198        // Update rho
199        rho = rho_new;
200
201        // Check convergence
202        let res_norm = vector_norm(client, &r)?;
203        if res_norm < options.atol || res_norm / b_norm < options.rtol {
204            return Ok(BiCgStabResult {
205                solution: x,
206                iterations: iter + 1,
207                residual_norm: res_norm,
208                converged: true,
209            });
210        }
211
212        // Check for stagnation
213        if omega.abs() < 1e-40 {
214            return Ok(BiCgStabResult {
215                solution: x,
216                iterations: iter + 1,
217                residual_norm: res_norm,
218                converged: false,
219            });
220        }
221    }
222
223    // Max iterations reached
224    let ax = a.spmv(&x)?;
225    let r_final = client.sub(b, &ax)?;
226    let final_residual = vector_norm(client, &r_final)?;
227
228    Ok(BiCgStabResult {
229        solution: x,
230        iterations: options.max_iter,
231        residual_norm: final_residual,
232        converged: false,
233    })
234}