use crate::audio::{MelConfig, StftConfig, mel_spectrogram};
use crate::discriminator::DiscriminatorOutput;
use crate::error::Result;
use mlx_rs::Array;
pub fn generator_adversarial_loss(disc_outputs: &[DiscriminatorOutput]) -> Result<Array> {
let mut total_loss = Array::from_f32(0.0);
for output in disc_outputs {
let one = Array::from_f32(1.0);
let diff = one.subtract(&output.logits)?;
let squared = diff.multiply(&diff)?;
let loss = squared.mean(None)?;
total_loss = total_loss.add(&loss)?;
}
Ok(total_loss)
}
pub fn discriminator_adversarial_loss(
real_outputs: &[DiscriminatorOutput],
fake_outputs: &[DiscriminatorOutput],
) -> Result<(Array, Array)> {
let mut real_loss = Array::from_f32(0.0);
let mut fake_loss = Array::from_f32(0.0);
for (real, fake) in real_outputs.iter().zip(fake_outputs.iter()) {
let one = Array::from_f32(1.0);
let real_diff = one.subtract(&real.logits)?;
let real_sq = real_diff.multiply(&real_diff)?;
real_loss = real_loss.add(&real_sq.mean(None)?)?;
let fake_sq = fake.logits.multiply(&fake.logits)?;
fake_loss = fake_loss.add(&fake_sq.mean(None)?)?;
}
Ok((real_loss, fake_loss))
}
pub fn feature_matching_loss(
real_outputs: &[DiscriminatorOutput],
fake_outputs: &[DiscriminatorOutput],
) -> Result<Array> {
let mut total_loss = Array::from_f32(0.0);
let mut num_features = 0;
for (real, fake) in real_outputs.iter().zip(fake_outputs.iter()) {
for (real_feat, fake_feat) in real.features.iter().zip(fake.features.iter()) {
let diff = real_feat.subtract(fake_feat)?;
let abs_diff = diff.abs()?;
let loss = abs_diff.mean(None)?;
total_loss = total_loss.add(&loss)?;
num_features += 1;
}
}
if num_features > 0 {
let num = Array::from_int(num_features);
total_loss = total_loss.divide(&num)?;
}
Ok(total_loss)
}
pub fn mel_reconstruction_loss(
real_audio: &Array,
fake_audio: &Array,
mel_config: &MelConfig,
stft_config: &StftConfig,
) -> Result<Array> {
let real = if real_audio.ndim() == 3 {
real_audio.squeeze()?
} else {
real_audio.clone()
};
let fake = if fake_audio.ndim() == 3 {
fake_audio.squeeze()?
} else {
fake_audio.clone()
};
let real_mel = mel_spectrogram(&real, mel_config, stft_config)?;
let fake_mel = mel_spectrogram(&fake, mel_config, stft_config)?;
let diff = real_mel.subtract(&fake_mel)?;
let abs_diff = diff.abs()?;
Ok(abs_diff.mean(None)?)
}
pub fn multi_scale_mel_loss(
real_audio: &Array,
fake_audio: &Array,
mel_config: &MelConfig,
) -> Result<Array> {
let scales = vec![
(512, 128), (1024, 256),
(2048, 512),
];
let mut total_loss = Array::from_f32(0.0);
for (n_fft, hop_length) in scales {
let stft_config = StftConfig {
n_fft,
hop_length,
win_length: Some(n_fft),
center: true,
..Default::default()
};
let loss = mel_reconstruction_loss(real_audio, fake_audio, mel_config, &stft_config)?;
total_loss = total_loss.add(&loss)?;
}
let num_scales = Array::from_f32(3.0);
Ok(total_loss.divide(&num_scales)?)
}
#[derive(Debug, Clone)]
pub struct GeneratorLossConfig {
pub lambda_adv: f32,
pub lambda_fm: f32,
pub lambda_mel: f32,
}
impl Default for GeneratorLossConfig {
fn default() -> Self {
Self {
lambda_adv: 1.0,
lambda_fm: 2.0,
lambda_mel: 45.0,
}
}
}
pub fn generator_loss(
real_audio: &Array,
fake_audio: &Array,
real_outputs_mpd: &[DiscriminatorOutput],
fake_outputs_mpd: &[DiscriminatorOutput],
real_outputs_mrd: &[DiscriminatorOutput],
fake_outputs_mrd: &[DiscriminatorOutput],
mel_config: &MelConfig,
config: &GeneratorLossConfig,
) -> Result<GeneratorLossOutput> {
let adv_loss_mpd = generator_adversarial_loss(fake_outputs_mpd)?;
let adv_loss_mrd = generator_adversarial_loss(fake_outputs_mrd)?;
let adv_loss = adv_loss_mpd.add(&adv_loss_mrd)?;
let fm_loss_mpd = feature_matching_loss(real_outputs_mpd, fake_outputs_mpd)?;
let fm_loss_mrd = feature_matching_loss(real_outputs_mrd, fake_outputs_mrd)?;
let fm_loss = fm_loss_mpd.add(&fm_loss_mrd)?;
let mel_loss = multi_scale_mel_loss(real_audio, fake_audio, mel_config)?;
let total = adv_loss
.multiply(&Array::from_f32(config.lambda_adv))?
.add(&fm_loss.multiply(&Array::from_f32(config.lambda_fm))?)?
.add(&mel_loss.multiply(&Array::from_f32(config.lambda_mel))?)?;
Ok(GeneratorLossOutput {
total,
adversarial: adv_loss,
feature_matching: fm_loss,
mel: mel_loss,
})
}
#[derive(Debug)]
pub struct GeneratorLossOutput {
pub total: Array,
pub adversarial: Array,
pub feature_matching: Array,
pub mel: Array,
}
pub fn discriminator_loss(
real_outputs: &[DiscriminatorOutput],
fake_outputs: &[DiscriminatorOutput],
) -> Result<DiscriminatorLossOutput> {
let (real_loss, fake_loss) = discriminator_adversarial_loss(real_outputs, fake_outputs)?;
let total = real_loss.add(&fake_loss)?;
Ok(DiscriminatorLossOutput {
total,
real: real_loss,
fake: fake_loss,
})
}
#[derive(Debug)]
pub struct DiscriminatorLossOutput {
pub total: Array,
pub real: Array,
pub fake: Array,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_dummy_outputs(batch: i32) -> Vec<DiscriminatorOutput> {
vec![DiscriminatorOutput {
logits: mlx_rs::random::normal::<f32>(&[batch, 1, 10], None, None, None).unwrap(),
features: vec![
mlx_rs::random::normal::<f32>(&[batch, 32, 64], None, None, None).unwrap(),
mlx_rs::random::normal::<f32>(&[batch, 64, 32], None, None, None).unwrap(),
],
}]
}
#[test]
fn test_generator_adversarial_loss() {
let outputs = create_dummy_outputs(2);
let loss = generator_adversarial_loss(&outputs).unwrap();
loss.eval().unwrap();
assert_eq!(loss.ndim(), 0); }
#[test]
fn test_discriminator_adversarial_loss() {
let real = create_dummy_outputs(2);
let fake = create_dummy_outputs(2);
let (real_loss, fake_loss) = discriminator_adversarial_loss(&real, &fake).unwrap();
real_loss.eval().unwrap();
fake_loss.eval().unwrap();
assert_eq!(real_loss.ndim(), 0);
assert_eq!(fake_loss.ndim(), 0);
}
#[test]
fn test_feature_matching_loss() {
let real = create_dummy_outputs(2);
let fake = create_dummy_outputs(2);
let loss = feature_matching_loss(&real, &fake).unwrap();
loss.eval().unwrap();
assert_eq!(loss.ndim(), 0);
}
#[test]
fn test_mel_reconstruction_loss() {
let mel_config = MelConfig::default();
let stft_config = StftConfig::default();
let real = mlx_rs::random::normal::<f32>(&[1, 8000], None, None, None).unwrap();
let fake = mlx_rs::random::normal::<f32>(&[1, 8000], None, None, None).unwrap();
let loss = mel_reconstruction_loss(&real, &fake, &mel_config, &stft_config).unwrap();
loss.eval().unwrap();
assert_eq!(loss.ndim(), 0);
}
#[test]
fn test_loss_config_default() {
let config = GeneratorLossConfig::default();
assert_eq!(config.lambda_adv, 1.0);
assert_eq!(config.lambda_fm, 2.0);
assert_eq!(config.lambda_mel, 45.0);
}
}