use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
use super::higher_order::grad;
pub fn gradient_penalty<T: Float, F>(
discriminator: F,
real: &Tensor<T>,
fake: &Tensor<T>,
lambda: f64,
) -> FerrotorchResult<Tensor<T>>
where
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
if real.shape() != fake.shape() {
return Err(crate::error::FerrotorchError::ShapeMismatch {
message: format!(
"gradient_penalty: real shape {:?} != fake shape {:?}",
real.shape(),
fake.shape()
),
});
}
let n = real.numel();
let alpha_tensor: Tensor<T> = crate::creation::rand(real.shape())?;
let alpha_data = alpha_tensor.data()?;
let real_data = real.data()?;
let fake_data = fake.data()?;
let one = <T as num_traits::One>::one();
let interp_data: Vec<T> = (0..n)
.map(|i| alpha_data[i] * real_data[i] + (one - alpha_data[i]) * fake_data[i])
.collect();
let x_interp = Tensor::from_storage(
TensorStorage::cpu(interp_data),
real.shape().to_vec(),
true,
)?;
let d_interp = discriminator(&x_interp)?;
let grads = grad(&d_interp, &[&x_interp], false, true)?;
let grad_interp = match &grads[0] {
Some(g) => g.clone(),
None => {
let zero_data = vec![<T as num_traits::Zero>::zero(); n];
Tensor::from_storage(TensorStorage::cpu(zero_data), real.shape().to_vec(), false)?
}
};
let grad_sq = crate::grad_fns::arithmetic::pow(&grad_interp, 2.0)?;
let grad_sq_sum = crate::grad_fns::reduction::sum(&grad_sq)?;
let grad_norm = crate::grad_fns::arithmetic::sqrt(&grad_sq_sum)?;
let one_tensor = Tensor::from_storage(TensorStorage::cpu(vec![one]), vec![], false)?;
let diff = crate::grad_fns::arithmetic::sub(&grad_norm, &one_tensor)?;
let diff_sq = crate::grad_fns::arithmetic::pow(&diff, 2.0)?;
let lambda_t = T::from(lambda).unwrap();
let lambda_tensor =
Tensor::from_storage(TensorStorage::cpu(vec![lambda_t]), vec![], false)?;
let penalty = crate::grad_fns::arithmetic::mul(&lambda_tensor, &diff_sq)?;
Ok(penalty)
}
pub fn grad_norm<T: Float>(
outputs: &Tensor<T>,
inputs: &[&Tensor<T>],
) -> FerrotorchResult<Tensor<T>> {
let grads = grad(outputs, inputs, false, false)?;
let zero = <T as num_traits::Zero>::zero();
let mut total_sq = zero;
for maybe_grad in grads.iter().flatten() {
let g_data = maybe_grad.data()?;
for &val in g_data.iter() {
total_sq += val * val;
}
}
let norm_val = total_sq.sqrt();
Tensor::from_storage(TensorStorage::cpu(vec![norm_val]), vec![], false)
}
pub fn jvp<T: Float, F>(f: F, input: &Tensor<T>, v: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
if input.shape() != v.shape() {
return Err(crate::error::FerrotorchError::ShapeMismatch {
message: format!(
"jvp: input shape {:?} != v shape {:?}",
input.shape(),
v.shape()
),
});
}
let h = T::from(1e-4).unwrap();
let two_h = T::from(2e-4).unwrap();
let input_data = input.data()?;
let v_data = v.data()?;
let n = input.numel();
let plus_data: Vec<T> = (0..n).map(|i| input_data[i] + h * v_data[i]).collect();
let x_plus =
Tensor::from_storage(TensorStorage::cpu(plus_data), input.shape().to_vec(), false)?;
let minus_data: Vec<T> = (0..n).map(|i| input_data[i] - h * v_data[i]).collect();
let x_minus = Tensor::from_storage(
TensorStorage::cpu(minus_data),
input.shape().to_vec(),
false,
)?;
let f_plus = f(&x_plus)?;
let f_minus = f(&x_minus)?;
let fp_data = f_plus.data()?;
let fm_data = f_minus.data()?;
let result_data: Vec<T> = fp_data
.iter()
.zip(fm_data.iter())
.map(|(&fp, &fm)| (fp - fm) / two_h)
.collect();
Tensor::from_storage(
TensorStorage::cpu(result_data),
f_plus.shape().to_vec(),
false,
)
}
pub fn vjp<T: Float, F>(f: F, input: &Tensor<T>, v: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
let x = Tensor::from_storage(
TensorStorage::cpu(input.data()?.to_vec()),
input.shape().to_vec(),
true,
)?;
let y = f(&x)?;
let y_data = y.data()?;
let v_data = v.data()?;
if y_data.len() != v_data.len() {
return Err(crate::error::FerrotorchError::ShapeMismatch {
message: format!(
"vjp: f(input) has {} elements but v has {}",
y_data.len(),
v_data.len()
),
});
}
let v_tensor = Tensor::from_storage(
TensorStorage::cpu(v_data.to_vec()),
y.shape().to_vec(),
false,
)?;
let weighted = crate::grad_fns::arithmetic::mul(&y, &v_tensor)?;
let scalar = crate::grad_fns::reduction::sum(&weighted)?;
let grads = grad(&scalar, &[&x], false, false)?;
match grads.into_iter().next().unwrap() {
Some(g) => Ok(g),
None => {
let zero_data = vec![<T as num_traits::Zero>::zero(); input.numel()];
Tensor::from_storage(
TensorStorage::cpu(zero_data),
input.shape().to_vec(),
false,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grad_fns::arithmetic::{add, mul, pow};
use crate::grad_fns::reduction::sum;
use crate::storage::TensorStorage;
fn leaf_scalar(val: f32, requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
fn leaf_vec(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], requires_grad)
.unwrap()
}
fn assert_approx(actual: f32, expected: f32, tol: f32, msg: &str) {
assert!(
(actual - expected).abs() < tol,
"{msg}: expected {expected}, got {actual}"
);
}
#[test]
fn test_gradient_penalty_linear_discriminator() {
let n = 4usize;
let real = leaf_vec(&[1.0, 2.0, 3.0, 4.0], false);
let fake = leaf_vec(&[0.5, 1.5, 2.5, 3.5], false);
let lambda = 10.0;
let penalty = gradient_penalty(|x| sum(x), &real, &fake, lambda).unwrap();
let expected = lambda as f32 * ((n as f32).sqrt() - 1.0).powi(2);
assert_approx(
penalty.item().unwrap(),
expected,
1e-3,
"gradient_penalty for linear D(x)=sum(x)",
);
}
#[test]
fn test_gradient_penalty_shape_mismatch() {
let real = leaf_vec(&[1.0, 2.0], false);
let fake = leaf_vec(&[1.0, 2.0, 3.0], false);
let result = gradient_penalty(|x| sum(x), &real, &fake, 10.0);
assert!(result.is_err(), "should error on shape mismatch");
}
#[test]
fn test_gradient_penalty_scalar_input() {
let real = leaf_vec(&[2.0], false);
let fake = leaf_vec(&[2.0], false);
let lambda = 5.0;
let penalty = gradient_penalty(
|x| {
let sq = pow(x, 2.0)?;
sum(&sq)
},
&real,
&fake,
lambda,
)
.unwrap();
let expected = 5.0f32 * (4.0 - 1.0_f32).powi(2);
assert_approx(
penalty.item().unwrap(),
expected,
1e-2,
"gradient_penalty for D(x)=x^2 at x=2",
);
}
#[test]
fn test_gradient_penalty_has_grad_fn() {
let real = leaf_vec(&[1.0, 2.0], false);
let fake = leaf_vec(&[0.5, 1.5], false);
let penalty = gradient_penalty(|x| sum(x), &real, &fake, 10.0).unwrap();
assert!(
penalty.grad_fn().is_some(),
"gradient_penalty result should have grad_fn for outer optimization"
);
}
#[test]
fn test_grad_norm_simple() {
let x = leaf_vec(&[3.0, 4.0], true);
let y = {
let sq = pow(&x, 2.0).unwrap();
sum(&sq).unwrap()
};
let norm = grad_norm(&y, &[&x]).unwrap();
assert_approx(norm.item().unwrap(), 10.0, 1e-3, "grad_norm of [6,8]");
}
#[test]
fn test_grad_norm_scalar() {
let x = leaf_scalar(2.0, true);
let y = pow(&x, 3.0).unwrap();
let norm = grad_norm(&y, &[&x]).unwrap();
assert_approx(norm.item().unwrap(), 12.0, 1e-3, "grad_norm of scalar");
}
#[test]
fn test_grad_norm_multiple_inputs() {
let x = leaf_scalar(3.0, true);
let y = leaf_scalar(4.0, true);
let x2 = pow(&x, 2.0).unwrap();
let y2 = pow(&y, 2.0).unwrap();
let z = add(&x2, &y2).unwrap();
let norm = grad_norm(&z, &[&x, &y]).unwrap();
assert_approx(norm.item().unwrap(), 10.0, 1e-3, "grad_norm across two inputs");
}
#[test]
fn test_vjp_identity() {
let input = leaf_vec(&[1.0, 2.0, 3.0], false);
let v = leaf_vec(&[4.0, 5.0, 6.0], false);
let result = vjp(
|x| {
let ones = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32; 3]),
vec![3],
false,
)
.unwrap();
mul(x, &ones)
},
&input,
&v,
)
.unwrap();
let data = result.data().unwrap();
assert_approx(data[0], 4.0, 1e-5, "vjp identity [0]");
assert_approx(data[1], 5.0, 1e-5, "vjp identity [1]");
assert_approx(data[2], 6.0, 1e-5, "vjp identity [2]");
}
#[test]
fn test_vjp_linear_2x() {
let input = leaf_vec(&[1.0, 2.0], false);
let v = leaf_vec(&[3.0, 4.0], false);
let result = vjp(|x| add(x, x), &input, &v).unwrap();
let data = result.data().unwrap();
assert_approx(data[0], 6.0, 1e-5, "vjp 2x [0]");
assert_approx(data[1], 8.0, 1e-5, "vjp 2x [1]");
}
#[test]
fn test_vjp_scalar_mul() {
let input = leaf_vec(&[2.0], false);
let v = leaf_vec(&[5.0], false);
let result = vjp(
|x| {
let c = Tensor::from_storage(TensorStorage::cpu(vec![3.0f32]), vec![1], false)
.unwrap();
mul(x, &c)
},
&input,
&v,
)
.unwrap();
assert_approx(result.data().unwrap()[0], 15.0, 1e-5, "vjp scalar mul");
}
#[test]
fn test_vjp_matches_manual_backward() {
let input = leaf_vec(&[3.0, 4.0], false);
let v = leaf_vec(&[1.0, 1.0], false);
let result = vjp(|x| pow(x, 2.0), &input, &v).unwrap();
let data = result.data().unwrap();
assert_approx(data[0], 6.0, 1e-3, "vjp x^2 [0]");
assert_approx(data[1], 8.0, 1e-3, "vjp x^2 [1]");
}
#[test]
fn test_vjp_shape_mismatch() {
let input = leaf_vec(&[1.0, 2.0], false);
let v = leaf_vec(&[1.0, 2.0, 3.0], false);
let result = vjp(|x| add(x, x), &input, &v);
assert!(result.is_err(), "vjp should error when v shape != f(input) shape");
}
#[test]
fn test_jvp_identity() {
let input = leaf_vec(&[1.0, 2.0, 3.0], false);
let v = leaf_vec(&[4.0, 5.0, 6.0], false);
let result = jvp(
|x| {
let d = x.data().unwrap();
Tensor::from_storage(TensorStorage::cpu(d.to_vec()), x.shape().to_vec(), false)
},
&input,
&v,
)
.unwrap();
let data = result.data().unwrap();
assert_approx(data[0], 4.0, 1e-2, "jvp identity [0]");
assert_approx(data[1], 5.0, 1e-2, "jvp identity [1]");
assert_approx(data[2], 6.0, 1e-2, "jvp identity [2]");
}
#[test]
fn test_jvp_quadratic() {
let input = leaf_vec(&[3.0, 4.0], false);
let v = leaf_vec(&[1.0, 1.0], false);
let result = jvp(
|x| {
let d = x.data().unwrap();
let sq: Vec<f32> = d.iter().map(|&val| val * val).collect();
Tensor::from_storage(TensorStorage::cpu(sq), x.shape().to_vec(), false)
},
&input,
&v,
)
.unwrap();
let data = result.data().unwrap();
assert_approx(data[0], 6.0, 1e-1, "jvp x^2 [0]");
assert_approx(data[1], 8.0, 1e-1, "jvp x^2 [1]");
}
#[test]
fn test_jvp_linear_2x() {
let input = leaf_vec(&[1.0, 2.0], false);
let v = leaf_vec(&[3.0, 4.0], false);
let result = jvp(
|x| {
let d = x.data().unwrap();
let doubled: Vec<f32> = d.iter().map(|&val| val * 2.0).collect();
Tensor::from_storage(TensorStorage::cpu(doubled), x.shape().to_vec(), false)
},
&input,
&v,
)
.unwrap();
let data = result.data().unwrap();
assert_approx(data[0], 6.0, 1e-2, "jvp 2x [0]");
assert_approx(data[1], 8.0, 1e-2, "jvp 2x [1]");
}
#[test]
fn test_jvp_matches_analytical_cubic() {
let input = leaf_vec(&[2.0], false);
let v = leaf_vec(&[1.0], false);
let result = jvp(
|x| {
let d = x.data().unwrap();
let cubed: Vec<f32> = d.iter().map(|&val| val * val * val).collect();
Tensor::from_storage(TensorStorage::cpu(cubed), x.shape().to_vec(), false)
},
&input,
&v,
)
.unwrap();
assert_approx(result.data().unwrap()[0], 12.0, 1e-1, "jvp x^3 at x=2");
}
#[test]
fn test_jvp_shape_mismatch() {
let input = leaf_vec(&[1.0, 2.0], false);
let v = leaf_vec(&[1.0], false);
let result = jvp(
|x| {
let d = x.data().unwrap();
Tensor::from_storage(TensorStorage::cpu(d.to_vec()), x.shape().to_vec(), false)
},
&input,
&v,
);
assert!(result.is_err(), "jvp should error on shape mismatch");
}
#[test]
fn test_gradient_penalty_create_graph_outer_loop() {
let real = leaf_vec(&[1.0, 2.0], false);
let fake = leaf_vec(&[0.5, 1.5], false);
let penalty = gradient_penalty(|x| sum(x), &real, &fake, 10.0).unwrap();
assert!(
penalty.grad_fn().is_some(),
"penalty must have grad_fn for outer-loop optimization"
);
assert!(
penalty.requires_grad(),
"penalty must require grad for outer-loop optimization"
);
}
}