use super::DiscriminatorOutput;
use crate::error::Result;
use crate::nn::WeightNormConv1d;
use pmetal_bridge::compat::{Array, nn, ops};
#[derive(Debug)]
pub struct MultiPeriodDiscriminator {
pub discriminators: Vec<PeriodDiscriminator>,
}
impl MultiPeriodDiscriminator {
pub fn new() -> Result<Self> {
let periods = vec![2, 3, 5, 7, 11];
let discriminators = periods
.into_iter()
.map(PeriodDiscriminator::new)
.collect::<Result<Vec<_>>>()?;
Ok(Self { discriminators })
}
pub fn with_periods(periods: Vec<i32>) -> Result<Self> {
let discriminators = periods
.into_iter()
.map(PeriodDiscriminator::new)
.collect::<Result<Vec<_>>>()?;
Ok(Self { discriminators })
}
pub fn forward(&self, audio: &Array) -> Result<Vec<DiscriminatorOutput>> {
self.discriminators
.iter()
.map(|d| d.forward(audio))
.collect()
}
}
impl Default for MultiPeriodDiscriminator {
fn default() -> Self {
Self::new().expect("Failed to create MPD")
}
}
#[derive(Debug)]
pub struct PeriodDiscriminator {
pub period: i32,
pub convs: Vec<WeightNormConv1d>,
pub conv_post: WeightNormConv1d,
}
impl PeriodDiscriminator {
pub fn new(period: i32) -> Result<Self> {
let channels = vec![
(period, 32),
(32, 128),
(128, 512),
(512, 1024),
(1024, 1024),
];
let mut convs = Vec::with_capacity(channels.len());
for (i, (in_ch, out_ch)) in channels.iter().enumerate() {
let stride = if i < 4 { 3 } else { 1 };
let kernel = 5;
let padding = 2;
let conv = WeightNormConv1d::new(
*in_ch,
*out_ch,
kernel,
Some(stride),
Some(padding),
None,
None,
Some(true),
)?;
convs.push(conv);
}
let conv_post =
WeightNormConv1d::new(1024, 1, 3, Some(1), Some(1), None, None, Some(true))?;
Ok(Self {
period,
convs,
conv_post,
})
}
pub fn forward(&self, audio: &Array) -> Result<DiscriminatorOutput> {
let batch = audio.dim(0);
let samples = audio.dim(2);
let remainder = samples % self.period;
let x = if remainder != 0 {
let pad_size = self.period - remainder;
let padding = Array::zeros(&[batch, 1, pad_size], 10);
ops::concatenate_axis(&[audio, &padding], -1)
} else {
audio.clone()
};
let new_length = x.dim(2);
let x = x.reshape(&[batch, self.period, new_length / self.period]);
let mut features = Vec::new();
let mut x = x;
for conv in &self.convs {
x = conv.forward(&x)?;
x = nn::leaky_relu(&x, 0.1);
features.push(x.clone());
}
let logits = self.conv_post.forward(&x)?;
Ok(DiscriminatorOutput { logits, features })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_period_discriminator() {
let disc = PeriodDiscriminator::new(2).unwrap();
let audio = Array::random_normal(&[1, 1, 1024], 10);
let output = disc.forward(&audio).unwrap();
let l2 = output.logits.clone();
l2.eval();
assert!(!output.features.is_empty());
}
#[test]
fn test_mpd() {
let mpd = MultiPeriodDiscriminator::new().unwrap();
let audio = Array::random_normal(&[2, 1, 2048], 10);
let outputs = mpd.forward(&audio).unwrap();
assert_eq!(outputs.len(), 5); }
#[test]
fn test_mpd_different_periods() {
let mpd = MultiPeriodDiscriminator::with_periods(vec![2, 5, 11]).unwrap();
assert_eq!(mpd.discriminators.len(), 3);
}
}