use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GradientProcessingConfig {
pub enable_centralization: bool,
pub enable_standardization: bool,
pub enable_adaptive_clipping: bool,
pub enable_noise_injection: bool,
pub enable_smoothing: bool,
pub enable_hessian_preconditioning: bool,
pub adaptive_clipping: AdaptiveClippingConfig,
pub noise_injection: NoiseInjectionConfig,
pub smoothing: SmoothingConfig,
pub hessian_preconditioning: HessianPreconditioningConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveClippingConfig {
pub initial_clip_norm: f32,
pub min_clip_norm: f32,
pub max_clip_norm: f32,
pub adaptation_rate: f32,
pub target_percentile: f32,
pub history_window: usize,
}
impl Default for AdaptiveClippingConfig {
fn default() -> Self {
Self {
initial_clip_norm: 1.0,
min_clip_norm: 0.1,
max_clip_norm: 10.0,
adaptation_rate: 0.01,
target_percentile: 0.9,
history_window: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NoiseInjectionConfig {
pub initial_noise_scale: f32,
pub decay_rate: f32,
pub min_noise_scale: f32,
pub noise_type: NoiseType,
}
impl Default for NoiseInjectionConfig {
fn default() -> Self {
Self {
initial_noise_scale: 0.1,
decay_rate: 0.999,
min_noise_scale: 1e-6,
noise_type: NoiseType::Gaussian,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SmoothingConfig {
pub decay: f32,
pub debias: bool,
}
impl Default for SmoothingConfig {
fn default() -> Self {
Self {
decay: 0.9,
debias: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HessianPreconditioningConfig {
pub approximation_type: HessianApproximationType,
pub damping: f32,
pub update_frequency: usize,
pub history_window: usize,
pub min_eigenvalue: f32,
pub max_condition_number: f32,
}
impl Default for HessianPreconditioningConfig {
fn default() -> Self {
Self {
approximation_type: HessianApproximationType::Diagonal,
damping: 1e-4,
update_frequency: 10,
history_window: 20,
min_eigenvalue: 1e-8,
max_condition_number: 1e6,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NoiseType {
Gaussian,
Uniform,
Laplace,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HessianApproximationType {
Diagonal,
GaussNewton,
FisherInformation,
QuasiNewton,
}
#[derive(Debug)]
pub struct GradientProcessor {
config: GradientProcessingConfig,
current_step: usize,
gradient_norm_history: Vec<f32>,
current_clip_norm: f32,
current_noise_scale: f32,
smoothed_gradients: HashMap<usize, Tensor>,
smoothing_bias_correction: f32,
hessian_diagonal: HashMap<usize, Tensor>,
hessian_inverse: HashMap<usize, Tensor>,
last_hessian_update: usize,
gradient_history: Vec<Vec<Tensor>>,
}
impl GradientProcessor {
pub fn new(config: GradientProcessingConfig) -> Self {
Self {
current_clip_norm: config.adaptive_clipping.initial_clip_norm,
current_noise_scale: config.noise_injection.initial_noise_scale,
config,
current_step: 0,
gradient_norm_history: Vec::new(),
smoothed_gradients: HashMap::new(),
smoothing_bias_correction: 1.0,
hessian_diagonal: HashMap::new(),
hessian_inverse: HashMap::new(),
last_hessian_update: 0,
gradient_history: Vec::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(GradientProcessingConfig::default())
}
pub fn process_gradients(&mut self, gradients: &mut [Tensor]) -> Result<()> {
self.current_step += 1;
if self.config.enable_centralization {
self.apply_centralization(gradients)?;
}
if self.config.enable_standardization {
self.apply_standardization(gradients)?;
}
if self.config.enable_smoothing {
self.apply_smoothing(gradients)?;
}
if self.config.enable_hessian_preconditioning {
self.apply_hessian_preconditioning(gradients)?;
}
if self.config.enable_adaptive_clipping {
self.apply_adaptive_clipping(gradients)?;
}
if self.config.enable_noise_injection {
self.apply_noise_injection(gradients)?;
}
Ok(())
}
fn apply_centralization(&self, gradients: &mut [Tensor]) -> Result<()> {
for gradient in gradients.iter_mut() {
let mean = gradient.mean()?;
*gradient = gradient.sub(&mean)?;
}
Ok(())
}
fn apply_standardization(&self, gradients: &mut [Tensor]) -> Result<()> {
for gradient in gradients.iter_mut() {
let mean = gradient.mean()?;
let centered = gradient.sub(&mean)?;
let squared = centered.mul(¢ered)?;
let variance = squared.mean()?;
let std_dev = variance.sqrt()?;
let epsilon = Tensor::scalar(1e-8)?;
let std_dev_safe = std_dev.add(&epsilon)?;
*gradient = gradient.div(&std_dev_safe)?;
}
Ok(())
}
fn apply_adaptive_clipping(&mut self, gradients: &mut [Tensor]) -> Result<()> {
let mut total_norm_sq = 0.0;
for gradient in gradients.iter() {
let norm_sq = gradient.norm_squared()?.to_scalar()?;
total_norm_sq += norm_sq;
}
let total_norm = total_norm_sq.sqrt();
self.gradient_norm_history.push(total_norm);
if self.gradient_norm_history.len() > self.config.adaptive_clipping.history_window {
self.gradient_norm_history.remove(0);
}
if self.gradient_norm_history.len() >= 10 {
let mut sorted_norms = self.gradient_norm_history.clone();
sorted_norms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let percentile_idx = (sorted_norms.len() as f32
* self.config.adaptive_clipping.target_percentile)
as usize;
let target_norm = sorted_norms[percentile_idx.min(sorted_norms.len() - 1)];
let adaptation = self.config.adaptive_clipping.adaptation_rate
* (target_norm - self.current_clip_norm);
self.current_clip_norm += adaptation;
self.current_clip_norm = self
.current_clip_norm
.max(self.config.adaptive_clipping.min_clip_norm)
.min(self.config.adaptive_clipping.max_clip_norm);
}
if total_norm > self.current_clip_norm {
let clip_factor = self.current_clip_norm / total_norm;
for gradient in gradients.iter_mut() {
*gradient = gradient.mul_scalar(clip_factor)?;
}
}
Ok(())
}
fn apply_noise_injection(&mut self, gradients: &mut [Tensor]) -> Result<()> {
self.current_noise_scale *= self.config.noise_injection.decay_rate;
self.current_noise_scale =
self.current_noise_scale.max(self.config.noise_injection.min_noise_scale);
for gradient in gradients.iter_mut() {
let noise = match self.config.noise_injection.noise_type {
NoiseType::Gaussian => {
let noise_tensor = Tensor::randn(&gradient.shape())?;
noise_tensor.mul_scalar(self.current_noise_scale)?;
noise_tensor
},
NoiseType::Uniform => {
let bound = self.current_noise_scale * 3.0_f32.sqrt(); let noise_tensor = Tensor::randn(&gradient.shape())?;
noise_tensor.mul_scalar(bound)?;
noise_tensor
},
NoiseType::Laplace => {
let noise_tensor = Tensor::randn(&gradient.shape())?;
noise_tensor.mul_scalar(self.current_noise_scale * 2.0_f32.sqrt())?;
noise_tensor
},
};
*gradient = gradient.add(&noise)?;
}
Ok(())
}
fn apply_smoothing(&mut self, gradients: &mut [Tensor]) -> Result<()> {
let decay = self.config.smoothing.decay;
for (i, gradient) in gradients.iter_mut().enumerate() {
if let Some(smoothed) = self.smoothed_gradients.get(&i) {
let new_smoothed =
smoothed.mul_scalar(decay)?.add(&gradient.mul_scalar(1.0 - decay)?)?;
self.smoothed_gradients.insert(i, new_smoothed.clone());
if self.config.smoothing.debias {
self.smoothing_bias_correction *= decay;
let bias_corrected =
new_smoothed.div_scalar(1.0 - self.smoothing_bias_correction)?;
*gradient = bias_corrected;
} else {
*gradient = new_smoothed;
}
} else {
self.smoothed_gradients.insert(i, gradient.clone());
}
}
Ok(())
}
fn apply_hessian_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
self.gradient_history.push(gradients.to_vec());
if self.gradient_history.len() > self.config.hessian_preconditioning.history_window {
self.gradient_history.remove(0);
}
if self.current_step - self.last_hessian_update
>= self.config.hessian_preconditioning.update_frequency
{
self.update_hessian_approximation(gradients)?;
self.last_hessian_update = self.current_step;
}
match self.config.hessian_preconditioning.approximation_type {
HessianApproximationType::Diagonal => {
self.apply_diagonal_preconditioning(gradients)?;
},
HessianApproximationType::GaussNewton => {
self.apply_gauss_newton_preconditioning(gradients)?;
},
HessianApproximationType::FisherInformation => {
self.apply_fisher_information_preconditioning(gradients)?;
},
HessianApproximationType::QuasiNewton => {
self.apply_quasi_newton_preconditioning(gradients)?;
},
}
Ok(())
}
fn update_hessian_approximation(&mut self, gradients: &[Tensor]) -> Result<()> {
match self.config.hessian_preconditioning.approximation_type {
HessianApproximationType::Diagonal => {
self.update_diagonal_hessian(gradients)?;
},
HessianApproximationType::GaussNewton => {
self.update_gauss_newton_hessian(gradients)?;
},
HessianApproximationType::FisherInformation => {
self.update_fisher_information_hessian(gradients)?;
},
HessianApproximationType::QuasiNewton => {
self.update_quasi_newton_hessian(gradients)?;
},
}
Ok(())
}
fn update_diagonal_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter().enumerate() {
if self.gradient_history.len() > 1 {
let mut variance = Tensor::zeros(&gradient.shape())?;
let mut mean = Tensor::zeros(&gradient.shape())?;
for grad_vec in &self.gradient_history {
if let Some(hist_grad) = grad_vec.get(i) {
mean = mean.add(hist_grad)?;
}
}
mean = mean.div_scalar(self.gradient_history.len() as f32)?;
for grad_vec in &self.gradient_history {
if let Some(hist_grad) = grad_vec.get(i) {
let diff = hist_grad.sub(&mean)?;
variance = variance.add(&diff.mul(&diff)?)?;
}
}
variance = variance.div_scalar(self.gradient_history.len() as f32)?;
let damping = Tensor::ones(&gradient.shape())?
.mul_scalar(self.config.hessian_preconditioning.damping)?;
variance = variance.add(&damping)?;
self.hessian_diagonal.insert(i, variance);
}
}
Ok(())
}
fn update_gauss_newton_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter().enumerate() {
let outer_product = gradient.mul(gradient)?;
let damping = Tensor::ones(&gradient.shape())?
.mul_scalar(self.config.hessian_preconditioning.damping)?;
let hessian_approx = outer_product.add(&damping)?;
self.hessian_diagonal.insert(i, hessian_approx);
}
Ok(())
}
fn update_fisher_information_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter().enumerate() {
let fisher_approx = gradient.mul(gradient)?;
let damping = Tensor::ones(&gradient.shape())?
.mul_scalar(self.config.hessian_preconditioning.damping)?;
let hessian_approx = fisher_approx.add(&damping)?;
self.hessian_diagonal.insert(i, hessian_approx);
}
Ok(())
}
fn update_quasi_newton_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
if self.gradient_history.len() > 1 {
for (i, gradient) in gradients.iter().enumerate() {
if let Some(prev_grad_vec) =
self.gradient_history.get(self.gradient_history.len() - 2)
{
if let Some(prev_grad) = prev_grad_vec.get(i) {
let grad_diff = gradient.sub(prev_grad)?;
let hessian_approx = grad_diff.abs()?;
let damping = Tensor::ones(&gradient.shape())?
.mul_scalar(self.config.hessian_preconditioning.damping)?;
let final_hessian = hessian_approx.add(&damping)?;
self.hessian_diagonal.insert(i, final_hessian);
}
}
}
}
Ok(())
}
fn apply_diagonal_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter_mut().enumerate() {
if let Some(hessian_diag) = self.hessian_diagonal.get(&i) {
let min_val = Tensor::scalar(self.config.hessian_preconditioning.min_eigenvalue)?;
let clamped_hessian = hessian_diag.max(&min_val)?;
*gradient = gradient.div(&clamped_hessian)?;
}
}
Ok(())
}
fn apply_gauss_newton_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
self.apply_diagonal_preconditioning(gradients)
}
fn apply_fisher_information_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
self.apply_diagonal_preconditioning(gradients)
}
fn apply_quasi_newton_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
self.apply_diagonal_preconditioning(gradients)
}
pub fn get_current_clip_norm(&self) -> f32 {
self.current_clip_norm
}
pub fn get_current_noise_scale(&self) -> f32 {
self.current_noise_scale
}
pub fn get_gradient_norm_stats(&self) -> Option<(f32, f32, f32)> {
if self.gradient_norm_history.is_empty() {
return None;
}
let sum: f32 = self.gradient_norm_history.iter().sum();
let mean = sum / self.gradient_norm_history.len() as f32;
let variance = self.gradient_norm_history.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
/ self.gradient_norm_history.len() as f32;
let std_dev = variance.sqrt();
let max_norm = self.gradient_norm_history.iter().fold(0.0f32, |acc, &x| acc.max(x));
Some((mean, std_dev, max_norm))
}
pub fn reset(&mut self) {
self.current_step = 0;
self.gradient_norm_history.clear();
self.smoothed_gradients.clear();
self.current_clip_norm = self.config.adaptive_clipping.initial_clip_norm;
self.current_noise_scale = self.config.noise_injection.initial_noise_scale;
self.smoothing_bias_correction = 1.0;
self.hessian_diagonal.clear();
self.hessian_inverse.clear();
self.last_hessian_update = 0;
self.gradient_history.clear();
}
pub fn set_config(&mut self, config: GradientProcessingConfig) {
self.config = config;
self.reset();
}
pub fn get_config(&self) -> &GradientProcessingConfig {
&self.config
}
}
pub struct GradientProcessedOptimizer<T> {
base_optimizer: T,
gradient_processor: GradientProcessor,
}
impl<T> GradientProcessedOptimizer<T> {
pub fn new(base_optimizer: T, config: GradientProcessingConfig) -> Self {
Self {
base_optimizer,
gradient_processor: GradientProcessor::new(config),
}
}
pub fn with_default_processing(base_optimizer: T) -> Self {
Self::new(base_optimizer, GradientProcessingConfig::default())
}
pub fn gradient_processor(&self) -> &GradientProcessor {
&self.gradient_processor
}
pub fn gradient_processor_mut(&mut self) -> &mut GradientProcessor {
&mut self.gradient_processor
}
pub fn base_optimizer(&self) -> &T {
&self.base_optimizer
}
pub fn base_optimizer_mut(&mut self) -> &mut T {
&mut self.base_optimizer
}
}
impl<T: crate::optimizer::OptimizerState> crate::optimizer::OptimizerState
for GradientProcessedOptimizer<T>
{
fn zero_grad(&mut self) -> Result<()> {
self.base_optimizer.zero_grad()
}
fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
let mut gradients = Vec::new();
for param in parameters.iter() {
if let Ok(grad) = param.grad() {
gradients.push(grad);
} else {
return Err(anyhow!("Parameter missing gradient"));
}
}
self.gradient_processor.process_gradients(&mut gradients)?;
for (param, processed_grad) in parameters.iter_mut().zip(gradients.iter()) {
param.set_grad(processed_grad.clone())?;
}
self.base_optimizer.step(parameters)
}
fn get_lr(&self) -> f32 {
self.base_optimizer.get_lr()
}
fn set_lr(&mut self, lr: f32) {
self.base_optimizer.set_lr(lr);
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
self.base_optimizer.state_dict()
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
self.base_optimizer.load_state_dict(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gradient_processing_config_default() {
let config = GradientProcessingConfig::default();
assert!(!config.enable_centralization);
assert!(!config.enable_standardization);
assert!(!config.enable_adaptive_clipping);
assert!(!config.enable_noise_injection);
assert!(!config.enable_smoothing);
}
#[test]
fn test_adaptive_clipping_config_default() {
let config = AdaptiveClippingConfig::default();
assert_eq!(config.initial_clip_norm, 1.0);
assert_eq!(config.min_clip_norm, 0.1);
assert_eq!(config.max_clip_norm, 10.0);
assert_eq!(config.adaptation_rate, 0.01);
assert_eq!(config.target_percentile, 0.9);
assert_eq!(config.history_window, 100);
}
#[test]
fn test_gradient_processor_creation() {
let processor = GradientProcessor::with_defaults();
assert_eq!(processor.current_step, 0);
assert_eq!(processor.gradient_norm_history.len(), 0);
}
#[test]
fn test_gradient_norm_stats_empty() {
let processor = GradientProcessor::with_defaults();
assert!(processor.get_gradient_norm_stats().is_none());
}
#[test]
fn test_gradient_processor_reset() {
let mut processor = GradientProcessor::with_defaults();
processor.current_step = 10;
processor.gradient_norm_history.push(1.0);
processor.reset();
assert_eq!(processor.current_step, 0);
assert_eq!(processor.gradient_norm_history.len(), 0);
assert_eq!(processor.hessian_diagonal.len(), 0);
assert_eq!(processor.gradient_history.len(), 0);
}
#[test]
fn test_hessian_preconditioning_config_default() {
let config = HessianPreconditioningConfig::default();
assert!(matches!(
config.approximation_type,
HessianApproximationType::Diagonal
));
assert_eq!(config.damping, 1e-4);
assert_eq!(config.update_frequency, 10);
assert_eq!(config.history_window, 20);
assert_eq!(config.min_eigenvalue, 1e-8);
assert_eq!(config.max_condition_number, 1e6);
}
#[test]
fn test_hessian_preconditioning_enabled() {
let config = GradientProcessingConfig {
enable_hessian_preconditioning: true,
..GradientProcessingConfig::default()
};
let processor = GradientProcessor::new(config);
assert!(processor.config.enable_hessian_preconditioning);
}
#[test]
fn test_hessian_approximation_types() {
let mut config = GradientProcessingConfig {
enable_hessian_preconditioning: true,
..GradientProcessingConfig::default()
};
config.hessian_preconditioning.approximation_type = HessianApproximationType::Diagonal;
let processor = GradientProcessor::new(config.clone());
assert!(matches!(
processor.config.hessian_preconditioning.approximation_type,
HessianApproximationType::Diagonal
));
config.hessian_preconditioning.approximation_type = HessianApproximationType::GaussNewton;
let processor = GradientProcessor::new(config.clone());
assert!(matches!(
processor.config.hessian_preconditioning.approximation_type,
HessianApproximationType::GaussNewton
));
config.hessian_preconditioning.approximation_type =
HessianApproximationType::FisherInformation;
let processor = GradientProcessor::new(config.clone());
assert!(matches!(
processor.config.hessian_preconditioning.approximation_type,
HessianApproximationType::FisherInformation
));
config.hessian_preconditioning.approximation_type = HessianApproximationType::QuasiNewton;
let processor = GradientProcessor::new(config.clone());
assert!(matches!(
processor.config.hessian_preconditioning.approximation_type,
HessianApproximationType::QuasiNewton
));
}
}