mdarray_linalg/solve.rs
1//! Linear system solving utilities for equations of the form Ax = B
2use mdarray::{DSlice, DTensor, Layout};
3use thiserror::Error;
4
5/// Error types related to linear system solving
6#[derive(Debug, Error)]
7pub enum SolveError {
8 #[error("Backend error code: {0}")]
9 BackendError(i32),
10
11 #[error("Matrix is singular: U({diagonal},{diagonal}) is exactly zero")]
12 SingularMatrix { diagonal: i32 },
13
14 #[error("Invalid matrix dimensions")]
15 InvalidDimensions,
16}
17
18/// Holds the results of a linear system solve, including
19/// the solution matrix and permutation matrix
20pub struct SolveResult<T> {
21 pub x: DTensor<T, 2>,
22 pub p: DTensor<T, 2>,
23}
24
25/// Result type for linear system solving, returning either a
26/// `SolveResult` or a `SolveError`
27pub type SolveResultType<T> = Result<SolveResult<T>, SolveError>;
28
29/// Linear system solver using LU decomposition
30pub trait Solve<T> {
31 /// Solves linear system AX = b overwriting existing matrices
32 /// A is overwritten with its LU decomposition
33 /// B is overwritten with the solution X
34 /// P is filled with the permutation matrix such that A = P*L*U
35 /// Returns Ok(()) on success, Err(SolveError) on failure
36 fn solve_overwrite<La: Layout, Lb: Layout, Lp: Layout>(
37 &self,
38 a: &mut DSlice<T, 2, La>,
39 b: &mut DSlice<T, 2, Lb>,
40 p: &mut DSlice<T, 2, Lp>,
41 ) -> Result<(), SolveError>;
42
43 /// Solves linear system AX = B with new allocated solution matrix
44 /// A is modified (overwritten with LU decomposition)
45 /// Returns the solution X and P the permutation matrix, or error
46 fn solve<La: Layout, Lb: Layout>(
47 &self,
48 a: &mut DSlice<T, 2, La>,
49 b: &DSlice<T, 2, Lb>,
50 ) -> SolveResultType<T>;
51}