iterative_solvers/solver.rs
1//! A trait for iterative solvers.
2
3use crate::{IterSolverError, IterSolverResult};
4use nalgebra::{DMatrix, DVector};
5
6/// A linear system `Ax = b` with a matrix `A` and a right-hand side vector `b`.
7pub struct LinearSystem {
8 mat: DMatrix<f64>,
9 rhs: DVector<f64>,
10}
11
12impl LinearSystem {
13 /// Create a new `LinearSystem` with a matrix `A` and a right-hand side vector `b`.
14 ///
15 /// # Arguments
16 ///
17 /// * `mat` - The matrix `A`.
18 /// * `rhs` - The right-hand side vector `b`.
19 ///
20 /// # Errors
21 ///
22 /// Returns an error if the matrix is not square or if the matrix and the right-hand side vector do not match.
23 ///
24 /// # Examples
25 ///
26 /// ```rust
27 /// use nalgebra::{DMatrix, DVector};
28 /// use iterative_solvers::LinearProblem;
29 ///
30 /// let mat = DMatrix::from_row_slice(2, 2, &[4.0, 1.0, 1.0, 3.0]);
31 /// let rhs = DVector::from_vec(vec![1.0, 2.0]);
32 /// let problem = LinearProblem::new(mat, rhs).unwrap();
33 /// ```
34 pub fn new(mat: DMatrix<f64>, rhs: DVector<f64>) -> IterSolverResult<Self> {
35 if !mat.is_square() {
36 return Err(IterSolverError::DimensionError(format!(
37 "The matrix is not square, whose shape is ({}, {})",
38 mat.shape().0,
39 mat.shape().1
40 )));
41 }
42 if mat.nrows() != rhs.len() {
43 return Err(IterSolverError::DimensionError(format!(
44 "The matrix with order {}, and the rhs with length {}, do not match",
45 mat.nrows(),
46 rhs.len()
47 )));
48 }
49 Ok(Self { mat, rhs })
50 }
51
52 /// Solve the linear problem `Ax = b` using the given solver.
53 ///
54 /// # Arguments
55 ///
56 /// * `solver` - The solver to use.
57 ///
58 /// # Errors
59 ///
60 /// Returns an error if the solver fails to solve the linear problem.
61 ///
62 /// # Examples
63 ///
64 /// ```rust
65 /// use nalgebra::{DMatrix, DVector};
66 /// use iterative_solvers::{LinearProblem, CG};
67 ///
68 /// let mat = DMatrix::from_row_slice(2, 2, &[4.0, 1.0, 1.0, 3.0]);
69 /// let rhs = DVector::from_vec(vec![1.0, 2.0]);
70 /// let problem = LinearProblem::new(mat, rhs).unwrap();
71 /// let solver = CG::new(&mat, &rhs, 1e-6).unwrap();
72 /// let solution = problem.solve(solver).unwrap();
73 /// ```
74 pub fn solve(&self, solver: impl IterativeSolver) -> IterSolverResult<DVector<f64>> {
75 solver.solve(&self.mat, &self.rhs)
76 }
77}
78
79/// A trait for iterative solvers.
80pub trait IterativeSolver {
81 /// Solve the linear problem `Ax = b` using the given solver.
82 ///
83 /// # Arguments
84 ///
85 /// * `mat` - The matrix `A`.
86 /// * `rhs` - The right-hand side vector `b`.
87 ///
88 /// # Returns
89 ///
90 /// A `DVector` containing the solution to the linear problem.
91 fn solve(&self, mat: &DMatrix<f64>, rhs: &DVector<f64>) -> IterSolverResult<DVector<f64>>;
92}