arr_rs/linalg/operations/
solving_inverting.rs1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 extensions::prelude::*,
5 linalg::prelude::*,
6 numeric::prelude::*,
7 validators::prelude::*,
8};
9
10pub trait ArrayLinalgSolvingInvertingProducts<N: NumericOps> where Self: Sized + Clone {
12
13 fn solve(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
33}
34
35impl <N: NumericOps> ArrayLinalgSolvingInvertingProducts<N> for Array<N> {
36
37 fn solve(&self, other: &Self) -> Result<Self, ArrayError> {
38 self.is_dim_supported(&[2])?;
39 let n = self.get_shape()?[0];
40 self.is_square()?;
41 other.get_shape()?[0].is_equal(&n)?;
42
43 if self.det()?[0].to_f64().abs() < 1e-12 {
44 return Err(ArrayError::SingularMatrix);
45 };
46
47 let mut arr_l = Self::identity(n)?.to_array_f64()?.to_matrix()?;
48 let mut arr_u = self.to_array_f64()?.to_matrix()?;
49
50 for j in 0..n {
51 let mut pivot_row = j;
52 for i in j + 1..n {
53 if arr_u[i][j].abs() > arr_u[pivot_row][j].abs() { pivot_row = i; }
54 }
55
56 if pivot_row != j {
57 let tmp = arr_u[pivot_row].clone();
58 arr_u[pivot_row] = arr_u[j].clone();
59 arr_u[j] = tmp;
60
61 let tmp = arr_l[pivot_row].clone();
62 for (idx, item) in tmp.iter().enumerate().take(j) {
63 arr_l[pivot_row][idx] = arr_l[j][idx];
64 arr_l[j][idx] = *item;
65 }
66 }
67 for i in j + 1..n {
68 let factor = arr_u[i][j] / arr_u[j][j];
69 arr_l[i][j] = factor;
70 let uu = ((arr_u[j][j..].to_vec().to_array()? * factor)?).get_elements()?;
71 for jj in j..n {
72 arr_u[i][jj] -= uu[jj - j];
73 }
74 }
75 }
76
77 let other = other.to_array_f64()?;
78 let mut arr_y = Array::<f64>::zeros_like(&other)?.get_rows()?;
79 let arr_b = other.to_array_f64()?.get_rows()?;
80 for i in 0..n {
81 let l_tmp = arr_l[i][..i].to_vec().to_array()?;
82 let y_tmp = arr_y[..i].iter().flatten().copied().collect::<Vec<f64>>().to_array()?;
83 let dot = l_tmp.dot(&y_tmp).unwrap_or(Array::flat(vec![0.; arr_b[0].len()?])?);
84 arr_y[i] = arr_b[i].broadcast_to(dot.get_shape()?)? - dot;
85 }
86
87 let mut arr_x = Array::<f64>::zeros_like(&other)?.get_rows()?;
88 for i in (0..n).rev() {
89 let u_tmp = arr_u[i][i + 1..].to_vec().to_array()?;
90 let x_tmp = arr_x[i + 1..].iter().flatten().copied().collect::<Vec<f64>>().to_array()?;
91 let dot = u_tmp.dot(&x_tmp).unwrap_or(Array::flat(vec![0.; arr_b[0].len()?])?);
92 arr_x[i] = ((arr_y[i].clone() - dot) / arr_u[i][i])?;
93 }
94
95 arr_x.into_iter()
96 .flatten()
97 .collect::<Array<_>>()
98 .to_array_num()
99 .reshape(&other.get_shape()?)
100 }
101}
102
103impl <N: NumericOps> ArrayLinalgSolvingInvertingProducts<N> for Result<Array<N>, ArrayError> {
104
105 fn solve(&self, other: &Array<N>) -> Self {
106 self.clone()?.solve(other)
107 }
108}