Skip to main content

deepcorr_normalization/
cosine.rs

1/*
2This file is part of DeepCorr.
3
4DeepCorr is free software: you can redistribute it and/or modify it under 
5the terms of the GNU General Public License as published by the Free 
6Software Foundation, either version 3 of the License, or any later version.
7
8DeepCorr is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; 
9without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
10PURPOSE. See the GNU General Public License for more details.
11
12You should have received a copy of the GNU General Public License along with 
13DeepCorr. If not, see <https://www.gnu.org/licenses/>.
14*/
15
16use ndarray::{Array2, Axis};
17use crate::NormError;
18
19pub struct CosineNormalizer {
20    pub epsilon: f64, 
21}
22
23impl Default for CosineNormalizer {
24    fn default() -> Self {
25        Self { epsilon: 1e-10 }
26    }
27}
28
29impl CosineNormalizer {
30    pub fn new(epsilon: f64) -> Self {
31        Self { epsilon }
32    }
33
34    pub fn normalize(&self, data: &Array2<f64>) -> Result<Array2<f64>, NormError> {
35        if data.is_empty() {
36            return Err(NormError::EmptyInput);
37        }
38
39        let mut normalized = data.clone();
40
41        for (idx, mut row) in normalized.axis_iter_mut(Axis(0)).enumerate() {
42            let norm = row.mapv(|x| x.powi(2)).sum().sqrt();
43
44            if norm < self.epsilon {
45                return Err(NormError::ZeroMagnitude(idx));
46            }
47
48            row.mapv_inplace(|x| x / norm);
49        }
50
51        Ok(normalized)
52    }
53}