use burn::nn::{
conv::{Conv1d, Conv1dConfig},
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct XCMPlusConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_filters: usize,
pub window_sizes: Vec<usize>,
}
impl Default for XCMPlusConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_filters: 128,
window_sizes: vec![10, 20, 40],
}
}
}
impl XCMPlusConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> XCMPlus<B> {
XCMPlus::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct XCMPlus<B: Backend> {
time_convs: Vec<Conv1d<B>>,
time_bns: Vec<BatchNorm<B, 1>>,
var_conv: Conv1d<B>,
var_bn: BatchNorm<B, 1>,
gap: AdaptiveAvgPool1d,
fc: Linear<B>,
}
impl<B: Backend> XCMPlus<B> {
pub fn new(config: XCMPlusConfig, device: &B::Device) -> Self {
let mut time_convs = Vec::new();
let mut time_bns = Vec::new();
for &window_size in &config.window_sizes {
let conv = Conv1dConfig::new(config.n_vars, config.n_filters, window_size)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(config.n_filters).init(device);
time_convs.push(conv);
time_bns.push(bn);
}
let var_conv = Conv1dConfig::new(config.seq_len, config.n_filters, 1)
.with_bias(false)
.init(device);
let var_bn = BatchNormConfig::new(config.n_filters).init(device);
let gap = AdaptiveAvgPool1dConfig::new(1).init();
let combined_features = config.n_filters * (config.window_sizes.len() + 1);
let fc = LinearConfig::new(combined_features, config.n_classes).init(device);
Self {
time_convs,
time_bns,
var_conv,
var_bn,
gap,
fc,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let relu = Relu::new();
let mut features = Vec::new();
for (conv, bn) in self.time_convs.iter().zip(&self.time_bns) {
let out = conv.forward(x.clone());
let out = bn.forward(out);
let out = relu.forward(out);
let out = self.gap.forward(out);
features.push(out);
}
let x_t = x.swap_dims(1, 2); let var_out = self.var_conv.forward(x_t);
let var_out = self.var_bn.forward(var_out);
let var_out = relu.forward(var_out);
let var_out = self.gap.forward(var_out);
features.push(var_out);
let combined = Tensor::cat(features, 1);
let [batch, channels, _] = combined.dims();
let combined = combined.reshape([batch, channels]);
self.fc.forward(combined)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xcm_config() {
let config = XCMPlusConfig::default();
assert_eq!(config.n_filters, 128);
assert_eq!(config.window_sizes.len(), 3);
}
}