use burn::nn::Linear;
use burn::nn::LinearConfig;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
use super::minirocket::KERNEL_PATTERNS;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FeatureType {
PPV,
MPV,
MIPV,
LSPV,
}
impl Default for FeatureType {
fn default() -> Self {
Self::PPV
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiRocketConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_kernels: usize,
pub feature_types: Vec<FeatureType>,
pub seed: u64,
}
impl Default for MultiRocketConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_kernels: 10000,
feature_types: vec![
FeatureType::PPV,
FeatureType::MPV,
FeatureType::MIPV,
FeatureType::LSPV,
],
seed: 42,
}
}
}
impl MultiRocketConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_n_kernels(mut self, n_kernels: usize) -> Self {
self.n_kernels = n_kernels;
self
}
#[must_use]
pub fn with_feature_types(mut self, feature_types: Vec<FeatureType>) -> Self {
self.feature_types = feature_types;
self
}
pub fn n_features(&self) -> usize {
self.n_kernels * self.feature_types.len()
}
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiRocket<B> {
MultiRocket::new(self.clone(), device)
}
}
#[derive(Debug, Clone)]
pub struct MultiRocketFeatures {
pub kernels: Vec<Vec<f32>>,
pub dilations: Vec<usize>,
pub biases: Vec<f32>,
pub feature_types: Vec<FeatureType>,
pub n_kernels: usize,
pub n_features: usize,
}
impl MultiRocketFeatures {
pub fn new(config: &MultiRocketConfig) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
let n_pattern_sets = config.n_kernels / 84 + 1;
let mut kernels = Vec::new();
let mut dilations = Vec::new();
let mut biases = Vec::new();
for _ in 0..n_pattern_sets {
let max_dilation = (config.seq_len - 1) / 8;
let dilation = rng.gen_range(1..=max_dilation.max(1));
for pattern in &KERNEL_PATTERNS {
if kernels.len() >= config.n_kernels {
break;
}
let kernel: Vec<f32> = pattern.iter().map(|&x| x as f32).collect();
kernels.push(kernel);
dilations.push(dilation);
biases.push(rng.gen_range(-1.0..1.0));
}
}
kernels.truncate(config.n_kernels);
dilations.truncate(config.n_kernels);
biases.truncate(config.n_kernels);
let n_features = config.n_kernels * config.feature_types.len();
Self {
kernels,
dilations,
biases,
feature_types: config.feature_types.clone(),
n_kernels: config.n_kernels,
n_features,
}
}
pub fn fit_biases<B: Backend>(&mut self, x: &Tensor<B, 3>) {
let [n_samples, n_vars, seq_len] = x.dims();
let x_data: Vec<f32> = x.to_data().to_vec().unwrap();
for (_k_idx, ((kernel, &dilation), bias)) in self
.kernels
.iter()
.zip(&self.dilations)
.zip(&mut self.biases)
.enumerate()
{
let kernel_len = kernel.len();
let effective_len = (kernel_len - 1) * dilation + 1;
if effective_len > seq_len {
continue;
}
let mut conv_outputs: Vec<f32> = Vec::new();
for b in 0..n_samples.min(100) {
for v in 0..n_vars {
for t in 0..=(seq_len - effective_len) {
let mut conv_val = 0.0f32;
for (i, &w) in kernel.iter().enumerate() {
let idx = b * n_vars * seq_len + v * seq_len + t + i * dilation;
conv_val += x_data[idx] * w;
}
conv_outputs.push(conv_val);
}
}
}
if !conv_outputs.is_empty() {
conv_outputs.sort_by(|a, b| a.partial_cmp(b).unwrap());
*bias = conv_outputs[conv_outputs.len() / 2];
}
}
}
pub fn extract<B: Backend>(&self, x: &Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, n_vars, seq_len] = x.dims();
let device = x.device();
let x_data: Vec<f32> = x.to_data().to_vec().unwrap();
let n_feature_types = self.feature_types.len();
let mut features = vec![0.0f32; batch * self.n_features];
for b in 0..batch {
for (k_idx, ((kernel, &dilation), &bias)) in self
.kernels
.iter()
.zip(&self.dilations)
.zip(&self.biases)
.enumerate()
{
let kernel_len = kernel.len();
let effective_len = (kernel_len - 1) * dilation + 1;
if effective_len > seq_len {
continue;
}
let mut positive_values: Vec<f32> = Vec::new();
let mut positive_indices: Vec<usize> = Vec::new();
let mut total_count = 0;
let mut current_stretch = 0;
let mut longest_stretch = 0;
for v in 0..n_vars {
for t in 0..=(seq_len - effective_len) {
let mut conv_val = 0.0f32;
for (i, &w) in kernel.iter().enumerate() {
let idx = b * n_vars * seq_len + v * seq_len + t + i * dilation;
conv_val += x_data[idx] * w;
}
if conv_val > bias {
positive_values.push(conv_val);
positive_indices.push(total_count);
current_stretch += 1;
if current_stretch > longest_stretch {
longest_stretch = current_stretch;
}
} else {
current_stretch = 0;
}
total_count += 1;
}
}
for (f_idx, &feature_type) in self.feature_types.iter().enumerate() {
let feature_value = match feature_type {
FeatureType::PPV => {
if total_count > 0 {
positive_values.len() as f32 / total_count as f32
} else {
0.0
}
}
FeatureType::MPV => {
if !positive_values.is_empty() {
positive_values.iter().sum::<f32>() / positive_values.len() as f32
} else {
0.0
}
}
FeatureType::MIPV => {
if !positive_indices.is_empty() && total_count > 0 {
let mean_idx: f32 = positive_indices.iter().sum::<usize>() as f32
/ positive_indices.len() as f32;
mean_idx / total_count as f32
} else {
0.5 }
}
FeatureType::LSPV => {
if total_count > 0 {
longest_stretch as f32 / total_count as f32
} else {
0.0
}
}
};
let feat_idx = b * self.n_features + k_idx * n_feature_types + f_idx;
features[feat_idx] = feature_value;
}
}
}
Tensor::<B, 1>::from_floats(features.as_slice(), &device).reshape([batch, self.n_features])
}
}
#[derive(Module, Debug)]
pub struct MultiRocket<B: Backend> {
classifier: Linear<B>,
#[module(skip)]
n_features: usize,
}
impl<B: Backend> MultiRocket<B> {
pub fn new(config: MultiRocketConfig, device: &B::Device) -> Self {
let n_features = config.n_features();
let classifier = LinearConfig::new(n_features, config.n_classes).init(device);
Self {
classifier,
n_features,
}
}
pub fn n_features(&self) -> usize {
self.n_features
}
pub fn forward(&self, features: Tensor<B, 2>) -> Tensor<B, 2> {
self.classifier.forward(features)
}
pub fn forward_probs(&self, features: Tensor<B, 2>) -> Tensor<B, 2> {
let logits = self.forward(features);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multirocket_config_default() {
let config = MultiRocketConfig::default();
assert_eq!(config.n_kernels, 10000);
assert_eq!(config.feature_types.len(), 4);
assert_eq!(config.n_features(), 40000); }
#[test]
fn test_multirocket_config_builder() {
let config = MultiRocketConfig::new(3, 200, 10)
.with_n_kernels(5000)
.with_feature_types(vec![FeatureType::PPV, FeatureType::MPV]);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
assert_eq!(config.n_kernels, 5000);
assert_eq!(config.feature_types.len(), 2);
assert_eq!(config.n_features(), 10000); }
#[test]
fn test_multirocket_features_creation() {
let config = MultiRocketConfig::new(1, 100, 2).with_n_kernels(100);
let features = MultiRocketFeatures::new(&config);
assert_eq!(features.n_kernels, 100);
assert_eq!(features.kernels.len(), 100);
assert_eq!(features.feature_types.len(), 4);
assert_eq!(features.n_features, 400); }
}