Skip to main content

sklears_linear/
omp.rs

1//! Orthogonal Matching Pursuit (OMP) implementation
2
3use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_linalg::compat::ArrayLinalgExt;
7// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
8use sklears_core::{
9    error::{validate, Result, SklearsError},
10    traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11    types::Float,
12};
13
14/// Configuration for OrthogonalMatchingPursuit
15#[derive(Debug, Clone)]
16pub struct OrthogonalMatchingPursuitConfig {
17    /// Maximum number of non-zero coefficients in the solution
18    pub n_nonzero_coefs: Option<usize>,
19    /// Tolerance for the residual
20    pub tol: Option<Float>,
21    /// Whether to fit the intercept
22    pub fit_intercept: bool,
23    /// Whether to normalize/standardize features before fitting
24    pub normalize: bool,
25}
26
27impl Default for OrthogonalMatchingPursuitConfig {
28    fn default() -> Self {
29        Self {
30            n_nonzero_coefs: None,
31            tol: None,
32            fit_intercept: true,
33            normalize: true,
34        }
35    }
36}
37
38/// Orthogonal Matching Pursuit model
39#[derive(Debug, Clone)]
40pub struct OrthogonalMatchingPursuit<State = Untrained> {
41    config: OrthogonalMatchingPursuitConfig,
42    state: PhantomData<State>,
43    // Trained state fields
44    coef_: Option<Array1<Float>>,
45    intercept_: Option<Float>,
46    n_features_: Option<usize>,
47    n_iter_: Option<usize>,
48}
49
50impl OrthogonalMatchingPursuit<Untrained> {
51    /// Create a new OMP model
52    pub fn new() -> Self {
53        Self {
54            config: OrthogonalMatchingPursuitConfig::default(),
55            state: PhantomData,
56            coef_: None,
57            intercept_: None,
58            n_features_: None,
59            n_iter_: None,
60        }
61    }
62
63    /// Set the maximum number of non-zero coefficients
64    pub fn n_nonzero_coefs(mut self, n_nonzero_coefs: usize) -> Self {
65        self.config.n_nonzero_coefs = Some(n_nonzero_coefs);
66        self
67    }
68
69    /// Set the tolerance for the residual
70    pub fn tol(mut self, tol: Float) -> Self {
71        self.config.tol = Some(tol);
72        self
73    }
74
75    /// Set whether to fit intercept
76    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
77        self.config.fit_intercept = fit_intercept;
78        self
79    }
80
81    /// Set whether to normalize features
82    pub fn normalize(mut self, normalize: bool) -> Self {
83        self.config.normalize = normalize;
84        self
85    }
86}
87
88impl Default for OrthogonalMatchingPursuit<Untrained> {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl Estimator for OrthogonalMatchingPursuit<Untrained> {
95    type Float = Float;
96    type Config = OrthogonalMatchingPursuitConfig;
97    type Error = SklearsError;
98
99    fn config(&self) -> &Self::Config {
100        &self.config
101    }
102}
103
104impl Fit<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Untrained> {
105    type Fitted = OrthogonalMatchingPursuit<Trained>;
106
107    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
108        // Validate inputs
109        validate::check_consistent_length(x, y)?;
110
111        let n_samples = x.nrows();
112        let n_features = x.ncols();
113
114        // Determine stopping criterion
115        let max_features = if let Some(n) = self.config.n_nonzero_coefs {
116            n.min(n_features).min(n_samples)
117        } else if self.config.tol.is_some() {
118            n_features.min(n_samples)
119        } else {
120            // Default: min(n_features, n_samples)
121            n_features.min(n_samples)
122        };
123
124        let tol = self.config.tol.unwrap_or(1e-3);
125
126        // Center X and y
127        let x_mean = x.mean_axis(Axis(0)).ok_or_else(|| {
128            SklearsError::NumericalError(
129                "mean computation should succeed for non-empty array".into(),
130            )
131        })?;
132        let mut x_centered = x - &x_mean;
133
134        let y_mean = if self.config.fit_intercept {
135            y.mean().unwrap_or(0.0)
136        } else {
137            0.0
138        };
139        let y_centered = y - y_mean;
140
141        // Normalize X if requested
142        let x_scale = if self.config.normalize {
143            let mut scale = Array1::zeros(n_features);
144            for j in 0..n_features {
145                let col = x_centered.column(j);
146                scale[j] = col.dot(&col).sqrt();
147                if scale[j] > Float::EPSILON {
148                    x_centered.column_mut(j).mapv_inplace(|x| x / scale[j]);
149                } else {
150                    scale[j] = 1.0;
151                }
152            }
153            scale
154        } else {
155            Array1::ones(n_features)
156        };
157
158        // Initialize OMP algorithm
159        let mut coef = Array1::zeros(n_features);
160        let mut active: Vec<usize> = Vec::new();
161        let mut residual = y_centered.clone();
162        let mut n_iter = 0;
163
164        // Main OMP loop
165        for _ in 0..max_features {
166            // Compute correlations with residual
167            let correlations = x_centered.t().dot(&residual);
168
169            // Find the most correlated feature not yet selected
170            let mut max_corr = 0.0;
171            let mut best_idx = 0;
172
173            for j in 0..n_features {
174                if !active.contains(&j) {
175                    let corr = correlations[j].abs();
176                    if corr > max_corr {
177                        max_corr = corr;
178                        best_idx = j;
179                    }
180                }
181            }
182
183            // Check stopping criterion
184            let residual_norm = residual.dot(&residual).sqrt();
185            if residual_norm < tol {
186                break;
187            }
188
189            // Add the best feature to active set
190            active.push(best_idx);
191            n_iter += 1;
192
193            // Solve least squares problem on active set
194            let n_active = active.len();
195            let mut x_active = Array2::zeros((n_samples, n_active));
196            for (i, &j) in active.iter().enumerate() {
197                x_active.column_mut(i).assign(&x_centered.column(j));
198            }
199
200            // Solve normal equations: (X_active^T X_active) coef_active = X_active^T y
201            let gram = x_active.t().dot(&x_active);
202            let x_active_t_y = x_active.t().dot(&y_centered);
203
204            // Add small regularization to avoid singular matrix
205            let mut gram_reg = gram.clone();
206            for i in 0..n_active {
207                gram_reg[[i, i]] += 1e-10;
208            }
209
210            let coef_active = &gram_reg
211                .solve(&x_active_t_y)
212                .map_err(|e| SklearsError::NumericalError(format!("Failed to solve: {}", e)))?;
213
214            // Update full coefficient vector
215            coef.fill(0.0);
216            for (i, &j) in active.iter().enumerate() {
217                coef[j] = coef_active[i];
218            }
219
220            // Update residual
221            residual = &y_centered - &x_centered.dot(&coef);
222        }
223
224        // Rescale coefficients if we normalized
225        if self.config.normalize {
226            for j in 0..n_features {
227                if x_scale[j] > 0.0 {
228                    coef[j] /= x_scale[j];
229                }
230            }
231        }
232
233        // Compute intercept if needed
234        let intercept = if self.config.fit_intercept {
235            Some(y_mean - x_mean.dot(&coef))
236        } else {
237            None
238        };
239
240        Ok(OrthogonalMatchingPursuit {
241            config: self.config,
242            state: PhantomData,
243            coef_: Some(coef),
244            intercept_: intercept,
245            n_features_: Some(n_features),
246            n_iter_: Some(n_iter),
247        })
248    }
249}
250
251impl OrthogonalMatchingPursuit<Trained> {
252    /// Get the coefficients
253    pub fn coef(&self) -> &Array1<Float> {
254        self.coef_.as_ref().expect("Model is trained")
255    }
256
257    /// Get the intercept
258    pub fn intercept(&self) -> Option<Float> {
259        self.intercept_
260    }
261
262    /// Get the number of iterations run
263    pub fn n_iter(&self) -> usize {
264        self.n_iter_.expect("Model is trained")
265    }
266}
267
268impl Predict<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
269    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
270        let n_features = self.n_features_.expect("Model is trained");
271        validate::check_n_features(x, n_features)?;
272
273        let coef = self.coef_.as_ref().expect("Model is trained");
274        let mut predictions = x.dot(coef);
275
276        if let Some(intercept) = self.intercept_ {
277            predictions += intercept;
278        }
279
280        Ok(predictions)
281    }
282}
283
284impl Score<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
285    type Float = Float;
286
287    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
288        let predictions = self.predict(x)?;
289
290        // Calculate R² score
291        let ss_res = (&predictions - y).mapv(|x| x * x).sum();
292        let y_mean = y.mean().unwrap_or(0.0);
293        let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
294
295        if ss_tot == 0.0 {
296            return Ok(1.0);
297        }
298
299        Ok(1.0 - (ss_res / ss_tot))
300    }
301}
302
303#[allow(non_snake_case)]
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use approx::assert_abs_diff_eq;
308    use scirs2_core::ndarray::array;
309
310    #[test]
311    fn test_omp_simple() {
312        // Simple test with orthogonal features
313        let x = array![
314            [1.0, 0.0],
315            [0.0, 1.0],
316            [1.0, 0.0],
317            [0.0, 1.0],
318            [2.0, 0.0],
319            [0.0, 2.0],
320        ];
321        let y = array![2.0, 3.0, 2.0, 3.0, 4.0, 6.0]; // y = 2*x1 + 3*x2
322
323        let model = OrthogonalMatchingPursuit::new()
324            .fit_intercept(false)
325            .normalize(false)
326            .fit(&x, &y)
327            .expect("operation should succeed");
328
329        // Should recover the true coefficients
330        let coef = model.coef();
331        assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-5);
332        assert_abs_diff_eq!(coef[1], 3.0, epsilon = 1e-5);
333
334        // Predictions should be perfect
335        let predictions = model.predict(&x).expect("prediction should succeed");
336        for i in 0..y.len() {
337            assert_abs_diff_eq!(predictions[i], y[i], epsilon = 1e-5);
338        }
339    }
340
341    #[test]
342    fn test_omp_max_features() {
343        // Test limiting number of features
344        let x = array![
345            [1.0, 0.1, 0.01],
346            [2.0, 0.2, 0.02],
347            [3.0, 0.3, 0.03],
348            [4.0, 0.4, 0.04],
349            [5.0, 0.5, 0.05],
350            [6.0, 0.6, 0.06],
351        ];
352        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; // y = 2*x1
353
354        let model = OrthogonalMatchingPursuit::new()
355            .n_nonzero_coefs(1)
356            .fit_intercept(false)
357            .normalize(false)
358            .fit(&x, &y)
359            .expect("operation should succeed");
360
361        let coef = model.coef();
362        let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
363        assert_eq!(n_nonzero, 1);
364
365        // First coefficient should be selected and close to 2.0
366        assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-3);
367
368        // Check that we ran exactly 1 iteration
369        assert_eq!(model.n_iter(), 1);
370    }
371
372    #[test]
373    fn test_omp_tolerance() {
374        // Test stopping based on tolerance
375        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
376        let y = array![2.1, 3.9, 6.05, 7.95, 10.1]; // y ≈ 2x with small noise
377
378        let model = OrthogonalMatchingPursuit::new()
379            .tol(0.5) // Relatively high tolerance
380            .fit_intercept(false)
381            .fit(&x, &y)
382            .expect("operation should succeed");
383
384        // Should get a reasonable approximation
385        let _predictions = model.predict(&x).expect("prediction should succeed");
386        let r2 = model.score(&x, &y).expect("scoring should succeed");
387        assert!(r2 > 0.95);
388    }
389
390    #[test]
391    fn test_omp_with_intercept() {
392        let x = array![[1.0], [2.0], [3.0], [4.0]];
393        let y = array![3.0, 5.0, 7.0, 9.0]; // y = 2x + 1
394
395        let model = OrthogonalMatchingPursuit::new()
396            .fit_intercept(true)
397            .fit(&x, &y)
398            .expect("operation should succeed");
399
400        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-5);
401        assert_abs_diff_eq!(
402            model.intercept().expect("intercept should be available"),
403            1.0,
404            epsilon = 1e-5
405        );
406    }
407
408    #[test]
409    fn test_omp_sparse_recovery() {
410        // Create sparse signal recovery problem
411        let n_samples = 20;
412        let n_features = 10;
413        let mut x = Array2::zeros((n_samples, n_features));
414        let mut true_coef = Array1::zeros(n_features);
415
416        // Generate random-like data deterministically
417        for i in 0..n_samples {
418            for j in 0..n_features {
419                x[[i, j]] = ((i * 7 + j * 13) % 20) as Float / 10.0 - 1.0;
420            }
421        }
422
423        // True coefficients are sparse (only 3 non-zero)
424        true_coef[1] = 2.0;
425        true_coef[4] = -1.5;
426        true_coef[7] = 1.0;
427
428        let y = x.dot(&true_coef);
429
430        let model = OrthogonalMatchingPursuit::new()
431            .n_nonzero_coefs(3)
432            .fit_intercept(false)
433            .normalize(true)
434            .fit(&x, &y)
435            .expect("operation should succeed");
436
437        let coef = model.coef();
438
439        // Should recover the support (non-zero indices)
440        for j in 0..n_features {
441            if true_coef[j] != 0.0 {
442                assert!(
443                    coef[j].abs() > 0.1,
444                    "Failed to recover non-zero coefficient at index {}",
445                    j
446                );
447            }
448        }
449
450        // Should have exactly 3 non-zero coefficients
451        let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
452        assert_eq!(n_nonzero, 3);
453    }
454}