use crate::error::Result;
use crate::nn::{Activation1d, Snake, SnakeBeta, WeightNormConv1d};
use pmetal_bridge::compat::Array;
#[derive(Debug)]
pub struct AMPBlock {
pub channels: i32,
pub kernel_sizes: Vec<i32>,
pub dilations: Vec<Vec<i32>>,
pub branches: Vec<ResidualBranch>,
}
#[derive(Debug)]
pub struct ResidualBranch {
pub layers: Vec<(Activation1d<SnakeBeta>, WeightNormConv1d)>,
}
impl AMPBlock {
pub fn new(channels: i32, kernel_size: i32, dilations: Vec<Vec<i32>>) -> Result<Self> {
let mut branches = Vec::with_capacity(dilations.len());
for branch_dilations in &dilations {
let mut layers = Vec::with_capacity(branch_dilations.len());
for &dilation in branch_dilations {
let activation = SnakeBeta::new(channels, true)?;
let act1d = Activation1d::new(activation, Some(2), Some(2))?;
let padding = (kernel_size - 1) * dilation / 2;
let conv = WeightNormConv1d::new(
channels,
channels,
kernel_size,
Some(1),
Some(padding),
Some(dilation),
None,
Some(true),
)?;
layers.push((act1d, conv));
}
branches.push(ResidualBranch { layers });
}
let kernel_sizes = vec![kernel_size; dilations.len()];
Ok(Self {
channels,
kernel_sizes,
dilations,
branches,
})
}
pub fn bigvgan_v2(channels: i32) -> Result<Self> {
Self::new(
channels,
3,
vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]],
)
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let mut output: Option<Array> = None;
for branch in &self.branches {
let mut branch_out = x.clone();
for (activation, conv) in &branch.layers {
branch_out = activation.forward(&branch_out)?;
branch_out = conv.forward(&branch_out)?;
}
match &output {
Some(o) => output = Some(o.add(&branch_out)),
None => output = Some(branch_out),
}
}
let num_branches = Array::from_i32(self.branches.len() as i32);
let branch_avg = output.unwrap().divide(&num_branches);
Ok(x.add(&branch_avg))
}
}
#[derive(Debug)]
pub struct AMPBlockSnake {
pub channels: i32,
pub branches: Vec<ResidualBranchSnake>,
}
#[derive(Debug)]
pub struct ResidualBranchSnake {
pub layers: Vec<(Activation1d<Snake>, WeightNormConv1d)>,
}
impl AMPBlockSnake {
pub fn new(channels: i32, kernel_size: i32, dilations: Vec<Vec<i32>>) -> Result<Self> {
let mut branches = Vec::with_capacity(dilations.len());
for branch_dilations in &dilations {
let mut layers = Vec::with_capacity(branch_dilations.len());
for &dilation in branch_dilations {
let activation = Snake::new(channels, true)?;
let act1d = Activation1d::new(activation, Some(2), Some(2))?;
let padding = (kernel_size - 1) * dilation / 2;
let conv = WeightNormConv1d::new(
channels,
channels,
kernel_size,
Some(1),
Some(padding),
Some(dilation),
None,
Some(true),
)?;
layers.push((act1d, conv));
}
branches.push(ResidualBranchSnake { layers });
}
Ok(Self { channels, branches })
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let mut output: Option<Array> = None;
for branch in &self.branches {
let mut branch_out = x.clone();
for (activation, conv) in &branch.layers {
branch_out = activation.forward(&branch_out)?;
branch_out = conv.forward(&branch_out)?;
}
match &output {
Some(o) => output = Some(o.add(&branch_out)),
None => output = Some(branch_out),
}
}
let num_branches = Array::from_i32(self.branches.len() as i32);
let branch_avg = output.unwrap().divide(&num_branches);
Ok(x.add(&branch_avg))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amp_block_shape() {
let amp = AMPBlock::new(64, 3, vec![vec![1, 3, 5]]).unwrap();
let x = Array::random_normal(&[1, 64, 128], 10);
let y = amp.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 64, 128]);
}
#[test]
fn test_amp_block_bigvgan_v2() {
let amp = AMPBlock::bigvgan_v2(128).unwrap();
let x = Array::random_normal(&[2, 128, 64], 10);
let y = amp.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[2, 128, 64]);
assert_eq!(amp.branches.len(), 3);
}
#[test]
fn test_amp_block_residual() {
let amp = AMPBlock::new(32, 3, vec![vec![1]]).unwrap();
let x = Array::random_normal(&[1, 32, 16], 10);
let y = amp.forward(&x).unwrap();
let x2 = x.clone();
x2.eval();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), x2.shape());
}
#[test]
fn test_amp_block_snake() {
let amp = AMPBlockSnake::new(64, 3, vec![vec![1, 3], vec![1, 3]]).unwrap();
let x = Array::random_normal(&[1, 64, 32], 10);
let y = amp.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 64, 32]);
}
#[test]
fn test_amp_block_multiple_branches() {
let amp =
AMPBlock::new(256, 3, vec![vec![1, 2], vec![3, 4], vec![5, 6], vec![7, 8]]).unwrap();
let x = Array::random_normal(&[1, 256, 64], 10);
let y = amp.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 256, 64]);
assert_eq!(amp.branches.len(), 4);
}
}