arr_rs/linalg/operations/
decompositions.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    linalg::prelude::*,
5    numeric::prelude::*,
6    validators::prelude::*,
7};
8
9/// `ArrayTrait` - Array Linalg Decompositions functions
10pub trait ArrayLinalgDecompositions<N: NumericOps> where Self: Sized + Clone {
11
12    /// Compute the qr factorization of a matrix
13    ///
14    /// # Examples
15    ///
16    /// ```
17    /// use arr_rs::prelude::*;
18    ///
19    /// let array = Array::new(vec![1., 2., 3., 4., 5., 6., 7., 8., 9.], vec![3, 3]);
20    /// let result = array.qr().unwrap();
21    /// let (q, r) = &result.clone()[0];
22    ///
23    /// assert_eq!(q, &Array::new(vec![0.12309149097933272, 0.9045340337332908, 0.1111111111111111, 0.4923659639173309, 0.30151134457776335, 0.4444444444444444, 0.8616404368553291, -0.30151134457776435, 0.8888888888888888], vec![3, 3]).unwrap());
24    /// assert_eq!(r, &Array::new(vec![8.12403840463596, 9.601136296387953, 11.078234188139945, -6.494804694057166e-15, 0.9045340337332837, 1.809068067466573, 8.11111111111111, 9.555555555555555, 11.], vec![3, 3]).unwrap());
25    /// ```
26    ///
27    /// # Errors
28    ///
29    /// may returns `ArrayError`
30    fn qr(&self) -> LinalgResult<N>;
31}
32
33impl <N: NumericOps> ArrayLinalgDecompositions<N> for Array<N> {
34
35    fn qr(&self) -> LinalgResult<N> {
36        self.is_dim_unsupported(&[0, 1])?;
37        self.is_square()?;
38        if self.ndim()? == 2 {
39            Ok(vec![Self::gram_schmidt(self)?])
40        } else {
41            let shape = self.get_shape()?;
42            let sub_shape = shape[self.ndim()? - 2 ..].to_vec();
43            let qrs = self
44                .ravel()?
45                .split(self.len()? / sub_shape.iter().product::<usize>(), None)?
46                .iter()
47                .map(|arr| arr.reshape(&sub_shape).qr())
48                .collect::<Vec<Result<Vec<(Self, Self)>, _>>>()
49                .has_error()?.into_iter()
50                .flat_map(Result::unwrap)
51                .collect::<Vec<(Self, Self)>>();
52            Ok(qrs)
53        }
54    }
55}
56
57impl <N: NumericOps> ArrayLinalgDecompositions<N> for Result<Array<N>, ArrayError> {
58
59    fn qr(&self) -> LinalgResult<N> {
60        self.clone()?.qr()
61    }
62}
63
64trait QRHelper<N: NumericOps> {
65
66    fn gram_schmidt(arr: &Array<N>) -> Result<(Array<N>, Array<N>), ArrayError> {
67
68        fn project<N: NumericOps>(u: &Array<N>, a: &Array<N>) -> Result<Array<N>, ArrayError> {
69            let cols = u.len()?;
70            let result = u.inner(a).broadcast_to(vec![cols])?
71                / u.inner(u).broadcast_to(vec![cols])?
72                * u.clone();
73            Ok(result)
74        }
75
76        fn normalize<N: NumericOps>(arr: &Array<N>) -> Result<Array<N>, ArrayError> {
77            let norm = arr
78                .get_elements()?
79                .into_iter()
80                .map(|u| u.to_f64().powi(2))
81                .sum::<f64>().sqrt();
82            Array::single(N::from(norm))
83                .broadcast_to(vec![arr.len()?])
84        }
85
86        let (mut u_vecs, mut e_vecs) = (vec![], vec![]);
87        for col in arr.get_columns()? {
88            let mut a = col;
89            for u in &u_vecs { a -= project(u, &a)? }
90            u_vecs.push(a.clone());
91            let a_norm = normalize(&a)?;
92            e_vecs.push(a / a_norm);
93        }
94
95        let q = Array::concatenate(e_vecs, None)
96            .reshape(&arr.get_shape()?)
97            .transpose(None)?;
98        let r = q
99            .transpose(None)
100            .dot(arr)?;
101
102        Ok((q, r))
103    }
104}
105
106impl <N: NumericOps> QRHelper<N> for Array<N> {}