fenris_sparse/
cg.rs

1use core::fmt;
2use fenris_traits::Real;
3use nalgebra::base::constraint::AreMultipliable;
4use nalgebra::constraint::{DimEq, ShapeConstraint};
5use nalgebra::storage::Storage;
6use nalgebra::{ClosedAdd, ClosedMul, DVector, DVectorView, DVectorViewMut, Dim, Dyn, Matrix, Scalar, U1};
7use nalgebra_sparse::ops::serial::spmm_csr_dense;
8use nalgebra_sparse::ops::Op;
9use nalgebra_sparse::CsrMatrix;
10use num::{One, Zero};
11use std::error::Error;
12use std::marker::PhantomData;
13use std::ops::{Deref, DerefMut};
14
15pub trait LinearOperator<T: Scalar> {
16    fn apply(&self, y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>>;
17}
18
19impl<'a, T, A> LinearOperator<T> for &'a A
20where
21    T: Scalar,
22    A: ?Sized + LinearOperator<T>,
23{
24    fn apply(&self, y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
25        <A as LinearOperator<T>>::apply(self, y, x)
26    }
27}
28
29impl<T, R, C, S> LinearOperator<T> for Matrix<T, R, C, S>
30where
31    T: Scalar + One + Zero + ClosedMul + ClosedAdd,
32    R: Dim,
33    C: Dim,
34    S: Storage<T, R, C>,
35    ShapeConstraint: DimEq<Dyn, R> + DimEq<C, Dyn> + AreMultipliable<R, C, Dyn, U1>,
36{
37    fn apply(&self, mut y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
38        y.gemv(T::one(), self, &x, T::zero());
39        Ok(())
40    }
41}
42
43impl<T> LinearOperator<T> for CsrMatrix<T>
44where
45    T: Scalar + Zero + One + ClosedMul + ClosedAdd,
46{
47    fn apply(&self, mut y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
48        spmm_csr_dense(T::zero(), &mut y, T::one(), Op::NoOp(self), Op::NoOp(&x));
49        Ok(())
50    }
51}
52
53pub struct IdentityOperator;
54
55impl<T: Scalar> LinearOperator<T> for IdentityOperator {
56    fn apply(&self, mut y: DVectorViewMut<T>, x: DVectorView<T>) -> Result<(), Box<dyn Error>> {
57        y.copy_from(&x);
58        Ok(())
59    }
60}
61
62pub trait CgStoppingCriterion<T: Scalar> {
63    /// Called by CG at the start of a new solve.
64    fn reset(&self, _a: &dyn LinearOperator<T>, _x: DVectorView<T>, _b: DVectorView<T>) {}
65
66    fn has_converged(
67        &self,
68        a: &dyn LinearOperator<T>,
69        x: DVectorView<T>,
70        b: DVectorView<T>,
71        b_norm: T,
72        iteration: usize,
73        approx_residual: DVectorView<T>,
74    ) -> Result<bool, SolveErrorKind>;
75}
76
77/// Relative residual tolerance ||r|| <= tol * ||b||.
78///
79/// Note that we use the *approximate* residual given by Conjugate-Gradient. For ill-conditioned
80/// problems, it is possible that CG's residual converges, but the real residual does not.
81/// However, in these cases, it is often the case that CG in any case is unable to obtain
82/// a more accurate solution, and a better preconditioner would be required if a high-resolution
83/// solution is desired.
84#[derive(Debug)]
85pub struct RelativeResidualCriterion<T: Scalar> {
86    tol: T,
87}
88
89impl<T: Scalar + Zero> RelativeResidualCriterion<T> {
90    pub fn new(tol: T) -> Self {
91        Self { tol }
92    }
93}
94
95impl Default for RelativeResidualCriterion<f64> {
96    fn default() -> Self {
97        Self::new(1e-8)
98    }
99}
100
101impl Default for RelativeResidualCriterion<f32> {
102    fn default() -> Self {
103        Self::new(1e-4)
104    }
105}
106
107impl<T> CgStoppingCriterion<T> for RelativeResidualCriterion<T>
108where
109    T: Real,
110{
111    fn has_converged(
112        &self,
113        _a: &dyn LinearOperator<T>,
114        _x: DVectorView<T>,
115        _b: DVectorView<T>,
116        b_norm: T,
117        _iteration: usize,
118        approx_residual: DVectorView<T>,
119    ) -> Result<bool, SolveErrorKind> {
120        let r_approx_norm = approx_residual.norm();
121        let converged = r_approx_norm <= self.tol * b_norm;
122        Ok(converged)
123    }
124}
125
126#[derive(Debug, Clone)]
127#[allow(non_snake_case)]
128pub struct CgWorkspace<T: Scalar> {
129    r: DVector<T>,
130    z: DVector<T>,
131    p: DVector<T>,
132    Ap: DVector<T>,
133}
134
135#[allow(non_snake_case)]
136struct Buffers<'a, T: Scalar> {
137    r: &'a mut DVector<T>,
138    z: &'a mut DVector<T>,
139    p: &'a mut DVector<T>,
140    Ap: &'a mut DVector<T>,
141}
142
143impl<T: Scalar + Zero> Default for CgWorkspace<T> {
144    fn default() -> Self {
145        Self {
146            r: DVector::zeros(0),
147            z: DVector::zeros(0),
148            p: DVector::zeros(0),
149            Ap: DVector::zeros(0),
150        }
151    }
152}
153
154impl<T: Scalar + Zero> CgWorkspace<T> {
155    fn prepare_buffers(&mut self, dim: usize) -> Buffers<T> {
156        self.r.resize_vertically_mut(dim, T::zero());
157        self.z.resize_vertically_mut(dim, T::zero());
158        self.p.resize_vertically_mut(dim, T::zero());
159        self.Ap.resize_vertically_mut(dim, T::zero());
160        Buffers {
161            r: &mut self.r,
162            z: &mut self.z,
163            p: &mut self.p,
164            Ap: &mut self.Ap,
165        }
166    }
167}
168
169#[derive(Debug)]
170enum OwnedOrMutRef<'a, T> {
171    Owned(T),
172    MutRef(&'a mut T),
173}
174
175impl<'a, T> Deref for OwnedOrMutRef<'a, T> {
176    type Target = T;
177
178    fn deref(&self) -> &Self::Target {
179        match self {
180            Self::Owned(owned) => &owned,
181            Self::MutRef(mutref) => &*mutref,
182        }
183    }
184}
185
186impl<'a, T> DerefMut for OwnedOrMutRef<'a, T> {
187    fn deref_mut(&mut self) -> &mut Self::Target {
188        match self {
189            Self::Owned(owned) => owned,
190            Self::MutRef(mutref) => mutref,
191        }
192    }
193}
194
195#[derive(Debug)]
196pub struct ConjugateGradient<'a, T, A, P, Criterion>
197where
198    T: Scalar,
199{
200    workspace: OwnedOrMutRef<'a, CgWorkspace<T>>,
201    operator: A,
202    preconditioner: P,
203    stopping_criterion: Criterion,
204    max_iter: Option<usize>,
205}
206
207impl<'a, T: Scalar + Zero> ConjugateGradient<'a, T, (), IdentityOperator, ()> {
208    pub fn new() -> Self {
209        Self {
210            workspace: OwnedOrMutRef::Owned(CgWorkspace::default()),
211            operator: (),
212            preconditioner: IdentityOperator,
213            stopping_criterion: (),
214            max_iter: None,
215        }
216    }
217}
218
219impl<'a, T: Scalar> ConjugateGradient<'a, T, (), IdentityOperator, ()> {
220    pub fn with_workspace(workspace: &'a mut CgWorkspace<T>) -> Self {
221        Self {
222            workspace: OwnedOrMutRef::MutRef(workspace),
223            operator: (),
224            preconditioner: IdentityOperator,
225            stopping_criterion: (),
226            max_iter: None,
227        }
228    }
229}
230
231impl<'a, T: Scalar, P, Criterion> ConjugateGradient<'a, T, (), P, Criterion> {
232    pub fn with_operator<A>(self, operator: A) -> ConjugateGradient<'a, T, A, P, Criterion> {
233        ConjugateGradient {
234            workspace: self.workspace,
235            operator,
236            preconditioner: self.preconditioner,
237            stopping_criterion: self.stopping_criterion,
238            max_iter: self.max_iter,
239        }
240    }
241}
242
243impl<'a, T: Scalar, A, P, Criterion> ConjugateGradient<'a, T, A, P, Criterion> {
244    pub fn with_preconditioner<P2>(self, preconditioner: P2) -> ConjugateGradient<'a, T, A, P2, Criterion> {
245        ConjugateGradient {
246            workspace: self.workspace,
247            operator: self.operator,
248            preconditioner,
249            stopping_criterion: self.stopping_criterion,
250            max_iter: self.max_iter,
251        }
252    }
253
254    pub fn with_max_iter(self, max_iter: usize) -> Self {
255        Self {
256            max_iter: Some(max_iter),
257            ..self
258        }
259    }
260}
261
262impl<'a, T: Scalar, A, P> ConjugateGradient<'a, T, A, P, ()> {
263    pub fn with_stopping_criterion<Criterion>(
264        self,
265        stopping_criterion: Criterion,
266    ) -> ConjugateGradient<'a, T, A, P, Criterion> {
267        ConjugateGradient {
268            workspace: self.workspace,
269            operator: self.operator,
270            preconditioner: self.preconditioner,
271            stopping_criterion,
272            max_iter: self.max_iter,
273        }
274    }
275}
276
277#[derive(Debug)]
278#[non_exhaustive]
279pub enum SolveErrorKind {
280    OperatorError(Box<dyn Error>),
281    PreconditionerError(Box<dyn Error>),
282    StoppingCriterionError(Box<dyn Error>),
283    IndefiniteOperator,
284    IndefinitePreconditioner,
285    MaxIterationsReached { max_iter: usize },
286}
287
288impl fmt::Display for SolveErrorKind {
289    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290        match self {
291            Self::OperatorError(err) => {
292                write!(f, "Error applying operator: ")?;
293                err.fmt(f)
294            }
295            Self::PreconditionerError(err) => {
296                write!(f, "Error applying preconditioner: ")?;
297                err.fmt(f)
298            }
299            Self::StoppingCriterionError(err) => {
300                write!(f, "Error evaluating stopping criterion: ")?;
301                err.fmt(f)
302            }
303            Self::IndefiniteOperator => write!(f, "Operator appears to be indefinite: "),
304            Self::IndefinitePreconditioner => write!(f, "Indefinite preconditioner: "),
305            Self::MaxIterationsReached { max_iter } => {
306                write!(f, "Max iterations ({}) reached.", max_iter)
307            }
308        }
309    }
310}
311
312#[non_exhaustive]
313#[derive(Debug)]
314pub struct SolveError<T> {
315    pub output: CgOutput<T>,
316    pub kind: SolveErrorKind,
317}
318
319impl<T> SolveError<T> {
320    fn new(output: CgOutput<T>, kind: SolveErrorKind) -> Self {
321        Self { output, kind }
322    }
323}
324
325impl<T> fmt::Display for SolveError<T> {
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        write!(f, "CG solve failed after {}", self.output.num_iterations)?;
328        write!(f, "Error: {}", self.kind)
329    }
330}
331
332impl<T: fmt::Debug> std::error::Error for SolveError<T> {}
333
334/// y = Ax
335fn apply_operator<'a, T, A>(
336    y: impl Into<DVectorViewMut<'a, T>>,
337    a: &'a A,
338    x: impl Into<DVectorView<'a, T>>,
339) -> Result<(), Box<dyn Error>>
340where
341    T: Scalar,
342    A: LinearOperator<T>,
343{
344    a.apply(y.into(), x.into())
345}
346
347#[non_exhaustive]
348#[derive(Debug, Clone)]
349pub struct CgOutput<T> {
350    /// Number of iterations of the solver.
351    ///
352    /// Corresponds to the number of updates made to the (initial) solution vector,
353    pub num_iterations: usize,
354    marker: PhantomData<T>,
355}
356
357impl<'a, T, A, P, Criterion> ConjugateGradient<'a, T, A, P, Criterion>
358where
359    T: Real,
360    A: LinearOperator<T>,
361    P: LinearOperator<T>,
362    Criterion: CgStoppingCriterion<T>,
363{
364    pub fn solve_with_guess<'b>(
365        &mut self,
366        b: impl Into<DVectorView<'b, T>>,
367        x: impl Into<DVectorViewMut<'b, T>>,
368    ) -> Result<CgOutput<T>, SolveError<T>> {
369        self.solve_with_guess_(b.into(), x.into())
370    }
371
372    #[allow(non_snake_case)]
373    fn solve_with_guess_(&mut self, b: DVectorView<T>, mut x: DVectorViewMut<T>) -> Result<CgOutput<T>, SolveError<T>> {
374        use SolveErrorKind::*;
375        assert_eq!(b.len(), x.len());
376
377        let mut output = CgOutput {
378            num_iterations: 0,
379            marker: PhantomData,
380        };
381
382        let Buffers { r, z, p, Ap } = self.workspace.prepare_buffers(x.len());
383
384        // r = b - Ax
385        // First: r <- Ax
386        if let Err(err) = apply_operator(&mut *r, &self.operator, &x) {
387            return Err(SolveError::new(output, OperatorError(err)));
388        }
389        // Second: r <- b - r
390        r.zip_apply(&b, |r_i, b_i| *r_i = b_i - r_i.clone());
391
392        // z = Pr
393        if let Err(err) = apply_operator(&mut *z, &self.preconditioner, &*r) {
394            return Err(SolveError::new(output, PreconditionerError(err)));
395        }
396
397        // p = z
398        p.copy_from(&z);
399
400        let mut zTr = z.dot(r);
401        let mut pAp;
402
403        let b_norm = b.norm();
404
405        if b_norm == T::zero() {
406            x.fill(T::zero());
407            return Ok(output);
408        }
409
410        loop {
411            // TODO: Can we simplify this monstronsity?
412            let convergence = self.stopping_criterion.has_converged(
413                &self.operator,
414                (&x).into(),
415                (&b).into(),
416                b_norm,
417                output.num_iterations,
418                (&*r).into(),
419            );
420
421            let has_converged = match convergence {
422                Ok(converged) => converged,
423                Err(error_kind) => return Err(SolveError::new(output, error_kind)),
424            };
425
426            if has_converged {
427                break;
428            } else if let Some(max_iter) = self.max_iter {
429                if output.num_iterations >= max_iter {
430                    return Err(SolveError::new(output, MaxIterationsReached { max_iter }));
431                }
432            }
433
434            // Ap = A * p
435            if let Err(err) = apply_operator(&mut *Ap, &self.operator, &*p) {
436                return Err(SolveError::new(output, OperatorError(err)));
437            }
438            pAp = p.dot(&Ap);
439
440            if pAp <= T::zero() {
441                return Err(SolveError {
442                    output,
443                    kind: SolveErrorKind::IndefiniteOperator,
444                });
445            }
446            if zTr <= T::zero() {
447                return Err(SolveError {
448                    output,
449                    kind: SolveErrorKind::IndefinitePreconditioner,
450                });
451            }
452
453            let alpha = zTr / pAp;
454            // x <- x + alpha * p
455            x.zip_apply(&*p, |x_i, p_i| *x_i += alpha * p_i);
456            // r <- r - alpha * Ap
457            r.zip_apply(&*Ap, |r_i, Ap_i| *r_i -= alpha * Ap_i);
458
459            // Number of iterations corresponds to number of updates to the x vector
460            output.num_iterations += 1;
461
462            // z <- P r
463            if let Err(err) = apply_operator(&mut *z, &self.preconditioner, &*r) {
464                return Err(SolveError::new(output, PreconditionerError(err)));
465            }
466            let zTr_next = z.dot(&*r);
467            let beta = zTr_next / zTr;
468
469            // p <- beta * p + z
470            p.zip_apply(&*z, |p_i, z_i| {
471                *p_i *= beta;
472                *p_i += z_i;
473            });
474
475            zTr = zTr_next;
476        }
477
478        Ok(output)
479    }
480}