arr_rs/linalg/operations/
norms.rs1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 linalg::prelude::*,
5 math::prelude::*,
6 numeric::prelude::*,
7 validators::prelude::*,
8};
9
10pub trait ArrayLinalgNorms<N: NumericOps> where Self: Sized + Clone {
12
13 fn norm(&self, ord: Option<impl NormOrdType>, axis: Option<Vec<isize>>, keepdims: Option<bool>) -> Result<Array<N>, ArrayError>;
38
39 fn det(&self) -> Result<Array<N>, ArrayError>;
53}
54
55impl <N: NumericOps> ArrayLinalgNorms<N> for Array<N> {
56
57 fn norm(&self, ord: Option<impl NormOrdType>, axis: Option<Vec<isize>>, keepdims: Option<bool>) -> Result<Self, ArrayError> {
58
59 fn norm_simple<N: NumericOps>(array: &Array<N>, keepdims: Option<bool>) -> Result<Array<N>, ArrayError> {
60 let ndim = array.ndim()?;
61 let array = array.ravel()?;
62 let result = array
63 .dot(&array)
64 .sqrt();
65 if keepdims.unwrap_or(false) { result.reshape(&[ndim; 1]) }
66 else { result }
67 }
68
69 let ndim = self.ndim()?;
70 if axis.is_none() {
71 match ord.clone() {
72 Some(ord) => {
73 let ord = ord.to_ord()?;
74 if (ndim == 2 && ord == NormOrd::Fro) || (ndim == 1 && ord == NormOrd::Int(2)) {
75 return norm_simple(self, keepdims)
76 }
77 },
78 None => return norm_simple(self, keepdims)
79 }
80 }
81
82 let axis = axis.unwrap_or_else(|| (0..ndim.to_isize()).collect());
83 match axis.len() {
84 1 => {
85 let axis = Some(axis[0]);
86 let ord = match ord {
87 Some(o) => o.to_ord()?,
88 None => NormOrd::Int(2),
89 };
90 match ord {
91 NormOrd::Inf => self.abs().max(axis),
92 NormOrd::NegInf => self.abs().min(axis),
93 NormOrd::Int(0) => self.map(|&i| if i == N::zero() { N::zero() } else { N::one() }).sum(axis),
94 NormOrd::Int(1) => self.abs().sum(axis),
95 NormOrd::Int(2) => self.multiply(self).abs()?.sum(axis).sqrt(),
96 NormOrd::Int(value) => self.abs()
97 .float_power(&Self::single(N::from(value))?)
98 .sum(axis)
99 .float_power(&Self::single(N::from(value)).reciprocal()?),
100 NormOrd::Fro | NormOrd::Nuc => {
101 Err(ArrayError::ParameterError { param: "`ord`", message: "invalid norm order for vectors." })
102 },
103 }
104 }
105 2 => {
106 let row_axis = self.normalize_axis(axis[0]).to_isize();
107 let col_axis = self.normalize_axis(axis[1]).to_isize();
108 if row_axis == col_axis {
109 return Err(ArrayError::ParameterError { param: "`axis`", message: "duplicate axes given." });
110 }
111 let ord = match ord {
112 Some(o) => o.to_ord()?,
113 None => NormOrd::Fro,
114 };
115 let result = match ord {
116 NormOrd::Int(1) => {
117 let col_axis = if col_axis > row_axis { -col_axis } else { col_axis };
118 self.abs().sum(Some(row_axis)).max(Some(col_axis))
119 },
120 NormOrd::Int(-1) => {
121 let col_axis = if col_axis > row_axis { -col_axis } else { col_axis };
122 self.abs().sum(Some(row_axis)).min(Some(col_axis))
123 },
124 NormOrd::Inf => {
125 let row_axis = if row_axis > col_axis { -row_axis } else { row_axis };
126 self.abs().sum(Some(col_axis)).max(Some(row_axis))
127 },
128 NormOrd::NegInf => {
129 let row_axis = if row_axis > col_axis { -row_axis } else { row_axis };
130 self.abs().sum(Some(col_axis)).min(Some(row_axis))
131 },
132 _ => {
133 Err(ArrayError::ParameterError { param: "`ord`", message: "invalid norm order for vectors." })
134 },
135 };
136 if keepdims.unwrap_or(false) {
137 let mut new_shape = result.get_shape()?;
138 new_shape.push(1);
139 result.reshape(&new_shape)
140 } else {
141 result
142 }
143 }
144 _ => Err(ArrayError::ParameterError { param: "`axis`", message: "improper number of dimensions to norm." })
145 }
146 }
147
148 fn det(&self) -> Result<Self<>, ArrayError> {
149 if self.ndim()? == 0 {
150 Err(ArrayError::MustBeAtLeast { value1: "`dimension`".to_string(), value2: "1".to_string() })
151 } else if self.ndim()? == 1 {
152 Ok(self.clone())
153 } else if self.ndim()? == 2 {
154 let shape = self.get_shape()?;
155 self.is_square()?;
156 if shape[0] == 2 {
157 Self::single(N::from(self[0].to_f64().mul_add(self[3].to_f64(), -self[1].to_f64() * self[2].to_f64())))
158 } else {
159 let elems = (0..shape[0])
160 .map(|i| self[i * shape[0]].to_f64())
161 .collect::<Vec<f64>>();
162 let dets = (0..shape[0])
163 .map(|i| Self::minor(self, i, 0).det())
164 .collect::<Vec<Result<Self, _>>>()
165 .has_error()?.into_iter()
166 .map(Result::unwrap)
167 .map(|i| i[0].to_f64())
168 .collect::<Vec<f64>>();
169 let result = elems.iter().zip(&dets)
170 .enumerate()
171 .map(|(i, (&e, &d))| e * f64::powi(-1., i.to_i32() + 2) * d)
172 .sum::<f64>();
173 Self::single(N::from(result))
174 }
175 } else {
176 let shape = self.get_shape()?;
177 shape.is_square()?;
178 let sub_shape = shape[self.ndim()? - 2 ..].to_vec();
179 let dets = self
180 .ravel()?
181 .split(self.len()? / sub_shape.iter().product::<usize>(), None)?
182 .iter()
183 .map(|arr| arr.reshape(&sub_shape).det())
184 .collect::<Vec<Result<Self, _>>>()
185 .has_error()?.into_iter()
186 .flat_map(Result::unwrap)
187 .collect::<Self>();
188 Ok(dets)
189 }
190 }
191}
192
193impl <N: NumericOps> ArrayLinalgNorms<N> for Result<Array<N>, ArrayError> {
194
195 fn norm(&self, ord: Option<impl NormOrdType>, axis: Option<Vec<isize>>, keepdims: Option<bool>) -> Self {
196 self.clone()?.norm(ord, axis, keepdims)
197 }
198
199 fn det(&self) -> Self {
200 self.clone()?.det()
201 }
202}
203
204trait NormsHelper<N: NumericOps> {
205
206 fn minor(arr: &Array<N>, row: usize, col: usize) -> Result<Array<N>, ArrayError> {
207 arr.is_dim_supported(&[2])?;
208 if row >= arr.get_shape()?[0] || col >= arr.get_shape()?[1] {
209 return Err(ArrayError::OutOfBounds { value: "Row or column index out of bounds" })
210 }
211
212 let mut sub_shape = arr.get_shape()?;
213 sub_shape[arr.ndim()? - 1] -= 1;
214 sub_shape[arr.ndim()? - 2] -= 1;
215
216 let mut sub_elements = Vec::new();
217 for (i, &element) in arr.get_elements()?.iter().enumerate() {
218 if i / arr.get_shape()?[1] != row && i % arr.get_shape()?[1] != col {
219 sub_elements.push(element);
220 }
221 }
222
223 Array::new(sub_elements, sub_shape)
224 }
225}
226
227impl <N: NumericOps> NormsHelper<N> for Array<N> {}