oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use std::collections::HashMap;

use crate::approximation::sfa::{Sfa, SfaConfig, SfaFitted};
use crate::core::config::BinStrategy;
use crate::core::traits::FittableTransformer;

/// WEASEL (Word ExtrAction for time SEries cLassification).
///
/// Multi-window BOSS with chi-squared feature selection.
/// Pipeline:
/// 1. For each window size, apply BOSS-like SFA windowing
/// 2. Build per-window histograms with window-size-prefixed words
/// 3. Merge all histograms
/// 4. Apply chi-squared feature selection to prune irrelevant words

#[derive(Debug, Clone)]
pub struct WeaselConfig {
    pub word_size: usize,
    pub n_bins: usize,
    pub window_sizes: Vec<usize>,
    pub window_step: usize,
    pub strategy: BinStrategy,
    pub norm_mean: bool,
    pub norm_std: bool,
    pub drop_sum: bool,
    pub anova: bool,
    pub chi2_threshold: f64,
}

impl WeaselConfig {
    pub fn new(word_size: usize, window_sizes: Vec<usize>) -> Self {
        Self {
            word_size,
            n_bins: 4,
            window_sizes,
            window_step: 1,
            strategy: BinStrategy::Quantile,
            norm_mean: true,
            norm_std: true,
            drop_sum: false,
            anova: true,
            chi2_threshold: 2.0,
        }
    }
}

#[derive(Debug, Clone)]
pub struct WeaselFitted {
    /// One SFA fitted model per window size
    pub sfa_models: Vec<(usize, SfaFitted)>,
    pub config: WeaselConfig,
    /// Selected features (words that passed chi-squared test)
    pub selected_features: Vec<String>,
}

pub struct Weasel;

impl Weasel {
    /// Fit the WEASEL model.
    pub fn fit(config: &WeaselConfig, x: &[Vec<f64>], y: &[String]) -> WeaselFitted {
        assert!(!x.is_empty(), "Input must have at least one sample");
        assert_eq!(x.len(), y.len(), "X and y must have same length");

        let n_timestamps = x[0].len();

        // Fit SFA for each window size
        let mut sfa_models = Vec::new();
        for &ws in &config.window_sizes {
            if ws > n_timestamps {
                continue;
            }

            let windows = extract_all_windows(x, ws, config.window_step);
            let n_windows_per_sample = (n_timestamps - ws) / config.window_step + 1;

            // Expand labels
            let expanded_y: Vec<String> = y
                .iter()
                .flat_map(|l| std::iter::repeat_n(l.clone(), n_windows_per_sample))
                .collect();

            let sfa_config = SfaConfig {
                n_coefs: Some(config.word_size),
                n_bins: config.n_bins,
                strategy: config.strategy,
                drop_sum: config.drop_sum,
                anova: config.anova,
                norm_mean: config.norm_mean,
                norm_std: config.norm_std,
            };

            let sfa_fitted = Sfa::fit(&sfa_config, &windows, Some(&expanded_y));
            sfa_models.push((ws, sfa_fitted));
        }

        // Build histograms for all training samples
        let histograms = build_histograms(x, &sfa_models, config);

        // Chi-squared feature selection
        let selected_features = chi2_feature_selection(&histograms, y, config.chi2_threshold);

        WeaselFitted {
            sfa_models,
            config: config.clone(),
            selected_features,
        }
    }

    /// Transform time series into WEASEL feature vectors.
    ///
    /// Returns histograms filtered to only selected features.
    pub fn transform(fitted: &WeaselFitted, x: &[Vec<f64>]) -> Vec<HashMap<String, usize>> {
        let histograms = build_histograms(x, &fitted.sfa_models, &fitted.config);

        // Filter to selected features only
        histograms
            .into_iter()
            .map(|hist| {
                hist.into_iter()
                    .filter(|(word, _)| fitted.selected_features.contains(word))
                    .collect()
            })
            .collect()
    }

    /// Fit and transform in one step.
    pub fn fit_transform(
        config: &WeaselConfig,
        x: &[Vec<f64>],
        y: &[String],
    ) -> Vec<HashMap<String, usize>> {
        let fitted = Self::fit(config, x, y);
        Self::transform(&fitted, x)
    }
}

