torsh-functional 0.1.2

Functional programming utilities for ToRSh tensors
Documentation
//! Regularization functions for training stability
//!
//! This module provides regularization techniques commonly used in deep learning,
//! particularly for adversarial training and GAN stabilization.

use crate::random_ops::randn;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{
    creation::{ones, rand},
    Tensor,
};

/// Gradient penalty for WGAN-GP (Wasserstein GAN with Gradient Penalty)
///
/// Computes the gradient penalty term used in WGAN-GP to enforce
/// Lipschitz constraint on the discriminator/critic.
///
/// # Arguments
/// * `real_samples` - Real data samples
/// * `fake_samples` - Generated/fake data samples  
/// * `discriminator_fn` - Function that computes discriminator output
/// * `lambda` - Gradient penalty coefficient (typically 10.0)
/// * `reduction` - Reduction method: "mean", "sum", or "none"
///
/// # Returns
/// Gradient penalty loss term
pub fn gradient_penalty<F>(
    real_samples: &Tensor,
    fake_samples: &Tensor,
    discriminator_fn: F,
    lambda: f64,
    reduction: &str,
) -> TorshResult<Tensor>
where
    F: Fn(&Tensor) -> TorshResult<Tensor>,
{
    let batch_size = real_samples.shape().dims()[0];

    // Sample random interpolation factors
    let _epsilon: Tensor = rand(&[batch_size, 1, 1, 1])?;

    // Create interpolated samples (simplified approach using first element of epsilon)
    let epsilon_val = 0.5; // Use fixed value for now - in full implementation would extract from tensor
    let fake_scaled = fake_samples.mul_scalar(1.0 - epsilon_val)?;
    let real_scaled = real_samples.mul_scalar(epsilon_val)?;
    let interpolated = real_scaled.add_op(&fake_scaled)?;

    // Note: Gradient computation would be enabled here in full autograd implementation
    // let interpolated = interpolated.requires_grad_(true)?;

    // Compute discriminator output for interpolated samples
    let _d_interpolated = discriminator_fn(&interpolated)?;

    // Note: For now, we use a placeholder gradient computation
    // In the full implementation, this would use torsh_autograd::grad
    // Create dummy gradient with same shape as interpolated for now
    let grad_flat = ones(&[batch_size, interpolated.numel() / batch_size])?;

    // Compute L2 norm of gradients for each sample
    let grad_flat_f32 = grad_flat.to_dtype(torsh_core::DType::F32)?;
    // Compute L2 norm using pow and sum operations
    let grad_norm = grad_flat_f32.pow_scalar(2.0)?.sum()?.sqrt()?;

    // Compute gradient penalty: (||grad||_2 - 1)^2
    let penalty = grad_norm.add_scalar(-1.0)?.pow_scalar(2.0)?;

    // Apply lambda coefficient
    let penalty = penalty.mul_scalar(lambda as f32)?;

    // Apply reduction
    match reduction {
        "none" => Ok(penalty),
        "mean" => penalty.mean(None, false),
        "sum" => penalty.sum(),
        _ => Err(TorshError::invalid_argument_with_context(
            &format!(
                "Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
                reduction
            ),
            "gradient_penalty",
        )),
    }
}

