1use std::{collections::HashMap, ops::Index};
2use crate::{FloatDType, Tensor, TensorId};
3
4#[derive(Debug, Clone)]
5pub struct GradStore<T: FloatDType>(HashMap<TensorId, Tensor<T>>);
6
7impl<T: FloatDType> GradStore<T> {
8 pub fn new() -> Self {
10 GradStore(HashMap::new())
11 }
12
13 pub fn get(&self, tensor: &Tensor<T>) -> Option<&Tensor<T>> {
15 self.0.get(&tensor.id())
16 }
17
18 pub fn get_by_index(&self, index: usize) -> Option<&Tensor<T>> {
19 self.0.get(&index)
20 }
21
22 pub fn remove(&mut self, tensor: &Tensor<T>) -> Option<Tensor<T>> {
24 self.0.remove(&tensor.id())
25 }
26
27 pub fn insert(&mut self, tensor: &Tensor<T>, grad: Tensor<T>) -> Option<Tensor<T>> {
29 self.0.insert(tensor.id(), grad)
30 }
31
32 pub fn or_insert(&mut self, tensor: &Tensor<T>) -> crate::Result<&mut Tensor<T>> {
35 use std::collections::hash_map::Entry;
36 let grad = match self.0.entry(tensor.id()) {
37 Entry::Occupied(entry) => entry.into_mut(),
38 Entry::Vacant(entry) => {
39 let grad = tensor.zeros_like()?;
40 entry.insert(grad)
41 }
42 };
43 Ok(grad)
44 }
45
46 pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
48 self.0.keys()
49 }
50
51 pub fn tensors(&self) -> impl Iterator<Item = &Tensor<T>> {
52 self.0.values()
53 }
54
55 pub fn iter(&self) -> std::collections::hash_map::Iter<'_, TensorId, Tensor<T>> {
56 self.0.iter()
57 }
58
59 pub fn len(&self) -> usize {
60 self.0.len()
61 }
62}
63
64impl<T: FloatDType> Index<&Tensor<T>> for GradStore<T> {
65 type Output = Tensor<T>;
66 fn index(&self, index: &Tensor<T>) -> &Self::Output {
67 self.get(index).unwrap()
68 }
69}