use burn_backend::{
Backend, Shape,
tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
};
#[derive(Clone)]
pub struct TensorHandle<H: Clone> {
pub handle: H,
pub shape: Shape,
}
pub trait BackendIr: Backend {
type Handle: Sync + Send + Clone;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self>;
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle;
}
#[derive(Clone, Debug)]
pub enum HandleKind<B: Backend> {
Float(B::FloatTensorPrimitive),
Int(B::IntTensorPrimitive),
Bool(B::BoolTensorPrimitive),
Quantized(B::QuantizedTensorPrimitive),
}
impl<B: Backend> HandleKind<B> {
pub fn name(&self) -> &str {
match self {
HandleKind::Float(_) => "float",
HandleKind::Int(_) => "int",
HandleKind::Bool(_) => "bool",
HandleKind::Quantized(_) => "quantized",
}
}
}