use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TsbfConfig {
pub n_estimators: usize,
pub n_intervals: usize,
pub min_interval_length: usize,
pub random_seed: Option<u64>,
}
impl TsbfConfig {
pub fn new(n_estimators: usize) -> Self {
Self {
n_estimators,
n_intervals: 10,
min_interval_length: 3,
random_seed: None,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct FeatureExtractor {
intervals: Vec<(usize, usize)>,
feature_index: usize,
threshold: f64,
left_class: String,
right_class: String,
}
#[derive(Debug, Clone)]
pub struct TsbfFitted {
pub(crate) extractors: Vec<FeatureExtractor>,
}
pub struct Tsbf;
impl Tsbf {
pub fn fit(config: &TsbfConfig, x: &[Vec<f64>], y: &[String]) -> TsbfFitted {
assert!(!x.is_empty(), "Input must have at least one sample");
assert_eq!(x.len(), y.len(), "X and y must have same length");
let n_timestamps = x[0].len();
let mut rng = match config.random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
let extractors: Vec<FeatureExtractor> = (0..config.n_estimators)
.map(|_| {
let intervals: Vec<(usize, usize)> = (0..config.n_intervals)
.map(|_| {
let start = rng.gen_range(0..n_timestamps - config.min_interval_length);
let end = rng.gen_range(start + config.min_interval_length..=n_timestamps);
(start, end)
})
.collect();
let features: Vec<Vec<f64>> = x
.iter()
.map(|sample| extract_features(sample, &intervals))
.collect();
let (feature_index, threshold, left_class, right_class) =
find_best_split(&features, y);
FeatureExtractor {
intervals,
feature_index,
threshold,
left_class,
right_class,
}
})
.collect();
TsbfFitted { extractors }
}
pub fn predict(fitted: &TsbfFitted, x: &[Vec<f64>]) -> Vec<String> {
x.iter()
.map(|sample| {
let mut votes: HashMap<&str, usize> = HashMap::new();
for ext in &fitted.extractors {
let features = extract_features(sample, &ext.intervals);
let pred = if features[ext.feature_index] <= ext.threshold {
ext.left_class.as_str()
} else {
ext.right_class.as_str()
};
*votes.entry(pred).or_insert(0) += 1;
}
votes
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(class, _)| class.to_string())
.unwrap()
})
.collect()
}
pub fn score(fitted: &TsbfFitted, x: &[Vec<f64>], y: &[String]) -> f64 {
let predictions = Self::predict(fitted, x);
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(p, t)| p == t)
.count();
correct as f64 / y.len() as f64
}
}
fn extract_features(sample: &[f64], intervals: &[(usize, usize)]) -> Vec<f64> {
let mut features = Vec::with_capacity(intervals.len() * 5);
for &(start, end) in intervals {
let slice = &sample[start..end];
let n = slice.len() as f64;
let mean = slice.iter().sum::<f64>() / n;
let var = slice.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n;
let std = var.sqrt();
let x_mean = (n - 1.0) / 2.0;
let mut num = 0.0;
let mut den = 0.0;
for (i, &v) in slice.iter().enumerate() {
let xi = i as f64 - x_mean;
num += xi * (v - mean);
den += xi * xi;
}
let slope = if den > 0.0 { num / den } else { 0.0 };
let min = slice.iter().copied().fold(f64::INFINITY, f64::min);
let max = slice.iter().copied().fold(f64::NEG_INFINITY, f64::max);
features.push(mean);
features.push(std);
features.push(slope);
features.push(min);
features.push(max);
}
features
}
fn find_best_split(features: &[Vec<f64>], y: &[String]) -> (usize, f64, String, String) {
let n_features = features[0].len();
let n_samples = features.len();
let mut best_gini = f64::INFINITY;
let mut best_feature = 0;
let mut best_threshold = 0.0;
for f_idx in 0..n_features {
let mut vals: Vec<(f64, &str)> = features
.iter()
.zip(y.iter())
.map(|(feat, label)| (feat[f_idx], label.as_str()))
.collect();
vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
for i in 0..n_samples - 1 {
if (vals[i].0 - vals[i + 1].0).abs() < 1e-15 {
continue;
}
let threshold = (vals[i].0 + vals[i + 1].0) / 2.0;
let left: Vec<&str> = vals[..=i].iter().map(|&(_, l)| l).collect();
let right: Vec<&str> = vals[i + 1..].iter().map(|&(_, l)| l).collect();
let gini = (left.len() as f64 * gini_impurity(&left)
+ right.len() as f64 * gini_impurity(&right))
/ n_samples as f64;
if gini < best_gini {
best_gini = gini;
best_feature = f_idx;
best_threshold = threshold;
}
}
}
let left_labels: Vec<&str> = features
.iter()
.zip(y.iter())
.filter(|(f, _)| f[best_feature] <= best_threshold)
.map(|(_, l)| l.as_str())
.collect();
let right_labels: Vec<&str> = features
.iter()
.zip(y.iter())
.filter(|(f, _)| f[best_feature] > best_threshold)
.map(|(_, l)| l.as_str())
.collect();
let left_class = majority_class(&left_labels)
.unwrap_or(y[0].as_str())
.to_string();
let right_class = majority_class(&right_labels)
.unwrap_or(y[0].as_str())
.to_string();
(best_feature, best_threshold, left_class, right_class)
}
fn gini_impurity(labels: &[&str]) -> f64 {
let n = labels.len() as f64;
if n == 0.0 {
return 0.0;
}
let mut counts: HashMap<&str, usize> = HashMap::new();
for &l in labels {
*counts.entry(l).or_insert(0) += 1;
}
1.0 - counts
.values()
.map(|&c| (c as f64 / n).powi(2))
.sum::<f64>()
}
fn majority_class<'a>(labels: &[&'a str]) -> Option<&'a str> {
let mut counts: HashMap<&str, usize> = HashMap::new();
for &l in labels {
*counts.entry(l).or_insert(0) += 1;
}
counts
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(class, _)| class)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tsbf_basic() {
let config = TsbfConfig {
n_estimators: 10,
random_seed: Some(42),
..TsbfConfig::new(10)
};
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0],
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![8.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = Tsbf::fit(&config, &x, &y);
let predictions = Tsbf::predict(&fitted, &x);
assert_eq!(predictions.len(), 4);
}
#[test]
fn test_tsbf_score() {
let config = TsbfConfig {
n_estimators: 20,
random_seed: Some(42),
..TsbfConfig::new(20)
};
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 6.0],
vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![6.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = Tsbf::fit(&config, &x, &y);
let score = Tsbf::score(&fitted, &x, &y);
assert!((0.0..=1.0).contains(&score));
}
}