1use crate::{error::DiffsolError, Matrix, NonLinearOpJacobian};
2
3#[cfg(feature = "nalgebra")]
4pub mod nalgebra;
5
6#[cfg(feature = "faer")]
7pub mod faer;
8
9#[cfg(feature = "suitesparse")]
10pub mod suitesparse;
11
12#[cfg(feature = "cuda")]
13pub mod cuda;
14
15pub use faer::lu::LU as FaerLU;
16pub use nalgebra::lu::LU as NalgebraLU;
17
18pub trait LinearSolver<M: Matrix>: Default {
20 fn set_linearisation<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M, C = M::C>>(
23 &mut self,
24 op: &C,
25 x: &M::V,
26 t: M::T,
27 );
28
29 fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M, C = M::C>>(&mut self, op: &C);
33
34 fn solve(&self, b: &M::V) -> Result<M::V, DiffsolError> {
37 let mut b = b.clone();
38 self.solve_in_place(&mut b)?;
39 Ok(b)
40 }
41
42 fn solve_in_place(&self, b: &mut M::V) -> Result<(), DiffsolError>;
43}
44
45pub struct LinearSolveSolution<V> {
46 pub x: V,
47 pub b: V,
48}
49
50impl<V> LinearSolveSolution<V> {
51 pub fn new(b: V, x: V) -> Self {
52 Self { x, b }
53 }
54}
55
56#[cfg(test)]
57pub mod tests {
58 use crate::{
59 linear_solver::{FaerLU, NalgebraLU},
60 matrix::dense_nalgebra_serial::NalgebraMat,
61 op::{closure::Closure, ParameterisedOp},
62 scalar::scale,
63 vector::VectorRef,
64 FaerMat, FaerVec, LinearSolver, Matrix, NalgebraVec, NonLinearOpJacobian, Op, Vector,
65 };
66 use num_traits::{FromPrimitive, One, Zero};
67
68 use super::LinearSolveSolution;
69
70 #[allow(clippy::type_complexity)]
71 pub fn linear_problem<M: Matrix + 'static>() -> (
72 Closure<
73 M,
74 impl Fn(&M::V, &M::V, M::T, &mut M::V),
75 impl Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
76 >,
77 M::T,
78 M::V,
79 Vec<LinearSolveSolution<M::V>>,
80 ) {
81 let diagonal = M::V::from_vec(
82 vec![M::T::from_f64(2.0).unwrap(), M::T::from_f64(2.0).unwrap()],
83 Default::default(),
84 );
85 let jac1 = M::from_diagonal(&diagonal);
86 let jac2 = M::from_diagonal(&diagonal);
87 let ctx = M::C::default();
88 let p = M::V::zeros(0, ctx.clone());
89 let mut op = Closure::new(
90 move |x, _p, _t, y| jac1.gemv(M::T::one(), x, M::T::zero(), y),
92 move |_x, _p, _t, v, y| jac2.gemv(M::T::one(), v, M::T::zero(), y),
93 2,
94 2,
95 p.len(),
96 ctx.clone(),
97 );
98 op.calculate_sparsity(
99 &M::V::from_element(2, M::T::one(), ctx.clone()),
100 M::T::zero(),
101 &p,
102 );
103 let rtol = M::T::from_f64(1e-6).unwrap();
104 let atol = M::V::from_vec(
105 vec![M::T::from_f64(1e-6).unwrap(), M::T::from_f64(1e-6).unwrap()],
106 ctx.clone(),
107 );
108 let solns = vec![LinearSolveSolution::new(
109 M::V::from_vec(
110 vec![M::T::from_f64(2.0).unwrap(), M::T::from_f64(4.0).unwrap()],
111 ctx.clone(),
112 ),
113 M::V::from_vec(vec![M::T::one(), M::T::from_f64(2.0).unwrap()], ctx.clone()),
114 )];
115 (op, rtol, atol, solns)
116 }
117
118 pub fn test_linear_solver<'a, C>(
119 mut solver: impl LinearSolver<C::M>,
120 op: C,
121 rtol: C::T,
122 atol: &'a C::V,
123 solns: Vec<LinearSolveSolution<C::V>>,
124 ) where
125 C: NonLinearOpJacobian,
126 for<'b> &'b C::V: VectorRef<C::V>,
127 {
128 solver.set_problem(&op);
129 let x = C::V::zeros(op.nout(), op.context().clone());
130 let t = C::T::zero();
131 solver.set_linearisation(&op, &x, t);
132 for soln in solns {
133 let x = solver.solve(&soln.b).unwrap();
134 let tol = { &soln.x * scale(rtol) + atol };
135 x.assert_eq(&soln.x, &tol);
136 }
137 }
138
139 #[test]
140 fn test_lu_nalgebra() {
141 let (op, rtol, atol, solns) = linear_problem::<NalgebraMat<f64>>();
142 let p = NalgebraVec::zeros(0, *op.context());
143 let op = ParameterisedOp::new(&op, &p);
144 let s = NalgebraLU::default();
145 test_linear_solver(s, op, rtol, &atol, solns);
146 }
147 #[test]
148 fn test_lu_faer() {
149 let (op, rtol, atol, solns) = linear_problem::<FaerMat<f64>>();
150 let p = FaerVec::zeros(0, *op.context());
151 let op = ParameterisedOp::new(&op, &p);
152 let s = FaerLU::default();
153 test_linear_solver(s, op, rtol, &atol, solns);
154 }
155
156 #[cfg(feature = "cuda")]
157 #[test]
158 fn test_lu_cuda() {
159 use crate::{CudaLU, CudaMat, CudaVec};
160 let (op, rtol, atol, solns) = linear_problem::<CudaMat<f64>>();
161 let p = CudaVec::zeros(0, op.context().clone());
162 let op = ParameterisedOp::new(&op, &p);
163 let s = CudaLU::default();
164 test_linear_solver(s, op, rtol, &atol, solns);
165 }
166}