burn_ir/backend.rs
1use burn_backend::{
2 Backend, Shape,
3 tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
4};
5
6/// A tensor representation containing a reference to a tensor resource with a given shape.
7#[derive(Clone)]
8pub struct TensorHandle<H: Clone> {
9 /// The type that can be used to point to a tensor of any kind.
10 pub handle: H,
11 /// The shape associated to the tensor.
12 pub shape: Shape,
13}
14
15/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor
16/// intermediate representation for compilation purpose or other...
17pub trait BackendIr: Backend {
18 /// The type that can be used to point to a tensor of any kind.
19 type Handle: Sync + Send + Clone;
20
21 /// Convert a [handle](BackendIr::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
22 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
23 /// Convert a [handle](BackendIr::Handle) to an [int tensor](Backend::IntTensorPrimitive).
24 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
25 /// Convert a [handle](BackendIr::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
26 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
27 /// Convert a [handle](BackendIr::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
28 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self>;
29
30 /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](BackendIr::Handle).
31 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
32 /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](BackendIr::Handle).
33 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
34 /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](BackendIr::Handle).
35 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
36 /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](BackendIr::Handle).
37 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle;
38}
39
40/// Handle which points to a backend tensor primitive kind.
41#[derive(Clone, Debug)]
42pub enum HandleKind<B: Backend> {
43 /// Float tensor handle.
44 Float(B::FloatTensorPrimitive),
45 /// Int tensor handle.
46 Int(B::IntTensorPrimitive),
47 /// Bool tensor handle.
48 Bool(B::BoolTensorPrimitive),
49 /// Quantized tensor handle.
50 Quantized(B::QuantizedTensorPrimitive),
51}
52
53impl<B: Backend> HandleKind<B> {
54 /// Returns the handle kind name.
55 pub fn name(&self) -> &str {
56 match self {
57 HandleKind::Float(_) => "float",
58 HandleKind::Int(_) => "int",
59 HandleKind::Bool(_) => "bool",
60 HandleKind::Quantized(_) => "quantized",
61 }
62 }
63}