burn_tensor/repr/backend.rs
1use crate::{
2 backend::Backend,
3 ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
4 Shape,
5};
6
7/// A tensor representation containing a reference to a tensor resource with a given shape.
8#[derive(Clone)]
9pub struct TensorHandle<H: Clone> {
10 /// The type that can be used to point to a tensor of any kind.
11 pub handle: H,
12 /// The shape associated to the tensor.
13 pub shape: Shape,
14}
15
16/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
17/// for compilation purpose or other...
18pub trait ReprBackend: Backend {
19 /// The type that can be used to point to a tensor of any kind.
20 type Handle: Sync + Send + Clone;
21
22 /// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
23 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
24 /// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
25 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
26 /// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
27 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
28 /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
29 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self>;
30
31 /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle).
32 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
33 /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](ReprBackend::Handle).
34 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
35 /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle).
36 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
37 /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
38 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle;
39}
40
41/// Handle which points to a backend tensor primitive kind.
42#[derive(Clone, Debug)]
43pub enum HandleKind<B: Backend> {
44 /// Float tensor handle.
45 Float(B::FloatTensorPrimitive),
46 /// Int tensor handle.
47 Int(B::IntTensorPrimitive),
48 /// Bool tensor handle.
49 Bool(B::BoolTensorPrimitive),
50 /// Quantized tensor handle.
51 Quantized(B::QuantizedTensorPrimitive),
52}
53
54impl<B: Backend> HandleKind<B> {
55 /// Returns the handle kind name.
56 pub fn name(&self) -> &str {
57 match self {
58 HandleKind::Float(_) => "float",
59 HandleKind::Int(_) => "int",
60 HandleKind::Bool(_) => "bool",
61 HandleKind::Quantized(_) => "quantized",
62 }
63 }
64}