sklears_feature_selection/
optimization.rs

1//! Optimization-based feature selection algorithms
2//!
3//! This module provides advanced optimization methods for feature selection including
4//! convex optimization, semidefinite programming, and proximal gradient methods.
5
6use crate::base::{FeatureSelector, SelectorMixin};
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::{
9    error::{validate, Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15/// Convex optimization-based feature selection
16#[derive(Debug, Clone)]
17pub struct ConvexFeatureSelector<State = Untrained> {
18    k: usize,
19    regularization: Float,
20    max_iter: usize,
21    tolerance: Float,
22    state: PhantomData<State>,
23    // Trained state
24    weights_: Option<Array1<Float>>,
25    selected_features_: Option<Vec<usize>>,
26    n_features_: Option<usize>,
27    objective_values_: Option<Vec<Float>>,
28}
29
30impl Default for ConvexFeatureSelector<Untrained> {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl ConvexFeatureSelector<Untrained> {
37    /// Create a new convex feature selector
38    pub fn new() -> Self {
39        Self {
40            k: 10,
41            regularization: 1.0,
42            max_iter: 1000,
43            tolerance: 1e-6,
44            state: PhantomData,
45            weights_: None,
46            selected_features_: None,
47            n_features_: None,
48            objective_values_: None,
49        }
50    }
51
52    /// Set the number of features to select
53    pub fn k(mut self, k: usize) -> Self {
54        self.k = k;
55        self
56    }
57
58    /// Set the regularization parameter
59    pub fn regularization(mut self, regularization: Float) -> Self {
60        self.regularization = regularization;
61        self
62    }
63
64    /// Set the maximum number of iterations
65    pub fn max_iter(mut self, max_iter: usize) -> Self {
66        self.max_iter = max_iter;
67        self
68    }
69
70    /// Set the convergence tolerance
71    pub fn tolerance(mut self, tolerance: Float) -> Self {
72        self.tolerance = tolerance;
73        self
74    }
75
76    /// Solve convex optimization problem for feature selection
77    fn solve_convex_optimization(
78        &self,
79        features: &Array2<Float>,
80        target: &Array1<Float>,
81    ) -> SklResult<(Array1<Float>, Vec<Float>)> {
82        let n_features = features.ncols();
83        let n_samples = features.nrows();
84
85        // Initialize weights
86        let mut weights = Array1::from_elem(n_features, 1.0 / n_features as Float);
87        let mut objective_values = Vec::new();
88
89        // Gradient descent with L1 regularization
90        for iter in 0..self.max_iter {
91            // Compute predictions
92            let predictions = features.dot(&weights);
93
94            // Compute residuals
95            let residuals = &predictions - target;
96
97            // Compute gradient
98            let data_gradient = features.t().dot(&residuals) / n_samples as Float;
99
100            // L1 regularization subgradient
101            let reg_gradient = weights.mapv(|w| {
102                if w > 0.0 {
103                    self.regularization
104                } else if w < 0.0 {
105                    -self.regularization
106                } else {
107                    0.0 // Subgradient at 0
108                }
109            });
110
111            let gradient = data_gradient + reg_gradient;
112
113            // Compute step size (simple fixed step)
114            let step_size = 0.01 / (iter + 1) as Float;
115
116            // Update weights
117            let new_weights = &weights - step_size * &gradient;
118
119            // Apply non-negativity constraint (for simplicity)
120            let new_weights = new_weights.mapv(|w| w.max(0.0));
121
122            // Compute objective value
123            let data_term = residuals.mapv(|r| r * r).sum() / (2.0 * n_samples as Float);
124            let reg_term = self.regularization * weights.mapv(|w| w.abs()).sum();
125            let objective = data_term + reg_term;
126            objective_values.push(objective);
127
128            // Check convergence
129            let weight_diff = (&new_weights - &weights).mapv(|d| d.abs()).sum();
130            if weight_diff < self.tolerance {
131                break;
132            }
133
134            weights = new_weights;
135        }
136
137        Ok((weights, objective_values))
138    }
139}
140
141impl Estimator for ConvexFeatureSelector<Untrained> {
142    type Config = ();
143    type Error = SklearsError;
144    type Float = Float;
145
146    fn config(&self) -> &Self::Config {
147        &()
148    }
149}
150
151impl Fit<Array2<Float>, Array1<Float>> for ConvexFeatureSelector<Untrained> {
152    type Fitted = ConvexFeatureSelector<Trained>;
153
154    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
155        let n_features = features.ncols();
156        if n_features == 0 {
157            return Err(SklearsError::InvalidInput(
158                "No features provided".to_string(),
159            ));
160        }
161
162        if self.k > n_features {
163            return Err(SklearsError::InvalidInput(format!(
164                "k ({}) cannot be greater than number of features ({})",
165                self.k, n_features
166            )));
167        }
168
169        // Solve convex optimization
170        let (weights, objective_values) = self.solve_convex_optimization(features, target)?;
171
172        // Select top k features based on weights
173        let mut feature_indices: Vec<usize> = (0..n_features).collect();
174        feature_indices.sort_by(|&a, &b| {
175            weights[b]
176                .partial_cmp(&weights[a])
177                .unwrap_or(std::cmp::Ordering::Equal)
178        });
179
180        let selected_features = feature_indices.into_iter().take(self.k).collect();
181
182        Ok(ConvexFeatureSelector {
183            k: self.k,
184            regularization: self.regularization,
185            max_iter: self.max_iter,
186            tolerance: self.tolerance,
187            state: PhantomData,
188            weights_: Some(weights),
189            selected_features_: Some(selected_features),
190            n_features_: Some(n_features),
191            objective_values_: Some(objective_values),
192        })
193    }
194}
195
196impl Transform<Array2<Float>> for ConvexFeatureSelector<Trained> {
197    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
198        validate::check_n_features(x, self.n_features_.unwrap())?;
199
200        let selected_features = self.selected_features_.as_ref().unwrap();
201        let n_samples = x.nrows();
202        let n_selected = selected_features.len();
203        let mut x_new = Array2::zeros((n_samples, n_selected));
204
205        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
206            x_new.column_mut(new_idx).assign(&x.column(old_idx));
207        }
208
209        Ok(x_new)
210    }
211}
212
213impl SelectorMixin for ConvexFeatureSelector<Trained> {
214    fn get_support(&self) -> SklResult<Array1<bool>> {
215        let n_features = self.n_features_.unwrap();
216        let selected_features = self.selected_features_.as_ref().unwrap();
217        let mut support = Array1::from_elem(n_features, false);
218
219        for &idx in selected_features {
220            support[idx] = true;
221        }
222
223        Ok(support)
224    }
225
226    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
227        let selected_features = self.selected_features_.as_ref().unwrap();
228        Ok(indices
229            .iter()
230            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
231            .collect())
232    }
233}
234
235impl FeatureSelector for ConvexFeatureSelector<Trained> {
236    fn selected_features(&self) -> &Vec<usize> {
237        self.selected_features_.as_ref().unwrap()
238    }
239}
240
241impl ConvexFeatureSelector<Trained> {
242    /// Get the learned weights
243    pub fn weights(&self) -> &Array1<Float> {
244        self.weights_.as_ref().unwrap()
245    }
246
247    /// Get the objective values during optimization
248    pub fn objective_values(&self) -> &[Float] {
249        self.objective_values_.as_ref().unwrap()
250    }
251
252    /// Get the number of selected features
253    pub fn n_features_out(&self) -> usize {
254        self.selected_features_.as_ref().unwrap().len()
255    }
256}
257
258/// Proximal gradient method for feature selection
259#[derive(Debug, Clone)]
260pub struct ProximalGradientSelector<State = Untrained> {
261    k: usize,
262    regularization: Float,
263    max_iter: usize,
264    tolerance: Float,
265    step_size: Float,
266    state: PhantomData<State>,
267    // Trained state
268    weights_: Option<Array1<Float>>,
269    selected_features_: Option<Vec<usize>>,
270    n_features_: Option<usize>,
271    objective_values_: Option<Vec<Float>>,
272}
273
274impl Default for ProximalGradientSelector<Untrained> {
275    fn default() -> Self {
276        Self::new()
277    }
278}
279
280impl ProximalGradientSelector<Untrained> {
281    /// Create a new proximal gradient selector
282    pub fn new() -> Self {
283        Self {
284            k: 10,
285            regularization: 1.0,
286            max_iter: 1000,
287            tolerance: 1e-6,
288            step_size: 0.01,
289            state: PhantomData,
290            weights_: None,
291            selected_features_: None,
292            n_features_: None,
293            objective_values_: None,
294        }
295    }
296
297    /// Set the number of features to select
298    pub fn k(mut self, k: usize) -> Self {
299        self.k = k;
300        self
301    }
302
303    /// Set the regularization parameter
304    pub fn regularization(mut self, regularization: Float) -> Self {
305        self.regularization = regularization;
306        self
307    }
308
309    /// Set the maximum number of iterations
310    pub fn max_iter(mut self, max_iter: usize) -> Self {
311        self.max_iter = max_iter;
312        self
313    }
314
315    /// Set the convergence tolerance
316    pub fn tolerance(mut self, tolerance: Float) -> Self {
317        self.tolerance = tolerance;
318        self
319    }
320
321    /// Set the step size
322    pub fn step_size(mut self, step_size: Float) -> Self {
323        self.step_size = step_size;
324        self
325    }
326
327    /// Soft thresholding operator (proximal operator for L1 norm)
328    fn soft_threshold(&self, x: Float, threshold: Float) -> Float {
329        if x > threshold {
330            x - threshold
331        } else if x < -threshold {
332            x + threshold
333        } else {
334            0.0
335        }
336    }
337
338    /// Solve using proximal gradient method
339    fn solve_proximal_gradient(
340        &self,
341        features: &Array2<Float>,
342        target: &Array1<Float>,
343    ) -> SklResult<(Array1<Float>, Vec<Float>)> {
344        let n_features = features.ncols();
345        let n_samples = features.nrows();
346
347        // Initialize weights
348        let mut weights = Array1::zeros(n_features);
349        let mut objective_values = Vec::new();
350
351        // Proximal gradient iterations
352        for _iter in 0..self.max_iter {
353            // Compute predictions
354            let predictions = features.dot(&weights);
355
356            // Compute residuals
357            let residuals = &predictions - target;
358
359            // Compute gradient of smooth part (data term)
360            let gradient = features.t().dot(&residuals) / n_samples as Float;
361
362            // Gradient step
363            let temp_weights = &weights - self.step_size * &gradient;
364
365            // Proximal step (soft thresholding for L1 regularization)
366            let threshold = self.step_size * self.regularization;
367            let new_weights = temp_weights.mapv(|w| self.soft_threshold(w, threshold));
368
369            // Compute objective value
370            let data_term = residuals.mapv(|r| r * r).sum() / (2.0 * n_samples as Float);
371            let reg_term = self.regularization * weights.mapv(|w| w.abs()).sum();
372            let objective = data_term + reg_term;
373            objective_values.push(objective);
374
375            // Check convergence
376            let weight_diff = (&new_weights - &weights).mapv(|d| d.abs()).sum();
377            if weight_diff < self.tolerance {
378                break;
379            }
380
381            weights = new_weights;
382        }
383
384        Ok((weights, objective_values))
385    }
386}
387
388impl Estimator for ProximalGradientSelector<Untrained> {
389    type Config = ();
390    type Error = SklearsError;
391    type Float = Float;
392
393    fn config(&self) -> &Self::Config {
394        &()
395    }
396}
397
398impl Fit<Array2<Float>, Array1<Float>> for ProximalGradientSelector<Untrained> {
399    type Fitted = ProximalGradientSelector<Trained>;
400
401    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
402        let n_features = features.ncols();
403        if n_features == 0 {
404            return Err(SklearsError::InvalidInput(
405                "No features provided".to_string(),
406            ));
407        }
408
409        if self.k > n_features {
410            return Err(SklearsError::InvalidInput(format!(
411                "k ({}) cannot be greater than number of features ({})",
412                self.k, n_features
413            )));
414        }
415
416        // Solve using proximal gradient method
417        let (weights, objective_values) = self.solve_proximal_gradient(features, target)?;
418
419        // Select top k features based on absolute weights
420        let mut feature_indices: Vec<usize> = (0..n_features).collect();
421        feature_indices.sort_by(|&a, &b| {
422            weights[b]
423                .abs()
424                .partial_cmp(&weights[a].abs())
425                .unwrap_or(std::cmp::Ordering::Equal)
426        });
427
428        let selected_features = feature_indices.into_iter().take(self.k).collect();
429
430        Ok(ProximalGradientSelector {
431            k: self.k,
432            regularization: self.regularization,
433            max_iter: self.max_iter,
434            tolerance: self.tolerance,
435            step_size: self.step_size,
436            state: PhantomData,
437            weights_: Some(weights),
438            selected_features_: Some(selected_features),
439            n_features_: Some(n_features),
440            objective_values_: Some(objective_values),
441        })
442    }
443}
444
445impl Transform<Array2<Float>> for ProximalGradientSelector<Trained> {
446    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
447        validate::check_n_features(x, self.n_features_.unwrap())?;
448
449        let selected_features = self.selected_features_.as_ref().unwrap();
450        let n_samples = x.nrows();
451        let n_selected = selected_features.len();
452        let mut x_new = Array2::zeros((n_samples, n_selected));
453
454        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
455            x_new.column_mut(new_idx).assign(&x.column(old_idx));
456        }
457
458        Ok(x_new)
459    }
460}
461
462impl SelectorMixin for ProximalGradientSelector<Trained> {
463    fn get_support(&self) -> SklResult<Array1<bool>> {
464        let n_features = self.n_features_.unwrap();
465        let selected_features = self.selected_features_.as_ref().unwrap();
466        let mut support = Array1::from_elem(n_features, false);
467
468        for &idx in selected_features {
469            support[idx] = true;
470        }
471
472        Ok(support)
473    }
474
475    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
476        let selected_features = self.selected_features_.as_ref().unwrap();
477        Ok(indices
478            .iter()
479            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
480            .collect())
481    }
482}
483
484impl FeatureSelector for ProximalGradientSelector<Trained> {
485    fn selected_features(&self) -> &Vec<usize> {
486        self.selected_features_.as_ref().unwrap()
487    }
488}
489
490impl ProximalGradientSelector<Trained> {
491    /// Get the learned weights
492    pub fn weights(&self) -> &Array1<Float> {
493        self.weights_.as_ref().unwrap()
494    }
495
496    /// Get the objective values during optimization
497    pub fn objective_values(&self) -> &[Float] {
498        self.objective_values_.as_ref().unwrap()
499    }
500
501    /// Get the number of selected features
502    pub fn n_features_out(&self) -> usize {
503        self.selected_features_.as_ref().unwrap().len()
504    }
505}
506
507/// Alternating Direction Method of Multipliers (ADMM) for feature selection
508#[derive(Debug, Clone)]
509pub struct ADMMFeatureSelector<State = Untrained> {
510    k: usize,
511    regularization: Float,
512    max_iter: usize,
513    tolerance: Float,
514    rho: Float, // ADMM penalty parameter
515    state: PhantomData<State>,
516    // Trained state
517    weights_: Option<Array1<Float>>,
518    selected_features_: Option<Vec<usize>>,
519    n_features_: Option<usize>,
520    objective_values_: Option<Vec<Float>>,
521}
522
523impl Default for ADMMFeatureSelector<Untrained> {
524    fn default() -> Self {
525        Self::new()
526    }
527}
528
529impl ADMMFeatureSelector<Untrained> {
530    /// Create a new ADMM feature selector
531    pub fn new() -> Self {
532        Self {
533            k: 10,
534            regularization: 1.0,
535            max_iter: 1000,
536            tolerance: 1e-6,
537            rho: 1.0,
538            state: PhantomData,
539            weights_: None,
540            selected_features_: None,
541            n_features_: None,
542            objective_values_: None,
543        }
544    }
545
546    /// Set the number of features to select
547    pub fn k(mut self, k: usize) -> Self {
548        self.k = k;
549        self
550    }
551
552    /// Set the regularization parameter
553    pub fn regularization(mut self, regularization: Float) -> Self {
554        self.regularization = regularization;
555        self
556    }
557
558    /// Set the maximum number of iterations
559    pub fn max_iter(mut self, max_iter: usize) -> Self {
560        self.max_iter = max_iter;
561        self
562    }
563
564    /// Set the convergence tolerance
565    pub fn tolerance(mut self, tolerance: Float) -> Self {
566        self.tolerance = tolerance;
567        self
568    }
569
570    /// Set the ADMM penalty parameter
571    pub fn rho(mut self, rho: Float) -> Self {
572        self.rho = rho;
573        self
574    }
575
576    /// Soft thresholding operator
577    fn soft_threshold(&self, x: Float, threshold: Float) -> Float {
578        if x > threshold {
579            x - threshold
580        } else if x < -threshold {
581            x + threshold
582        } else {
583            0.0
584        }
585    }
586
587    /// Solve using ADMM
588    fn solve_admm(
589        &self,
590        features: &Array2<Float>,
591        target: &Array1<Float>,
592    ) -> SklResult<(Array1<Float>, Vec<Float>)> {
593        let n_features = features.ncols();
594        let n_samples = features.nrows();
595
596        // Initialize variables
597        let mut x = Array1::<Float>::zeros(n_features); // Primary variable
598        let mut z = Array1::<Float>::zeros(n_features); // Auxiliary variable
599        let mut u = Array1::<Float>::zeros(n_features); // Dual variable
600
601        let mut objective_values = Vec::new();
602
603        // Precompute matrices for efficiency
604        let xtx = features.t().dot(features);
605        let xty = features.t().dot(target);
606
607        // ADMM iterations
608        for _iter in 0..self.max_iter {
609            let _x_old = x.clone();
610            let z_old = z.clone();
611
612            // x-update: solve quadratic subproblem
613            // (X^T X + rho I) x = X^T y + rho (z - u)
614            let rhs = &xty + self.rho * (&z - &u);
615
616            // Simplified solve: assume diagonal dominance and use Jacobi iterations
617            for i in 0..n_features {
618                let diag_elem = xtx[[i, i]] + self.rho;
619                if diag_elem > 1e-12 {
620                    let off_diag = (0..n_features)
621                        .filter(|&j| j != i)
622                        .map(|j| xtx[[i, j]] * x[j])
623                        .sum::<Float>();
624                    x[i] = (rhs[i] - off_diag) / diag_elem;
625                }
626            }
627
628            // z-update: soft thresholding
629            let threshold = self.regularization / self.rho;
630            for i in 0..n_features {
631                z[i] = self.soft_threshold(x[i] + u[i], threshold);
632            }
633
634            // u-update: dual variable update
635            u = &u + &x - &z;
636
637            // Compute objective value
638            let predictions = features.dot(&x);
639            let residuals = &predictions - target;
640            let data_term = residuals.mapv(|r| r * r).sum() / (2.0 * n_samples as Float);
641            let reg_term = self.regularization * z.mapv(|z_i| z_i.abs()).sum();
642            let objective = data_term + reg_term;
643            objective_values.push(objective);
644
645            // Check convergence
646            let primal_residual = (&x - &z).mapv(|r| r.abs()).sum();
647            let dual_residual = self.rho * (&z - &z_old).mapv(|r| r.abs()).sum();
648
649            if primal_residual < self.tolerance && dual_residual < self.tolerance {
650                break;
651            }
652        }
653
654        Ok((z, objective_values))
655    }
656}
657
658impl Estimator for ADMMFeatureSelector<Untrained> {
659    type Config = ();
660    type Error = SklearsError;
661    type Float = Float;
662
663    fn config(&self) -> &Self::Config {
664        &()
665    }
666}
667
668impl Fit<Array2<Float>, Array1<Float>> for ADMMFeatureSelector<Untrained> {
669    type Fitted = ADMMFeatureSelector<Trained>;
670
671    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
672        let n_features = features.ncols();
673        if n_features == 0 {
674            return Err(SklearsError::InvalidInput(
675                "No features provided".to_string(),
676            ));
677        }
678
679        if self.k > n_features {
680            return Err(SklearsError::InvalidInput(format!(
681                "k ({}) cannot be greater than number of features ({})",
682                self.k, n_features
683            )));
684        }
685
686        // Solve using ADMM
687        let (weights, objective_values) = self.solve_admm(features, target)?;
688
689        // Select top k features based on absolute weights
690        let mut feature_indices: Vec<usize> = (0..n_features).collect();
691        feature_indices.sort_by(|&a, &b| {
692            weights[b]
693                .abs()
694                .partial_cmp(&weights[a].abs())
695                .unwrap_or(std::cmp::Ordering::Equal)
696        });
697
698        let selected_features = feature_indices.into_iter().take(self.k).collect();
699
700        Ok(ADMMFeatureSelector {
701            k: self.k,
702            regularization: self.regularization,
703            max_iter: self.max_iter,
704            tolerance: self.tolerance,
705            rho: self.rho,
706            state: PhantomData,
707            weights_: Some(weights),
708            selected_features_: Some(selected_features),
709            n_features_: Some(n_features),
710            objective_values_: Some(objective_values),
711        })
712    }
713}
714
715impl Transform<Array2<Float>> for ADMMFeatureSelector<Trained> {
716    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
717        validate::check_n_features(x, self.n_features_.unwrap())?;
718
719        let selected_features = self.selected_features_.as_ref().unwrap();
720        let n_samples = x.nrows();
721        let n_selected = selected_features.len();
722        let mut x_new = Array2::zeros((n_samples, n_selected));
723
724        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
725            x_new.column_mut(new_idx).assign(&x.column(old_idx));
726        }
727
728        Ok(x_new)
729    }
730}
731
732impl SelectorMixin for ADMMFeatureSelector<Trained> {
733    fn get_support(&self) -> SklResult<Array1<bool>> {
734        let n_features = self.n_features_.unwrap();
735        let selected_features = self.selected_features_.as_ref().unwrap();
736        let mut support = Array1::from_elem(n_features, false);
737
738        for &idx in selected_features {
739            support[idx] = true;
740        }
741
742        Ok(support)
743    }
744
745    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
746        let selected_features = self.selected_features_.as_ref().unwrap();
747        Ok(indices
748            .iter()
749            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
750            .collect())
751    }
752}
753
754impl FeatureSelector for ADMMFeatureSelector<Trained> {
755    fn selected_features(&self) -> &Vec<usize> {
756        self.selected_features_.as_ref().unwrap()
757    }
758}
759
760impl ADMMFeatureSelector<Trained> {
761    /// Get the learned weights
762    pub fn weights(&self) -> &Array1<Float> {
763        self.weights_.as_ref().unwrap()
764    }
765
766    /// Get the objective values during optimization
767    pub fn objective_values(&self) -> &[Float] {
768        self.objective_values_.as_ref().unwrap()
769    }
770
771    /// Get the number of selected features
772    pub fn n_features_out(&self) -> usize {
773        self.selected_features_.as_ref().unwrap().len()
774    }
775}
776
777/// Semidefinite Programming (SDP) approach for feature selection
778#[derive(Debug, Clone)]
779pub struct SemidefiniteFeatureSelector<State = Untrained> {
780    k: usize,
781    max_iter: usize,
782    tolerance: Float,
783    regularization: Float,
784    state: PhantomData<State>,
785    // Trained state
786    feature_matrix_: Option<Array2<Float>>,
787    selected_features_: Option<Vec<usize>>,
788    n_features_: Option<usize>,
789    eigenvalues_: Option<Array1<Float>>,
790    objective_values_: Option<Vec<Float>>,
791}
792
793impl Default for SemidefiniteFeatureSelector<Untrained> {
794    fn default() -> Self {
795        Self::new()
796    }
797}
798
799impl SemidefiniteFeatureSelector<Untrained> {
800    /// Create a new SDP feature selector
801    pub fn new() -> Self {
802        Self {
803            k: 10,
804            max_iter: 100,
805            tolerance: 1e-6,
806            regularization: 1.0,
807            state: PhantomData,
808            feature_matrix_: None,
809            selected_features_: None,
810            n_features_: None,
811            eigenvalues_: None,
812            objective_values_: None,
813        }
814    }
815
816    /// Set the number of features to select
817    pub fn k(mut self, k: usize) -> Self {
818        self.k = k;
819        self
820    }
821
822    /// Set the maximum number of iterations
823    pub fn max_iter(mut self, max_iter: usize) -> Self {
824        self.max_iter = max_iter;
825        self
826    }
827
828    /// Set the convergence tolerance
829    pub fn tolerance(mut self, tolerance: Float) -> Self {
830        self.tolerance = tolerance;
831        self
832    }
833
834    /// Set the regularization parameter
835    pub fn regularization(mut self, regularization: Float) -> Self {
836        self.regularization = regularization;
837        self
838    }
839
840    /// Project matrix onto positive semidefinite cone
841    fn project_psd(&self, matrix: &Array2<Float>) -> SklResult<Array2<Float>> {
842        let n = matrix.nrows();
843
844        // Simple symmetric projection for PSD constraint
845        let mut projected = Array2::zeros((n, n));
846        for i in 0..n {
847            for j in 0..n {
848                projected[[i, j]] = (matrix[[i, j]] + matrix[[j, i]]) / 2.0;
849            }
850        }
851
852        // Zero out negative eigenvalues (simplified PSD projection)
853        for i in 0..n {
854            if projected[[i, i]] < 0.0 {
855                projected[[i, i]] = 0.0;
856            }
857        }
858
859        Ok(projected)
860    }
861
862    /// Solve SDP relaxation for feature selection
863    fn solve_sdp_relaxation(
864        &self,
865        features: &Array2<Float>,
866        target: &Array1<Float>,
867    ) -> SklResult<(Array2<Float>, Array1<Float>, Vec<Float>)> {
868        let n_features = features.ncols();
869
870        // Compute feature covariance matrix
871        let centered_features = features - &features.mean_axis(Axis(0)).unwrap();
872        let cov_matrix =
873            centered_features.t().dot(&centered_features) / (features.nrows() - 1) as Float;
874
875        // Compute feature-target correlations
876        let target_centered = target - target.mean().unwrap();
877        let correlations =
878            centered_features.t().dot(&target_centered) / (features.nrows() - 1) as Float;
879
880        // Initialize variable matrix X (relaxation of x*x^T where x is binary)
881        let mut x_matrix = Array2::eye(n_features) * 0.5; // Start with diagonal matrix
882        let mut objective_values = Vec::new();
883
884        // Projected gradient method for SDP relaxation
885        for _iter in 0..self.max_iter {
886            let _x_old = x_matrix.clone();
887
888            // Compute gradient
889            // Objective: maximize correlations^T * X * correlations - regularization * trace(X * cov_matrix)
890            let outer_corr = outer_product(&correlations, &correlations);
891            let grad = &outer_corr - self.regularization * &cov_matrix;
892
893            // Gradient step
894            let step_size = 0.01;
895            let x_new = &x_matrix + step_size * &grad;
896
897            // Project onto constraints: PSD and diagonal constraints
898            let mut x_projected = self.project_psd(&x_new)?;
899
900            // Enforce constraint: 0 <= X_{ii} <= 1
901            for i in 0..n_features {
902                x_projected[[i, i]] = x_projected[[i, i]].clamp(0.0, 1.0);
903            }
904
905            // Compute objective value
906            let obj = correlations.dot(&x_projected.dot(&correlations))
907                - self.regularization * trace(&x_projected.dot(&cov_matrix));
908            objective_values.push(obj);
909
910            // Check convergence
911            let diff = (&x_projected - &x_matrix).mapv(|x| x.abs()).sum();
912            if diff < self.tolerance {
913                break;
914            }
915
916            x_matrix = x_projected;
917        }
918
919        // Extract eigenvalues from final matrix
920        let eigenvalues = extract_diagonal(&x_matrix);
921
922        Ok((x_matrix, eigenvalues, objective_values))
923    }
924}
925
926impl Estimator for SemidefiniteFeatureSelector<Untrained> {
927    type Config = ();
928    type Error = SklearsError;
929    type Float = Float;
930
931    fn config(&self) -> &Self::Config {
932        &()
933    }
934}
935
936impl Fit<Array2<Float>, Array1<Float>> for SemidefiniteFeatureSelector<Untrained> {
937    type Fitted = SemidefiniteFeatureSelector<Trained>;
938
939    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
940        let n_features = features.ncols();
941        if n_features == 0 {
942            return Err(SklearsError::InvalidInput(
943                "No features provided".to_string(),
944            ));
945        }
946
947        if self.k > n_features {
948            return Err(SklearsError::InvalidInput(format!(
949                "k ({}) cannot be greater than number of features ({})",
950                self.k, n_features
951            )));
952        }
953
954        // Solve SDP relaxation
955        let (feature_matrix, eigenvalues, objective_values) =
956            self.solve_sdp_relaxation(features, target)?;
957
958        // Select top k features based on diagonal values (relaxed selection indicators)
959        let mut feature_indices: Vec<usize> = (0..n_features).collect();
960        feature_indices.sort_by(|&a, &b| {
961            eigenvalues[b]
962                .partial_cmp(&eigenvalues[a])
963                .unwrap_or(std::cmp::Ordering::Equal)
964        });
965
966        let selected_features = feature_indices.into_iter().take(self.k).collect();
967
968        Ok(SemidefiniteFeatureSelector {
969            k: self.k,
970            max_iter: self.max_iter,
971            tolerance: self.tolerance,
972            regularization: self.regularization,
973            state: PhantomData,
974            feature_matrix_: Some(feature_matrix),
975            selected_features_: Some(selected_features),
976            n_features_: Some(n_features),
977            eigenvalues_: Some(eigenvalues),
978            objective_values_: Some(objective_values),
979        })
980    }
981}
982
983impl Transform<Array2<Float>> for SemidefiniteFeatureSelector<Trained> {
984    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
985        validate::check_n_features(x, self.n_features_.unwrap())?;
986
987        let selected_features = self.selected_features_.as_ref().unwrap();
988        let n_samples = x.nrows();
989        let n_selected = selected_features.len();
990        let mut x_new = Array2::zeros((n_samples, n_selected));
991
992        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
993            x_new.column_mut(new_idx).assign(&x.column(old_idx));
994        }
995
996        Ok(x_new)
997    }
998}
999
1000impl SelectorMixin for SemidefiniteFeatureSelector<Trained> {
1001    fn get_support(&self) -> SklResult<Array1<bool>> {
1002        let n_features = self.n_features_.unwrap();
1003        let selected_features = self.selected_features_.as_ref().unwrap();
1004        let mut support = Array1::from_elem(n_features, false);
1005
1006        for &idx in selected_features {
1007            support[idx] = true;
1008        }
1009
1010        Ok(support)
1011    }
1012
1013    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
1014        let selected_features = self.selected_features_.as_ref().unwrap();
1015        Ok(indices
1016            .iter()
1017            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
1018            .collect())
1019    }
1020}
1021
1022impl FeatureSelector for SemidefiniteFeatureSelector<Trained> {
1023    fn selected_features(&self) -> &Vec<usize> {
1024        self.selected_features_.as_ref().unwrap()
1025    }
1026}
1027
1028impl SemidefiniteFeatureSelector<Trained> {
1029    /// Get the feature matrix from SDP optimization
1030    pub fn feature_matrix(&self) -> &Array2<Float> {
1031        self.feature_matrix_.as_ref().unwrap()
1032    }
1033
1034    /// Get the eigenvalues (selection indicators)
1035    pub fn eigenvalues(&self) -> &Array1<Float> {
1036        self.eigenvalues_.as_ref().unwrap()
1037    }
1038
1039    /// Get the objective values during optimization
1040    pub fn objective_values(&self) -> &[Float] {
1041        self.objective_values_.as_ref().unwrap()
1042    }
1043
1044    /// Get the number of selected features
1045    pub fn n_features_out(&self) -> usize {
1046        self.selected_features_.as_ref().unwrap().len()
1047    }
1048}
1049
1050/// Integer Programming approach for feature selection
1051#[derive(Debug, Clone)]
1052pub struct IntegerProgrammingFeatureSelector<State = Untrained> {
1053    k: usize,
1054    max_iter: usize,
1055    tolerance: Float,
1056    greedy_init: bool,
1057    local_search: bool,
1058    state: PhantomData<State>,
1059    // Trained state
1060    binary_solution_: Option<Array1<bool>>,
1061    selected_features_: Option<Vec<usize>>,
1062    n_features_: Option<usize>,
1063    objective_value_: Option<Float>,
1064    improvement_history_: Option<Vec<Float>>,
1065}
1066
1067impl Default for IntegerProgrammingFeatureSelector<Untrained> {
1068    fn default() -> Self {
1069        Self::new()
1070    }
1071}
1072
1073impl IntegerProgrammingFeatureSelector<Untrained> {
1074    /// Create a new integer programming feature selector
1075    pub fn new() -> Self {
1076        Self {
1077            k: 10,
1078            max_iter: 1000,
1079            tolerance: 1e-6,
1080            greedy_init: true,
1081            local_search: true,
1082            state: PhantomData,
1083            binary_solution_: None,
1084            selected_features_: None,
1085            n_features_: None,
1086            objective_value_: None,
1087            improvement_history_: None,
1088        }
1089    }
1090
1091    /// Set the number of features to select
1092    pub fn k(mut self, k: usize) -> Self {
1093        self.k = k;
1094        self
1095    }
1096
1097    /// Set the maximum number of iterations
1098    pub fn max_iter(mut self, max_iter: usize) -> Self {
1099        self.max_iter = max_iter;
1100        self
1101    }
1102
1103    /// Set the convergence tolerance
1104    pub fn tolerance(mut self, tolerance: Float) -> Self {
1105        self.tolerance = tolerance;
1106        self
1107    }
1108
1109    /// Enable/disable greedy initialization
1110    pub fn greedy_init(mut self, greedy_init: bool) -> Self {
1111        self.greedy_init = greedy_init;
1112        self
1113    }
1114
1115    /// Enable/disable local search
1116    pub fn local_search(mut self, local_search: bool) -> Self {
1117        self.local_search = local_search;
1118        self
1119    }
1120
1121    /// Compute feature importance scores
1122    fn compute_feature_scores(
1123        &self,
1124        features: &Array2<Float>,
1125        target: &Array1<Float>,
1126    ) -> SklResult<Array1<Float>> {
1127        let n_features = features.ncols();
1128        let mut scores = Array1::zeros(n_features);
1129
1130        // Compute correlation-based importance
1131        for i in 0..n_features {
1132            let feature_col = features.column(i);
1133            let correlation = correlation_coefficient(&feature_col.to_owned(), target)?;
1134            scores[i] = correlation.abs();
1135        }
1136
1137        Ok(scores)
1138    }
1139
1140    /// Greedy initialization for IP
1141    fn greedy_initialization(&self, scores: &Array1<Float>) -> Array1<bool> {
1142        let n_features = scores.len();
1143        let mut solution = Array1::from_elem(n_features, false);
1144
1145        // Select top k features greedily
1146        let mut indices: Vec<usize> = (0..n_features).collect();
1147        indices.sort_by(|&a, &b| {
1148            scores[b]
1149                .partial_cmp(&scores[a])
1150                .unwrap_or(std::cmp::Ordering::Equal)
1151        });
1152
1153        for &idx in indices.iter().take(self.k) {
1154            solution[idx] = true;
1155        }
1156
1157        solution
1158    }
1159
1160    /// Evaluate objective function for binary solution
1161    fn evaluate_objective(&self, solution: &Array1<bool>, scores: &Array1<Float>) -> Float {
1162        let mut objective = 0.0;
1163        let mut selected_count = 0;
1164
1165        for i in 0..solution.len() {
1166            if solution[i] {
1167                objective += scores[i];
1168                selected_count += 1;
1169            }
1170        }
1171
1172        // Penalty for violating cardinality constraint
1173        if selected_count != self.k {
1174            objective -= 1000.0 * (selected_count as Float - self.k as Float).abs();
1175        }
1176
1177        objective
1178    }
1179
1180    /// Local search improvement (1-opt and 2-opt moves)
1181    fn local_search_improvement(
1182        &self,
1183        solution: &mut Array1<bool>,
1184        scores: &Array1<Float>,
1185        best_obj: &mut Float,
1186    ) -> bool {
1187        let n_features = solution.len();
1188        let mut improved = false;
1189
1190        // 1-opt moves: flip single features
1191        for i in 0..n_features {
1192            let original = solution[i];
1193            solution[i] = !solution[i];
1194
1195            let new_obj = self.evaluate_objective(solution, scores);
1196            if new_obj > *best_obj + self.tolerance {
1197                *best_obj = new_obj;
1198                improved = true;
1199            } else {
1200                solution[i] = original; // Revert if no improvement
1201            }
1202        }
1203
1204        // 2-opt moves: swap pairs of features
1205        if self.local_search {
1206            for i in 0..n_features {
1207                for j in (i + 1)..n_features {
1208                    if solution[i] == solution[j] {
1209                        continue; // Skip if both have same value
1210                    }
1211
1212                    // Swap
1213                    let temp = solution[i];
1214                    solution[i] = solution[j];
1215                    solution[j] = temp;
1216
1217                    let new_obj = self.evaluate_objective(solution, scores);
1218                    if new_obj > *best_obj + self.tolerance {
1219                        *best_obj = new_obj;
1220                        improved = true;
1221                    } else {
1222                        // Revert swap
1223                        let temp = solution[i];
1224                        solution[i] = solution[j];
1225                        solution[j] = temp;
1226                    }
1227                }
1228            }
1229        }
1230
1231        improved
1232    }
1233
1234    /// Solve integer programming problem approximately
1235    fn solve_integer_programming(
1236        &self,
1237        features: &Array2<Float>,
1238        target: &Array1<Float>,
1239    ) -> SklResult<(Array1<bool>, Float, Vec<Float>)> {
1240        let scores = self.compute_feature_scores(features, target)?;
1241        let mut improvement_history = Vec::new();
1242
1243        // Initialize solution
1244        let mut solution = if self.greedy_init {
1245            self.greedy_initialization(&scores)
1246        } else {
1247            // Random initialization
1248            let mut random_solution = Array1::from_elem(scores.len(), false);
1249            let indices: Vec<usize> = (0..scores.len()).collect();
1250            for &idx in indices.iter().take(self.k) {
1251                random_solution[idx] = true;
1252            }
1253            random_solution
1254        };
1255
1256        let mut best_objective = self.evaluate_objective(&solution, &scores);
1257        improvement_history.push(best_objective);
1258
1259        // Iterative improvement
1260        for _iter in 0..self.max_iter {
1261            let prev_objective = best_objective;
1262
1263            let improved =
1264                self.local_search_improvement(&mut solution, &scores, &mut best_objective);
1265            improvement_history.push(best_objective);
1266
1267            if !improved || (best_objective - prev_objective).abs() < self.tolerance {
1268                break;
1269            }
1270        }
1271
1272        Ok((solution, best_objective, improvement_history))
1273    }
1274}
1275
1276impl Estimator for IntegerProgrammingFeatureSelector<Untrained> {
1277    type Config = ();
1278    type Error = SklearsError;
1279    type Float = Float;
1280
1281    fn config(&self) -> &Self::Config {
1282        &()
1283    }
1284}
1285
1286impl Fit<Array2<Float>, Array1<Float>> for IntegerProgrammingFeatureSelector<Untrained> {
1287    type Fitted = IntegerProgrammingFeatureSelector<Trained>;
1288
1289    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
1290        let n_features = features.ncols();
1291        if n_features == 0 {
1292            return Err(SklearsError::InvalidInput(
1293                "No features provided".to_string(),
1294            ));
1295        }
1296
1297        if self.k > n_features {
1298            return Err(SklearsError::InvalidInput(format!(
1299                "k ({}) cannot be greater than number of features ({})",
1300                self.k, n_features
1301            )));
1302        }
1303
1304        // Solve integer programming problem
1305        let (binary_solution, objective_value, improvement_history) =
1306            self.solve_integer_programming(features, target)?;
1307
1308        // Extract selected features
1309        let selected_features: Vec<usize> = binary_solution
1310            .iter()
1311            .enumerate()
1312            .filter_map(|(i, &selected)| if selected { Some(i) } else { None })
1313            .collect();
1314
1315        Ok(IntegerProgrammingFeatureSelector {
1316            k: self.k,
1317            max_iter: self.max_iter,
1318            tolerance: self.tolerance,
1319            greedy_init: self.greedy_init,
1320            local_search: self.local_search,
1321            state: PhantomData,
1322            binary_solution_: Some(binary_solution),
1323            selected_features_: Some(selected_features),
1324            n_features_: Some(n_features),
1325            objective_value_: Some(objective_value),
1326            improvement_history_: Some(improvement_history),
1327        })
1328    }
1329}
1330
1331impl Transform<Array2<Float>> for IntegerProgrammingFeatureSelector<Trained> {
1332    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
1333        validate::check_n_features(x, self.n_features_.unwrap())?;
1334
1335        let selected_features = self.selected_features_.as_ref().unwrap();
1336        let n_samples = x.nrows();
1337        let n_selected = selected_features.len();
1338        let mut x_new = Array2::zeros((n_samples, n_selected));
1339
1340        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
1341            x_new.column_mut(new_idx).assign(&x.column(old_idx));
1342        }
1343
1344        Ok(x_new)
1345    }
1346}
1347
1348impl SelectorMixin for IntegerProgrammingFeatureSelector<Trained> {
1349    fn get_support(&self) -> SklResult<Array1<bool>> {
1350        Ok(self.binary_solution_.as_ref().unwrap().clone())
1351    }
1352
1353    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
1354        let selected_features = self.selected_features_.as_ref().unwrap();
1355        Ok(indices
1356            .iter()
1357            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
1358            .collect())
1359    }
1360}
1361
1362impl FeatureSelector for IntegerProgrammingFeatureSelector<Trained> {
1363    fn selected_features(&self) -> &Vec<usize> {
1364        self.selected_features_.as_ref().unwrap()
1365    }
1366}
1367
1368impl IntegerProgrammingFeatureSelector<Trained> {
1369    /// Get the binary solution
1370    pub fn binary_solution(&self) -> &Array1<bool> {
1371        self.binary_solution_.as_ref().unwrap()
1372    }
1373
1374    /// Get the final objective value
1375    pub fn objective_value(&self) -> Float {
1376        self.objective_value_.unwrap()
1377    }
1378
1379    /// Get the improvement history during optimization
1380    pub fn improvement_history(&self) -> &[Float] {
1381        self.improvement_history_.as_ref().unwrap()
1382    }
1383
1384    /// Get the number of selected features
1385    pub fn n_features_out(&self) -> usize {
1386        self.selected_features_.as_ref().unwrap().len()
1387    }
1388}
1389
1390// Helper functions
1391fn outer_product(a: &Array1<Float>, b: &Array1<Float>) -> Array2<Float> {
1392    let mut result = Array2::zeros((a.len(), b.len()));
1393    for i in 0..a.len() {
1394        for j in 0..b.len() {
1395            result[[i, j]] = a[i] * b[j];
1396        }
1397    }
1398    result
1399}
1400
1401fn trace(matrix: &Array2<Float>) -> Float {
1402    let n = matrix.nrows().min(matrix.ncols());
1403    (0..n).map(|i| matrix[[i, i]]).sum()
1404}
1405
1406fn extract_diagonal(matrix: &Array2<Float>) -> Array1<Float> {
1407    let n = matrix.nrows().min(matrix.ncols());
1408    let mut diag = Array1::zeros(n);
1409    for i in 0..n {
1410        diag[i] = matrix[[i, i]];
1411    }
1412    diag
1413}
1414
1415fn correlation_coefficient(x: &Array1<Float>, y: &Array1<Float>) -> SklResult<Float> {
1416    if x.len() != y.len() {
1417        return Err(SklearsError::InvalidInput(
1418            "Arrays must have the same length".to_string(),
1419        ));
1420    }
1421
1422    let _n = x.len() as Float;
1423    let mean_x = x.mean().unwrap();
1424    let mean_y = y.mean().unwrap();
1425
1426    let mut num = 0.0;
1427    let mut den_x = 0.0;
1428    let mut den_y = 0.0;
1429
1430    for i in 0..x.len() {
1431        let diff_x = x[i] - mean_x;
1432        let diff_y = y[i] - mean_y;
1433        num += diff_x * diff_y;
1434        den_x += diff_x * diff_x;
1435        den_y += diff_y * diff_y;
1436    }
1437
1438    if den_x.abs() < 1e-10 || den_y.abs() < 1e-10 {
1439        return Ok(0.0);
1440    }
1441
1442    Ok(num / (den_x * den_y).sqrt())
1443}
1444
1445#[allow(non_snake_case)]
1446#[cfg(test)]
1447mod tests {
1448    use super::*;
1449    use scirs2_core::ndarray::Array2;
1450
1451    fn create_test_data() -> (Array2<Float>, Array1<Float>) {
1452        // Create synthetic data with some correlation structure
1453        let n_samples = 50;
1454        let n_features = 10;
1455        let mut features = Array2::zeros((n_samples, n_features));
1456        let mut target = Array1::zeros(n_samples);
1457
1458        // Fill with structured data
1459        for i in 0..n_samples {
1460            for j in 0..n_features {
1461                features[[i, j]] = (i as Float * 0.1 + j as Float * 0.01).sin() + 0.1 * j as Float;
1462            }
1463            // Make first few features predictive
1464            target[i] = features[[i, 0]] + 0.5 * features[[i, 1]] + 0.1 * features[[i, 2]];
1465        }
1466
1467        (features, target)
1468    }
1469
1470    #[test]
1471    fn test_convex_feature_selector() {
1472        let (features, target) = create_test_data();
1473
1474        let selector = ConvexFeatureSelector::new()
1475            .k(5)
1476            .regularization(0.1)
1477            .max_iter(100);
1478
1479        let trained = selector.fit(&features, &target).unwrap();
1480        assert_eq!(trained.n_features_out(), 5);
1481
1482        // Test transform
1483        let transformed = trained.transform(&features).unwrap();
1484        assert_eq!(transformed.ncols(), 5);
1485        assert_eq!(transformed.nrows(), features.nrows());
1486
1487        // Test weights
1488        let weights = trained.weights();
1489        assert_eq!(weights.len(), features.ncols());
1490        assert!(weights.iter().all(|&x| x.is_finite()));
1491
1492        // Test objective values
1493        let obj_vals = trained.objective_values();
1494        assert!(!obj_vals.is_empty());
1495        assert!(obj_vals.iter().all(|&x| x.is_finite()));
1496    }
1497
1498    #[test]
1499    fn test_proximal_gradient_selector() {
1500        let (features, target) = create_test_data();
1501
1502        let selector = ProximalGradientSelector::new()
1503            .k(4)
1504            .regularization(0.1)
1505            .step_size(0.01)
1506            .max_iter(100);
1507
1508        let trained = selector.fit(&features, &target).unwrap();
1509        assert_eq!(trained.n_features_out(), 4);
1510
1511        // Test transform
1512        let transformed = trained.transform(&features).unwrap();
1513        assert_eq!(transformed.ncols(), 4);
1514        assert_eq!(transformed.nrows(), features.nrows());
1515
1516        // Test weights
1517        let weights = trained.weights();
1518        assert_eq!(weights.len(), features.ncols());
1519        assert!(weights.iter().all(|&x| x.is_finite()));
1520    }
1521
1522    #[test]
1523    fn test_admm_feature_selector() {
1524        let (features, target) = create_test_data();
1525
1526        let selector = ADMMFeatureSelector::new()
1527            .k(3)
1528            .regularization(0.1)
1529            .rho(1.0)
1530            .max_iter(50);
1531
1532        let trained = selector.fit(&features, &target).unwrap();
1533        assert_eq!(trained.n_features_out(), 3);
1534
1535        // Test transform
1536        let transformed = trained.transform(&features).unwrap();
1537        assert_eq!(transformed.ncols(), 3);
1538        assert_eq!(transformed.nrows(), features.nrows());
1539
1540        // Test weights
1541        let weights = trained.weights();
1542        assert_eq!(weights.len(), features.ncols());
1543        assert!(weights.iter().all(|&x| x.is_finite()));
1544    }
1545
1546    #[test]
1547    fn test_convex_selector_invalid_k() {
1548        let (features, target) = create_test_data();
1549
1550        let selector = ConvexFeatureSelector::new().k(features.ncols() + 1);
1551        assert!(selector.fit(&features, &target).is_err());
1552    }
1553
1554    #[test]
1555    fn test_proximal_selector_invalid_k() {
1556        let (features, target) = create_test_data();
1557
1558        let selector = ProximalGradientSelector::new().k(features.ncols() + 1);
1559        assert!(selector.fit(&features, &target).is_err());
1560    }
1561
1562    #[test]
1563    fn test_admm_selector_invalid_k() {
1564        let (features, target) = create_test_data();
1565
1566        let selector = ADMMFeatureSelector::new().k(features.ncols() + 1);
1567        assert!(selector.fit(&features, &target).is_err());
1568    }
1569
1570    #[test]
1571    fn test_semidefinite_feature_selector() {
1572        let (features, target) = create_test_data();
1573
1574        let selector = SemidefiniteFeatureSelector::new()
1575            .k(4)
1576            .regularization(0.1)
1577            .max_iter(50);
1578
1579        let trained = selector.fit(&features, &target).unwrap();
1580        assert_eq!(trained.n_features_out(), 4);
1581
1582        // Test transform
1583        let transformed = trained.transform(&features).unwrap();
1584        assert_eq!(transformed.ncols(), 4);
1585        assert_eq!(transformed.nrows(), features.nrows());
1586
1587        // Test feature matrix
1588        let feature_matrix = trained.feature_matrix();
1589        assert_eq!(feature_matrix.nrows(), features.ncols());
1590        assert_eq!(feature_matrix.ncols(), features.ncols());
1591
1592        // Test eigenvalues
1593        let eigenvalues = trained.eigenvalues();
1594        assert_eq!(eigenvalues.len(), features.ncols());
1595        assert!(eigenvalues.iter().all(|&x| x.is_finite()));
1596
1597        // Test objective values
1598        let obj_vals = trained.objective_values();
1599        assert!(!obj_vals.is_empty());
1600        assert!(obj_vals.iter().all(|&x| x.is_finite()));
1601    }
1602
1603    #[test]
1604    fn test_integer_programming_feature_selector() {
1605        let (features, target) = create_test_data();
1606
1607        let selector = IntegerProgrammingFeatureSelector::new()
1608            .k(3)
1609            .greedy_init(true)
1610            .local_search(true)
1611            .max_iter(100);
1612
1613        let trained = selector.fit(&features, &target).unwrap();
1614        assert_eq!(trained.n_features_out(), 3);
1615
1616        // Test transform
1617        let transformed = trained.transform(&features).unwrap();
1618        assert_eq!(transformed.ncols(), 3);
1619        assert_eq!(transformed.nrows(), features.nrows());
1620
1621        // Test binary solution
1622        let binary_solution = trained.binary_solution();
1623        assert_eq!(binary_solution.len(), features.ncols());
1624        let selected_count = binary_solution.iter().filter(|&&x| x).count();
1625        assert_eq!(selected_count, 3);
1626
1627        // Test objective value
1628        let obj_value = trained.objective_value();
1629        assert!(obj_value.is_finite());
1630
1631        // Test improvement history
1632        let improvement_history = trained.improvement_history();
1633        assert!(!improvement_history.is_empty());
1634        assert!(improvement_history.iter().all(|&x| x.is_finite()));
1635    }
1636
1637    #[test]
1638    fn test_semidefinite_selector_invalid_k() {
1639        let (features, target) = create_test_data();
1640
1641        let selector = SemidefiniteFeatureSelector::new().k(features.ncols() + 1);
1642        assert!(selector.fit(&features, &target).is_err());
1643    }
1644
1645    #[test]
1646    fn test_integer_programming_selector_invalid_k() {
1647        let (features, target) = create_test_data();
1648
1649        let selector = IntegerProgrammingFeatureSelector::new().k(features.ncols() + 1);
1650        assert!(selector.fit(&features, &target).is_err());
1651    }
1652
1653    #[test]
1654    fn test_correlation_coefficient() {
1655        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1656        let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1657
1658        let corr = correlation_coefficient(&x, &y).unwrap();
1659        assert!((corr - 1.0).abs() < 1e-10); // Perfect correlation
1660
1661        let z = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1662        let corr2 = correlation_coefficient(&x, &z).unwrap();
1663        assert!((corr2 + 1.0).abs() < 1e-10); // Perfect negative correlation
1664    }
1665
1666    #[test]
1667    fn test_helper_functions() {
1668        let a = Array1::from_vec(vec![1.0, 2.0]);
1669        let b = Array1::from_vec(vec![3.0, 4.0]);
1670
1671        let outer = outer_product(&a, &b);
1672        assert_eq!(outer[[0, 0]], 3.0);
1673        assert_eq!(outer[[0, 1]], 4.0);
1674        assert_eq!(outer[[1, 0]], 6.0);
1675        assert_eq!(outer[[1, 1]], 8.0);
1676
1677        let matrix = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1678        let tr = trace(&matrix);
1679        assert_eq!(tr, 5.0); // 1 + 4
1680
1681        let diag = extract_diagonal(&matrix);
1682        assert_eq!(diag[0], 1.0);
1683        assert_eq!(diag[1], 4.0);
1684    }
1685}