arr_rs/linalg/operations/
decompositions.rs1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 linalg::prelude::*,
5 numeric::prelude::*,
6 validators::prelude::*,
7};
8
9pub trait ArrayLinalgDecompositions<N: NumericOps> where Self: Sized + Clone {
11
12 fn qr(&self) -> LinalgResult<N>;
31}
32
33impl <N: NumericOps> ArrayLinalgDecompositions<N> for Array<N> {
34
35 fn qr(&self) -> LinalgResult<N> {
36 self.is_dim_unsupported(&[0, 1])?;
37 self.is_square()?;
38 if self.ndim()? == 2 {
39 Ok(vec![Self::gram_schmidt(self)?])
40 } else {
41 let shape = self.get_shape()?;
42 let sub_shape = shape[self.ndim()? - 2 ..].to_vec();
43 let qrs = self
44 .ravel()?
45 .split(self.len()? / sub_shape.iter().product::<usize>(), None)?
46 .iter()
47 .map(|arr| arr.reshape(&sub_shape).qr())
48 .collect::<Vec<Result<Vec<(Self, Self)>, _>>>()
49 .has_error()?.into_iter()
50 .flat_map(Result::unwrap)
51 .collect::<Vec<(Self, Self)>>();
52 Ok(qrs)
53 }
54 }
55}
56
57impl <N: NumericOps> ArrayLinalgDecompositions<N> for Result<Array<N>, ArrayError> {
58
59 fn qr(&self) -> LinalgResult<N> {
60 self.clone()?.qr()
61 }
62}
63
64trait QRHelper<N: NumericOps> {
65
66 fn gram_schmidt(arr: &Array<N>) -> Result<(Array<N>, Array<N>), ArrayError> {
67
68 fn project<N: NumericOps>(u: &Array<N>, a: &Array<N>) -> Result<Array<N>, ArrayError> {
69 let cols = u.len()?;
70 let result = u.inner(a).broadcast_to(vec![cols])?
71 / u.inner(u).broadcast_to(vec![cols])?
72 * u.clone();
73 Ok(result)
74 }
75
76 fn normalize<N: NumericOps>(arr: &Array<N>) -> Result<Array<N>, ArrayError> {
77 let norm = arr
78 .get_elements()?
79 .into_iter()
80 .map(|u| u.to_f64().powi(2))
81 .sum::<f64>().sqrt();
82 Array::single(N::from(norm))
83 .broadcast_to(vec![arr.len()?])
84 }
85
86 let (mut u_vecs, mut e_vecs) = (vec![], vec![]);
87 for col in arr.get_columns()? {
88 let mut a = col;
89 for u in &u_vecs { a -= project(u, &a)? }
90 u_vecs.push(a.clone());
91 let a_norm = normalize(&a)?;
92 e_vecs.push(a / a_norm);
93 }
94
95 let q = Array::concatenate(e_vecs, None)
96 .reshape(&arr.get_shape()?)
97 .transpose(None)?;
98 let r = q
99 .transpose(None)
100 .dot(arr)?;
101
102 Ok((q, r))
103 }
104}
105
106impl <N: NumericOps> QRHelper<N> for Array<N> {}