1use linfa::dataset::{AsTargets, DatasetBase, Float, WithLapack, WithoutLapack};
3use linfa::traits::Transformer;
4#[cfg(not(feature = "blas"))]
5use linfa_linalg::norm::Norm;
6use ndarray::{Array2, ArrayBase, Axis, Data, Ix2, Zip};
7#[cfg(feature = "blas")]
8use ndarray_linalg::norm::Norm;
9
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12
13#[cfg_attr(
14 feature = "serde",
15 derive(Serialize, Deserialize),
16 serde(crate = "serde_crate")
17)]
18#[derive(Debug, Clone, PartialEq, Eq)]
19enum Norms {
20 L1,
21 L2,
22 Max,
23}
24
25#[cfg_attr(
42 feature = "serde",
43 derive(Serialize, Deserialize),
44 serde(crate = "serde_crate")
45)]
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct NormScaler {
48 norm: Norms,
49}
50
51impl NormScaler {
52 pub fn l2() -> Self {
54 Self { norm: Norms::L2 }
55 }
56
57 pub fn l1() -> Self {
59 Self { norm: Norms::L1 }
60 }
61
62 pub fn max() -> Self {
64 Self { norm: Norms::Max }
65 }
66}
67
68impl<F: Float> Transformer<Array2<F>, Array2<F>> for NormScaler {
69 fn transform(&self, x: Array2<F>) -> Array2<F> {
71 let x = x.with_lapack();
73
74 let norms = match &self.norm {
75 Norms::L1 => x.map_axis(Axis(1), |row| F::cast(row.norm_l1())),
76 Norms::L2 => x.map_axis(Axis(1), |row| F::cast(row.norm_l2())),
77 Norms::Max => x.map_axis(Axis(1), |row| F::cast(row.norm_max())),
78 };
79
80 let mut x = x.without_lapack();
82
83 Zip::from(x.rows_mut())
84 .and(&norms)
85 .for_each(|mut row, &norm| {
86 row.mapv_inplace(|el| el / norm);
87 });
88 x
89 }
90}
91
92impl<F: Float, D: Data<Elem = F>, T: AsTargets>
93 Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for NormScaler
94{
95 fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
97 let feature_names = x.feature_names();
98 let (records, targets, weights) = (x.records, x.targets, x.weights);
99 let records = self.transform(records.to_owned());
100 DatasetBase::new(records, targets)
101 .with_weights(weights)
102 .with_feature_names(feature_names)
103 }
104}
105
106#[cfg(test)]
107mod tests {
108
109 use crate::norm_scaling::NormScaler;
110 use approx::assert_abs_diff_eq;
111 use linfa::dataset::DatasetBase;
112 use linfa::traits::Transformer;
113 use ndarray::{array, Array2};
114
115 #[test]
116 fn autotraits() {
117 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
118 has_autotraits::<NormScaler>();
119 }
120
121 #[test]
122 fn test_norm_l2() {
123 let dataset = DatasetBase::from(array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]]);
124 let scaler = NormScaler::l2();
125 let normalized_data = scaler.transform(dataset);
126 let ground_truth = array![[0.4, -0.4, 0.81], [1., 0., 0.], [0., 0.7, -0.7]];
127 assert_abs_diff_eq!(*normalized_data.records(), ground_truth, epsilon = 1e-2);
128 }
129
130 #[test]
131 fn test_norm_l1() {
132 let dataset = DatasetBase::from(array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]]);
133 let scaler = NormScaler::l1();
134 let normalized_data = scaler.transform(dataset);
135 let ground_truth = array![[0.25, -0.25, 0.5], [1., 0., 0.], [0., 0.5, -0.5]];
136 assert_abs_diff_eq!(*normalized_data.records(), ground_truth, epsilon = 1e-2);
137 }
138
139 #[test]
140 fn test_norm_max() {
141 let dataset = DatasetBase::from(array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]]);
142 let scaler = NormScaler::max();
143 let normalized_data = scaler.transform(dataset);
144 let ground_truth = array![[0.5, -0.5, 1.], [1., 0., 0.], [0., 1., -1.]];
145 assert_abs_diff_eq!(*normalized_data.records(), ground_truth, epsilon = 1e-2);
146 }
147
148 #[test]
149 fn test_no_input() {
150 let input: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
151 let ground_truth: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
152 let scaler = NormScaler::max();
153 assert_abs_diff_eq!(scaler.transform(input.clone()), ground_truth);
154 let scaler = NormScaler::l1();
155 assert_abs_diff_eq!(scaler.transform(input.clone()), ground_truth);
156 let scaler = NormScaler::l2();
157 assert_abs_diff_eq!(scaler.transform(input), ground_truth);
158 }
159
160 #[test]
161 fn test_retain_feature_names() {
162 let dataset = linfa_datasets::diabetes();
163 let original_feature_names = dataset.feature_names();
164 let transformed = NormScaler::l2().transform(dataset);
165 assert_eq!(original_feature_names, transformed.feature_names())
166 }
167}