faer/operator/
conjugate_gradient.rs

1use crate::assert;
2use crate::matrix_free::*;
3use linalg::cholesky::llt_pivoting::factor as piv_llt;
4use linalg::matmul::triangular::BlockStructure;
5
6/// algorithm parameters
7#[derive(Copy, Clone, Debug)]
8pub struct CgParams<T: RealField> {
9	/// whether the initial guess is implicitly zero or not
10	pub initial_guess: InitialGuessStatus,
11	/// absolute tolerance for convergence testing
12	pub abs_tolerance: T,
13	/// relative tolerance for convergence testing
14	pub rel_tolerance: T,
15	/// maximum number of iterations
16	pub max_iters: usize,
17
18	#[doc(hidden)]
19	pub non_exhaustive: NonExhaustive,
20}
21
22/// algorithm result
23#[derive(Copy, Clone, Debug)]
24pub struct CgInfo<T: RealField> {
25	/// absolute residual at the final step
26	pub abs_residual: T,
27	/// relative residual at the final step
28	pub rel_residual: T,
29	/// number of iterations executed by the algorithm
30	pub iter_count: usize,
31
32	#[doc(hidden)]
33	pub non_exhaustive: NonExhaustive,
34}
35
36/// algorithm error
37#[derive(Copy, Clone, Debug)]
38pub enum CgError<T: ComplexField> {
39	/// operator was detected to not be positive definite
40	NonPositiveDefiniteOperator,
41	/// preconditioner was detected to not be positive definite
42	NonPositiveDefinitePreconditioner,
43	/// convergence failure
44	NoConvergence {
45		/// absolute residual at the final step
46		abs_residual: T::Real,
47		/// relative residual at the final step
48		rel_residual: T::Real,
49	},
50}
51
52impl<T: RealField> Default for CgParams<T> {
53	#[inline]
54	#[math]
55	fn default() -> Self {
56		Self {
57			initial_guess: InitialGuessStatus::MaybeNonZero,
58			abs_tolerance: zero::<T>(),
59			rel_tolerance: eps::<T>() * from_f64::<T>(128.0),
60			max_iters: usize::MAX,
61			non_exhaustive: NonExhaustive(()),
62		}
63	}
64}
65
66/// computes the size and alignment of required workspace for executing the conjugate gradient
67/// algorithm
68pub fn conjugate_gradient_scratch<T: ComplexField>(precond: impl Precond<T>, mat: impl LinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
69	fn implementation<T: ComplexField>(M: &dyn Precond<T>, A: &dyn LinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
70		let n = A.nrows();
71		let k = rhs_ncols;
72
73		let nk = temp_mat_scratch::<T>(n, k);
74		let kk = temp_mat_scratch::<T>(k, k);
75		let k_usize = StackReq::new::<usize>(k);
76		let chol = piv_llt::cholesky_in_place_scratch::<usize, T>(k, par, default());
77		StackReq::all_of(&[
78			nk,      // residual
79			nk,      // p
80			nk,      // z
81			kk,      // rtz
82			k_usize, // perm
83			k_usize, // perm_inv
84			StackReq::any_of(&[
85				StackReq::all_of(&[
86					nk, // Ap
87					kk, // ptAp | rtz_new
88					StackReq::any_of(&[
89						A.apply_scratch(k, par),
90						chol, // ptAp | rtz
91						StackReq::all_of(&[
92							kk, // alpha | beta
93							kk, // alpha_perm | beta_perm
94						]),
95					]),
96				]),
97				M.apply_scratch(k, par),
98			]),
99		])
100	}
101	implementation(&precond, &mat, rhs_ncols, par)
102}
103
104/// executes the conjugate gradient using the provided preconditioner
105///
106/// # note
107/// this function is also optimized for a rhs with multiple columns
108#[inline]
109#[track_caller]
110pub fn conjugate_gradient<T: ComplexField>(
111	out: MatMut<'_, T>,
112	precond: impl Precond<T>,
113	mat: impl LinOp<T>,
114	rhs: MatRef<'_, T>,
115	params: CgParams<T::Real>,
116	callback: impl FnMut(MatRef<'_, T>),
117	par: Par,
118	stack: &mut MemStack,
119) -> Result<CgInfo<T::Real>, CgError<T::Real>> {
120	#[track_caller]
121	#[math]
122	fn implementation<T: ComplexField>(
123		mut x: MatMut<'_, T>,
124		M: &dyn Precond<T>,
125		A: &dyn LinOp<T>,
126		b: MatRef<'_, T>,
127
128		params: CgParams<T::Real>,
129		callback: &mut dyn FnMut(MatRef<'_, T>),
130		par: Par,
131		mut stack: &mut MemStack,
132	) -> Result<CgInfo<T::Real>, CgError<T::Real>> {
133		assert!(A.nrows() == A.ncols());
134
135		let n = A.nrows();
136		let k = b.ncols();
137		let b_norm = b.norm_l2();
138		if b_norm == zero::<T::Real>() {
139			x.fill(zero());
140			return Ok(CgInfo {
141				abs_residual: zero::<T::Real>(),
142				rel_residual: zero::<T::Real>(),
143				iter_count: 0,
144				non_exhaustive: NonExhaustive(()),
145			});
146		}
147
148		let rel_threshold = params.rel_tolerance * b_norm;
149		let abs_threshold = params.abs_tolerance;
150
151		let threshold = if abs_threshold > rel_threshold { abs_threshold } else { rel_threshold };
152
153		let (mut r, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
154		let mut r = r.as_mat_mut();
155		let (mut p, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
156		let mut p = p.as_mat_mut();
157		let (mut z, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
158		let mut z = z.as_mat_mut();
159		let (mut rtz, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
160		let mut rtz = rtz.as_mat_mut();
161
162		let (perm, mut stack) = unsafe { stack.rb_mut().make_raw::<usize>(k) };
163		let (perm_inv, mut stack) = unsafe { stack.rb_mut().make_raw::<usize>(k) };
164
165		let abs_residual = if params.initial_guess == InitialGuessStatus::MaybeNonZero {
166			A.apply(r.rb_mut(), x.rb(), par, stack.rb_mut());
167			z!(&mut r, &b).for_each(|uz!(res, rhs)| *res = *rhs - *res);
168			r.norm_l2()
169		} else {
170			copy(b_norm)
171		};
172
173		if abs_residual < threshold {
174			return Ok(CgInfo {
175				rel_residual: abs_residual / b_norm,
176				abs_residual,
177				iter_count: 0,
178				non_exhaustive: NonExhaustive(()),
179			});
180		}
181
182		let tril = BlockStructure::TriangularLower;
183
184		{
185			M.apply(p.rb_mut(), r.rb(), par, stack.rb_mut());
186
187			crate::linalg::matmul::triangular::matmul(
188				rtz.rb_mut(),
189				tril,
190				Accum::Replace,
191				r.rb().adjoint(),
192				BlockStructure::Rectangular,
193				p.rb(),
194				BlockStructure::Rectangular,
195				one::<T>(),
196				par,
197			);
198		}
199		for iter in 0..params.max_iters {
200			{
201				let (mut Ap, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
202				let mut Ap = Ap.as_mat_mut();
203				let (mut ptAp, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
204				let mut ptAp = ptAp.as_mat_mut();
205
206				A.apply(Ap.rb_mut(), p.rb(), par, stack.rb_mut());
207				crate::linalg::matmul::triangular::matmul(
208					ptAp.rb_mut(),
209					tril,
210					Accum::Replace,
211					p.rb().adjoint(),
212					BlockStructure::Rectangular,
213					Ap.rb(),
214					BlockStructure::Rectangular,
215					one::<T>(),
216					par,
217				);
218
219				let (info, llt_perm) = match piv_llt::cholesky_in_place(ptAp.rb_mut(), perm, perm_inv, par, stack.rb_mut(), Default::default()) {
220					Ok(ok) => ok,
221					Err(_) => return Err(CgError::NonPositiveDefiniteOperator),
222				};
223
224				let (mut alpha, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
225				let mut alpha = alpha.as_mat_mut();
226				let (mut alpha_perm, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
227				let mut alpha_perm = alpha_perm.as_mat_mut();
228				alpha.copy_from(&rtz);
229				for j in 0..k {
230					for i in 0..j {
231						alpha.write(i, j, conj(alpha[(j, i)]));
232					}
233				}
234				crate::perm::permute_rows(alpha_perm.rb_mut(), alpha.rb(), llt_perm);
235				crate::linalg::triangular_solve::solve_lower_triangular_in_place(
236					ptAp.rb().get(..info.rank, ..info.rank),
237					alpha_perm.rb_mut().get_mut(..info.rank, ..),
238					par,
239				);
240				crate::linalg::triangular_solve::solve_upper_triangular_in_place(
241					ptAp.rb().get(..info.rank, ..info.rank).adjoint(),
242					alpha_perm.rb_mut().get_mut(..info.rank, ..),
243					par,
244				);
245				alpha_perm.rb_mut().get_mut(info.rank.., ..).fill(zero());
246				crate::perm::permute_rows(alpha.rb_mut(), alpha_perm.rb(), llt_perm.inverse());
247
248				crate::linalg::matmul::matmul(
249					x.rb_mut(),
250					if iter == 0 && params.initial_guess == InitialGuessStatus::Zero {
251						Accum::Replace
252					} else {
253						Accum::Add
254					},
255					p.rb(),
256					alpha.rb(),
257					one::<T>(),
258					par,
259				);
260				crate::linalg::matmul::matmul(r.rb_mut(), Accum::Add, Ap.rb(), alpha.rb(), -one::<T>(), par);
261				callback(x.rb());
262			}
263
264			let abs_residual = r.norm_l2();
265			if abs_residual < threshold {
266				return Ok(CgInfo {
267					rel_residual: abs_residual / b_norm,
268					abs_residual,
269					iter_count: iter + 1,
270					non_exhaustive: NonExhaustive(()),
271				});
272			}
273
274			M.apply(z.rb_mut(), r.rb(), par, stack.rb_mut());
275
276			let (mut rtz_new, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
277			let mut rtz_new = rtz_new.as_mat_mut();
278			crate::linalg::matmul::triangular::matmul(
279				rtz_new.rb_mut(),
280				tril,
281				Accum::Replace,
282				r.rb().adjoint(),
283				BlockStructure::Rectangular,
284				z.rb(),
285				BlockStructure::Rectangular,
286				one::<T>(),
287				par,
288			);
289
290			{
291				let (info, llt_perm) = match piv_llt::cholesky_in_place(rtz.rb_mut(), perm, perm_inv, par, stack.rb_mut(), Default::default()) {
292					Ok(ok) => ok,
293					Err(_) => return Err(CgError::NonPositiveDefiniteOperator),
294				};
295				let (mut beta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
296				let mut beta = beta.as_mat_mut();
297				let (mut beta_perm, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack.rb_mut()) };
298				let mut beta_perm = beta_perm.as_mat_mut();
299				beta.copy_from(&rtz_new);
300				for j in 0..k {
301					for i in 0..j {
302						beta.write(i, j, conj(beta[(j, i)]));
303					}
304				}
305				crate::perm::permute_rows(beta_perm.rb_mut(), beta.rb(), llt_perm);
306				crate::linalg::triangular_solve::solve_lower_triangular_in_place(
307					rtz.rb().get(..info.rank, ..info.rank),
308					beta_perm.rb_mut().get_mut(..info.rank, ..),
309					par,
310				);
311				crate::linalg::triangular_solve::solve_upper_triangular_in_place(
312					rtz.rb().get(..info.rank, ..info.rank).adjoint(),
313					beta_perm.rb_mut().get_mut(..info.rank, ..),
314					par,
315				);
316				beta_perm.rb_mut().get_mut(info.rank.., ..).fill(zero());
317				crate::perm::permute_rows(beta.rb_mut(), beta_perm.rb(), llt_perm.inverse());
318				rtz.copy_from(&rtz_new);
319
320				crate::linalg::matmul::matmul(z.rb_mut(), Accum::Add, p.rb(), beta.rb(), one::<T>(), par);
321				p.copy_from(&z);
322			}
323		}
324
325		Err(CgError::NoConvergence {
326			rel_residual: abs_residual / b_norm,
327			abs_residual,
328		})
329	}
330
331	implementation(out, &precond, &mat, rhs, params, &mut { callback }, par, stack)
332}
333
334#[cfg(test)]
335mod tests {
336	use super::*;
337	use crate::stats::prelude::*;
338	use crate::{mat, matrix_free};
339	use dyn_stack::MemBuffer;
340	use equator::assert;
341	use rand::prelude::*;
342
343	#[test]
344	fn test_cg() {
345		let ref A = mat![[2.5, -1.0], [-1.0, 3.1]];
346		let ref sol = mat![[2.1, 2.4], [4.1, 4.0]];
347		let ref rhs = A * sol;
348		let ref mut out = Mat::<f64>::zeros(2, sol.ncols());
349		let mut params = CgParams::default();
350		params.max_iters = 10;
351		let precond = matrix_free::IdentityPrecond { dim: 2 };
352		let result = conjugate_gradient(
353			out.as_mut(),
354			precond,
355			A.as_ref(),
356			rhs.as_ref(),
357			params,
358			|_| {},
359			Par::Seq,
360			MemStack::new(&mut MemBuffer::new(conjugate_gradient_scratch(precond, A.as_ref(), 2, Par::Seq))),
361		);
362		let ref out = *out;
363
364		assert!(result.is_ok());
365		let result = result.unwrap();
366		assert!((A * out - rhs).norm_l2() <= params.rel_tolerance * rhs.norm_l2());
367		assert!(result.iter_count <= 1);
368	}
369
370	#[test]
371	fn test_cg_breakdown() {
372		let ref mut rng = StdRng::seed_from_u64(0);
373		let n = 10;
374		let k = 15;
375		let ref Q: Mat<c64> = UnitaryMat {
376			dim: n,
377			standard_normal: ComplexDistribution::new(StandardNormal, StandardNormal),
378		}
379		.sample(rng);
380		let mut d = Col::zeros(n);
381		for i in 0..n {
382			d[i] = c64::new(f64::exp(rand::distributions::Standard.sample(rng)).recip(), 0.0);
383		}
384		let ref A = Q * d.as_ref().as_diagonal() * Q.adjoint();
385		let ref mut diag = Mat::<c64>::identity(n, n);
386		for i in 0..n {
387			diag[(i, i)] = c64::new(f64::exp(rand::distributions::Standard.sample(rng)).recip(), 0.0);
388		}
389		let ref diag = *diag;
390		let ref mut sol = CwiseMatDistribution {
391			nrows: n,
392			ncols: k,
393			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
394		}
395		.sample(rng);
396
397		for i in 0..n {
398			sol[(i, k - 1)] = c64::new(0.0, 0.0);
399			for j in 0..k - 1 {
400				let val = sol[(i, j)];
401				sol[(i, k - 1)] += val;
402			}
403		}
404
405		let ref sol = *sol;
406		let ref rhs = A * sol;
407		let ref mut out = Mat::<c64>::zeros(n, k);
408		let params = CgParams::default();
409		let result = conjugate_gradient(
410			out.as_mut(),
411			diag.as_ref(),
412			A.as_ref(),
413			rhs.as_ref(),
414			params,
415			|_| {},
416			Par::Seq,
417			MemStack::new(&mut MemBuffer::new(conjugate_gradient_scratch::<c64>(
418				diag.as_ref(),
419				A.as_ref(),
420				k,
421				Par::Seq,
422			))),
423		);
424		let ref out = *out;
425
426		assert!(result.is_ok());
427		let result = result.unwrap();
428		assert!((A * out - rhs).norm_l2() <= params.rel_tolerance * rhs.norm_l2());
429		assert!(result.iter_count <= 1);
430	}
431}