use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
use burn_std::quantization::QuantScheme;
use crate::ffi::{self, Id};
use crate::MpsGraphDevice;
pub fn elem_size(dtype: DType) -> usize {
match dtype {
DType::F64 => 8,
DType::F32 | DType::Flex32 => 4,
DType::F16 | DType::BF16 => 2,
DType::I64 | DType::U64 => 8,
DType::I32 | DType::U32 => 4,
DType::I16 | DType::U16 => 2,
DType::I8 | DType::U8 | DType::Bool => 1,
_ => panic!("Unsupported dtype: {:?}", dtype),
}
}
pub struct MpsGraphTensor {
pub buffer: Id,
pub shape: Shape,
pub dtype: DType,
pub device: MpsGraphDevice,
}
unsafe impl Send for MpsGraphTensor {}
unsafe impl Sync for MpsGraphTensor {}
impl MpsGraphTensor {
pub fn num_elements(&self) -> usize { self.shape.num_elements() }
}
impl Clone for MpsGraphTensor {
fn clone(&self) -> Self {
Self {
buffer: unsafe { ffi::retain(self.buffer) },
shape: self.shape.clone(),
dtype: self.dtype,
device: self.device,
}
}
}
impl Drop for MpsGraphTensor {
fn drop(&mut self) {
unsafe { ffi::release(self.buffer); }
}
}
impl std::fmt::Debug for MpsGraphTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MpsGraphTensor")
.field("shape", &self.shape)
.field("dtype", &self.dtype)
.field("device", &self.device)
.finish()
}
}
impl TensorMetadata for MpsGraphTensor {
fn dtype(&self) -> DType { self.dtype }
fn shape(&self) -> Shape { self.shape.clone() }
fn rank(&self) -> usize { self.shape.num_dims() }
}
#[derive(Debug, Clone)]
pub struct MpsGraphQTensor {
pub tensor: MpsGraphTensor,
pub scheme: QuantScheme,
}
impl TensorMetadata for MpsGraphQTensor {
fn dtype(&self) -> DType { DType::QFloat(self.scheme) }
fn shape(&self) -> Shape { self.tensor.shape.clone() }
fn rank(&self) -> usize { self.tensor.shape.num_dims() }
}
impl QTensorPrimitive for MpsGraphQTensor {
fn scheme(&self) -> &QuantScheme { &self.scheme }
}