use std::collections::HashMap;
use std::time::Instant;
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use oxits::approximation::dft::{Dft, DftConfig};
use oxits::approximation::paa::{Paa, PaaConfig};
use oxits::approximation::sax::{Sax, SaxConfig};
use oxits::approximation::sfa::Sfa;
use oxits::approximation::sfa::SfaConfig;
use oxits::bag_of_words::bag_of_words::{BagOfWords, BagOfWordsConfig};
use oxits::classification::bossvs::{Bossvs, BossvsConfig};
use oxits::classification::knn::{EuclideanMetric, Knn, KnnConfig};
use oxits::classification::learning_shapelets::{LearningShapelets, LearningShapeletsConfig};
use oxits::classification::saxvsm::{Saxvsm, SaxvsmConfig};
use oxits::classification::time_series_forest::{TimeSeriesForest, TimeSeriesForestConfig};
use oxits::classification::tsbf::{Tsbf, TsbfConfig};
use oxits::decomposition::ssa::{Ssa, SsaConfig};
use oxits::image::gaf::{Gaf, GafConfig};
use oxits::image::mtf::{Mtf, MtfConfig};
use oxits::image::recurrence_plot::{RecurrencePlot, RecurrencePlotConfig};
use oxits::metrics::dtw::{dtw_classic, dtw_fast, dtw_sakoe_chiba};
use oxits::preprocessing::discretizer::{KBinsDiscretizer, KBinsDiscretizerConfig};
use oxits::preprocessing::scaler::{
MinMaxScaler, MinMaxScalerConfig, StandardScaler, StandardScalerConfig,
};
use oxits::transformation::bag_of_patterns::{BagOfPatterns, BagOfPatternsConfig};
use oxits::transformation::boss::{Boss, BossConfig};
use oxits::transformation::rocket::{Rocket, RocketConfig};
use oxits::transformation::shapelet_transform::{ShapeletTransform, ShapeletTransformConfig};
use oxits::FittableTransformer;
use oxits::GafMethod;
use oxits::Transformer;
const N_RUNS: usize = 51;
const N_WARMUP: usize = 5;
fn median_time<F: FnMut()>(mut f: F) -> f64 {
for _ in 0..N_WARMUP {
f();
}
let mut times = Vec::with_capacity(N_RUNS);
for _ in 0..N_RUNS {
let start = Instant::now();
f();
times.push(start.elapsed().as_secs_f64());
}
times.sort_by(|a, b| a.partial_cmp(b).unwrap());
times[N_RUNS / 4] }
fn generate_data(seed: u64, n_samples: usize, n_timestamps: usize) -> Vec<Vec<f64>> {
let mut rng = StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..n_samples)
.map(|_| (0..n_timestamps).map(|_| normal.sample(&mut rng)).collect())
.collect()
}
fn generate_labels(n_samples: usize) -> Vec<String> {
let half = n_samples / 2;
let mut labels = Vec::with_capacity(n_samples);
for _ in 0..half {
labels.push("A".to_string());
}
for _ in half..n_samples {
labels.push("B".to_string());
}
labels
}
fn main() {
let mut results: HashMap<String, f64> = HashMap::new();
println!("Benchmarking preprocessing...");
{
let x = generate_data(42, 100, 500);
let config = StandardScalerConfig::new();
let t = median_time(|| {
StandardScaler::transform(&config, &x);
});
results.insert("preprocessing_standard_scaler".into(), t);
let config = MinMaxScalerConfig::new();
let t = median_time(|| {
MinMaxScaler::transform(&config, &x);
});
results.insert("preprocessing_minmax_scaler".into(), t);
let config = KBinsDiscretizerConfig::new(4);
let t = median_time(|| {
KBinsDiscretizer::transform(&config, &x);
});
results.insert("preprocessing_discretizer".into(), t);
}
println!("Benchmarking approximation...");
{
let x = generate_data(42, 100, 500);
let config = PaaConfig::new(50);
let t = median_time(|| {
Paa::transform(&config, &x);
});
results.insert("approximation_paa".into(), t);
let config = SaxConfig::new(4);
let t = median_time(|| {
Sax::transform(&config, &x);
});
results.insert("approximation_sax".into(), t);
let config = DftConfig {
n_coefs: Some(20),
..DftConfig::new()
};
let t = median_time(|| {
Dft::transform(&config, &x);
});
results.insert("approximation_dft".into(), t);
let config = SfaConfig {
n_coefs: Some(20),
n_bins: 4,
..SfaConfig::new()
};
let t = median_time(|| {
Sfa::fit_transform(&config, &x, None);
});
results.insert("approximation_sfa".into(), t);
}
println!("Benchmarking metrics...");
{
for n in [100, 500, 1000] {
let data = generate_data(42, 2, n);
let a = &data[0];
let b = &data[1];
let t = median_time(|| {
dtw_classic(a, b);
});
results.insert(format!("metrics_dtw_classic_{n}"), t);
}
let data = generate_data(42, 2, 500);
let a = &data[0];
let b = &data[1];
let t = median_time(|| {
dtw_sakoe_chiba(a, b, 50);
});
results.insert("metrics_dtw_sakoe_chiba".into(), t);
let t = median_time(|| {
dtw_fast(a, b, 4, 2);
});
results.insert("metrics_dtw_fast".into(), t);
}
println!("Benchmarking bag_of_words...");
{
let x = generate_data(42, 50, 200);
let config = BagOfWordsConfig::new(10, 4);
let t = median_time(|| {
BagOfWords::transform(&config, &x);
});
results.insert("bag_of_words".into(), t);
}
println!("Benchmarking image...");
{
let x = generate_data(42, 50, 100);
let config = GafConfig::new();
let t = median_time(|| {
Gaf::transform(&config, &x);
});
results.insert("image_gasf".into(), t);
let config = GafConfig {
method: GafMethod::Difference,
..GafConfig::new()
};
let t = median_time(|| {
Gaf::transform(&config, &x);
});
results.insert("image_gadf".into(), t);
let config = MtfConfig::new();
let t = median_time(|| {
Mtf::transform(&config, &x);
});
results.insert("image_mtf".into(), t);
let config = RecurrencePlotConfig::new();
let t = median_time(|| {
RecurrencePlot::transform(&config, &x);
});
results.insert("image_recurrence_plot".into(), t);
}
println!("Benchmarking decomposition...");
{
let x = generate_data(42, 20, 200);
let config = SsaConfig::new(10);
let t = median_time(|| {
Ssa::transform(&config, &x);
});
results.insert("decomposition_ssa".into(), t);
}
println!("Benchmarking transformation...");
{
let x = generate_data(42, 50, 300);
let y = generate_labels(50);
let config = BossConfig::new(10, 4);
let t = median_time(|| {
Boss::fit_transform(&config, &x, None);
});
results.insert("transformation_boss".into(), t);
let config = RocketConfig {
n_kernels: 500,
random_seed: Some(42),
};
let t = median_time(|| {
Rocket::fit_transform(&config, &x);
});
results.insert("transformation_rocket".into(), t);
let config = ShapeletTransformConfig {
n_shapelets: 5,
random_seed: Some(42),
..ShapeletTransformConfig::new(5)
};
let t = median_time(|| {
ShapeletTransform::fit_transform(&config, &x, &y);
});
results.insert("transformation_shapelet".into(), t);
let config = BagOfPatternsConfig::new(10, 4);
let t = median_time(|| {
BagOfPatterns::transform(&config, &x);
});
results.insert("transformation_bag_of_patterns".into(), t);
}
println!("Benchmarking classification...");
{
let x_train = generate_data(42, 50, 200);
let y_train = generate_labels(50);
let x_test = generate_data(99, 20, 200);
let knn_config = KnnConfig::new(3);
let t = median_time(|| {
let fitted = Knn::fit(&knn_config, &x_train, &y_train, EuclideanMetric);
Knn::predict(&fitted, &x_test);
});
results.insert("classification_knn_euclidean".into(), t);
let config = BossvsConfig::new(10);
let t = median_time(|| {
let fitted = Bossvs::fit(&config, &x_train, &y_train);
Bossvs::predict(&fitted, &x_test);
});
results.insert("classification_bossvs".into(), t);
let config = SaxvsmConfig::new(10, 4);
let t = median_time(|| {
let fitted = Saxvsm::fit(&config, &x_train, &y_train);
Saxvsm::predict(&fitted, &x_test);
});
results.insert("classification_saxvsm".into(), t);
let config = TimeSeriesForestConfig {
n_estimators: 50,
random_seed: Some(42),
..TimeSeriesForestConfig::new(50)
};
let t = median_time(|| {
let fitted = TimeSeriesForest::fit(&config, &x_train, &y_train);
TimeSeriesForest::predict(&fitted, &x_test);
});
results.insert("classification_tsf".into(), t);
let config = TsbfConfig {
n_estimators: 50,
random_seed: Some(42),
..TsbfConfig::new(50)
};
let t = median_time(|| {
let fitted = Tsbf::fit(&config, &x_train, &y_train);
Tsbf::predict(&fitted, &x_test);
});
results.insert("classification_tsbf".into(), t);
let config = LearningShapeletsConfig {
n_shapelets_per_size: 3,
shapelet_sizes: vec![3, 5],
learning_rate: 0.01,
n_epochs: 50,
random_seed: Some(42),
..LearningShapeletsConfig::new()
};
let t = median_time(|| {
let fitted = LearningShapelets::fit(&config, &x_train, &y_train);
LearningShapelets::predict(&fitted, &x_test);
});
results.insert("classification_learning_shapelets".into(), t);
}
let out_path = "test_harness/benchmark_results_rust.json";
let json = serde_json::to_string_pretty(&results).expect("Failed to serialize results");
std::fs::write(out_path, &json).expect("Failed to write results file");
println!("\nWrote {} benchmarks to {}", results.len(), out_path);
let mut sorted: Vec<_> = results.iter().collect();
sorted.sort_by_key(|(k, _)| (*k).clone());
for (name, t) in &sorted {
println!(" {:<45} {:>10.3} ms", name, *t * 1000.0);
}
}