1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
use hashbrown::HashMap;
use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};
/// Keep all [tensor handles](BackendIr::Handle) in one place and ensure that all resources
/// are used optimally.
#[derive(Default)]
pub struct HandleContainer<H> {
handles: HashMap<TensorId, Handle<H>>,
counter: u64,
}
impl<H: Clone> HandleContainer<H> {
/// Fork the container, useful for autotune.
pub fn fork(&self) -> Self {
let mut handles = HashMap::with_capacity(self.handles.len());
for (id, handle) in self.handles.iter() {
handles.insert(*id, handle.clone());
}
Self {
handles,
counter: self.counter,
}
}
}
impl<H> core::fmt::Debug for HandleContainer<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HandleContainer")
.field("handles", &self.handles.keys()) // only care about the IDs when debugging
.field("counter", &self.counter)
.finish()
}
}
/// Backend [tensor handle](BackendIr::Handle) wrapper tracking their creation state
#[derive(Clone)]
pub enum Handle<H> {
/// No [tensor handle](BackendIr::Handle) has been created yet
NotInit,
/// A [tensor handle](BackendIr::Handle) has been created
Existing(H),
}
impl<H: Clone> HandleContainer<H> {
/// Create a new HandleContainer
pub fn new() -> Self {
Self {
handles: HashMap::new(),
counter: 0,
}
}
/// Register a handle for the given [tensor id](TensorId).
pub fn register_handle(&mut self, id: TensorId, handle: H) {
self.handles.insert(id, Handle::Existing(handle));
}
/// Whether an handle exists.
pub fn has_handle(&mut self, id: &TensorId) -> bool {
self.handles.contains_key(id)
}
/// Get the reference to a handle.
pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {
self.handles
.get(id)
.filter(|h| !matches!(h, Handle::NotInit))
.map(|h| match h {
Handle::Existing(handle) => handle,
Handle::NotInit => unreachable!(),
})
}
/// Get the handle for the given [tensor id](TensorId). The status is used to determine if the
/// tensor should be popped out of the current tensor map, necessary for inplace operations.
///
/// # Warnings
///
/// Make sure the status corresponds to the operation you want to execute the handle on,
/// otherwise you might remove a tensor handle that will be required in the future.
pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
let (id, handle) = self
.handles
.remove_entry(id)
.unwrap_or_else(|| panic!("Should have handle for tensor {id:?}"));
match handle {
Handle::Existing(handle) => match status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
handle
}
TensorStatus::ReadWrite => handle,
TensorStatus::NotInit => panic!(
"Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status"
),
},
Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."),
}
}
/// Get the tensor handle for the given [tensor intermediate representation](TensorIr).
pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {
TensorHandle {
handle: self.get_handle(&tensor.id, &tensor.status),
shape: tensor.shape.clone(),
}
}
/// Get the [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::float_tensor(self.get_tensor_handle(tensor))
}
/// Get the [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::int_tensor(self.get_tensor_handle(tensor))
}
/// Get the [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::bool_tensor(self.get_tensor_handle(tensor))
}
/// Get the [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) corresponding to the
/// given [tensor intermediate representation](TensorIr).
pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive
where
B: BackendIr<Handle = H>,
{
B::quantized_tensor(self.get_tensor_handle(tensor))
}
/// Register a new [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
where
B: BackendIr<Handle = H>,
{
let handle = B::float_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
pub fn register_quantized_tensor<B>(
&mut self,
id: &TensorId,
tensor: B::QuantizedTensorPrimitive,
) where
B: BackendIr<Handle = H>,
{
let handle = B::quantized_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
where
B: BackendIr<Handle = H>,
{
let handle = B::int_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Register a new [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
where
B: BackendIr<Handle = H>,
{
let handle = B::bool_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
/// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId).
pub fn create_tensor_uninit(&mut self) -> TensorId {
let id = TensorId::new(self.counter);
self.counter += 1;
self.handles.insert(id, Handle::NotInit);
id
}
/// Remove tensor handle from container.
pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {
self.handles.remove(&id)
}
/// Remove tensor handle from container if writable
pub fn free(&mut self, tensor: &TensorIr) {
match tensor.status {
TensorStatus::ReadOnly => (),
TensorStatus::NotInit => (),
TensorStatus::ReadWrite => {
self.handles.remove(&tensor.id);
}
};
}
/// Returns the number of handles.
pub fn num_handles(&self) -> usize {
self.handles.len()
}
}