Skip to main content

burn_ir/
handle.rs

1use hashbrown::HashMap;
2
3use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};
4
5/// Keep all [tensor handles](BackendIr::Handle) in one place and ensure that all resources
6/// are used optimally.
7#[derive(Default)]
8pub struct HandleContainer<H> {
9    handles: HashMap<TensorId, Handle<H>>,
10    counter: u64,
11}
12
13impl<H: Clone> HandleContainer<H> {
14    /// Fork the container, useful for autotune.
15    pub fn fork(&self) -> Self {
16        let mut handles = HashMap::with_capacity(self.handles.len());
17
18        for (id, handle) in self.handles.iter() {
19            handles.insert(*id, handle.clone());
20        }
21
22        Self {
23            handles,
24            counter: self.counter,
25        }
26    }
27}
28
29impl<H> core::fmt::Debug for HandleContainer<H> {
30    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31        f.debug_struct("HandleContainer")
32            .field("handles", &self.handles.keys()) // only care about the IDs when debugging
33            .field("counter", &self.counter)
34            .finish()
35    }
36}
37
38/// Backend [tensor handle](BackendIr::Handle) wrapper tracking their creation state
39#[derive(Clone)]
40pub enum Handle<H> {
41    /// No [tensor handle](BackendIr::Handle) has been created yet
42    NotInit,
43    /// A [tensor handle](BackendIr::Handle) has been created
44    Existing(H),
45}
46
47impl<H: Clone> HandleContainer<H> {
48    /// Create a new HandleContainer
49    pub fn new() -> Self {
50        Self {
51            handles: HashMap::new(),
52            counter: 0,
53        }
54    }
55
56    /// Register a handle for the given [tensor id](TensorId).
57    pub fn register_handle(&mut self, id: TensorId, handle: H) {
58        self.handles.insert(id, Handle::Existing(handle));
59    }
60
61    /// Whether an handle exists.
62    pub fn has_handle(&mut self, id: &TensorId) -> bool {
63        self.handles.contains_key(id)
64    }
65
66    /// Get the reference to a handle.
67    pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {
68        self.handles
69            .get(id)
70            .filter(|h| !matches!(h, Handle::NotInit))
71            .map(|h| match h {
72                Handle::Existing(handle) => handle,
73                Handle::NotInit => unreachable!(),
74            })
75    }
76
77    /// Get the handle for the given [tensor id](TensorId). The status is used to determine if the
78    /// tensor should be popped out of the current tensor map, necessary for inplace operations.
79    ///
80    /// # Warnings
81    ///
82    /// Make sure the status corresponds to the operation you want to execute the handle on,
83    /// otherwise you might remove a tensor handle that will be required in the future.
84    pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
85        let (id, handle) = self
86            .handles
87            .remove_entry(id)
88            .unwrap_or_else(|| panic!("Should have handle for tensor {id:?}"));
89
90        match handle {
91            Handle::Existing(handle) => match status {
92                TensorStatus::ReadOnly => {
93                    self.handles.insert(id, Handle::Existing(handle.clone()));
94                    handle
95                }
96                TensorStatus::ReadWrite => handle,
97                TensorStatus::NotInit => panic!(
98                    "Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status"
99                ),
100            },
101            Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."),
102        }
103    }
104
105    /// Get the tensor handle for the given [tensor intermediate representation](TensorIr).
106    pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {
107        TensorHandle {
108            handle: self.get_handle(&tensor.id, &tensor.status),
109            shape: tensor.shape.clone(),
110        }
111    }
112
113    /// Get the [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) corresponding to the
114    /// given [tensor intermediate representation](TensorIr).
115    pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive
116    where
117        B: BackendIr<Handle = H>,
118    {
119        B::float_tensor(self.get_tensor_handle(tensor))
120    }
121
122    /// Get the [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) corresponding to the
123    /// given [tensor intermediate representation](TensorIr).
124    pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive
125    where
126        B: BackendIr<Handle = H>,
127    {
128        B::int_tensor(self.get_tensor_handle(tensor))
129    }
130
131    /// Get the [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) corresponding to the
132    /// given [tensor intermediate representation](TensorIr).
133    pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive
134    where
135        B: BackendIr<Handle = H>,
136    {
137        B::bool_tensor(self.get_tensor_handle(tensor))
138    }
139
140    /// Get the [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) corresponding to the
141    /// given [tensor intermediate representation](TensorIr).
142    pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive
143    where
144        B: BackendIr<Handle = H>,
145    {
146        B::quantized_tensor(self.get_tensor_handle(tensor))
147    }
148
149    /// Register a new [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
150    pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
151    where
152        B: BackendIr<Handle = H>,
153    {
154        let handle = B::float_tensor_handle(tensor);
155        self.handles.insert(*id, Handle::Existing(handle));
156    }
157
158    /// Register a new [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
159    pub fn register_quantized_tensor<B>(
160        &mut self,
161        id: &TensorId,
162        tensor: B::QuantizedTensorPrimitive,
163    ) where
164        B: BackendIr<Handle = H>,
165    {
166        let handle = B::quantized_tensor_handle(tensor);
167        self.handles.insert(*id, Handle::Existing(handle));
168    }
169
170    /// Register a new [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
171    pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
172    where
173        B: BackendIr<Handle = H>,
174    {
175        let handle = B::int_tensor_handle(tensor);
176        self.handles.insert(*id, Handle::Existing(handle));
177    }
178
179    /// Register a new [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
180    pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
181    where
182        B: BackendIr<Handle = H>,
183    {
184        let handle = B::bool_tensor_handle(tensor);
185        self.handles.insert(*id, Handle::Existing(handle));
186    }
187
188    /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId).
189    pub fn create_tensor_uninit(&mut self) -> TensorId {
190        let id = TensorId::new(self.counter);
191        self.counter += 1;
192        self.handles.insert(id, Handle::NotInit);
193        id
194    }
195
196    /// Remove tensor handle from container.
197    pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {
198        self.handles.remove(&id)
199    }
200
201    /// Remove tensor handle from container if writable
202    pub fn free(&mut self, tensor: &TensorIr) {
203        match tensor.status {
204            TensorStatus::ReadOnly => (),
205            TensorStatus::NotInit => (),
206            TensorStatus::ReadWrite => {
207                self.handles.remove(&tensor.id);
208            }
209        };
210    }
211
212    /// Returns the number of handles.
213    pub fn num_handles(&self) -> usize {
214        self.handles.len()
215    }
216}