scirs2_sparse/linalg/
cgs.rs

1use crate::error::{SparseError, SparseResult};
2use crate::linalg::interface::LinearOperator;
3use crate::linalg::iterative::{dot, norm2, BiCGOptions, IterationResult};
4use scirs2_core::numeric::{Float, NumAssign, SparseElement};
5use std::iter::Sum;
6
7/// Options for CGS solver
8pub type CGSOptions<F> = BiCGOptions<F>;
9pub type CGSResult<F> = IterationResult<F>;
10
11/// Conjugate Gradient Squared solver (CGS)
12///
13/// Implementation following the algorithm from "Templates for the Solution of Linear Systems"
14/// by Barrett et al. This is for non-symmetric linear systems.
15#[allow(dead_code)]
16pub fn cgs<F>(
17    a: &dyn LinearOperator<F>,
18    b: &[F],
19    options: CGSOptions<F>,
20) -> SparseResult<CGSResult<F>>
21where
22    F: Float + NumAssign + Sum + SparseElement + 'static,
23{
24    let (rows, cols) = a.shape();
25    if rows != cols {
26        return Err(SparseError::ValueError(
27            "Matrix must be square for CGS solver".to_string(),
28        ));
29    }
30    if b.len() != rows {
31        return Err(SparseError::DimensionMismatch {
32            expected: rows,
33            found: b.len(),
34        });
35    }
36
37    let n = rows;
38
39    // Initialize solution
40    let mut x: Vec<F> = match &options.x0 {
41        Some(x0) => {
42            if x0.len() != n {
43                return Err(SparseError::DimensionMismatch {
44                    expected: n,
45                    found: x0.len(),
46                });
47            }
48            x0.clone()
49        }
50        None => vec![F::sparse_zero(); n],
51    };
52
53    // Compute initial residual: r = b - A*x
54    let ax = a.matvec(&x)?;
55    let mut r: Vec<F> = b.iter().zip(&ax).map(|(&bi, &axi)| bi - axi).collect();
56
57    // Check if initial guess is solution
58    let mut rnorm = norm2(&r);
59    let bnorm = norm2(b);
60    let tolerance = F::max(options.atol, options.rtol * bnorm);
61
62    if rnorm <= tolerance {
63        return Ok(CGSResult {
64            x,
65            iterations: 0,
66            residual_norm: rnorm,
67            converged: true,
68            message: "Converged with initial guess".to_string(),
69        });
70    }
71
72    // Choose arbitrary r̃ (usually r̃ = r)
73    let r_tilde = r.clone();
74
75    // Initialize vectors
76    let mut u = vec![F::sparse_zero(); n];
77    let mut p = vec![F::sparse_zero(); n];
78    let mut q = vec![F::sparse_zero(); n];
79
80    let mut rho = F::sparse_one();
81    let mut iterations = 0;
82
83    // Main CGS iteration
84    while iterations < options.max_iter {
85        // Compute ρ = (r̃, r)
86        let rho_new = dot(&r_tilde, &r);
87
88        // Check for breakdown
89        if rho_new.abs() < F::epsilon() * F::from(10).unwrap() {
90            return Ok(CGSResult {
91                x,
92                iterations,
93                residual_norm: rnorm,
94                converged: false,
95                message: "CGS breakdown: rho ≈ 0".to_string(),
96            });
97        }
98
99        // Compute β = ρ_i / ρ_{i-1}
100        let beta = if iterations == 0 {
101            F::sparse_zero()
102        } else {
103            rho_new / rho
104        };
105
106        // Update u and p
107        for i in 0..n {
108            u[i] = r[i] + beta * q[i];
109            p[i] = u[i] + beta * (q[i] + beta * p[i]);
110        }
111
112        // Apply right preconditioner if provided
113        let p_prec = if let Some(m) = &options.right_preconditioner {
114            m.matvec(&p)?
115        } else {
116            p.clone()
117        };
118
119        // v = A * M^{-1} * p
120        let v = a.matvec(&p_prec)?;
121
122        // σ = (r̃, v)
123        let sigma = dot(&r_tilde, &v);
124
125        // Check for breakdown
126        if sigma.abs() < F::epsilon() * F::from(10).unwrap() {
127            return Ok(CGSResult {
128                x,
129                iterations,
130                residual_norm: rnorm,
131                converged: false,
132                message: "CGS breakdown: sigma ≈ 0".to_string(),
133            });
134        }
135
136        // α = ρ / σ
137        let alpha = rho_new / sigma;
138
139        // Update q
140        for i in 0..n {
141            q[i] = u[i] - alpha * v[i];
142        }
143
144        // Compute u + q
145        let u_plus_q: Vec<F> = u.iter().zip(&q).map(|(&ui, &qi)| ui + qi).collect();
146
147        // Apply right preconditioner if provided
148        let u_plus_q_prec = if let Some(m) = &options.right_preconditioner {
149            m.matvec(&u_plus_q)?
150        } else {
151            u_plus_q
152        };
153
154        // Update x
155        for i in 0..n {
156            x[i] += alpha * u_plus_q_prec[i];
157        }
158
159        // Apply right preconditioner to q
160        let q_prec = if let Some(m) = &options.right_preconditioner {
161            m.matvec(&q)?
162        } else {
163            q.clone()
164        };
165
166        // Compute A * M^{-1} * q
167        let aq = a.matvec(&q_prec)?;
168
169        // Update r
170        for i in 0..n {
171            r[i] -= alpha * (v[i] + aq[i]);
172        }
173
174        rho = rho_new;
175        iterations += 1;
176
177        // Check convergence
178        rnorm = norm2(&r);
179        if rnorm <= tolerance {
180            break;
181        }
182    }
183
184    Ok(CGSResult {
185        x,
186        iterations,
187        residual_norm: rnorm,
188        converged: rnorm <= tolerance,
189        message: if rnorm <= tolerance {
190            "Converged".to_string()
191        } else {
192            "Maximum iterations reached".to_string()
193        },
194    })
195}