acme_tensor/actions/grad/
store.rs1use crate::prelude::TensorId;
6use crate::TensorBase;
7use acme::prelude::Store;
8use core::borrow::{Borrow, BorrowMut};
9use core::ops::{Deref, DerefMut, Index, IndexMut};
10use std::collections::btree_map::{BTreeMap, Entry, Keys, Values};
11
12#[derive(Clone, Debug)]
13pub struct TensorGrad<T> {
14 pub(crate) store: BTreeMap<TensorId, TensorBase<T>>,
15}
16
17impl<T> Default for TensorGrad<T> {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl<T> TensorGrad<T> {
24 pub fn new() -> Self {
25 Self {
26 store: BTreeMap::new(),
27 }
28 }
29 pub fn clear(&mut self) {
31 self.store.clear()
32 }
33 pub fn entry(&mut self, key: TensorId) -> Entry<'_, TensorId, TensorBase<T>> {
35 self.store.entry(key)
36 }
37 pub fn get_tensor(&self, item: &TensorBase<T>) -> Option<&TensorBase<T>> {
39 self.store.get(&item.id())
40 }
41 pub fn insert_tensor(&mut self, tensor: TensorBase<T>) -> Option<TensorBase<T>> {
43 self.insert(tensor.id(), tensor)
44 }
45 pub fn is_empty(&self) -> bool {
47 self.store.is_empty()
48 }
49 pub fn keys(&self) -> Keys<'_, TensorId, TensorBase<T>> {
51 self.store.keys()
52 }
53 pub fn len(&self) -> usize {
55 self.store.len()
56 }
57 pub fn or_insert(&mut self, tensor: TensorBase<T>) -> &mut TensorBase<T> {
60 self.entry(tensor.id()).or_insert(tensor)
61 }
62 pub fn or_insert_default(&mut self, tensor: &TensorBase<T>) -> &mut TensorBase<T>
65 where
66 T: Clone + Default,
67 {
68 self.entry(tensor.id()).or_insert(tensor.default_like())
69 }
70 pub fn or_insert_zeros(&mut self, tensor: &TensorBase<T>) -> &mut TensorBase<T>
73 where
74 T: Clone + num::Zero,
75 {
76 self.entry(tensor.id()).or_insert(tensor.zeros_like())
77 }
78 pub fn remove(&mut self, key: &TensorId) -> Option<TensorBase<T>> {
80 self.store.remove(key)
81 }
82 pub fn remove_tensor(&mut self, tensor: &TensorBase<T>) -> Option<TensorBase<T>> {
84 self.remove(&tensor.id())
85 }
86
87 pub fn values(&self) -> Values<'_, TensorId, TensorBase<T>> {
88 self.store.values()
89 }
90}
91
92impl<T> AsRef<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
93 fn as_ref(&self) -> &BTreeMap<TensorId, TensorBase<T>> {
94 &self.store
95 }
96}
97
98impl<T> AsMut<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
99 fn as_mut(&mut self) -> &mut BTreeMap<TensorId, TensorBase<T>> {
100 &mut self.store
101 }
102}
103
104impl<T> Borrow<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
105 fn borrow(&self) -> &BTreeMap<TensorId, TensorBase<T>> {
106 &self.store
107 }
108}
109
110impl<T> BorrowMut<BTreeMap<TensorId, TensorBase<T>>> for TensorGrad<T> {
111 fn borrow_mut(&mut self) -> &mut BTreeMap<TensorId, TensorBase<T>> {
112 &mut self.store
113 }
114}
115
116impl<T> Deref for TensorGrad<T> {
117 type Target = BTreeMap<TensorId, TensorBase<T>>;
118
119 fn deref(&self) -> &Self::Target {
120 &self.store
121 }
122}
123
124impl<T> DerefMut for TensorGrad<T> {
125 fn deref_mut(&mut self) -> &mut Self::Target {
126 &mut self.store
127 }
128}
129
130impl<T> Extend<(TensorId, TensorBase<T>)> for TensorGrad<T> {
131 fn extend<I: IntoIterator<Item = (TensorId, TensorBase<T>)>>(&mut self, iter: I) {
132 self.store.extend(iter)
133 }
134}
135
136impl<T> FromIterator<(TensorId, TensorBase<T>)> for TensorGrad<T> {
137 fn from_iter<I: IntoIterator<Item = (TensorId, TensorBase<T>)>>(iter: I) -> Self {
138 Self {
139 store: BTreeMap::from_iter(iter),
140 }
141 }
142}
143
144impl<T> Index<&TensorId> for TensorGrad<T> {
145 type Output = TensorBase<T>;
146
147 fn index(&self, index: &TensorId) -> &Self::Output {
148 &self.store[index]
149 }
150}
151
152impl<T> IndexMut<&TensorId> for TensorGrad<T> {
153 fn index_mut(&mut self, index: &TensorId) -> &mut Self::Output {
154 self.get_mut(index).expect("Tensor not found")
155 }
156}
157
158impl<T> IntoIterator for TensorGrad<T> {
159 type Item = (TensorId, TensorBase<T>);
160 type IntoIter = std::collections::btree_map::IntoIter<TensorId, TensorBase<T>>;
161
162 fn into_iter(self) -> Self::IntoIter {
163 self.store.into_iter()
164 }
165}
166
167impl<T> Store<TensorId, TensorBase<T>> for TensorGrad<T> {
168 fn get(&self, key: &TensorId) -> Option<&TensorBase<T>> {
169 self.store.get(key)
170 }
171
172 fn get_mut(&mut self, key: &TensorId) -> Option<&mut TensorBase<T>> {
173 self.store.get_mut(key)
174 }
175
176 fn insert(&mut self, key: TensorId, value: TensorBase<T>) -> Option<TensorBase<T>> {
177 self.store.insert(key, value)
178 }
179
180 fn remove(&mut self, key: &TensorId) -> Option<TensorBase<T>> {
181 self.remove(key)
182 }
183}