1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
use hashbrown::HashMap;
use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};
/// Keep all [tensor handles](BackendIr::Handle) in one place and ensure that all resources
/// are used optimally.
pub struct HandleContainer<H> {
handles: HashMap<TensorId, Handle<H>>,
counter: u64,
}
// Hand-written perfect derive as we don't require `H: Default`.
impl<H> Default for HandleContainer<H> {
fn default() -> Self {
Self {
handles: HashMap::new(),
counter: 0,
}
}
}
impl<H: Clone> HandleContainer<H> {
/// Fork the container, useful for autotune.
pub fn fork(&self) -> Self {
let mut handles = HashMap::with_capacity(self.handles.len());
for (id, handle) in self.handles.iter() {
handles.insert(*id, handle.clone());
}
Self {
handles,
counter: self.counter,
}
}
}
impl<H> core::fmt::Debug for HandleContainer<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HandleContainer")
.field("handles", &self.handles.keys()) // only care about the IDs when debugging
.field("counter", &self.counter)
.finish()
}
}
/// Backend [tensor handle](BackendIr::Handle) wrapper tracking their creation state
#[derive(Clone)]
pub enum Handle<H> {
/// No [tensor handle](BackendIr::Handle) has been created yet
NotInit,
/// A [tensor handle](BackendIr::Handle) has been created
Existing(H),
}
impl<H: Clone> HandleContainer<H> {
/// Create a new HandleContainer
pub fn new() -> Self {
Self {
handles: HashMap::new(),
counter: 0,
}
}
/// Register a handle for the given [tensor id](TensorId).
pub fn register_handle(&mut self, id: TensorId, handle: H) {
self.handles.insert(id, Handle::Existing(handle));
}
/// Whether a handle exists.
pub fn has_handle(&self, id: &TensorId) -> bool {
self.handles.contains_key(id)
}
/// Get the reference to a handle.
pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {
self.handles
.get(id)
.filter(|h| !matches!(h, Handle::NotInit))
.map(|h| match h {
Handle::Existing(handle) => handle,
Handle::NotInit => unreachable!(),
})
}
/// Get the handle for the given [tensor id](TensorId). The status is used to determine if the
/// tensor should be popped out of the current tensor map, necessary for inplace operations.
///
/// # Warnings
///
/// Make sure the status corresponds to the operation you want to execute the handle on,
/// otherwise you might remove a tensor handle that will be required in the future.
pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
let (id, handle) = self
.handles
.remove_entry(id)
.unwrap_or_else(|| panic!("Should have handle for tensor {id:?}"));
match handle {
Handle::Existing(handle) => match status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
handle
}
TensorStatus::ReadWrite => handle,
TensorStatus::NotInit => panic!(
"Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status"
),
},
Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."),
}
}
/// Get the tensor handle for the given [tensor intermediate representation](TensorIr).
pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {
TensorHandle {
handle: self.get_handle(&tensor.id, &tensor.status),
shape: tensor.shape.clone(),
}
}
/// Get the [float tensor](burn_backend::backend::BackendTypes::FloatTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::float_tensor(self.get_tensor_handle(tensor))
}
/// Get the [int tensor](burn_backend::backend::BackendTypes::IntTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::int_tensor(self.get_tensor_handle(tensor))
}
/// Get the [bool tensor](burn_backend::backend::BackendTypes::BoolTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::bool_tensor(self.get_tensor_handle(tensor))
}
/// Get the [quantized tensor](burn_backend::backend::BackendTypes::QuantizedTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::quantized_tensor(self.get_tensor_handle(tensor))
}
/// Register a new [float tensor](burn_backend::backend::BackendTypes::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
where
B: BackendIr<Handle = H>,
{
let handle = B::float_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [quantized tensor](burn_backend::backend::BackendTypes::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
pub fn register_quantized_tensor<B>(
&mut self,
id: &TensorId,
tensor: B::QuantizedTensorPrimitive,
) where
B: BackendIr<Handle = H>,
{
let handle = B::quantized_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [int tensor](burn_backend::backend::BackendTypes::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
where
B: BackendIr<Handle = H>,
{
let handle = B::int_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [bool tensor](burn_backend::backend::BackendTypes::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
where
B: BackendIr<Handle = H>,
{
let handle = B::bool_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Remove tensor handle from container.
pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {
self.handles.remove(&id)
}
/// Remove tensor handle from container if writable
pub fn free(&mut self, tensor: &TensorIr) {
match tensor.status {
TensorStatus::ReadOnly => (),
TensorStatus::NotInit => (),
TensorStatus::ReadWrite => {
self.handles.remove(&tensor.id);
}
};
}
/// Returns the number of handles.
pub fn num_handles(&self) -> usize {
self.handles.len()
}
/// Returns the IDs of all currently registered handles.
///
/// Useful for snapshotting which handles exist at a point in time (e.g., before
/// executing on a forked context) so that newly registered output handles can
/// be detected afterwards.
pub fn handle_ids(&self) -> impl Iterator<Item = &'_ TensorId> {
self.handles.keys()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorId;
/// Helper to create a TensorId for tests.
fn tid(value: u64) -> TensorId {
TensorId::new(value)
}
#[test]
fn fork_clones_existing_handles() {
let mut container = HandleContainer::<String>::new();
container.register_handle(tid(1), "input_a".to_string());
container.register_handle(tid(2), "input_b".to_string());
let fork = container.fork();
assert_eq!(fork.num_handles(), 2);
assert!(fork.get_handle_ref(&tid(1)).is_some());
assert!(fork.get_handle_ref(&tid(2)).is_some());
}
#[test]
fn fork_is_isolated_from_original() {
// This test documents the core of the autotune clone bug:
// output handles registered in a fork do NOT appear in the original.
let mut container = HandleContainer::<String>::new();
container.register_handle(tid(1), "input_a".to_string());
let mut fork = container.fork();
// Simulate an optimization registering output handles in the fork.
fork.register_handle(tid(100), "output_x".to_string());
fork.register_handle(tid(101), "output_y".to_string());
// The fork has the output handles.
assert_eq!(fork.num_handles(), 3);
assert!(fork.get_handle_ref(&tid(100)).is_some());
assert!(fork.get_handle_ref(&tid(101)).is_some());
// But the original does NOT — these output handles are lost.
assert_eq!(container.num_handles(), 1);
assert!(container.get_handle_ref(&tid(100)).is_none());
assert!(container.get_handle_ref(&tid(101)).is_none());
}
#[test]
fn fork_mutations_do_not_affect_original() {
let mut container = HandleContainer::<String>::new();
container.register_handle(tid(1), "original_value".to_string());
let mut fork = container.fork();
// Overwrite a handle in the fork (e.g., inplace output reuse).
fork.register_handle(tid(1), "modified_in_fork".to_string());
// Original is unchanged.
assert_eq!(
container.get_handle_ref(&tid(1)),
Some(&"original_value".to_string())
);
assert_eq!(
fork.get_handle_ref(&tid(1)),
Some(&"modified_in_fork".to_string())
);
}
#[test]
fn double_fork_is_fully_isolated() {
// Simulates what happens when UnsafeTuneContext::get() is called on a Fork:
// it forks again, creating a second level of isolation.
let mut container = HandleContainer::<String>::new();
container.register_handle(tid(1), "input".to_string());
let fork1 = container.fork();
let mut fork2 = fork1.fork();
fork2.register_handle(tid(200), "deep_output".to_string());
assert!(fork1.get_handle_ref(&tid(200)).is_none());
assert!(container.get_handle_ref(&tid(200)).is_none());
assert!(fork2.get_handle_ref(&tid(200)).is_some());
}
}