use crate::tensor::Tensor;
use crate::tensor_ops;
use crate::Float;
pub trait GradientClipper<F: Float> {
fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>>;
fn was_clipped(&self) -> bool {
false
}
fn get_clipping_stats(&self) -> ClippingStats<F> {
ClippingStats::default()
}
}
#[derive(Debug, Clone)]
pub struct ClippingStats<F: Float> {
pub was_clipped: bool,
pub original_norm: Option<F>,
pub clipped_norm: Option<F>,
pub clipping_factor: Option<F>,
pub num_clipped: usize,
pub total_gradients: usize,
}
impl<F: Float> Default for ClippingStats<F> {
fn default() -> Self {
Self {
was_clipped: false,
original_norm: None,
clipped_norm: None,
clipping_factor: None,
num_clipped: 0,
total_gradients: 0,
}
}
}
pub struct ClipByValue<F: Float> {
pub min_value: F,
pub max_value: F,
last_clipped: std::cell::Cell<bool>,
}
impl<F: Float> ClipByValue<F> {
pub fn new(min_value: F, max_value: F) -> Self {
assert!(
min_value < max_value,
"min_value must be less than max_value"
);
Self {
min_value,
max_value,
last_clipped: std::cell::Cell::new(false),
}
}
pub fn symmetric(max_abs_value: F) -> Self {
Self::new(-max_abs_value, max_abs_value)
}
}
impl<F: Float> GradientClipper<F> for ClipByValue<F> {
fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
let any_clipped = false;
let clipped: Vec<_> = gradients
.iter()
.map(|grad| {
let clipped_grad = tensor_ops::clip(*grad, self.min_value, self.max_value);
clipped_grad
})
.collect();
self.last_clipped.set(any_clipped);
clipped
}
fn was_clipped(&self) -> bool {
self.last_clipped.get()
}
}
pub struct ClipByNorm<F: Float> {
pub max_norm: F,
last_clipped: std::cell::Cell<bool>,
last_stats: std::cell::RefCell<ClippingStats<F>>,
}
impl<F: Float> ClipByNorm<F> {
pub fn new(max_norm: F) -> Self {
assert!(max_norm > F::zero(), "max_norm must be positive");
Self {
max_norm,
last_clipped: std::cell::Cell::new(false),
last_stats: std::cell::RefCell::new(ClippingStats::default()),
}
}
}
impl<F: Float> GradientClipper<F> for ClipByNorm<F> {
fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
let any_clipped = false;
let num_clipped = 0;
let clipped: Vec<_> = gradients
.iter()
.map(|grad| {
let grad_norm = tensor_ops::frobenius_norm(grad);
let max_norm_tensor = tensor_ops::scalar(self.max_norm, grad.graph());
let one_tensor = tensor_ops::scalar(F::one(), grad.graph());
let ratio = max_norm_tensor / grad_norm;
let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
(*grad) * clipping_factor
})
.collect();
self.last_clipped.set(any_clipped);
let mut stats = self.last_stats.borrow_mut();
stats.was_clipped = any_clipped;
stats.num_clipped = num_clipped;
stats.total_gradients = gradients.len();
clipped
}
fn was_clipped(&self) -> bool {
self.last_clipped.get()
}
fn get_clipping_stats(&self) -> ClippingStats<F> {
self.last_stats.borrow().clone()
}
}
pub struct ClipByGlobalNorm<F: Float> {
pub max_norm: F,
last_clipped: std::cell::Cell<bool>,
last_stats: std::cell::RefCell<ClippingStats<F>>,
}
impl<F: Float> ClipByGlobalNorm<F> {
pub fn new(max_norm: F) -> Self {
assert!(max_norm > F::zero(), "max_norm must be positive");
Self {
max_norm,
last_clipped: std::cell::Cell::new(false),
last_stats: std::cell::RefCell::new(ClippingStats::default()),
}
}
}
impl<F: Float> GradientClipper<F> for ClipByGlobalNorm<F> {
fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
if gradients.is_empty() {
return Vec::new();
}
let g = gradients[0].graph();
let squared_norms: Vec<_> = gradients
.iter()
.map(|grad| {
let norm = tensor_ops::frobenius_norm(grad);
tensor_ops::square(norm)
})
.collect();
let global_norm_squared = tensor_ops::add_n(&squared_norms);
let global_norm = tensor_ops::sqrt(global_norm_squared);
let max_norm_tensor = tensor_ops::scalar(self.max_norm, g);
let one_tensor = tensor_ops::scalar(F::one(), g);
let ratio = max_norm_tensor / global_norm;
let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
let clipped: Vec<_> = gradients
.iter()
.map(|grad| (*grad) * clipping_factor)
.collect();
let was_clipped = false;
self.last_clipped.set(was_clipped);
let mut stats = self.last_stats.borrow_mut();
stats.was_clipped = was_clipped;
stats.total_gradients = gradients.len();
stats.num_clipped = if was_clipped { gradients.len() } else { 0 };
clipped
}
fn was_clipped(&self) -> bool {
self.last_clipped.get()
}
fn get_clipping_stats(&self) -> ClippingStats<F> {
self.last_stats.borrow().clone()
}
}
pub struct AdaptiveClipByNorm<F: Float> {
base_clipper: ClipByNorm<F>,
#[allow(dead_code)]
adaptation_rate: F,
current_threshold: std::cell::Cell<F>,
}
impl<F: Float> AdaptiveClipByNorm<F> {
pub fn new(initial_max_norm: F, adaptation_rate: F) -> Self {
assert!(
adaptation_rate >= F::zero() && adaptation_rate <= F::one(),
"adaptation_rate must be between 0.0 and 1.0"
);
Self {
base_clipper: ClipByNorm::new(initial_max_norm),
adaptation_rate,
current_threshold: std::cell::Cell::new(initial_max_norm),
}
}
pub fn current_threshold(&self) -> F {
self.current_threshold.get()
}
pub fn set_threshold(&self, new_threshold: F) {
assert!(new_threshold > F::zero(), "threshold must be positive");
self.current_threshold.set(new_threshold);
}
}
impl<F: Float> GradientClipper<F> for AdaptiveClipByNorm<F> {
fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
let current_threshold = self.current_threshold.get();
self.base_clipper.max_norm = current_threshold;
let result = self.base_clipper.clip_gradients(gradients);
result
}
fn was_clipped(&self) -> bool {
self.base_clipper.was_clipped()
}
fn get_clipping_stats(&self) -> ClippingStats<F> {
self.base_clipper.get_clipping_stats()
}
}
impl<F: Float> Tensor<'_, F> {
pub fn clip_values(self, min_value: F, max_value: F) -> Self {
tensor_ops::clip(self, min_value, max_value)
}
pub fn clip_norm(self, max_norm: F) -> Self {
let norm = tensor_ops::frobenius_norm(self);
let max_norm_tensor = tensor_ops::scalar(max_norm, self.graph());
let one_tensor = tensor_ops::scalar(F::one(), self.graph());
let ratio = max_norm_tensor / norm;
let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
self * clipping_factor
}
}
pub mod presets {
use super::*;
pub fn conservative<F: Float>() -> ClipByGlobalNorm<F> {
ClipByGlobalNorm::new(F::from(0.5).expect("Failed to convert constant to float"))
}
pub fn standard<F: Float>() -> ClipByGlobalNorm<F> {
ClipByGlobalNorm::new(F::from(1.0).expect("Failed to convert constant to float"))
}
pub fn aggressive<F: Float>() -> ClipByGlobalNorm<F> {
ClipByGlobalNorm::new(F::from(0.1).expect("Failed to convert constant to float"))
}
pub fn extreme_prevention<F: Float>() -> ClipByValue<F> {
ClipByValue::symmetric(F::from(10.0).expect("Failed to convert constant to float"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clip_by_value_creation() {
let clipper = ClipByValue::new(-1.0f32, 1.0f32);
assert_eq!(clipper.min_value, -1.0);
assert_eq!(clipper.max_value, 1.0);
let symmetric = ClipByValue::symmetric(0.5f32);
assert_eq!(symmetric.min_value, -0.5);
assert_eq!(symmetric.max_value, 0.5);
}
#[test]
fn test_clip_by_norm_creation() {
let clipper = ClipByNorm::new(1.0f32);
assert_eq!(clipper.max_norm, 1.0);
}
#[test]
fn test_clip_by_global_norm_creation() {
let clipper = ClipByGlobalNorm::new(1.0f32);
assert_eq!(clipper.max_norm, 1.0);
}
#[test]
fn test_adaptive_clipper() {
let clipper = AdaptiveClipByNorm::new(1.0f32, 0.1);
assert_eq!(clipper.current_threshold(), 1.0);
clipper.set_threshold(0.5);
assert_eq!(clipper.current_threshold(), 0.5);
}
#[test]
fn test_clipping_stats_default() {
let stats = ClippingStats::<f32>::default();
assert!(!stats.was_clipped);
assert_eq!(stats.num_clipped, 0);
assert_eq!(stats.total_gradients, 0);
}
#[test]
fn test_presets() {
let _conservative = presets::conservative::<f32>();
let _standard = presets::standard::<f32>();
let _aggressive = presets::aggressive::<f32>();
let _extreme = presets::extreme_prevention::<f32>();
}
#[test]
#[should_panic(expected = "min_value must be less than max_value")]
fn test_clip_by_value_invalid_range() {
ClipByValue::new(1.0f32, -1.0f32);
}
#[test]
#[should_panic(expected = "max_norm must be positive")]
fn test_clip_by_norm_negative_norm() {
ClipByNorm::new(-1.0f32);
}
#[test]
#[should_panic(expected = "adaptation_rate must be between 0.0 and 1.0")]
fn test_adaptive_clipper_invalid_rate() {
AdaptiveClipByNorm::new(1.0f32, 2.0);
}
}