numr/algorithm/iterative/impl_generic/
bicgstab.rs1use crate::algorithm::sparse_linalg::{IluDecomposition, IluOptions, SparseLinAlgAlgorithms};
7use crate::dtype::DType;
8use crate::error::{Error, Result};
9use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
10use crate::runtime::Runtime;
11use crate::sparse::{CsrData, SparseOps};
12use crate::tensor::Tensor;
13
14use super::super::helpers::{apply_ilu0_preconditioner, vector_dot, vector_norm};
15use super::super::types::{BiCgStabOptions, BiCgStabResult, PreconditionerType};
16
17pub fn bicgstab_impl<R, C>(
22 client: &C,
23 a: &CsrData<R>,
24 b: &Tensor<R>,
25 x0: Option<&Tensor<R>>,
26 options: BiCgStabOptions,
27) -> Result<BiCgStabResult<R>>
28where
29 R: Runtime<DType = DType>,
30 R::Client: SparseOps<R>,
31 C: SparseLinAlgAlgorithms<R>
32 + SparseOps<R>
33 + BinaryOps<R>
34 + UnaryOps<R>
35 + ReduceOps<R>
36 + ScalarOps<R>,
37{
38 let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
39 let device = b.device();
40 let dtype = b.dtype();
41
42 if !matches!(dtype, DType::F32 | DType::F64) {
44 return Err(Error::UnsupportedDType {
45 dtype,
46 op: "bicgstab",
47 });
48 }
49
50 let mut x = match x0 {
52 Some(x0) => x0.clone(),
53 None => Tensor::<R>::zeros(&[n], dtype, device),
54 };
55
56 let precond: Option<IluDecomposition<R>> = match options.preconditioner {
58 PreconditionerType::None => None,
59 PreconditionerType::Ilu0 => {
60 let ilu = client.ilu0(a, IluOptions::default())?;
61 Some(ilu)
62 }
63 PreconditionerType::Amg => {
64 return Err(Error::Internal(
65 "AMG preconditioner not supported for BiCGSTAB - use amg_preconditioned_cg"
66 .to_string(),
67 ));
68 }
69 PreconditionerType::Ic0 => {
70 return Err(Error::Internal(
71 "IC0 preconditioner not yet supported for BiCGSTAB - use ILU0".to_string(),
72 ));
73 }
74 };
75
76 let b_norm = vector_norm(client, b)?;
78 if b_norm < options.atol {
79 return Ok(BiCgStabResult {
80 solution: x,
81 iterations: 0,
82 residual_norm: b_norm,
83 converged: true,
84 });
85 }
86
87 let ax = a.spmv(&x)?;
89 let mut r = client.sub(b, &ax)?;
90
91 let r_hat = r.clone();
93
94 let mut rho = 1.0;
96 let mut alpha = 1.0;
97 let mut omega = 1.0;
98
99 let mut v = Tensor::<R>::zeros(&[n], dtype, device);
100 let mut p = Tensor::<R>::zeros(&[n], dtype, device);
101
102 for iter in 0..options.max_iter {
103 let rho_new = vector_dot(client, &r_hat, &r)?;
105
106 if rho_new.abs() < 1e-40 {
108 let res_norm = vector_norm(client, &r)?;
110 return Ok(BiCgStabResult {
111 solution: x,
112 iterations: iter,
113 residual_norm: res_norm,
114 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
115 });
116 }
117
118 let beta = (rho_new / rho) * (alpha / omega);
120
121 let p_scaled = client.mul_scalar(&p, beta)?;
124 let v_scaled = client.mul_scalar(&v, beta * omega)?;
125 let temp = client.sub(&p_scaled, &v_scaled)?;
126 p = client.add(&r, &temp)?;
127
128 let p_hat = apply_ilu0_preconditioner(client, &precond, &p)?;
130
131 v = a.spmv(&p_hat)?;
133
134 let r_hat_v = vector_dot(client, &r_hat, &v)?;
136 if r_hat_v.abs() < 1e-40 {
137 let res_norm = vector_norm(client, &r)?;
139 return Ok(BiCgStabResult {
140 solution: x,
141 iterations: iter,
142 residual_norm: res_norm,
143 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
144 });
145 }
146 alpha = rho_new / r_hat_v;
147
148 let v_scaled = client.mul_scalar(&v, alpha)?;
150 let s = client.sub(&r, &v_scaled)?;
151
152 let s_norm = vector_norm(client, &s)?;
154 if s_norm < options.atol || s_norm / b_norm < options.rtol {
155 let p_hat_scaled = client.mul_scalar(&p_hat, alpha)?;
157 x = client.add(&x, &p_hat_scaled)?;
158
159 return Ok(BiCgStabResult {
160 solution: x,
161 iterations: iter + 1,
162 residual_norm: s_norm,
163 converged: true,
164 });
165 }
166
167 let s_hat = apply_ilu0_preconditioner(client, &precond, &s)?;
169
170 let t = a.spmv(&s_hat)?;
172
173 let t_s = vector_dot(client, &t, &s)?;
175 let t_t = vector_dot(client, &t, &t)?;
176 if t_t.abs() < 1e-40 {
177 let res_norm = vector_norm(client, &s)?;
179 return Ok(BiCgStabResult {
180 solution: x,
181 iterations: iter + 1,
182 residual_norm: res_norm,
183 converged: false,
184 });
185 }
186 omega = t_s / t_t;
187
188 let p_hat_scaled = client.mul_scalar(&p_hat, alpha)?;
190 let s_hat_scaled = client.mul_scalar(&s_hat, omega)?;
191 x = client.add(&x, &p_hat_scaled)?;
192 x = client.add(&x, &s_hat_scaled)?;
193
194 let t_scaled = client.mul_scalar(&t, omega)?;
196 r = client.sub(&s, &t_scaled)?;
197
198 rho = rho_new;
200
201 let res_norm = vector_norm(client, &r)?;
203 if res_norm < options.atol || res_norm / b_norm < options.rtol {
204 return Ok(BiCgStabResult {
205 solution: x,
206 iterations: iter + 1,
207 residual_norm: res_norm,
208 converged: true,
209 });
210 }
211
212 if omega.abs() < 1e-40 {
214 return Ok(BiCgStabResult {
215 solution: x,
216 iterations: iter + 1,
217 residual_norm: res_norm,
218 converged: false,
219 });
220 }
221 }
222
223 let ax = a.spmv(&x)?;
225 let r_final = client.sub(b, &ax)?;
226 let final_residual = vector_norm(client, &r_final)?;
227
228 Ok(BiCgStabResult {
229 solution: x,
230 iterations: options.max_iter,
231 residual_norm: final_residual,
232 converged: false,
233 })
234}