acme_tensor/actions/grad/
store.rs

1/*
2    Appellation: store <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::prelude::TensorId;
6use crate::TensorBase;
7use acme::prelude::Store;
8use core::borrow::{Borrow, BorrowMut};
9use core::ops::{Deref, DerefMut, Index, IndexMut};
10use std::collections::btree_map::{BTreeMap, Entry, Keys, Values};
11
12#[derive(Clone, Debug)]
13pub struct TensorGrad<T> {
14    pub(crate) store: BTreeMap<TensorId, TensorBase<T>>,
15}
16
17impl<T> Default for TensorGrad<T> {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl<T> TensorGrad<T> {
24    pub fn new() -> Self {
25        Self {
26            store: BTreeMap::new(),
27        }
28    }
29    /// Clears the store, removing all values.
30    pub fn clear(&mut self) {
31        self.store.clear()
32    }
33    /// Returns a reference to the value corresponding to the key.
34    pub fn entry(&mut self, key: TensorId) -> Entry<'_, TensorId, TensorBase<T>> {
35        self.store.entry(key)
36    }
37    /// Returns a reference to the value corresponding to the key.
38    pub fn get_tensor(&self, item: &TensorBase<T>) -> Option<&TensorBase<T>> {
39        self.store.get(&item.id())
40    }
41    /// Inserts a tensor into the store.
42    pub fn insert_tensor(&mut self, tensor: TensorBase<T>) -> Option<TensorBase<T>> {
43        self.insert(tensor.id(), tensor)
44    }
45    /// Returns true if the store contains no elements.
46    pub fn is_empty(&self) -> bool {
47        self.store.is_empty()
48    }
49    /// Returns an iterator over the store's keys
50    pub fn keys(&self) -> Keys<'_, TensorId, TensorBase<T>> {
51        self.store.keys()
52    }
53    /// Returns the number of elements in the store.
54    pub fn len(&self) -> usize {
55        self.store.len()
56    }
57    /// If the store does not have a tensor with the given id, insert it.
58    /// Returns a mutable reference to the tensor.
59    pub fn or_insert(&mut self, tensor: TensorBase<T>) -> &mut TensorBase<T> {
60        self.entry(tensor.id()).or_insert(tensor)
61    }
62    /// If the store does not have a tensor with the given id, insert a tensor with the same shape
63    /// and dtype as the given tensor, where all elements are default.
64    pub fn or_insert_default(&mut self, tensor: &TensorBase<T>) -> &mut TensorBase<T>
65    where
66        T: Clone + Default,
67    {
68        self.entry(tensor.id()).or_insert(tensor.default_like())
69    }
70    /// If the store does not have a tensor with the given id, insert a tensor with the same shape
71    /// and dtype as the given tensor, where all elements are zeros.
72    pub fn or_insert_zeros(&mut self, tensor: &TensorBase<T>) -> &mut TensorBase<T>
73    where
74        T: Clone + num::Zero,
75    {
76        self.entry(tensor.id()).or_insert(tensor.zeros_like())
77    }
78    /// Remove an element from the store.
79    pub fn remove(&mut self, key: &TensorId) -> Option<TensorBase<T>> {
80        self.store.remove(key)
81    }
82    /// Remove a tensor from the store.
83    pub fn remove_tensor(&mut self, tensor: &TensorBase<T>) -> Option<TensorBase<T>> {
84        self.remove(&tensor.id())
85    }
86
87    pub fn values(&self) -> Values<'_, TensorId, TensorBase<T>> {
88        self.store.values()
89    }
90}
91
92impl<T> AsRef<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
93    fn as_ref(&self) -> &BTreeMap<TensorId, TensorBase<T>> {
94        &self.store
95    }
96}
97
98impl<T> AsMut<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
99    fn as_mut(&mut self) -> &mut BTreeMap<TensorId, TensorBase<T>> {
100        &mut self.store
101    }
102}
103
104impl<T> Borrow<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
105    fn borrow(&self) -> &BTreeMap<TensorId, TensorBase<T>> {
106        &self.store
107    }
108}
109
110impl<T> BorrowMut<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
111    fn borrow_mut(&mut self) -> &mut BTreeMap<TensorId, TensorBase<T>> {
112        &mut self.store
113    }
114}
115
116impl<T> Deref for TensorGrad<T> {
117    type Target = BTreeMap<TensorId, TensorBase<T>>;
118
119    fn deref(&self) -> &Self::Target {
120        &self.store
121    }
122}
123
124impl<T> DerefMut for TensorGrad<T> {
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        &mut self.store
127    }
128}
129
130impl<T> Extend<(TensorId, TensorBase<T>)> for TensorGrad<T> {
131    fn extend<I: IntoIterator<Item = (TensorId, TensorBase<T>)>>(&mut self, iter: I) {
132        self.store.extend(iter)
133    }
134}
135
136impl<T> FromIterator<(TensorId, TensorBase<T>)> for TensorGrad<T> {
137    fn from_iter<I: IntoIterator<Item = (TensorId, TensorBase<T>)>>(iter: I) -> Self {
138        Self {
139            store: BTreeMap::from_iter(iter),
140        }
141    }
142}
143
144impl<T> Index<&TensorId> for TensorGrad<T> {
145    type Output = TensorBase<T>;
146
147    fn index(&self, index: &TensorId) -> &Self::Output {
148        &self.store[index]
149    }
150}
151
152impl<T> IndexMut<&TensorId> for TensorGrad<T> {
153    fn index_mut(&mut self, index: &TensorId) -> &mut Self::Output {
154        self.get_mut(index).expect("Tensor not found")
155    }
156}
157
158impl<T> IntoIterator for TensorGrad<T> {
159    type Item = (TensorId, TensorBase<T>);
160    type IntoIter = std::collections::btree_map::IntoIter<TensorId, TensorBase<T>>;
161
162    fn into_iter(self) -> Self::IntoIter {
163        self.store.into_iter()
164    }
165}
166
167impl<T> Store<TensorId, TensorBase<T>> for TensorGrad<T> {
168    fn get(&self, key: &TensorId) -> Option<&TensorBase<T>> {
169        self.store.get(key)
170    }
171
172    fn get_mut(&mut self, key: &TensorId) -> Option<&mut TensorBase<T>> {
173        self.store.get_mut(key)
174    }
175
176    fn insert(&mut self, key: TensorId, value: TensorBase<T>) -> Option<TensorBase<T>> {
177        self.store.insert(key, value)
178    }
179
180    fn remove(&mut self, key: &TensorId) -> Option<TensorBase<T>> {
181        self.remove(key)
182    }
183}