use burn::prelude::*;
use burn::nn::{
conv::{Conv1d, Conv1dConfig},
GroupNorm, GroupNormConfig,
};
use burn::tensor::activation::gelu;
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
pub conv: Conv1d<B>,
pub norm: Option<GroupNorm<B>>,
}
impl<B: Backend> ConvBlock<B> {
pub fn new(
in_ch: usize, out_ch: usize, kernel: usize, stride: usize,
use_group_norm: bool, conv_bias: bool, device: &B::Device,
) -> Self {
let conv = Conv1dConfig::new(in_ch, out_ch, kernel)
.with_stride(stride)
.with_bias(conv_bias)
.init(device);
let norm = if use_group_norm {
Some(GroupNormConfig::new(out_ch, out_ch).with_epsilon(1e-5).init(device))
} else {
None
};
Self { conv, norm }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv.forward(x);
let x = if let Some(ref norm) = self.norm {
norm.forward(x)
} else {
x
};
gelu(x)
}
}
#[derive(Module, Debug)]
pub struct ConvFeatureEncoder<B: Backend> {
pub blocks: Vec<ConvBlock<B>>,
pub n_channels: usize,
pub emb_dim: usize,
}
impl<B: Backend> ConvFeatureEncoder<B> {
pub fn new(
conv_layers_spec: &[(usize, usize, usize)],
n_channels: usize,
conv_bias: bool,
device: &B::Device,
) -> Self {
let mut blocks = Vec::new();
let mut in_ch = 1;
for (i, &(out_ch, kernel, stride)) in conv_layers_spec.iter().enumerate() {
let use_gn = i == 0; blocks.push(ConvBlock::new(in_ch, out_ch, kernel, stride, use_gn, conv_bias, device));
in_ch = out_ch;
}
let emb_dim = conv_layers_spec.last().unwrap().0;
Self { blocks, n_channels, emb_dim }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, channels, time] = x.dims();
let x = x.reshape([batch * channels, 1, time]);
let mut x = x;
for block in &self.blocks {
x = block.forward(x);
}
let [_bc, emb, time_out] = x.dims();
x.reshape([batch, channels, emb, time_out])
.swap_dims(2, 3) .reshape([batch, channels * time_out, emb])
}
pub fn n_times_out(&self, n_times: usize) -> usize {
let mut t = n_times;
for block in &self.blocks {
let k = block.conv.weight.dims()[2];
let s = block.conv.stride; t = t; }
t
}
}
pub fn n_times_out(spec: &[(usize, usize, usize)], n_times: usize) -> usize {
let mut t = n_times;
for &(_dim, kernel, stride) in spec {
t = (t - kernel) / stride + 1;
}
t
}