arr_rs/linalg/operations/
solving_inverting.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    extensions::prelude::*,
5    linalg::prelude::*,
6    numeric::prelude::*,
7    validators::prelude::*,
8};
9
10/// `ArrayTrait` - Array Linalg Solving equations and Inverting matrices functions
11pub trait ArrayLinalgSolvingInvertingProducts<N: NumericOps> where Self: Sized + Clone {
12
13    /// Solve a linear matrix equation, or system of linear scalar equations
14    ///
15    /// # Arguments
16    ///
17    /// * `other` - other array to perform operations with
18    ///
19    /// # Examples
20    ///
21    /// ```
22    /// use arr_rs::prelude::*;
23    ///
24    /// let arr_1 = Array::new(vec![2., 1., 1., 3.], vec![2, 2]).unwrap();
25    /// let arr_2 = Array::new(vec![5., 8., 3., 6.], vec![2, 2]).unwrap();
26    /// assert_eq!(Array::new(vec![2.4, 3.6, 0.2, 0.8], vec![2, 2]), arr_1.solve(&arr_2));
27    /// ```
28    ///
29    /// # Errors
30    ///
31    /// may returns `ArrayError`
32    fn solve(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
33}
34
35impl <N: NumericOps> ArrayLinalgSolvingInvertingProducts<N> for Array<N> {
36
37    fn solve(&self, other: &Self) -> Result<Self, ArrayError> {
38        self.is_dim_supported(&[2])?;
39        let n = self.get_shape()?[0];
40        self.is_square()?;
41        other.get_shape()?[0].is_equal(&n)?;
42
43        if self.det()?[0].to_f64().abs() < 1e-12 {
44            return Err(ArrayError::SingularMatrix);
45        };
46
47        let mut arr_l = Self::identity(n)?.to_array_f64()?.to_matrix()?;
48        let mut arr_u = self.to_array_f64()?.to_matrix()?;
49
50        for j in 0..n {
51            let mut pivot_row = j;
52            for i in j + 1..n {
53                if arr_u[i][j].abs() > arr_u[pivot_row][j].abs() { pivot_row = i; }
54            }
55
56            if pivot_row != j {
57                let tmp = arr_u[pivot_row].clone();
58                arr_u[pivot_row] = arr_u[j].clone();
59                arr_u[j] = tmp;
60
61                let tmp = arr_l[pivot_row].clone();
62                for (idx, item) in tmp.iter().enumerate().take(j) {
63                    arr_l[pivot_row][idx] = arr_l[j][idx];
64                    arr_l[j][idx] = *item;
65                }
66            }
67            for i in j + 1..n {
68                let factor = arr_u[i][j] / arr_u[j][j];
69                arr_l[i][j] = factor;
70                let uu = ((arr_u[j][j..].to_vec().to_array()? * factor)?).get_elements()?;
71                for jj in j..n {
72                    arr_u[i][jj] -= uu[jj - j];
73                }
74            }
75        }
76
77        let other = other.to_array_f64()?;
78        let mut arr_y = Array::<f64>::zeros_like(&other)?.get_rows()?;
79        let arr_b = other.to_array_f64()?.get_rows()?;
80        for i in 0..n {
81            let l_tmp = arr_l[i][..i].to_vec().to_array()?;
82            let y_tmp = arr_y[..i].iter().flatten().copied().collect::<Vec<f64>>().to_array()?;
83            let dot = l_tmp.dot(&y_tmp).unwrap_or(Array::flat(vec![0.; arr_b[0].len()?])?);
84            arr_y[i] = arr_b[i].broadcast_to(dot.get_shape()?)? - dot;
85        }
86
87        let mut arr_x = Array::<f64>::zeros_like(&other)?.get_rows()?;
88        for i in (0..n).rev() {
89            let u_tmp = arr_u[i][i + 1..].to_vec().to_array()?;
90            let x_tmp = arr_x[i + 1..].iter().flatten().copied().collect::<Vec<f64>>().to_array()?;
91            let dot = u_tmp.dot(&x_tmp).unwrap_or(Array::flat(vec![0.; arr_b[0].len()?])?);
92            arr_x[i] = ((arr_y[i].clone() - dot) / arr_u[i][i])?;
93        }
94
95        arr_x.into_iter()
96            .flatten()
97            .collect::<Array<_>>()
98            .to_array_num()
99            .reshape(&other.get_shape()?)
100    }
101}
102
103impl <N: NumericOps> ArrayLinalgSolvingInvertingProducts<N> for Result<Array<N>, ArrayError> {
104
105    fn solve(&self, other: &Array<N>) -> Self {
106        self.clone()?.solve(other)
107    }
108}