Skip to main content

lumen_core/grad/
store.rs

1use std::{collections::HashMap, ops::Index};
2use crate::{FloatDType, Tensor, TensorId};
3
4#[derive(Debug, Clone)]
5pub struct GradStore<T: FloatDType>(HashMap<TensorId, Tensor<T>>);
6
7impl<T: FloatDType> GradStore<T> {
8    /// Create a new gradient store
9    pub fn new() -> Self {
10        GradStore(HashMap::new())
11    }
12
13    /// Get the gradient tensor associated with the given tensor
14    pub fn get(&self, tensor: &Tensor<T>) -> Option<&Tensor<T>> {
15        self.0.get(&tensor.id())
16    }
17
18    pub fn get_by_index(&self, index: usize) -> Option<&Tensor<T>> {
19        self.0.get(&index)
20    }
21
22    /// Remove the gradient tensor associated with the given tensor, returning it if it exists
23    pub fn remove(&mut self, tensor: &Tensor<T>) -> Option<Tensor<T>> {
24        self.0.remove(&tensor.id())
25    }
26
27    /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
28    pub fn insert(&mut self, tensor: &Tensor<T>, grad: Tensor<T>) -> Option<Tensor<T>> {
29        self.0.insert(tensor.id(), grad)
30    }
31
32    /// Get the gradient tensor associated with the given tensor, or, if it does not exist,
33    /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
34    pub fn or_insert(&mut self, tensor: &Tensor<T>) -> crate::Result<&mut Tensor<T>> {
35        use std::collections::hash_map::Entry;
36        let grad = match self.0.entry(tensor.id()) {
37            Entry::Occupied(entry) => entry.into_mut(),
38            Entry::Vacant(entry) => {
39                let grad = tensor.zeros_like()?;
40                entry.insert(grad)
41            }
42        };
43        Ok(grad)
44    }
45
46    /// Get the tensor ids of the stored gradient tensors
47    pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
48        self.0.keys()
49    }
50
51    pub fn tensors(&self) -> impl Iterator<Item = &Tensor<T>> {
52        self.0.values()
53    }
54
55    pub fn iter(&self) -> std::collections::hash_map::Iter<'_, TensorId, Tensor<T>> {
56        self.0.iter()
57    }
58
59    pub fn len(&self) -> usize {
60        self.0.len()
61    }
62}
63
64impl<T: FloatDType> Index<&Tensor<T>> for GradStore<T> {
65    type Output = Tensor<T>;
66    fn index(&self, index: &Tensor<T>) -> &Self::Output {
67        self.get(index).unwrap()
68    }
69}