burn_tensor/repr/
handle.rsuse crate::{
repr::{
backend::ReprBackend,
tensor::{TensorDescription, TensorId, TensorStatus},
},
Shape,
};
use alloc::vec::Vec;
use hashbrown::HashMap;
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle};
#[derive(Default)]
pub struct HandleContainer<H> {
handles: HashMap<TensorId, Handle<H>>,
counter: u64,
pub handles_orphan: Vec<TensorId>,
}
impl<H> core::fmt::Debug for HandleContainer<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HandleContainer")
.field("handles", &self.handles.keys()) .field("counter", &self.counter)
.field("handles_orphan", &self.handles_orphan)
.finish()
}
}
pub enum Handle<H> {
NotInit,
Existing(H),
}
impl<H: Clone> HandleContainer<H> {
pub fn new() -> Self {
Self {
handles: HashMap::new(),
handles_orphan: Vec::new(),
counter: 0,
}
}
pub fn register_handle(&mut self, id: TensorId, handle: H) {
self.handles.insert(id, Handle::Existing(handle));
}
pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
let (id, handle) = self
.handles
.remove_entry(id)
.unwrap_or_else(|| panic!("Should have handle for tensor {:?}", id));
match handle {
Handle::Existing(handle) => match status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
handle
}
TensorStatus::ReadWrite => handle,
TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."),
},
Handle::NotInit => panic!("Cannot get uninitialized handle."),
}
}
pub fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
TensorHandle {
handle: self.get_handle(&tensor.id, &tensor.status),
shape: Shape::from(&tensor.shape),
}
}
pub fn get_float_tensor<B>(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive
where
B: ReprBackend<Handle = H>,
{
B::float_tensor(self.get_tensor_handle(tensor))
}
pub fn get_int_tensor<B>(&mut self, tensor: &TensorDescription) -> B::IntTensorPrimitive
where
B: ReprBackend<Handle = H>,
{
B::int_tensor(self.get_tensor_handle(tensor))
}
pub fn get_bool_tensor<B>(&mut self, tensor: &TensorDescription) -> B::BoolTensorPrimitive
where
B: ReprBackend<Handle = H>,
{
B::bool_tensor(self.get_tensor_handle(tensor))
}
pub fn get_quantized_tensor<B>(
&mut self,
tensor: &QuantizedTensorDescription,
) -> B::QuantizedTensorPrimitive
where
B: ReprBackend<Handle = H>,
{
let handles = QuantizedKind {
tensor: self.get_tensor_handle(&tensor.tensor),
scale: self.get_tensor_handle(&tensor.qparams.scale),
offset: tensor
.qparams
.offset
.as_ref()
.map(|offset| self.get_tensor_handle(offset)),
};
B::quantized_tensor(handles, tensor.scheme.clone())
}
pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)
where
B: ReprBackend<Handle = H>,
{
let handle = B::float_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
pub fn register_quantized_tensor<B>(
&mut self,
id: &QuantizedKind<TensorId>,
tensor: B::QuantizedTensorPrimitive,
) where
B: ReprBackend<Handle = H>,
{
let handles = B::quantized_tensor_handle(tensor);
self.handles
.insert(id.tensor, Handle::Existing(handles.tensor));
self.handles
.insert(id.scale, Handle::Existing(handles.scale));
if let (Some(id), Some(handle)) = (id.offset, handles.offset) {
self.handles.insert(id, Handle::Existing(handle));
}
}
pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)
where
B: ReprBackend<Handle = H>,
{
let handle = B::int_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)
where
B: ReprBackend<Handle = H>,
{
let handle = B::bool_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
pub fn create_tensor_uninit(&mut self) -> Arc<TensorId> {
let id = TensorId::new(self.counter);
self.counter += 1;
self.handles.insert(id, Handle::NotInit);
Arc::new(id)
}
pub fn free(&mut self, tensor: &TensorDescription) {
match tensor.status {
TensorStatus::ReadOnly => (),
TensorStatus::NotInit => (),
TensorStatus::ReadWrite => {
self.handles.remove(&tensor.id);
}
}
}
pub fn free_orphans(&mut self, remaining: &[&TensorId]) {
let mut handles_orphan = Vec::new();
for id in self.handles_orphan.drain(..) {
if remaining.contains(&&id) {
handles_orphan.push(id);
} else {
self.handles.remove(&id);
}
}
self.handles_orphan = handles_orphan;
}
}