1use ferrolearn_core::error::FerroError;
33use ferrolearn_core::introspection::HasCoefficients;
34use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
35use ferrolearn_core::traits::{Fit, Predict};
36use ndarray::{Array1, Array2, Axis, ScalarOperand};
37use num_traits::{Float, FromPrimitive};
38
39#[derive(Debug, Clone)]
54pub struct OrthogonalMatchingPursuit<F> {
55 pub n_nonzero_coefs: Option<usize>,
58 pub tol: Option<F>,
61 pub fit_intercept: bool,
63}
64
65impl<F: Float> OrthogonalMatchingPursuit<F> {
66 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 n_nonzero_coefs: None,
74 tol: None,
75 fit_intercept: true,
76 }
77 }
78
79 #[must_use]
81 pub fn with_n_nonzero_coefs(mut self, n: usize) -> Self {
82 self.n_nonzero_coefs = Some(n);
83 self
84 }
85
86 #[must_use]
88 pub fn with_tol(mut self, tol: F) -> Self {
89 self.tol = Some(tol);
90 self
91 }
92
93 #[must_use]
95 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
96 self.fit_intercept = fit_intercept;
97 self
98 }
99}
100
101impl<F: Float> Default for OrthogonalMatchingPursuit<F> {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[derive(Debug, Clone)]
111pub struct FittedOMP<F> {
112 coefficients: Array1<F>,
114 intercept: F,
116}
117
118fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
124 let n = a.nrows();
125 let mut l = Array2::<F>::zeros((n, n));
126
127 for i in 0..n {
128 for j in 0..=i {
129 let mut s = a[[i, j]];
130 for k in 0..j {
131 s = s - l[[i, k]] * l[[j, k]];
132 }
133 if i == j {
134 if s <= F::zero() {
135 return Err(FerroError::NumericalInstability {
136 message: "Cholesky: matrix not positive definite".into(),
137 });
138 }
139 l[[i, j]] = s.sqrt();
140 } else {
141 l[[i, j]] = s / l[[j, j]];
142 }
143 }
144 }
145
146 let mut z = Array1::<F>::zeros(n);
147 for i in 0..n {
148 let mut s = b[i];
149 for k in 0..i {
150 s = s - l[[i, k]] * z[k];
151 }
152 z[i] = s / l[[i, i]];
153 }
154
155 let mut x_sol = Array1::<F>::zeros(n);
156 for i in (0..n).rev() {
157 let mut s = z[i];
158 for k in (i + 1)..n {
159 s = s - l[[k, i]] * x_sol[k];
160 }
161 x_sol[i] = s / l[[i, i]];
162 }
163
164 Ok(x_sol)
165}
166
167fn gaussian_solve<F: Float>(
169 n: usize,
170 a: &Array2<F>,
171 b: &Array1<F>,
172) -> Result<Array1<F>, FerroError> {
173 let mut aug = Array2::<F>::zeros((n, n + 1));
174 for i in 0..n {
175 for j in 0..n {
176 aug[[i, j]] = a[[i, j]];
177 }
178 aug[[i, n]] = b[i];
179 }
180
181 for col in 0..n {
182 let mut max_val = aug[[col, col]].abs();
183 let mut max_row = col;
184 for row in (col + 1)..n {
185 let v = aug[[row, col]].abs();
186 if v > max_val {
187 max_val = v;
188 max_row = row;
189 }
190 }
191
192 if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
193 return Err(FerroError::NumericalInstability {
194 message: "singular matrix in Gaussian elimination".into(),
195 });
196 }
197
198 if max_row != col {
199 for j in 0..=n {
200 let tmp = aug[[col, j]];
201 aug[[col, j]] = aug[[max_row, j]];
202 aug[[max_row, j]] = tmp;
203 }
204 }
205
206 let pivot = aug[[col, col]];
207 for row in (col + 1)..n {
208 let factor = aug[[row, col]] / pivot;
209 for j in col..=n {
210 let above = aug[[col, j]];
211 aug[[row, j]] = aug[[row, j]] - factor * above;
212 }
213 }
214 }
215
216 let mut x_sol = Array1::<F>::zeros(n);
217 for i in (0..n).rev() {
218 let mut s = aug[[i, n]];
219 for j in (i + 1)..n {
220 s = s - aug[[i, j]] * x_sol[j];
221 }
222 if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
223 return Err(FerroError::NumericalInstability {
224 message: "near-zero pivot in back substitution".into(),
225 });
226 }
227 x_sol[i] = s / aug[[i, i]];
228 }
229
230 Ok(x_sol)
231}
232
233fn ols_active<F: Float + FromPrimitive + 'static>(
235 x: &Array2<F>,
236 y: &Array1<F>,
237 support: &[usize],
238 n_features: usize,
239) -> Result<Array1<F>, FerroError> {
240 let n_samples = x.nrows();
241 let k = support.len();
242
243 let mut xa = Array2::<F>::zeros((n_samples, k));
244 for (col_idx, &j) in support.iter().enumerate() {
245 for i in 0..n_samples {
246 xa[[i, col_idx]] = x[[i, j]];
247 }
248 }
249
250 let xat = xa.t();
251 let xtx = xat.dot(&xa);
252 let xty = xat.dot(y);
253
254 let w_active =
255 cholesky_solve(&xtx, &xty).or_else(|_| gaussian_solve(k, &xtx, &xty))?;
256
257 let mut w = Array1::<F>::zeros(n_features);
258 for (col_idx, &j) in support.iter().enumerate() {
259 w[j] = w_active[col_idx];
260 }
261 Ok(w)
262}
263
264impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
269 for OrthogonalMatchingPursuit<F>
270{
271 type Fitted = FittedOMP<F>;
272 type Error = FerroError;
273
274 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedOMP<F>, FerroError> {
286 let (n_samples, n_features) = x.dim();
287
288 if n_samples != y.len() {
289 return Err(FerroError::ShapeMismatch {
290 expected: vec![n_samples],
291 actual: vec![y.len()],
292 context: "y length must match number of samples in X".into(),
293 });
294 }
295
296 if n_samples == 0 {
297 return Err(FerroError::InsufficientSamples {
298 required: 1,
299 actual: 0,
300 context: "OMP requires at least one sample".into(),
301 });
302 }
303
304 if self.n_nonzero_coefs.is_none() && self.tol.is_none() {
306 return Err(FerroError::InvalidParameter {
307 name: "n_nonzero_coefs / tol".into(),
308 reason: "at least one stopping criterion must be set".into(),
309 });
310 }
311
312 let max_k = self
313 .n_nonzero_coefs
314 .unwrap_or(n_features)
315 .min(n_features);
316
317 if let Some(n) = self.n_nonzero_coefs {
318 if n > n_features {
319 return Err(FerroError::InvalidParameter {
320 name: "n_nonzero_coefs".into(),
321 reason: format!(
322 "cannot exceed number of features ({n_features})"
323 ),
324 });
325 }
326 }
327
328 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
330 let x_mean = x
331 .mean_axis(Axis(0))
332 .ok_or_else(|| FerroError::NumericalInstability {
333 message: "failed to compute column means".into(),
334 })?;
335 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
336 message: "failed to compute target mean".into(),
337 })?;
338 let x_c = x - &x_mean;
339 let y_c = y - y_mean;
340 (x_c, y_c, Some(x_mean), Some(y_mean))
341 } else {
342 (x.clone(), y.clone(), None, None)
343 };
344
345 let mut support: Vec<usize> = Vec::with_capacity(max_k);
346 let mut in_support = vec![false; n_features];
347 let mut w = Array1::<F>::zeros(n_features);
348 let mut residual = y_work.clone();
349
350 for _step in 0..max_k {
351 if let Some(tol_val) = self.tol {
353 let res_norm_sq = residual.dot(&residual);
354 if res_norm_sq < tol_val {
355 break;
356 }
357 }
358
359 let mut best_j = None;
361 let mut best_corr = F::zero();
362 for (j, &is_in_support) in in_support.iter().enumerate() {
363 if is_in_support {
364 continue;
365 }
366 let corr = x_work.column(j).dot(&residual).abs();
367 if corr > best_corr {
368 best_corr = corr;
369 best_j = Some(j);
370 }
371 }
372
373 let j = match best_j {
374 Some(j) => j,
375 None => break,
376 };
377
378 support.push(j);
379 in_support[j] = true;
380
381 w = ols_active(&x_work, &y_work, &support, n_features)?;
383
384 residual = &y_work - x_work.dot(&w);
386 }
387
388 let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
389 *ym - xm.dot(&w)
390 } else {
391 F::zero()
392 };
393
394 Ok(FittedOMP {
395 coefficients: w,
396 intercept,
397 })
398 }
399}
400
401impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedOMP<F> {
406 type Output = Array1<F>;
407 type Error = FerroError;
408
409 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
418 if x.ncols() != self.coefficients.len() {
419 return Err(FerroError::ShapeMismatch {
420 expected: vec![self.coefficients.len()],
421 actual: vec![x.ncols()],
422 context: "number of features must match fitted model".into(),
423 });
424 }
425 Ok(x.dot(&self.coefficients) + self.intercept)
426 }
427}
428
429impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedOMP<F> {
430 fn coefficients(&self) -> &Array1<F> {
431 &self.coefficients
432 }
433
434 fn intercept(&self) -> F {
435 self.intercept
436 }
437}
438
439impl<F> PipelineEstimator<F> for OrthogonalMatchingPursuit<F>
440where
441 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
442{
443 fn fit_pipeline(
444 &self,
445 x: &Array2<F>,
446 y: &Array1<F>,
447 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
448 let fitted = self.fit(x, y)?;
449 Ok(Box::new(fitted))
450 }
451}
452
453impl<F> FittedPipelineEstimator<F> for FittedOMP<F>
454where
455 F: Float + ScalarOperand + Send + Sync + 'static,
456{
457 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
458 self.predict(x)
459 }
460}
461
462#[cfg(test)]
467mod tests {
468 use super::*;
469 use approx::assert_relative_eq;
470 use ndarray::array;
471
472 #[test]
473 fn test_defaults() {
474 let m = OrthogonalMatchingPursuit::<f64>::new();
475 assert!(m.n_nonzero_coefs.is_none());
476 assert!(m.tol.is_none());
477 assert!(m.fit_intercept);
478 }
479
480 #[test]
481 fn test_builder() {
482 let m = OrthogonalMatchingPursuit::<f64>::new()
483 .with_n_nonzero_coefs(3)
484 .with_tol(1e-4)
485 .with_fit_intercept(false);
486 assert_eq!(m.n_nonzero_coefs, Some(3));
487 assert_relative_eq!(m.tol.unwrap(), 1e-4);
488 assert!(!m.fit_intercept);
489 }
490
491 #[test]
492 fn test_shape_mismatch() {
493 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
494 let y = array![1.0, 2.0];
495 assert!(OrthogonalMatchingPursuit::<f64>::new()
496 .with_n_nonzero_coefs(1)
497 .fit(&x, &y)
498 .is_err());
499 }
500
501 #[test]
502 fn test_no_stopping_criterion() {
503 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
504 let y = array![1.0, 2.0, 3.0];
505 assert!(OrthogonalMatchingPursuit::<f64>::new().fit(&x, &y).is_err());
506 }
507
508 #[test]
509 fn test_n_nonzero_exceeds_features() {
510 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
511 let y = array![1.0, 2.0, 3.0];
512 assert!(OrthogonalMatchingPursuit::<f64>::new()
513 .with_n_nonzero_coefs(5)
514 .fit(&x, &y)
515 .is_err());
516 }
517
518 #[test]
519 fn test_simple_linear() {
520 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
521 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
522
523 let fitted = OrthogonalMatchingPursuit::<f64>::new()
524 .with_n_nonzero_coefs(1)
525 .fit(&x, &y)
526 .unwrap();
527 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-6);
528 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-6);
529 }
530
531 #[test]
532 fn test_sparsity() {
533 let x = Array2::from_shape_vec(
535 (10, 3),
536 vec![
537 1.0, 0.1, 0.01, 2.0, 0.2, 0.02, 3.0, 0.3, 0.03, 4.0, 0.4, 0.04,
538 5.0, 0.5, 0.05, 6.0, 0.6, 0.06, 7.0, 0.7, 0.07, 8.0, 0.8, 0.08,
539 9.0, 0.9, 0.09, 10.0, 1.0, 0.10,
540 ],
541 )
542 .unwrap();
543 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
544
545 let fitted = OrthogonalMatchingPursuit::<f64>::new()
546 .with_n_nonzero_coefs(1)
547 .fit(&x, &y)
548 .unwrap();
549 let nonzero = fitted
550 .coefficients()
551 .iter()
552 .filter(|&&c| c.abs() > 1e-10)
553 .count();
554 assert_eq!(nonzero, 1);
555 }
556
557 #[test]
558 fn test_tol_stopping() {
559 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
560 let y = array![2.0, 4.0, 6.0, 8.0, 10.0]; let fitted = OrthogonalMatchingPursuit::<f64>::new()
563 .with_tol(1e-10)
564 .fit(&x, &y)
565 .unwrap();
566 let preds = fitted.predict(&x).unwrap();
568 for (pred, actual) in preds.iter().zip(y.iter()) {
569 assert_relative_eq!(pred, actual, epsilon = 1e-4);
570 }
571 }
572
573 #[test]
574 fn test_predict() {
575 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
576 let y = array![2.0, 4.0, 6.0, 8.0];
577
578 let fitted = OrthogonalMatchingPursuit::<f64>::new()
579 .with_n_nonzero_coefs(1)
580 .fit(&x, &y)
581 .unwrap();
582 let preds = fitted.predict(&x).unwrap();
583 assert_eq!(preds.len(), 4);
584 }
585
586 #[test]
587 fn test_predict_feature_mismatch() {
588 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
589 let y = array![1.0, 2.0, 3.0];
590 let fitted = OrthogonalMatchingPursuit::<f64>::new()
591 .with_n_nonzero_coefs(1)
592 .fit(&x, &y)
593 .unwrap();
594 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
595 assert!(fitted.predict(&x_bad).is_err());
596 }
597
598 #[test]
599 fn test_has_coefficients() {
600 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
601 let y = array![1.0, 2.0, 3.0];
602 let fitted = OrthogonalMatchingPursuit::<f64>::new()
603 .with_n_nonzero_coefs(2)
604 .fit(&x, &y)
605 .unwrap();
606 assert_eq!(fitted.coefficients().len(), 2);
607 }
608
609 #[test]
610 fn test_no_intercept() {
611 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
612 let y = array![2.0, 4.0, 6.0, 8.0];
613
614 let fitted = OrthogonalMatchingPursuit::<f64>::new()
615 .with_n_nonzero_coefs(1)
616 .with_fit_intercept(false)
617 .fit(&x, &y)
618 .unwrap();
619 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
620 }
621
622 #[test]
623 fn test_pipeline() {
624 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
625 let y = array![3.0, 5.0, 7.0, 9.0];
626 let model = OrthogonalMatchingPursuit::<f64>::new().with_n_nonzero_coefs(1);
627 let fitted = model.fit_pipeline(&x, &y).unwrap();
628 let preds = fitted.predict_pipeline(&x).unwrap();
629 assert_eq!(preds.len(), 4);
630 }
631
632 #[test]
633 fn test_multivariate_recovery() {
634 let x = Array2::from_shape_vec(
636 (5, 3),
637 vec![
638 1.0, 0.0, 0.5, 0.0, 1.0, 0.3, 1.0, 1.0, 0.1, 2.0, 0.0, 0.8, 0.0, 2.0, 0.4,
639 ],
640 )
641 .unwrap();
642 let y = array![1.0, 3.0, 4.0, 2.0, 6.0]; let fitted = OrthogonalMatchingPursuit::<f64>::new()
645 .with_n_nonzero_coefs(2)
646 .fit(&x, &y)
647 .unwrap();
648
649 assert!(
651 fitted.coefficients()[2].abs() < 0.5,
652 "irrelevant feature should have near-zero coefficient"
653 );
654 }
655}