burn_tensor/repr/
handle.rs

1use crate::{
2    repr::{
3        backend::ReprBackend,
4        tensor::{TensorDescription, TensorId, TensorStatus},
5    },
6    Shape,
7};
8use alloc::vec::Vec;
9use hashbrown::HashMap;
10
11#[cfg(target_has_atomic = "ptr")]
12use alloc::sync::Arc;
13
14#[cfg(not(target_has_atomic = "ptr"))]
15use portable_atomic_util::Arc;
16
17use super::TensorHandle;
18
19/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
20/// are used optimally.
21#[derive(Default)]
22pub struct HandleContainer<H> {
23    handles: HashMap<TensorId, Handle<H>>,
24    counter: u64,
25    /// Handle candidates to be freed.
26    pub handles_orphan: Vec<TensorId>,
27}
28
29impl<H> core::fmt::Debug for HandleContainer<H> {
30    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31        f.debug_struct("HandleContainer")
32            .field("handles", &self.handles.keys()) // only care about the IDs when debugging
33            .field("counter", &self.counter)
34            .field("handles_orphan", &self.handles_orphan)
35            .finish()
36    }
37}
38
39/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state
40pub enum Handle<H> {
41    /// No [tensor handle](ReprBackend::Handle) has been created yet
42    NotInit,
43    /// A [tensor handle](ReprBackend::Handle) has been created
44    Existing(H),
45}
46
47impl<H: Clone> HandleContainer<H> {
48    /// Create a new HandleContainer
49    pub fn new() -> Self {
50        Self {
51            handles: HashMap::new(),
52            handles_orphan: Vec::new(),
53            counter: 0,
54        }
55    }
56
57    /// Register a handle for the given [tensor id](TensorId).
58    pub fn register_handle(&mut self, id: TensorId, handle: H) {
59        self.handles.insert(id, Handle::Existing(handle));
60    }
61
62    /// Get the handle for the given [tensor id](TensorId). The status is used to determine if the
63    /// tensor should be popped out of the current tensor map, necessary for inplace operations.
64    ///
65    /// # Warnings
66    ///
67    /// Make sure the status corresponds to the operation you want to execute the handle on,
68    /// otherwise you might remove a tensor handle that will be required in the future.
69    pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
70        let (id, handle) = self
71            .handles
72            .remove_entry(id)
73            .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", id));
74
75        match handle {
76            Handle::Existing(handle) => match status {
77                TensorStatus::ReadOnly => {
78                    self.handles.insert(id, Handle::Existing(handle.clone()));
79                    handle
80                }
81                TensorStatus::ReadWrite => handle,
82                TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."),
83            },
84            Handle::NotInit => panic!("Cannot get uninitialized handle."),
85        }
86    }
87
88    /// Get the tensor handle for the given [tensor description](TensorDescription).
89    pub fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
90        TensorHandle {
91            handle: self.get_handle(&tensor.id, &tensor.status),
92            shape: Shape::from(&tensor.shape),
93        }
94    }
95
96    /// Get the [float tensor](crate::backend::Backend::FloatTensorPrimitive) corresponding to the
97    /// given [tensor description](TensorDescription).
98    pub fn get_float_tensor<B>(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive
99    where
100        B: ReprBackend<Handle = H>,
101    {
102        B::float_tensor(self.get_tensor_handle(tensor))
103    }
104
105    /// Get the [int tensor](crate::backend::Backend::IntTensorPrimitive) corresponding to the
106    /// given [tensor description](TensorDescription).
107    pub fn get_int_tensor<B>(&mut self, tensor: &TensorDescription) -> B::IntTensorPrimitive
108    where
109        B: ReprBackend<Handle = H>,
110    {
111        B::int_tensor(self.get_tensor_handle(tensor))
112    }
113
114    /// Get the [bool tensor](crate::backend::Backend::BoolTensorPrimitive) corresponding to the
115    /// given [tensor description](TensorDescription).
116    pub fn get_bool_tensor<B>(&mut self, tensor: &TensorDescription) -> B::BoolTensorPrimitive
117    where
118        B: ReprBackend<Handle = H>,
119    {
120        B::bool_tensor(self.get_tensor_handle(tensor))
121    }
122
123    /// Get the [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) corresponding to the
124    /// given [tensor description](TensorDescription).
125    pub fn get_quantized_tensor<B>(
126        &mut self,
127        tensor: &TensorDescription,
128    ) -> B::QuantizedTensorPrimitive
129    where
130        B: ReprBackend<Handle = H>,
131    {
132        B::quantized_tensor(self.get_tensor_handle(tensor))
133    }
134
135    /// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).
136    pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
137    where
138        B: ReprBackend<Handle = H>,
139    {
140        let handle = B::float_tensor_handle(tensor);
141        self.handles.insert(*id, Handle::Existing(handle));
142    }
143
144    /// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
145    pub fn register_quantized_tensor<B>(
146        &mut self,
147        id: &TensorId,
148        tensor: B::QuantizedTensorPrimitive,
149    ) where
150        B: ReprBackend<Handle = H>,
151    {
152        let handle = B::quantized_tensor_handle(tensor);
153        self.handles.insert(*id, Handle::Existing(handle));
154    }
155
156    /// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).
157    pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
158    where
159        B: ReprBackend<Handle = H>,
160    {
161        let handle = B::int_tensor_handle(tensor);
162        self.handles.insert(*id, Handle::Existing(handle));
163    }
164
165    /// Register a new [bool tensor](crate::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).
166    pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
167    where
168        B: ReprBackend<Handle = H>,
169    {
170        let handle = B::bool_tensor_handle(tensor);
171        self.handles.insert(*id, Handle::Existing(handle));
172    }
173
174    /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId).
175    pub fn create_tensor_uninit(&mut self) -> Arc<TensorId> {
176        let id = TensorId::new(self.counter);
177        self.counter += 1;
178        self.handles.insert(id, Handle::NotInit);
179
180        Arc::new(id)
181    }
182
183    /// Remove tensor handle from container if writable
184    pub fn free(&mut self, tensor: &TensorDescription) {
185        match tensor.status {
186            TensorStatus::ReadOnly => (),
187            TensorStatus::NotInit => (),
188            TensorStatus::ReadWrite => {
189                self.handles.remove(&tensor.id);
190            }
191        }
192    }
193
194    /// Remove tensor handle from container if not in use
195    pub fn free_orphans(&mut self, remaining: &[&TensorId]) {
196        let mut handles_orphan = Vec::new();
197
198        // TODO: Optimization => Change the for loop order depending of the length of each.
199        for id in self.handles_orphan.drain(..) {
200            if remaining.contains(&&id) {
201                handles_orphan.push(id);
202            } else {
203                self.handles.remove(&id);
204            }
205        }
206
207        self.handles_orphan = handles_orphan;
208    }
209}