mod mpd;
mod mrd;
pub use mpd::{MultiPeriodDiscriminator, PeriodDiscriminator};
pub use mrd::{MultiResolutionDiscriminator, ResolutionDiscriminator};
use crate::error::Result;
use mlx_rs::Array;
#[derive(Debug)]
pub struct BigVGANDiscriminator {
pub mpd: MultiPeriodDiscriminator,
pub mrd: MultiResolutionDiscriminator,
}
impl BigVGANDiscriminator {
pub fn new() -> Result<Self> {
Ok(Self {
mpd: MultiPeriodDiscriminator::new()?,
mrd: MultiResolutionDiscriminator::new()?,
})
}
pub fn forward(
&self,
audio: &Array,
) -> Result<(Vec<DiscriminatorOutput>, Vec<DiscriminatorOutput>)> {
let mpd_outputs = self.mpd.forward(audio)?;
let mrd_outputs = self.mrd.forward(audio)?;
Ok((mpd_outputs, mrd_outputs))
}
}
impl Default for BigVGANDiscriminator {
fn default() -> Self {
Self::new().expect("Failed to create discriminator")
}
}
#[derive(Debug)]
pub struct DiscriminatorOutput {
pub logits: Array,
pub features: Vec<Array>,
}