use std::collections::HashMap;
use crate::approximation::sfa::{sfa_transform_symbolic, Sfa, SfaConfig, SfaFitted};
use crate::core::config::BinStrategy;
use crate::core::traits::FittableTransformer;
#[derive(Debug, Clone)]
pub struct WeaselMuseConfig {
pub word_size: usize,
pub n_bins: usize,
pub window_sizes: Vec<usize>,
pub window_step: usize,
pub strategy: BinStrategy,
pub chi2_threshold: f64,
pub use_first_differences: bool,
}
impl WeaselMuseConfig {
pub fn new(word_size: usize, window_sizes: Vec<usize>) -> Self {
Self {
word_size,
n_bins: 4,
window_sizes,
window_step: 1,
strategy: BinStrategy::Quantile,
chi2_threshold: 2.0,
use_first_differences: true,
}
}
}
#[derive(Debug, Clone)]
pub struct WeaselMuseFitted {
pub sfa_models: Vec<(usize, usize, SfaFitted)>,
pub config: WeaselMuseConfig,
pub selected_features: Vec<String>,
}
fn extract_feature_series(
x: &[Vec<Vec<f64>>],
feat_idx: usize,
n_features: usize,
) -> Vec<Vec<f64>> {
let is_diff = feat_idx >= n_features;
let orig_feat = if is_diff {
feat_idx - n_features
} else {
feat_idx
};
if is_diff {
x.iter()
.map(|sample| {
let s = &sample[orig_feat];
(0..s.len() - 1).map(|i| s[i + 1] - s[i]).collect()
})
.collect()
} else {
x.iter().map(|sample| sample[orig_feat].clone()).collect()
}
}
fn extract_windows(series: &[Vec<f64>], window_size: usize, window_step: usize) -> Vec<Vec<f64>> {
series
.iter()
.flat_map(|s| {
(0..=s.len() - window_size)
.step_by(window_step)
.map(move |i| s[i..i + window_size].to_vec())
})
.collect()
}
pub struct WeaselMuse;
impl WeaselMuse {
pub fn fit(config: &WeaselMuseConfig, x: &[Vec<Vec<f64>>], y: &[String]) -> WeaselMuseFitted {
assert!(!x.is_empty(), "Input must have at least one sample");
assert_eq!(x.len(), y.len(), "X and y must have same length");
let n_features = x[0].len();
let n_timestamps = x[0][0].len();
let mut sfa_models = Vec::new();
let total_features = if config.use_first_differences {
n_features * 2
} else {
n_features
};
for feat_idx in 0..total_features {
for &ws in &config.window_sizes {
let effective_ts_len = if feat_idx >= n_features {
n_timestamps - 1
} else {
n_timestamps
};
if ws > effective_ts_len {
continue;
}
let series = extract_feature_series(x, feat_idx, n_features);
let windows = extract_windows(&series, ws, config.window_step);
let n_windows_per = (effective_ts_len - ws) / config.window_step + 1;
let expanded_y: Vec<String> = y
.iter()
.flat_map(|l| std::iter::repeat_n(l.clone(), n_windows_per))
.collect();
let sfa_config = SfaConfig {
n_coefs: Some(config.word_size),
n_bins: config.n_bins,
strategy: config.strategy,
drop_sum: false,
anova: true,
norm_mean: true,
norm_std: true,
};
let sfa_fitted = Sfa::fit(&sfa_config, &windows, Some(&expanded_y));
sfa_models.push((feat_idx, ws, sfa_fitted));
}
}
let histograms = build_histograms(x, &sfa_models, config);
let selected_features = chi2_select(&histograms, y, config.chi2_threshold);
WeaselMuseFitted {
sfa_models,
config: config.clone(),
selected_features,
}
}
pub fn transform(
fitted: &WeaselMuseFitted,
x: &[Vec<Vec<f64>>],
) -> Vec<HashMap<String, usize>> {
let histograms = build_histograms(x, &fitted.sfa_models, &fitted.config);
histograms
.into_iter()
.map(|hist| {
hist.into_iter()
.filter(|(word, _)| fitted.selected_features.contains(word))
.collect()
})
.collect()
}
}
fn build_histograms(
x: &[Vec<Vec<f64>>],
sfa_models: &[(usize, usize, SfaFitted)],
config: &WeaselMuseConfig,
) -> Vec<HashMap<String, usize>> {
let n_samples = x.len();
let n_features = x[0].len();
let n_timestamps = x[0][0].len();
let mut histograms: Vec<HashMap<String, usize>> =
(0..n_samples).map(|_| HashMap::new()).collect();
for (feat_idx, ws, sfa_fitted) in sfa_models {
let effective_ts_len = if *feat_idx >= n_features {
n_timestamps - 1
} else {
n_timestamps
};
let n_windows_per = (effective_ts_len - ws) / config.window_step + 1;
let series = extract_feature_series(x, *feat_idx, n_features);
let windows = extract_windows(&series, *ws, config.window_step);
let symbolic = sfa_transform_symbolic(sfa_fitted, &windows);
for sample_idx in 0..n_samples {
let start = sample_idx * n_windows_per;
let end = start + n_windows_per;
let words: Vec<String> = symbolic[start..end]
.iter()
.map(|bins| {
let word: String = bins.iter().map(|&b| (b'a' + b) as char).collect();
format!("f{feat_idx}_w{ws}_{word}")
})
.collect();
let mut prev = String::new();
for word in words {
if word != prev {
prev.clone_from(&word);
*histograms[sample_idx].entry(word).or_insert(0) += 1;
}
}
}
}
histograms
}
fn chi2_select(histograms: &[HashMap<String, usize>], y: &[String], threshold: f64) -> Vec<String> {
let mut all_words: Vec<String> = histograms.iter().flat_map(|h| h.keys().cloned()).collect();
all_words.sort();
all_words.dedup();
let mut classes: Vec<&str> = y.iter().map(|s| s.as_str()).collect();
classes.sort();
classes.dedup();
let n = histograms.len() as f64;
let mut selected = Vec::new();
for word in &all_words {
let mut chi2 = 0.0;
for class in &classes {
let (mut a, mut b, mut c, mut d) = (0.0, 0.0, 0.0, 0.0);
for (i, hist) in histograms.iter().enumerate() {
let has = hist.contains_key(word);
let is_c = y[i].as_str() == *class;
match (has, is_c) {
(true, true) => a += 1.0,
(true, false) => b += 1.0,
(false, true) => c += 1.0,
(false, false) => d += 1.0,
}
}
let num = n * (a * d - b * c) * (a * d - b * c);
let den = (a + b) * (c + d) * (a + c) * (b + d);
if den > 0.0 {
chi2 += num / den;
}
}
if chi2 >= threshold {
selected.push(word.clone());
}
}
selected
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weasel_muse_basic() {
let config = WeaselMuseConfig::new(2, vec![3]);
let x = vec![
vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
],
vec![
vec![0.0, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 7.5],
vec![7.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.0],
],
vec![
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
],
vec![
vec![7.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.0],
vec![0.0, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 7.5],
],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = WeaselMuse::fit(&config, &x, &y);
let result = WeaselMuse::transform(&fitted, &x);
assert_eq!(result.len(), 4);
}
}