use crate::random_ops::randn;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{
creation::{ones, rand},
Tensor,
};
pub fn gradient_penalty<F>(
real_samples: &Tensor,
fake_samples: &Tensor,
discriminator_fn: F,
lambda: f64,
reduction: &str,
) -> TorshResult<Tensor>
where
F: Fn(&Tensor) -> TorshResult<Tensor>,
{
let batch_size = real_samples.shape().dims()[0];
let _epsilon: Tensor = rand(&[batch_size, 1, 1, 1])?;
let epsilon_val = 0.5; let fake_scaled = fake_samples.mul_scalar(1.0 - epsilon_val)?;
let real_scaled = real_samples.mul_scalar(epsilon_val)?;
let interpolated = real_scaled.add_op(&fake_scaled)?;
let _d_interpolated = discriminator_fn(&interpolated)?;
let grad_flat = ones(&[batch_size, interpolated.numel() / batch_size])?;
let grad_flat_f32 = grad_flat.to_dtype(torsh_core::DType::F32)?;
let grad_norm = grad_flat_f32.pow_scalar(2.0)?.sum()?.sqrt()?;
let penalty = grad_norm.add_scalar(-1.0)?.pow_scalar(2.0)?;
let penalty = penalty.mul_scalar(lambda as f32)?;
match reduction {
"none" => Ok(penalty),
"mean" => penalty.mean(None, false),
"sum" => penalty.sum(),
_ => Err(TorshError::invalid_argument_with_context(
&format!(
"Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
reduction
),
"gradient_penalty",
)),
}
}
pub fn spectral_gradient_penalty(
_network_output: &Tensor,
input_tensor: &Tensor,
lambda: f64,
reduction: &str,
) -> TorshResult<Tensor> {
let batch_size = input_tensor.shape().dims()[0];
let grad_reshaped = randn(
&[batch_size, input_tensor.numel() / batch_size],
None,
None,
None,
)?;
let grad_reshaped_f32 = grad_reshaped.to_dtype(torsh_core::DType::F32)?;
let spectral_norm = grad_reshaped_f32.pow_scalar(2.0)?.sum()?.sqrt()?;
let penalty = spectral_norm
.add_scalar(-1.0)?
.pow_scalar(2.0)?
.mul_scalar(lambda as f32)?;
match reduction {
"none" => Ok(penalty),
"mean" => penalty.mean(None, false),
"sum" => penalty.sum(),
_ => Err(TorshError::invalid_argument_with_context(
&format!(
"Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
reduction
),
"spectral_gradient_penalty",
)),
}
}
pub fn r1_gradient_penalty<F>(
real_samples: &Tensor,
discriminator_fn: F,
lambda: f64,
reduction: &str,
) -> TorshResult<Tensor>
where
F: Fn(&Tensor) -> TorshResult<Tensor>,
{
let _d_real = discriminator_fn(real_samples)?;
let _batch_size = real_samples.shape().dims()[0];
let dummy_grad = randn(real_samples.shape().dims(), None, None, None)?;
let grad_norm_sq = dummy_grad.pow_scalar(2.0)?.sum_dim(&[1, 2, 3], false)?;
let penalty = grad_norm_sq.mul_scalar(lambda as f32 * 0.5)?;
match reduction {
"none" => Ok(penalty),
"mean" => penalty.mean(None, false),
"sum" => penalty.sum(),
_ => Err(TorshError::invalid_argument_with_context(
&format!(
"Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
reduction
),
"r1_gradient_penalty",
)),
}
}
pub fn r2_gradient_penalty<F>(
fake_samples: &Tensor,
discriminator_fn: F,
lambda: f64,
reduction: &str,
) -> TorshResult<Tensor>
where
F: Fn(&Tensor) -> TorshResult<Tensor>,
{
let _d_fake = discriminator_fn(fake_samples)?;
let dummy_grad = randn(fake_samples.shape().dims(), None, None, None)?;
let grad_norm_sq = dummy_grad.pow_scalar(2.0)?.sum_dim(&[1, 2, 3], false)?;
let penalty = grad_norm_sq.mul_scalar(lambda as f32 * 0.5)?;
match reduction {
"none" => Ok(penalty),
"mean" => penalty.mean(None, false),
"sum" => penalty.sum(),
_ => Err(TorshError::invalid_argument_with_context(
&format!(
"Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
reduction
),
"r2_gradient_penalty",
)),
}
}
pub fn consistency_penalty<F>(
model_fn: F,
input: &Tensor,
perturbed_input: &Tensor,
lambda: f64,
reduction: &str,
) -> TorshResult<Tensor>
where
F: Fn(&Tensor) -> TorshResult<Tensor>,
{
let output_original = model_fn(input)?;
let output_perturbed = model_fn(perturbed_input)?;
let diff = output_original.sub(&output_perturbed)?;
let penalty = diff.pow_scalar(2.0)?.mul_scalar(lambda as f32)?;
match reduction {
"none" => Ok(penalty),
"mean" => penalty.mean(None, false),
"sum" => penalty.sum(),
_ => Err(TorshError::invalid_argument_with_context(
&format!(
"Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
reduction
),
"consistency_penalty",
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
#[test]
fn test_gradient_penalty_shapes() {
let real_samples = randn(&[4, 3, 32, 32], None, None, None).unwrap();
let fake_samples = randn(&[4, 3, 32, 32], None, None, None).unwrap();
let discriminator_fn = |x: &Tensor| -> TorshResult<Tensor> {
let flattened = x.view(&[x.shape().dims()[0] as i32, -1])?;
let weight = randn(&[flattened.shape().dims()[1], 1], None, None, None)?;
flattened.matmul(&weight)
};
let penalty =
gradient_penalty(&real_samples, &fake_samples, discriminator_fn, 10.0, "mean");
assert!(penalty.is_ok());
let penalty = penalty.unwrap();
assert_eq!(penalty.shape().dims(), &[] as &[usize]); }
#[test]
fn test_r1_penalty_shapes() {
let real_samples = randn(&[4, 3, 32, 32], None, None, None).unwrap();
let discriminator_fn = |x: &Tensor| -> TorshResult<Tensor> {
let flattened = x.view(&[x.shape().dims()[0] as i32, -1])?;
let weight = randn(&[flattened.shape().dims()[1], 1], None, None, None)?;
flattened.matmul(&weight)
};
let penalty = r1_gradient_penalty(&real_samples, discriminator_fn, 10.0, "mean");
assert!(penalty.is_ok());
let penalty = penalty.unwrap();
assert_eq!(penalty.shape().dims(), &[] as &[usize]); }
#[test]
fn test_consistency_penalty() {
let input = randn(&[4, 10], None, None, None).unwrap();
let perturbed_input = input.add_scalar(0.1).unwrap();
let model_fn = |x: &Tensor| -> TorshResult<Tensor> {
let weight = randn(&[10, 5], None, None, None)?;
x.matmul(&weight)
};
let penalty = consistency_penalty(model_fn, &input, &perturbed_input, 1.0, "mean");
assert!(penalty.is_ok());
let penalty = penalty.unwrap();
assert_eq!(penalty.shape().dims(), &[] as &[usize]); }
}