burn_tensor/repr/
backend.rs

1use crate::{
2    backend::Backend,
3    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
4    Shape,
5};
6
7/// A tensor representation containing a reference to a tensor resource with a given shape.
8#[derive(Clone)]
9pub struct TensorHandle<H: Clone> {
10    /// The type that can be used to point to a tensor of any kind.
11    pub handle: H,
12    /// The shape associated to the tensor.
13    pub shape: Shape,
14}
15
16/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
17/// for compilation purpose or other...
18pub trait ReprBackend: Backend {
19    /// The type that can be used to point to a tensor of any kind.
20    type Handle: Sync + Send + Clone;
21
22    /// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
23    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
24    /// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
25    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
26    /// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
27    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
28    /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
29    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self>;
30
31    /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle).
32    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
33    /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](ReprBackend::Handle).
34    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
35    /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle).
36    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
37    /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
38    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle;
39}
40
41/// Handle which points to a backend tensor primitive kind.
42#[derive(Clone, Debug)]
43pub enum HandleKind<B: Backend> {
44    /// Float tensor handle.
45    Float(B::FloatTensorPrimitive),
46    /// Int tensor handle.
47    Int(B::IntTensorPrimitive),
48    /// Bool tensor handle.
49    Bool(B::BoolTensorPrimitive),
50    /// Quantized tensor handle.
51    Quantized(B::QuantizedTensorPrimitive),
52}
53
54impl<B: Backend> HandleKind<B> {
55    /// Returns the handle kind name.
56    pub fn name(&self) -> &str {
57        match self {
58            HandleKind::Float(_) => "float",
59            HandleKind::Int(_) => "int",
60            HandleKind::Bool(_) => "bool",
61            HandleKind::Quantized(_) => "quantized",
62        }
63    }
64}