use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::{Tensor, TensorId};
use std::collections::HashMap;
pub struct GradStore<R: Runtime> {
grads: HashMap<TensorId, Tensor<R>>,
}
impl<R: Runtime> GradStore<R> {
pub fn new() -> Self {
Self {
grads: HashMap::new(),
}
}
pub fn get(&self, id: TensorId) -> Option<&Tensor<R>> {
self.grads.get(&id)
}
pub fn insert(&mut self, id: TensorId, grad: Tensor<R>) {
self.grads.insert(id, grad);
}
pub fn contains(&self, id: TensorId) -> bool {
self.grads.contains_key(&id)
}
pub fn remove(&mut self, id: TensorId) -> Option<Tensor<R>> {
self.grads.remove(&id)
}
pub fn keys(&self) -> impl Iterator<Item = &TensorId> {
self.grads.keys()
}
pub fn len(&self) -> usize {
self.grads.len()
}
pub fn is_empty(&self) -> bool {
self.grads.is_empty()
}
pub fn clear(&mut self) {
self.grads.clear();
}
pub fn accumulate<F>(&mut self, id: TensorId, grad: Tensor<R>, add_fn: F)
where
F: FnOnce(Tensor<R>, Tensor<R>) -> Tensor<R>,
{
if let Some(existing) = self.grads.remove(&id) {
let accumulated = add_fn(existing, grad);
self.grads.insert(id, accumulated);
} else {
self.grads.insert(id, grad);
}
}
pub fn try_accumulate<F>(&mut self, id: TensorId, grad: Tensor<R>, add_fn: F) -> Result<()>
where
F: FnOnce(Tensor<R>, Tensor<R>) -> Result<Tensor<R>>,
{
if let Some(existing) = self.grads.remove(&id) {
let accumulated = add_fn(existing, grad)?;
self.grads.insert(id, accumulated);
} else {
self.grads.insert(id, grad);
}
Ok(())
}
pub fn accumulate_or_insert(&mut self, id: TensorId, grad: Tensor<R>) {
self.grads.insert(id, grad);
}
}
impl<R: Runtime> Default for GradStore<R> {
fn default() -> Self {
Self::new()
}
}