use std::marker::PhantomData;
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
Element, Shape, TensorData,
};
use crate::{element::CandleElement, CandleDevice};
#[derive(Debug, Clone)]
pub struct CandleTensor<E: CandleElement, const D: usize> {
pub(crate) tensor: candle_core::Tensor,
phantom: PhantomData<E>,
}
impl<E: CandleElement, const D: usize> CandleTensor<E, D> {
pub fn new(tensor: candle_core::Tensor) -> Self {
Self {
tensor,
phantom: PhantomData,
}
}
pub fn from_data(data: TensorData, device: CandleDevice) -> Self {
let candle_shape: candle_core::Shape = data.shape.clone().into();
let tensor = candle_core::Tensor::from_slice(
data.convert::<E>().as_slice::<E>().unwrap(),
candle_shape,
&device.into(),
);
Self::new(tensor.unwrap())
}
pub(crate) fn shape(&self) -> Shape<D> {
let x: [usize; D] = self.tensor.dims().try_into().unwrap();
Shape::from(x)
}
}
#[derive(Clone, Debug)]
pub struct CandleQTensor<const D: usize> {
pub qtensor: CandleTensor<u8, D>,
pub scheme: QuantizationScheme,
}
impl<const D: usize> QTensorPrimitive for CandleQTensor<D> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}