1use anofox_ml_core::{Fit, Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct MutualInformationSelector {
39 pub n_features_to_select: usize,
41 pub n_bins: usize,
43}
44
45impl MutualInformationSelector {
46 pub fn new(n_features_to_select: usize) -> Self {
48 Self {
49 n_features_to_select,
50 n_bins: 10,
51 }
52 }
53
54 pub fn with_n_bins(mut self, n_bins: usize) -> Self {
56 self.n_bins = n_bins;
57 self
58 }
59}
60
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
65pub struct FittedMutualInformationSelector<F: Float> {
66 mi_scores: Array1<F>,
68 selected_indices: Vec<usize>,
70 n_features_in: usize,
72}
73
74impl<F: Float> FittedMutualInformationSelector<F> {
75 pub fn mi_scores(&self) -> &Array1<F> {
77 &self.mi_scores
78 }
79
80 pub fn selected_indices(&self) -> &[usize] {
82 &self.selected_indices
83 }
84}
85
86fn discretize<F: Float>(values: &[F], n_bins: usize) -> Vec<usize> {
91 let mut min_val = values[0];
92 let mut max_val = values[0];
93 for &v in values.iter().skip(1) {
94 if v < min_val {
95 min_val = v;
96 }
97 if v > max_val {
98 max_val = v;
99 }
100 }
101
102 let range = max_val - min_val;
103 let eps = F::from_f64(1e-15).unwrap();
104
105 if range < eps {
106 return vec![0; values.len()];
108 }
109
110 let n_bins_f = F::from_usize(n_bins).unwrap();
111 let max_bin = n_bins - 1;
112
113 values
114 .iter()
115 .map(|&v| {
116 let normalized = (v - min_val) / range; let bin = (normalized * n_bins_f).to_usize().unwrap_or(max_bin);
118 bin.min(max_bin)
119 })
120 .collect()
121}
122
123fn mutual_information_discrete<F: Float>(x_bins: &[usize], y_labels: &[usize]) -> F {
130 let n = x_bins.len();
131 let n_f = F::from_usize(n).unwrap();
132
133 let mut joint: HashMap<(usize, usize), usize> = HashMap::new();
135 let mut x_counts: HashMap<usize, usize> = HashMap::new();
136 let mut y_counts: HashMap<usize, usize> = HashMap::new();
137
138 for (&xb, &yb) in x_bins.iter().zip(y_labels.iter()) {
139 *joint.entry((xb, yb)).or_insert(0) += 1;
140 *x_counts.entry(xb).or_insert(0) += 1;
141 *y_counts.entry(yb).or_insert(0) += 1;
142 }
143
144 let mut mi = F::zero();
145 for (&(xb, yb), &count) in &joint {
146 if count == 0 {
147 continue;
148 }
149 let p_xy = F::from_usize(count).unwrap() / n_f;
150 let p_x = F::from_usize(x_counts[&xb]).unwrap() / n_f;
151 let p_y = F::from_usize(y_counts[&yb]).unwrap() / n_f;
152
153 let ratio = p_xy / (p_x * p_y);
154 mi += p_xy * ratio.ln();
155 }
156
157 if mi < F::zero() {
159 F::zero()
160 } else {
161 mi
162 }
163}
164
165fn labels_to_indices<F: Float>(y: &Array1<F>) -> Vec<usize> {
170 let mut label_map: HashMap<u64, usize> = HashMap::new();
171 let mut indices = Vec::with_capacity(y.len());
172
173 for &val in y.iter() {
174 let bits = val.to_f64().unwrap().to_bits();
176 let next_id = label_map.len();
177 let id = *label_map.entry(bits).or_insert(next_id);
178 indices.push(id);
179 }
180
181 indices
182}
183
184impl<F: Float> Fit<F> for MutualInformationSelector {
185 type Fitted = FittedMutualInformationSelector<F>;
186
187 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
188 let (n_samples, n_features) = x.dim();
189
190 if n_samples == 0 || n_features == 0 {
191 return Err(RustMlError::EmptyInput("input array is empty".into()));
192 }
193
194 if y.len() != n_samples {
195 return Err(RustMlError::ShapeMismatch(format!(
196 "X has {} samples but y has {} elements",
197 n_samples,
198 y.len()
199 )));
200 }
201
202 if self.n_features_to_select == 0 {
203 return Err(RustMlError::InvalidParameter(
204 "n_features_to_select must be at least 1".into(),
205 ));
206 }
207
208 if self.n_features_to_select > n_features {
209 return Err(RustMlError::InvalidParameter(format!(
210 "n_features_to_select ({}) exceeds number of features ({})",
211 self.n_features_to_select, n_features
212 )));
213 }
214
215 if self.n_bins == 0 {
216 return Err(RustMlError::InvalidParameter(
217 "n_bins must be at least 1".into(),
218 ));
219 }
220
221 let y_indices = labels_to_indices(y);
223
224 let mut mi_scores = Array1::<F>::zeros(n_features);
226 for j in 0..n_features {
227 let col: Vec<F> = x.column(j).to_vec();
228 let x_bins = discretize(&col, self.n_bins);
229 mi_scores[j] = mutual_information_discrete::<F>(&x_bins, &y_indices);
230 }
231
232 let mut feature_scores: Vec<(usize, F)> = mi_scores.iter().copied().enumerate().collect();
234 feature_scores.sort_by(|a, b| {
236 b.1.partial_cmp(&a.1)
237 .unwrap_or(std::cmp::Ordering::Equal)
238 .then(a.0.cmp(&b.0))
239 });
240
241 let mut selected_indices: Vec<usize> = feature_scores
242 .iter()
243 .take(self.n_features_to_select)
244 .map(|&(idx, _)| idx)
245 .collect();
246 selected_indices.sort_unstable();
248
249 Ok(FittedMutualInformationSelector {
250 mi_scores,
251 selected_indices,
252 n_features_in: n_features,
253 })
254 }
255}
256
257impl<F: Float> Transform<F> for FittedMutualInformationSelector<F> {
258 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
259 if x.ncols() != self.n_features_in {
260 return Err(RustMlError::ShapeMismatch(format!(
261 "expected {} features, got {}",
262 self.n_features_in,
263 x.ncols()
264 )));
265 }
266
267 let n_rows = x.nrows();
268 let n_selected = self.selected_indices.len();
269 let mut result = Array2::<F>::zeros((n_rows, n_selected));
270
271 for (i, row) in x.rows().into_iter().enumerate() {
272 for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
273 result[[i, out_j]] = row[src_j];
274 }
275 }
276
277 Ok(result)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use ndarray::array;
285
286 #[test]
287 fn test_selects_informative_feature_over_noise() {
288 let x = array![
291 [0.0, 0.5],
292 [0.0, 0.8],
293 [0.0, 0.2],
294 [0.0, 0.9],
295 [1.0, 0.3],
296 [1.0, 0.7],
297 [1.0, 0.1],
298 [1.0, 0.6],
299 ];
300 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
301
302 let selector = MutualInformationSelector::new(1).with_n_bins(2);
303 let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
304
305 assert_eq!(fitted.selected_indices(), &[0]);
307
308 assert!(
310 fitted.mi_scores()[0] > fitted.mi_scores()[1],
311 "informative feature MI ({}) should be > noise MI ({})",
312 fitted.mi_scores()[0],
313 fitted.mi_scores()[1]
314 );
315 }
316
317 #[test]
318 fn test_scores_are_non_negative() {
319 let x = array![
320 [1.0, 2.0, 3.0],
321 [4.0, 5.0, 6.0],
322 [7.0, 8.0, 9.0],
323 [10.0, 11.0, 12.0],
324 ];
325 let y = array![0.0, 1.0, 0.0, 1.0];
326
327 let selector = MutualInformationSelector::new(2);
328 let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
329
330 for (i, &score) in fitted.mi_scores().iter().enumerate() {
331 assert!(
332 score >= 0.0,
333 "MI score for feature {} is negative: {}",
334 i,
335 score
336 );
337 }
338 }
339
340 #[test]
341 fn test_transform_outputs_correct_shape() {
342 let x = array![
343 [1.0, 2.0, 3.0, 4.0],
344 [5.0, 6.0, 7.0, 8.0],
345 [9.0, 10.0, 11.0, 12.0],
346 ];
347 let y = array![0.0, 1.0, 0.0];
348
349 let selector = MutualInformationSelector::new(2);
350 let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
351 let result = fitted.transform(&x).unwrap();
352
353 assert_eq!(result.nrows(), 3);
354 assert_eq!(result.ncols(), 2);
355 }
356
357 #[test]
358 fn test_selects_all_when_k_equals_n_features() {
359 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
360 let y = array![0.0, 1.0, 0.0];
361
362 let selector = MutualInformationSelector::new(2);
363 let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
364
365 assert_eq!(fitted.selected_indices(), &[0, 1]);
366 }
367
368 #[test]
369 fn test_shape_mismatch_x_y() {
370 let x = array![[1.0, 2.0], [3.0, 4.0]];
371 let y = array![0.0, 1.0, 2.0]; let selector = MutualInformationSelector::new(1);
374 let result = Fit::<f64>::fit(&selector, &x, &y);
375
376 assert!(result.is_err());
377 match result.unwrap_err() {
378 RustMlError::ShapeMismatch(msg) => {
379 assert!(msg.contains("samples"), "unexpected message: {}", msg);
380 }
381 other => panic!("expected ShapeMismatch, got {:?}", other),
382 }
383 }
384
385 #[test]
386 fn test_error_on_empty_input() {
387 let x = Array2::<f64>::zeros((0, 3));
388 let y = Array1::<f64>::zeros(0);
389
390 let selector = MutualInformationSelector::new(1);
391 let result = Fit::<f64>::fit(&selector, &x, &y);
392
393 assert!(result.is_err());
394 }
395
396 #[test]
397 fn test_error_n_features_to_select_exceeds_n_features() {
398 let x = array![[1.0, 2.0], [3.0, 4.0]];
399 let y = array![0.0, 1.0];
400
401 let selector = MutualInformationSelector::new(5); let result = Fit::<f64>::fit(&selector, &x, &y);
403
404 assert!(result.is_err());
405 match result.unwrap_err() {
406 RustMlError::InvalidParameter(msg) => {
407 assert!(
408 msg.contains("n_features_to_select"),
409 "unexpected message: {}",
410 msg
411 );
412 }
413 other => panic!("expected InvalidParameter, got {:?}", other),
414 }
415 }
416
417 #[test]
418 fn test_shape_mismatch_on_transform() {
419 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
420 let y = array![0.0, 1.0];
421
422 let selector = MutualInformationSelector::new(1);
423 let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
424
425 let wrong = array![[1.0, 2.0]]; assert!(fitted.transform(&wrong).is_err());
427 }
428
429 #[test]
430 fn test_works_with_f32() {
431 let x: Array2<f32> = array![[0.0_f32, 0.5], [0.0, 0.8], [1.0, 0.3], [1.0, 0.7],];
432 let y: Array1<f32> = array![0.0_f32, 0.0, 1.0, 1.0];
433
434 let selector = MutualInformationSelector::new(1).with_n_bins(2);
435 let fitted = Fit::<f32>::fit(&selector, &x, &y).unwrap();
436
437 assert_eq!(fitted.selected_indices().len(), 1);
438 let result = fitted.transform(&x).unwrap();
439 assert_eq!(result.ncols(), 1);
440 }
441
442 #[test]
443 fn test_multiclass_labels() {
444 let x = array![
446 [0.0, 5.0],
447 [0.0, 5.0],
448 [0.5, 5.0],
449 [0.5, 5.0],
450 [1.0, 5.0],
451 [1.0, 5.0],
452 ];
453 let y = array![0.0, 0.0, 1.0, 1.0, 2.0, 2.0];
454
455 let selector = MutualInformationSelector::new(1).with_n_bins(3);
456 let fitted = Fit::<f64>::fit(&selector, &x, &y).unwrap();
457
458 assert_eq!(fitted.selected_indices(), &[0]);
460 }
461}