oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use std::collections::HashMap;

use crate::core::traits::DistanceMetric;

/// K-Nearest Neighbors classifier for time series.
///
/// Stores training data and computes distances at prediction time.
/// Generic over any distance metric (DTW, Euclidean, etc.).

#[derive(Debug, Clone)]
pub struct KnnConfig {
    pub n_neighbors: usize,
}

impl KnnConfig {
    pub fn new(n_neighbors: usize) -> Self {
        Self { n_neighbors }
    }
}

#[derive(Debug, Clone)]
pub struct KnnFitted<D: DistanceMetric> {
    pub x_train: Vec<Vec<f64>>,
    pub y_train: Vec<String>,
    pub metric: D,
    pub n_neighbors: usize,
}

pub struct Knn;

impl Knn {
    /// Fit the KNN classifier (stores training data).
    pub fn fit<D: DistanceMetric>(
        config: &KnnConfig,
        x: &[Vec<f64>],
        y: &[String],
        metric: D,
    ) -> KnnFitted<D> {
        assert!(!x.is_empty(), "Input must have at least one sample");
        assert_eq!(x.len(), y.len(), "X and y must have same length");
        assert!(config.n_neighbors >= 1, "n_neighbors must be >= 1");
        assert!(
            config.n_neighbors <= x.len(),
            "n_neighbors must not exceed training set size"
        );

        KnnFitted {
            x_train: x.to_vec(),
            y_train: y.to_vec(),
            metric,
            n_neighbors: config.n_neighbors,
        }
    }

    /// Predict labels for test samples.
    pub fn predict<D: DistanceMetric>(fitted: &KnnFitted<D>, x: &[Vec<f64>]) -> Vec<String> {
        x.iter()
            .map(|sample| predict_single(sample, fitted))
            .collect()
    }

    /// Compute classification accuracy.
    pub fn score<D: DistanceMetric>(fitted: &KnnFitted<D>, x: &[Vec<f64>], y: &[String]) -> f64 {
        let predictions = Self::predict(fitted, x);
        let correct = predictions
            .iter()
            .zip(y.iter())
            .filter(|(p, t)| p == t)
            .count();
        correct as f64 / y.len() as f64
    }
}

fn predict_single<D: DistanceMetric>(sample: &[f64], fitted: &KnnFitted<D>) -> String {
    // Compute distances to all training samples
    let mut distances: Vec<(f64, &str)> = fitted
        .x_train
        .iter()
        .zip(fitted.y_train.iter())
        .map(|(train_sample, label)| (fitted.metric.distance(sample, train_sample), label.as_str()))
        .collect();

    // Sort by distance
    distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    // Take k nearest neighbors and vote
    let mut votes: HashMap<&str, usize> = HashMap::new();
    for &(_, label) in distances.iter().take(fitted.n_neighbors) {
        *votes.entry(label).or_insert(0) += 1;
    }

    // Return majority class
    votes
        .into_iter()
        .max_by_key(|&(_, count)| count)
        .map(|(label, _)| label.to_string())
        .unwrap()
}

/// Simple Euclidean distance metric.
#[derive(Debug, Clone)]
pub struct EuclideanMetric;

impl DistanceMetric for EuclideanMetric {
    fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
        a.iter()
            .zip(b.iter())
            .map(|(&x, &y)| (x - y).powi(2))
            .sum::<f64>()
            .sqrt()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_knn_basic() {
        let config = KnnConfig::new(1);
        let x_train = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
        let y_train = vec!["A".to_string(), "A".to_string(), "B".to_string()];
        let metric = EuclideanMetric;

        let fitted = Knn::fit(&config, &x_train, &y_train, metric);

        // Nearest to [0.1, 0.1] should be [0, 0] → "A"
        let x_test = vec![vec![0.1, 0.1]];
        let pred = Knn::predict(&fitted, &x_test);
        assert_eq!(pred[0], "A");

        // Nearest to [1.9, 1.9] should be [2, 2] → "B"
        let x_test = vec![vec![1.9, 1.9]];
        let pred = Knn::predict(&fitted, &x_test);
        assert_eq!(pred[0], "B");
    }

    #[test]
    fn test_knn_k3_voting() {
        let config = KnnConfig::new(3);
        let x_train = vec![vec![0.0], vec![0.1], vec![0.2], vec![10.0]];
        let y_train = vec![
            "A".to_string(),
            "A".to_string(),
            "B".to_string(),
            "B".to_string(),
        ];
        let metric = EuclideanMetric;

        let fitted = Knn::fit(&config, &x_train, &y_train, metric);
        // Nearest 3 to [0.05] are [0.0, 0.1, 0.2] → "A", "A", "B" → "A" wins
        let pred = Knn::predict(&fitted, &[vec![0.05]]);
        assert_eq!(pred[0], "A");
    }

    #[test]
    fn test_knn_score() {
        let config = KnnConfig::new(1);
        let x_train = vec![vec![0.0], vec![10.0]];
        let y_train = vec!["A".to_string(), "B".to_string()];
        let metric = EuclideanMetric;

        let fitted = Knn::fit(&config, &x_train, &y_train, metric);
        let x_test = vec![vec![0.0], vec![10.0]];
        let y_test = vec!["A".to_string(), "B".to_string()];
        let score = Knn::score(&fitted, &x_test, &y_test);
        assert!((score - 1.0).abs() < 1e-10);
    }
}