use candle_core::{Result, Tensor};
use candle_nn::{LayerNorm, Linear, Module, VarBuilder, layer_norm, linear};
use crate::audio::tokenizer::v2::causal_conv::CausalConv1d;
#[derive(Debug, Clone)]
pub struct ConvNeXtBlock {
dwconv: CausalConv1d,
norm: LayerNorm,
pwconv1: Linear,
pwconv2: Linear,
gamma: Tensor,
}
impl ConvNeXtBlock {
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let dwconv = CausalConv1d::new(
dim,
dim,
7, 1, 1, dim, vb.pp("dwconv"),
)?;
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
let pwconv1 = linear(dim, 4 * dim, vb.pp("pwconv1"))?;
let pwconv2 = linear(4 * dim, dim, vb.pp("pwconv2"))?;
let gamma = vb.get_with_hints(dim, "gamma", candle_nn::Init::Const(1e-6))?;
Ok(Self {
dwconv,
norm,
pwconv1,
pwconv2,
gamma,
})
}
pub fn load(dim: usize, vb: VarBuilder) -> Result<Self> {
Self::new(dim, vb)
}
}
impl Module for ConvNeXtBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let hidden = self.dwconv.forward(xs)?;
let hidden = hidden.transpose(1, 2)?;
let hidden = self.norm.forward(&hidden)?;
let hidden = self.pwconv1.forward(&hidden)?;
let hidden = hidden.gelu()?;
let hidden = self.pwconv2.forward(&hidden)?;
let hidden = hidden.broadcast_mul(&self.gamma)?;
let hidden = hidden.transpose(1, 2)?;
residual + hidden
}
}