use super::Var;
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::{Tensor, TensorId};
use std::collections::HashMap;
pub struct VarGradStore<R: Runtime> {
grads: HashMap<TensorId, Var<R>>,
}
impl<R: Runtime> VarGradStore<R> {
pub fn new() -> Self {
Self {
grads: HashMap::new(),
}
}
pub fn get_var(&self, id: TensorId) -> Option<&Var<R>> {
self.grads.get(&id)
}
pub fn get(&self, id: TensorId) -> Option<&Tensor<R>> {
self.grads.get(&id).map(|v| v.tensor())
}
pub fn insert(&mut self, id: TensorId, grad: Var<R>) {
self.grads.insert(id, grad);
}
pub fn contains(&self, id: TensorId) -> bool {
self.grads.contains_key(&id)
}
pub fn remove(&mut self, id: TensorId) -> Option<Var<R>> {
self.grads.remove(&id)
}
pub fn keys(&self) -> impl Iterator<Item = &TensorId> {
self.grads.keys()
}
pub fn len(&self) -> usize {
self.grads.len()
}
pub fn is_empty(&self) -> bool {
self.grads.is_empty()
}
pub fn clear(&mut self) {
self.grads.clear();
}
pub fn try_accumulate<F>(&mut self, id: TensorId, grad: Var<R>, add_fn: F) -> Result<()>
where
F: FnOnce(Var<R>, Var<R>) -> Result<Var<R>>,
{
use std::collections::hash_map::Entry;
match self.grads.entry(id) {
Entry::Occupied(entry) => {
let existing = entry.remove();
let accumulated = add_fn(existing, grad)?;
self.grads.insert(id, accumulated);
}
Entry::Vacant(entry) => {
entry.insert(grad);
}
}
Ok(())
}
pub fn to_grad_store(self) -> super::GradStore<R> {
let mut store = super::GradStore::new();
for (id, var) in self.grads {
store.insert(id, var.tensor().clone());
}
store
}
pub fn iter(&self) -> impl Iterator<Item = (&TensorId, &Var<R>)> {
self.grads.iter()
}
}
impl<R: Runtime> Default for VarGradStore<R> {
fn default() -> Self {
Self::new()
}
}
impl<R: Runtime> IntoIterator for VarGradStore<R> {
type Item = (TensorId, Var<R>);
type IntoIter = std::collections::hash_map::IntoIter<TensorId, Var<R>>;
fn into_iter(self) -> Self::IntoIter {
self.grads.into_iter()
}
}