Skip to main content

anofox_ml_preprocessing/
mutual_information.rs

1use anofox_ml_core::{Fit, Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3use std::collections::HashMap;
4
5/// Parameters for `MutualInformationSelector` (unfitted state).
6///
7/// Selects the top-k features by mutual information with the target variable.
8/// Mutual information measures the dependency between two variables: a higher
9/// score means the feature is more informative about the target.
10///
11/// For continuous features, values are discretized into equal-width bins before
12/// computing mutual information. The number of bins can be tuned via `n_bins`.
13///
14/// This is a **supervised** feature selector and requires target labels `y`.
15///
16/// # Example
17///
18/// ```
19/// use anofox_ml_preprocessing::MutualInformationSelector;
20/// use anofox_ml_core::{Fit, Transform};
21/// use ndarray::array;
22///
23/// let x = array![
24///     [1.0, 100.0],
25///     [2.0, 200.0],
26///     [1.0, 300.0],
27///     [2.0, 400.0],
28/// ];
29/// let y = array![0.0, 1.0, 0.0, 1.0]; // perfectly correlated with col 0
30///
31/// let selector = MutualInformationSelector::new(1);
32/// let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
33/// let x_selected = fitted.transform(&x).unwrap();
34///
35/// assert_eq!(x_selected.ncols(), 1);
36/// ```
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct MutualInformationSelector {
39    /// Number of top features to select.
40    pub n_features_to_select: usize,
41    /// Number of equal-width bins for discretizing continuous features.
42    pub n_bins: usize,
43}
44
45impl MutualInformationSelector {
46    /// Create a new selector that keeps the top `n_features_to_select` features.
47    pub fn new(n_features_to_select: usize) -> Self {
48        Self {
49            n_features_to_select,
50            n_bins: 10,
51        }
52    }
53
54    /// Set the number of bins for discretizing continuous features.
55    pub fn with_n_bins(mut self, n_bins: usize) -> Self {
56        self.n_bins = n_bins;
57        self
58    }
59}
60
61/// Fitted `MutualInformationSelector` — holds per-feature MI scores and
62/// the indices of the selected top-k features.
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
65pub struct FittedMutualInformationSelector<F: Float> {
66    /// Mutual information score for each feature.
67    mi_scores: Array1<F>,
68    /// Indices of the top-k features (sorted by index for stable column ordering).
69    selected_indices: Vec<usize>,
70    /// Total number of input features (before selection).
71    n_features_in: usize,
72}
73
74impl<F: Float> FittedMutualInformationSelector<F> {
75    /// Per-feature mutual information scores.
76    pub fn mi_scores(&self) -> &Array1<F> {
77        &self.mi_scores
78    }
79
80    /// Indices of the selected features, sorted in ascending order.
81    pub fn selected_indices(&self) -> &[usize] {
82        &self.selected_indices
83    }
84}
85
86/// Discretize a 1-D array of continuous values into `n_bins` equal-width bins.
87///
88/// Returns a `Vec<usize>` where each element is the bin index (0..n_bins-1).
89/// If all values are identical (zero range), every sample is placed in bin 0.
90fn discretize<F: Float>(values: &[F], n_bins: usize) -> Vec<usize> {
91    let mut min_val = values[0];
92    let mut max_val = values[0];
93    for &v in values.iter().skip(1) {
94        if v < min_val {
95            min_val = v;
96        }
97        if v > max_val {
98            max_val = v;
99        }
100    }
101
102    let range = max_val - min_val;
103    let eps = F::from_f64(1e-15).unwrap();
104
105    if range < eps {
106        // All values identical -> single bin.
107        return vec![0; values.len()];
108    }
109
110    let n_bins_f = F::from_usize(n_bins).unwrap();
111    let max_bin = n_bins - 1;
112
113    values
114        .iter()
115        .map(|&v| {
116            let normalized = (v - min_val) / range; // [0, 1]
117            let bin = (normalized * n_bins_f).to_usize().unwrap_or(max_bin);
118            bin.min(max_bin)
119        })
120        .collect()
121}
122
123/// Compute mutual information MI(X, Y) between two discrete integer-valued
124/// random variables represented as parallel slices.
125///
126/// MI(X, Y) = sum_{x,y} p(x,y) * log2(p(x,y) / (p(x) * p(y)))
127///
128/// Convention: 0 * log(0) = 0.
129fn mutual_information_discrete<F: Float>(x_bins: &[usize], y_labels: &[usize]) -> F {
130    let n = x_bins.len();
131    let n_f = F::from_usize(n).unwrap();
132
133    // Count joint and marginal frequencies.
134    let mut joint: HashMap<(usize, usize), usize> = HashMap::new();
135    let mut x_counts: HashMap<usize, usize> = HashMap::new();
136    let mut y_counts: HashMap<usize, usize> = HashMap::new();
137
138    for (&xb, &yb) in x_bins.iter().zip(y_labels.iter()) {
139        *joint.entry((xb, yb)).or_insert(0) += 1;
140        *x_counts.entry(xb).or_insert(0) += 1;
141        *y_counts.entry(yb).or_insert(0) += 1;
142    }
143
144    let mut mi = F::zero();
145    for (&(xb, yb), &count) in &joint {
146        if count == 0 {
147            continue;
148        }
149        let p_xy = F::from_usize(count).unwrap() / n_f;
150        let p_x = F::from_usize(x_counts[&xb]).unwrap() / n_f;
151        let p_y = F::from_usize(y_counts[&yb]).unwrap() / n_f;
152
153        let ratio = p_xy / (p_x * p_y);
154        mi += p_xy * ratio.ln();
155    }
156
157    // Clamp to zero in case of floating-point noise producing a tiny negative.
158    if mi < F::zero() {
159        F::zero()
160    } else {
161        mi
162    }
163}
164
165/// Convert target labels to discrete integer class indices.
166///
167/// Returns `(label_indices, n_classes)`. Unique labels are discovered in
168/// order of first appearance.
169fn labels_to_indices<F: Float>(y: &Array1<F>) -> Vec<usize> {
170    let mut label_map: HashMap<u64, usize> = HashMap::new();
171    let mut indices = Vec::with_capacity(y.len());
172
173    for &val in y.iter() {
174        // Use bit representation as hash key for exact equality.
175        let bits = val.to_f64().unwrap().to_bits();
176        let next_id = label_map.len();
177        let id = *label_map.entry(bits).or_insert(next_id);
178        indices.push(id);
179    }
180
181    indices
182}
183
184impl<F: Float> Fit<F> for MutualInformationSelector {
185    type Fitted = FittedMutualInformationSelector<F>;
186
187    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
188        let (n_samples, n_features) = x.dim();
189
190        if n_samples == 0 || n_features == 0 {
191            return Err(RustMlError::EmptyInput("input array is empty".into()));
192        }
193
194        if y.len() != n_samples {
195            return Err(RustMlError::ShapeMismatch(format!(
196                "X has {} samples but y has {} elements",
197                n_samples,
198                y.len()
199            )));
200        }
201
202        if self.n_features_to_select == 0 {
203            return Err(RustMlError::InvalidParameter(
204                "n_features_to_select must be at least 1".into(),
205            ));
206        }
207
208        if self.n_features_to_select > n_features {
209            return Err(RustMlError::InvalidParameter(format!(
210                "n_features_to_select ({}) exceeds number of features ({})",
211                self.n_features_to_select, n_features
212            )));
213        }
214
215        if self.n_bins == 0 {
216            return Err(RustMlError::InvalidParameter(
217                "n_bins must be at least 1".into(),
218            ));
219        }
220
221        // Convert target labels to integer indices.
222        let y_indices = labels_to_indices(y);
223
224        // Compute MI for each feature column.
225        let mut mi_scores = Array1::<F>::zeros(n_features);
226        for j in 0..n_features {
227            let col: Vec<F> = x.column(j).to_vec();
228            let x_bins = discretize(&col, self.n_bins);
229            mi_scores[j] = mutual_information_discrete::<F>(&x_bins, &y_indices);
230        }
231
232        // Select top-k features by MI score.
233        let mut feature_scores: Vec<(usize, F)> = mi_scores.iter().copied().enumerate().collect();
234        // Sort descending by score; break ties by feature index (ascending).
235        feature_scores.sort_by(|a, b| {
236            b.1.partial_cmp(&a.1)
237                .unwrap_or(std::cmp::Ordering::Equal)
238                .then(a.0.cmp(&b.0))
239        });
240
241        let mut selected_indices: Vec<usize> = feature_scores
242            .iter()
243            .take(self.n_features_to_select)
244            .map(|&(idx, _)| idx)
245            .collect();
246        // Sort indices for stable column ordering in transform.
247        selected_indices.sort_unstable();
248
249        Ok(FittedMutualInformationSelector {
250            mi_scores,
251            selected_indices,
252            n_features_in: n_features,
253        })
254    }
255}
256
257impl<F: Float> Transform<F> for FittedMutualInformationSelector<F> {
258    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
259        if x.ncols() != self.n_features_in {
260            return Err(RustMlError::ShapeMismatch(format!(
261                "expected {} features, got {}",
262                self.n_features_in,
263                x.ncols()
264            )));
265        }
266
267        let n_rows = x.nrows();
268        let n_selected = self.selected_indices.len();
269        let mut result = Array2::<F>::zeros((n_rows, n_selected));
270
271        for (i, row) in x.rows().into_iter().enumerate() {
272            for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
273                result[[i, out_j]] = row[src_j];
274            }
275        }
276
277        Ok(result)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use ndarray::array;
285
286    #[test]
287    fn test_selects_informative_feature_over_noise() {
288        // Feature 0: perfectly predicts the class (0->class0, 1->class1).
289        // Feature 1: random noise, uncorrelated with class.
290        let x = array![
291            [0.0, 0.5],
292            [0.0, 0.8],
293            [0.0, 0.2],
294            [0.0, 0.9],
295            [1.0, 0.3],
296            [1.0, 0.7],
297            [1.0, 0.1],
298            [1.0, 0.6],
299        ];
300        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
301
302        let selector = MutualInformationSelector::new(1).with_n_bins(2);
303        let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
304
305        // Feature 0 should be selected (it perfectly separates the classes).
306        assert_eq!(fitted.selected_indices(), &[0]);
307
308        // MI of feature 0 should be substantially larger than feature 1.
309        assert!(
310            fitted.mi_scores()[0] > fitted.mi_scores()[1],
311            "informative feature MI ({}) should be > noise MI ({})",
312            fitted.mi_scores()[0],
313            fitted.mi_scores()[1]
314        );
315    }
316
317    #[test]
318    fn test_scores_are_non_negative() {
319        let x = array![
320            [1.0, 2.0, 3.0],
321            [4.0, 5.0, 6.0],
322            [7.0, 8.0, 9.0],
323            [10.0, 11.0, 12.0],
324        ];
325        let y = array![0.0, 1.0, 0.0, 1.0];
326
327        let selector = MutualInformationSelector::new(2);
328        let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
329
330        for (i, &score) in fitted.mi_scores().iter().enumerate() {
331            assert!(
332                score >= 0.0,
333                "MI score for feature {} is negative: {}",
334                i,
335                score
336            );
337        }
338    }
339
340    #[test]
341    fn test_transform_outputs_correct_shape() {
342        let x = array![
343            [1.0, 2.0, 3.0, 4.0],
344            [5.0, 6.0, 7.0, 8.0],
345            [9.0, 10.0, 11.0, 12.0],
346        ];
347        let y = array![0.0, 1.0, 0.0];
348
349        let selector = MutualInformationSelector::new(2);
350        let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
351        let result = fitted.transform(&x).unwrap();
352
353        assert_eq!(result.nrows(), 3);
354        assert_eq!(result.ncols(), 2);
355    }
356
357    #[test]
358    fn test_selects_all_when_k_equals_n_features() {
359        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
360        let y = array![0.0, 1.0, 0.0];
361
362        let selector = MutualInformationSelector::new(2);
363        let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
364
365        assert_eq!(fitted.selected_indices(), &[0, 1]);
366    }
367
368    #[test]
369    fn test_shape_mismatch_x_y() {
370        let x = array![[1.0, 2.0], [3.0, 4.0]];
371        let y = array![0.0, 1.0, 2.0]; // 3 labels for 2 samples
372
373        let selector = MutualInformationSelector::new(1);
374        let result = Fit::<f64>::fit(&selector, &x, &y);
375
376        assert!(result.is_err());
377        match result.unwrap_err() {
378            RustMlError::ShapeMismatch(msg) => {
379                assert!(msg.contains("samples"), "unexpected message: {}", msg);
380            }
381            other => panic!("expected ShapeMismatch, got {:?}", other),
382        }
383    }
384
385    #[test]
386    fn test_error_on_empty_input() {
387        let x = Array2::<f64>::zeros((0, 3));
388        let y = Array1::<f64>::zeros(0);
389
390        let selector = MutualInformationSelector::new(1);
391        let result = Fit::<f64>::fit(&selector, &x, &y);
392
393        assert!(result.is_err());
394    }
395
396    #[test]
397    fn test_error_n_features_to_select_exceeds_n_features() {
398        let x = array![[1.0, 2.0], [3.0, 4.0]];
399        let y = array![0.0, 1.0];
400
401        let selector = MutualInformationSelector::new(5); // only 2 features
402        let result = Fit::<f64>::fit(&selector, &x, &y);
403
404        assert!(result.is_err());
405        match result.unwrap_err() {
406            RustMlError::InvalidParameter(msg) => {
407                assert!(
408                    msg.contains("n_features_to_select"),
409                    "unexpected message: {}",
410                    msg
411                );
412            }
413            other => panic!("expected InvalidParameter, got {:?}", other),
414        }
415    }
416
417    #[test]
418    fn test_shape_mismatch_on_transform() {
419        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
420        let y = array![0.0, 1.0];
421
422        let selector = MutualInformationSelector::new(1);
423        let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
424
425        let wrong = array![[1.0, 2.0]]; // 2 cols instead of 3
426        assert!(fitted.transform(&wrong).is_err());
427    }
428
429    #[test]
430    fn test_works_with_f32() {
431        let x: Array2<f32> = array![[0.0_f32, 0.5], [0.0, 0.8], [1.0, 0.3], [1.0, 0.7],];
432        let y: Array1<f32> = array![0.0_f32, 0.0, 1.0, 1.0];
433
434        let selector = MutualInformationSelector::new(1).with_n_bins(2);
435        let fitted = Fit::<f32>::fit(&selector, &x, &y).unwrap();
436
437        assert_eq!(fitted.selected_indices().len(), 1);
438        let result = fitted.transform(&x).unwrap();
439        assert_eq!(result.ncols(), 1);
440    }
441
442    #[test]
443    fn test_multiclass_labels() {
444        // Feature 0 has 3 bins matching 3 classes; feature 1 is constant.
445        let x = array![
446            [0.0, 5.0],
447            [0.0, 5.0],
448            [0.5, 5.0],
449            [0.5, 5.0],
450            [1.0, 5.0],
451            [1.0, 5.0],
452        ];
453        let y = array![0.0, 0.0, 1.0, 1.0, 2.0, 2.0];
454
455        let selector = MutualInformationSelector::new(1).with_n_bins(3);
456        let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
457
458        // Feature 0 should be selected (feature 1 has zero MI since it's constant).
459        assert_eq!(fitted.selected_indices(), &[0]);
460    }
461}