oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
//! Wall-clock benchmark of all oxits algorithms.
//!
//! Run:
//!     cargo run --release --example benchmark --features decomposition
//!
//! Outputs: test_harness/benchmark_results_rust.json

use std::collections::HashMap;
use std::time::Instant;

use rand::rngs::StdRng;
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};

use oxits::approximation::dft::{Dft, DftConfig};
use oxits::approximation::paa::{Paa, PaaConfig};
use oxits::approximation::sax::{Sax, SaxConfig};
use oxits::approximation::sfa::Sfa;
use oxits::approximation::sfa::SfaConfig;
use oxits::bag_of_words::bag_of_words::{BagOfWords, BagOfWordsConfig};
use oxits::classification::bossvs::{Bossvs, BossvsConfig};
use oxits::classification::knn::{EuclideanMetric, Knn, KnnConfig};
use oxits::classification::learning_shapelets::{LearningShapelets, LearningShapeletsConfig};
use oxits::classification::saxvsm::{Saxvsm, SaxvsmConfig};
use oxits::classification::time_series_forest::{TimeSeriesForest, TimeSeriesForestConfig};
use oxits::classification::tsbf::{Tsbf, TsbfConfig};
use oxits::decomposition::ssa::{Ssa, SsaConfig};
use oxits::image::gaf::{Gaf, GafConfig};
use oxits::image::mtf::{Mtf, MtfConfig};
use oxits::image::recurrence_plot::{RecurrencePlot, RecurrencePlotConfig};
use oxits::metrics::dtw::{dtw_classic, dtw_fast, dtw_sakoe_chiba};
use oxits::preprocessing::discretizer::{KBinsDiscretizer, KBinsDiscretizerConfig};
use oxits::preprocessing::scaler::{
    MinMaxScaler, MinMaxScalerConfig, StandardScaler, StandardScalerConfig,
};
use oxits::transformation::bag_of_patterns::{BagOfPatterns, BagOfPatternsConfig};
use oxits::transformation::boss::{Boss, BossConfig};
use oxits::transformation::rocket::{Rocket, RocketConfig};
use oxits::transformation::shapelet_transform::{ShapeletTransform, ShapeletTransformConfig};
use oxits::FittableTransformer;
use oxits::GafMethod;
use oxits::Transformer;

const N_RUNS: usize = 51;
const N_WARMUP: usize = 5;

/// Run `f` with warmup, then N_RUNS timed iterations, return P25 elapsed seconds.
/// P25 is more robust than median for workloads with upward-skewed noise (rayon scheduling).
fn median_time<F: FnMut()>(mut f: F) -> f64 {
    for _ in 0..N_WARMUP {
        f();
    }
    let mut times = Vec::with_capacity(N_RUNS);
    for _ in 0..N_RUNS {
        let start = Instant::now();
        f();
        times.push(start.elapsed().as_secs_f64());
    }
    times.sort_by(|a, b| a.partial_cmp(b).unwrap());
    times[N_RUNS / 4] // P25: filters upward noise from thread scheduling
}

fn generate_data(seed: u64, n_samples: usize, n_timestamps: usize) -> Vec<Vec<f64>> {
    let mut rng = StdRng::seed_from_u64(seed);
    let normal = Normal::new(0.0, 1.0).unwrap();
    (0..n_samples)
        .map(|_| (0..n_timestamps).map(|_| normal.sample(&mut rng)).collect())
        .collect()
}

fn generate_labels(n_samples: usize) -> Vec<String> {
    let half = n_samples / 2;
    let mut labels = Vec::with_capacity(n_samples);
    for _ in 0..half {
        labels.push("A".to_string());
    }
    for _ in half..n_samples {
        labels.push("B".to_string());
    }
    labels
}

