use crate::config::BigVGANConfig;
use crate::error::{Result, VocoderError};
use crate::nn::{AMPBlock, Activation1d, SnakeBeta, WeightNormConv1d, WeightNormConvTranspose1d};
use mlx_rs::Array;
use std::path::Path;
use zerocopy::FromBytes;
#[derive(Debug)]
pub struct BigVGAN {
pub config: BigVGANConfig,
pub conv_pre: WeightNormConv1d,
pub upsamples: Vec<WeightNormConvTranspose1d>,
pub amp_blocks: Vec<Vec<AMPBlock>>,
pub activation_post: Activation1d<SnakeBeta>,
pub conv_post: WeightNormConv1d,
}
impl BigVGAN {
pub fn new(config: BigVGANConfig) -> Result<Self> {
let num_upsamples = config.upsample_rates.len();
let num_kernels = config.resblock_kernel_sizes.len();
let conv_pre = WeightNormConv1d::new(
config.num_mels,
config.upsample_initial_channel,
7,
Some(1),
Some(3),
None,
None,
Some(true),
)?;
let mut upsamples = Vec::with_capacity(num_upsamples);
let mut amp_blocks = Vec::with_capacity(num_upsamples);
let mut channels = config.upsample_initial_channel;
for i in 0..num_upsamples {
let out_channels = channels / 2;
let upsample_rate = config.upsample_rates[i];
let kernel_size = config.upsample_kernel_sizes[i];
let padding = (kernel_size - upsample_rate) / 2;
let upsample = WeightNormConvTranspose1d::new(
channels,
out_channels,
kernel_size,
Some(upsample_rate),
Some(padding),
None,
None,
None,
Some(true),
)?;
upsamples.push(upsample);
let mut stage_amps = Vec::with_capacity(num_kernels);
for j in 0..num_kernels {
let kernel_size = config.resblock_kernel_sizes[j];
let dilations = vec![config.resblock_dilation_sizes[j].clone()];
let amp = AMPBlock::new(out_channels, kernel_size, dilations)?;
stage_amps.push(amp);
}
amp_blocks.push(stage_amps);
channels = out_channels;
}
let activation_post = Activation1d::new(SnakeBeta::new(channels, true)?, Some(2), Some(2))?;
let conv_post = WeightNormConv1d::new(
channels,
1, 7,
Some(1),
Some(3),
None,
None,
Some(true),
)?;
Ok(Self {
config,
conv_pre,
upsamples,
amp_blocks,
activation_post,
conv_post,
})
}
pub fn v2_24khz_100band() -> Result<Self> {
Self::new(BigVGANConfig::v2_24khz_100band())
}
pub fn v2_44khz_128band() -> Result<Self> {
Self::new(BigVGANConfig::v2_44khz_128band())
}
pub fn load_weights(&mut self, path: &Path) -> Result<()> {
let file_data = std::fs::read(path).map_err(VocoderError::from)?;
let tensors = safetensors::SafeTensors::deserialize(&file_data)
.map_err(|e| VocoderError::WeightLoad(e.to_string()))?;
if let Ok(weight_v) = tensors.tensor("conv_pre.weight_v") {
let arr = tensor_to_array(weight_v)?;
self.conv_pre.weight_v.value = arr;
}
if let Ok(weight_g) = tensors.tensor("conv_pre.weight_g") {
let arr = tensor_to_array(weight_g)?;
self.conv_pre.weight_g.value = arr;
}
if let Ok(bias) = tensors.tensor("conv_pre.bias") {
if let Some(ref mut b) = self.conv_pre.bias {
let arr = tensor_to_array(bias)?;
b.value = arr;
}
}
for (i, upsample) in self.upsamples.iter_mut().enumerate() {
let prefix = format!("ups.{}", i);
if let Ok(weight_v) = tensors.tensor(&format!("{}.weight_v", prefix)) {
let arr = tensor_to_array(weight_v)?;
upsample.weight_v.value = arr;
}
if let Ok(weight_g) = tensors.tensor(&format!("{}.weight_g", prefix)) {
let arr = tensor_to_array(weight_g)?;
upsample.weight_g.value = arr;
}
}
if let Ok(weight_v) = tensors.tensor("conv_post.weight_v") {
let arr = tensor_to_array(weight_v)?;
self.conv_post.weight_v.value = arr;
}
if let Ok(weight_g) = tensors.tensor("conv_post.weight_g") {
let arr = tensor_to_array(weight_g)?;
self.conv_post.weight_g.value = arr;
}
if let Ok(bias) = tensors.tensor("conv_post.bias") {
if let Some(ref mut b) = self.conv_post.bias {
let arr = tensor_to_array(bias)?;
b.value = arr;
}
}
Ok(())
}
pub fn from_pretrained(model_id: &str) -> Result<Self> {
use hf_hub::api::sync::ApiBuilder;
let api = ApiBuilder::from_env()
.build()
.map_err(|e| VocoderError::Hub(e.to_string()))?;
let repo = api.model(model_id.to_string());
let config_path = repo
.get("config.json")
.map_err(|e| VocoderError::Hub(e.to_string()))?;
let config_str = std::fs::read_to_string(&config_path).map_err(VocoderError::from)?;
let config: BigVGANConfig =
serde_json::from_str(&config_str).map_err(|e| VocoderError::Config(e.to_string()))?;
let mut model = Self::new(config)?;
let weights_path = repo
.get("model.safetensors")
.map_err(|e| VocoderError::Hub(e.to_string()))?;
model.load_weights(&weights_path)?;
Ok(model)
}
pub fn forward(&self, mel: &Array) -> Result<Array> {
let mut x = self.conv_pre.forward(mel)?;
for (i, upsample) in self.upsamples.iter().enumerate() {
x = upsample.forward(&x)?;
let mut amp_out: Option<Array> = None;
for amp in &self.amp_blocks[i] {
let out = amp.forward(&x)?;
match &_out {
Some(o) => amp_out = Some(o.add(&out)?),
None => amp_out = Some(out),
}
}
let num_amps = Array::from_int(self.amp_blocks[i].len() as i32);
x = amp_out.unwrap().divide(&num_amps)?;
}
x = self.activation_post.forward(&x)?;
x = self.conv_post.forward(&x)?;
Ok(mlx_rs::ops::tanh(&x)?)
}
pub fn generate(&self, mel: &Array) -> Result<Array> {
let (mel, was_2d) = if mel.ndim() == 2 {
(mel.reshape(&[1, mel.dim(0), mel.dim(1)])?, true)
} else {
(mel.clone(), false)
};
let audio = self.forward(&mel)?;
let audio = audio.squeeze()?;
if was_2d {
Ok(audio.squeeze()?) } else {
Ok(audio)
}
}
}
fn tensor_to_array(tensor: safetensors::tensor::TensorView<'_>) -> Result<Array> {
let shape: Vec<i32> = tensor.shape().iter().map(|&s| s as i32).collect();
let data = tensor.data();
match tensor.dtype() {
safetensors::Dtype::F32 => {
let floats: &[f32] = <[f32]>::ref_from_bytes(data).expect("safetensors data aligned");
Ok(Array::from_slice(floats, &shape))
}
safetensors::Dtype::F16 => {
let f16s: &[half::f16] =
<[half::f16]>::ref_from_bytes(data).expect("safetensors data aligned");
let floats: Vec<f32> = f16s.iter().map(|f| f.to_f32()).collect();
Ok(Array::from_slice(&floats, &shape))
}
safetensors::Dtype::BF16 => {
let bf16s: &[half::bf16] =
<[half::bf16]>::ref_from_bytes(data).expect("safetensors data aligned");
let floats: Vec<f32> = bf16s.iter().map(|f| f.to_f32()).collect();
Ok(Array::from_slice(&floats, &shape))
}
_ => Err(VocoderError::WeightLoad(format!(
"Unsupported dtype: {:?}",
tensor.dtype()
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bigvgan_new() {
let config = BigVGANConfig::v2_24khz_100band();
let model = BigVGAN::new(config).unwrap();
assert_eq!(model.upsamples.len(), 6);
assert_eq!(model.amp_blocks.len(), 6);
}
#[test]
fn test_bigvgan_forward_shape() {
let config = BigVGANConfig::v2_24khz_100band();
let model = BigVGAN::new(config).unwrap();
let mel = mlx_rs::random::normal::<f32>(&[1, 100, 10], None, None, None).unwrap();
let audio = model.forward(&mel).unwrap();
audio.eval().unwrap();
assert_eq!(audio.dim(0), 1); assert_eq!(audio.dim(1), 1); }
#[test]
fn test_bigvgan_generate() {
let config = BigVGANConfig::base_24khz_100band();
let model = BigVGAN::new(config).unwrap();
let mel = mlx_rs::random::normal::<f32>(&[100, 8], None, None, None).unwrap();
let audio = model.generate(&mel).unwrap();
audio.eval().unwrap();
assert_eq!(audio.ndim(), 1);
}
#[test]
fn test_bigvgan_output_range() {
let config = BigVGANConfig::base_24khz_100band();
let model = BigVGAN::new(config).unwrap();
let mel = mlx_rs::random::normal::<f32>(&[1, 100, 4], None, None, None).unwrap();
let audio = model.forward(&mel).unwrap();
audio.eval().unwrap();
let max_val = audio.max(None).unwrap();
let min_val = audio.min(None).unwrap();
max_val.eval().unwrap();
min_val.eval().unwrap();
assert!(max_val.item::<f32>() <= 1.0);
assert!(min_val.item::<f32>() >= -1.0);
}
}