use burn::nn::Linear;
use burn::nn::LinearConfig;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RocketConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_kernels: usize,
pub kernel_lengths: Vec<usize>,
pub seed: u64,
}
impl Default for RocketConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_kernels: 10000,
kernel_lengths: vec![7, 9, 11],
seed: 42,
}
}
}
impl RocketConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_n_kernels(mut self, n_kernels: usize) -> Self {
self.n_kernels = n_kernels;
self
}
#[must_use]
pub fn with_kernel_lengths(mut self, lengths: Vec<usize>) -> Self {
self.kernel_lengths = lengths;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> Rocket<B> {
Rocket::new(self.clone(), device)
}
}
#[derive(Debug, Clone)]
pub struct RandomKernel {
pub weights: Vec<f32>,
pub length: usize,
pub dilation: usize,
pub padding: usize,
pub bias: f32,
}
#[derive(Debug, Clone)]
pub struct RocketFeatures {
pub kernels: Vec<RandomKernel>,
pub n_kernels: usize,
pub n_features: usize,
}
impl RocketFeatures {
pub fn new(config: &RocketConfig) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
let mut kernels = Vec::with_capacity(config.n_kernels);
for _ in 0..config.n_kernels {
let length = config.kernel_lengths[rng.gen_range(0..config.kernel_lengths.len())];
let mut weights: Vec<f32> = (0..length)
.map(|_| {
let u1: f32 = rng.gen();
let u2: f32 = rng.gen();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
})
.collect();
let mean: f32 = weights.iter().sum::<f32>() / length as f32;
for w in &mut weights {
*w -= mean;
}
let max_dilation = ((config.seq_len - 1) / (length - 1)).max(1);
let dilation = rng.gen_range(1..=max_dilation);
let padding = if rng.gen::<bool>() {
(length - 1) * dilation / 2
} else {
0
};
let bias: f32 = rng.gen_range(-1.0..1.0);
kernels.push(RandomKernel {
weights,
length,
dilation,
padding,
bias,
});
}
Self {
kernels,
n_kernels: config.n_kernels,
n_features: config.n_kernels * 2, }
}
pub fn fit_biases<B: Backend>(&mut self, x: &Tensor<B, 3>) {
let [n_samples, n_vars, seq_len] = x.dims();
let x_data: Vec<f32> = x.to_data().to_vec().unwrap();
for kernel in &mut self.kernels {
let effective_len = (kernel.length - 1) * kernel.dilation + 1;
let output_len = if kernel.padding > 0 {
seq_len
} else {
seq_len.saturating_sub(effective_len - 1)
};
if output_len == 0 {
continue;
}
let mut all_outputs: Vec<f32> = Vec::new();
for b in 0..n_samples.min(100) {
for v in 0..n_vars {
for t in 0..output_len {
let mut conv_val = 0.0f32;
for (i, &w) in kernel.weights.iter().enumerate() {
let pos = t as i32 - kernel.padding as i32 + (i * kernel.dilation) as i32;
if pos >= 0 && (pos as usize) < seq_len {
let idx = b * n_vars * seq_len + v * seq_len + pos as usize;
conv_val += x_data[idx] * w;
}
}
all_outputs.push(conv_val);
}
}
}
if !all_outputs.is_empty() {
all_outputs.sort_by(|a, b| a.partial_cmp(b).unwrap());
kernel.bias = all_outputs[all_outputs.len() / 2];
}
}
}
pub fn extract<B: Backend>(&self, x: &Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, n_vars, seq_len] = x.dims();
let device = x.device();
let x_data: Vec<f32> = x.to_data().to_vec().unwrap();
let mut features = vec![0.0f32; batch * self.n_features];
for b in 0..batch {
for (k_idx, kernel) in self.kernels.iter().enumerate() {
let effective_len = (kernel.length - 1) * kernel.dilation + 1;
let output_len = if kernel.padding > 0 {
seq_len
} else {
seq_len.saturating_sub(effective_len - 1)
};
if output_len == 0 {
continue;
}
let mut max_val = f32::NEG_INFINITY;
let mut ppv_sum = 0.0f32;
let mut count = 0;
for v in 0..n_vars {
for t in 0..output_len {
let mut conv_val = 0.0f32;
for (i, &w) in kernel.weights.iter().enumerate() {
let pos = t as i32 - kernel.padding as i32 + (i * kernel.dilation) as i32;
if pos >= 0 && (pos as usize) < seq_len {
let idx = b * n_vars * seq_len + v * seq_len + pos as usize;
conv_val += x_data[idx] * w;
}
}
if conv_val > max_val {
max_val = conv_val;
}
if conv_val > kernel.bias {
ppv_sum += 1.0;
}
count += 1;
}
}
let feat_idx = b * self.n_features + k_idx * 2;
features[feat_idx] = if max_val.is_finite() { max_val } else { 0.0 };
features[feat_idx + 1] = if count > 0 { ppv_sum / count as f32 } else { 0.0 };
}
}
Tensor::<B, 1>::from_floats(features.as_slice(), &device).reshape([batch, self.n_features])
}
}
#[derive(Module, Debug)]
pub struct Rocket<B: Backend> {
classifier: Linear<B>,
n_features: usize,
}
impl<B: Backend> Rocket<B> {
pub fn new(config: RocketConfig, device: &B::Device) -> Self {
let n_features = config.n_kernels * 2;
let classifier = LinearConfig::new(n_features, config.n_classes).init(device);
Self {
classifier,
n_features,
}
}
pub fn n_features(&self) -> usize {
self.n_features
}
pub fn forward(&self, features: Tensor<B, 2>) -> Tensor<B, 2> {
self.classifier.forward(features)
}
pub fn forward_probs(&self, features: Tensor<B, 2>) -> Tensor<B, 2> {
let logits = self.forward(features);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rocket_config_default() {
let config = RocketConfig::default();
assert_eq!(config.n_kernels, 10000);
assert_eq!(config.kernel_lengths, vec![7, 9, 11]);
assert_eq!(config.seed, 42);
}
#[test]
fn test_rocket_config_builder() {
let config = RocketConfig::new(3, 200, 10)
.with_n_kernels(5000)
.with_kernel_lengths(vec![5, 7, 9, 11]);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
assert_eq!(config.n_kernels, 5000);
assert_eq!(config.kernel_lengths.len(), 4);
}
#[test]
fn test_rocket_features_creation() {
let config = RocketConfig::new(1, 100, 2).with_n_kernels(100);
let features = RocketFeatures::new(&config);
assert_eq!(features.n_kernels, 100);
assert_eq!(features.n_features, 200); assert_eq!(features.kernels.len(), 100);
for kernel in &features.kernels {
assert!(config.kernel_lengths.contains(&kernel.length));
assert!(kernel.dilation >= 1);
assert!(kernel.weights.len() == kernel.length);
let mean: f32 = kernel.weights.iter().sum::<f32>() / kernel.length as f32;
assert!(mean.abs() < 1e-5, "Kernel weights should have zero mean");
}
}
#[test]
fn test_rocket_deterministic() {
let config = RocketConfig::new(1, 100, 2).with_n_kernels(100);
let features1 = RocketFeatures::new(&config);
let features2 = RocketFeatures::new(&config);
for (k1, k2) in features1.kernels.iter().zip(&features2.kernels) {
assert_eq!(k1.length, k2.length);
assert_eq!(k1.dilation, k2.dilation);
assert_eq!(k1.padding, k2.padding);
for (w1, w2) in k1.weights.iter().zip(&k2.weights) {
assert!((w1 - w2).abs() < 1e-6);
}
}
}
}