/// Spectral normalization gradient penalty
///
/// Computes gradient penalty specifically designed for spectral normalization,
/// enforcing spectral norm constraints on network weights.
///
/// # Arguments
/// * `network_output` - Output from the network
/// * `input_tensor` - Network input tensor
/// * `lambda` - Penalty coefficient
/// * `reduction` - Reduction method
pub fn spectral_gradient_penalty(
    _network_output: &Tensor,
    input_tensor: &Tensor,
    lambda: f64,
    reduction: &str,
) -> TorshResult<Tensor> {
    // Note: For now, we use a placeholder gradient computation
    // In the full implementation, this would use torsh_autograd::grad
    let batch_size = input_tensor.shape().dims()[0];
    let grad_reshaped = randn(
        &[batch_size, input_tensor.numel() / batch_size],
        None,
        None,
        None,
    )?;

    // Use placeholder for spectral norm (SVD not available yet)
    // In full implementation: let (_, s, _) = grad_reshaped.svd(false, false)?;
    let grad_reshaped_f32 = grad_reshaped.to_dtype(torsh_core::DType::F32)?;
    // Compute L2 norm using pow and sum operations
    let spectral_norm = grad_reshaped_f32.pow_scalar(2.0)?.sum()?.sqrt()?; // Placeholder using l2 norm

    // Spectral gradient penalty: (sigma_max - 1)^2 where sigma_max is spectral norm
    let penalty = spectral_norm
        .add_scalar(-1.0)?
        .pow_scalar(2.0)?
        .mul_scalar(lambda as f32)?;

    match reduction {
        "none" => Ok(penalty),
        "mean" => penalty.mean(None, false),
        "sum" => penalty.sum(),
        _ => Err(TorshError::invalid_argument_with_context(
            &format!(
                "Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
                reduction
            ),
            "spectral_gradient_penalty",
        )),
    }
}

/// R1 gradient penalty used in StyleGAN
///
/// Implements the R1 regularization term from "Which Training Methods for GANs
/// do actually Converge?" This penalty encourages the discriminator to have
/// zero gradients on real data.
///
/// # Arguments
/// * `real_samples` - Real data samples
/// * `discriminator_fn` - Function that computes discriminator output
/// * `lambda` - Gradient penalty coefficient
/// * `reduction` - Reduction method
pub fn r1_gradient_penalty<F>(
    real_samples: &Tensor,
    discriminator_fn: F,
    lambda: f64,
    reduction: &str,
) -> TorshResult<Tensor>
where
    F: Fn(&Tensor) -> TorshResult<Tensor>,
{
    let _d_real = discriminator_fn(real_samples)?;

    // Note: For now, we use a placeholder gradient computation
    // In the full implementation, this would use torsh_autograd::grad
    let _batch_size = real_samples.shape().dims()[0];
    let dummy_grad = randn(real_samples.shape().dims(), None, None, None)?;

    // Compute squared L2 norm of gradients
    let grad_norm_sq = dummy_grad.pow_scalar(2.0)?.sum_dim(&[1, 2, 3], false)?;
    let penalty = grad_norm_sq.mul_scalar(lambda as f32 * 0.5)?; // 0.5 factor from R1 paper

    match reduction {
        "none" => Ok(penalty),
        "mean" => penalty.mean(None, false),
        "sum" => penalty.sum(),
        _ => Err(TorshError::invalid_argument_with_context(
            &format!(
                "Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
                reduction
            ),
            "r1_gradient_penalty",
        )),
    }
}

/// R2 gradient penalty for generator regularization
///
/// Implements R2 regularization for generators, encouraging zero gradients
/// on fake data. Less commonly used than R1 but useful in some settings.
///
/// # Arguments
/// * `fake_samples` - Generated/fake data samples
/// * `discriminator_fn` - Function that computes discriminator output
/// * `lambda` - Gradient penalty coefficient
/// * `reduction` - Reduction method
pub fn r2_gradient_penalty<F>(
    fake_samples: &Tensor,
    discriminator_fn: F,
    lambda: f64,
    reduction: &str,
) -> TorshResult<Tensor>
where
    F: Fn(&Tensor) -> TorshResult<Tensor>,
{
    let _d_fake = discriminator_fn(fake_samples)?;

    // Note: For now, we use a placeholder gradient computation
    // In the full implementation, this would use torsh_autograd::grad
    let dummy_grad = randn(fake_samples.shape().dims(), None, None, None)?;

    // Compute squared L2 norm of gradients
    let grad_norm_sq = dummy_grad.pow_scalar(2.0)?.sum_dim(&[1, 2, 3], false)?;
    let penalty = grad_norm_sq.mul_scalar(lambda as f32 * 0.5)?; // 0.5 factor from R2 paper

    match reduction {
        "none" => Ok(penalty),
        "mean" => penalty.mean(None, false),
        "sum" => penalty.sum(),
        _ => Err(TorshError::invalid_argument_with_context(
            &format!(
                "Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
                reduction
            ),
            "r2_gradient_penalty",
        )),
    }
}

