use std::collections::HashMap;
use crate::autograd::AutogradError;
use crate::tensor::{GradId, Tensor};
pub struct GradientStore {
grads: HashMap<GradId, Tensor>,
}
impl GradientStore {
pub fn new() -> Self {
Self {
grads: HashMap::new(),
}
}
pub fn accumulate(&mut self, id: GradId, grad: Tensor) -> Result<(), AutogradError> {
match self.grads.get(&id) {
Some(existing) => {
if existing.shape() != grad.shape() {
return Err(AutogradError::ShapeMismatch {
grad_id: id,
expected: existing.shape().to_vec(),
found: grad.shape().to_vec(),
});
}
let summed = existing.add(&grad);
self.grads.insert(id, summed);
Ok(())
}
None => {
self.grads.insert(id, grad);
Ok(())
}
}
}
pub fn remove(&mut self, id: GradId) -> Option<Tensor> {
self.grads.remove(&id)
}
pub fn get(&self, id: GradId) -> Option<&Tensor> {
self.grads.get(&id)
}
pub fn is_empty(&self) -> bool {
self.grads.is_empty()
}
pub fn len(&self) -> usize {
self.grads.len()
}
pub fn replace(&mut self, id: GradId, grad: Tensor) {
self.grads.insert(id, grad);
}
}