use std::collections::HashMap;
use crate::approximation::sfa::{Sfa, SfaConfig, SfaFitted};
use crate::core::config::BinStrategy;
use crate::core::traits::FittableTransformer;
#[derive(Debug, Clone)]
pub struct WeaselConfig {
pub word_size: usize,
pub n_bins: usize,
pub window_sizes: Vec<usize>,
pub window_step: usize,
pub strategy: BinStrategy,
pub norm_mean: bool,
pub norm_std: bool,
pub drop_sum: bool,
pub anova: bool,
pub chi2_threshold: f64,
}
impl WeaselConfig {
pub fn new(word_size: usize, window_sizes: Vec<usize>) -> Self {
Self {
word_size,
n_bins: 4,
window_sizes,
window_step: 1,
strategy: BinStrategy::Quantile,
norm_mean: true,
norm_std: true,
drop_sum: false,
anova: true,
chi2_threshold: 2.0,
}
}
}
#[derive(Debug, Clone)]
pub struct WeaselFitted {
pub sfa_models: Vec<(usize, SfaFitted)>,
pub config: WeaselConfig,
pub selected_features: Vec<String>,
}
pub struct Weasel;
impl Weasel {
pub fn fit(config: &WeaselConfig, x: &[Vec<f64>], y: &[String]) -> WeaselFitted {
assert!(!x.is_empty(), "Input must have at least one sample");
assert_eq!(x.len(), y.len(), "X and y must have same length");
let n_timestamps = x[0].len();
let mut sfa_models = Vec::new();
for &ws in &config.window_sizes {
if ws > n_timestamps {
continue;
}
let windows = extract_all_windows(x, ws, config.window_step);
let n_windows_per_sample = (n_timestamps - ws) / config.window_step + 1;
let expanded_y: Vec<String> = y
.iter()
.flat_map(|l| std::iter::repeat_n(l.clone(), n_windows_per_sample))
.collect();
let sfa_config = SfaConfig {
n_coefs: Some(config.word_size),
n_bins: config.n_bins,
strategy: config.strategy,
drop_sum: config.drop_sum,
anova: config.anova,
norm_mean: config.norm_mean,
norm_std: config.norm_std,
};
let sfa_fitted = Sfa::fit(&sfa_config, &windows, Some(&expanded_y));
sfa_models.push((ws, sfa_fitted));
}
let histograms = build_histograms(x, &sfa_models, config);
let selected_features = chi2_feature_selection(&histograms, y, config.chi2_threshold);
WeaselFitted {
sfa_models,
config: config.clone(),
selected_features,
}
}
pub fn transform(fitted: &WeaselFitted, x: &[Vec<f64>]) -> Vec<HashMap<String, usize>> {
let histograms = build_histograms(x, &fitted.sfa_models, &fitted.config);
histograms
.into_iter()
.map(|hist| {
hist.into_iter()
.filter(|(word, _)| fitted.selected_features.contains(word))
.collect()
})
.collect()
}
pub fn fit_transform(
config: &WeaselConfig,
x: &[Vec<f64>],
y: &[String],
) -> Vec<HashMap<String, usize>> {
let fitted = Self::fit(config, x, y);
Self::transform(&fitted, x)
}
}
fn build_histograms(
x: &[Vec<f64>],
sfa_models: &[(usize, SfaFitted)],
config: &WeaselConfig,
) -> Vec<HashMap<String, usize>> {
let n_samples = x.len();
let n_timestamps = x[0].len();
let mut histograms: Vec<HashMap<String, usize>> =
(0..n_samples).map(|_| HashMap::new()).collect();
for (ws, sfa_fitted) in sfa_models {
let n_windows_per_sample = (n_timestamps - ws) / config.window_step + 1;
let windows = extract_all_windows(x, *ws, config.window_step);
let symbolic = crate::approximation::sfa::sfa_transform_symbolic(sfa_fitted, &windows);
for sample_idx in 0..n_samples {
let start = sample_idx * n_windows_per_sample;
let end = start + n_windows_per_sample;
let words: Vec<String> = symbolic[start..end]
.iter()
.map(|bins| {
let word: String = bins.iter().map(|&b| (b'a' + b) as char).collect();
format!("{ws}_{word}")
})
.collect();
let reduced = {
let mut result = Vec::new();
let mut prev = String::new();
for word in words {
if word != prev {
prev.clone_from(&word);
result.push(word);
}
}
result
};
for word in reduced {
*histograms[sample_idx].entry(word).or_insert(0) += 1;
}
}
}
histograms
}
fn chi2_feature_selection(
histograms: &[HashMap<String, usize>],
y: &[String],
threshold: f64,
) -> Vec<String> {
let mut all_words: Vec<String> = histograms.iter().flat_map(|h| h.keys().cloned()).collect();
all_words.sort();
all_words.dedup();
let mut classes: Vec<&str> = y.iter().map(|s| s.as_str()).collect();
classes.sort();
classes.dedup();
let n = histograms.len() as f64;
let mut selected = Vec::new();
for word in &all_words {
let mut chi2 = 0.0;
for class in &classes {
let mut a = 0.0; let mut b = 0.0; let mut c = 0.0; let mut d = 0.0;
for (i, hist) in histograms.iter().enumerate() {
let has_word = hist.contains_key(word);
let is_class = y[i].as_str() == *class;
match (has_word, is_class) {
(true, true) => a += 1.0,
(true, false) => b += 1.0,
(false, true) => c += 1.0,
(false, false) => d += 1.0,
}
}
let numerator = n * (a * d - b * c) * (a * d - b * c);
let denominator = (a + b) * (c + d) * (a + c) * (b + d);
if denominator > 0.0 {
chi2 += numerator / denominator;
}
}
if chi2 >= threshold {
selected.push(word.clone());
}
}
selected
}
fn extract_all_windows(x: &[Vec<f64>], window_size: usize, window_step: usize) -> Vec<Vec<f64>> {
x.iter()
.flat_map(|sample| {
let n = sample.len();
(0..=n - window_size)
.step_by(window_step)
.map(move |i| sample[i..i + window_size].to_vec())
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weasel_basic() {
let config = WeaselConfig::new(2, vec![4, 6]);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![0.0, 2.0, 4.0, 6.0, 4.0, 2.0, 0.0, -2.0],
vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0],
];
let y = vec![
"A".to_string(),
"B".to_string(),
"A".to_string(),
"B".to_string(),
];
let result = Weasel::fit_transform(&config, &x, &y);
assert_eq!(result.len(), 4);
}
#[test]
fn test_weasel_fit_then_transform() {
let config = WeaselConfig::new(2, vec![3, 4]);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0],
vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
];
let y = vec![
"A".to_string(),
"B".to_string(),
"A".to_string(),
"B".to_string(),
];
let fitted = Weasel::fit(&config, &x, &y);
let result = Weasel::transform(&fitted, &x);
assert_eq!(result.len(), 4);
}
#[test]
fn test_chi2_selection() {
let mut h1 = HashMap::new();
h1.insert("good".to_string(), 5);
let mut h2 = HashMap::new();
h2.insert("bad".to_string(), 5);
let histograms = vec![h1.clone(), h2.clone(), h1, h2];
let y = vec![
"A".to_string(),
"B".to_string(),
"A".to_string(),
"B".to_string(),
];
let selected = chi2_feature_selection(&histograms, &y, 0.1);
assert!(selected.contains(&"good".to_string()));
assert!(selected.contains(&"bad".to_string()));
}
}