use super::DiscriminatorOutput;
use crate::audio::{StftConfig, stft};
use crate::error::Result;
use crate::nn::WeightNormConv1d;
use mlx_rs::Array;
#[derive(Debug)]
pub struct MultiResolutionDiscriminator {
pub discriminators: Vec<ResolutionDiscriminator>,
}
impl MultiResolutionDiscriminator {
pub fn new() -> Result<Self> {
let resolutions = vec![
(1024, 120, 600), (2048, 240, 1200),
(512, 50, 240),
];
let discriminators = resolutions
.into_iter()
.map(|(n_fft, hop, win)| ResolutionDiscriminator::new(n_fft, hop, win))
.collect::<Result<Vec<_>>>()?;
Ok(Self { discriminators })
}
pub fn with_resolutions(resolutions: Vec<(i32, i32, i32)>) -> Result<Self> {
let discriminators = resolutions
.into_iter()
.map(|(n_fft, hop, win)| ResolutionDiscriminator::new(n_fft, hop, win))
.collect::<Result<Vec<_>>>()?;
Ok(Self { discriminators })
}
pub fn forward(&self, audio: &Array) -> Result<Vec<DiscriminatorOutput>> {
self.discriminators
.iter()
.map(|d| d.forward(audio))
.collect()
}
}
impl Default for MultiResolutionDiscriminator {
fn default() -> Self {
Self::new().expect("Failed to create MRD")
}
}
#[derive(Debug)]
pub struct ResolutionDiscriminator {
pub stft_config: StftConfig,
pub convs: Vec<WeightNormConv1d>,
pub conv_post: WeightNormConv1d,
}
impl ResolutionDiscriminator {
pub fn new(n_fft: i32, hop_length: i32, win_length: i32) -> Result<Self> {
let stft_config = StftConfig {
n_fft,
hop_length,
win_length: Some(win_length),
center: true,
..Default::default()
};
let n_freq = n_fft / 2 + 1;
let channels = vec![
(n_freq, 32),
(32, 128),
(128, 512),
(512, 1024),
(1024, 1024),
];
let mut convs = Vec::with_capacity(channels.len());
for (i, (in_ch, out_ch)) in channels.iter().enumerate() {
let stride = if i < 4 { 2 } else { 1 };
let conv = WeightNormConv1d::new(
*in_ch,
*out_ch,
3,
Some(stride),
Some(1),
None,
None,
Some(true),
)?;
convs.push(conv);
}
let conv_post =
WeightNormConv1d::new(1024, 1, 3, Some(1), Some(1), None, None, Some(true))?;
Ok(Self {
stft_config,
convs,
conv_post,
})
}
pub fn forward(&self, audio: &Array) -> Result<DiscriminatorOutput> {
let audio_2d = audio.squeeze()?;
let stft_out = stft(&audio_2d, &self.stft_config)?;
let magnitude = stft_out.abs()?;
let x = if magnitude.ndim() == 2 {
magnitude.reshape(&[1, magnitude.dim(0), magnitude.dim(1)])?
} else {
magnitude
};
let mut features = Vec::new();
let mut x = x;
for conv in &self.convs {
x = conv.forward(&x)?;
x = mlx_rs::nn::leaky_relu(&x, 0.1)?;
features.push(x.clone());
}
let logits = self.conv_post.forward(&x)?;
Ok(DiscriminatorOutput { logits, features })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolution_discriminator() {
let disc = ResolutionDiscriminator::new(1024, 256, 1024).unwrap();
let audio = mlx_rs::random::normal::<f32>(&[1, 1, 4096], None, None, None).unwrap();
let output = disc.forward(&audio).unwrap();
output.logits.eval().unwrap();
assert!(!output.features.is_empty());
}
#[test]
fn test_mrd() {
let mrd = MultiResolutionDiscriminator::new().unwrap();
let audio = mlx_rs::random::normal::<f32>(&[1, 1, 8000], None, None, None).unwrap();
let outputs = mrd.forward(&audio).unwrap();
assert_eq!(outputs.len(), 3); }
#[test]
fn test_mrd_custom_resolutions() {
let mrd = MultiResolutionDiscriminator::with_resolutions(vec![
(512, 128, 512),
(1024, 256, 1024),
])
.unwrap();
assert_eq!(mrd.discriminators.len(), 2);
}
}