use burn::nn::Linear;
use burn::nn::LinearConfig;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MiniRocketConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_features: usize,
pub seed: u64,
}
impl Default for MiniRocketConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_features: 10000,
seed: 42,
}
}
}
impl MiniRocketConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> MiniRocket<B> {
MiniRocket::new(self.clone(), device)
}
}
pub const KERNEL_PATTERNS: [[i8; 9]; 84] = [
[-1, -1, -1, -1, -1, -1, -1, -1, 2],
[-1, -1, -1, -1, -1, -1, -1, 2, -1],
[-1, -1, -1, -1, -1, -1, 2, -1, -1],
[-1, -1, -1, -1, -1, 2, -1, -1, -1],
[-1, -1, -1, -1, 2, -1, -1, -1, -1],
[-1, -1, -1, 2, -1, -1, -1, -1, -1],
[-1, -1, 2, -1, -1, -1, -1, -1, -1],
[-1, 2, -1, -1, -1, -1, -1, -1, -1],
[2, -1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, 2, 2],
[-1, -1, -1, -1, -1, -1, 2, 2, -1],
[-1, -1, -1, -1, -1, 2, 2, -1, -1],
[-1, -1, -1, -1, 2, 2, -1, -1, -1],
[-1, -1, -1, 2, 2, -1, -1, -1, -1],
[-1, -1, 2, 2, -1, -1, -1, -1, -1],
[-1, 2, 2, -1, -1, -1, -1, -1, -1],
[2, 2, -1, -1, -1, -1, -1, -1, -1],
[2, -1, -1, -1, -1, -1, -1, -1, 2],
[-1, -1, -1, -1, -1, -1, 2, -1, 2],
[-1, -1, -1, -1, -1, 2, -1, 2, -1],
[-1, -1, -1, -1, 2, -1, 2, -1, -1],
[-1, -1, -1, 2, -1, 2, -1, -1, -1],
[-1, -1, 2, -1, 2, -1, -1, -1, -1],
[-1, 2, -1, 2, -1, -1, -1, -1, -1],
[2, -1, 2, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, 2, 2, 2, -1],
[-1, -1, -1, -1, 2, 2, 2, -1, -1],
[-1, -1, -1, 2, 2, 2, -1, -1, -1],
[-1, -1, 2, 2, 2, -1, -1, -1, -1],
[-1, 2, 2, 2, -1, -1, -1, -1, -1],
[2, 2, 2, -1, -1, -1, -1, -1, -1],
[2, 2, -1, -1, -1, -1, -1, -1, 2],
[2, -1, -1, -1, -1, -1, -1, 2, 2],
[-1, -1, -1, -1, -1, 2, 2, -1, 2],
[-1, -1, -1, -1, 2, 2, -1, 2, -1],
[-1, -1, -1, 2, 2, -1, 2, -1, -1],
[-1, -1, 2, 2, -1, 2, -1, -1, -1],
[-1, 2, 2, -1, 2, -1, -1, -1, -1],
[2, 2, -1, 2, -1, -1, -1, -1, -1],
[2, -1, 2, -1, -1, -1, -1, -1, 2],
[-1, 2, -1, 2, -1, -1, -1, 2, -1],
[2, -1, 2, -1, -1, -1, -1, 2, -1],
[-1, -1, 2, -1, 2, -1, 2, -1, -1],
[-1, 2, -1, 2, -1, 2, -1, -1, -1],
[2, -1, 2, -1, 2, -1, -1, -1, -1],
[-1, -1, -1, -1, 2, 2, 2, 2, -1],
[-1, -1, -1, 2, 2, 2, 2, -1, -1],
[-1, -1, 2, 2, 2, 2, -1, -1, -1],
[-1, 2, 2, 2, 2, -1, -1, -1, -1],
[2, 2, 2, 2, -1, -1, -1, -1, -1],
[2, 2, 2, -1, -1, -1, -1, -1, 2],
[2, 2, -1, -1, -1, -1, -1, 2, 2],
[2, -1, -1, -1, -1, -1, 2, 2, 2],
[-1, -1, -1, -1, 2, 2, 2, -1, 2],
[-1, -1, -1, 2, 2, 2, -1, 2, -1],
[-1, -1, 2, 2, 2, -1, 2, -1, -1],
[-1, 2, 2, 2, -1, 2, -1, -1, -1],
[2, 2, 2, -1, 2, -1, -1, -1, -1],
[2, 2, -1, 2, -1, -1, -1, -1, 2],
[2, -1, 2, -1, -1, -1, -1, 2, 2],
[-1, 2, -1, 2, -1, -1, 2, 2, -1],
[2, -1, 2, -1, -1, 2, 2, -1, -1],
[-1, 2, -1, 2, 2, -1, -1, 2, -1],
[2, -1, 2, 2, -1, -1, 2, -1, -1],
[-1, 2, 2, -1, -1, 2, -1, 2, -1],
[2, 2, -1, -1, 2, -1, 2, -1, -1],
[2, -1, -1, 2, -1, 2, -1, -1, 2],
[-1, -1, 2, -1, 2, -1, -1, 2, 2],
[-1, 2, -1, 2, -1, 2, 2, -1, -1],
[2, -1, 2, -1, 2, 2, -1, -1, -1],
[-1, 2, -1, 2, 2, -1, -1, -1, 2],
[2, -1, 2, 2, -1, -1, -1, 2, -1],
[-1, 2, 2, -1, -1, -1, 2, -1, 2],
[2, 2, -1, -1, -1, 2, -1, 2, -1],
[2, -1, -1, -1, 2, -1, 2, -1, 2],
[-1, -1, -1, 2, -1, 2, -1, 2, 2],
[-1, -1, 2, -1, 2, -1, 2, 2, -1],
[-1, 2, -1, 2, -1, 2, 2, -1, -1],
[2, -1, 2, -1, 2, 2, -1, -1, -1],
[-1, 2, -1, 2, 2, -1, -1, 2, -1],
[2, -1, 2, 2, -1, -1, 2, -1, -1],
[-1, 2, 2, -1, -1, 2, -1, -1, 2],
[2, 2, -1, -1, 2, -1, -1, 2, -1],
[2, -1, -1, 2, -1, -1, 2, -1, 2],
];
#[derive(Debug, Clone)]
pub struct MiniRocketFeatures {
pub kernels: Vec<Vec<f32>>,
pub dilations: Vec<usize>,
pub biases: Vec<f32>,
pub n_features: usize,
}
#[derive(Module, Debug)]
pub struct MiniRocket<B: Backend> {
classifier: Linear<B>,
}
impl MiniRocketFeatures {
pub fn new(config: &MiniRocketConfig) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
let n_kernels = config.n_features / 84 + 1;
let mut kernels = Vec::new();
let mut dilations = Vec::new();
let mut biases = Vec::new();
for _ in 0..n_kernels {
let max_dilation = (config.seq_len - 1) / 8;
let dilation = rng.gen_range(1..=max_dilation.max(1));
for pattern in &KERNEL_PATTERNS {
if kernels.len() >= config.n_features {
break;
}
let kernel: Vec<f32> = pattern.iter().map(|&x| x as f32).collect();
kernels.push(kernel);
dilations.push(dilation);
biases.push(rng.gen_range(-1.0..1.0));
}
}
kernels.truncate(config.n_features);
dilations.truncate(config.n_features);
biases.truncate(config.n_features);
Self {
kernels,
dilations,
biases,
n_features: config.n_features,
}
}
pub fn extract<B: Backend>(&self, x: &Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, n_vars, seq_len] = x.dims();
let device = x.device();
let x_data: Vec<f32> = x.to_data().to_vec().unwrap();
let mut features = vec![0.0f32; batch * self.n_features];
for b in 0..batch {
for (k_idx, (kernel, &dilation)) in
self.kernels.iter().zip(&self.dilations).enumerate()
{
let bias = self.biases[k_idx];
let kernel_len = kernel.len();
let effective_len = (kernel_len - 1) * dilation + 1;
if effective_len > seq_len {
continue;
}
let mut ppv_sum = 0.0f32;
let mut count = 0;
for v in 0..n_vars {
for t in 0..=(seq_len - effective_len) {
let mut conv_val = 0.0f32;
for (i, &w) in kernel.iter().enumerate() {
let idx = b * n_vars * seq_len + v * seq_len + t + i * dilation;
conv_val += x_data[idx] * w;
}
if conv_val > bias {
ppv_sum += 1.0;
}
count += 1;
}
}
features[b * self.n_features + k_idx] =
if count > 0 { ppv_sum / count as f32 } else { 0.0 };
}
}
Tensor::<B, 1>::from_floats(features.as_slice(), &device)
.reshape([batch, self.n_features])
}
}
impl<B: Backend> MiniRocket<B> {
pub fn new(config: MiniRocketConfig, device: &B::Device) -> Self {
let classifier = LinearConfig::new(config.n_features, config.n_classes).init(device);
Self { classifier }
}
pub fn forward(&self, features: Tensor<B, 2>) -> Tensor<B, 2> {
self.classifier.forward(features)
}
pub fn forward_probs(&self, features: Tensor<B, 2>) -> Tensor<B, 2> {
let logits = self.forward(features);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_minirocket_config() {
let config = MiniRocketConfig::default();
assert_eq!(config.n_features, 10000);
assert_eq!(config.seed, 42);
}
#[test]
fn test_kernel_patterns() {
for pattern in &KERNEL_PATTERNS {
let sum: i8 = pattern.iter().sum();
assert!(sum.abs() <= 27); }
}
}