use crate::common::{OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::Result;
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BGEAdamConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub entropy_scaling: f32,
pub beta1_adaptation: f32,
pub beta2_adaptation: f32,
pub min_entropy: f32,
pub bias_correction: bool,
pub entropy_weighting: bool,
pub adaptive_parameters: bool,
}
impl Default for BGEAdamConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.01,
entropy_scaling: 0.1,
beta1_adaptation: 0.05,
beta2_adaptation: 0.05,
min_entropy: 1e-6,
bias_correction: true,
entropy_weighting: true,
adaptive_parameters: true,
}
}
}
pub struct BGEAdam {
config: BGEAdamConfig,
state: OptimizerState,
step_count: usize,
entropy_history: Vec<f32>,
max_entropy_history: usize,
}
impl BGEAdam {
pub fn new(
learning_rate: f32,
betas: (f32, f32),
epsilon: f32,
weight_decay: f32,
entropy_scaling: f32,
beta1_adaptation: f32,
beta2_adaptation: f32,
) -> Self {
let config = BGEAdamConfig {
learning_rate,
beta1: betas.0,
beta2: betas.1,
epsilon,
weight_decay,
entropy_scaling,
beta1_adaptation,
beta2_adaptation,
..Default::default()
};
Self {
config,
state: OptimizerState::new(),
step_count: 0,
entropy_history: Vec::new(),
max_entropy_history: 100, }
}
pub fn for_large_models() -> Self {
Self::new(
3e-4, (0.9, 0.95), 1e-8,
0.1, 0.15, 0.08, 0.08,
)
}
pub fn for_vision() -> Self {
Self::new(
1e-3, (0.9, 0.999), 1e-8,
0.05, 0.1, 0.05, 0.05,
)
}
pub fn for_robust_training() -> Self {
Self::new(
5e-4, (0.95, 0.999), 1e-6, 0.02, 0.2, 0.1, 0.1,
)
}
fn calculate_gradient_entropy(&self, gradients: &Tensor) -> Result<f32> {
let grad_data = gradients.data()?;
let abs_grads: Vec<f32> = grad_data.iter().map(|&g| g.abs()).collect();
let sum_abs_grads: f32 = abs_grads.iter().sum();
if sum_abs_grads < self.config.epsilon {
return Ok(self.config.min_entropy);
}
let probabilities: Vec<f32> =
abs_grads.iter().map(|&abs_g| abs_g / sum_abs_grads).collect();
let entropy =
probabilities
.iter()
.map(|&p| {
if p > self.config.epsilon {
-p * (p + self.config.epsilon).ln()
} else {
0.0
}
})
.sum::<f32>();
Ok(entropy.max(self.config.min_entropy))
}
fn apply_entropy_weighting(&self, gradients: &Tensor, entropy: f32) -> Result<Tensor> {
if !self.config.entropy_weighting {
return Ok(gradients.clone());
}
let grad_data = gradients.data()?;
let sum_abs_grads: f32 = grad_data.iter().map(|&g| g.abs()).sum();
if sum_abs_grads < self.config.epsilon {
return Ok(gradients.clone());
}
let weighted_data: Vec<f32> = grad_data
.iter()
.map(|&g| {
let p_i = g.abs() / sum_abs_grads;
let weight = (-self.config.entropy_scaling * entropy * p_i).exp();
g * weight
})
.collect();
Tensor::new(weighted_data)
}
fn get_adaptive_betas(&self, entropy: f32) -> (f32, f32) {
if !self.config.adaptive_parameters {
return (self.config.beta1, self.config.beta2);
}
let beta1_adaptive = self.config.beta1 * (1.0 + self.config.beta1_adaptation * entropy);
let beta2_adaptive = self.config.beta2 * (1.0 - self.config.beta2_adaptation * entropy);
let beta1_adaptive = beta1_adaptive.clamp(0.1, 0.99);
let beta2_adaptive = beta2_adaptive.clamp(0.9, 0.9999);
(beta1_adaptive, beta2_adaptive)
}
fn update_entropy_history(&mut self, entropy: f32) {
self.entropy_history.push(entropy);
if self.entropy_history.len() > self.max_entropy_history {
self.entropy_history.remove(0);
}
}
pub fn get_average_entropy(&self) -> f32 {
if self.entropy_history.is_empty() {
0.0
} else {
self.entropy_history.iter().sum::<f32>() / self.entropy_history.len() as f32
}
}
pub fn get_entropy_stats(&self) -> (f32, f32, f32) {
if self.entropy_history.is_empty() {
return (0.0, 0.0, 0.0);
}
let min_entropy = self.entropy_history.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_entropy = self.entropy_history.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let avg_entropy = self.get_average_entropy();
(min_entropy, max_entropy, avg_entropy)
}
}
impl Optimizer for BGEAdam {
fn zero_grad(&mut self) {
}
fn update(&mut self, parameter: &mut Tensor, gradient: &Tensor) -> Result<()> {
let param_data = parameter.data()?;
let param_id = format!("{:p}", param_data.as_ptr());
self.step_count += 1;
let entropy = self.calculate_gradient_entropy(gradient)?;
self.update_entropy_history(entropy);
let weighted_gradients = self.apply_entropy_weighting(gradient, entropy)?;
let (beta1_adaptive, beta2_adaptive) = self.get_adaptive_betas(entropy);
let param_size = parameter.data()?.len();
let momentum_data = {
let momentum_buffer = self.state.get_or_create_momentum(param_id.clone(), param_size);
momentum_buffer.clone()
};
let variance_data = {
let variance_buffer = self.state.get_or_create_variance(param_id.clone(), param_size);
variance_buffer.clone()
};
let momentum = Tensor::new(momentum_data)?;
let variance = Tensor::new(variance_data)?;
let momentum_data = momentum.data()?;
let weighted_grad_data = weighted_gradients.data()?;
let new_momentum_data: Vec<f32> = momentum_data
.iter()
.zip(weighted_grad_data.iter())
.map(|(&m, &g)| beta1_adaptive * m + (1.0 - beta1_adaptive) * g)
.collect();
let new_momentum = Tensor::new(new_momentum_data)?;
let variance_data = variance.data()?;
let new_variance_data: Vec<f32> = variance_data
.iter()
.zip(weighted_grad_data.iter())
.map(|(&v, &g)| beta2_adaptive * v + (1.0 - beta2_adaptive) * g * g)
.collect();
let new_variance = Tensor::new(new_variance_data)?;
let new_momentum_data = new_momentum.data()?;
let new_variance_data = new_variance.data()?;
self.state.momentum.insert(param_id.clone(), new_momentum_data.clone());
self.state.variance.insert(param_id.clone(), new_variance_data.clone());
let (corrected_momentum, corrected_variance) = if self.config.bias_correction {
let step_f32 = self.step_count as f32;
let momentum_correction = 1.0 - beta1_adaptive.powf(step_f32);
let variance_correction = 1.0 - beta2_adaptive.powf(step_f32);
let momentum_data = new_momentum.data()?;
let variance_data = new_variance.data()?;
let corrected_momentum_data: Vec<f32> =
momentum_data.iter().map(|&m| m / momentum_correction).collect();
let corrected_variance_data: Vec<f32> =
variance_data.iter().map(|&v| v / variance_correction).collect();
(
Tensor::new(corrected_momentum_data)?,
Tensor::new(corrected_variance_data)?,
)
} else {
(new_momentum, new_variance)
};
let param_data = parameter.data()?;
let corrected_momentum_data = corrected_momentum.data()?;
let corrected_variance_data = corrected_variance.data()?;
let updated_params: Vec<f32> = param_data
.iter()
.zip(corrected_momentum_data.iter())
.zip(corrected_variance_data.iter())
.map(|((&p, &m), &v)| {
let update = m / (v.sqrt() + self.config.epsilon);
let weight_decay_term = self.config.weight_decay * p;
p - self.config.learning_rate * (update + weight_decay_term)
})
.collect();
*parameter = Tensor::new(updated_params)?;
Ok(())
}
fn step(&mut self) {
self.state.step();
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
}
impl StatefulOptimizer for BGEAdam {
type Config = BGEAdamConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
for (key, buffer) in &self.state.momentum {
let tensor = Tensor::new(buffer.clone())?;
state_dict.insert(format!("{}_momentum", key), tensor);
}
for (key, buffer) in &self.state.variance {
let tensor = Tensor::new(buffer.clone())?;
state_dict.insert(format!("{}_variance", key), tensor);
}
let entropy_tensor = Tensor::new(self.entropy_history.clone())?;
state_dict.insert("entropy_history".to_string(), entropy_tensor);
let step_tensor = Tensor::new(vec![self.step_count as f32])?;
state_dict.insert("step_count".to_string(), step_tensor);
Ok(state_dict)
}
fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
for (key, tensor) in state_dict {
let data = tensor.data()?;
if key == "entropy_history" {
self.entropy_history = data.clone();
} else if key == "step_count" {
if let Some(&count) = data.first() {
self.step_count = count as usize;
}
} else if let Some(param_key) = key.strip_suffix("_momentum") {
self.state.momentum.insert(param_key.to_string(), data.clone());
} else if let Some(param_key) = key.strip_suffix("_variance") {
self.state.variance.insert(param_key.to_string(), data.clone());
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let momentum_size: usize = self.state.momentum.values().map(|v| v.len()).sum();
let variance_size: usize = self.state.variance.values().map(|v| v.len()).sum();
let entropy_size = self.entropy_history.len();
let total_bytes =
(momentum_size + variance_size + entropy_size) * std::mem::size_of::<f32>();
StateMemoryStats {
momentum_elements: momentum_size,
variance_elements: variance_size,
third_moment_elements: 0,
total_bytes,
num_parameters: self.state.momentum.len().max(self.state.variance.len()),
}
}
fn reset_state(&mut self) {
self.state.clear();
self.step_count = 0;
self.entropy_history.clear();
}
fn num_parameters(&self) -> usize {
self.state.momentum.len().max(self.state.variance.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_bge_adam_creation() {
let optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
assert_eq!(optimizer.config.learning_rate, 1e-3);
assert_eq!(optimizer.config.beta1, 0.9);
assert_eq!(optimizer.config.beta2, 0.999);
assert_eq!(optimizer.step_count, 0);
}
#[test]
fn test_bge_adam_presets() {
let llm_optimizer = BGEAdam::for_large_models();
assert_eq!(llm_optimizer.config.learning_rate, 3e-4);
assert_eq!(llm_optimizer.config.beta2, 0.95);
let vision_optimizer = BGEAdam::for_vision();
assert_eq!(vision_optimizer.config.learning_rate, 1e-3);
assert_eq!(vision_optimizer.config.beta2, 0.999);
let robust_optimizer = BGEAdam::for_robust_training();
assert_eq!(robust_optimizer.config.learning_rate, 5e-4);
assert_eq!(robust_optimizer.config.beta1, 0.95);
}
#[test]
fn test_entropy_calculation() -> Result<()> {
let optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
let gradients = Tensor::new(vec![1.0, 2.0, 1.0, 0.5])?;
let entropy = optimizer.calculate_gradient_entropy(&gradients)?;
assert!(entropy > 0.0);
Ok(())
}
#[test]
fn test_entropy_weighting() -> Result<()> {
let optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
let gradients = Tensor::new(vec![1.0, 2.0, 1.0, 0.5])?;
let entropy = 1.0;
let weighted_gradients = optimizer.apply_entropy_weighting(&gradients, entropy)?;
let orig_data = gradients.data()?;
let weighted_data = weighted_gradients.data()?;
assert_ne!(orig_data, weighted_data);
Ok(())
}
#[test]
fn test_adaptive_betas() {
let optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
let entropy = 1.0;
let (beta1_adaptive, beta2_adaptive) = optimizer.get_adaptive_betas(entropy);
assert!(beta1_adaptive > 0.9);
assert!(beta2_adaptive < 0.999);
assert!(beta1_adaptive < 0.99);
assert!(beta2_adaptive > 0.9);
}
#[test]
fn test_entropy_history() {
let mut optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
optimizer.update_entropy_history(1.0);
optimizer.update_entropy_history(1.5);
optimizer.update_entropy_history(0.8);
assert_eq!(optimizer.entropy_history.len(), 3);
assert_relative_eq!(optimizer.get_average_entropy(), 1.1, epsilon = 1e-6);
let (min_entropy, max_entropy, avg_entropy) = optimizer.get_entropy_stats();
assert_relative_eq!(min_entropy, 0.8, epsilon = 1e-6);
assert_relative_eq!(max_entropy, 1.5, epsilon = 1e-6);
assert_relative_eq!(avg_entropy, 1.1, epsilon = 1e-6);
}
#[test]
fn test_lr_setter_getter() {
let mut optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
assert_eq!(optimizer.get_lr(), 1e-3);
optimizer.set_lr(2e-3);
assert_eq!(optimizer.get_lr(), 2e-3);
}
#[test]
fn test_state_dict_operations() -> Result<()> {
let mut optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
optimizer.update_entropy_history(1.0);
optimizer.update_entropy_history(1.5);
optimizer.step_count = 10;
let state_dict = optimizer.state_dict()?;
assert!(state_dict.contains_key("entropy_history"));
assert!(state_dict.contains_key("step_count"));
let mut new_optimizer = BGEAdam::new(2e-3, (0.8, 0.99), 1e-7, 0.02, 0.2, 0.1, 0.1);
new_optimizer.load_state_dict(state_dict)?;
assert_eq!(new_optimizer.entropy_history.len(), 2);
assert_eq!(new_optimizer.step_count, 10);
assert_relative_eq!(new_optimizer.get_average_entropy(), 1.25, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_reset() {
let mut optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
optimizer.update_entropy_history(1.0);
optimizer.step_count = 5;
optimizer.reset_state();
assert_eq!(optimizer.step_count, 0);
assert_eq!(optimizer.entropy_history.len(), 0);
assert_eq!(optimizer.state.momentum.len(), 0);
assert_eq!(optimizer.state.variance.len(), 0);
}
#[test]
fn test_memory_usage() {
let optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
let stats = optimizer.memory_usage();
assert_eq!(stats.total_bytes, 0); assert_eq!(stats.num_parameters, 0);
assert_eq!(stats.momentum_elements, 0);
assert_eq!(stats.variance_elements, 0);
}
#[test]
fn test_config_access() {
let optimizer = BGEAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01, 0.1, 0.05, 0.05);
let config = optimizer.config();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.beta1, 0.9);
assert_eq!(config.beta2, 0.999);
assert_eq!(config.entropy_scaling, 0.1);
assert_eq!(config.beta1_adaptation, 0.05);
assert_eq!(config.beta2_adaptation, 0.05);
}
#[test]
fn test_disabled_features() {
let config = BGEAdamConfig {
entropy_weighting: false,
adaptive_parameters: false,
..BGEAdamConfig::default()
};
let optimizer = BGEAdam {
config,
state: OptimizerState::new(),
step_count: 0,
entropy_history: Vec::new(),
max_entropy_history: 100,
};
let gradients = Tensor::new(vec![1.0, 2.0, 1.0, 0.5]).expect("Failed to create tensor");
let weighted = optimizer
.apply_entropy_weighting(&gradients, 1.0)
.expect("Operation failed in test");
let grad_data = gradients.data().expect("Operation failed in test");
let weighted_data = weighted.data().expect("Operation failed in test");
assert_eq!(grad_data, weighted_data);
let (beta1, beta2) = optimizer.get_adaptive_betas(1.0);
assert_eq!(beta1, optimizer.config.beta1);
assert_eq!(beta2, optimizer.config.beta2);
}
}