Skip to main content

burn_ir/
handle.rs

1use hashbrown::HashMap;
2
3use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};
4
5/// Keep all [tensor handles](BackendIr::Handle) in one place and ensure that all resources
6/// are used optimally.
7pub struct HandleContainer<H> {
8    handles: HashMap<TensorId, Handle<H>>,
9    counter: u64,
10}
11
12// Hand-written perfect derive as we don't require `H: Default`.
13impl<H> Default for HandleContainer<H> {
14    fn default() -> Self {
15        Self {
16            handles: HashMap::new(),
17            counter: 0,
18        }
19    }
20}
21
22impl<H: Clone> HandleContainer<H> {
23    /// Fork the container, useful for autotune.
24    pub fn fork(&self) -> Self {
25        let mut handles = HashMap::with_capacity(self.handles.len());
26
27        for (id, handle) in self.handles.iter() {
28            handles.insert(*id, handle.clone());
29        }
30
31        Self {
32            handles,
33            counter: self.counter,
34        }
35    }
36}
37
38impl<H> core::fmt::Debug for HandleContainer<H> {
39    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40        f.debug_struct("HandleContainer")
41            .field("handles", &self.handles.keys()) // only care about the IDs when debugging
42            .field("counter", &self.counter)
43            .finish()
44    }
45}
46
47/// Backend [tensor handle](BackendIr::Handle) wrapper tracking their creation state
48#[derive(Clone)]
49pub enum Handle<H> {
50    /// No [tensor handle](BackendIr::Handle) has been created yet
51    NotInit,
52    /// A [tensor handle](BackendIr::Handle) has been created
53    Existing(H),
54}
55
56impl<H: Clone> HandleContainer<H> {
57    /// Create a new HandleContainer
58    pub fn new() -> Self {
59        Self {
60            handles: HashMap::new(),
61            counter: 0,
62        }
63    }
64
65    /// Register a handle for the given [tensor id](TensorId).
66    pub fn register_handle(&mut self, id: TensorId, handle: H) {
67        self.handles.insert(id, Handle::Existing(handle));
68    }
69
70    /// Whether a handle exists.
71    pub fn has_handle(&self, id: &TensorId) -> bool {
72        self.handles.contains_key(id)
73    }
74
75    /// Get the reference to a handle.
76    pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {
77        self.handles
78            .get(id)
79            .filter(|h| !matches!(h, Handle::NotInit))
80            .map(|h| match h {
81                Handle::Existing(handle) => handle,
82                Handle::NotInit => unreachable!(),
83            })
84    }
85
86    /// Get the handle for the given [tensor id](TensorId). The status is used to determine if the
87    /// tensor should be popped out of the current tensor map, necessary for inplace operations.
88    ///
89    /// # Warnings
90    ///
91    /// Make sure the status corresponds to the operation you want to execute the handle on,
92    /// otherwise you might remove a tensor handle that will be required in the future.
93    pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
94        let (id, handle) = self
95            .handles
96            .remove_entry(id)
97            .unwrap_or_else(|| panic!("Should have handle for tensor {id:?}"));
98
99        match handle {
100            Handle::Existing(handle) => match status {
101                TensorStatus::ReadOnly => {
102                    self.handles.insert(id, Handle::Existing(handle.clone()));
103                    handle
104                }
105                TensorStatus::ReadWrite => handle,
106                TensorStatus::NotInit => panic!(
107                    "Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status"
108                ),
109            },
110            Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."),
111        }
112    }
113
114    /// Get the tensor handle for the given [tensor intermediate representation](TensorIr).
115    pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {
116        TensorHandle {
117            handle: self.get_handle(&tensor.id, &tensor.status),
118            shape: tensor.shape.clone(),
119        }
120    }
121
122    /// Get the [float tensor](burn_backend::backend::BackendTypes::FloatTensorPrimitive) corresponding to the
123    /// given [tensor intermediate representation](TensorIr).
124    pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
125    where
126        B: BackendIr<Handle = H>,
127    {
128        B::float_tensor(self.get_tensor_handle(tensor))
129    }
130
131    /// Get the [int tensor](burn_backend::backend::BackendTypes::IntTensorPrimitive) corresponding to the
132    /// given [tensor intermediate representation](TensorIr).
133    pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
134    where
135        B: BackendIr<Handle = H>,
136    {
137        B::int_tensor(self.get_tensor_handle(tensor))
138    }
139
140    /// Get the [bool tensor](burn_backend::backend::BackendTypes::BoolTensorPrimitive) corresponding to the
141    /// given [tensor intermediate representation](TensorIr).
142    pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
143    where
144        B: BackendIr<Handle = H>,
145    {
146        B::bool_tensor(self.get_tensor_handle(tensor))
147    }
148
149    /// Get the [quantized tensor](burn_backend::backend::BackendTypes::QuantizedTensorPrimitive) corresponding to the
150    /// given [tensor intermediate representation](TensorIr).
151    pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive
152    where
153        B: BackendIr<Handle = H>,
154    {
155        B::quantized_tensor(self.get_tensor_handle(tensor))
156    }
157
158    /// Register a new [float tensor](burn_backend::backend::BackendTypes::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
159    pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
160    where
161        B: BackendIr<Handle = H>,
162    {
163        let handle = B::float_tensor_handle(tensor);
164        self.handles.insert(*id, Handle::Existing(handle));
165    }
166
167    /// Register a new [quantized tensor](burn_backend::backend::BackendTypes::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
168    pub fn register_quantized_tensor<B>(
169        &mut self,
170        id: &TensorId,
171        tensor: B::QuantizedTensorPrimitive,
172    ) where
173        B: BackendIr<Handle = H>,
174    {
175        let handle = B::quantized_tensor_handle(tensor);
176        self.handles.insert(*id, Handle::Existing(handle));
177    }
178
179    /// Register a new [int tensor](burn_backend::backend::BackendTypes::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
180    pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
181    where
182        B: BackendIr<Handle = H>,
183    {
184        let handle = B::int_tensor_handle(tensor);
185        self.handles.insert(*id, Handle::Existing(handle));
186    }
187
188    /// Register a new [bool tensor](burn_backend::backend::BackendTypes::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
189    pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
190    where
191        B: BackendIr<Handle = H>,
192    {
193        let handle = B::bool_tensor_handle(tensor);
194        self.handles.insert(*id, Handle::Existing(handle));
195    }
196
197    /// Remove tensor handle from container.
198    pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {
199        self.handles.remove(&id)
200    }
201
202    /// Remove tensor handle from container if writable
203    pub fn free(&mut self, tensor: &TensorIr) {
204        match tensor.status {
205            TensorStatus::ReadOnly => (),
206            TensorStatus::NotInit => (),
207            TensorStatus::ReadWrite => {
208                self.handles.remove(&tensor.id);
209            }
210        };
211    }
212
213    /// Returns the number of handles.
214    pub fn num_handles(&self) -> usize {
215        self.handles.len()
216    }
217
218    /// Returns the IDs of all currently registered handles.
219    ///
220    /// Useful for snapshotting which handles exist at a point in time (e.g., before
221    /// executing on a forked context) so that newly registered output handles can
222    /// be detected afterwards.
223    pub fn handle_ids(&self) -> impl Iterator<Item = &'_ TensorId> {
224        self.handles.keys()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::TensorId;
232
233    /// Helper to create a TensorId for tests.
234    fn tid(value: u64) -> TensorId {
235        TensorId::new(value)
236    }
237
238    #[test]
239    fn fork_clones_existing_handles() {
240        let mut container = HandleContainer::<String>::new();
241        container.register_handle(tid(1), "input_a".to_string());
242        container.register_handle(tid(2), "input_b".to_string());
243
244        let fork = container.fork();
245
246        assert_eq!(fork.num_handles(), 2);
247        assert!(fork.get_handle_ref(&tid(1)).is_some());
248        assert!(fork.get_handle_ref(&tid(2)).is_some());
249    }
250
251    #[test]
252    fn fork_is_isolated_from_original() {
253        // This test documents the core of the autotune clone bug:
254        // output handles registered in a fork do NOT appear in the original.
255        let mut container = HandleContainer::<String>::new();
256        container.register_handle(tid(1), "input_a".to_string());
257
258        let mut fork = container.fork();
259
260        // Simulate an optimization registering output handles in the fork.
261        fork.register_handle(tid(100), "output_x".to_string());
262        fork.register_handle(tid(101), "output_y".to_string());
263
264        // The fork has the output handles.
265        assert_eq!(fork.num_handles(), 3);
266        assert!(fork.get_handle_ref(&tid(100)).is_some());
267        assert!(fork.get_handle_ref(&tid(101)).is_some());
268
269        // But the original does NOT — these output handles are lost.
270        assert_eq!(container.num_handles(), 1);
271        assert!(container.get_handle_ref(&tid(100)).is_none());
272        assert!(container.get_handle_ref(&tid(101)).is_none());
273    }
274
275    #[test]
276    fn fork_mutations_do_not_affect_original() {
277        let mut container = HandleContainer::<String>::new();
278        container.register_handle(tid(1), "original_value".to_string());
279
280        let mut fork = container.fork();
281
282        // Overwrite a handle in the fork (e.g., inplace output reuse).
283        fork.register_handle(tid(1), "modified_in_fork".to_string());
284
285        // Original is unchanged.
286        assert_eq!(
287            container.get_handle_ref(&tid(1)),
288            Some(&"original_value".to_string())
289        );
290        assert_eq!(
291            fork.get_handle_ref(&tid(1)),
292            Some(&"modified_in_fork".to_string())
293        );
294    }
295
296    #[test]
297    fn double_fork_is_fully_isolated() {
298        // Simulates what happens when UnsafeTuneContext::get() is called on a Fork:
299        // it forks again, creating a second level of isolation.
300        let mut container = HandleContainer::<String>::new();
301        container.register_handle(tid(1), "input".to_string());
302
303        let fork1 = container.fork();
304        let mut fork2 = fork1.fork();
305
306        fork2.register_handle(tid(200), "deep_output".to_string());
307
308        assert!(fork1.get_handle_ref(&tid(200)).is_none());
309        assert!(container.get_handle_ref(&tid(200)).is_none());
310        assert!(fork2.get_handle_ref(&tid(200)).is_some());
311    }
312}