use burn::module::{Ignored, Param};
use burn::nn::conv::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig};
use burn::nn::{Embedding, EmbeddingConfig, PaddingConfig1d};
use burn::prelude::*;
use burn::tensor::{Int, TensorData};
#[derive(Module, Debug)]
pub struct CausalConv1d<B: Backend> {
pub conv: Conv1d<B>,
pub left_pad: usize,
}
impl<B: Backend> CausalConv1d<B> {
pub fn new(
channels_in: usize,
channels_out: usize,
kernel_size: usize,
stride: usize,
dilation: usize,
groups: usize,
padding: usize,
output_padding: usize,
bias: bool,
device: &B::Device,
) -> Self {
let conv = Conv1dConfig::new(channels_in, channels_out, kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(groups)
.with_padding(PaddingConfig1d::Valid)
.with_bias(bias)
.init(device);
Self {
conv,
left_pad: padding * 2 - output_padding,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = if self.left_pad > 0 {
x.pad((self.left_pad, 0, 0, 0), 0.0)
} else {
x
};
self.conv.forward(x)
}
}
#[derive(Module, Debug)]
pub struct CausalTransposeConv1d<B: Backend> {
pub conv: ConvTranspose1d<B>,
pub right_trim: usize,
}
impl<B: Backend> CausalTransposeConv1d<B> {
pub fn new(
channels_in: usize,
channels_out: usize,
kernel_size: usize,
stride: usize,
padding: usize,
output_padding: usize,
groups: usize,
bias: bool,
device: &B::Device,
) -> Self {
let conv = ConvTranspose1dConfig::new([channels_in, channels_out], kernel_size)
.with_stride(stride)
.with_padding(0)
.with_padding_out(0)
.with_groups(groups)
.with_bias(bias)
.init(device);
Self {
conv,
right_trim: padding * 2 - output_padding,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let out = self.conv.forward(x);
if self.right_trim > 0 {
let t = out.dims()[2];
out.narrow(2, 0, t - self.right_trim)
} else {
out
}
}
}
#[derive(Module, Debug)]
pub struct Snake1d<B: Backend> {
pub alpha: Param<Tensor<B, 3>>, }
impl<B: Backend> Snake1d<B> {
pub fn new(channels: usize, device: &B::Device) -> Self {
Self {
alpha: Param::from_tensor(Tensor::ones([1, channels, 1], device)),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let alpha = self.alpha.val();
let denom = alpha.clone().add_scalar(1e-9).recip();
let sin_sq = (alpha * x.clone()).sin().powf_scalar(2.0);
x + denom * sin_sq
}
}
#[derive(Module, Debug)]
pub struct CausalResidualUnit<B: Backend> {
pub snake1: Snake1d<B>,
pub conv1: CausalConv1d<B>,
pub snake2: Snake1d<B>,
pub conv2: CausalConv1d<B>,
}
impl<B: Backend> CausalResidualUnit<B> {
pub fn new(dim: usize, dilation: usize, groups: usize, device: &B::Device) -> Self {
let kernel = 7usize;
let pad = ((kernel - 1) * dilation) / 2;
Self {
snake1: Snake1d::new(dim, device),
conv1: CausalConv1d::new(dim, dim, kernel, 1, dilation, groups, pad, 0, true, device),
snake2: Snake1d::new(dim, device),
conv2: CausalConv1d::new(dim, dim, 1, 1, 1, 1, 0, 0, true, device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let y = self.snake1.forward(x.clone());
let y = self.conv1.forward(y);
let y = self.snake2.forward(y);
let y = self.conv2.forward(y);
x + y
}
}
#[derive(Module, Debug)]
pub struct NoiseBlock<B: Backend> {
pub linear: CausalConv1d<B>,
}
impl<B: Backend> NoiseBlock<B> {
pub fn new(dim: usize, device: &B::Device) -> Self {
Self {
linear: CausalConv1d::new(dim, dim, 1, 1, 1, 1, 0, 0, false, device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, _c, t] = x.dims();
let noise = Tensor::<B, 3>::random(
[b, 1, t],
burn::tensor::Distribution::Normal(0.0, 1.0),
&x.device(),
);
let h = self.linear.forward(x.clone());
x + noise * h
}
}
#[derive(Module, Debug)]
pub struct SampleRateConditionLayer<B: Backend> {
pub scale_embed: Embedding<B>,
pub bias_embed: Embedding<B>,
pub cond_type: Ignored<String>,
}
impl<B: Backend> SampleRateConditionLayer<B> {
pub fn new(input_dim: usize, sr_bin_buckets: usize, cond_type: &str, device: &B::Device) -> Self {
assert!(
cond_type == "scale_bias" || cond_type == "scale_bias_init",
"Only scale_bias cond_type is supported, got: {cond_type}"
);
Self {
scale_embed: EmbeddingConfig::new(sr_bin_buckets, input_dim).init(device),
bias_embed: EmbeddingConfig::new(sr_bin_buckets, input_dim).init(device),
cond_type: Ignored(cond_type.to_string()),
}
}
pub fn forward(&self, x: Tensor<B, 3>, sr_cond: Tensor<B, 1, Int>) -> Tensor<B, 3> {
let sr2: Tensor<B, 2, Int> = sr_cond.unsqueeze_dim(1); let scale = self.scale_embed.forward(sr2.clone()); let bias = self.bias_embed.forward(sr2); let scale = scale.swap_dims(1, 2); let bias = bias.swap_dims(1, 2); x * scale + bias
}
}
pub fn sr_bucket<B: Backend>(sr: i32, boundaries: &[i32], device: &B::Device) -> Tensor<B, 1, Int> {
let idx = boundaries.iter().position(|&b| sr < b).unwrap_or(boundaries.len()) as i32;
Tensor::<B, 1, Int>::from_data(TensorData::new(vec![idx as i64], [1]), device)
}