Skip to main content

burn_autodiff/runtime/
memory_management.rs

1use crate::{
2    NodeId,
3    collections::{HashMap, HashSet},
4    graph::Parent,
5    tensor::NodeRefCount,
6};
7use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
8use core::mem;
9
10#[derive(Default, Debug)]
11pub struct GraphMemoryManagement {
12    nodes: HashMap<NodeRefCount, Vec<NodeId>>,
13    leaves: HashSet<NodeId>,
14    statuses: HashMap<NodeId, NodeMemoryStatus>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18enum NodeMemoryStatus {
19    Useful,
20    Unavailable,
21    Unknown,
22}
23
24impl GraphMemoryManagement {
25    pub fn extend(&mut self, other: Self) {
26        self.nodes.extend(other.nodes);
27        self.leaves.extend(other.leaves);
28        self.statuses.extend(other.statuses);
29    }
30
31    /// Register a new node with its parent.
32    pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
33        let node_id = *node.as_ref();
34
35        for parent in parents.iter() {
36            self.leaves.remove(&parent.id);
37        }
38
39        self.leaves.insert(node_id);
40        self.nodes
41            .insert(node, parents.iter().map(|p| p.id).collect());
42    }
43
44    /// Free the node from the state.
45    pub fn consume_node(&mut self, node_id: NodeId) {
46        if !self.is_referenced(node_id) {
47            self.leaves.remove(&node_id);
48            self.nodes.remove(&node_id);
49        }
50    }
51
52    /// Free all nodes whose backward call has become impossible
53    ///
54    /// This function goes into three steps, which must happen for all leaves
55    /// before going into the next step. Then it deletes what can be safely deleted
56    pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
57        let leaves = self.leaves.clone();
58        let mut new_leaves = HashSet::new();
59        let mut deletables = Vec::new();
60
61        // When consuming nodes with a backward pass, some other backward passes become
62        // unavailable because some of their parents have been consumed. They are
63        // identified here.
64        for leaf in leaves.clone() {
65            self.unavailable_propagation(leaf);
66        }
67
68        // Among the available nodes that remain, some may be useless if no
69        // available node with a tensor reference exist in their descendance.
70        // But some may seem useless from some leaf but be useful from another one,
71        // hence the need to iterate on all leaves.
72        self.useful_propagation(leaves.clone());
73
74        // New leaves are the roots of a useful backward sub-tree.
75        // Deletables are everything not marked as useful.
76        for leaf in leaves {
77            self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
78        }
79
80        // Replace leaves by the new ones and delete everything not useful anymore
81        mem::swap(&mut self.leaves, &mut new_leaves);
82
83        self.clear_unused_roots(&mut deletables);
84
85        self.statuses.clear();
86        for node_to_delete in deletables {
87            self.nodes.remove(&node_to_delete);
88            on_free_graph(&node_to_delete)
89        }
90    }
91
92    pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
93        let mut deletables = Vec::new();
94        self.clear_unused_roots(&mut deletables);
95
96        for node_id in deletables {
97            self.nodes.remove(&node_id);
98            on_free_graph(&node_id);
99        }
100    }
101
102    fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
103        for (id, parents) in self.nodes.iter() {
104            let is_useful = matches!(
105                self.statuses.get(id.as_ref()),
106                Some(NodeMemoryStatus::Useful)
107            );
108
109            // Check if parents are either empty or absent from self.nodes
110            let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
111
112            if !is_useful && Arc::strong_count(id) == 1 && parents_absent {
113                to_delete.push(*id.as_ref())
114            }
115        }
116    }
117
118    fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {
119        // If already visited
120        if let Some(status) = self.statuses.get(&node_id) {
121            return status.clone();
122        }
123
124        match self.nodes.get(&node_id).cloned() {
125            // If node exists and any of its parents is unavailable, it is unavailable as well
126            // If node exists but the parents vec is empty, it is a tensor that never had parents;
127            //  the status remains unknown
128            Some(parents) => {
129                let mut node_status = NodeMemoryStatus::Unknown;
130                for parent in parents {
131                    let parent_status = self.unavailable_propagation(parent);
132                    if let NodeMemoryStatus::Unavailable = parent_status {
133                        node_status = NodeMemoryStatus::Unavailable;
134                    }
135                }
136                self.statuses.insert(node_id, node_status.clone());
137                node_status
138            }
139            // If node does not exist, it was
140            // deleted, so this and all its descendants are unavailable
141            None => {
142                self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
143                NodeMemoryStatus::Unavailable
144            }
145        }
146    }
147
148    fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
149        // Accumulate visited nodes
150        let mut explored = HashSet::new();
151        let mut tagged_useful = HashSet::new();
152
153        // Queue of nodes to visit
154        let mut to_tag_useful = PopNodeSet::default();
155        let mut to_explore = PopNodeSet::new(leaves);
156
157        // Utilitary function to iterate over a node's parents
158        let parents = |node_id| {
159            self.nodes
160                .get(&node_id)
161                .cloned()
162                .unwrap_or_default()
163                .into_iter()
164        };
165
166        loop {
167            // Pop a node id, greedily looking at tag_useful ones first
168            let (node_id, status) = match to_tag_useful.pop() {
169                Some(node_id) => (node_id, NodeMemoryStatus::Useful),
170                None => match to_explore.pop() {
171                    Some(node_id) => {
172                        let node_status = self
173                            .statuses
174                            .get(&node_id)
175                            .expect("All nodes should have received a status during unavailable_propagation")
176                            .to_owned();
177
178                        if let NodeMemoryStatus::Unknown = node_status {
179                            match self.is_referenced(node_id) {
180                                true => (node_id, NodeMemoryStatus::Useful),
181                                false => (node_id, NodeMemoryStatus::Unknown),
182                            }
183                        } else {
184                            (node_id, node_status)
185                        }
186                    }
187                    None => {
188                        // There are no nodes in the queues anymore
189                        break;
190                    }
191                },
192            };
193
194            match status {
195                NodeMemoryStatus::Useful => {
196                    tagged_useful.insert(node_id);
197                    for parent in parents(node_id) {
198                        // The node can be explored, as long as it's not already tagged useful
199                        if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
200                            to_tag_useful.insert(parent);
201                        }
202                    }
203                }
204                _ => {
205                    explored.insert(node_id);
206                    for parent in parents(node_id) {
207                        if !(explored.contains(&parent) || to_explore.contains(&parent)) {
208                            to_explore.insert(parent);
209                        }
210                    }
211                }
212            }
213
214            self.statuses.insert(node_id, status);
215        }
216    }
217
218    fn identify_leaves_and_deletables(
219        &self,
220        leaf_id: NodeId,
221        new_leaves: &mut HashSet<NodeId>,
222        to_delete: &mut Vec<NodeId>,
223    ) {
224        let mut visited = HashSet::new();
225        let mut to_visit = vec![leaf_id];
226
227        while let Some(node_id) = to_visit.pop() {
228            visited.insert(node_id);
229
230            match self
231                .statuses
232                .get(&node_id)
233                .expect("Node should have status")
234            {
235                NodeMemoryStatus::Useful => {
236                    new_leaves.insert(node_id);
237                }
238                _ => {
239                    to_delete.push(node_id);
240
241                    for parent in self
242                        .nodes
243                        .get(&node_id)
244                        .cloned()
245                        .unwrap_or_default()
246                        .into_iter()
247                    {
248                        if !visited.contains(&parent) {
249                            to_visit.push(parent);
250                        }
251                    }
252                }
253            };
254        }
255    }
256
257    fn is_referenced(&self, node_id: NodeId) -> bool {
258        match self.nodes.get_key_value(&node_id) {
259            Some((key, _value)) => Arc::strong_count(key) > 1,
260            None => panic!("Node should be in the nodes map"),
261        }
262    }
263
264    pub(crate) fn maybe_useful(&self) -> bool {
265        self.nodes.keys().any(|node| Arc::strong_count(node) > 1)
266    }
267}
268
269/// Wrapper over hash set for fast popping of any node
270#[derive(new, Default)]
271struct PopNodeSet {
272    hash_set: HashSet<NodeId>,
273}
274
275impl PopNodeSet {
276    #[inline(always)]
277    fn pop(&mut self) -> Option<NodeId> {
278        self.hash_set
279            .iter()
280            .next()
281            .copied()
282            .and_then(|node_id| self.hash_set.take(&node_id))
283    }
284
285    #[inline(always)]
286    fn contains(&self, node_id: &NodeId) -> bool {
287        self.hash_set.contains(node_id)
288    }
289
290    #[inline(always)]
291    fn insert(&mut self, node_id: NodeId) {
292        self.hash_set.insert(node_id);
293    }
294}