oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::HashMap;

/// Time Series Forest classifier.
///
/// Pipeline:
/// 1. Extract random intervals from time series
/// 2. Compute summary statistics (mean, std, slope) for each interval
/// 3. Build an ensemble of decision trees on the features

#[derive(Debug, Clone)]
pub struct TimeSeriesForestConfig {
    pub n_estimators: usize,
    pub min_interval_length: usize,
    pub random_seed: Option<u64>,
}

impl TimeSeriesForestConfig {
    pub fn new(n_estimators: usize) -> Self {
        Self {
            n_estimators,
            min_interval_length: 3,
            random_seed: None,
        }
    }
}

#[derive(Debug, Clone)]
struct DecisionStump {
    feature_index: usize,
    threshold: f64,
    left_class: String,
    right_class: String,
}

#[derive(Debug, Clone)]
pub(crate) struct TreeModel {
    intervals: Vec<(usize, usize)>, // (start, end) of random intervals
    stump: DecisionStump,
}

#[derive(Debug, Clone)]
pub struct TimeSeriesForestFitted {
    pub(crate) trees: Vec<TreeModel>,
}

pub struct TimeSeriesForest;

impl TimeSeriesForest {
    /// Fit the Time Series Forest.
    pub fn fit(
        config: &TimeSeriesForestConfig,
        x: &[Vec<f64>],
        y: &[String],
    ) -> TimeSeriesForestFitted {
        assert!(!x.is_empty(), "Input must have at least one sample");
        assert_eq!(x.len(), y.len(), "X and y must have same length");

        let n_timestamps = x[0].len();
        let mut rng = match config.random_seed {
            Some(seed) => StdRng::seed_from_u64(seed),
            None => StdRng::from_entropy(),
        };

        let n_intervals = ((n_timestamps as f64).sqrt().ceil() as usize).max(1);

        // Build class index map once for incremental Gini
        let class_map = build_class_map(y);

        let trees: Vec<TreeModel> = (0..config.n_estimators)
            .map(|_| {
                // Generate random intervals
                let intervals: Vec<(usize, usize)> = (0..n_intervals)
                    .map(|_| {
                        let start = rng.gen_range(0..n_timestamps - config.min_interval_length);
                        let max_end = n_timestamps;
                        let end = rng.gen_range(start + config.min_interval_length..=max_end);
                        (start, end)
                    })
                    .collect();

                // Extract features for all samples
                let features: Vec<Vec<f64>> = x
                    .iter()
                    .map(|sample| extract_interval_features(sample, &intervals))
                    .collect();

                // Build a simple decision stump
                let stump = build_stump(&features, y, &class_map);

                TreeModel { intervals, stump }
            })
            .collect();

        TimeSeriesForestFitted { trees }
    }

    /// Predict class labels using majority voting.
    pub fn predict(fitted: &TimeSeriesForestFitted, x: &[Vec<f64>]) -> Vec<String> {
        #[cfg(feature = "parallel")]
        {
            use rayon::prelude::*;
            return x
                .par_iter()
                .map(|sample| predict_single(fitted, sample))
                .collect();
        }

        #[cfg(not(feature = "parallel"))]
        x.iter()
            .map(|sample| predict_single(fitted, sample))
            .collect()
    }

    /// Compute classification accuracy.
    pub fn score(fitted: &TimeSeriesForestFitted, x: &[Vec<f64>], y: &[String]) -> f64 {
        let predictions = Self::predict(fitted, x);
        let correct = predictions
            .iter()
            .zip(y.iter())
            .filter(|(p, t)| p == t)
            .count();
        correct as f64 / y.len() as f64
    }
}

fn predict_single(fitted: &TimeSeriesForestFitted, sample: &[f64]) -> String {
    let mut votes: HashMap<&str, usize> = HashMap::new();
    for tree in &fitted.trees {
        let features = extract_interval_features(sample, &tree.intervals);
        let pred = predict_stump(&tree.stump, &features);
        *votes
            .entry(if pred {
                tree.stump.left_class.as_str()
            } else {
                tree.stump.right_class.as_str()
            })
            .or_insert(0) += 1;
    }
    votes
        .into_iter()
        .max_by_key(|&(_, count)| count)
        .map(|(class, _)| class.to_string())
        .unwrap()
}

