1use crate::assert;
2use crate::matrix_free::*;
3use linalg::cholesky::llt_pivoting::factor as piv_llt;
4use linalg::matmul::triangular::BlockStructure;
5
6#[derive(Copy, Clone, Debug)]
8pub struct CgParams<T: RealField> {
9 pub initial_guess: InitialGuessStatus,
11 pub abs_tolerance: T,
13 pub rel_tolerance: T,
15 pub max_iters: usize,
17
18 #[doc(hidden)]
19 pub non_exhaustive: NonExhaustive,
20}
21
22#[derive(Copy, Clone, Debug)]
24pub struct CgInfo<T: RealField> {
25 pub abs_residual: T,
27 pub rel_residual: T,
29 pub iter_count: usize,
31
32 #[doc(hidden)]
33 pub non_exhaustive: NonExhaustive,
34}
35
36#[derive(Copy, Clone, Debug)]
38pub enum CgError<T: ComplexField> {
39 NonPositiveDefiniteOperator,
41 NonPositiveDefinitePreconditioner,
43 NoConvergence {
45 abs_residual: T::Real,
47 rel_residual: T::Real,
49 },
50}
51
52impl<T: RealField> Default for CgParams<T> {
53 #[inline]
54 #[math]
55 fn default() -> Self {
56 Self {
57 initial_guess: InitialGuessStatus::MaybeNonZero,
58 abs_tolerance: zero::<T>(),
59 rel_tolerance: eps::<T>() * from_f64::<T>(128.0),
60 max_iters: usize::MAX,
61 non_exhaustive: NonExhaustive(()),
62 }
63 }
64}
65
66pub fn conjugate_gradient_scratch<T: ComplexField>(precond: impl Precond<T>, mat: impl LinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
69 fn implementation<T: ComplexField>(M: &dyn Precond<T>, A: &dyn LinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
70 let n = A.nrows();
71 let k = rhs_ncols;
72
73 let nk = temp_mat_scratch::<T>(n, k);
74 let kk = temp_mat_scratch::<T>(k, k);
75 let k_usize = StackReq::new::<usize>(k);
76 let chol = piv_llt::cholesky_in_place_scratch::<usize, T>(k, par, default());
77 StackReq::all_of(&[
78 nk, nk, nk, kk, k_usize, k_usize, StackReq::any_of(&[
85 StackReq::all_of(&[
86 nk, kk, StackReq::any_of(&[
89 A.apply_scratch(k, par),
90 chol, StackReq::all_of(&[
92 kk, kk, ]),
95 ]),
96 ]),
97 M.apply_scratch(k, par),
98 ]),
99 ])
100 }
101 implementation(&precond, &mat, rhs_ncols, par)
102}
103
104#[inline]
109#[track_caller]
110pub fn conjugate_gradient<T: ComplexField>(
111 out: MatMut<'_, T>,
112 precond: impl Precond<T>,
113 mat: impl LinOp<T>,
114 rhs: MatRef<'_, T>,
115 params: CgParams<T::Real>,
116 callback: impl FnMut(MatRef<'_, T>),
117 par: Par,
118 stack: &mut MemStack,
119) -> Result<CgInfo<T::Real>, CgError<T::Real>> {
120 #[track_caller]
121 #[math]
122 fn implementation<T: ComplexField>(
123 mut x: MatMut<'_, T>,
124 M: &dyn Precond<T>,
125 A: &dyn LinOp<T>,
126 b: MatRef<'_, T>,
127
128 params: CgParams<T::Real>,
129 callback: &mut dyn FnMut(MatRef<'_, T>),
130 par: Par,
131 mut stack: &mut MemStack,
132 ) -> Result<CgInfo<T::Real>, CgError<T::Real>> {
133 assert!(A.nrows() == A.ncols());
134
135 let n = A.nrows();
136 let k = b.ncols();
137 let b_norm = b.norm_l2();
138 if b_norm == zero::<T::Real>() {
139 x.fill(zero());
140 return Ok(CgInfo {
141 abs_residual: zero::<T::Real>(),
142 rel_residual: zero::<T::Real>(),
143 iter_count: 0,
144 non_exhaustive: NonExhaustive(()),
145 });
146 }
147
148 let rel_threshold = params.rel_tolerance * b_norm;
149 let abs_threshold = params.abs_tolerance;
150
151 let threshold = if abs_threshold > rel_threshold { abs_threshold } else { rel_threshold };
152
153 let (mut r, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
154 let mut r = r.as_mat_mut();
155 let (mut p, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
156 let mut p = p.as_mat_mut();
157 let (mut z, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
158 let mut z = z.as_mat_mut();
159 let (mut rtz, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
160 let mut rtz = rtz.as_mat_mut();
161
162 let (perm, mut stack) = unsafe { stack.rb_mut().make_raw::<usize>(k) };
163 let (perm_inv, mut stack) = unsafe { stack.rb_mut().make_raw::<usize>(k) };
164
165 let abs_residual = if params.initial_guess == InitialGuessStatus::MaybeNonZero {
166 A.apply(r.rb_mut(), x.rb(), par, stack.rb_mut());
167 z!(&mut r, &b).for_each(|uz!(res, rhs)| *res = *rhs - *res);
168 r.norm_l2()
169 } else {
170 copy(b_norm)
171 };
172
173 if abs_residual < threshold {
174 return Ok(CgInfo {
175 rel_residual: abs_residual / b_norm,
176 abs_residual,
177 iter_count: 0,
178 non_exhaustive: NonExhaustive(()),
179 });
180 }
181
182 let tril = BlockStructure::TriangularLower;
183
184 {
185 M.apply(p.rb_mut(), r.rb(), par, stack.rb_mut());
186
187 crate::linalg::matmul::triangular::matmul(
188 rtz.rb_mut(),
189 tril,
190 Accum::Replace,
191 r.rb().adjoint(),
192 BlockStructure::Rectangular,
193 p.rb(),
194 BlockStructure::Rectangular,
195 one::<T>(),
196 par,
197 );
198 }
199 for iter in 0..params.max_iters {
200 {
201 let (mut Ap, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
202 let mut Ap = Ap.as_mat_mut();
203 let (mut ptAp, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
204 let mut ptAp = ptAp.as_mat_mut();
205
206 A.apply(Ap.rb_mut(), p.rb(), par, stack.rb_mut());
207 crate::linalg::matmul::triangular::matmul(
208 ptAp.rb_mut(),
209 tril,
210 Accum::Replace,
211 p.rb().adjoint(),
212 BlockStructure::Rectangular,
213 Ap.rb(),
214 BlockStructure::Rectangular,
215 one::<T>(),
216 par,
217 );
218
219 let (info, llt_perm) = match piv_llt::cholesky_in_place(ptAp.rb_mut(), perm, perm_inv, par, stack.rb_mut(), Default::default()) {
220 Ok(ok) => ok,
221 Err(_) => return Err(CgError::NonPositiveDefiniteOperator),
222 };
223
224 let (mut alpha, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
225 let mut alpha = alpha.as_mat_mut();
226 let (mut alpha_perm, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
227 let mut alpha_perm = alpha_perm.as_mat_mut();
228 alpha.copy_from(&rtz);
229 for j in 0..k {
230 for i in 0..j {
231 alpha.write(i, j, conj(alpha[(j, i)]));
232 }
233 }
234 crate::perm::permute_rows(alpha_perm.rb_mut(), alpha.rb(), llt_perm);
235 crate::linalg::triangular_solve::solve_lower_triangular_in_place(
236 ptAp.rb().get(..info.rank, ..info.rank),
237 alpha_perm.rb_mut().get_mut(..info.rank, ..),
238 par,
239 );
240 crate::linalg::triangular_solve::solve_upper_triangular_in_place(
241 ptAp.rb().get(..info.rank, ..info.rank).adjoint(),
242 alpha_perm.rb_mut().get_mut(..info.rank, ..),
243 par,
244 );
245 alpha_perm.rb_mut().get_mut(info.rank.., ..).fill(zero());
246 crate::perm::permute_rows(alpha.rb_mut(), alpha_perm.rb(), llt_perm.inverse());
247
248 crate::linalg::matmul::matmul(
249 x.rb_mut(),
250 if iter == 0 && params.initial_guess == InitialGuessStatus::Zero {
251 Accum::Replace
252 } else {
253 Accum::Add
254 },
255 p.rb(),
256 alpha.rb(),
257 one::<T>(),
258 par,
259 );
260 crate::linalg::matmul::matmul(r.rb_mut(), Accum::Add, Ap.rb(), alpha.rb(), -one::<T>(), par);
261 callback(x.rb());
262 }
263
264 let abs_residual = r.norm_l2();
265 if abs_residual < threshold {
266 return Ok(CgInfo {
267 rel_residual: abs_residual / b_norm,
268 abs_residual,
269 iter_count: iter + 1,
270 non_exhaustive: NonExhaustive(()),
271 });
272 }
273
274 M.apply(z.rb_mut(), r.rb(), par, stack.rb_mut());
275
276 let (mut rtz_new, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
277 let mut rtz_new = rtz_new.as_mat_mut();
278 crate::linalg::matmul::triangular::matmul(
279 rtz_new.rb_mut(),
280 tril,
281 Accum::Replace,
282 r.rb().adjoint(),
283 BlockStructure::Rectangular,
284 z.rb(),
285 BlockStructure::Rectangular,
286 one::<T>(),
287 par,
288 );
289
290 {
291 let (info, llt_perm) = match piv_llt::cholesky_in_place(rtz.rb_mut(), perm, perm_inv, par, stack.rb_mut(), Default::default()) {
292 Ok(ok) => ok,
293 Err(_) => return Err(CgError::NonPositiveDefiniteOperator),
294 };
295 let (mut beta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
296 let mut beta = beta.as_mat_mut();
297 let (mut beta_perm, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
298 let mut beta_perm = beta_perm.as_mat_mut();
299 beta.copy_from(&rtz_new);
300 for j in 0..k {
301 for i in 0..j {
302 beta.write(i, j, conj(beta[(j, i)]));
303 }
304 }
305 crate::perm::permute_rows(beta_perm.rb_mut(), beta.rb(), llt_perm);
306 crate::linalg::triangular_solve::solve_lower_triangular_in_place(
307 rtz.rb().get(..info.rank, ..info.rank),
308 beta_perm.rb_mut().get_mut(..info.rank, ..),
309 par,
310 );
311 crate::linalg::triangular_solve::solve_upper_triangular_in_place(
312 rtz.rb().get(..info.rank, ..info.rank).adjoint(),
313 beta_perm.rb_mut().get_mut(..info.rank, ..),
314 par,
315 );
316 beta_perm.rb_mut().get_mut(info.rank.., ..).fill(zero());
317 crate::perm::permute_rows(beta.rb_mut(), beta_perm.rb(), llt_perm.inverse());
318 rtz.copy_from(&rtz_new);
319
320 crate::linalg::matmul::matmul(z.rb_mut(), Accum::Add, p.rb(), beta.rb(), one::<T>(), par);
321 p.copy_from(&z);
322 }
323 }
324
325 Err(CgError::NoConvergence {
326 rel_residual: abs_residual / b_norm,
327 abs_residual,
328 })
329 }
330
331 implementation(out, &precond, &mat, rhs, params, &mut { callback }, par, stack)
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::stats::prelude::*;
338 use crate::{mat, matrix_free};
339 use dyn_stack::MemBuffer;
340 use equator::assert;
341 use rand::prelude::*;
342
343 #[test]
344 fn test_cg() {
345 let ref A = mat![[2.5, -1.0], [-1.0, 3.1]];
346 let ref sol = mat![[2.1, 2.4], [4.1, 4.0]];
347 let ref rhs = A * sol;
348 let ref mut out = Mat::<f64>::zeros(2, sol.ncols());
349 let mut params = CgParams::default();
350 params.max_iters = 10;
351 let precond = matrix_free::IdentityPrecond { dim: 2 };
352 let result = conjugate_gradient(
353 out.as_mut(),
354 precond,
355 A.as_ref(),
356 rhs.as_ref(),
357 params,
358 |_| {},
359 Par::Seq,
360 MemStack::new(&mut MemBuffer::new(conjugate_gradient_scratch(precond, A.as_ref(), 2, Par::Seq))),
361 );
362 let ref out = *out;
363
364 assert!(result.is_ok());
365 let result = result.unwrap();
366 assert!((A * out - rhs).norm_l2() <= params.rel_tolerance * rhs.norm_l2());
367 assert!(result.iter_count <= 1);
368 }
369
370 #[test]
371 fn test_cg_breakdown() {
372 let ref mut rng = StdRng::seed_from_u64(0);
373 let n = 10;
374 let k = 15;
375 let ref Q: Mat<c64> = UnitaryMat {
376 dim: n,
377 standard_normal: ComplexDistribution::new(StandardNormal, StandardNormal),
378 }
379 .sample(rng);
380 let mut d = Col::zeros(n);
381 for i in 0..n {
382 d[i] = c64::new(f64::exp(rand::distributions::Standard.sample(rng)).recip(), 0.0);
383 }
384 let ref A = Q * d.as_ref().as_diagonal() * Q.adjoint();
385 let ref mut diag = Mat::<c64>::identity(n, n);
386 for i in 0..n {
387 diag[(i, i)] = c64::new(f64::exp(rand::distributions::Standard.sample(rng)).recip(), 0.0);
388 }
389 let ref diag = *diag;
390 let ref mut sol = CwiseMatDistribution {
391 nrows: n,
392 ncols: k,
393 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
394 }
395 .sample(rng);
396
397 for i in 0..n {
398 sol[(i, k - 1)] = c64::new(0.0, 0.0);
399 for j in 0..k - 1 {
400 let val = sol[(i, j)];
401 sol[(i, k - 1)] += val;
402 }
403 }
404
405 let ref sol = *sol;
406 let ref rhs = A * sol;
407 let ref mut out = Mat::<c64>::zeros(n, k);
408 let params = CgParams::default();
409 let result = conjugate_gradient(
410 out.as_mut(),
411 diag.as_ref(),
412 A.as_ref(),
413 rhs.as_ref(),
414 params,
415 |_| {},
416 Par::Seq,
417 MemStack::new(&mut MemBuffer::new(conjugate_gradient_scratch::<c64>(
418 diag.as_ref(),
419 A.as_ref(),
420 k,
421 Par::Seq,
422 ))),
423 );
424 let ref out = *out;
425
426 assert!(result.is_ok());
427 let result = result.unwrap();
428 assert!((A * out - rhs).norm_l2() <= params.rel_tolerance * rhs.norm_l2());
429 assert!(result.iter_count <= 1);
430 }
431}