use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand_distr::Normal;
#[derive(Debug, Clone)]
pub struct RocketConfig {
pub n_kernels: usize,
pub random_seed: Option<u64>,
}
impl RocketConfig {
pub fn new(n_kernels: usize) -> Self {
Self {
n_kernels,
random_seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RocketKernel {
pub weights: Vec<f64>,
pub length: usize,
pub bias: f64,
pub dilation: usize,
pub padding: usize,
}
#[derive(Debug, Clone)]
pub struct RocketFitted {
pub kernels: Vec<RocketKernel>,
}
pub struct Rocket;
impl Rocket {
pub fn fit(config: &RocketConfig) -> RocketFitted {
let mut rng = match config.random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
let candidate_lengths = [7, 9, 11];
let kernels: Vec<RocketKernel> = (0..config.n_kernels)
.map(|_| generate_kernel(&mut rng, &candidate_lengths))
.collect();
RocketFitted { kernels }
}
pub fn transform(fitted: &RocketFitted, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
assert!(!x.is_empty(), "Input must have at least one sample");
let compute = |sample: &Vec<f64>| {
let mut features = Vec::with_capacity(fitted.kernels.len() * 2);
for kernel in &fitted.kernels {
let (max_val, ppv) = apply_kernel_features(sample, kernel);
features.push(max_val);
features.push(ppv);
}
features
};
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
return x.par_iter().map(compute).collect();
}
#[cfg(not(feature = "parallel"))]
x.iter().map(compute).collect()
}
pub fn fit_transform(config: &RocketConfig, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
let fitted = Self::fit(config);
Self::transform(&fitted, x)
}
}
fn generate_kernel(rng: &mut StdRng, candidate_lengths: &[usize]) -> RocketKernel {
let length = candidate_lengths[rng.gen_range(0..candidate_lengths.len())];
let normal = Normal::new(0.0, 1.0).unwrap();
let mut weights: Vec<f64> = (0..length).map(|_| rng.sample(normal)).collect();
let mean = weights.iter().sum::<f64>() / weights.len() as f64;
for w in &mut weights {
*w -= mean;
}
let bias: f64 = rng.gen_range(-1.0..1.0);
let max_log = 7; let dilation_exp: u32 = rng.gen_range(0..=max_log);
let dilation = 2_usize.pow(dilation_exp);
let padding = if rng.gen_bool(0.5) {
(length - 1) * dilation
} else {
0
};
RocketKernel {
weights,
length,
bias,
dilation,
padding,
}
}
#[inline]
fn apply_kernel_features(ts: &[f64], kernel: &RocketKernel) -> (f64, f64) {
let n = ts.len();
let kernel_span = (kernel.length - 1) * kernel.dilation + 1;
let padded_len = n + 2 * kernel.padding;
if padded_len < kernel_span {
let ppv = if kernel.bias > 0.0 { 1.0 } else { 0.0 };
return (kernel.bias, ppv);
}
let output_len = padded_len - kernel_span + 1;
let mut max_val = f64::NEG_INFINITY;
let mut positive_count: usize = 0;
for i in 0..output_len {
let mut sum = kernel.bias;
for (j, &w) in kernel.weights.iter().enumerate() {
let idx = i + j * kernel.dilation;
let ts_idx = idx as isize - kernel.padding as isize;
if ts_idx >= 0 && (ts_idx as usize) < n {
sum += w * ts[ts_idx as usize];
}
}
if sum > max_val {
max_val = sum;
}
if sum > 0.0 {
positive_count += 1;
}
}
(max_val, positive_count as f64 / output_len as f64)
}
#[cfg(test)]
fn apply_kernel(ts: &[f64], kernel: &RocketKernel) -> Vec<f64> {
let n = ts.len();
let kernel_span = (kernel.length - 1) * kernel.dilation + 1;
let padded_len = n + 2 * kernel.padding;
if padded_len < kernel_span {
return vec![kernel.bias];
}
let output_len = padded_len - kernel_span + 1;
let mut output = Vec::with_capacity(output_len);
for i in 0..output_len {
let mut sum = kernel.bias;
for (j, &w) in kernel.weights.iter().enumerate() {
let idx = i + j * kernel.dilation;
let ts_idx = idx as isize - kernel.padding as isize;
if ts_idx >= 0 && (ts_idx as usize) < n {
sum += w * ts[ts_idx as usize];
}
}
output.push(sum);
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rocket_basic() {
let config = RocketConfig {
n_kernels: 10,
random_seed: Some(42),
};
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let result = Rocket::fit_transform(&config, &x);
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 20); }
#[test]
fn test_rocket_deterministic() {
let config = RocketConfig {
n_kernels: 5,
random_seed: Some(123),
};
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]];
let r1 = Rocket::fit_transform(&config, &x);
let r2 = Rocket::fit_transform(&config, &x);
assert_eq!(r1, r2);
}
#[test]
fn test_rocket_features_range() {
let config = RocketConfig {
n_kernels: 20,
random_seed: Some(42),
};
let x = vec![vec![
0.0, 1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0, 5.0, -5.0, 6.0,
]];
let result = Rocket::fit_transform(&config, &x);
for i in 0..20 {
let ppv = result[0][i * 2 + 1];
assert!((0.0..=1.0).contains(&ppv), "PPV should be in [0, 1]");
}
}
#[test]
fn test_rocket_fit_then_transform() {
let config = RocketConfig {
n_kernels: 5,
random_seed: Some(42),
};
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]];
let fitted = Rocket::fit(&config);
let result = Rocket::transform(&fitted, &x);
assert_eq!(result[0].len(), 10); }
#[test]
fn test_apply_kernel() {
let kernel = RocketKernel {
weights: vec![1.0, -1.0, 1.0],
length: 3,
bias: 0.0,
dilation: 1,
padding: 0,
};
let ts = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = apply_kernel(&ts, &kernel);
assert_eq!(result.len(), 3); assert!((result[0] - 2.0).abs() < 1e-10);
}
#[test]
fn test_apply_kernel_dilated() {
let kernel = RocketKernel {
weights: vec![1.0, 1.0],
length: 2,
bias: 0.0,
dilation: 2,
padding: 0,
};
let ts = vec![1.0, 0.0, 2.0, 0.0, 3.0];
let result = apply_kernel(&ts, &kernel);
assert_eq!(result.len(), 3);
assert!((result[0] - 3.0).abs() < 1e-10);
assert!((result[2] - 5.0).abs() < 1e-10);
}
#[test]
fn test_apply_kernel_features_matches_apply_kernel() {
let config = RocketConfig {
n_kernels: 50,
random_seed: Some(42),
};
let fitted = Rocket::fit(&config);
let ts = vec![
0.5, -1.2, 3.1, 0.0, -0.7, 2.4, 1.1, -0.3, 0.8, -1.5, 2.0, 0.3, -0.9, 1.7, -0.1, 0.6,
-2.0, 1.3, -0.5, 0.2,
];
for (i, kernel) in fitted.kernels.iter().enumerate() {
let conv = apply_kernel(&ts, kernel);
let expected_max = conv.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let expected_ppv = conv.iter().filter(|&&v| v > 0.0).count() as f64 / conv.len() as f64;
let (actual_max, actual_ppv) = apply_kernel_features(&ts, kernel);
assert!(
(actual_max - expected_max).abs() < 1e-10,
"kernel {i}: max mismatch: fused={actual_max}, expected={expected_max}"
);
assert!(
(actual_ppv - expected_ppv).abs() < 1e-10,
"kernel {i}: ppv mismatch: fused={actual_ppv}, expected={expected_ppv}"
);
}
}
}