/// Consistency regularization penalty
///
/// Enforces consistency between network outputs for slightly perturbed inputs.
/// Commonly used in semi-supervised learning and domain adaptation.
///
/// # Arguments
/// * `model_fn` - Function that computes model output
/// * `input` - Original input tensor
/// * `perturbed_input` - Perturbed version of input
/// * `lambda` - Consistency penalty coefficient
/// * `reduction` - Reduction method
pub fn consistency_penalty<F>(
    model_fn: F,
    input: &Tensor,
    perturbed_input: &Tensor,
    lambda: f64,
    reduction: &str,
) -> TorshResult<Tensor>
where
    F: Fn(&Tensor) -> TorshResult<Tensor>,
{
    let output_original = model_fn(input)?;
    let output_perturbed = model_fn(perturbed_input)?;

    // Compute MSE between outputs
    let diff = output_original.sub(&output_perturbed)?;
    let penalty = diff.pow_scalar(2.0)?.mul_scalar(lambda as f32)?;

    match reduction {
        "none" => Ok(penalty),
        "mean" => penalty.mean(None, false),
        "sum" => penalty.sum(),
        _ => Err(TorshError::invalid_argument_with_context(
            &format!(
                "Invalid reduction: {}, expected 'none', 'mean', or 'sum'",
                reduction
            ),
            "consistency_penalty",
        )),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::random_ops::randn;

    #[test]
    fn test_gradient_penalty_shapes() {
        let real_samples = randn(&[4, 3, 32, 32], None, None, None).unwrap();
        let fake_samples = randn(&[4, 3, 32, 32], None, None, None).unwrap();

        // Simple discriminator function for testing
        let discriminator_fn = |x: &Tensor| -> TorshResult<Tensor> {
            // Simple linear transformation for testing
            let flattened = x.view(&[x.shape().dims()[0] as i32, -1])?;
            let weight = randn(&[flattened.shape().dims()[1], 1], None, None, None)?;
            flattened.matmul(&weight)
        };

        let penalty =
            gradient_penalty(&real_samples, &fake_samples, discriminator_fn, 10.0, "mean");

        assert!(penalty.is_ok());
        let penalty = penalty.unwrap();
        assert_eq!(penalty.shape().dims(), &[] as &[usize]); // Scalar for mean reduction
    }

    #[test]
    fn test_r1_penalty_shapes() {
        let real_samples = randn(&[4, 3, 32, 32], None, None, None).unwrap();

        let discriminator_fn = |x: &Tensor| -> TorshResult<Tensor> {
            let flattened = x.view(&[x.shape().dims()[0] as i32, -1])?;
            let weight = randn(&[flattened.shape().dims()[1], 1], None, None, None)?;
            flattened.matmul(&weight)
        };

        let penalty = r1_gradient_penalty(&real_samples, discriminator_fn, 10.0, "mean");

        assert!(penalty.is_ok());
        let penalty = penalty.unwrap();
        assert_eq!(penalty.shape().dims(), &[] as &[usize]); // Scalar for mean reduction
    }

    #[test]
    fn test_consistency_penalty() {
        let input = randn(&[4, 10], None, None, None).unwrap();
        let perturbed_input = input.add_scalar(0.1).unwrap();

        let model_fn = |x: &Tensor| -> TorshResult<Tensor> {
            let weight = randn(&[10, 5], None, None, None)?;
            x.matmul(&weight)
        };

        let penalty = consistency_penalty(model_fn, &input, &perturbed_input, 1.0, "mean");

        assert!(penalty.is_ok());
        let penalty = penalty.unwrap();
        assert_eq!(penalty.shape().dims(), &[] as &[usize]); // Scalar for mean reduction
    }
}