/// Extract features (mean, std, slope) from intervals.
fn extract_interval_features(sample: &[f64], intervals: &[(usize, usize)]) -> Vec<f64> {
    let mut features = Vec::with_capacity(intervals.len() * 3);
    for &(start, end) in intervals {
        let slice = &sample[start..end];
        let n = slice.len() as f64;

        // Mean
        let mean = slice.iter().sum::<f64>() / n;

        // Standard deviation
        let var = slice
            .iter()
            .map(|&v| {
                let d = v - mean;
                d * d
            })
            .sum::<f64>()
            / n;
        let std = var.sqrt();

        // Slope (linear regression coefficient)
        let x_mean = (n - 1.0) / 2.0;
        let mut num = 0.0;
        let mut den = 0.0;
        for (i, &v) in slice.iter().enumerate() {
            let xi = i as f64 - x_mean;
            num += xi * (v - mean);
            den += xi * xi;
        }
        let slope = if den > 0.0 { num / den } else { 0.0 };

        features.push(mean);
        features.push(std);
        features.push(slope);
    }
    features
}

/// Maps class labels to contiguous indices 0..n_classes.
struct ClassMap {
    n_classes: usize,
    indices: Vec<usize>, // index for each sample, parallel to y
}

fn build_class_map(y: &[String]) -> ClassMap {
    let mut unique: Vec<&str> = y.iter().map(|s| s.as_str()).collect();
    unique.sort();
    unique.dedup();
    let n_classes = unique.len();

    let label_to_idx: HashMap<&str, usize> =
        unique.iter().enumerate().map(|(i, &s)| (s, i)).collect();

    let indices: Vec<usize> = y.iter().map(|s| label_to_idx[s.as_str()]).collect();

    ClassMap { n_classes, indices }
}

/// Build a decision stump using incremental Gini computation.
/// Instead of allocating Vec<&str> + HashMap per split point, we maintain
/// left_counts/right_counts arrays and slide the split point.
fn build_stump(features: &[Vec<f64>], y: &[String], class_map: &ClassMap) -> DecisionStump {
    let n_features = features[0].len();
    let n_samples = features.len();
    let n_classes = class_map.n_classes;

    let mut best_gini = f64::INFINITY;
    let mut best_feature = 0;
    let mut best_threshold = 0.0;

    // Pre-allocate sort indices and count arrays (reused across features)
    let mut sort_indices: Vec<usize> = (0..n_samples).collect();
    let mut left_counts = vec![0u32; n_classes];
    let mut right_counts = vec![0u32; n_classes];

    for f_idx in 0..n_features {
        // Sort sample indices by this feature value
        sort_indices
            .iter_mut()
            .enumerate()
            .for_each(|(i, v)| *v = i);
        sort_indices.sort_by(|&a, &b| features[a][f_idx].partial_cmp(&features[b][f_idx]).unwrap());

        // Initialize: all samples in right, none in left
        left_counts.iter_mut().for_each(|c| *c = 0);
        right_counts.iter_mut().for_each(|c| *c = 0);
        for &si in &sort_indices {
            right_counts[class_map.indices[si]] += 1;
        }

        let mut n_left = 0u32;
        let mut n_right = n_samples as u32;

        // Slide split point: move one sample at a time from right to left
        for i in 0..n_samples - 1 {
            let si = sort_indices[i];
            let class_idx = class_map.indices[si];

            // Move sample from right to left
            left_counts[class_idx] += 1;
            right_counts[class_idx] -= 1;
            n_left += 1;
            n_right -= 1;

            // Skip if same value as next (no valid split between equal values)
            let val = features[si][f_idx];
            let next_val = features[sort_indices[i + 1]][f_idx];
            if (val - next_val).abs() < 1e-15 {
                continue;
            }

            // Compute weighted Gini from counts
            let gini_left = gini_from_counts(&left_counts, n_left);
            let gini_right = gini_from_counts(&right_counts, n_right);
            let gini = (n_left as f64 * gini_left + n_right as f64 * gini_right) / n_samples as f64;

            if gini < best_gini {
                best_gini = gini;
                best_feature = f_idx;
                best_threshold = (val + next_val) / 2.0;
            }
        }
    }

    // Determine majority class on each side
    let left_labels: Vec<&str> = features
        .iter()
        .zip(y.iter())
        .filter(|(f, _)| f[best_feature] <= best_threshold)
        .map(|(_, l)| l.as_str())
        .collect();
    let right_labels: Vec<&str> = features
        .iter()
        .zip(y.iter())
        .filter(|(f, _)| f[best_feature] > best_threshold)
        .map(|(_, l)| l.as_str())
        .collect();

    let left_class = majority_class(&left_labels)
        .unwrap_or_else(|| y[0].as_str())
        .to_string();
    let right_class = majority_class(&right_labels)
        .unwrap_or_else(|| y[0].as_str())
        .to_string();

    DecisionStump {
        feature_index: best_feature,
        threshold: best_threshold,
        left_class,
        right_class,
    }
}

