Skip to main content

ferrolearn_linear/
omp.rs

1//! Orthogonal Matching Pursuit (OMP).
2//!
3//! This module provides [`OrthogonalMatchingPursuit`], a greedy feature
4//! selection algorithm that iteratively selects the feature most correlated
5//! with the current residual, adds it to a support set, solves OLS on
6//! the support, and updates the residual. The process repeats until the
7//! desired number of non-zero coefficients is reached or the residual
8//! tolerance is met.
9//!
10//! # Examples
11//!
12//! ```
13//! use ferrolearn_linear::OrthogonalMatchingPursuit;
14//! use ferrolearn_core::{Fit, Predict};
15//! use ndarray::{array, Array1, Array2};
16//!
17//! let x = Array2::from_shape_vec((5, 3), vec![
18//!     1.0, 0.0, 0.0,
19//!     2.0, 0.1, 0.0,
20//!     3.0, 0.0, 0.1,
21//!     4.0, 0.1, 0.0,
22//!     5.0, 0.0, 0.1,
23//! ]).unwrap();
24//! let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
25//!
26//! let model = OrthogonalMatchingPursuit::<f64>::new().with_n_nonzero_coefs(1);
27//! let fitted = model.fit(&x, &y).unwrap();
28//! let preds = fitted.predict(&x).unwrap();
29//! assert_eq!(preds.len(), 5);
30//! ```
31
32use ferrolearn_core::error::FerroError;
33use ferrolearn_core::introspection::HasCoefficients;
34use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
35use ferrolearn_core::traits::{Fit, Predict};
36use ndarray::{Array1, Array2, Axis, ScalarOperand};
37use num_traits::{Float, FromPrimitive};
38
39/// Orthogonal Matching Pursuit.
40///
41/// A greedy sparse approximation algorithm that selects features one at a
42/// time. At each iteration it picks the feature most correlated with the
43/// residual, adds it to the support, solves OLS on the support set, and
44/// re-computes the residual.
45///
46/// Termination is controlled by either `n_nonzero_coefs` (maximum
47/// support size) or `tol` (residual norm threshold), whichever is reached
48/// first.
49///
50/// # Type Parameters
51///
52/// - `F`: The floating-point type (`f32` or `f64`).
53#[derive(Debug, Clone)]
54pub struct OrthogonalMatchingPursuit<F> {
55    /// Maximum number of non-zero coefficients. Defaults to `None` (use
56    /// all features or stop at `tol`).
57    pub n_nonzero_coefs: Option<usize>,
58    /// Residual norm tolerance. If the squared residual norm drops below
59    /// this threshold the algorithm terminates. Defaults to `None`.
60    pub tol: Option<F>,
61    /// Whether to fit an intercept (bias) term.
62    pub fit_intercept: bool,
63}
64
65impl<F: Float> OrthogonalMatchingPursuit<F> {
66    /// Create a new `OrthogonalMatchingPursuit` with default settings.
67    ///
68    /// Defaults: `n_nonzero_coefs = None`, `tol = None`,
69    /// `fit_intercept = true`.
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            n_nonzero_coefs: None,
74            tol: None,
75            fit_intercept: true,
76        }
77    }
78
79    /// Set the maximum number of non-zero coefficients.
80    #[must_use]
81    pub fn with_n_nonzero_coefs(mut self, n: usize) -> Self {
82        self.n_nonzero_coefs = Some(n);
83        self
84    }
85
86    /// Set the residual norm tolerance.
87    #[must_use]
88    pub fn with_tol(mut self, tol: F) -> Self {
89        self.tol = Some(tol);
90        self
91    }
92
93    /// Set whether to fit an intercept term.
94    #[must_use]
95    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
96        self.fit_intercept = fit_intercept;
97        self
98    }
99}
100
101impl<F: Float> Default for OrthogonalMatchingPursuit<F> {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// Fitted Orthogonal Matching Pursuit model.
108///
109/// Stores the learned (sparse) coefficients and intercept.
110#[derive(Debug, Clone)]
111pub struct FittedOMP<F> {
112    /// Learned coefficient vector (many entries may be zero).
113    coefficients: Array1<F>,
114    /// Learned intercept (bias) term.
115    intercept: F,
116}
117
118// ---------------------------------------------------------------------------
119// Internal helpers
120// ---------------------------------------------------------------------------
121
122/// Cholesky solve for `A x = b`.
123fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
124    let n = a.nrows();
125    let mut l = Array2::<F>::zeros((n, n));
126
127    for i in 0..n {
128        for j in 0..=i {
129            let mut s = a[[i, j]];
130            for k in 0..j {
131                s = s - l[[i, k]] * l[[j, k]];
132            }
133            if i == j {
134                if s <= F::zero() {
135                    return Err(FerroError::NumericalInstability {
136                        message: "Cholesky: matrix not positive definite".into(),
137                    });
138                }
139                l[[i, j]] = s.sqrt();
140            } else {
141                l[[i, j]] = s / l[[j, j]];
142            }
143        }
144    }
145
146    let mut z = Array1::<F>::zeros(n);
147    for i in 0..n {
148        let mut s = b[i];
149        for k in 0..i {
150            s = s - l[[i, k]] * z[k];
151        }
152        z[i] = s / l[[i, i]];
153    }
154
155    let mut x_sol = Array1::<F>::zeros(n);
156    for i in (0..n).rev() {
157        let mut s = z[i];
158        for k in (i + 1)..n {
159            s = s - l[[k, i]] * x_sol[k];
160        }
161        x_sol[i] = s / l[[i, i]];
162    }
163
164    Ok(x_sol)
165}
166
167/// Gaussian elimination with partial pivoting.
168fn gaussian_solve<F: Float>(
169    n: usize,
170    a: &Array2<F>,
171    b: &Array1<F>,
172) -> Result<Array1<F>, FerroError> {
173    let mut aug = Array2::<F>::zeros((n, n + 1));
174    for i in 0..n {
175        for j in 0..n {
176            aug[[i, j]] = a[[i, j]];
177        }
178        aug[[i, n]] = b[i];
179    }
180
181    for col in 0..n {
182        let mut max_val = aug[[col, col]].abs();
183        let mut max_row = col;
184        for row in (col + 1)..n {
185            let v = aug[[row, col]].abs();
186            if v > max_val {
187                max_val = v;
188                max_row = row;
189            }
190        }
191
192        if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
193            return Err(FerroError::NumericalInstability {
194                message: "singular matrix in Gaussian elimination".into(),
195            });
196        }
197
198        if max_row != col {
199            for j in 0..=n {
200                let tmp = aug[[col, j]];
201                aug[[col, j]] = aug[[max_row, j]];
202                aug[[max_row, j]] = tmp;
203            }
204        }
205
206        let pivot = aug[[col, col]];
207        for row in (col + 1)..n {
208            let factor = aug[[row, col]] / pivot;
209            for j in col..=n {
210                let above = aug[[col, j]];
211                aug[[row, j]] = aug[[row, j]] - factor * above;
212            }
213        }
214    }
215
216    let mut x_sol = Array1::<F>::zeros(n);
217    for i in (0..n).rev() {
218        let mut s = aug[[i, n]];
219        for j in (i + 1)..n {
220            s = s - aug[[i, j]] * x_sol[j];
221        }
222        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
223            return Err(FerroError::NumericalInstability {
224                message: "near-zero pivot in back substitution".into(),
225            });
226        }
227        x_sol[i] = s / aug[[i, i]];
228    }
229
230    Ok(x_sol)
231}
232
233/// Solve OLS on the active columns, returning the full-length coefficient vector.
234fn ols_active<F: Float + FromPrimitive + 'static>(
235    x: &Array2<F>,
236    y: &Array1<F>,
237    support: &[usize],
238    n_features: usize,
239) -> Result<Array1<F>, FerroError> {
240    let n_samples = x.nrows();
241    let k = support.len();
242
243    let mut xa = Array2::<F>::zeros((n_samples, k));
244    for (col_idx, &j) in support.iter().enumerate() {
245        for i in 0..n_samples {
246            xa[[i, col_idx]] = x[[i, j]];
247        }
248    }
249
250    let xat = xa.t();
251    let xtx = xat.dot(&xa);
252    let xty = xat.dot(y);
253
254    let w_active =
255        cholesky_solve(&xtx, &xty).or_else(|_| gaussian_solve(k, &xtx, &xty))?;
256
257    let mut w = Array1::<F>::zeros(n_features);
258    for (col_idx, &j) in support.iter().enumerate() {
259        w[j] = w_active[col_idx];
260    }
261    Ok(w)
262}
263
264// ---------------------------------------------------------------------------
265// Fit
266// ---------------------------------------------------------------------------
267
268impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
269    for OrthogonalMatchingPursuit<F>
270{
271    type Fitted = FittedOMP<F>;
272    type Error = FerroError;
273
274    /// Fit the OMP model.
275    ///
276    /// Greedily selects features by correlation with the residual and
277    /// solves OLS on the growing support set.
278    ///
279    /// # Errors
280    ///
281    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
282    /// - [`FerroError::InsufficientSamples`] — zero samples.
283    /// - [`FerroError::InvalidParameter`] — `n_nonzero_coefs` exceeds features,
284    ///   or neither `n_nonzero_coefs` nor `tol` is set.
285    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedOMP<F>, FerroError> {
286        let (n_samples, n_features) = x.dim();
287
288        if n_samples != y.len() {
289            return Err(FerroError::ShapeMismatch {
290                expected: vec![n_samples],
291                actual: vec![y.len()],
292                context: "y length must match number of samples in X".into(),
293            });
294        }
295
296        if n_samples == 0 {
297            return Err(FerroError::InsufficientSamples {
298                required: 1,
299                actual: 0,
300                context: "OMP requires at least one sample".into(),
301            });
302        }
303
304        // At least one stopping criterion must be set.
305        if self.n_nonzero_coefs.is_none() && self.tol.is_none() {
306            return Err(FerroError::InvalidParameter {
307                name: "n_nonzero_coefs / tol".into(),
308                reason: "at least one stopping criterion must be set".into(),
309            });
310        }
311
312        let max_k = self
313            .n_nonzero_coefs
314            .unwrap_or(n_features)
315            .min(n_features);
316
317        if let Some(n) = self.n_nonzero_coefs {
318            if n > n_features {
319                return Err(FerroError::InvalidParameter {
320                    name: "n_nonzero_coefs".into(),
321                    reason: format!(
322                        "cannot exceed number of features ({n_features})"
323                    ),
324                });
325            }
326        }
327
328        // Center data if fitting intercept.
329        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
330            let x_mean = x
331                .mean_axis(Axis(0))
332                .ok_or_else(|| FerroError::NumericalInstability {
333                    message: "failed to compute column means".into(),
334                })?;
335            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
336                message: "failed to compute target mean".into(),
337            })?;
338            let x_c = x - &x_mean;
339            let y_c = y - y_mean;
340            (x_c, y_c, Some(x_mean), Some(y_mean))
341        } else {
342            (x.clone(), y.clone(), None, None)
343        };
344
345        let mut support: Vec<usize> = Vec::with_capacity(max_k);
346        let mut in_support = vec![false; n_features];
347        let mut w = Array1::<F>::zeros(n_features);
348        let mut residual = y_work.clone();
349
350        for _step in 0..max_k {
351            // Check residual tolerance.
352            if let Some(tol_val) = self.tol {
353                let res_norm_sq = residual.dot(&residual);
354                if res_norm_sq < tol_val {
355                    break;
356                }
357            }
358
359            // Find feature most correlated with residual.
360            let mut best_j = None;
361            let mut best_corr = F::zero();
362            for (j, &is_in_support) in in_support.iter().enumerate() {
363                if is_in_support {
364                    continue;
365                }
366                let corr = x_work.column(j).dot(&residual).abs();
367                if corr > best_corr {
368                    best_corr = corr;
369                    best_j = Some(j);
370                }
371            }
372
373            let j = match best_j {
374                Some(j) => j,
375                None => break,
376            };
377
378            support.push(j);
379            in_support[j] = true;
380
381            // OLS on support set.
382            w = ols_active(&x_work, &y_work, &support, n_features)?;
383
384            // Update residual.
385            residual = &y_work - x_work.dot(&w);
386        }
387
388        let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
389            *ym - xm.dot(&w)
390        } else {
391            F::zero()
392        };
393
394        Ok(FittedOMP {
395            coefficients: w,
396            intercept,
397        })
398    }
399}
400
401// ---------------------------------------------------------------------------
402// Predict / HasCoefficients / Pipeline
403// ---------------------------------------------------------------------------
404
405impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedOMP<F> {
406    type Output = Array1<F>;
407    type Error = FerroError;
408
409    /// Predict target values for the given feature matrix.
410    ///
411    /// Computes `X @ coefficients + intercept`.
412    ///
413    /// # Errors
414    ///
415    /// Returns [`FerroError::ShapeMismatch`] if the number of features
416    /// does not match the fitted model.
417    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
418        if x.ncols() != self.coefficients.len() {
419            return Err(FerroError::ShapeMismatch {
420                expected: vec![self.coefficients.len()],
421                actual: vec![x.ncols()],
422                context: "number of features must match fitted model".into(),
423            });
424        }
425        Ok(x.dot(&self.coefficients) + self.intercept)
426    }
427}
428
429impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedOMP<F> {
430    fn coefficients(&self) -> &Array1<F> {
431        &self.coefficients
432    }
433
434    fn intercept(&self) -> F {
435        self.intercept
436    }
437}
438
439impl<F> PipelineEstimator<F> for OrthogonalMatchingPursuit<F>
440where
441    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
442{
443    fn fit_pipeline(
444        &self,
445        x: &Array2<F>,
446        y: &Array1<F>,
447    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
448        let fitted = self.fit(x, y)?;
449        Ok(Box::new(fitted))
450    }
451}
452
453impl<F> FittedPipelineEstimator<F> for FittedOMP<F>
454where
455    F: Float + ScalarOperand + Send + Sync + 'static,
456{
457    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
458        self.predict(x)
459    }
460}
461
462// ---------------------------------------------------------------------------
463// Tests
464// ---------------------------------------------------------------------------
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use approx::assert_relative_eq;
470    use ndarray::array;
471
472    #[test]
473    fn test_defaults() {
474        let m = OrthogonalMatchingPursuit::<f64>::new();
475        assert!(m.n_nonzero_coefs.is_none());
476        assert!(m.tol.is_none());
477        assert!(m.fit_intercept);
478    }
479
480    #[test]
481    fn test_builder() {
482        let m = OrthogonalMatchingPursuit::<f64>::new()
483            .with_n_nonzero_coefs(3)
484            .with_tol(1e-4)
485            .with_fit_intercept(false);
486        assert_eq!(m.n_nonzero_coefs, Some(3));
487        assert_relative_eq!(m.tol.unwrap(), 1e-4);
488        assert!(!m.fit_intercept);
489    }
490
491    #[test]
492    fn test_shape_mismatch() {
493        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
494        let y = array![1.0, 2.0];
495        assert!(OrthogonalMatchingPursuit::<f64>::new()
496            .with_n_nonzero_coefs(1)
497            .fit(&x, &y)
498            .is_err());
499    }
500
501    #[test]
502    fn test_no_stopping_criterion() {
503        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
504        let y = array![1.0, 2.0, 3.0];
505        assert!(OrthogonalMatchingPursuit::<f64>::new().fit(&x, &y).is_err());
506    }
507
508    #[test]
509    fn test_n_nonzero_exceeds_features() {
510        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
511        let y = array![1.0, 2.0, 3.0];
512        assert!(OrthogonalMatchingPursuit::<f64>::new()
513            .with_n_nonzero_coefs(5)
514            .fit(&x, &y)
515            .is_err());
516    }
517
518    #[test]
519    fn test_simple_linear() {
520        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
521        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
522
523        let fitted = OrthogonalMatchingPursuit::<f64>::new()
524            .with_n_nonzero_coefs(1)
525            .fit(&x, &y)
526            .unwrap();
527        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-6);
528        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-6);
529    }
530
531    #[test]
532    fn test_sparsity() {
533        // With n_nonzero_coefs=1, only one coefficient should be non-zero.
534        let x = Array2::from_shape_vec(
535            (10, 3),
536            vec![
537                1.0, 0.1, 0.01, 2.0, 0.2, 0.02, 3.0, 0.3, 0.03, 4.0, 0.4, 0.04,
538                5.0, 0.5, 0.05, 6.0, 0.6, 0.06, 7.0, 0.7, 0.07, 8.0, 0.8, 0.08,
539                9.0, 0.9, 0.09, 10.0, 1.0, 0.10,
540            ],
541        )
542        .unwrap();
543        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
544
545        let fitted = OrthogonalMatchingPursuit::<f64>::new()
546            .with_n_nonzero_coefs(1)
547            .fit(&x, &y)
548            .unwrap();
549        let nonzero = fitted
550            .coefficients()
551            .iter()
552            .filter(|&&c| c.abs() > 1e-10)
553            .count();
554        assert_eq!(nonzero, 1);
555    }
556
557    #[test]
558    fn test_tol_stopping() {
559        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
560        let y = array![2.0, 4.0, 6.0, 8.0, 10.0]; // perfect linear
561
562        let fitted = OrthogonalMatchingPursuit::<f64>::new()
563            .with_tol(1e-10)
564            .fit(&x, &y)
565            .unwrap();
566        // Should find perfect fit with 1 feature.
567        let preds = fitted.predict(&x).unwrap();
568        for (pred, actual) in preds.iter().zip(y.iter()) {
569            assert_relative_eq!(pred, actual, epsilon = 1e-4);
570        }
571    }
572
573    #[test]
574    fn test_predict() {
575        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
576        let y = array![2.0, 4.0, 6.0, 8.0];
577
578        let fitted = OrthogonalMatchingPursuit::<f64>::new()
579            .with_n_nonzero_coefs(1)
580            .fit(&x, &y)
581            .unwrap();
582        let preds = fitted.predict(&x).unwrap();
583        assert_eq!(preds.len(), 4);
584    }
585
586    #[test]
587    fn test_predict_feature_mismatch() {
588        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
589        let y = array![1.0, 2.0, 3.0];
590        let fitted = OrthogonalMatchingPursuit::<f64>::new()
591            .with_n_nonzero_coefs(1)
592            .fit(&x, &y)
593            .unwrap();
594        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
595        assert!(fitted.predict(&x_bad).is_err());
596    }
597
598    #[test]
599    fn test_has_coefficients() {
600        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
601        let y = array![1.0, 2.0, 3.0];
602        let fitted = OrthogonalMatchingPursuit::<f64>::new()
603            .with_n_nonzero_coefs(2)
604            .fit(&x, &y)
605            .unwrap();
606        assert_eq!(fitted.coefficients().len(), 2);
607    }
608
609    #[test]
610    fn test_no_intercept() {
611        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
612        let y = array![2.0, 4.0, 6.0, 8.0];
613
614        let fitted = OrthogonalMatchingPursuit::<f64>::new()
615            .with_n_nonzero_coefs(1)
616            .with_fit_intercept(false)
617            .fit(&x, &y)
618            .unwrap();
619        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
620    }
621
622    #[test]
623    fn test_pipeline() {
624        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
625        let y = array![3.0, 5.0, 7.0, 9.0];
626        let model = OrthogonalMatchingPursuit::<f64>::new().with_n_nonzero_coefs(1);
627        let fitted = model.fit_pipeline(&x, &y).unwrap();
628        let preds = fitted.predict_pipeline(&x).unwrap();
629        assert_eq!(preds.len(), 4);
630    }
631
632    #[test]
633    fn test_multivariate_recovery() {
634        // y = 1*x1 + 3*x2, OMP with n_nonzero_coefs=2 should recover both.
635        let x = Array2::from_shape_vec(
636            (5, 3),
637            vec![
638                1.0, 0.0, 0.5, 0.0, 1.0, 0.3, 1.0, 1.0, 0.1, 2.0, 0.0, 0.8, 0.0, 2.0, 0.4,
639            ],
640        )
641        .unwrap();
642        let y = array![1.0, 3.0, 4.0, 2.0, 6.0]; // = x1 + 3*x2
643
644        let fitted = OrthogonalMatchingPursuit::<f64>::new()
645            .with_n_nonzero_coefs(2)
646            .fit(&x, &y)
647            .unwrap();
648
649        // The third feature should remain approximately zero.
650        assert!(
651            fitted.coefficients()[2].abs() < 0.5,
652            "irrelevant feature should have near-zero coefficient"
653        );
654    }
655}