Skip to main content

ferrolearn_preprocess/
random_projection.rs

1//! Random projection transformers for dimensionality reduction.
2//!
3//! Random projections preserve pairwise distances in expectation (Johnson-Lindenstrauss lemma).
4//!
5//! - [`GaussianRandomProjection`] — dense Gaussian random matrix
6//! - [`SparseRandomProjection`] — sparse random matrix with `{-1, 0, +1}` entries
7
8use ferrolearn_core::error::FerroError;
9use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
10use ferrolearn_core::traits::{Fit, FitTransform, Transform};
11use ndarray::{Array1, Array2};
12use num_traits::Float;
13use rand::SeedableRng;
14use rand::rngs::SmallRng;
15use rand_distr::{Distribution, StandardNormal};
16
17// ---------------------------------------------------------------------------
18// GaussianRandomProjection
19// ---------------------------------------------------------------------------
20
21/// Gaussian random projection transformer.
22///
23/// Projects data into a lower-dimensional space using a random matrix drawn
24/// from `N(0, 1/n_components)`.
25///
26/// # Examples
27///
28/// ```
29/// use ferrolearn_preprocess::random_projection::GaussianRandomProjection;
30/// use ferrolearn_core::traits::{Fit, Transform};
31/// use ndarray::Array2;
32///
33/// let x = Array2::<f64>::ones((10, 50));
34/// let proj = GaussianRandomProjection::<f64>::new(5);
35/// let fitted = proj.fit(&x, &()).unwrap();
36/// let out = fitted.transform(&x).unwrap();
37/// assert_eq!(out.shape(), &[10, 5]);
38/// ```
39#[derive(Debug, Clone)]
40pub struct GaussianRandomProjection<F> {
41    /// Number of output dimensions.
42    n_components: usize,
43    /// Optional RNG seed for reproducibility.
44    random_state: Option<u64>,
45    _marker: std::marker::PhantomData<F>,
46}
47
48impl<F: Float + Send + Sync + 'static> GaussianRandomProjection<F> {
49    /// Create a new Gaussian random projection with `n_components` output dimensions.
50    #[must_use]
51    pub fn new(n_components: usize) -> Self {
52        Self {
53            n_components,
54            random_state: None,
55            _marker: std::marker::PhantomData,
56        }
57    }
58
59    /// Set the random seed for reproducibility.
60    #[must_use]
61    pub fn random_state(mut self, seed: u64) -> Self {
62        self.random_state = Some(seed);
63        self
64    }
65}
66
67/// Fitted Gaussian random projection holding the projection matrix.
68#[derive(Debug, Clone)]
69pub struct FittedGaussianRandomProjection<F> {
70    /// Projection matrix of shape `(n_features, n_components)`.
71    projection: Array2<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> FittedGaussianRandomProjection<F> {
75    /// Return a reference to the projection matrix.
76    #[must_use]
77    pub fn projection(&self) -> &Array2<F> {
78        &self.projection
79    }
80}
81
82impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for GaussianRandomProjection<F> {
83    type Fitted = FittedGaussianRandomProjection<F>;
84    type Error = FerroError;
85
86    /// Fit the projection by generating a random matrix `R ~ N(0, 1/n_components)`.
87    ///
88    /// # Errors
89    ///
90    /// Returns [`FerroError::InvalidParameter`] if `n_components == 0`.
91    /// Returns [`FerroError::InsufficientSamples`] if `x` has zero rows.
92    fn fit(
93        &self,
94        x: &Array2<F>,
95        _y: &(),
96    ) -> Result<FittedGaussianRandomProjection<F>, FerroError> {
97        if self.n_components == 0 {
98            return Err(FerroError::InvalidParameter {
99                name: "n_components".into(),
100                reason: "must be >= 1".into(),
101            });
102        }
103        if x.nrows() == 0 {
104            return Err(FerroError::InsufficientSamples {
105                required: 1,
106                actual: 0,
107                context: "GaussianRandomProjection::fit".into(),
108            });
109        }
110
111        let n_features = x.ncols();
112        let mut rng: SmallRng = match self.random_state {
113            Some(seed) => SmallRng::seed_from_u64(seed),
114            None => SmallRng::from_os_rng(),
115        };
116
117        let scale = F::one() / F::from(self.n_components).unwrap().sqrt();
118        let normal = StandardNormal;
119        let mut projection = Array2::zeros((n_features, self.n_components));
120        for v in projection.iter_mut() {
121            let sample: f64 = normal.sample(&mut rng);
122            *v = F::from(sample).unwrap() * scale;
123        }
124
125        Ok(FittedGaussianRandomProjection { projection })
126    }
127}
128
129impl<F: Float + Send + Sync + 'static> Transform<Array2<F>>
130    for FittedGaussianRandomProjection<F>
131{
132    type Output = Array2<F>;
133    type Error = FerroError;
134
135    /// Transform data by computing `X @ R`.
136    ///
137    /// # Errors
138    ///
139    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols() != n_features`.
140    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
141        if x.ncols() != self.projection.nrows() {
142            return Err(FerroError::ShapeMismatch {
143                expected: vec![x.nrows(), self.projection.nrows()],
144                actual: vec![x.nrows(), x.ncols()],
145                context: "FittedGaussianRandomProjection::transform".into(),
146            });
147        }
148        Ok(x.dot(&self.projection))
149    }
150}
151
152impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for GaussianRandomProjection<F> {
153    type Output = Array2<F>;
154    type Error = FerroError;
155
156    /// Always returns an error — the projection must be fitted first.
157    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
158        Err(FerroError::InvalidParameter {
159            name: "GaussianRandomProjection".into(),
160            reason: "projection must be fitted before calling transform; use fit() first".into(),
161        })
162    }
163}
164
165impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for GaussianRandomProjection<F> {
166    type FitError = FerroError;
167
168    /// Fit and transform in one step.
169    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
170        let fitted = self.fit(x, &())?;
171        fitted.transform(x)
172    }
173}
174
175impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for GaussianRandomProjection<F> {
176    fn fit_pipeline(
177        &self,
178        x: &Array2<F>,
179        _y: &Array1<F>,
180    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
181        let fitted = self.fit(x, &())?;
182        Ok(Box::new(fitted))
183    }
184}
185
186impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
187    for FittedGaussianRandomProjection<F>
188{
189    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
190        self.transform(x)
191    }
192}
193
194// ---------------------------------------------------------------------------
195// SparseRandomProjection
196// ---------------------------------------------------------------------------
197
198/// Sparse random projection transformer.
199///
200/// Projects data into a lower-dimensional space using a sparse random matrix
201/// with entries `{-1, 0, +1}` drawn with probabilities
202/// `{d/2, 1 - d, d/2}`, scaled by `sqrt(1 / (d * n_components))`.
203///
204/// The default density `d = 1 / sqrt(n_features)` is used when not specified.
205///
206/// # Examples
207///
208/// ```
209/// use ferrolearn_preprocess::random_projection::SparseRandomProjection;
210/// use ferrolearn_core::traits::{Fit, Transform};
211/// use ndarray::Array2;
212///
213/// let x = Array2::<f64>::ones((10, 100));
214/// let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
215/// let fitted = proj.fit(&x, &()).unwrap();
216/// let out = fitted.transform(&x).unwrap();
217/// assert_eq!(out.shape(), &[10, 5]);
218/// ```
219#[derive(Debug, Clone)]
220pub struct SparseRandomProjection<F> {
221    /// Number of output dimensions.
222    n_components: usize,
223    /// Density of non-zero entries. `None` means `1/sqrt(n_features)`.
224    density: Option<f64>,
225    /// Optional RNG seed for reproducibility.
226    random_state: Option<u64>,
227    _marker: std::marker::PhantomData<F>,
228}
229
230impl<F: Float + Send + Sync + 'static> SparseRandomProjection<F> {
231    /// Create a new sparse random projection with `n_components` output dimensions.
232    #[must_use]
233    pub fn new(n_components: usize) -> Self {
234        Self {
235            n_components,
236            density: None,
237            random_state: None,
238            _marker: std::marker::PhantomData,
239        }
240    }
241
242    /// Set the density of non-zero entries.
243    #[must_use]
244    pub fn density(mut self, density: f64) -> Self {
245        self.density = Some(density);
246        self
247    }
248
249    /// Set the random seed for reproducibility.
250    #[must_use]
251    pub fn random_state(mut self, seed: u64) -> Self {
252        self.random_state = Some(seed);
253        self
254    }
255}
256
257/// Fitted sparse random projection holding the projection matrix.
258#[derive(Debug, Clone)]
259pub struct FittedSparseRandomProjection<F> {
260    /// Projection matrix of shape `(n_features, n_components)`.
261    projection: Array2<F>,
262}
263
264impl<F: Float + Send + Sync + 'static> FittedSparseRandomProjection<F> {
265    /// Return a reference to the projection matrix.
266    #[must_use]
267    pub fn projection(&self) -> &Array2<F> {
268        &self.projection
269    }
270}
271
272impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SparseRandomProjection<F> {
273    type Fitted = FittedSparseRandomProjection<F>;
274    type Error = FerroError;
275
276    /// Fit the projection by generating a sparse random matrix.
277    ///
278    /// # Errors
279    ///
280    /// Returns [`FerroError::InvalidParameter`] if `n_components == 0` or
281    /// `density` is not in `(0, 1]`.
282    /// Returns [`FerroError::InsufficientSamples`] if `x` has zero rows.
283    fn fit(
284        &self,
285        x: &Array2<F>,
286        _y: &(),
287    ) -> Result<FittedSparseRandomProjection<F>, FerroError> {
288        if self.n_components == 0 {
289            return Err(FerroError::InvalidParameter {
290                name: "n_components".into(),
291                reason: "must be >= 1".into(),
292            });
293        }
294        if x.nrows() == 0 {
295            return Err(FerroError::InsufficientSamples {
296                required: 1,
297                actual: 0,
298                context: "SparseRandomProjection::fit".into(),
299            });
300        }
301
302        let n_features = x.ncols();
303        let d = self
304            .density
305            .unwrap_or_else(|| 1.0 / (n_features as f64).sqrt());
306
307        if d <= 0.0 || d > 1.0 {
308            return Err(FerroError::InvalidParameter {
309                name: "density".into(),
310                reason: format!("must be in (0, 1], got {d}"),
311            });
312        }
313
314        let mut rng: SmallRng = match self.random_state {
315            Some(seed) => SmallRng::seed_from_u64(seed),
316            None => SmallRng::from_os_rng(),
317        };
318
319        let scale = F::from(1.0 / (d * self.n_components as f64).sqrt()).unwrap();
320        let uniform = rand::distr::Uniform::new(0.0_f64, 1.0).unwrap();
321
322        let mut projection = Array2::zeros((n_features, self.n_components));
323        for v in projection.iter_mut() {
324            let u: f64 = uniform.sample(&mut rng);
325            if u < d / 2.0 {
326                *v = scale.neg();
327            } else if u < d {
328                *v = scale;
329            }
330            // else: remains 0
331        }
332
333        Ok(FittedSparseRandomProjection { projection })
334    }
335}
336
337impl<F: Float + Send + Sync + 'static> Transform<Array2<F>>
338    for FittedSparseRandomProjection<F>
339{
340    type Output = Array2<F>;
341    type Error = FerroError;
342
343    /// Transform data by computing `X @ R`.
344    ///
345    /// # Errors
346    ///
347    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols() != n_features`.
348    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
349        if x.ncols() != self.projection.nrows() {
350            return Err(FerroError::ShapeMismatch {
351                expected: vec![x.nrows(), self.projection.nrows()],
352                actual: vec![x.nrows(), x.ncols()],
353                context: "FittedSparseRandomProjection::transform".into(),
354            });
355        }
356        Ok(x.dot(&self.projection))
357    }
358}
359
360impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SparseRandomProjection<F> {
361    type Output = Array2<F>;
362    type Error = FerroError;
363
364    /// Always returns an error — the projection must be fitted first.
365    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
366        Err(FerroError::InvalidParameter {
367            name: "SparseRandomProjection".into(),
368            reason: "projection must be fitted before calling transform; use fit() first".into(),
369        })
370    }
371}
372
373impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SparseRandomProjection<F> {
374    type FitError = FerroError;
375
376    /// Fit and transform in one step.
377    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
378        let fitted = self.fit(x, &())?;
379        fitted.transform(x)
380    }
381}
382
383impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SparseRandomProjection<F> {
384    fn fit_pipeline(
385        &self,
386        x: &Array2<F>,
387        _y: &Array1<F>,
388    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
389        let fitted = self.fit(x, &())?;
390        Ok(Box::new(fitted))
391    }
392}
393
394impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
395    for FittedSparseRandomProjection<F>
396{
397    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
398        self.transform(x)
399    }
400}
401
402// ---------------------------------------------------------------------------
403// Tests
404// ---------------------------------------------------------------------------
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use ndarray::Array2;
410
411    // -- GaussianRandomProjection --
412
413    #[test]
414    fn test_gaussian_rp_output_shape() {
415        let x = Array2::<f64>::ones((10, 50));
416        let proj = GaussianRandomProjection::<f64>::new(5).random_state(42);
417        let fitted = proj.fit(&x, &()).unwrap();
418        let out = fitted.transform(&x).unwrap();
419        assert_eq!(out.shape(), &[10, 5]);
420    }
421
422    #[test]
423    fn test_gaussian_rp_deterministic() {
424        let x = Array2::<f64>::ones((10, 20));
425        let proj = GaussianRandomProjection::<f64>::new(3).random_state(42);
426        let fitted1 = proj.fit(&x, &()).unwrap();
427        let out1 = fitted1.transform(&x).unwrap();
428        let fitted2 = proj.fit(&x, &()).unwrap();
429        let out2 = fitted2.transform(&x).unwrap();
430        for (a, b) in out1.iter().zip(out2.iter()) {
431            assert!((a - b).abs() < 1e-10);
432        }
433    }
434
435    #[test]
436    fn test_gaussian_rp_zero_components() {
437        let x = Array2::<f64>::ones((5, 10));
438        let proj = GaussianRandomProjection::<f64>::new(0);
439        assert!(proj.fit(&x, &()).is_err());
440    }
441
442    #[test]
443    fn test_gaussian_rp_empty_input() {
444        let x = Array2::<f64>::zeros((0, 10));
445        let proj = GaussianRandomProjection::<f64>::new(5);
446        assert!(proj.fit(&x, &()).is_err());
447    }
448
449    #[test]
450    fn test_gaussian_rp_shape_mismatch() {
451        let x_train = Array2::<f64>::ones((10, 20));
452        let proj = GaussianRandomProjection::<f64>::new(5).random_state(42);
453        let fitted = proj.fit(&x_train, &()).unwrap();
454        let x_bad = Array2::<f64>::ones((5, 15));
455        assert!(fitted.transform(&x_bad).is_err());
456    }
457
458    #[test]
459    fn test_gaussian_rp_fit_transform() {
460        let x = Array2::<f64>::ones((10, 20));
461        let proj = GaussianRandomProjection::<f64>::new(5).random_state(42);
462        let out = proj.fit_transform(&x).unwrap();
463        assert_eq!(out.shape(), &[10, 5]);
464    }
465
466    #[test]
467    fn test_gaussian_rp_f32() {
468        let x = Array2::<f32>::ones((5, 10));
469        let proj = GaussianRandomProjection::<f32>::new(3).random_state(42);
470        let fitted = proj.fit(&x, &()).unwrap();
471        let out = fitted.transform(&x).unwrap();
472        assert_eq!(out.shape(), &[5, 3]);
473    }
474
475    // -- SparseRandomProjection --
476
477    #[test]
478    fn test_sparse_rp_output_shape() {
479        let x = Array2::<f64>::ones((10, 100));
480        let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
481        let fitted = proj.fit(&x, &()).unwrap();
482        let out = fitted.transform(&x).unwrap();
483        assert_eq!(out.shape(), &[10, 5]);
484    }
485
486    #[test]
487    fn test_sparse_rp_deterministic() {
488        let x = Array2::<f64>::ones((10, 50));
489        let proj = SparseRandomProjection::<f64>::new(3).random_state(42);
490        let fitted1 = proj.fit(&x, &()).unwrap();
491        let out1 = fitted1.transform(&x).unwrap();
492        let fitted2 = proj.fit(&x, &()).unwrap();
493        let out2 = fitted2.transform(&x).unwrap();
494        for (a, b) in out1.iter().zip(out2.iter()) {
495            assert!((a - b).abs() < 1e-10);
496        }
497    }
498
499    #[test]
500    fn test_sparse_rp_sparsity() {
501        let x = Array2::<f64>::ones((5, 100));
502        let proj = SparseRandomProjection::<f64>::new(10).random_state(42);
503        let fitted = proj.fit(&x, &()).unwrap();
504        let r = fitted.projection();
505        // With density = 1/sqrt(100) = 0.1, about 90% should be zero
506        let total = r.len();
507        let zeros = r.iter().filter(|&&v| v == 0.0).count();
508        let sparsity = zeros as f64 / total as f64;
509        assert!(sparsity > 0.5, "expected sparse matrix, got sparsity={sparsity}");
510    }
511
512    #[test]
513    fn test_sparse_rp_custom_density() {
514        let x = Array2::<f64>::ones((5, 20));
515        let proj = SparseRandomProjection::<f64>::new(5)
516            .density(0.5)
517            .random_state(42);
518        let fitted = proj.fit(&x, &()).unwrap();
519        let out = fitted.transform(&x).unwrap();
520        assert_eq!(out.shape(), &[5, 5]);
521    }
522
523    #[test]
524    fn test_sparse_rp_zero_components() {
525        let x = Array2::<f64>::ones((5, 10));
526        let proj = SparseRandomProjection::<f64>::new(0);
527        assert!(proj.fit(&x, &()).is_err());
528    }
529
530    #[test]
531    fn test_sparse_rp_invalid_density() {
532        let x = Array2::<f64>::ones((5, 10));
533        let proj = SparseRandomProjection::<f64>::new(5).density(0.0);
534        assert!(proj.fit(&x, &()).is_err());
535    }
536
537    #[test]
538    fn test_sparse_rp_empty_input() {
539        let x = Array2::<f64>::zeros((0, 10));
540        let proj = SparseRandomProjection::<f64>::new(5);
541        assert!(proj.fit(&x, &()).is_err());
542    }
543
544    #[test]
545    fn test_sparse_rp_shape_mismatch() {
546        let x_train = Array2::<f64>::ones((10, 20));
547        let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
548        let fitted = proj.fit(&x_train, &()).unwrap();
549        let x_bad = Array2::<f64>::ones((5, 15));
550        assert!(fitted.transform(&x_bad).is_err());
551    }
552
553    #[test]
554    fn test_sparse_rp_fit_transform() {
555        let x = Array2::<f64>::ones((10, 20));
556        let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
557        let out = proj.fit_transform(&x).unwrap();
558        assert_eq!(out.shape(), &[10, 5]);
559    }
560
561    #[test]
562    fn test_sparse_rp_f32() {
563        let x = Array2::<f32>::ones((5, 10));
564        let proj = SparseRandomProjection::<f32>::new(3).random_state(42);
565        let fitted = proj.fit(&x, &()).unwrap();
566        let out = fitted.transform(&x).unwrap();
567        assert_eq!(out.shape(), &[5, 3]);
568    }
569}