arr_rs/linalg/operations/
norms.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    linalg::prelude::*,
5    math::prelude::*,
6    numeric::prelude::*,
7    validators::prelude::*,
8};
9
10/// `ArrayTrait` - Array Linalg Norms functions
11pub trait ArrayLinalgNorms<N: NumericOps> where Self: Sized + Clone {
12
13    /// Calculates matrix or vector norm
14    ///
15    /// # Arguments
16    ///
17    /// * `ord` - order of the norm: {non-zero int, inf, -inf, `fro`, `nuc`}. optional
18    /// * `axis` - axis along which vector norms are to be calculated. optional
19    /// * `keepdims` - if true, the result will broadcast correctly against the input. optional
20    ///
21    /// # Examples
22    ///
23    /// ```
24    /// use arr_rs::prelude::*;
25    ///
26    /// let array_a = Array::arange(-4., 4., None);
27    /// let array_b = array_a.reshape(&[3, 3]);
28    ///
29    /// let expected = Array::single(7.745966692414834);
30    /// assert_eq!(expected, array_a.norm(None::<NormOrd>, None, None));
31    /// assert_eq!(expected, array_b.norm(None::<NormOrd>, None, None));
32    /// ```
33    ///
34    /// # Errors
35    ///
36    /// may returns `ArrayError`
37    fn norm(&self, ord: Option<impl NormOrdType>, axis: Option<Vec<isize>>, keepdims: Option<bool>) -> Result<Array<N>, ArrayError>;
38
39    /// Compute the determinant of an array
40    ///
41    /// # Examples
42    ///
43    /// ```
44    /// use arr_rs::prelude::*;
45    ///
46    /// assert_eq!(Array::single(-14), Array::new(vec![3, 8, 4, 6], vec![2, 2]).det());
47    /// ```
48    ///
49    /// # Errors
50    ///
51    /// may returns `ArrayError`
52    fn det(&self) -> Result<Array<N>, ArrayError>;
53}
54
55impl <N: NumericOps> ArrayLinalgNorms<N> for Array<N> {
56
57    fn norm(&self, ord: Option<impl NormOrdType>, axis: Option<Vec<isize>>, keepdims: Option<bool>) -> Result<Self, ArrayError> {
58
59        fn norm_simple<N: NumericOps>(array: &Array<N>, keepdims: Option<bool>) -> Result<Array<N>, ArrayError> {
60            let ndim = array.ndim()?;
61            let array = array.ravel()?;
62            let result = array
63                .dot(&array)
64                .sqrt();
65            if keepdims.unwrap_or(false) { result.reshape(&[ndim; 1]) }
66            else { result }
67        }
68
69        let ndim = self.ndim()?;
70        if axis.is_none() {
71            match ord.clone() {
72                Some(ord) => {
73                    let ord = ord.to_ord()?;
74                    if (ndim == 2 && ord == NormOrd::Fro) || (ndim == 1 && ord == NormOrd::Int(2)) {
75                        return norm_simple(self, keepdims)
76                    }
77                },
78                None => return norm_simple(self, keepdims)
79            }
80        }
81
82        let axis = axis.unwrap_or_else(|| (0..ndim.to_isize()).collect());
83        match axis.len() {
84            1 => {
85                let axis = Some(axis[0]);
86                let ord = match ord {
87                    Some(o) => o.to_ord()?,
88                    None => NormOrd::Int(2),
89                };
90                match ord {
91                    NormOrd::Inf => self.abs().max(axis),
92                    NormOrd::NegInf => self.abs().min(axis),
93                    NormOrd::Int(0) => self.map(|&i| if i == N::zero() { N::zero() } else { N::one() }).sum(axis),
94                    NormOrd::Int(1) => self.abs().sum(axis),
95                    NormOrd::Int(2) => self.multiply(self).abs()?.sum(axis).sqrt(),
96                    NormOrd::Int(value) => self.abs()
97                        .float_power(&Self::single(N::from(value))?)
98                        .sum(axis)
99                        .float_power(&Self::single(N::from(value)).reciprocal()?),
100                    NormOrd::Fro | NormOrd::Nuc => {
101                        Err(ArrayError::ParameterError { param: "`ord`", message: "invalid norm order for vectors." })
102                    },
103                }
104            }
105            2 => {
106                let row_axis = self.normalize_axis(axis[0]).to_isize();
107                let col_axis = self.normalize_axis(axis[1]).to_isize();
108                if row_axis == col_axis {
109                    return Err(ArrayError::ParameterError { param: "`axis`", message: "duplicate axes given." });
110                }
111                let ord = match ord {
112                    Some(o) => o.to_ord()?,
113                    None => NormOrd::Fro,
114                };
115                let result = match ord {
116                    NormOrd::Int(1) => {
117                        let col_axis = if col_axis > row_axis { -col_axis } else { col_axis };
118                        self.abs().sum(Some(row_axis)).max(Some(col_axis))
119                    },
120                    NormOrd::Int(-1) => {
121                        let col_axis = if col_axis > row_axis { -col_axis } else { col_axis };
122                        self.abs().sum(Some(row_axis)).min(Some(col_axis))
123                    },
124                    NormOrd::Inf => {
125                        let row_axis = if row_axis > col_axis { -row_axis } else { row_axis };
126                        self.abs().sum(Some(col_axis)).max(Some(row_axis))
127                    },
128                    NormOrd::NegInf => {
129                        let row_axis = if row_axis > col_axis { -row_axis } else { row_axis };
130                        self.abs().sum(Some(col_axis)).min(Some(row_axis))
131                    },
132                    _ => {
133                        Err(ArrayError::ParameterError { param: "`ord`", message: "invalid norm order for vectors." })
134                    },
135                };
136                if keepdims.unwrap_or(false) {
137                    let mut new_shape = result.get_shape()?;
138                    new_shape.push(1);
139                    result.reshape(&new_shape)
140                } else {
141                    result
142                }
143            }
144            _ => Err(ArrayError::ParameterError { param: "`axis`", message: "improper number of dimensions to norm." })
145        }
146    }
147
148    fn det(&self) -> Result<Self<>, ArrayError> {
149        if self.ndim()? == 0 {
150            Err(ArrayError::MustBeAtLeast { value1: "`dimension`".to_string(), value2: "1".to_string() })
151        } else if self.ndim()? == 1 {
152            Ok(self.clone())
153        } else if self.ndim()? == 2 {
154            let shape = self.get_shape()?;
155            self.is_square()?;
156            if shape[0] == 2 {
157                Self::single(N::from(self[0].to_f64().mul_add(self[3].to_f64(), -self[1].to_f64() * self[2].to_f64())))
158            } else {
159                let elems = (0..shape[0])
160                    .map(|i| self[i * shape[0]].to_f64())
161                    .collect::<Vec<f64>>();
162                let dets = (0..shape[0])
163                    .map(|i| Self::minor(self, i, 0).det())
164                    .collect::<Vec<Result<Self, _>>>()
165                    .has_error()?.into_iter()
166                    .map(Result::unwrap)
167                    .map(|i| i[0].to_f64())
168                    .collect::<Vec<f64>>();
169                let result = elems.iter().zip(&dets)
170                    .enumerate()
171                    .map(|(i, (&e, &d))| e * f64::powi(-1., i.to_i32() + 2) * d)
172                    .sum::<f64>();
173                Self::single(N::from(result))
174            }
175        } else {
176            let shape = self.get_shape()?;
177            shape.is_square()?;
178            let sub_shape = shape[self.ndim()? - 2 ..].to_vec();
179            let dets = self
180                .ravel()?
181                .split(self.len()? / sub_shape.iter().product::<usize>(), None)?
182                .iter()
183                .map(|arr| arr.reshape(&sub_shape).det())
184                .collect::<Vec<Result<Self, _>>>()
185                .has_error()?.into_iter()
186                .flat_map(Result::unwrap)
187                .collect::<Self>();
188            Ok(dets)
189        }
190    }
191}
192
193impl <N: NumericOps> ArrayLinalgNorms<N> for Result<Array<N>, ArrayError> {
194
195    fn norm(&self, ord: Option<impl NormOrdType>, axis: Option<Vec<isize>>, keepdims: Option<bool>) -> Self {
196        self.clone()?.norm(ord, axis, keepdims)
197    }
198
199    fn det(&self) -> Self {
200        self.clone()?.det()
201    }
202}
203
204trait NormsHelper<N: NumericOps> {
205
206    fn minor(arr: &Array<N>, row: usize, col: usize) -> Result<Array<N>, ArrayError> {
207        arr.is_dim_supported(&[2])?;
208        if row >= arr.get_shape()?[0] || col >= arr.get_shape()?[1] {
209            return Err(ArrayError::OutOfBounds { value: "Row or column index out of bounds" })
210        }
211
212        let mut sub_shape = arr.get_shape()?;
213        sub_shape[arr.ndim()? - 1] -= 1;
214        sub_shape[arr.ndim()? - 2] -= 1;
215
216        let mut sub_elements = Vec::new();
217        for (i, &element) in arr.get_elements()?.iter().enumerate() {
218            if i / arr.get_shape()?[1] != row && i % arr.get_shape()?[1] != col {
219                sub_elements.push(element);
220            }
221        }
222
223        Array::new(sub_elements, sub_shape)
224    }
225}
226
227impl <N: NumericOps> NormsHelper<N> for Array<N> {}