Skip to main content

numr/algorithm/iterative/impl_generic/
cgs.rs

1//! Generic CGS (Conjugate Gradient Squared) implementation
2//!
3//! Sonneveld's CGS for non-symmetric systems. Faster convergence than BiCGSTAB
4//! when it works, but can be less stable (residual norms may oscillate wildly).
5
6use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
7use crate::dtype::DType;
8use crate::error::{Error, Result};
9use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
10use crate::runtime::Runtime;
11use crate::sparse::{CsrData, SparseOps};
12use crate::tensor::Tensor;
13
14use super::super::helpers::{BREAKDOWN_TOL, apply_ilu0_preconditioner, vector_dot, vector_norm};
15use super::super::types::{CgsOptions, CgsResult, PreconditionerType};
16
17/// Generic preconditioned CGS implementation
18///
19/// Algorithm (Sonneveld):
20/// ```text
21/// x = x0, r = b - A*x, r_hat = r
22/// rho = <r_hat, r>, u = r, p = r
23///
24/// for iter = 1, 2, ...:
25///     p_hat = M^-1 * p
26///     v = A * p_hat
27///     sigma = <r_hat, v>
28///     alpha = rho / sigma
29///
30///     q = u - alpha * v
31///     u_plus_q = u + q
32///     uq_hat = M^-1 * u_plus_q
33///     x = x + alpha * uq_hat
34///     r = r - alpha * A * uq_hat
35///
36///     if ||r|| < tol: return
37///
38///     rho_new = <r_hat, r>
39///     beta = rho_new / rho
40///     u = r + beta * q
41///     p = u + beta * (q + beta * p)
42///     rho = rho_new
43/// ```
44pub fn cgs_impl<R, C>(
45    client: &C,
46    a: &CsrData<R>,
47    b: &Tensor<R>,
48    x0: Option<&Tensor<R>>,
49    options: CgsOptions,
50) -> Result<CgsResult<R>>
51where
52    R: Runtime<DType = DType>,
53    R::Client: SparseOps<R>,
54    C: SparseLinAlgAlgorithms<R>
55        + SparseOps<R>
56        + BinaryOps<R>
57        + UnaryOps<R>
58        + ReduceOps<R>
59        + ScalarOps<R>,
60{
61    let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
62    let device = b.device();
63    let dtype = b.dtype();
64
65    if !matches!(dtype, DType::F32 | DType::F64) {
66        return Err(Error::UnsupportedDType { dtype, op: "cgs" });
67    }
68
69    let mut x = match x0 {
70        Some(x0) => x0.clone(),
71        None => Tensor::<R>::zeros(&[n], dtype, device),
72    };
73
74    let precond = match options.preconditioner {
75        PreconditionerType::None => None,
76        PreconditionerType::Ilu0 => Some(client.ilu0(a, IluOptions::default())?),
77        PreconditionerType::Amg => {
78            return Err(Error::Internal(
79                "AMG preconditioner not supported for CGS — use amg_preconditioned_cg".to_string(),
80            ));
81        }
82        PreconditionerType::Ic0 => {
83            return Err(Error::Internal(
84                "IC0 preconditioner not supported for CGS — use ILU0".to_string(),
85            ));
86        }
87    };
88
89    let b_norm = vector_norm(client, b)?;
90    if b_norm < options.atol {
91        return Ok(CgsResult {
92            solution: x,
93            iterations: 0,
94            residual_norm: b_norm,
95            converged: true,
96        });
97    }
98
99    // r = b - A*x
100    let ax = a.spmv(&x)?;
101    let mut r = client.sub(b, &ax)?;
102
103    // r_hat = r (shadow residual, kept constant)
104    let r_hat = r.clone();
105
106    let mut rho = vector_dot(client, &r_hat, &r)?;
107
108    // u = r, p = r
109    let mut u = r.clone();
110    let mut p = r.clone();
111
112    for iter in 0..options.max_iter {
113        if rho.abs() < BREAKDOWN_TOL {
114            let res_norm = vector_norm(client, &r)?;
115            return Ok(CgsResult {
116                solution: x,
117                iterations: iter,
118                residual_norm: res_norm,
119                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
120            });
121        }
122
123        // p_hat = M^-1 * p
124        let p_hat = apply_ilu0_preconditioner(client, &precond, &p)?;
125
126        // v = A * p_hat
127        let v = a.spmv(&p_hat)?;
128
129        // sigma = <r_hat, v>
130        let sigma = vector_dot(client, &r_hat, &v)?;
131        if sigma.abs() < BREAKDOWN_TOL {
132            let res_norm = vector_norm(client, &r)?;
133            return Ok(CgsResult {
134                solution: x,
135                iterations: iter,
136                residual_norm: res_norm,
137                converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
138            });
139        }
140        let alpha = rho / sigma;
141
142        // q = u - alpha * v
143        let v_scaled = client.mul_scalar(&v, alpha)?;
144        let q = client.sub(&u, &v_scaled)?;
145
146        // u_plus_q = u + q
147        let u_plus_q = client.add(&u, &q)?;
148
149        // uq_hat = M^-1 * (u + q)
150        let uq_hat = apply_ilu0_preconditioner(client, &precond, &u_plus_q)?;
151
152        // x = x + alpha * uq_hat
153        let uq_scaled = client.mul_scalar(&uq_hat, alpha)?;
154        x = client.add(&x, &uq_scaled)?;
155
156        // r = r - alpha * A * uq_hat
157        let a_uq = a.spmv(&uq_hat)?;
158        let a_uq_scaled = client.mul_scalar(&a_uq, alpha)?;
159        r = client.sub(&r, &a_uq_scaled)?;
160
161        let res_norm = vector_norm(client, &r)?;
162        if res_norm < options.atol || res_norm / b_norm < options.rtol {
163            return Ok(CgsResult {
164                solution: x,
165                iterations: iter + 1,
166                residual_norm: res_norm,
167                converged: true,
168            });
169        }
170
171        // rho_new = <r_hat, r>
172        let rho_new = vector_dot(client, &r_hat, &r)?;
173
174        let beta = rho_new / rho;
175
176        // u = r + beta * q
177        let q_scaled = client.mul_scalar(&q, beta)?;
178        u = client.add(&r, &q_scaled)?;
179
180        // p = u + beta * (q + beta * p)
181        let p_scaled = client.mul_scalar(&p, beta)?;
182        let q_plus_bp = client.add(&q, &p_scaled)?;
183        let qbp_scaled = client.mul_scalar(&q_plus_bp, beta)?;
184        p = client.add(&u, &qbp_scaled)?;
185
186        rho = rho_new;
187    }
188
189    let ax = a.spmv(&x)?;
190    let r_final = client.sub(b, &ax)?;
191    let final_residual = vector_norm(client, &r_final)?;
192
193    Ok(CgsResult {
194        solution: x,
195        iterations: options.max_iter,
196        residual_norm: final_residual,
197        converged: false,
198    })
199}