burn_autodiff/runtime/
memory_management.rs1use crate::{
2 NodeId,
3 collections::{HashMap, HashSet},
4 graph::Parent,
5 tensor::NodeRefCount,
6};
7use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
8use core::mem;
9
10#[derive(Default, Debug)]
11pub struct GraphMemoryManagement {
12 nodes: HashMap<NodeRefCount, Vec<NodeId>>,
13 leaves: HashSet<NodeId>,
14 statuses: HashMap<NodeId, NodeMemoryStatus>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18enum NodeMemoryStatus {
19 Useful,
20 Unavailable,
21 Unknown,
22}
23
24impl GraphMemoryManagement {
25 pub fn extend(&mut self, other: Self) {
26 self.nodes.extend(other.nodes);
27 self.leaves.extend(other.leaves);
28 self.statuses.extend(other.statuses);
29 }
30
31 pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
33 let node_id = *node.as_ref();
34
35 for parent in parents.iter() {
36 self.leaves.remove(&parent.id);
37 }
38
39 self.leaves.insert(node_id);
40 self.nodes
41 .insert(node, parents.iter().map(|p| p.id).collect());
42 }
43
44 pub fn consume_node(&mut self, node_id: NodeId) {
46 if !self.is_referenced(node_id) {
47 self.leaves.remove(&node_id);
48 self.nodes.remove(&node_id);
49 }
50 }
51
52 pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
57 let leaves = self.leaves.clone();
58 let mut new_leaves = HashSet::new();
59 let mut deletables = Vec::new();
60
61 for leaf in leaves.clone() {
65 self.unavailable_propagation(leaf);
66 }
67
68 self.useful_propagation(leaves.clone());
73
74 for leaf in leaves {
77 self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
78 }
79
80 mem::swap(&mut self.leaves, &mut new_leaves);
82
83 self.clear_unused_roots(&mut deletables);
84
85 self.statuses.clear();
86 for node_to_delete in deletables {
87 self.nodes.remove(&node_to_delete);
88 on_free_graph(&node_to_delete)
89 }
90 }
91
92 pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
93 let mut deletables = Vec::new();
94 self.clear_unused_roots(&mut deletables);
95
96 for node_id in deletables {
97 self.nodes.remove(&node_id);
98 on_free_graph(&node_id);
99 }
100 }
101
102 fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
103 for (id, parents) in self.nodes.iter() {
104 let is_useful = matches!(
105 self.statuses.get(id.as_ref()),
106 Some(NodeMemoryStatus::Useful)
107 );
108
109 let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
111
112 if !is_useful && Arc::strong_count(id) == 1 && parents_absent {
113 to_delete.push(*id.as_ref())
114 }
115 }
116 }
117
118 fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {
119 if let Some(status) = self.statuses.get(&node_id) {
121 return status.clone();
122 }
123
124 match self.nodes.get(&node_id).cloned() {
125 Some(parents) => {
129 let mut node_status = NodeMemoryStatus::Unknown;
130 for parent in parents {
131 let parent_status = self.unavailable_propagation(parent);
132 if let NodeMemoryStatus::Unavailable = parent_status {
133 node_status = NodeMemoryStatus::Unavailable;
134 }
135 }
136 self.statuses.insert(node_id, node_status.clone());
137 node_status
138 }
139 None => {
142 self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
143 NodeMemoryStatus::Unavailable
144 }
145 }
146 }
147
148 fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
149 let mut explored = HashSet::new();
151 let mut tagged_useful = HashSet::new();
152
153 let mut to_tag_useful = PopNodeSet::default();
155 let mut to_explore = PopNodeSet::new(leaves);
156
157 let parents = |node_id| {
159 self.nodes
160 .get(&node_id)
161 .cloned()
162 .unwrap_or_default()
163 .into_iter()
164 };
165
166 loop {
167 let (node_id, status) = match to_tag_useful.pop() {
169 Some(node_id) => (node_id, NodeMemoryStatus::Useful),
170 None => match to_explore.pop() {
171 Some(node_id) => {
172 let node_status = self
173 .statuses
174 .get(&node_id)
175 .expect("All nodes should have received a status during unavailable_propagation")
176 .to_owned();
177
178 if let NodeMemoryStatus::Unknown = node_status {
179 match self.is_referenced(node_id) {
180 true => (node_id, NodeMemoryStatus::Useful),
181 false => (node_id, NodeMemoryStatus::Unknown),
182 }
183 } else {
184 (node_id, node_status)
185 }
186 }
187 None => {
188 break;
190 }
191 },
192 };
193
194 match status {
195 NodeMemoryStatus::Useful => {
196 tagged_useful.insert(node_id);
197 for parent in parents(node_id) {
198 if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
200 to_tag_useful.insert(parent);
201 }
202 }
203 }
204 _ => {
205 explored.insert(node_id);
206 for parent in parents(node_id) {
207 if !(explored.contains(&parent) || to_explore.contains(&parent)) {
208 to_explore.insert(parent);
209 }
210 }
211 }
212 }
213
214 self.statuses.insert(node_id, status);
215 }
216 }
217
218 fn identify_leaves_and_deletables(
219 &self,
220 leaf_id: NodeId,
221 new_leaves: &mut HashSet<NodeId>,
222 to_delete: &mut Vec<NodeId>,
223 ) {
224 let mut visited = HashSet::new();
225 let mut to_visit = vec![leaf_id];
226
227 while let Some(node_id) = to_visit.pop() {
228 visited.insert(node_id);
229
230 match self
231 .statuses
232 .get(&node_id)
233 .expect("Node should have status")
234 {
235 NodeMemoryStatus::Useful => {
236 new_leaves.insert(node_id);
237 }
238 _ => {
239 to_delete.push(node_id);
240
241 for parent in self
242 .nodes
243 .get(&node_id)
244 .cloned()
245 .unwrap_or_default()
246 .into_iter()
247 {
248 if !visited.contains(&parent) {
249 to_visit.push(parent);
250 }
251 }
252 }
253 };
254 }
255 }
256
257 fn is_referenced(&self, node_id: NodeId) -> bool {
258 match self.nodes.get_key_value(&node_id) {
259 Some((key, _value)) => Arc::strong_count(key) > 1,
260 None => panic!("Node should be in the nodes map"),
261 }
262 }
263
264 pub(crate) fn maybe_useful(&self) -> bool {
265 self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
266 }
267}
268
269#[derive(new, Default)]
271struct PopNodeSet {
272 hash_set: HashSet<NodeId>,
273}
274
275impl PopNodeSet {
276 #[inline(always)]
277 fn pop(&mut self) -> Option<NodeId> {
278 self.hash_set
279 .iter()
280 .next()
281 .copied()
282 .and_then(|node_id| self.hash_set.take(&node_id))
283 }
284
285 #[inline(always)]
286 fn contains(&self, node_id: &NodeId) -> bool {
287 self.hash_set.contains(node_id)
288 }
289
290 #[inline(always)]
291 fn insert(&mut self, node_id: NodeId) {
292 self.hash_set.insert(node_id);
293 }
294}