use crate::common::{OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MicroAdamConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
pub compression_ratio: f32,
pub min_block_size: usize,
pub adaptive_compression: bool,
pub compression_threshold: f32,
pub bias_correction: bool,
pub max_compression_error: f32,
}
impl Default for MicroAdamConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.01,
compression_ratio: 0.1,
min_block_size: 64,
adaptive_compression: true,
compression_threshold: 1e-6,
bias_correction: true,
max_compression_error: 1e-4,
}
}
}
#[derive(Debug, Clone)]
struct CompressedGradient {
compressed_data: Vec<f32>,
indices: Vec<usize>,
scale_factor: f32,
original_size: usize,
compression_type: CompressionType,
}
#[derive(Debug, Clone, Copy)]
enum CompressionType {
TopK,
Threshold,
BlockWise,
#[allow(dead_code)]
Adaptive,
}
impl CompressedGradient {
fn compress(gradient: &[f32], config: &MicroAdamConfig) -> Self {
let original_size = gradient.len();
let target_size = (original_size as f32 * config.compression_ratio) as usize;
let target_size = target_size.max(config.min_block_size.min(original_size));
let compression_type = if config.adaptive_compression {
Self::choose_adaptive_compression(gradient, config)
} else {
CompressionType::TopK
};
match compression_type {
CompressionType::TopK => Self::compress_topk(gradient, target_size),
CompressionType::Threshold => Self::compress_threshold(gradient, config),
CompressionType::BlockWise => Self::compress_blockwise(gradient, config),
CompressionType::Adaptive => Self::compress_adaptive(gradient, config),
}
}
fn choose_adaptive_compression(gradient: &[f32], config: &MicroAdamConfig) -> CompressionType {
let mean_abs = gradient.iter().map(|x| x.abs()).sum::<f32>() / gradient.len() as f32;
let sparsity = gradient.iter().filter(|&&x| x.abs() < config.compression_threshold).count()
as f32
/ gradient.len() as f32;
if sparsity > 0.8 {
CompressionType::Threshold
} else if mean_abs > 1e-3 {
CompressionType::BlockWise
} else {
CompressionType::TopK
}
}
fn compress_topk(gradient: &[f32], k: usize) -> Self {
let mut indexed_values: Vec<(usize, f32)> =
gradient.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = k.min(indexed_values.len());
let indices: Vec<usize> = indexed_values[..k].iter().map(|(i, _)| *i).collect();
let compressed_data: Vec<f32> = indices.iter().map(|&i| gradient[i]).collect();
let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
Self {
compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
indices,
scale_factor: 1.0 / scale_factor,
original_size: gradient.len(),
compression_type: CompressionType::TopK,
}
}
fn compress_threshold(gradient: &[f32], config: &MicroAdamConfig) -> Self {
let threshold = config.compression_threshold;
let mut indices = Vec::new();
let mut compressed_data = Vec::new();
for (i, &val) in gradient.iter().enumerate() {
if val.abs() >= threshold {
indices.push(i);
compressed_data.push(val);
}
}
let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
Self {
compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
indices,
scale_factor: 1.0 / scale_factor,
original_size: gradient.len(),
compression_type: CompressionType::Threshold,
}
}
fn compress_blockwise(gradient: &[f32], config: &MicroAdamConfig) -> Self {
let block_size = config.min_block_size;
let num_blocks = gradient.len().div_ceil(block_size);
let target_elements_per_block =
((block_size as f32 * config.compression_ratio) as usize).max(1);
let mut indices = Vec::new();
let mut compressed_data = Vec::new();
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(gradient.len());
let block = &gradient[start..end];
let mut block_indexed: Vec<(usize, f32)> =
block.iter().enumerate().map(|(i, &val)| (start + i, val.abs())).collect();
block_indexed
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = target_elements_per_block.min(block_indexed.len());
for i in 0..k {
let global_idx = block_indexed[i].0;
indices.push(global_idx);
compressed_data.push(gradient[global_idx]);
}
}
let max_val = compressed_data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale_factor = if max_val > 0.0 { 1.0 / max_val } else { 1.0 };
Self {
compressed_data: compressed_data.iter().map(|x| x * scale_factor).collect(),
indices,
scale_factor: 1.0 / scale_factor,
original_size: gradient.len(),
compression_type: CompressionType::BlockWise,
}
}
fn compress_adaptive(gradient: &[f32], config: &MicroAdamConfig) -> Self {
let topk = Self::compress_topk(
gradient,
(gradient.len() as f32 * config.compression_ratio) as usize,
);
let threshold = Self::compress_threshold(gradient, config);
let blockwise = Self::compress_blockwise(gradient, config);
let topk_ratio = topk.compressed_data.len() as f32 / gradient.len() as f32;
let threshold_ratio = threshold.compressed_data.len() as f32 / gradient.len() as f32;
let blockwise_ratio = blockwise.compressed_data.len() as f32 / gradient.len() as f32;
if threshold_ratio <= config.compression_ratio && threshold_ratio < topk_ratio {
threshold
} else if blockwise_ratio <= config.compression_ratio && blockwise_ratio < topk_ratio {
blockwise
} else {
topk
}
}
fn decompress(&self) -> Vec<f32> {
let mut result = vec![0.0; self.original_size];
for (i, &idx) in self.indices.iter().enumerate() {
if idx < self.original_size && i < self.compressed_data.len() {
result[idx] = self.compressed_data[i] * self.scale_factor;
}
}
result
}
fn compression_ratio(&self) -> f32 {
self.compressed_data.len() as f32 / self.original_size as f32
}
fn compression_error(&self, original: &[f32]) -> f32 {
let decompressed = self.decompress();
let mut error_sum = 0.0;
let mut norm_sum = 0.0;
for (orig, decomp) in original.iter().zip(decompressed.iter()) {
error_sum += (orig - decomp).powi(2);
norm_sum += orig.powi(2);
}
if norm_sum > 0.0 {
(error_sum / norm_sum).sqrt()
} else {
0.0
}
}
}
#[derive(Debug)]
pub struct MicroAdam {
config: MicroAdamConfig,
state: OptimizerState,
momentum: HashMap<String, CompressedGradient>,
variance: HashMap<String, CompressedGradient>,
compression_stats: CompressionStats,
}
#[derive(Debug, Default)]
struct CompressionStats {
total_parameters: usize,
total_compressed_size: usize,
average_compression_ratio: f32,
average_compression_error: f32,
compression_method_usage: HashMap<String, usize>,
}
impl MicroAdam {
pub fn new() -> Self {
Self::with_config(MicroAdamConfig::default())
}
pub fn new_with_lr(learning_rate: f32) -> Self {
let config = MicroAdamConfig {
learning_rate,
..Default::default()
};
Self::with_config(config)
}
pub fn for_large_models() -> Self {
let config = MicroAdamConfig {
learning_rate: 1e-4,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.01,
compression_ratio: 0.05, min_block_size: 128,
adaptive_compression: true,
compression_threshold: 1e-7,
bias_correction: true,
max_compression_error: 1e-5,
};
Self::with_config(config)
}
pub fn for_memory_constrained() -> Self {
let config = MicroAdamConfig {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.01,
compression_ratio: 0.02, min_block_size: 32,
adaptive_compression: true,
compression_threshold: 1e-6,
bias_correction: true,
max_compression_error: 1e-4,
};
Self::with_config(config)
}
pub fn with_config(config: MicroAdamConfig) -> Self {
Self {
config,
state: OptimizerState::new(),
momentum: HashMap::new(),
variance: HashMap::new(),
compression_stats: CompressionStats::default(),
}
}
pub fn memory_savings_ratio(&self) -> f32 {
if self.compression_stats.total_parameters > 0 {
1.0 - (self.compression_stats.total_compressed_size as f32
/ (self.compression_stats.total_parameters * 2) as f32)
} else {
0.0
}
}
pub fn compression_statistics(&self) -> String {
format!(
"MicroAdam Compression Stats:\n\
- Total parameters: {}\n\
- Compressed size: {}\n\
- Memory savings: {:.1}%\n\
- Average compression ratio: {:.3}\n\
- Average compression error: {:.2e}",
self.compression_stats.total_parameters,
self.compression_stats.total_compressed_size,
self.memory_savings_ratio() * 100.0,
self.compression_stats.average_compression_ratio,
self.compression_stats.average_compression_error
)
}
fn update_compression_stats(
&mut self,
_param_id: &str,
compressed: &CompressedGradient,
original_gradient: &[f32],
) {
self.compression_stats.total_parameters += compressed.original_size;
self.compression_stats.total_compressed_size += compressed.compressed_data.len();
let compression_ratio = compressed.compression_ratio();
let compression_error = compressed.compression_error(original_gradient);
let total_params = self.compression_stats.total_parameters as f32;
self.compression_stats.average_compression_ratio =
(self.compression_stats.average_compression_ratio
* (total_params - compressed.original_size as f32)
+ compression_ratio * compressed.original_size as f32)
/ total_params;
self.compression_stats.average_compression_error =
(self.compression_stats.average_compression_error
* (total_params - compressed.original_size as f32)
+ compression_error * compressed.original_size as f32)
/ total_params;
let method_name = format!("{:?}", compressed.compression_type);
*self.compression_stats.compression_method_usage.entry(method_name).or_insert(0) += 1;
}
}
impl Default for MicroAdam {
fn default() -> Self {
Self::new()
}
}
impl Optimizer for MicroAdam {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
let param_id = format!("{:p}", parameter as *const Tensor);
let grad_data = grad.data()?;
let compressed_gradient = CompressedGradient::compress(&grad_data, &self.config);
let compression_error = compressed_gradient.compression_error(&grad_data);
if compression_error > self.config.max_compression_error {
return Err(TrustformersError::tensor_op_error(
&format!(
"Compression error {} exceeds maximum allowed {}",
compression_error, self.config.max_compression_error
),
"MicroAdam::update",
));
}
self.update_compression_stats(¶m_id, &compressed_gradient, &grad_data);
let momentum = self.momentum.entry(param_id.clone()).or_insert_with(|| {
CompressedGradient::compress(&vec![0.0; grad_data.len()], &self.config)
});
let variance = self.variance.entry(param_id.clone()).or_insert_with(|| {
CompressedGradient::compress(&vec![0.0; grad_data.len()], &self.config)
});
let mut m = momentum.decompress();
let mut v = variance.decompress();
m.resize(grad_data.len(), 0.0);
v.resize(grad_data.len(), 0.0);
self.state.step();
let bias_correction1 = if self.config.bias_correction {
1.0 - self.config.beta1.powf(self.state.step as f32)
} else {
1.0
};
let bias_correction2 = if self.config.bias_correction {
1.0 - self.config.beta2.powf(self.state.step as f32)
} else {
1.0
};
for i in 0..grad_data.len() {
m[i] = self.config.beta1 * m[i] + (1.0 - self.config.beta1) * grad_data[i];
}
for i in 0..grad_data.len() {
v[i] = self.config.beta2 * v[i] + (1.0 - self.config.beta2) * grad_data[i].powi(2);
}
let mut param_data = parameter.data()?;
for i in 0..grad_data.len() {
let m_hat = m[i] / bias_correction1;
let v_hat = v[i] / bias_correction2;
let update_val =
self.config.learning_rate * m_hat / (v_hat.sqrt() + self.config.epsilon);
if self.config.weight_decay > 0.0 {
param_data[i] *= 1.0 - self.config.learning_rate * self.config.weight_decay;
}
param_data[i] -= update_val;
}
*parameter = Tensor::new(param_data)?;
*momentum = CompressedGradient::compress(&m, &self.config);
*variance = CompressedGradient::compress(&v, &self.config);
Ok(())
}
fn zero_grad(&mut self) {
}
fn step(&mut self) {
}
fn get_lr(&self) -> f32 {
self.config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.config.learning_rate = lr;
}
}
impl StatefulOptimizer for MicroAdam {
type Config = MicroAdamConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
for (param_id, momentum) in &self.momentum {
let key = format!("momentum.{}", param_id);
let tensor = Tensor::new(momentum.decompress())?;
state_dict.insert(key, tensor);
}
for (param_id, variance) in &self.variance {
let key = format!("variance.{}", param_id);
let tensor = Tensor::new(variance.decompress())?;
state_dict.insert(key, tensor);
}
state_dict.insert(
"step".to_string(),
Tensor::new(vec![self.state.step as f32])?,
);
Ok(state_dict)
}
fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
if let Some(step_tensor) = state_dict.get("step") {
let step_data = step_tensor.data()?;
if !step_data.is_empty() {
self.state.step = step_data[0] as usize;
}
}
for (key, tensor) in &state_dict {
if let Some(param_id) = key.strip_prefix("momentum.") {
let values = tensor.data()?;
let compressed = CompressedGradient::compress(&values, &self.config);
self.momentum.insert(param_id.to_string(), compressed);
} else if let Some(param_id) = key.strip_prefix("variance.") {
let values = tensor.data()?;
let compressed = CompressedGradient::compress(&values, &self.config);
self.variance.insert(param_id.to_string(), compressed);
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let momentum_size: usize = self.momentum.values().map(|m| m.compressed_data.len()).sum();
let variance_size: usize = self.variance.values().map(|v| v.compressed_data.len()).sum();
StateMemoryStats {
momentum_elements: momentum_size,
variance_elements: variance_size,
third_moment_elements: 0,
total_bytes: (momentum_size + variance_size) * std::mem::size_of::<f32>(),
num_parameters: self.momentum.len(),
}
}
fn reset_state(&mut self) {
self.state.clear();
self.momentum.clear();
self.variance.clear();
self.compression_stats = CompressionStats::default();
}
fn num_parameters(&self) -> usize {
self.momentum.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_microadam_creation() {
let optimizer = MicroAdam::new();
assert_eq!(optimizer.config.learning_rate, 1e-3);
assert_eq!(optimizer.config.beta1, 0.9);
assert_eq!(optimizer.config.beta2, 0.999);
}
#[test]
fn test_microadam_with_config() {
let config = MicroAdamConfig {
learning_rate: 2e-3,
compression_ratio: 0.2,
..Default::default()
};
let optimizer = MicroAdam::with_config(config);
assert_eq!(optimizer.config.learning_rate, 2e-3);
assert_eq!(optimizer.config.compression_ratio, 0.2);
}
#[test]
fn test_microadam_for_large_models() {
let optimizer = MicroAdam::for_large_models();
assert_eq!(optimizer.config.learning_rate, 1e-4);
assert_eq!(optimizer.config.compression_ratio, 0.05);
assert_eq!(optimizer.config.min_block_size, 128);
assert!(optimizer.config.adaptive_compression);
}
#[test]
fn test_microadam_for_memory_constrained() {
let optimizer = MicroAdam::for_memory_constrained();
assert_eq!(optimizer.config.compression_ratio, 0.02);
assert_eq!(optimizer.config.min_block_size, 32);
assert!(optimizer.config.adaptive_compression);
}
#[test]
fn test_compressed_gradient_topk() {
let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
let _config = MicroAdamConfig::default();
let compressed = CompressedGradient::compress_topk(&gradient, 3);
assert_eq!(compressed.compressed_data.len(), 3);
assert_eq!(compressed.indices.len(), 3);
assert_eq!(compressed.original_size, 6);
let mut expected_indices = vec![2, 4, 0];
let mut actual_indices = compressed.indices.clone();
expected_indices.sort();
actual_indices.sort();
assert_eq!(actual_indices, expected_indices);
}
#[test]
fn test_compressed_gradient_threshold() {
let gradient = vec![0.1, 0.001, 0.2, 0.0001, 0.15, 0.0003];
let config = MicroAdamConfig {
compression_threshold: 0.05,
..Default::default()
};
let compressed = CompressedGradient::compress_threshold(&gradient, &config);
assert_eq!(compressed.compressed_data.len(), 3);
assert_eq!(compressed.indices.len(), 3);
let mut expected_indices = vec![0, 2, 4];
let mut actual_indices = compressed.indices.clone();
expected_indices.sort();
actual_indices.sort();
assert_eq!(actual_indices, expected_indices);
}
#[test]
fn test_compression_decompress_cycle() {
let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
let config = MicroAdamConfig::default();
let compressed = CompressedGradient::compress(&gradient, &config);
let decompressed = compressed.decompress();
assert_eq!(decompressed.len(), gradient.len());
for (i, &original) in gradient.iter().enumerate() {
if original.abs() > 0.08 {
assert!(
decompressed[i].abs() > 0.0,
"Significant value at index {} was lost",
i
);
}
}
}
#[test]
fn test_compression_error_calculation() {
let gradient = vec![0.1, 0.05, 0.2, 0.01, 0.15, 0.03];
let config = MicroAdamConfig::default();
let compressed = CompressedGradient::compress(&gradient, &config);
let error = compressed.compression_error(&gradient);
assert!(error >= 0.0);
assert!(error <= 1.0); }
#[test]
fn test_microadam_update() -> Result<()> {
let mut optimizer = MicroAdam::new();
let gradient_data = vec![0.1, -0.05, 0.2, -0.01];
let gradient = Tensor::new(gradient_data.clone())?;
let mut parameter = Tensor::new(vec![1.0, 1.0, 1.0, 1.0])?;
optimizer.update(&mut parameter, &gradient)?;
assert_eq!(optimizer.state().step, 1);
let param_data = parameter.data()?;
assert_eq!(param_data.len(), gradient_data.len());
assert_ne!(param_data[0], 1.0);
Ok(())
}
#[test]
fn test_microadam_multiple_updates() -> Result<()> {
let mut optimizer = MicroAdam::new();
let gradient_data = vec![0.1, -0.05, 0.2, -0.01];
let gradient = Tensor::new(gradient_data)?;
let mut parameter = Tensor::new(vec![1.0, 1.0, 1.0, 1.0])?;
for i in 1..=5 {
optimizer.update(&mut parameter, &gradient)?;
assert_eq!(optimizer.state().step, i);
}
Ok(())
}
#[test]
fn test_memory_savings_ratio() {
let config = MicroAdamConfig {
max_compression_error: 1.0, ..MicroAdamConfig::default()
};
let mut optimizer = MicroAdam::with_config(config);
assert_eq!(optimizer.memory_savings_ratio(), 0.0);
let gradient_data = vec![0.1; 1000]; let gradient = Tensor::new(gradient_data).expect("Failed to create tensor");
let mut parameter = Tensor::new(vec![1.0; 1000]).expect("Failed to create tensor");
optimizer.update(&mut parameter, &gradient).expect("Optimizer update failed");
let savings = optimizer.memory_savings_ratio();
assert!(savings > 0.0, "Should show memory savings");
assert!(savings < 1.0, "Savings ratio should be less than 100%");
}
#[test]
fn test_compression_statistics() {
let config = MicroAdamConfig {
max_compression_error: 1.0, ..MicroAdamConfig::default()
};
let mut optimizer = MicroAdam::with_config(config);
let gradient_data = vec![0.1; 500];
let gradient = Tensor::new(gradient_data).expect("Failed to create tensor");
let mut parameter = Tensor::new(vec![1.0; 500]).expect("Failed to create tensor");
optimizer.update(&mut parameter, &gradient).expect("Optimizer update failed");
let stats = optimizer.compression_statistics();
assert!(stats.contains("MicroAdam Compression Stats"));
assert!(stats.contains("Total parameters: 500"));
assert!(stats.contains("Memory savings"));
assert!(stats.contains("compression ratio"));
}
#[test]
fn test_learning_rate_setter_getter() {
let mut optimizer = MicroAdam::new();
assert_eq!(optimizer.get_lr(), 1e-3);
optimizer.set_lr(2e-3);
assert_eq!(optimizer.get_lr(), 2e-3);
}
#[test]
fn test_state_dict_operations() -> Result<()> {
let mut optimizer = MicroAdam::new();
let gradient_data = vec![0.1, -0.05, 0.2];
let gradient = Tensor::new(gradient_data)?;
let mut param1 = Tensor::new(vec![1.0, 1.0, 1.0])?;
let mut param2 = Tensor::new(vec![2.0, 2.0, 2.0])?;
optimizer.update(&mut param1, &gradient)?;
optimizer.update(&mut param2, &gradient)?;
let state_dict = optimizer.state_dict()?;
assert!(state_dict.contains_key("step"));
let mut new_optimizer = MicroAdam::new();
new_optimizer.load_state_dict(state_dict)?;
assert_eq!(new_optimizer.state().step, optimizer.state().step);
Ok(())
}
#[test]
fn test_memory_usage_tracking() -> Result<()> {
let config = MicroAdamConfig {
max_compression_error: 1.0, ..MicroAdamConfig::default()
};
let mut optimizer = MicroAdam::with_config(config);
let initial_usage = optimizer.memory_usage();
let gradient_data = vec![0.1; 1000];
let gradient = Tensor::new(gradient_data)?;
let mut parameter = Tensor::new(vec![1.0; 1000])?;
optimizer.update(&mut parameter, &gradient)?;
let after_usage = optimizer.memory_usage();
assert!(after_usage.total_bytes > initial_usage.total_bytes);
assert!(after_usage.momentum_elements > 0);
assert!(after_usage.variance_elements > 0);
Ok(())
}
#[test]
fn test_adaptive_compression_selection() {
let sparse_gradient = vec![0.0; 1000]; let dense_gradient = vec![0.1; 1000];
let config = MicroAdamConfig {
adaptive_compression: true,
compression_threshold: 1e-6,
..Default::default()
};
let sparse_compression =
CompressedGradient::choose_adaptive_compression(&sparse_gradient, &config);
let dense_compression =
CompressedGradient::choose_adaptive_compression(&dense_gradient, &config);
match sparse_compression {
CompressionType::Threshold
| CompressionType::TopK
| CompressionType::BlockWise
| CompressionType::Adaptive => {},
}
match dense_compression {
CompressionType::Threshold
| CompressionType::TopK
| CompressionType::BlockWise
| CompressionType::Adaptive => {},
}
}
}