use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug)]
pub struct PDEAwareOptimizer {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub residual_variance_weight: f32, pub gradient_alignment_factor: f32, pub smoothing_factor: f32, pub sharp_gradient_threshold: f32,
pub step: usize,
pub momentum: HashMap<String, Vec<f32>>,
pub variance: HashMap<String, Vec<f32>>,
pub residual_variance_history: Vec<f32>,
pub gradient_alignment_history: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct PDEAwareConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub residual_variance_weight: f32,
pub gradient_alignment_factor: f32,
pub smoothing_factor: f32,
pub sharp_gradient_threshold: f32,
}
impl Default for PDEAwareConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
residual_variance_weight: 0.1,
gradient_alignment_factor: 0.05,
smoothing_factor: 0.95,
sharp_gradient_threshold: 1.0,
}
}
}
impl Default for PDEAwareOptimizer {
fn default() -> Self {
Self::new()
}
}
impl PDEAwareOptimizer {
pub fn new() -> Self {
Self::from_config(PDEAwareConfig::default())
}
pub fn from_config(config: PDEAwareConfig) -> Self {
Self {
learning_rate: config.learning_rate,
beta1: config.beta1,
beta2: config.beta2,
epsilon: config.epsilon,
weight_decay: config.weight_decay,
residual_variance_weight: config.residual_variance_weight,
gradient_alignment_factor: config.gradient_alignment_factor,
smoothing_factor: config.smoothing_factor,
sharp_gradient_threshold: config.sharp_gradient_threshold,
step: 0,
momentum: HashMap::new(),
variance: HashMap::new(),
residual_variance_history: Vec::new(),
gradient_alignment_history: Vec::new(),
}
}
pub fn for_burgers_equation() -> Self {
Self::from_config(PDEAwareConfig {
learning_rate: 5e-4,
beta1: 0.95,
beta2: 0.999,
epsilon: 1e-10,
weight_decay: 1e-6,
residual_variance_weight: 0.15,
gradient_alignment_factor: 0.08,
smoothing_factor: 0.98,
sharp_gradient_threshold: 0.8,
})
}
pub fn for_allen_cahn() -> Self {
Self::from_config(PDEAwareConfig {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.995,
epsilon: 1e-9,
weight_decay: 1e-5,
residual_variance_weight: 0.2,
gradient_alignment_factor: 0.1,
smoothing_factor: 0.95,
sharp_gradient_threshold: 1.5,
})
}
pub fn for_kdv_equation() -> Self {
Self::from_config(PDEAwareConfig {
learning_rate: 2e-4,
beta1: 0.95,
beta2: 0.9995,
epsilon: 1e-12,
weight_decay: 0.0,
residual_variance_weight: 0.25,
gradient_alignment_factor: 0.12,
smoothing_factor: 0.99,
sharp_gradient_threshold: 0.5,
})
}
pub fn for_sharp_gradients() -> Self {
Self::from_config(PDEAwareConfig {
learning_rate: 1e-4,
beta1: 0.95,
beta2: 0.9999,
epsilon: 1e-10,
weight_decay: 1e-7,
residual_variance_weight: 0.3,
gradient_alignment_factor: 0.15,
smoothing_factor: 0.99,
sharp_gradient_threshold: 0.3,
})
}
fn compute_residual_variance_from_norm(&mut self, grad_norm: f32) -> f32 {
let variance = grad_norm;
self.residual_variance_history.push(variance);
if self.residual_variance_history.len() > 100 {
self.residual_variance_history.remove(0);
}
if self.residual_variance_history.len() > 1 {
let prev_variance =
self.residual_variance_history[self.residual_variance_history.len() - 2];
self.smoothing_factor * prev_variance + (1.0 - self.smoothing_factor) * variance
} else {
variance
}
}
fn is_sharp_gradient_region_from_norm(&self, grad_norm: f32, max_grad: f32) -> bool {
grad_norm > self.sharp_gradient_threshold || max_grad > 2.0 * self.sharp_gradient_threshold
}
pub fn adaptive_learning_rate(
&self,
base_lr: f32,
residual_variance: f32,
is_sharp_region: bool,
) -> f32 {
let mut adaptive_lr = base_lr;
if residual_variance > 0.1 {
adaptive_lr *= 1.0 / (1.0 + self.residual_variance_weight * residual_variance);
}
if is_sharp_region {
adaptive_lr *= 0.5;
}
adaptive_lr.clamp(base_lr * 0.01, base_lr * 2.0)
}
pub fn get_pde_stats(&self) -> PDEAwareStats {
let avg_residual_variance = if !self.residual_variance_history.is_empty() {
self.residual_variance_history.iter().sum::<f32>()
/ self.residual_variance_history.len() as f32
} else {
0.0
};
PDEAwareStats {
step: self.step,
average_residual_variance: avg_residual_variance,
parameters_tracked: self.momentum.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct PDEAwareStats {
pub step: usize,
pub average_residual_variance: f32,
pub parameters_tracked: usize,
}
impl Optimizer for PDEAwareOptimizer {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
self.step += 1;
let param_id = format!("{:p}", param.as_ptr());
let grad_norm: f32 = grad_arr.iter().map(|g| g * g).sum::<f32>().sqrt();
let max_grad: f32 = grad_arr.iter().map(|g| g.abs()).fold(0.0, f32::max);
let residual_variance = self.compute_residual_variance_from_norm(grad_norm);
let is_sharp_region = self.is_sharp_gradient_region_from_norm(grad_norm, max_grad);
let adaptive_lr = self.adaptive_learning_rate(
self.learning_rate,
residual_variance,
is_sharp_region,
);
let m = self
.momentum
.entry(param_id.clone())
.or_insert_with(|| vec![0.0; grad_arr.len()]);
let v = self.variance.entry(param_id).or_insert_with(|| vec![0.0; grad_arr.len()]);
if m.len() != grad_arr.len() || v.len() != grad_arr.len() {
return Err(TrustformersError::tensor_op_error(
"Momentum/variance buffer size mismatch",
"pde_aware_update",
));
}
for i in 0..grad_arr.len() {
m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * grad_arr[i];
v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * grad_arr[i] * grad_arr[i];
}
let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
let mut update_vec = vec![0.0; param.len()];
for i in 0..param.len() {
let m_hat = m[i] / bias_correction1;
let v_hat = v[i] / bias_correction2;
let update = adaptive_lr * m_hat / (v_hat.sqrt() + self.epsilon);
update_vec[i] = update;
if self.weight_decay > 0.0 {
update_vec[i] += self.weight_decay * param[i];
}
}
for (i, update) in update_vec.iter().enumerate() {
param[i] -= update;
}
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for PDEAwareOptimizer",
"pde_aware_update",
)),
}
}
fn zero_grad(&mut self) {
}
fn step(&mut self) {
}
fn get_lr(&self) -> f32 {
self.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.learning_rate = lr;
}
}
#[cfg(test)]
#[path = "pde_aware_tests.rs"]
mod pde_aware_tests;