Skip to main content

burn_ir/
backend.rs

1use burn_backend::{
2    Backend, Shape,
3    tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
4};
5
6/// A tensor representation containing a reference to a tensor resource with a given shape.
7#[derive(Clone)]
8pub struct TensorHandle<H: Clone> {
9    /// The type that can be used to point to a tensor of any kind.
10    pub handle: H,
11    /// The shape associated to the tensor.
12    pub shape: Shape,
13}
14
15/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor
16/// intermediate representation for compilation purpose or other...
17pub trait BackendIr: Backend {
18    /// The type that can be used to point to a tensor of any kind.
19    type Handle: Sync + Send + Clone;
20
21    /// Convert a [handle](BackendIr::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
22    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
23    /// Convert a [handle](BackendIr::Handle) to an [int tensor](Backend::IntTensorPrimitive).
24    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
25    /// Convert a [handle](BackendIr::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
26    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
27    /// Convert a [handle](BackendIr::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
28    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self>;
29
30    /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](BackendIr::Handle).
31    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
32    /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](BackendIr::Handle).
33    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
34    /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](BackendIr::Handle).
35    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
36    /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](BackendIr::Handle).
37    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle;
38}
39
40/// Handle which points to a backend tensor primitive kind.
41#[derive(Clone, Debug)]
42pub enum HandleKind<B: Backend> {
43    /// Float tensor handle.
44    Float(B::FloatTensorPrimitive),
45    /// Int tensor handle.
46    Int(B::IntTensorPrimitive),
47    /// Bool tensor handle.
48    Bool(B::BoolTensorPrimitive),
49    /// Quantized tensor handle.
50    Quantized(B::QuantizedTensorPrimitive),
51}
52
53impl<B: Backend> HandleKind<B> {
54    /// Returns the handle kind name.
55    pub fn name(&self) -> &str {
56        match self {
57            HandleKind::Float(_) => "float",
58            HandleKind::Int(_) => "int",
59            HandleKind::Bool(_) => "bool",
60            HandleKind::Quantized(_) => "quantized",
61        }
62    }
63}