#[derive(Debug, Clone)]
pub struct FeatureImportance {
gains: Vec<f64>,
}
impl FeatureImportance {
pub fn new(n_features: usize) -> Self {
Self {
gains: vec![0.0; n_features],
}
}
pub fn update(&mut self, feature_idx: usize, gain: f64) {
self.gains[feature_idx] += gain;
}
pub fn importances(&self) -> &[f64] {
&self.gains
}
pub fn normalized(&self) -> Vec<f64> {
let total = self.total_gain();
if total == 0.0 {
return vec![0.0; self.gains.len()];
}
self.gains.iter().map(|&g| g / total).collect()
}
pub fn top_k(&self, k: usize) -> Vec<(usize, f64)> {
let mut indexed: Vec<(usize, f64)> = self
.gains
.iter()
.enumerate()
.map(|(i, &g)| (i, g))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
indexed
}
pub fn n_features(&self) -> usize {
self.gains.len()
}
pub fn total_gain(&self) -> f64 {
self.gains.iter().sum()
}
pub fn reset(&mut self) {
self.gains.iter_mut().for_each(|g| *g = 0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn empty_state() {
let fi = FeatureImportance::new(5);
assert_eq!(fi.n_features(), 5);
assert!(approx_eq(fi.total_gain(), 0.0));
assert!(fi.importances().iter().all(|&g| g == 0.0));
}
#[test]
fn single_update() {
let mut fi = FeatureImportance::new(3);
fi.update(1, 2.5);
assert!(approx_eq(fi.importances()[0], 0.0));
assert!(approx_eq(fi.importances()[1], 2.5));
assert!(approx_eq(fi.importances()[2], 0.0));
assert!(approx_eq(fi.total_gain(), 2.5));
}
#[test]
fn multiple_updates_accumulate() {
let mut fi = FeatureImportance::new(3);
fi.update(0, 1.0);
fi.update(0, 0.5);
fi.update(2, 3.0);
assert!(approx_eq(fi.importances()[0], 1.5));
assert!(approx_eq(fi.importances()[1], 0.0));
assert!(approx_eq(fi.importances()[2], 3.0));
assert!(approx_eq(fi.total_gain(), 4.5));
}
#[test]
fn normalized_sums_to_one() {
let mut fi = FeatureImportance::new(4);
fi.update(0, 2.0);
fi.update(1, 3.0);
fi.update(2, 1.0);
fi.update(3, 4.0);
let norm = fi.normalized();
let sum: f64 = norm.iter().sum();
assert!(approx_eq(sum, 1.0));
assert!(approx_eq(norm[0], 0.2));
assert!(approx_eq(norm[1], 0.3));
assert!(approx_eq(norm[2], 0.1));
assert!(approx_eq(norm[3], 0.4));
}
#[test]
fn normalized_zero_total_returns_zeros() {
let fi = FeatureImportance::new(3);
let norm = fi.normalized();
assert_eq!(norm.len(), 3);
assert!(norm.iter().all(|&v| v == 0.0));
}
#[test]
fn top_k_ordering() {
let mut fi = FeatureImportance::new(5);
fi.update(0, 1.0);
fi.update(1, 5.0);
fi.update(2, 3.0);
fi.update(3, 0.5);
fi.update(4, 4.0);
let top3 = fi.top_k(3);
assert_eq!(top3.len(), 3);
assert_eq!(top3[0].0, 1); assert_eq!(top3[1].0, 4); assert_eq!(top3[2].0, 2); assert!(approx_eq(top3[0].1, 5.0));
assert!(approx_eq(top3[1].1, 4.0));
assert!(approx_eq(top3[2].1, 3.0));
}
#[test]
fn top_k_exceeds_n_features() {
let mut fi = FeatureImportance::new(2);
fi.update(0, 1.0);
fi.update(1, 2.0);
let top = fi.top_k(10);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, 1);
assert_eq!(top[1].0, 0);
}
#[test]
fn reset_clears_all_gains() {
let mut fi = FeatureImportance::new(3);
fi.update(0, 5.0);
fi.update(1, 3.0);
fi.update(2, 1.0);
fi.reset();
assert!(approx_eq(fi.total_gain(), 0.0));
assert!(fi.importances().iter().all(|&g| g == 0.0));
assert_eq!(fi.n_features(), 3); }
#[test]
fn top_k_empty_gains() {
let fi = FeatureImportance::new(4);
let top = fi.top_k(2);
assert_eq!(top.len(), 2);
assert!(top.iter().all(|(_, g)| *g == 0.0));
}
#[test]
fn top_k_zero_k() {
let mut fi = FeatureImportance::new(3);
fi.update(0, 1.0);
let top = fi.top_k(0);
assert!(top.is_empty());
}
}