/// Compute Gini impurity from class counts: 1 - sum((count/n)^2)
#[inline]
fn gini_from_counts(counts: &[u32], n: u32) -> f64 {
    if n == 0 {
        return 0.0;
    }
    let n_f = n as f64;
    let sum_sq: f64 = counts
        .iter()
        .map(|&c| {
            let p = c as f64 / n_f;
            p * p
        })
        .sum();
    1.0 - sum_sq
}

fn predict_stump(stump: &DecisionStump, features: &[f64]) -> bool {
    features[stump.feature_index] <= stump.threshold
}

fn majority_class<'a>(labels: &[&'a str]) -> Option<&'a str> {
    let mut counts: HashMap<&str, usize> = HashMap::new();
    for &l in labels {
        *counts.entry(l).or_insert(0) += 1;
    }
    counts
        .into_iter()
        .max_by_key(|&(_, count)| count)
        .map(|(class, _)| class)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_tsf_basic() {
        let config = TimeSeriesForestConfig {
            n_estimators: 10,
            random_seed: Some(42),
            ..TimeSeriesForestConfig::new(10)
        };
        let x = vec![
            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0],
            vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
            vec![8.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
        ];
        let y = vec![
            "A".to_string(),
            "A".to_string(),
            "B".to_string(),
            "B".to_string(),
        ];
        let fitted = TimeSeriesForest::fit(&config, &x, &y);
        let predictions = TimeSeriesForest::predict(&fitted, &x);
        assert_eq!(predictions.len(), 4);
    }

    #[test]
    fn test_tsf_score() {
        let config = TimeSeriesForestConfig {
            n_estimators: 20,
            random_seed: Some(42),
            ..TimeSeriesForestConfig::new(20)
        };
        let x = vec![
            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
            vec![0.0, 1.0, 2.0, 3.0, 4.0, 6.0],
            vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
            vec![6.0, 4.0, 3.0, 2.0, 1.0, 0.0],
        ];
        let y = vec![
            "A".to_string(),
            "A".to_string(),
            "B".to_string(),
            "B".to_string(),
        ];
        let fitted = TimeSeriesForest::fit(&config, &x, &y);
        let score = TimeSeriesForest::score(&fitted, &x, &y);
        assert!((0.0..=1.0).contains(&score));
    }

    #[test]
    fn test_interval_features() {
        let sample = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        let intervals = vec![(0, 3), (2, 5)];
        let features = extract_interval_features(&sample, &intervals);
        // 2 intervals * 3 features = 6
        assert_eq!(features.len(), 6);
        // First interval [1,2,3]: mean=2, std=sqrt(2/3), slope=1
        assert!((features[0] - 2.0).abs() < 1e-10);
    }
}