use crate::config::BigVGANConfig;
use crate::error::{Result, VocoderError};
use crate::nn::{AMPBlock, Activation1d, SnakeBeta, WeightNormConv1d, WeightNormConvTranspose1d};
use pmetal_bridge::compat::Array;
use std::path::Path;
use zerocopy::FromBytes;
#[derive(Debug)]
pub struct BigVGAN {
pub config: BigVGANConfig,
pub conv_pre: WeightNormConv1d,
pub upsamples: Vec<WeightNormConvTranspose1d>,
pub amp_blocks: Vec<Vec<AMPBlock>>,
pub activation_post: Activation1d<SnakeBeta>,
pub conv_post: WeightNormConv1d,
}
impl BigVGAN {
pub fn new(config: BigVGANConfig) -> Result<Self> {
let num_upsamples = config.upsample_rates.len();
let num_kernels = config.resblock_kernel_sizes.len();
let conv_pre = WeightNormConv1d::new(
config.num_mels,
config.upsample_initial_channel,
7,
Some(1),
Some(3),
None,
None,
Some(true),
)?;
let mut upsamples = Vec::with_capacity(num_upsamples);
let mut amp_blocks = Vec::with_capacity(num_upsamples);
let mut channels = config.upsample_initial_channel;
for i in 0..num_upsamples {
let out_channels = channels / 2;
let upsample_rate = config.upsample_rates[i];
let kernel_size = config.upsample_kernel_sizes[i];
let padding = (kernel_size - upsample_rate) / 2;
let upsample = WeightNormConvTranspose1d::new(
channels,
out_channels,
kernel_size,
Some(upsample_rate),
Some(padding),
None,
None,
None,
Some(true),
)?;
upsamples.push(upsample);
let mut stage_amps = Vec::with_capacity(num_kernels);
for j in 0..num_kernels {
let kernel_size = config.resblock_kernel_sizes[j];
let dilations = vec![config.resblock_dilation_sizes[j].clone()];
let amp = AMPBlock::new(out_channels, kernel_size, dilations)?;
stage_amps.push(amp);
}
amp_blocks.push(stage_amps);
channels = out_channels;
}
let activation_post = Activation1d::new(SnakeBeta::new(channels, true)?, Some(2), Some(2))?;
let conv_post = WeightNormConv1d::new(
channels,
1, 7,
Some(1),
Some(3),
None,
None,
Some(true),
)?;
Ok(Self {
config,
conv_pre,
upsamples,
amp_blocks,
activation_post,
conv_post,
})
}
pub fn v2_24khz_100band() -> Result<Self> {
Self::new(BigVGANConfig::v2_24khz_100band())
}
pub fn v2_44khz_128band() -> Result<Self> {
Self::new(BigVGANConfig::v2_44khz_128band())
}
pub fn load_weights(&mut self, path: &Path) -> Result<()> {
let file_data = std::fs::read(path).map_err(VocoderError::from)?;
let tensors = safetensors::SafeTensors::deserialize(&file_data)
.map_err(|e| VocoderError::WeightLoad(e.to_string()))?;
if let Ok(weight_v) = tensors.tensor("conv_pre.weight_v") {
let arr = tensor_to_array(weight_v)?;
self.conv_pre.weight_v.value = arr;
}
if let Ok(weight_g) = tensors.tensor("conv_pre.weight_g") {
let arr = tensor_to_array(weight_g)?;
self.conv_pre.weight_g.value = arr;
}
if let Ok(bias) = tensors.tensor("conv_pre.bias") {
if let Some(ref mut b) = self.conv_pre.bias {
let arr = tensor_to_array(bias)?;
b.value = arr;
}
}
for (i, upsample) in self.upsamples.iter_mut().enumerate() {
let prefix = format!("ups.{}", i);
if let Ok(weight_v) = tensors.tensor(&format!("{}.weight_v", prefix)) {
let arr = tensor_to_array(weight_v)?;
upsample.weight_v.value = arr;
}
if let Ok(weight_g) = tensors.tensor(&format!("{}.weight_g", prefix)) {
let arr = tensor_to_array(weight_g)?;
upsample.weight_g.value = arr;
}
}
if let Ok(weight_v) = tensors.tensor("conv_post.weight_v") {
let arr = tensor_to_array(weight_v)?;
self.conv_post.weight_v.value = arr;
}
if let Ok(weight_g) = tensors.tensor("conv_post.weight_g") {
let arr = tensor_to_array(weight_g)?;
self.conv_post.weight_g.value = arr;
}
if let Ok(bias) = tensors.tensor("conv_post.bias") {
if let Some(ref mut b) = self.conv_post.bias {
let arr = tensor_to_array(bias)?;
b.value = arr;
}
}
Ok(())
}
pub fn from_pretrained(model_id: &str) -> Result<Self> {
use hf_hub::api::sync::ApiBuilder;
let api = ApiBuilder::from_env()
.build()
.map_err(|e| VocoderError::Hub(e.to_string()))?;
let repo = api.model(model_id.to_string());
let config_path = repo
.get("config.json")
.map_err(|e| VocoderError::Hub(e.to_string()))?;
let config_str = std::fs::read_to_string(&config_path).map_err(VocoderError::from)?;
let config: BigVGANConfig =
serde_json::from_str(&config_str).map_err(|e| VocoderError::Config(e.to_string()))?;
let mut model = Self::new(config)?;
let weights_path = repo
.get("model.safetensors")
.map_err(|e| VocoderError::Hub(e.to_string()))?;
model.load_weights(&weights_path)?;
Ok(model)
}
pub fn forward(&self, mel: &Array) -> Result<Array> {
let mut x = self.conv_pre.forward(mel)?;
for (i, upsample) in self.upsamples.iter().enumerate() {
x = upsample.forward(&x)?;
let mut amp_out: Option<Array> = None;
for amp in &self.amp_blocks[i] {
let out = amp.forward(&x)?;
match &_out {
Some(o) => amp_out = Some(o.add(&out)),
None => amp_out = Some(out),
}
}
let num_amps = Array::from_i32(self.amp_blocks[i].len() as i32);
x = amp_out.unwrap().divide(&num_amps);
}
x = self.activation_post.forward(&x)?;
x = self.conv_post.forward(&x)?;
let two = Array::from_f32(2.0);
let two_x = x.multiply(&two);
let sig = two_x.sigmoid();
let one = Array::from_f32(1.0);
let two2 = Array::from_f32(2.0);
Ok(sig.multiply(&two2).subtract(&one))
}
pub fn generate(&self, mel: &Array) -> Result<Array> {
let (mel, was_2d) = if mel.ndim() == 2 {
(mel.reshape(&[1, mel.dim(0), mel.dim(1)]), true)
} else {
(mel.clone(), false)
};
let audio = self.forward(&mel)?;
let audio = audio.squeeze_all();
if was_2d {
Ok(audio.squeeze_all())
} else {
Ok(audio)
}
}
}
fn tensor_to_array(tensor: safetensors::tensor::TensorView<'_>) -> Result<Array> {
let shape: Vec<i32> = tensor.shape().iter().map(|&s| s as i32).collect();
let data = tensor.data();
match tensor.dtype() {
safetensors::Dtype::F32 => {
let floats: &[f32] = <[f32]>::ref_from_bytes(data).expect("safetensors data aligned");
Ok(Array::from_f32_slice(floats, &shape))
}
safetensors::Dtype::F16 => {
let f16s: &[half::f16] =
<[half::f16]>::ref_from_bytes(data).expect("safetensors data aligned");
let floats: Vec<f32> = f16s.iter().map(|f| f.to_f32()).collect();
Ok(Array::from_f32_slice(&floats, &shape))
}
safetensors::Dtype::BF16 => {
let bf16s: &[half::bf16] =
<[half::bf16]>::ref_from_bytes(data).expect("safetensors data aligned");
let floats: Vec<f32> = bf16s.iter().map(|f| f.to_f32()).collect();
Ok(Array::from_f32_slice(&floats, &shape))
}
_ => Err(VocoderError::WeightLoad(format!(
"Unsupported dtype: {:?}",
tensor.dtype()
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bigvgan_new() {
let config = BigVGANConfig::v2_24khz_100band();
let model = BigVGAN::new(config).unwrap();
assert_eq!(model.upsamples.len(), 6);
assert_eq!(model.amp_blocks.len(), 6);
}
#[test]
fn test_bigvgan_forward_shape() {
let config = BigVGANConfig::v2_24khz_100band();
let model = BigVGAN::new(config).unwrap();
let mel = Array::random_normal(&[1, 100, 10], 10);
let audio = model.forward(&mel).unwrap();
let a2 = audio.clone();
a2.eval();
assert_eq!(a2.dim(0), 1); assert_eq!(a2.dim(1), 1); }
#[test]
fn test_bigvgan_generate() {
let config = BigVGANConfig::base_24khz_100band();
let model = BigVGAN::new(config).unwrap();
let mel = Array::random_normal(&[100, 8], 10);
let audio = model.generate(&mel).unwrap();
let a2 = audio.clone();
a2.eval();
assert_eq!(a2.ndim(), 1);
}
#[test]
fn test_bigvgan_output_range() {
let config = BigVGANConfig::base_24khz_100band();
let model = BigVGAN::new(config).unwrap();
let mel = Array::random_normal(&[1, 100, 4], 10);
let audio = model.forward(&mel).unwrap();
let a2 = audio.clone();
a2.eval();
let max_val = a2.max(None);
let min_val = a2.min(None);
max_val.eval();
min_val.eval();
assert!(max_val.item_f32() <= 1.0);
assert!(min_val.item_f32() >= -1.0);
}
}