deepcorr_normalization/cosine.rs
1/*
2This file is part of DeepCorr.
3
4DeepCorr is free software: you can redistribute it and/or modify it under
5the terms of the GNU General Public License as published by the Free
6Software Foundation, either version 3 of the License, or any later version.
7
8DeepCorr is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
9without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
10PURPOSE. See the GNU General Public License for more details.
11
12You should have received a copy of the GNU General Public License along with
13DeepCorr. If not, see <https://www.gnu.org/licenses/>.
14*/
15
16use ndarray::{Array2, Axis};
17use crate::NormError;
18
19pub struct CosineNormalizer {
20 pub epsilon: f64,
21}
22
23impl Default for CosineNormalizer {
24 fn default() -> Self {
25 Self { epsilon: 1e-10 }
26 }
27}
28
29impl CosineNormalizer {
30 pub fn new(epsilon: f64) -> Self {
31 Self { epsilon }
32 }
33
34 pub fn normalize(&self, data: &Array2<f64>) -> Result<Array2<f64>, NormError> {
35 if data.is_empty() {
36 return Err(NormError::EmptyInput);
37 }
38
39 let mut normalized = data.clone();
40
41 for (idx, mut row) in normalized.axis_iter_mut(Axis(0)).enumerate() {
42 let norm = row.mapv(|x| x.powi(2)).sum().sqrt();
43
44 if norm < self.epsilon {
45 return Err(NormError::ZeroMagnitude(idx));
46 }
47
48 row.mapv_inplace(|x| x / norm);
49 }
50
51 Ok(normalized)
52 }
53}