fn main() {
    let mut results: HashMap<String, f64> = HashMap::new();

    // ── Preprocessing ──────────────────────────────────────────────
    println!("Benchmarking preprocessing...");
    {
        let x = generate_data(42, 100, 500);

        // StandardScaler
        let config = StandardScalerConfig::new();
        let t = median_time(|| {
            StandardScaler::transform(&config, &x);
        });
        results.insert("preprocessing_standard_scaler".into(), t);

        // MinMaxScaler
        let config = MinMaxScalerConfig::new();
        let t = median_time(|| {
            MinMaxScaler::transform(&config, &x);
        });
        results.insert("preprocessing_minmax_scaler".into(), t);

        // KBinsDiscretizer
        let config = KBinsDiscretizerConfig::new(4);
        let t = median_time(|| {
            KBinsDiscretizer::transform(&config, &x);
        });
        results.insert("preprocessing_discretizer".into(), t);
    }

    // ── Approximation ─────────────────────────────────────────────
    println!("Benchmarking approximation...");
    {
        let x = generate_data(42, 100, 500);

        // PAA
        let config = PaaConfig::new(50);
        let t = median_time(|| {
            Paa::transform(&config, &x);
        });
        results.insert("approximation_paa".into(), t);

        // SAX
        let config = SaxConfig::new(4);
        let t = median_time(|| {
            Sax::transform(&config, &x);
        });
        results.insert("approximation_sax".into(), t);

        // DFT
        let config = DftConfig {
            n_coefs: Some(20),
            ..DftConfig::new()
        };
        let t = median_time(|| {
            Dft::transform(&config, &x);
        });
        results.insert("approximation_dft".into(), t);

        // SFA (fit + transform)
        let config = SfaConfig {
            n_coefs: Some(20),
            n_bins: 4,
            ..SfaConfig::new()
        };
        let t = median_time(|| {
            Sfa::fit_transform(&config, &x, None);
        });
        results.insert("approximation_sfa".into(), t);
    }

    // ── Metrics (DTW) ─────────────────────────────────────────────
    println!("Benchmarking metrics...");
    {
        for n in [100, 500, 1000] {
            let data = generate_data(42, 2, n);
            let a = &data[0];
            let b = &data[1];
            let t = median_time(|| {
                dtw_classic(a, b);
            });
            results.insert(format!("metrics_dtw_classic_{n}"), t);
        }

        // Sakoe-Chiba (n=500)
        let data = generate_data(42, 2, 500);
        let a = &data[0];
        let b = &data[1];
        let t = median_time(|| {
            dtw_sakoe_chiba(a, b, 50);
        });
        results.insert("metrics_dtw_sakoe_chiba".into(), t);

        // Fast DTW (n=500)
        let t = median_time(|| {
            dtw_fast(a, b, 4, 2);
        });
        results.insert("metrics_dtw_fast".into(), t);
    }

    // ── Bag of Words ──────────────────────────────────────────────
    println!("Benchmarking bag_of_words...");
    {
        let x = generate_data(42, 50, 200);
        let config = BagOfWordsConfig::new(10, 4);
        let t = median_time(|| {
            BagOfWords::transform(&config, &x);
        });
        results.insert("bag_of_words".into(), t);
    }

    // ── Image ─────────────────────────────────────────────────────
    println!("Benchmarking image...");
    {
        let x = generate_data(42, 50, 100);

        // GASF
        let config = GafConfig::new();
        let t = median_time(|| {
            Gaf::transform(&config, &x);
        });
        results.insert("image_gasf".into(), t);

        // GADF
        let config = GafConfig {
            method: GafMethod::Difference,
            ..GafConfig::new()
        };
        let t = median_time(|| {
            Gaf::transform(&config, &x);
        });
        results.insert("image_gadf".into(), t);

        // MTF
        let config = MtfConfig::new();
        let t = median_time(|| {
            Mtf::transform(&config, &x);
        });
        results.insert("image_mtf".into(), t);

        // RecurrencePlot
        let config = RecurrencePlotConfig::new();
        let t = median_time(|| {
            RecurrencePlot::transform(&config, &x);
        });
        results.insert("image_recurrence_plot".into(), t);
    }

    // ── Decomposition ─────────────────────────────────────────────
    println!("Benchmarking decomposition...");
    {
        let x = generate_data(42, 20, 200);
        let config = SsaConfig::new(10);
        let t = median_time(|| {
            Ssa::transform(&config, &x);
        });
        results.insert("decomposition_ssa".into(), t);
    }

    // ── Transformation ────────────────────────────────────────────
    println!("Benchmarking transformation...");
    {
        let x = generate_data(42, 50, 300);
        let y = generate_labels(50);

        // BOSS (fit + transform)
        let config = BossConfig::new(10, 4);
        let t = median_time(|| {
            Boss::fit_transform(&config, &x, None);
        });
        results.insert("transformation_boss".into(), t);

        // ROCKET (fit + transform)
        let config = RocketConfig {
            n_kernels: 500,
            random_seed: Some(42),
        };
        let t = median_time(|| {
            Rocket::fit_transform(&config, &x);
        });
        results.insert("transformation_rocket".into(), t);

        // ShapeletTransform (fit + transform)
        let config = ShapeletTransformConfig {
            n_shapelets: 5,
            random_seed: Some(42),
            ..ShapeletTransformConfig::new(5)
        };
        let t = median_time(|| {
            ShapeletTransform::fit_transform(&config, &x, &y);
        });
        results.insert("transformation_shapelet".into(), t);

        // BagOfPatterns (transform)
        let config = BagOfPatternsConfig::new(10, 4);
        let t = median_time(|| {
            BagOfPatterns::transform(&config, &x);
        });
        results.insert("transformation_bag_of_patterns".into(), t);
    }

    // ── Classification ────────────────────────────────────────────
    println!("Benchmarking classification...");
    {
        let x_train = generate_data(42, 50, 200);
        let y_train = generate_labels(50);
        let x_test = generate_data(99, 20, 200);

        // KNN Euclidean (fit + predict)
        let knn_config = KnnConfig::new(3);
        let t = median_time(|| {
            let fitted = Knn::fit(&knn_config, &x_train, &y_train, EuclideanMetric);
            Knn::predict(&fitted, &x_test);
        });
        results.insert("classification_knn_euclidean".into(), t);

        // BOSSVS (fit + predict)
        let config = BossvsConfig::new(10);
        let t = median_time(|| {
            let fitted = Bossvs::fit(&config, &x_train, &y_train);
            Bossvs::predict(&fitted, &x_test);
        });
        results.insert("classification_bossvs".into(), t);

        // SAXVSM (fit + predict)
        let config = SaxvsmConfig::new(10, 4);
        let t = median_time(|| {
            let fitted = Saxvsm::fit(&config, &x_train, &y_train);
            Saxvsm::predict(&fitted, &x_test);
        });
        results.insert("classification_saxvsm".into(), t);

        // TimeSeriesForest (fit + predict)
        let config = TimeSeriesForestConfig {
            n_estimators: 50,
            random_seed: Some(42),
            ..TimeSeriesForestConfig::new(50)
        };
        let t = median_time(|| {
            let fitted = TimeSeriesForest::fit(&config, &x_train, &y_train);
            TimeSeriesForest::predict(&fitted, &x_test);
        });
        results.insert("classification_tsf".into(), t);

        // TSBF (fit + predict) — Rust-only, no pyts equivalent
        let config = TsbfConfig {
            n_estimators: 50,
            random_seed: Some(42),
            ..TsbfConfig::new(50)
        };
        let t = median_time(|| {
            let fitted = Tsbf::fit(&config, &x_train, &y_train);
            Tsbf::predict(&fitted, &x_test);
        });
        results.insert("classification_tsbf".into(), t);

        // LearningShapelets (fit + predict) — Rust-only, no pyts equivalent
        let config = LearningShapeletsConfig {
            n_shapelets_per_size: 3,
            shapelet_sizes: vec![3, 5],
            learning_rate: 0.01,
            n_epochs: 50,
            random_seed: Some(42),
            ..LearningShapeletsConfig::new()
        };
        let t = median_time(|| {
            let fitted = LearningShapelets::fit(&config, &x_train, &y_train);
            LearningShapelets::predict(&fitted, &x_test);
        });
        results.insert("classification_learning_shapelets".into(), t);
    }

    // ── Write results ─────────────────────────────────────────────
    let out_path = "test_harness/benchmark_results_rust.json";
    let json = serde_json::to_string_pretty(&results).expect("Failed to serialize results");
    std::fs::write(out_path, &json).expect("Failed to write results file");

    println!("\nWrote {} benchmarks to {}", results.len(), out_path);

    // Print quick summary
    let mut sorted: Vec<_> = results.iter().collect();
    sorted.sort_by_key(|(k, _)| (*k).clone());
    for (name, t) in &sorted {
        println!("  {:<45} {:>10.3} ms", name, *t * 1000.0);
    }
}