1use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_linalg::compat::ArrayLinalgExt;
7use sklears_core::{
9 error::{validate, Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
16pub struct OrthogonalMatchingPursuitConfig {
17 pub n_nonzero_coefs: Option<usize>,
19 pub tol: Option<Float>,
21 pub fit_intercept: bool,
23 pub normalize: bool,
25}
26
27impl Default for OrthogonalMatchingPursuitConfig {
28 fn default() -> Self {
29 Self {
30 n_nonzero_coefs: None,
31 tol: None,
32 fit_intercept: true,
33 normalize: true,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct OrthogonalMatchingPursuit<State = Untrained> {
41 config: OrthogonalMatchingPursuitConfig,
42 state: PhantomData<State>,
43 coef_: Option<Array1<Float>>,
45 intercept_: Option<Float>,
46 n_features_: Option<usize>,
47 n_iter_: Option<usize>,
48}
49
50impl OrthogonalMatchingPursuit<Untrained> {
51 pub fn new() -> Self {
53 Self {
54 config: OrthogonalMatchingPursuitConfig::default(),
55 state: PhantomData,
56 coef_: None,
57 intercept_: None,
58 n_features_: None,
59 n_iter_: None,
60 }
61 }
62
63 pub fn n_nonzero_coefs(mut self, n_nonzero_coefs: usize) -> Self {
65 self.config.n_nonzero_coefs = Some(n_nonzero_coefs);
66 self
67 }
68
69 pub fn tol(mut self, tol: Float) -> Self {
71 self.config.tol = Some(tol);
72 self
73 }
74
75 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
77 self.config.fit_intercept = fit_intercept;
78 self
79 }
80
81 pub fn normalize(mut self, normalize: bool) -> Self {
83 self.config.normalize = normalize;
84 self
85 }
86}
87
88impl Default for OrthogonalMatchingPursuit<Untrained> {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl Estimator for OrthogonalMatchingPursuit<Untrained> {
95 type Float = Float;
96 type Config = OrthogonalMatchingPursuitConfig;
97 type Error = SklearsError;
98
99 fn config(&self) -> &Self::Config {
100 &self.config
101 }
102}
103
104impl Fit<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Untrained> {
105 type Fitted = OrthogonalMatchingPursuit<Trained>;
106
107 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
108 validate::check_consistent_length(x, y)?;
110
111 let n_samples = x.nrows();
112 let n_features = x.ncols();
113
114 let max_features = if let Some(n) = self.config.n_nonzero_coefs {
116 n.min(n_features).min(n_samples)
117 } else if self.config.tol.is_some() {
118 n_features.min(n_samples)
119 } else {
120 n_features.min(n_samples)
122 };
123
124 let tol = self.config.tol.unwrap_or(1e-3);
125
126 let x_mean = x.mean_axis(Axis(0)).ok_or_else(|| {
128 SklearsError::NumericalError(
129 "mean computation should succeed for non-empty array".into(),
130 )
131 })?;
132 let mut x_centered = x - &x_mean;
133
134 let y_mean = if self.config.fit_intercept {
135 y.mean().unwrap_or(0.0)
136 } else {
137 0.0
138 };
139 let y_centered = y - y_mean;
140
141 let x_scale = if self.config.normalize {
143 let mut scale = Array1::zeros(n_features);
144 for j in 0..n_features {
145 let col = x_centered.column(j);
146 scale[j] = col.dot(&col).sqrt();
147 if scale[j] > Float::EPSILON {
148 x_centered.column_mut(j).mapv_inplace(|x| x / scale[j]);
149 } else {
150 scale[j] = 1.0;
151 }
152 }
153 scale
154 } else {
155 Array1::ones(n_features)
156 };
157
158 let mut coef = Array1::zeros(n_features);
160 let mut active: Vec<usize> = Vec::new();
161 let mut residual = y_centered.clone();
162 let mut n_iter = 0;
163
164 for _ in 0..max_features {
166 let correlations = x_centered.t().dot(&residual);
168
169 let mut max_corr = 0.0;
171 let mut best_idx = 0;
172
173 for j in 0..n_features {
174 if !active.contains(&j) {
175 let corr = correlations[j].abs();
176 if corr > max_corr {
177 max_corr = corr;
178 best_idx = j;
179 }
180 }
181 }
182
183 let residual_norm = residual.dot(&residual).sqrt();
185 if residual_norm < tol {
186 break;
187 }
188
189 active.push(best_idx);
191 n_iter += 1;
192
193 let n_active = active.len();
195 let mut x_active = Array2::zeros((n_samples, n_active));
196 for (i, &j) in active.iter().enumerate() {
197 x_active.column_mut(i).assign(&x_centered.column(j));
198 }
199
200 let gram = x_active.t().dot(&x_active);
202 let x_active_t_y = x_active.t().dot(&y_centered);
203
204 let mut gram_reg = gram.clone();
206 for i in 0..n_active {
207 gram_reg[[i, i]] += 1e-10;
208 }
209
210 let coef_active = &gram_reg
211 .solve(&x_active_t_y)
212 .map_err(|e| SklearsError::NumericalError(format!("Failed to solve: {}", e)))?;
213
214 coef.fill(0.0);
216 for (i, &j) in active.iter().enumerate() {
217 coef[j] = coef_active[i];
218 }
219
220 residual = &y_centered - &x_centered.dot(&coef);
222 }
223
224 if self.config.normalize {
226 for j in 0..n_features {
227 if x_scale[j] > 0.0 {
228 coef[j] /= x_scale[j];
229 }
230 }
231 }
232
233 let intercept = if self.config.fit_intercept {
235 Some(y_mean - x_mean.dot(&coef))
236 } else {
237 None
238 };
239
240 Ok(OrthogonalMatchingPursuit {
241 config: self.config,
242 state: PhantomData,
243 coef_: Some(coef),
244 intercept_: intercept,
245 n_features_: Some(n_features),
246 n_iter_: Some(n_iter),
247 })
248 }
249}
250
251impl OrthogonalMatchingPursuit<Trained> {
252 pub fn coef(&self) -> &Array1<Float> {
254 self.coef_.as_ref().expect("Model is trained")
255 }
256
257 pub fn intercept(&self) -> Option<Float> {
259 self.intercept_
260 }
261
262 pub fn n_iter(&self) -> usize {
264 self.n_iter_.expect("Model is trained")
265 }
266}
267
268impl Predict<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
269 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
270 let n_features = self.n_features_.expect("Model is trained");
271 validate::check_n_features(x, n_features)?;
272
273 let coef = self.coef_.as_ref().expect("Model is trained");
274 let mut predictions = x.dot(coef);
275
276 if let Some(intercept) = self.intercept_ {
277 predictions += intercept;
278 }
279
280 Ok(predictions)
281 }
282}
283
284impl Score<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
285 type Float = Float;
286
287 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
288 let predictions = self.predict(x)?;
289
290 let ss_res = (&predictions - y).mapv(|x| x * x).sum();
292 let y_mean = y.mean().unwrap_or(0.0);
293 let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
294
295 if ss_tot == 0.0 {
296 return Ok(1.0);
297 }
298
299 Ok(1.0 - (ss_res / ss_tot))
300 }
301}
302
303#[allow(non_snake_case)]
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use approx::assert_abs_diff_eq;
308 use scirs2_core::ndarray::array;
309
310 #[test]
311 fn test_omp_simple() {
312 let x = array![
314 [1.0, 0.0],
315 [0.0, 1.0],
316 [1.0, 0.0],
317 [0.0, 1.0],
318 [2.0, 0.0],
319 [0.0, 2.0],
320 ];
321 let y = array![2.0, 3.0, 2.0, 3.0, 4.0, 6.0]; let model = OrthogonalMatchingPursuit::new()
324 .fit_intercept(false)
325 .normalize(false)
326 .fit(&x, &y)
327 .expect("operation should succeed");
328
329 let coef = model.coef();
331 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-5);
332 assert_abs_diff_eq!(coef[1], 3.0, epsilon = 1e-5);
333
334 let predictions = model.predict(&x).expect("prediction should succeed");
336 for i in 0..y.len() {
337 assert_abs_diff_eq!(predictions[i], y[i], epsilon = 1e-5);
338 }
339 }
340
341 #[test]
342 fn test_omp_max_features() {
343 let x = array![
345 [1.0, 0.1, 0.01],
346 [2.0, 0.2, 0.02],
347 [3.0, 0.3, 0.03],
348 [4.0, 0.4, 0.04],
349 [5.0, 0.5, 0.05],
350 [6.0, 0.6, 0.06],
351 ];
352 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; let model = OrthogonalMatchingPursuit::new()
355 .n_nonzero_coefs(1)
356 .fit_intercept(false)
357 .normalize(false)
358 .fit(&x, &y)
359 .expect("operation should succeed");
360
361 let coef = model.coef();
362 let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
363 assert_eq!(n_nonzero, 1);
364
365 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-3);
367
368 assert_eq!(model.n_iter(), 1);
370 }
371
372 #[test]
373 fn test_omp_tolerance() {
374 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
376 let y = array![2.1, 3.9, 6.05, 7.95, 10.1]; let model = OrthogonalMatchingPursuit::new()
379 .tol(0.5) .fit_intercept(false)
381 .fit(&x, &y)
382 .expect("operation should succeed");
383
384 let _predictions = model.predict(&x).expect("prediction should succeed");
386 let r2 = model.score(&x, &y).expect("scoring should succeed");
387 assert!(r2 > 0.95);
388 }
389
390 #[test]
391 fn test_omp_with_intercept() {
392 let x = array![[1.0], [2.0], [3.0], [4.0]];
393 let y = array![3.0, 5.0, 7.0, 9.0]; let model = OrthogonalMatchingPursuit::new()
396 .fit_intercept(true)
397 .fit(&x, &y)
398 .expect("operation should succeed");
399
400 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-5);
401 assert_abs_diff_eq!(
402 model.intercept().expect("intercept should be available"),
403 1.0,
404 epsilon = 1e-5
405 );
406 }
407
408 #[test]
409 fn test_omp_sparse_recovery() {
410 let n_samples = 20;
412 let n_features = 10;
413 let mut x = Array2::zeros((n_samples, n_features));
414 let mut true_coef = Array1::zeros(n_features);
415
416 for i in 0..n_samples {
418 for j in 0..n_features {
419 x[[i, j]] = ((i * 7 + j * 13) % 20) as Float / 10.0 - 1.0;
420 }
421 }
422
423 true_coef[1] = 2.0;
425 true_coef[4] = -1.5;
426 true_coef[7] = 1.0;
427
428 let y = x.dot(&true_coef);
429
430 let model = OrthogonalMatchingPursuit::new()
431 .n_nonzero_coefs(3)
432 .fit_intercept(false)
433 .normalize(true)
434 .fit(&x, &y)
435 .expect("operation should succeed");
436
437 let coef = model.coef();
438
439 for j in 0..n_features {
441 if true_coef[j] != 0.0 {
442 assert!(
443 coef[j].abs() > 0.1,
444 "Failed to recover non-zero coefficient at index {}",
445 j
446 );
447 }
448 }
449
450 let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
452 assert_eq!(n_nonzero, 3);
453 }
454}