Skip to main content

diffsol/linear_solver/
mod.rs

1use crate::{error::DiffsolError, Matrix, NonLinearOpJacobian};
2
3#[cfg(feature = "nalgebra")]
4pub mod nalgebra;
5
6#[cfg(feature = "faer")]
7pub mod faer;
8
9#[cfg(feature = "suitesparse")]
10pub mod suitesparse;
11
12#[cfg(feature = "cuda")]
13pub mod cuda;
14
15pub use faer::lu::LU as FaerLU;
16pub use nalgebra::lu::LU as NalgebraLU;
17
18/// A solver for the linear problem `Ax = b`, where `A` is a linear operator that is obtained by taking the linearisation of a nonlinear operator `C`
19pub trait LinearSolver<M: Matrix>: Default {
20    // sets the point at which the linearisation of the operator is evaluated
21    // the operator is assumed to have the same sparsity as that given to [Self::set_problem]
22    fn set_linearisation<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M, C = M::C>>(
23        &mut self,
24        op: &C,
25        x: &M::V,
26        t: M::T,
27    );
28
29    /// Set the problem to be solved, any previous problem is discarded.
30    /// Any internal state of the solver is reset.
31    /// This function will normally set the sparsity pattern of the matrix to be solved.
32    fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M, C = M::C>>(&mut self, op: &C);
33
34    /// Solve the problem `Ax = b` and return the solution `x`.
35    /// panics if [Self::set_linearisation] has not been called previously
36    fn solve(&self, b: &M::V) -> Result<M::V, DiffsolError> {
37        let mut b = b.clone();
38        self.solve_in_place(&mut b)?;
39        Ok(b)
40    }
41
42    fn solve_in_place(&self, b: &mut M::V) -> Result<(), DiffsolError>;
43}
44
45pub struct LinearSolveSolution<V> {
46    pub x: V,
47    pub b: V,
48}
49
50impl<V> LinearSolveSolution<V> {
51    pub fn new(b: V, x: V) -> Self {
52        Self { x, b }
53    }
54}
55
56#[cfg(test)]
57pub mod tests {
58    use crate::{
59        linear_solver::{FaerLU, NalgebraLU},
60        matrix::dense_nalgebra_serial::NalgebraMat,
61        op::{closure::Closure, ParameterisedOp},
62        scalar::scale,
63        vector::VectorRef,
64        FaerMat, FaerVec, LinearSolver, Matrix, NalgebraVec, NonLinearOpJacobian, Op, Vector,
65    };
66    use num_traits::{FromPrimitive, One, Zero};
67
68    use super::LinearSolveSolution;
69
70    #[allow(clippy::type_complexity)]
71    pub fn linear_problem<M: Matrix + 'static>() -> (
72        Closure<
73            M,
74            impl Fn(&M::V, &M::V, M::T, &mut M::V),
75            impl Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
76        >,
77        M::T,
78        M::V,
79        Vec<LinearSolveSolution<M::V>>,
80    ) {
81        let diagonal = M::V::from_vec(
82            vec![M::T::from_f64(2.0).unwrap(), M::T::from_f64(2.0).unwrap()],
83            Default::default(),
84        );
85        let jac1 = M::from_diagonal(&diagonal);
86        let jac2 = M::from_diagonal(&diagonal);
87        let ctx = M::C::default();
88        let p = M::V::zeros(0, ctx.clone());
89        let mut op = Closure::new(
90            // f = J * x
91            move |x, _p, _t, y| jac1.gemv(M::T::one(), x, M::T::zero(), y),
92            move |_x, _p, _t, v, y| jac2.gemv(M::T::one(), v, M::T::zero(), y),
93            2,
94            2,
95            p.len(),
96            ctx.clone(),
97        );
98        op.calculate_sparsity(
99            &M::V::from_element(2, M::T::one(), ctx.clone()),
100            M::T::zero(),
101            &p,
102        );
103        let rtol = M::T::from_f64(1e-6).unwrap();
104        let atol = M::V::from_vec(
105            vec![M::T::from_f64(1e-6).unwrap(), M::T::from_f64(1e-6).unwrap()],
106            ctx.clone(),
107        );
108        let solns = vec![LinearSolveSolution::new(
109            M::V::from_vec(
110                vec![M::T::from_f64(2.0).unwrap(), M::T::from_f64(4.0).unwrap()],
111                ctx.clone(),
112            ),
113            M::V::from_vec(vec![M::T::one(), M::T::from_f64(2.0).unwrap()], ctx.clone()),
114        )];
115        (op, rtol, atol, solns)
116    }
117
118    pub fn test_linear_solver<'a, C>(
119        mut solver: impl LinearSolver<C::M>,
120        op: C,
121        rtol: C::T,
122        atol: &'a C::V,
123        solns: Vec<LinearSolveSolution<C::V>>,
124    ) where
125        C: NonLinearOpJacobian,
126        for<'b> &'b C::V: VectorRef<C::V>,
127    {
128        solver.set_problem(&op);
129        let x = C::V::zeros(op.nout(), op.context().clone());
130        let t = C::T::zero();
131        solver.set_linearisation(&op, &x, t);
132        for soln in solns {
133            let x = solver.solve(&soln.b).unwrap();
134            let tol = { &soln.x * scale(rtol) + atol };
135            x.assert_eq(&soln.x, &tol);
136        }
137    }
138
139    #[test]
140    fn test_lu_nalgebra() {
141        let (op, rtol, atol, solns) = linear_problem::<NalgebraMat<f64>>();
142        let p = NalgebraVec::zeros(0, *op.context());
143        let op = ParameterisedOp::new(&op, &p);
144        let s = NalgebraLU::default();
145        test_linear_solver(s, op, rtol, &atol, solns);
146    }
147    #[test]
148    fn test_lu_faer() {
149        let (op, rtol, atol, solns) = linear_problem::<FaerMat<f64>>();
150        let p = FaerVec::zeros(0, *op.context());
151        let op = ParameterisedOp::new(&op, &p);
152        let s = FaerLU::default();
153        test_linear_solver(s, op, rtol, &atol, solns);
154    }
155
156    #[cfg(feature = "cuda")]
157    #[test]
158    fn test_lu_cuda() {
159        use crate::{CudaLU, CudaMat, CudaVec};
160        let (op, rtol, atol, solns) = linear_problem::<CudaMat<f64>>();
161        let p = CudaVec::zeros(0, op.context().clone());
162        let op = ParameterisedOp::new(&op, &p);
163        let s = CudaLU::default();
164        test_linear_solver(s, op, rtol, &atol, solns);
165    }
166}