use std::collections::HashMap;
use std::sync::Arc;
use crate::tensor::TensorId;
use crate::buffer::{Buffer, CpuBuffer};
use crate::errors::{EtensorError, EtensorResult};
pub struct Gradients {
grads: HashMap<TensorId, Buffer>,
}
impl Gradients {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
grads: HashMap::new(),
}
}
pub fn get(&self, id: &TensorId) -> Option<&Buffer> {
self.grads.get(id)
}
pub fn remove(&mut self, id: &TensorId) -> Option<Buffer> {
self.grads.remove(id)
}
pub fn insert(&mut self, id: TensorId, new_grad: Buffer) -> EtensorResult<()> {
if let Some(existing_grad) = self.grads.get_mut(&id) {
match (existing_grad, &new_grad) {
(Buffer::Cpu(CpuBuffer::F32(existing_arc)), Buffer::Cpu(CpuBuffer::F32(new_arc))) => {
let existing_vec = Arc::make_mut(existing_arc);
if existing_vec.len() != new_arc.len() {
return Err(EtensorError::InternalError(
"Gradient accumulation failed: Buffer lengths mismatch!".to_string()
));
}
for (a, b) in existing_vec.iter_mut().zip(new_arc.iter()) {
*a += b;
}
Ok(())
}
_ => {
Err(EtensorError::InternalError(
"Gradient accumulation for non-F32 or GPU buffers is deferred to Phase 4 Dispatcher.".to_string()
))
}
}
} else {
self.grads.insert(id, new_grad);
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gradient_insertion_and_retrieval() {
let mut grads = Gradients::new();
let id = TensorId::new();
let grad_buf = Buffer::from_f32_vec(vec![1.5, 2.5]);
grads.insert(id, grad_buf).unwrap();
let retrieved = grads.get(&id).unwrap();
let slice = retrieved.as_f32_slice().unwrap();
assert_eq!(slice, &[1.5, 2.5]);
}
#[test]
fn test_gradient_accumulation_on_collision() {
let mut grads = Gradients::new();
let id = TensorId::new();
let grad1 = Buffer::from_f32_vec(vec![1.0, 2.0]);
grads.insert(id, grad1).unwrap();
let grad2 = Buffer::from_f32_vec(vec![3.0, 4.0]);
grads.insert(id, grad2).unwrap();
let accumulated = grads.get(&id).unwrap();
let slice = accumulated.as_f32_slice().unwrap();
assert_eq!(slice, &[4.0, 6.0], "Gradients failed to sum together upon collision!");
}
}