#[derive(Clone, Debug)]
pub struct OnlineFeatureSelector {
importance_ewma: Vec<f64>,
alpha: f64,
keep_fraction: f64,
warmup: u64,
mask: Vec<bool>,
n_updates: u64,
}
impl OnlineFeatureSelector {
pub fn new(n_features: usize, keep_fraction: f64, alpha: f64, warmup: u64) -> Self {
let keep_fraction = keep_fraction.clamp(f64::MIN_POSITIVE, 1.0);
let alpha = alpha.clamp(f64::MIN_POSITIVE, 1.0);
Self {
importance_ewma: vec![0.0; n_features],
alpha,
keep_fraction,
warmup,
mask: vec![true; n_features],
n_updates: 0,
}
}
pub fn update_importances(&mut self, importances: &[f64]) {
let n = self.importance_ewma.len();
assert_eq!(
importances.len(),
n,
"OnlineFeatureSelector: expected {} importances, got {}",
n,
importances.len()
);
for (ewma, &imp) in self.importance_ewma.iter_mut().zip(importances.iter()) {
*ewma = self.alpha * imp + (1.0 - self.alpha) * *ewma;
}
self.n_updates += 1;
if self.n_updates >= self.warmup {
self.recompute_mask();
}
}
pub fn mask_features(&self, features: &[f64]) -> Vec<f64> {
let n = self.importance_ewma.len();
assert_eq!(
features.len(),
n,
"OnlineFeatureSelector: expected {} features, got {}",
n,
features.len()
);
let mut out = Vec::with_capacity(n);
for (&m, &f) in self.mask.iter().zip(features.iter()) {
out.push(if m { f } else { 0.0 });
}
out
}
pub fn mask_features_in_place(&self, features: &mut [f64]) {
let n = self.importance_ewma.len();
assert_eq!(
features.len(),
n,
"OnlineFeatureSelector: expected {} features, got {}",
n,
features.len()
);
for (f, &m) in features.iter_mut().zip(self.mask.iter()) {
if !m {
*f = 0.0;
}
}
}
pub fn active_features(&self) -> &[bool] {
&self.mask
}
pub fn active_count(&self) -> usize {
self.mask.iter().filter(|&&b| b).count()
}
pub fn n_features(&self) -> usize {
self.importance_ewma.len()
}
pub fn reset(&mut self) {
self.importance_ewma.fill(0.0);
self.mask.fill(true);
self.n_updates = 0;
}
fn recompute_mask(&mut self) {
let n = self.importance_ewma.len();
if n == 0 {
return;
}
let k = (n as f64 * self.keep_fraction).ceil() as usize;
let k = k.clamp(1, n);
let mut sorted: Vec<f64> = self.importance_ewma.clone();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let threshold = sorted[k - 1];
let mut active_count = 0;
for i in 0..n {
if self.importance_ewma[i] > threshold {
self.mask[i] = true;
active_count += 1;
} else {
self.mask[i] = false;
}
}
for i in 0..n {
if active_count >= k {
break;
}
if !self.mask[i] && (self.importance_ewma[i] - threshold).abs() < f64::EPSILON {
self.mask[i] = true;
active_count += 1;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_active_during_warmup() {
let mut sel = OnlineFeatureSelector::new(4, 0.25, 0.5, 10);
for _ in 0..9 {
sel.update_importances(&[1.0, 0.0, 0.0, 0.0]);
}
assert_eq!(sel.active_count(), 4);
assert!(sel.active_features().iter().all(|&b| b));
}
#[test]
fn masking_activates_after_warmup() {
let mut sel = OnlineFeatureSelector::new(4, 0.5, 0.9, 3);
for _ in 0..3 {
sel.update_importances(&[0.9, 0.1, 0.8, 0.05]);
}
assert_eq!(sel.active_count(), 2);
assert!(sel.active_features()[0]);
assert!(!sel.active_features()[1]);
assert!(sel.active_features()[2]);
assert!(!sel.active_features()[3]);
}
#[test]
fn keep_fraction_determines_active_count() {
let mut sel = OnlineFeatureSelector::new(8, 0.75, 0.9, 1);
sel.update_importances(&[0.8, 0.7, 0.1, 0.05, 0.6, 0.5, 0.9, 0.4]);
assert_eq!(sel.active_count(), 6);
}
#[test]
fn mask_features_zeros_inactive() {
let mut sel = OnlineFeatureSelector::new(4, 0.5, 0.9, 1);
sel.update_importances(&[1.0, 0.0, 1.0, 0.0]);
let masked = sel.mask_features(&[10.0, 20.0, 30.0, 40.0]);
assert_eq!(masked[0], 10.0);
assert_eq!(masked[1], 0.0);
assert_eq!(masked[2], 30.0);
assert_eq!(masked[3], 0.0);
}
#[test]
fn importance_ewma_update_smooths_correctly() {
let mut sel = OnlineFeatureSelector::new(1, 1.0, 0.5, 100);
sel.update_importances(&[1.0]);
assert!((sel.importance_ewma[0] - 0.5).abs() < 1e-12);
sel.update_importances(&[1.0]);
assert!((sel.importance_ewma[0] - 0.75).abs() < 1e-12);
sel.update_importances(&[0.0]);
assert!((sel.importance_ewma[0] - 0.375).abs() < 1e-12);
}
#[test]
fn full_keep_no_mask() {
let mut sel = OnlineFeatureSelector::new(5, 1.0, 0.5, 1);
sel.update_importances(&[0.1, 0.2, 0.3, 0.4, 0.5]);
assert_eq!(sel.active_count(), 5);
assert!(sel.active_features().iter().all(|&b| b));
}
#[test]
fn active_count_matches_mask() {
let mut sel = OnlineFeatureSelector::new(6, 0.5, 0.9, 1);
sel.update_importances(&[0.9, 0.1, 0.8, 0.05, 0.7, 0.02]);
let count = sel.active_count();
let mask_count = sel.active_features().iter().filter(|&&b| b).count();
assert_eq!(count, mask_count);
assert_eq!(count, 3);
}
#[test]
fn reset_restores_initial_state() {
let mut sel = OnlineFeatureSelector::new(4, 0.5, 0.9, 2);
sel.update_importances(&[1.0, 0.0, 1.0, 0.0]);
sel.update_importances(&[1.0, 0.0, 1.0, 0.0]);
assert_eq!(sel.active_count(), 2);
sel.reset();
assert_eq!(sel.active_count(), 4); assert!(sel.active_features().iter().all(|&b| b));
assert!(sel.importance_ewma.iter().all(|&v| v == 0.0));
}
#[test]
fn mask_features_in_place_matches_mask_features() {
let mut sel = OnlineFeatureSelector::new(4, 0.5, 0.9, 1);
sel.update_importances(&[0.9, 0.1, 0.8, 0.05]);
let input = [10.0, 20.0, 30.0, 40.0];
let alloc = sel.mask_features(&input);
let mut inplace = input;
sel.mask_features_in_place(&mut inplace);
for (a, b) in alloc.iter().zip(inplace.iter()) {
assert!((a - b).abs() < 1e-12, "mismatch: {} vs {}", a, b);
}
}
#[test]
#[should_panic(expected = "expected 4 importances, got 2")]
fn panics_on_importance_length_mismatch() {
let mut sel = OnlineFeatureSelector::new(4, 0.5, 0.5, 1);
sel.update_importances(&[1.0, 2.0]);
}
}