use std::collections::HashMap;
use candle_core::{DType, Device, Tensor};
use crate::config::TernaryConfig;
use crate::error::Result;
use crate::ternary::{
calculate_memory_savings, ternary_quantize_deterministic, ternary_quantize_stochastic,
};
#[derive(Debug, Clone)]
struct AccumulatedGradient {
ternary: Tensor,
scale_sum: f32,
shape: Vec<usize>,
}
pub struct TernaryGradientAccumulator {
config: TernaryConfig,
device: Device,
accumulators: HashMap<String, AccumulatedGradient>,
count: usize,
}
impl TernaryGradientAccumulator {
pub fn new(
param_shapes: &[(String, Vec<usize>)],
config: TernaryConfig,
device: &Device,
) -> Result<Self> {
let mut accumulators = HashMap::new();
for (name, shape) in param_shapes {
let ternary = Tensor::zeros(shape.as_slice(), DType::F32, device)?;
accumulators.insert(
name.clone(),
AccumulatedGradient {
ternary,
scale_sum: 0.0,
shape: shape.clone(),
},
);
}
Ok(Self {
config,
device: device.clone(),
accumulators,
count: 0,
})
}
pub fn accumulate(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
let threshold = Some(self.config.ternary_threshold);
for (name, grad) in gradients {
if let Some(accum) = self.accumulators.get_mut(name) {
let (ternary, scale) = if self.config.use_stochastic_rounding {
ternary_quantize_stochastic(grad, threshold)?
} else {
ternary_quantize_deterministic(grad, threshold)?
};
accum.ternary = accum.ternary.add(&ternary)?;
accum.scale_sum += scale;
}
}
self.count += 1;
Ok(())
}
#[allow(clippy::cast_precision_loss)]
pub fn get_accumulated(&self) -> Result<HashMap<String, Tensor>> {
let mut accumulated = HashMap::new();
for (name, accum) in &self.accumulators {
if self.count > 0 {
let avg_scale = accum.scale_sum / self.count as f32;
let result = (&accum.ternary * avg_scale as f64)?;
let result = (result / self.count as f64)?;
accumulated.insert(name.clone(), result);
} else {
accumulated.insert(name.clone(), accum.ternary.clone());
}
}
Ok(accumulated)
}
pub fn reset(&mut self) -> Result<()> {
for accum in self.accumulators.values_mut() {
accum.ternary = accum.ternary.zeros_like()?;
accum.scale_sum = 0.0;
}
self.count = 0;
Ok(())
}
#[must_use]
pub const fn count(&self) -> usize {
self.count
}
#[must_use]
pub fn memory_savings(&self) -> f32 {
let param_count: usize = self.accumulators.values().map(|a| a.shape.iter().product::<usize>()).sum();
let num_tensors = self.accumulators.len();
calculate_memory_savings(param_count, num_tensors)
}
#[must_use]
pub fn ready_for_update(&self) -> bool {
self.count >= self.config.accumulation_steps
}
}
pub struct TernaryOptimizerWrapper {
config: TernaryConfig,
accumulator: TernaryGradientAccumulator,
step_count: usize,
update_count: usize,
}
impl TernaryOptimizerWrapper {
pub fn new(
param_shapes: &[(String, Vec<usize>)],
config: TernaryConfig,
device: &Device,
) -> Result<Self> {
let accumulator = TernaryGradientAccumulator::new(param_shapes, config.clone(), device)?;
Ok(Self {
config,
accumulator,
step_count: 0,
update_count: 0,
})
}
pub fn step(&mut self, gradients: &HashMap<String, Tensor>) -> Result<bool> {
self.accumulator.accumulate(gradients)?;
self.step_count += 1;
Ok(self.step_count % self.config.accumulation_steps == 0)
}
pub fn get_gradients_for_update(&mut self) -> Result<HashMap<String, Tensor>> {
let grads = self.accumulator.get_accumulated()?;
self.accumulator.reset()?;
self.update_count += 1;
Ok(grads)
}
#[must_use]
pub fn get_stats(&self) -> OptimizerStats {
OptimizerStats {
step_count: self.step_count,
update_count: self.update_count,
memory_savings: self.accumulator.memory_savings(),
accumulation_steps: self.config.accumulation_steps,
}
}
#[must_use]
pub const fn step_count(&self) -> usize {
self.step_count
}
#[must_use]
pub const fn update_count(&self) -> usize {
self.update_count
}
pub fn reset_state(&mut self) {
self.step_count = 0;
self.update_count = 0;
}
pub fn load_state(&mut self, step_count: usize, update_count: usize) {
self.step_count = step_count;
self.update_count = update_count;
}
}
#[derive(Debug, Clone)]
pub struct OptimizerStats {
pub step_count: usize,
pub update_count: usize,
pub memory_savings: f32,
pub accumulation_steps: usize,
}
impl std::fmt::Display for OptimizerStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Steps: {} | Updates: {} | Memory saved: {:.1}%",
self.step_count,
self.update_count,
self.memory_savings * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
vec![
("layer1.weight".to_string(), vec![64, 128]),
("layer1.bias".to_string(), vec![64]),
("layer2.weight".to_string(), vec![32, 64]),
]
}
fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
let mut gradients = HashMap::new();
gradients.insert(
"layer1.weight".to_string(),
Tensor::randn(0.0f32, 1.0, (64, 128), device).unwrap(),
);
gradients.insert(
"layer1.bias".to_string(),
Tensor::randn(0.0f32, 1.0, 64, device).unwrap(),
);
gradients.insert(
"layer2.weight".to_string(),
Tensor::randn(0.0f32, 1.0, (32, 64), device).unwrap(),
);
gradients
}
#[test]
fn test_accumulator_creation() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default();
let accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
assert_eq!(accumulator.count(), 0);
}
#[test]
fn test_accumulator_accumulate() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default();
let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
let gradients = create_mock_gradients(&device);
accumulator.accumulate(&gradients).unwrap();
assert_eq!(accumulator.count(), 1);
accumulator.accumulate(&gradients).unwrap();
assert_eq!(accumulator.count(), 2);
}
#[test]
fn test_accumulator_get_accumulated() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default();
let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
let gradients = create_mock_gradients(&device);
accumulator.accumulate(&gradients).unwrap();
let accumulated = accumulator.get_accumulated().unwrap();
assert_eq!(accumulated.len(), 3);
for (name, _shape) in &shapes {
assert!(accumulated.contains_key(name));
}
}
#[test]
fn test_accumulator_reset() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default();
let mut accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
let gradients = create_mock_gradients(&device);
accumulator.accumulate(&gradients).unwrap();
assert_eq!(accumulator.count(), 1);
accumulator.reset().unwrap();
assert_eq!(accumulator.count(), 0);
}
#[test]
fn test_accumulator_memory_savings() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default();
let accumulator = TernaryGradientAccumulator::new(&shapes, config, &device).unwrap();
let savings = accumulator.memory_savings();
assert!(savings > 0.9, "Expected >90% savings, got {:.2}%", savings * 100.0);
}
#[test]
fn test_optimizer_wrapper_step() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default().with_accumulation_steps(4);
let mut wrapper = TernaryOptimizerWrapper::new(&shapes, config, &device).unwrap();
let gradients = create_mock_gradients(&device);
for _ in 0..3 {
let should_update = wrapper.step(&gradients).unwrap();
assert!(!should_update);
}
let should_update = wrapper.step(&gradients).unwrap();
assert!(should_update);
let accumulated = wrapper.get_gradients_for_update().unwrap();
assert_eq!(accumulated.len(), 3);
let should_update = wrapper.step(&gradients).unwrap();
assert!(!should_update);
}
#[test]
fn test_optimizer_wrapper_stats() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config = TernaryConfig::default().with_accumulation_steps(2);
let mut wrapper = TernaryOptimizerWrapper::new(&shapes, config, &device).unwrap();
let gradients = create_mock_gradients(&device);
wrapper.step(&gradients).unwrap();
wrapper.step(&gradients).unwrap();
let _ = wrapper.get_gradients_for_update().unwrap();
let stats = wrapper.get_stats();
assert_eq!(stats.step_count, 2);
assert_eq!(stats.update_count, 1);
assert!(stats.memory_savings > 0.9);
}
#[test]
fn test_stochastic_vs_deterministic() {
let shapes = create_param_shapes();
let device = Device::Cpu;
let config_stochastic = TernaryConfig::default().with_stochastic_rounding(true);
let mut acc_stochastic = TernaryGradientAccumulator::new(&shapes, config_stochastic, &device).unwrap();
let config_deterministic = TernaryConfig::default().with_stochastic_rounding(false);
let mut acc_deterministic = TernaryGradientAccumulator::new(&shapes, config_deterministic, &device).unwrap();
let gradients = create_mock_gradients(&device);
acc_stochastic.accumulate(&gradients).unwrap();
acc_deterministic.accumulate(&gradients).unwrap();
let result_stochastic = acc_stochastic.get_accumulated().unwrap();
let result_deterministic = acc_deterministic.get_accumulated().unwrap();
assert_eq!(result_stochastic.len(), 3);
assert_eq!(result_deterministic.len(), 3);
}
}