linfa_preprocessing/
whitening.rs

1//! Methods for uncorrelating data
2//!
3//! Whitening refers to a collection of methods that, given in input a matrix `X` of records with
4//! covariance matrix =  `sigma`, output a whitening matrix `W` such that `W.T` dot `W` = `sigma`.
5//! Appliyng the whitening matrix `W` to the input data gives a new data matrix `Y` of the same
6//! size as the input such that `Y` has
7//! unit diagonal (white) covariance matrix.
8
9use crate::error::{PreprocessingError, Result};
10use linfa::dataset::{AsTargets, Records, WithLapack, WithoutLapack};
11use linfa::traits::{Fit, Transformer};
12use linfa::{DatasetBase, Float};
13#[cfg(not(feature = "blas"))]
14use linfa_linalg::{
15    cholesky::{CholeskyInplace, InverseCInplace},
16    svd::SVD,
17};
18use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
19#[cfg(feature = "blas")]
20use ndarray_linalg::{
21    cholesky::{CholeskyInto, InverseCInto, UPLO},
22    svd::SVD,
23    Scalar,
24};
25
26#[cfg(feature = "serde")]
27use serde_crate::{Deserialize, Serialize};
28
29#[cfg_attr(
30    feature = "serde",
31    derive(Serialize, Deserialize),
32    serde(crate = "serde_crate")
33)]
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub enum WhiteningMethod {
36    Pca,
37    Zca,
38    Cholesky,
39}
40
41/// Struct that can be fitted to the input data to obtain the related whitening matrix.
42/// Fitting returns a [FittedWhitener] struct that can be used to
43/// apply the whitening transformation to the input data.
44#[cfg_attr(
45    feature = "serde",
46    derive(Serialize, Deserialize),
47    serde(crate = "serde_crate")
48)]
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct Whitener {
51    method: WhiteningMethod,
52}
53
54impl Whitener {
55    /// Creates an instance of a Whitener that uses the PCA method
56    pub fn pca() -> Self {
57        Self {
58            method: WhiteningMethod::Pca,
59        }
60    }
61    /// Creates an instance of a Whitener that uses the ZCA (Mahalanobis) method
62    pub fn zca() -> Self {
63        Self {
64            method: WhiteningMethod::Zca,
65        }
66    }
67    /// Creates an instance of a Whitener that uses the cholesky decomposition of the inverse of the covariance matrix
68    pub fn cholesky() -> Self {
69        Self {
70            method: WhiteningMethod::Cholesky,
71        }
72    }
73
74    pub fn method(mut self, method: WhiteningMethod) -> Self {
75        self.method = method;
76        self
77    }
78}
79
80impl<F: Float, D: Data<Elem = F>, T: AsTargets> Fit<ArrayBase<D, Ix2>, T, PreprocessingError>
81    for Whitener
82{
83    type Object = FittedWhitener<F>;
84
85    fn fit(&self, x: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
86        if x.nsamples() == 0 {
87            return Err(PreprocessingError::NotEnoughSamples);
88        }
89        // safe because of above zero samples check
90        let mean = x.records().mean_axis(Axis(0)).unwrap();
91        let sigma = x.records() - &mean;
92
93        // add Lapack + Scalar trait bounds
94        let sigma = sigma.with_lapack();
95
96        let transformation_matrix = match self.method {
97            WhiteningMethod::Pca => {
98                let (_, s, v_t) = sigma.svd(false, true)?;
99
100                // Safe because the second argument in the above call is set to true
101                let mut v_t = v_t.unwrap().without_lapack();
102                #[cfg(feature = "blas")]
103                let s = s.mapv(Scalar::from_real);
104                let s = s.without_lapack();
105
106                let s = s.mapv(|x: F| x.max(F::cast(1e-8)));
107
108                let cov_scale = F::cast(x.nsamples() - 1).sqrt();
109                for (mut v_t, s) in v_t.axis_iter_mut(Axis(0)).zip(s.iter()) {
110                    v_t *= cov_scale / *s;
111                }
112
113                v_t
114            }
115            WhiteningMethod::Zca => {
116                let sigma = sigma.t().dot(&sigma) / F::Lapack::cast(x.nsamples() - 1);
117                let (u, s, _) = sigma.svd(true, false)?;
118
119                // Safe because the first argument in the above call is set to true
120                let u = u.unwrap().without_lapack();
121                #[cfg(feature = "blas")]
122                let s = s.mapv(Scalar::from_real);
123                let s = s.without_lapack();
124
125                let s = s.mapv(|x: F| (F::one() / x.sqrt()).max(F::cast(1e-8)));
126                let lambda: Array2<F> = Array2::<F>::eye(s.len()) * s;
127                u.dot(&lambda).dot(&u.t())
128            }
129            WhiteningMethod::Cholesky => {
130                let sigma = sigma.t().dot(&sigma) / F::Lapack::cast(x.nsamples() - 1);
131                // sigma must be positive definite for us to call cholesky on its inverse, so invc
132                // is allowed here
133                #[cfg(feature = "blas")]
134                let out = sigma
135                    .invc_into()?
136                    .cholesky_into(UPLO::Upper)?
137                    .without_lapack();
138                #[cfg(not(feature = "blas"))]
139                let mut sigma = sigma;
140                #[cfg(not(feature = "blas"))]
141                let out = sigma
142                    .invc_inplace()?
143                    .reversed_axes()
144                    .cholesky_into()?
145                    .reversed_axes()
146                    .without_lapack();
147                out
148            }
149        };
150
151        Ok(FittedWhitener {
152            transformation_matrix,
153            mean,
154        })
155    }
156}
157
158/// Struct that can be used to whiten data. Data will be scaled according to the whitening matrix learned
159/// during fitting.
160/// Obtained by fitting a [Whitener].
161///
162/// Transforming the data used during fitting will yield a scaled data matrix with
163/// unit diagonal covariance matrix.
164///
165/// ### Example
166///
167/// ```rust
168/// use linfa::traits::{Fit, Transformer};
169/// use linfa_preprocessing::whitening::Whitener;
170///
171/// // Load dataset
172/// let dataset = linfa_datasets::diabetes();
173/// // Learn whitening parameters
174/// let whitener = Whitener::pca().fit(&dataset).unwrap();
175/// // transform dataset according to whitening parameters
176/// let dataset = whitener.transform(dataset);
177/// ```
178#[cfg_attr(
179    feature = "serde",
180    derive(Serialize, Deserialize),
181    serde(crate = "serde_crate")
182)]
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct FittedWhitener<F: Float> {
185    transformation_matrix: Array2<F>,
186    mean: Array1<F>,
187}
188
189impl<F: Float> FittedWhitener<F> {
190    /// The matrix used for scaling the data
191    pub fn transformation_matrix(&self) -> ArrayView2<F> {
192        self.transformation_matrix.view()
193    }
194
195    /// The means that will be subtracted to the features before scaling the data
196    pub fn mean(&self) -> ArrayView1<F> {
197        self.mean.view()
198    }
199}
200
201impl<F: Float> Transformer<Array2<F>, Array2<F>> for FittedWhitener<F> {
202    fn transform(&self, x: Array2<F>) -> Array2<F> {
203        (x - &self.mean).dot(&self.transformation_matrix.t())
204    }
205}
206
207impl<F: Float, D: Data<Elem = F>, T: AsTargets>
208    Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>>
209    for FittedWhitener<F>
210{
211    fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
212        let feature_names = x.feature_names();
213        let (records, targets, weights) = (x.records, x.targets, x.weights);
214        let records = self.transform(records.to_owned());
215        DatasetBase::new(records, targets)
216            .with_weights(weights)
217            .with_feature_names(feature_names)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223
224    use super::*;
225    use approx::assert_abs_diff_eq;
226
227    use ndarray_rand::{
228        rand::distributions::Uniform, rand::rngs::SmallRng, rand::SeedableRng, RandomExt,
229    };
230
231    fn cov<D: Data<Elem = f64>>(x: &ArrayBase<D, Ix2>) -> Array2<f64> {
232        let mean = x.mean_axis(Axis(0)).unwrap();
233        let sigma = x - &mean;
234        let sigma = sigma.t().dot(&sigma) / ((x.dim().0 - 1) as f64);
235        sigma
236    }
237
238    fn inv_cov<D: Data<Elem = f64>>(x: &ArrayBase<D, Ix2>) -> Array2<f64> {
239        #[cfg(feature = "blas")]
240        let inv = cov(x).invc_into().unwrap();
241        #[cfg(not(feature = "blas"))]
242        let inv = cov(x).invc_inplace().unwrap();
243        inv
244    }
245
246    #[test]
247    fn autotraits() {
248        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
249        has_autotraits::<Whitener>();
250        has_autotraits::<WhiteningMethod>();
251        has_autotraits::<FittedWhitener<f64>>();
252    }
253
254    #[test]
255    fn test_zca_matrix() {
256        let mut rng = SmallRng::seed_from_u64(42);
257        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
258        let whitener = Whitener::zca().fit(&dataset).unwrap();
259        let inv_cov_est = whitener
260            .transformation_matrix()
261            .t()
262            .dot(&whitener.transformation_matrix());
263        let inv_cov = inv_cov(dataset.records());
264        assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-9);
265    }
266
267    #[test]
268    fn test_cholesky_matrix() {
269        let mut rng = SmallRng::seed_from_u64(42);
270        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
271        let whitener = Whitener::cholesky().fit(&dataset).unwrap();
272        let inv_cov_est = whitener
273            .transformation_matrix()
274            .t()
275            .dot(&whitener.transformation_matrix());
276        let inv_cov = inv_cov(dataset.records());
277        assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-10);
278    }
279
280    #[test]
281    fn test_pca_matrix() {
282        let mut rng = SmallRng::seed_from_u64(42);
283        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
284        let whitener = Whitener::pca().fit(&dataset).unwrap();
285        let inv_cov_est = whitener
286            .transformation_matrix()
287            .t()
288            .dot(&whitener.transformation_matrix());
289        let inv_cov = inv_cov(dataset.records());
290        assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-10);
291    }
292
293    #[test]
294    fn test_cholesky_whitening() {
295        let mut rng = SmallRng::seed_from_u64(64);
296        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
297        let whitener = Whitener::cholesky().fit(&dataset).unwrap();
298        let whitened = whitener.transform(dataset);
299        let cov = cov(whitened.records());
300        assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
301    }
302
303    #[test]
304    fn test_zca_whitening() {
305        let mut rng = SmallRng::seed_from_u64(64);
306        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
307        let whitener = Whitener::zca().fit(&dataset).unwrap();
308        let whitened = whitener.transform(dataset);
309        let cov = cov(whitened.records());
310        assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
311    }
312
313    #[test]
314    fn test_pca_whitening() {
315        let mut rng = SmallRng::seed_from_u64(64);
316        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
317        let whitener = Whitener::pca().fit(&dataset).unwrap();
318        let whitened = whitener.transform(dataset);
319        let cov = cov(whitened.records());
320        assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
321    }
322
323    #[test]
324    fn test_train_val_matrix() {
325        let (train, val) = linfa_datasets::diabetes().split_with_ratio(0.9);
326        let (train_dim, val_dim) = (train.records().dim(), val.records().dim());
327        let whitener = Whitener::pca().fit(&train).unwrap();
328        let whitened_train = whitener.transform(train);
329        let whitened_val = whitener.transform(val);
330        assert_eq!(train_dim, whitened_train.records.dim());
331        assert_eq!(val_dim, whitened_val.records.dim());
332    }
333
334    #[test]
335    fn test_retain_feature_names() {
336        let dataset = linfa_datasets::diabetes();
337        let original_feature_names = dataset.feature_names();
338        let transformed = Whitener::cholesky()
339            .fit(&dataset)
340            .unwrap()
341            .transform(dataset);
342        assert_eq!(original_feature_names, transformed.feature_names())
343    }
344
345    #[test]
346    #[should_panic]
347    fn test_pca_fail_on_empty_input() {
348        let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
349        let _whitener = Whitener::pca().fit(&dataset).unwrap();
350    }
351
352    #[test]
353    #[should_panic]
354    fn test_zca_fail_on_empty_input() {
355        let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
356        let _whitener = Whitener::zca().fit(&dataset).unwrap();
357    }
358
359    #[test]
360    #[should_panic]
361    fn test_cholesky_fail_on_empty_input() {
362        let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
363        let _whitener = Whitener::cholesky().fit(&dataset).unwrap();
364    }
365}