use candle_core::{DType, Module, Result, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder, conv1d};
use crate::audio::encoder::transformer::TransformerLayer;
use crate::nn::rope::simple::SimpleRotaryEmbedding;
#[derive(Debug, Clone)]
struct SeaNetResBlock {
conv1: Conv1d,
conv2: Conv1d,
}
impl SeaNetResBlock {
fn new(channels: usize, compress: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
let hidden = channels / compress;
let conv1 = conv1d(
channels,
hidden,
3,
Conv1dConfig {
padding: 0,
dilation,
..Default::default()
},
vb.pp("block.1.conv"),
)?;
let conv2 = conv1d(
hidden,
channels,
1,
Conv1dConfig::default(),
vb.pp("block.3.conv"),
)?;
Ok(Self { conv1, conv2 })
}
fn forward(&self, xs: &Tensor, dilation: usize) -> Result<Tensor> {
let activated = xs.elu(1.0)?;
let effective_kernel = (3 - 1) * dilation + 1;
let padding = effective_kernel - 1;
let padded = activated.pad_with_zeros(2, padding, 0)?;
let h = self.conv1.forward(&padded)?;
let h = h.elu(1.0)?;
let h = self.conv2.forward(&h)?;
xs + h
}
}
#[derive(Debug, Clone)]
pub struct SeaNetEncoder {
init_conv: Conv1d,
stages: Vec<(SeaNetResBlock, Conv1d)>,
final_conv: Conv1d,
ratios: Vec<usize>,
}
impl SeaNetEncoder {
pub fn new(vb: VarBuilder) -> Result<Self> {
let ratios = vec![4, 5, 6, 8];
let n_filters = 64;
let compress = 2;
let vb = vb.pp("layers");
let init_conv = conv1d(1, n_filters, 7, Conv1dConfig::default(), vb.pp("0.conv"))?;
let mut stages = Vec::new();
let mut mult = 1usize;
let mut layer_idx = 1;
for &ratio in ratios.iter() {
let resblock = SeaNetResBlock::new(
mult * n_filters,
compress,
1, vb.pp(layer_idx),
)?;
layer_idx += 1;
let downsample = conv1d(
mult * n_filters,
mult * n_filters * 2,
ratio * 2,
Conv1dConfig {
stride: ratio,
..Default::default()
},
vb.pp(format!("{}.conv", layer_idx + 1)),
)?;
layer_idx += 2;
stages.push((resblock, downsample));
mult *= 2;
}
let final_conv = conv1d(
mult * n_filters, 512,
3,
Conv1dConfig::default(),
vb.pp(format!("{}.conv", layer_idx + 1)),
)?;
Ok(Self {
init_conv,
stages,
final_conv,
ratios,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let padded = xs.pad_with_zeros(2, 6, 0)?; let mut h = self.init_conv.forward(&padded)?;
for (i, (resblock, downsample)) in self.stages.iter().enumerate() {
h = resblock.forward(&h, 1)?;
h = h.elu(1.0)?;
let ratio = self.ratios[i];
let kernel_size = ratio * 2;
let padding = kernel_size - ratio; let padded = h.pad_with_zeros(2, padding, 0)?;
h = downsample.forward(&padded)?;
}
h = h.elu(1.0)?;
let padded = h.pad_with_zeros(2, 2, 0)?; self.final_conv.forward(&padded)
}
}
#[derive(Debug, Clone)]
pub struct EncoderTransformer {
layers: Vec<TransformerLayer>,
rope: SimpleRotaryEmbedding,
}
impl EncoderTransformer {
pub fn new(
dim: usize,
num_heads: usize,
mlp_dim: usize,
num_layers: usize,
device: &candle_core::Device,
dtype: DType,
vb: VarBuilder,
) -> Result<Self> {
let layers = (0..num_layers)
.map(|i| TransformerLayer::new(dim, num_heads, mlp_dim, vb.pp(format!("layers.{}", i))))
.collect::<Result<Vec<_>>>()?;
let head_dim = dim / num_heads;
let rope = SimpleRotaryEmbedding::new(head_dim, 8192, 10000.0, device, dtype)?;
Ok(Self { layers, rope })
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut h = xs.clone();
for layer in &self.layers {
h = layer.forward(&h, &self.rope)?;
}
Ok(h)
}
}
#[derive(Debug, Clone)]
pub struct Downsample {
conv: Conv1d,
stride: usize,
}
impl Downsample {
pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
let kernel_size = stride * 2;
let weight = vb.get((dim, dim, kernel_size), "conv.weight")?;
let config = Conv1dConfig {
stride,
..Default::default()
};
let conv = Conv1d::new(weight, None, config);
Ok(Self { conv, stride })
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let kernel_size = self.stride * 2;
let padding = kernel_size - self.stride;
let padded = xs.pad_with_zeros(2, padding, 0)?;
self.conv.forward(&padded)
}
}