/// Build histograms for all samples across all window sizes.
fn build_histograms(
    x: &[Vec<f64>],
    sfa_models: &[(usize, SfaFitted)],
    config: &WeaselConfig,
) -> Vec<HashMap<String, usize>> {
    let n_samples = x.len();
    let n_timestamps = x[0].len();
    let mut histograms: Vec<HashMap<String, usize>> =
        (0..n_samples).map(|_| HashMap::new()).collect();

    for (ws, sfa_fitted) in sfa_models {
        let n_windows_per_sample = (n_timestamps - ws) / config.window_step + 1;

        let windows = extract_all_windows(x, *ws, config.window_step);
        let symbolic = crate::approximation::sfa::sfa_transform_symbolic(sfa_fitted, &windows);

        for sample_idx in 0..n_samples {
            let start = sample_idx * n_windows_per_sample;
            let end = start + n_windows_per_sample;

            let words: Vec<String> = symbolic[start..end]
                .iter()
                .map(|bins| {
                    // Prefix word with window size for disambiguation
                    let word: String = bins.iter().map(|&b| (b'a' + b) as char).collect();
                    format!("{ws}_{word}")
                })
                .collect();

            // Apply numerosity reduction
            let reduced = {
                let mut result = Vec::new();
                let mut prev = String::new();
                for word in words {
                    if word != prev {
                        prev.clone_from(&word);
                        result.push(word);
                    }
                }
                result
            };

            for word in reduced {
                *histograms[sample_idx].entry(word).or_insert(0) += 1;
            }
        }
    }

    histograms
}

/// Chi-squared feature selection.
/// Returns the set of words whose chi-squared statistic exceeds the threshold.
fn chi2_feature_selection(
    histograms: &[HashMap<String, usize>],
    y: &[String],
    threshold: f64,
) -> Vec<String> {
    // Collect all unique words
    let mut all_words: Vec<String> = histograms.iter().flat_map(|h| h.keys().cloned()).collect();
    all_words.sort();
    all_words.dedup();

    // Get unique classes
    let mut classes: Vec<&str> = y.iter().map(|s| s.as_str()).collect();
    classes.sort();
    classes.dedup();

    let n = histograms.len() as f64;

    let mut selected = Vec::new();

    for word in &all_words {
        let mut chi2 = 0.0;

        for class in &classes {
            // Count samples with this word in this class
            let mut a = 0.0; // has word, is class
            let mut b = 0.0; // has word, not class
            let mut c = 0.0; // no word, is class
            let mut d = 0.0; // no word, not class

            for (i, hist) in histograms.iter().enumerate() {
                let has_word = hist.contains_key(word);
                let is_class = y[i].as_str() == *class;

                match (has_word, is_class) {
                    (true, true) => a += 1.0,
                    (true, false) => b += 1.0,
                    (false, true) => c += 1.0,
                    (false, false) => d += 1.0,
                }
            }

            let numerator = n * (a * d - b * c) * (a * d - b * c);
            let denominator = (a + b) * (c + d) * (a + c) * (b + d);
            if denominator > 0.0 {
                chi2 += numerator / denominator;
            }
        }

        if chi2 >= threshold {
            selected.push(word.clone());
        }
    }

    selected
}

/// Extract sliding windows from all samples.
fn extract_all_windows(x: &[Vec<f64>], window_size: usize, window_step: usize) -> Vec<Vec<f64>> {
    x.iter()
        .flat_map(|sample| {
            let n = sample.len();
            (0..=n - window_size)
                .step_by(window_step)
                .map(move |i| sample[i..i + window_size].to_vec())
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_weasel_basic() {
        let config = WeaselConfig::new(2, vec![4, 6]);
        let x = vec![
            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
            vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
            vec![0.0, 2.0, 4.0, 6.0, 4.0, 2.0, 0.0, -2.0],
            vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0],
        ];
        let y = vec![
            "A".to_string(),
            "B".to_string(),
            "A".to_string(),
            "B".to_string(),
        ];
        let result = Weasel::fit_transform(&config, &x, &y);
        assert_eq!(result.len(), 4);
    }

    #[test]
    fn test_weasel_fit_then_transform() {
        let config = WeaselConfig::new(2, vec![3, 4]);
        let x = vec![
            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
            vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
            vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0],
            vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
        ];
        let y = vec![
            "A".to_string(),
            "B".to_string(),
            "A".to_string(),
            "B".to_string(),
        ];
        let fitted = Weasel::fit(&config, &x, &y);
        let result = Weasel::transform(&fitted, &x);
        assert_eq!(result.len(), 4);
    }

    #[test]
    fn test_chi2_selection() {
        // Words that discriminate classes should be selected
        let mut h1 = HashMap::new();
        h1.insert("good".to_string(), 5);
        let mut h2 = HashMap::new();
        h2.insert("bad".to_string(), 5);
        let histograms = vec![h1.clone(), h2.clone(), h1, h2];
        let y = vec![
            "A".to_string(),
            "B".to_string(),
            "A".to_string(),
            "B".to_string(),
        ];
        let selected = chi2_feature_selection(&histograms, &y, 0.1);
        // Both words should be selected since they discriminate perfectly
        assert!(selected.contains(&"good".to_string()));
        assert!(selected.contains(&"bad".to_string()));
    }
}