Skip to main content

ferrolearn_preprocess/
random_projection.rs

1//! Random projection transformers for dimensionality reduction.
2//!
3//! Random projections preserve pairwise distances in expectation (Johnson-Lindenstrauss lemma).
4//!
5//! - [`GaussianRandomProjection`] — dense Gaussian random matrix
6//! - [`SparseRandomProjection`] — sparse random matrix with `{-1, 0, +1}` entries
7
8use ferrolearn_core::error::FerroError;
9use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
10use ferrolearn_core::traits::{Fit, FitTransform, Transform};
11use ndarray::{Array1, Array2};
12use num_traits::Float;
13use rand::SeedableRng;
14use rand::rngs::SmallRng;
15use rand_distr::{Distribution, StandardNormal};
16
17// ---------------------------------------------------------------------------
18// GaussianRandomProjection
19// ---------------------------------------------------------------------------
20
21/// Gaussian random projection transformer.
22///
23/// Projects data into a lower-dimensional space using a random matrix drawn
24/// from `N(0, 1/n_components)`.
25///
26/// # Examples
27///
28/// ```
29/// use ferrolearn_preprocess::random_projection::GaussianRandomProjection;
30/// use ferrolearn_core::traits::{Fit, Transform};
31/// use ndarray::Array2;
32///
33/// let x = Array2::<f64>::ones((10, 50));
34/// let proj = GaussianRandomProjection::<f64>::new(5);
35/// let fitted = proj.fit(&x, &()).unwrap();
36/// let out = fitted.transform(&x).unwrap();
37/// assert_eq!(out.shape(), &[10, 5]);
38/// ```
39#[derive(Debug, Clone)]
40pub struct GaussianRandomProjection<F> {
41    /// Number of output dimensions.
42    n_components: usize,
43    /// Optional RNG seed for reproducibility.
44    random_state: Option<u64>,
45    _marker: std::marker::PhantomData<F>,
46}
47
48impl<F: Float + Send + Sync + 'static> GaussianRandomProjection<F> {
49    /// Create a new Gaussian random projection with `n_components` output dimensions.
50    #[must_use]
51    pub fn new(n_components: usize) -> Self {
52        Self {
53            n_components,
54            random_state: None,
55            _marker: std::marker::PhantomData,
56        }
57    }
58
59    /// Set the random seed for reproducibility.
60    #[must_use]
61    pub fn random_state(mut self, seed: u64) -> Self {
62        self.random_state = Some(seed);
63        self
64    }
65}
66
67/// Fitted Gaussian random projection holding the projection matrix.
68#[derive(Debug, Clone)]
69pub struct FittedGaussianRandomProjection<F> {
70    /// Projection matrix of shape `(n_features, n_components)`.
71    projection: Array2<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> FittedGaussianRandomProjection<F> {
75    /// Return a reference to the projection matrix.
76    #[must_use]
77    pub fn projection(&self) -> &Array2<F> {
78        &self.projection
79    }
80}
81
82impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for GaussianRandomProjection<F> {
83    type Fitted = FittedGaussianRandomProjection<F>;
84    type Error = FerroError;
85
86    /// Fit the projection by generating a random matrix `R ~ N(0, 1/n_components)`.
87    ///
88    /// # Errors
89    ///
90    /// Returns [`FerroError::InvalidParameter`] if `n_components == 0`.
91    /// Returns [`FerroError::InsufficientSamples`] if `x` has zero rows.
92    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedGaussianRandomProjection<F>, FerroError> {
93        if self.n_components == 0 {
94            return Err(FerroError::InvalidParameter {
95                name: "n_components".into(),
96                reason: "must be >= 1".into(),
97            });
98        }
99        if x.nrows() == 0 {
100            return Err(FerroError::InsufficientSamples {
101                required: 1,
102                actual: 0,
103                context: "GaussianRandomProjection::fit".into(),
104            });
105        }
106
107        let n_features = x.ncols();
108        let mut rng: SmallRng = match self.random_state {
109            Some(seed) => SmallRng::seed_from_u64(seed),
110            None => SmallRng::from_os_rng(),
111        };
112
113        let scale = F::one() / F::from(self.n_components).unwrap().sqrt();
114        let normal = StandardNormal;
115        let mut projection = Array2::zeros((n_features, self.n_components));
116        for v in &mut projection {
117            let sample: f64 = normal.sample(&mut rng);
118            *v = F::from(sample).unwrap() * scale;
119        }
120
121        Ok(FittedGaussianRandomProjection { projection })
122    }
123}
124
125impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedGaussianRandomProjection<F> {
126    type Output = Array2<F>;
127    type Error = FerroError;
128
129    /// Transform data by computing `X @ R`.
130    ///
131    /// # Errors
132    ///
133    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols() != n_features`.
134    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
135        if x.ncols() != self.projection.nrows() {
136            return Err(FerroError::ShapeMismatch {
137                expected: vec![x.nrows(), self.projection.nrows()],
138                actual: vec![x.nrows(), x.ncols()],
139                context: "FittedGaussianRandomProjection::transform".into(),
140            });
141        }
142        Ok(x.dot(&self.projection))
143    }
144}
145
146impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for GaussianRandomProjection<F> {
147    type Output = Array2<F>;
148    type Error = FerroError;
149
150    /// Always returns an error — the projection must be fitted first.
151    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
152        Err(FerroError::InvalidParameter {
153            name: "GaussianRandomProjection".into(),
154            reason: "projection must be fitted before calling transform; use fit() first".into(),
155        })
156    }
157}
158
159impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for GaussianRandomProjection<F> {
160    type FitError = FerroError;
161
162    /// Fit and transform in one step.
163    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
164        let fitted = self.fit(x, &())?;
165        fitted.transform(x)
166    }
167}
168
169impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for GaussianRandomProjection<F> {
170    fn fit_pipeline(
171        &self,
172        x: &Array2<F>,
173        _y: &Array1<F>,
174    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
175        let fitted = self.fit(x, &())?;
176        Ok(Box::new(fitted))
177    }
178}
179
180impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
181    for FittedGaussianRandomProjection<F>
182{
183    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
184        self.transform(x)
185    }
186}
187
188// ---------------------------------------------------------------------------
189// SparseRandomProjection
190// ---------------------------------------------------------------------------
191
192/// Sparse random projection transformer.
193///
194/// Projects data into a lower-dimensional space using a sparse random matrix
195/// with entries `{-1, 0, +1}` drawn with probabilities
196/// `{d/2, 1 - d, d/2}`, scaled by `sqrt(1 / (d * n_components))`.
197///
198/// The default density `d = 1 / sqrt(n_features)` is used when not specified.
199///
200/// # Examples
201///
202/// ```
203/// use ferrolearn_preprocess::random_projection::SparseRandomProjection;
204/// use ferrolearn_core::traits::{Fit, Transform};
205/// use ndarray::Array2;
206///
207/// let x = Array2::<f64>::ones((10, 100));
208/// let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
209/// let fitted = proj.fit(&x, &()).unwrap();
210/// let out = fitted.transform(&x).unwrap();
211/// assert_eq!(out.shape(), &[10, 5]);
212/// ```
213#[derive(Debug, Clone)]
214pub struct SparseRandomProjection<F> {
215    /// Number of output dimensions.
216    n_components: usize,
217    /// Density of non-zero entries. `None` means `1/sqrt(n_features)`.
218    density: Option<f64>,
219    /// Optional RNG seed for reproducibility.
220    random_state: Option<u64>,
221    _marker: std::marker::PhantomData<F>,
222}
223
224impl<F: Float + Send + Sync + 'static> SparseRandomProjection<F> {
225    /// Create a new sparse random projection with `n_components` output dimensions.
226    #[must_use]
227    pub fn new(n_components: usize) -> Self {
228        Self {
229            n_components,
230            density: None,
231            random_state: None,
232            _marker: std::marker::PhantomData,
233        }
234    }
235
236    /// Set the density of non-zero entries.
237    #[must_use]
238    pub fn density(mut self, density: f64) -> Self {
239        self.density = Some(density);
240        self
241    }
242
243    /// Set the random seed for reproducibility.
244    #[must_use]
245    pub fn random_state(mut self, seed: u64) -> Self {
246        self.random_state = Some(seed);
247        self
248    }
249}
250
251/// Fitted sparse random projection holding the projection matrix.
252#[derive(Debug, Clone)]
253pub struct FittedSparseRandomProjection<F> {
254    /// Projection matrix of shape `(n_features, n_components)`.
255    projection: Array2<F>,
256}
257
258impl<F: Float + Send + Sync + 'static> FittedSparseRandomProjection<F> {
259    /// Return a reference to the projection matrix.
260    #[must_use]
261    pub fn projection(&self) -> &Array2<F> {
262        &self.projection
263    }
264}
265
266impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SparseRandomProjection<F> {
267    type Fitted = FittedSparseRandomProjection<F>;
268    type Error = FerroError;
269
270    /// Fit the projection by generating a sparse random matrix.
271    ///
272    /// # Errors
273    ///
274    /// Returns [`FerroError::InvalidParameter`] if `n_components == 0` or
275    /// `density` is not in `(0, 1]`.
276    /// Returns [`FerroError::InsufficientSamples`] if `x` has zero rows.
277    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSparseRandomProjection<F>, FerroError> {
278        if self.n_components == 0 {
279            return Err(FerroError::InvalidParameter {
280                name: "n_components".into(),
281                reason: "must be >= 1".into(),
282            });
283        }
284        if x.nrows() == 0 {
285            return Err(FerroError::InsufficientSamples {
286                required: 1,
287                actual: 0,
288                context: "SparseRandomProjection::fit".into(),
289            });
290        }
291
292        let n_features = x.ncols();
293        let d = self
294            .density
295            .unwrap_or_else(|| 1.0 / (n_features as f64).sqrt());
296
297        if d <= 0.0 || d > 1.0 {
298            return Err(FerroError::InvalidParameter {
299                name: "density".into(),
300                reason: format!("must be in (0, 1], got {d}"),
301            });
302        }
303
304        let mut rng: SmallRng = match self.random_state {
305            Some(seed) => SmallRng::seed_from_u64(seed),
306            None => SmallRng::from_os_rng(),
307        };
308
309        let scale = F::from(1.0 / (d * self.n_components as f64).sqrt()).unwrap();
310        let uniform = rand::distr::Uniform::new(0.0_f64, 1.0).unwrap();
311
312        let mut projection = Array2::zeros((n_features, self.n_components));
313        for v in &mut projection {
314            let u: f64 = uniform.sample(&mut rng);
315            if u < d / 2.0 {
316                *v = scale.neg();
317            } else if u < d {
318                *v = scale;
319            }
320            // else: remains 0
321        }
322
323        Ok(FittedSparseRandomProjection { projection })
324    }
325}
326
327impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSparseRandomProjection<F> {
328    type Output = Array2<F>;
329    type Error = FerroError;
330
331    /// Transform data by computing `X @ R`.
332    ///
333    /// # Errors
334    ///
335    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols() != n_features`.
336    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
337        if x.ncols() != self.projection.nrows() {
338            return Err(FerroError::ShapeMismatch {
339                expected: vec![x.nrows(), self.projection.nrows()],
340                actual: vec![x.nrows(), x.ncols()],
341                context: "FittedSparseRandomProjection::transform".into(),
342            });
343        }
344        Ok(x.dot(&self.projection))
345    }
346}
347
348impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SparseRandomProjection<F> {
349    type Output = Array2<F>;
350    type Error = FerroError;
351
352    /// Always returns an error — the projection must be fitted first.
353    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
354        Err(FerroError::InvalidParameter {
355            name: "SparseRandomProjection".into(),
356            reason: "projection must be fitted before calling transform; use fit() first".into(),
357        })
358    }
359}
360
361impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SparseRandomProjection<F> {
362    type FitError = FerroError;
363
364    /// Fit and transform in one step.
365    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
366        let fitted = self.fit(x, &())?;
367        fitted.transform(x)
368    }
369}
370
371impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SparseRandomProjection<F> {
372    fn fit_pipeline(
373        &self,
374        x: &Array2<F>,
375        _y: &Array1<F>,
376    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
377        let fitted = self.fit(x, &())?;
378        Ok(Box::new(fitted))
379    }
380}
381
382impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
383    for FittedSparseRandomProjection<F>
384{
385    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
386        self.transform(x)
387    }
388}
389
390// ---------------------------------------------------------------------------
391// Tests
392// ---------------------------------------------------------------------------
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use ndarray::Array2;
398
399    // -- GaussianRandomProjection --
400
401    #[test]
402    fn test_gaussian_rp_output_shape() {
403        let x = Array2::<f64>::ones((10, 50));
404        let proj = GaussianRandomProjection::<f64>::new(5).random_state(42);
405        let fitted = proj.fit(&x, &()).unwrap();
406        let out = fitted.transform(&x).unwrap();
407        assert_eq!(out.shape(), &[10, 5]);
408    }
409
410    #[test]
411    fn test_gaussian_rp_deterministic() {
412        let x = Array2::<f64>::ones((10, 20));
413        let proj = GaussianRandomProjection::<f64>::new(3).random_state(42);
414        let fitted1 = proj.fit(&x, &()).unwrap();
415        let out1 = fitted1.transform(&x).unwrap();
416        let fitted2 = proj.fit(&x, &()).unwrap();
417        let out2 = fitted2.transform(&x).unwrap();
418        for (a, b) in out1.iter().zip(out2.iter()) {
419            assert!((a - b).abs() < 1e-10);
420        }
421    }
422
423    #[test]
424    fn test_gaussian_rp_zero_components() {
425        let x = Array2::<f64>::ones((5, 10));
426        let proj = GaussianRandomProjection::<f64>::new(0);
427        assert!(proj.fit(&x, &()).is_err());
428    }
429
430    #[test]
431    fn test_gaussian_rp_empty_input() {
432        let x = Array2::<f64>::zeros((0, 10));
433        let proj = GaussianRandomProjection::<f64>::new(5);
434        assert!(proj.fit(&x, &()).is_err());
435    }
436
437    #[test]
438    fn test_gaussian_rp_shape_mismatch() {
439        let x_train = Array2::<f64>::ones((10, 20));
440        let proj = GaussianRandomProjection::<f64>::new(5).random_state(42);
441        let fitted = proj.fit(&x_train, &()).unwrap();
442        let x_bad = Array2::<f64>::ones((5, 15));
443        assert!(fitted.transform(&x_bad).is_err());
444    }
445
446    #[test]
447    fn test_gaussian_rp_fit_transform() {
448        let x = Array2::<f64>::ones((10, 20));
449        let proj = GaussianRandomProjection::<f64>::new(5).random_state(42);
450        let out = proj.fit_transform(&x).unwrap();
451        assert_eq!(out.shape(), &[10, 5]);
452    }
453
454    #[test]
455    fn test_gaussian_rp_f32() {
456        let x = Array2::<f32>::ones((5, 10));
457        let proj = GaussianRandomProjection::<f32>::new(3).random_state(42);
458        let fitted = proj.fit(&x, &()).unwrap();
459        let out = fitted.transform(&x).unwrap();
460        assert_eq!(out.shape(), &[5, 3]);
461    }
462
463    // -- SparseRandomProjection --
464
465    #[test]
466    fn test_sparse_rp_output_shape() {
467        let x = Array2::<f64>::ones((10, 100));
468        let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
469        let fitted = proj.fit(&x, &()).unwrap();
470        let out = fitted.transform(&x).unwrap();
471        assert_eq!(out.shape(), &[10, 5]);
472    }
473
474    #[test]
475    fn test_sparse_rp_deterministic() {
476        let x = Array2::<f64>::ones((10, 50));
477        let proj = SparseRandomProjection::<f64>::new(3).random_state(42);
478        let fitted1 = proj.fit(&x, &()).unwrap();
479        let out1 = fitted1.transform(&x).unwrap();
480        let fitted2 = proj.fit(&x, &()).unwrap();
481        let out2 = fitted2.transform(&x).unwrap();
482        for (a, b) in out1.iter().zip(out2.iter()) {
483            assert!((a - b).abs() < 1e-10);
484        }
485    }
486
487    #[test]
488    fn test_sparse_rp_sparsity() {
489        let x = Array2::<f64>::ones((5, 100));
490        let proj = SparseRandomProjection::<f64>::new(10).random_state(42);
491        let fitted = proj.fit(&x, &()).unwrap();
492        let r = fitted.projection();
493        // With density = 1/sqrt(100) = 0.1, about 90% should be zero
494        let total = r.len();
495        let zeros = r.iter().filter(|&&v| v == 0.0).count();
496        let sparsity = zeros as f64 / total as f64;
497        assert!(
498            sparsity > 0.5,
499            "expected sparse matrix, got sparsity={sparsity}"
500        );
501    }
502
503    #[test]
504    fn test_sparse_rp_custom_density() {
505        let x = Array2::<f64>::ones((5, 20));
506        let proj = SparseRandomProjection::<f64>::new(5)
507            .density(0.5)
508            .random_state(42);
509        let fitted = proj.fit(&x, &()).unwrap();
510        let out = fitted.transform(&x).unwrap();
511        assert_eq!(out.shape(), &[5, 5]);
512    }
513
514    #[test]
515    fn test_sparse_rp_zero_components() {
516        let x = Array2::<f64>::ones((5, 10));
517        let proj = SparseRandomProjection::<f64>::new(0);
518        assert!(proj.fit(&x, &()).is_err());
519    }
520
521    #[test]
522    fn test_sparse_rp_invalid_density() {
523        let x = Array2::<f64>::ones((5, 10));
524        let proj = SparseRandomProjection::<f64>::new(5).density(0.0);
525        assert!(proj.fit(&x, &()).is_err());
526    }
527
528    #[test]
529    fn test_sparse_rp_empty_input() {
530        let x = Array2::<f64>::zeros((0, 10));
531        let proj = SparseRandomProjection::<f64>::new(5);
532        assert!(proj.fit(&x, &()).is_err());
533    }
534
535    #[test]
536    fn test_sparse_rp_shape_mismatch() {
537        let x_train = Array2::<f64>::ones((10, 20));
538        let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
539        let fitted = proj.fit(&x_train, &()).unwrap();
540        let x_bad = Array2::<f64>::ones((5, 15));
541        assert!(fitted.transform(&x_bad).is_err());
542    }
543
544    #[test]
545    fn test_sparse_rp_fit_transform() {
546        let x = Array2::<f64>::ones((10, 20));
547        let proj = SparseRandomProjection::<f64>::new(5).random_state(42);
548        let out = proj.fit_transform(&x).unwrap();
549        assert_eq!(out.shape(), &[10, 5]);
550    }
551
552    #[test]
553    fn test_sparse_rp_f32() {
554        let x = Array2::<f32>::ones((5, 10));
555        let proj = SparseRandomProjection::<f32>::new(3).random_state(42);
556        let fitted = proj.fit(&x, &()).unwrap();
557        let out = fitted.transform(&x).unwrap();
558        assert_eq!(out.shape(), &[5, 3]);
559    }
560}