use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme};
use burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata};
use crate::{CandleDevice, element::CandleElement};
#[derive(Debug, Clone)]
pub struct CandleTensor {
pub(crate) tensor: candle_core::Tensor,
}
impl TensorMetadata for CandleTensor {
fn dtype(&self) -> DType {
match self.tensor.dtype() {
candle_core::DType::U8 => DType::U8,
candle_core::DType::U32 => DType::U32,
candle_core::DType::I64 => DType::I64,
candle_core::DType::BF16 => DType::BF16,
candle_core::DType::F16 => DType::F16,
candle_core::DType::F32 => DType::F32,
candle_core::DType::F64 => DType::F64,
}
}
fn shape(&self) -> Shape {
Shape::from(self.tensor.dims().to_vec())
}
fn rank(&self) -> usize {
self.tensor.dims().len()
}
}
impl QTensorPrimitive for CandleTensor {
fn scheme(&self) -> &QuantScheme {
unimplemented!("Quantization is not supported")
}
}
impl CandleTensor {
pub fn new(tensor: candle_core::Tensor) -> Self {
Self { tensor }
}
pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
let candle_shape: candle_core::Shape = data.shape.clone().into();
let tensor = candle_core::Tensor::from_slice(
data.as_slice::<E>().unwrap(),
candle_shape,
&device.into(),
);
Self::new(tensor.unwrap())
}
}
pub(crate) trait IntoDType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error>;
fn into_dtype(self) -> candle_core::DType
where
Self: Sized,
{
self.try_into_dtype().unwrap()
}
}
impl IntoDType for IntDType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
let dtype: DType = self.into();
dtype.try_into_dtype()
}
}
impl IntoDType for FloatDType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
let dtype: DType = self.into();
dtype.try_into_dtype()
}
}
impl IntoDType for DType {
fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
match self {
DType::F64 => Ok(candle_core::DType::F64),
DType::F32 => Ok(candle_core::DType::F32),
DType::Flex32 => Ok(candle_core::DType::F32),
DType::F16 => Ok(candle_core::DType::F16),
DType::BF16 => Ok(candle_core::DType::BF16),
DType::I64 => Ok(candle_core::DType::I64),
DType::U32 => Ok(candle_core::DType::U32),
DType::U8 => Ok(candle_core::DType::U8),
_ => Err(candle_core::Error::Msg(format!(
"Unsupported dtype {self:?}"
))),
}
}
}