1use core::fmt;
2use fenris_traits::Real;
3use nalgebra::base::constraint::AreMultipliable;
4use nalgebra::constraint::{DimEq, ShapeConstraint};
5use nalgebra::storage::Storage;
6use nalgebra::{ClosedAdd, ClosedMul, DVector, DVectorView, DVectorViewMut, Dim, Dyn, Matrix, Scalar, U1};
7use nalgebra_sparse::ops::serial::spmm_csr_dense;
8use nalgebra_sparse::ops::Op;
9use nalgebra_sparse::CsrMatrix;
10use num::{One, Zero};
11use std::error::Error;
12use std::marker::PhantomData;
13use std::ops::{Deref, DerefMut};
14
15pub trait LinearOperator<T: Scalar> {
16 fn apply(&self, y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>>;
17}
18
19impl<'a, T, A> LinearOperator<T> for &'a A
20where
21 T: Scalar,
22 A: ?Sized + LinearOperator<T>,
23{
24 fn apply(&self, y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
25 <A as LinearOperator<T>>::apply(self, y, x)
26 }
27}
28
29impl<T, R, C, S> LinearOperator<T> for Matrix<T, R, C, S>
30where
31 T: Scalar + One + Zero + ClosedMul + ClosedAdd,
32 R: Dim,
33 C: Dim,
34 S: Storage<T, R, C>,
35 ShapeConstraint: DimEq<Dyn, R> + DimEq<C, Dyn> + AreMultipliable<R, C, Dyn, U1>,
36{
37 fn apply(&self, mut y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
38 y.gemv(T::one(), self, &x, T::zero());
39 Ok(())
40 }
41}
42
43impl<T> LinearOperator<T> for CsrMatrix<T>
44where
45 T: Scalar + Zero + One + ClosedMul + ClosedAdd,
46{
47 fn apply(&self, mut y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
48 spmm_csr_dense(T::zero(), &mut y, T::one(), Op::NoOp(self), Op::NoOp(&x));
49 Ok(())
50 }
51}
52
53pub struct IdentityOperator;
54
55impl<T: Scalar> LinearOperator<T> for IdentityOperator {
56 fn apply(&self, mut y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
57 y.copy_from(&x);
58 Ok(())
59 }
60}
61
62pub trait CgStoppingCriterion<T: Scalar> {
63 fn reset(&self, _a: &dyn LinearOperator<T>, _x: DVectorView<T>, _b: DVectorView<T>) {}
65
66 fn has_converged(
67 &self,
68 a: &dyn LinearOperator<T>,
69 x: DVectorView<T>,
70 b: DVectorView<T>,
71 b_norm: T,
72 iteration: usize,
73 approx_residual: DVectorView<T>,
74 ) -> Result<bool, SolveErrorKind>;
75}
76
77#[derive(Debug)]
85pub struct RelativeResidualCriterion<T: Scalar> {
86 tol: T,
87}
88
89impl<T: Scalar + Zero> RelativeResidualCriterion<T> {
90 pub fn new(tol: T) -> Self {
91 Self { tol }
92 }
93}
94
95impl Default for RelativeResidualCriterion<f64> {
96 fn default() -> Self {
97 Self::new(1e-8)
98 }
99}
100
101impl Default for RelativeResidualCriterion<f32> {
102 fn default() -> Self {
103 Self::new(1e-4)
104 }
105}
106
107impl<T> CgStoppingCriterion<T> for RelativeResidualCriterion<T>
108where
109 T: Real,
110{
111 fn has_converged(
112 &self,
113 _a: &dyn LinearOperator<T>,
114 _x: DVectorView<T>,
115 _b: DVectorView<T>,
116 b_norm: T,
117 _iteration: usize,
118 approx_residual: DVectorView<T>,
119 ) -> Result<bool, SolveErrorKind> {
120 let r_approx_norm = approx_residual.norm();
121 let converged = r_approx_norm <= self.tol * b_norm;
122 Ok(converged)
123 }
124}
125
126#[derive(Debug, Clone)]
127#[allow(non_snake_case)]
128pub struct CgWorkspace<T: Scalar> {
129 r: DVector<T>,
130 z: DVector<T>,
131 p: DVector<T>,
132 Ap: DVector<T>,
133}
134
135#[allow(non_snake_case)]
136struct Buffers<'a, T: Scalar> {
137 r: &'a mut DVector<T>,
138 z: &'a mut DVector<T>,
139 p: &'a mut DVector<T>,
140 Ap: &'a mut DVector<T>,
141}
142
143impl<T: Scalar + Zero> Default for CgWorkspace<T> {
144 fn default() -> Self {
145 Self {
146 r: DVector::zeros(0),
147 z: DVector::zeros(0),
148 p: DVector::zeros(0),
149 Ap: DVector::zeros(0),
150 }
151 }
152}
153
154impl<T: Scalar + Zero> CgWorkspace<T> {
155 fn prepare_buffers(&mut self, dim: usize) -> Buffers<T> {
156 self.r.resize_vertically_mut(dim, T::zero());
157 self.z.resize_vertically_mut(dim, T::zero());
158 self.p.resize_vertically_mut(dim, T::zero());
159 self.Ap.resize_vertically_mut(dim, T::zero());
160 Buffers {
161 r: &mut self.r,
162 z: &mut self.z,
163 p: &mut self.p,
164 Ap: &mut self.Ap,
165 }
166 }
167}
168
169#[derive(Debug)]
170enum OwnedOrMutRef<'a, T> {
171 Owned(T),
172 MutRef(&'a mut T),
173}
174
175impl<'a, T> Deref for OwnedOrMutRef<'a, T> {
176 type Target = T;
177
178 fn deref(&self) -> &Self::Target {
179 match self {
180 Self::Owned(owned) => &owned,
181 Self::MutRef(mutref) => &*mutref,
182 }
183 }
184}
185
186impl<'a, T> DerefMut for OwnedOrMutRef<'a, T> {
187 fn deref_mut(&mut self) -> &mut Self::Target {
188 match self {
189 Self::Owned(owned) => owned,
190 Self::MutRef(mutref) => mutref,
191 }
192 }
193}
194
195#[derive(Debug)]
196pub struct ConjugateGradient<'a, T, A, P, Criterion>
197where
198 T: Scalar,
199{
200 workspace: OwnedOrMutRef<'a, CgWorkspace<T>>,
201 operator: A,
202 preconditioner: P,
203 stopping_criterion: Criterion,
204 max_iter: Option<usize>,
205}
206
207impl<'a, T: Scalar + Zero> ConjugateGradient<'a, T, (), IdentityOperator, ()> {
208 pub fn new() -> Self {
209 Self {
210 workspace: OwnedOrMutRef::Owned(CgWorkspace::default()),
211 operator: (),
212 preconditioner: IdentityOperator,
213 stopping_criterion: (),
214 max_iter: None,
215 }
216 }
217}
218
219impl<'a, T: Scalar> ConjugateGradient<'a, T, (), IdentityOperator, ()> {
220 pub fn with_workspace(workspace: &'a mut CgWorkspace<T>) -> Self {
221 Self {
222 workspace: OwnedOrMutRef::MutRef(workspace),
223 operator: (),
224 preconditioner: IdentityOperator,
225 stopping_criterion: (),
226 max_iter: None,
227 }
228 }
229}
230
231impl<'a, T: Scalar, P, Criterion> ConjugateGradient<'a, T, (), P, Criterion> {
232 pub fn with_operator<A>(self, operator: A) -> ConjugateGradient<'a, T, A, P, Criterion> {
233 ConjugateGradient {
234 workspace: self.workspace,
235 operator,
236 preconditioner: self.preconditioner,
237 stopping_criterion: self.stopping_criterion,
238 max_iter: self.max_iter,
239 }
240 }
241}
242
243impl<'a, T: Scalar, A, P, Criterion> ConjugateGradient<'a, T, A, P, Criterion> {
244 pub fn with_preconditioner<P2>(self, preconditioner: P2) -> ConjugateGradient<'a, T, A, P2, Criterion> {
245 ConjugateGradient {
246 workspace: self.workspace,
247 operator: self.operator,
248 preconditioner,
249 stopping_criterion: self.stopping_criterion,
250 max_iter: self.max_iter,
251 }
252 }
253
254 pub fn with_max_iter(self, max_iter: usize) -> Self {
255 Self {
256 max_iter: Some(max_iter),
257 ..self
258 }
259 }
260}
261
262impl<'a, T: Scalar, A, P> ConjugateGradient<'a, T, A, P, ()> {
263 pub fn with_stopping_criterion<Criterion>(
264 self,
265 stopping_criterion: Criterion,
266 ) -> ConjugateGradient<'a, T, A, P, Criterion> {
267 ConjugateGradient {
268 workspace: self.workspace,
269 operator: self.operator,
270 preconditioner: self.preconditioner,
271 stopping_criterion,
272 max_iter: self.max_iter,
273 }
274 }
275}
276
277#[derive(Debug)]
278#[non_exhaustive]
279pub enum SolveErrorKind {
280 OperatorError(Box<dyn Error>),
281 PreconditionerError(Box<dyn Error>),
282 StoppingCriterionError(Box<dyn Error>),
283 IndefiniteOperator,
284 IndefinitePreconditioner,
285 MaxIterationsReached { max_iter: usize },
286}
287
288impl fmt::Display for SolveErrorKind {
289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290 match self {
291 Self::OperatorError(err) => {
292 write!(f, "Error applying operator: ")?;
293 err.fmt(f)
294 }
295 Self::PreconditionerError(err) => {
296 write!(f, "Error applying preconditioner: ")?;
297 err.fmt(f)
298 }
299 Self::StoppingCriterionError(err) => {
300 write!(f, "Error evaluating stopping criterion: ")?;
301 err.fmt(f)
302 }
303 Self::IndefiniteOperator => write!(f, "Operator appears to be indefinite: "),
304 Self::IndefinitePreconditioner => write!(f, "Indefinite preconditioner: "),
305 Self::MaxIterationsReached { max_iter } => {
306 write!(f, "Max iterations ({}) reached.", max_iter)
307 }
308 }
309 }
310}
311
312#[non_exhaustive]
313#[derive(Debug)]
314pub struct SolveError<T> {
315 pub output: CgOutput<T>,
316 pub kind: SolveErrorKind,
317}
318
319impl<T> SolveError<T> {
320 fn new(output: CgOutput<T>, kind: SolveErrorKind) -> Self {
321 Self { output, kind }
322 }
323}
324
325impl<T> fmt::Display for SolveError<T> {
326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 write!(f, "CG solve failed after {}", self.output.num_iterations)?;
328 write!(f, "Error: {}", self.kind)
329 }
330}
331
332impl<T: fmt::Debug> std::error::Error for SolveError<T> {}
333
334fn apply_operator<'a, T, A>(
336 y: impl Into<DVectorViewMut<'a, T>>,
337 a: &'a A,
338 x: impl Into<DVectorView<'a, T>>,
339) -> Result<(), Box<dyn Error>>
340where
341 T: Scalar,
342 A: LinearOperator<T>,
343{
344 a.apply(y.into(), x.into())
345}
346
347#[non_exhaustive]
348#[derive(Debug, Clone)]
349pub struct CgOutput<T> {
350 pub num_iterations: usize,
354 marker: PhantomData<T>,
355}
356
357impl<'a, T, A, P, Criterion> ConjugateGradient<'a, T, A, P, Criterion>
358where
359 T: Real,
360 A: LinearOperator<T>,
361 P: LinearOperator<T>,
362 Criterion: CgStoppingCriterion<T>,
363{
364 pub fn solve_with_guess<'b>(
365 &mut self,
366 b: impl Into<DVectorView<'b, T>>,
367 x: impl Into<DVectorViewMut<'b, T>>,
368 ) -> Result<CgOutput<T>, SolveError<T>> {
369 self.solve_with_guess_(b.into(), x.into())
370 }
371
372 #[allow(non_snake_case)]
373 fn solve_with_guess_(&mut self, b: DVectorView<T>, mut x: DVectorViewMut<T>) -> Result<CgOutput<T>, SolveError<T>> {
374 use SolveErrorKind::*;
375 assert_eq!(b.len(), x.len());
376
377 let mut output = CgOutput {
378 num_iterations: 0,
379 marker: PhantomData,
380 };
381
382 let Buffers { r, z, p, Ap } = self.workspace.prepare_buffers(x.len());
383
384 if let Err(err) = apply_operator(&mut *r, &self.operator, &x) {
387 return Err(SolveError::new(output, OperatorError(err)));
388 }
389 r.zip_apply(&b, |r_i, b_i| *r_i = b_i - r_i.clone());
391
392 if let Err(err) = apply_operator(&mut *z, &self.preconditioner, &*r) {
394 return Err(SolveError::new(output, PreconditionerError(err)));
395 }
396
397 p.copy_from(&z);
399
400 let mut zTr = z.dot(r);
401 let mut pAp;
402
403 let b_norm = b.norm();
404
405 if b_norm == T::zero() {
406 x.fill(T::zero());
407 return Ok(output);
408 }
409
410 loop {
411 let convergence = self.stopping_criterion.has_converged(
413 &self.operator,
414 (&x).into(),
415 (&b).into(),
416 b_norm,
417 output.num_iterations,
418 (&*r).into(),
419 );
420
421 let has_converged = match convergence {
422 Ok(converged) => converged,
423 Err(error_kind) => return Err(SolveError::new(output, error_kind)),
424 };
425
426 if has_converged {
427 break;
428 } else if let Some(max_iter) = self.max_iter {
429 if output.num_iterations >= max_iter {
430 return Err(SolveError::new(output, MaxIterationsReached { max_iter }));
431 }
432 }
433
434 if let Err(err) = apply_operator(&mut *Ap, &self.operator, &*p) {
436 return Err(SolveError::new(output, OperatorError(err)));
437 }
438 pAp = p.dot(&Ap);
439
440 if pAp <= T::zero() {
441 return Err(SolveError {
442 output,
443 kind: SolveErrorKind::IndefiniteOperator,
444 });
445 }
446 if zTr <= T::zero() {
447 return Err(SolveError {
448 output,
449 kind: SolveErrorKind::IndefinitePreconditioner,
450 });
451 }
452
453 let alpha = zTr / pAp;
454 x.zip_apply(&*p, |x_i, p_i| *x_i += alpha * p_i);
456 r.zip_apply(&*Ap, |r_i, Ap_i| *r_i -= alpha * Ap_i);
458
459 output.num_iterations += 1;
461
462 if let Err(err) = apply_operator(&mut *z, &self.preconditioner, &*r) {
464 return Err(SolveError::new(output, PreconditionerError(err)));
465 }
466 let zTr_next = z.dot(&*r);
467 let beta = zTr_next / zTr;
468
469 p.zip_apply(&*z, |p_i, z_i| {
471 *p_i *= beta;
472 *p_i += z_i;
473 });
474
475 zTr = zTr_next;
476 }
477
478 Ok(output)
479 }
480}