numr/algorithm/iterative/impl_generic/
cgs.rs1use crate::algorithm::sparse_linalg::{IluOptions, SparseLinAlgAlgorithms};
7use crate::dtype::DType;
8use crate::error::{Error, Result};
9use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
10use crate::runtime::Runtime;
11use crate::sparse::{CsrData, SparseOps};
12use crate::tensor::Tensor;
13
14use super::super::helpers::{BREAKDOWN_TOL, apply_ilu0_preconditioner, vector_dot, vector_norm};
15use super::super::types::{CgsOptions, CgsResult, PreconditionerType};
16
17pub fn cgs_impl<R, C>(
45 client: &C,
46 a: &CsrData<R>,
47 b: &Tensor<R>,
48 x0: Option<&Tensor<R>>,
49 options: CgsOptions,
50) -> Result<CgsResult<R>>
51where
52 R: Runtime<DType = DType>,
53 R::Client: SparseOps<R>,
54 C: SparseLinAlgAlgorithms<R>
55 + SparseOps<R>
56 + BinaryOps<R>
57 + UnaryOps<R>
58 + ReduceOps<R>
59 + ScalarOps<R>,
60{
61 let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
62 let device = b.device();
63 let dtype = b.dtype();
64
65 if !matches!(dtype, DType::F32 | DType::F64) {
66 return Err(Error::UnsupportedDType { dtype, op: "cgs" });
67 }
68
69 let mut x = match x0 {
70 Some(x0) => x0.clone(),
71 None => Tensor::<R>::zeros(&[n], dtype, device),
72 };
73
74 let precond = match options.preconditioner {
75 PreconditionerType::None => None,
76 PreconditionerType::Ilu0 => Some(client.ilu0(a, IluOptions::default())?),
77 PreconditionerType::Amg => {
78 return Err(Error::Internal(
79 "AMG preconditioner not supported for CGS — use amg_preconditioned_cg".to_string(),
80 ));
81 }
82 PreconditionerType::Ic0 => {
83 return Err(Error::Internal(
84 "IC0 preconditioner not supported for CGS — use ILU0".to_string(),
85 ));
86 }
87 };
88
89 let b_norm = vector_norm(client, b)?;
90 if b_norm < options.atol {
91 return Ok(CgsResult {
92 solution: x,
93 iterations: 0,
94 residual_norm: b_norm,
95 converged: true,
96 });
97 }
98
99 let ax = a.spmv(&x)?;
101 let mut r = client.sub(b, &ax)?;
102
103 let r_hat = r.clone();
105
106 let mut rho = vector_dot(client, &r_hat, &r)?;
107
108 let mut u = r.clone();
110 let mut p = r.clone();
111
112 for iter in 0..options.max_iter {
113 if rho.abs() < BREAKDOWN_TOL {
114 let res_norm = vector_norm(client, &r)?;
115 return Ok(CgsResult {
116 solution: x,
117 iterations: iter,
118 residual_norm: res_norm,
119 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
120 });
121 }
122
123 let p_hat = apply_ilu0_preconditioner(client, &precond, &p)?;
125
126 let v = a.spmv(&p_hat)?;
128
129 let sigma = vector_dot(client, &r_hat, &v)?;
131 if sigma.abs() < BREAKDOWN_TOL {
132 let res_norm = vector_norm(client, &r)?;
133 return Ok(CgsResult {
134 solution: x,
135 iterations: iter,
136 residual_norm: res_norm,
137 converged: res_norm < options.atol || res_norm / b_norm < options.rtol,
138 });
139 }
140 let alpha = rho / sigma;
141
142 let v_scaled = client.mul_scalar(&v, alpha)?;
144 let q = client.sub(&u, &v_scaled)?;
145
146 let u_plus_q = client.add(&u, &q)?;
148
149 let uq_hat = apply_ilu0_preconditioner(client, &precond, &u_plus_q)?;
151
152 let uq_scaled = client.mul_scalar(&uq_hat, alpha)?;
154 x = client.add(&x, &uq_scaled)?;
155
156 let a_uq = a.spmv(&uq_hat)?;
158 let a_uq_scaled = client.mul_scalar(&a_uq, alpha)?;
159 r = client.sub(&r, &a_uq_scaled)?;
160
161 let res_norm = vector_norm(client, &r)?;
162 if res_norm < options.atol || res_norm / b_norm < options.rtol {
163 return Ok(CgsResult {
164 solution: x,
165 iterations: iter + 1,
166 residual_norm: res_norm,
167 converged: true,
168 });
169 }
170
171 let rho_new = vector_dot(client, &r_hat, &r)?;
173
174 let beta = rho_new / rho;
175
176 let q_scaled = client.mul_scalar(&q, beta)?;
178 u = client.add(&r, &q_scaled)?;
179
180 let p_scaled = client.mul_scalar(&p, beta)?;
182 let q_plus_bp = client.add(&q, &p_scaled)?;
183 let qbp_scaled = client.mul_scalar(&q_plus_bp, beta)?;
184 p = client.add(&u, &qbp_scaled)?;
185
186 rho = rho_new;
187 }
188
189 let ax = a.spmv(&x)?;
190 let r_final = client.sub(b, &ax)?;
191 let final_residual = vector_norm(client, &r_final)?;
192
193 Ok(CgsResult {
194 solution: x,
195 iterations: options.max_iter,
196 residual_norm: final_residual,
197 converged: false,
198 })
199}