1use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
2use burn_std::quantization::QuantScheme;
3
4use crate::ffi::{self, Id};
5use crate::MpsGraphDevice;
6
7pub fn elem_size(dtype: DType) -> usize {
9 match dtype {
10 DType::F64 => 8,
11 DType::F32 | DType::Flex32 => 4,
12 DType::F16 | DType::BF16 => 2,
13 DType::I64 | DType::U64 => 8,
14 DType::I32 | DType::U32 => 4,
15 DType::I16 | DType::U16 => 2,
16 DType::I8 | DType::U8 | DType::Bool => 1,
17 _ => panic!("Unsupported dtype: {:?}", dtype),
18 }
19}
20
21pub struct MpsGraphTensor {
26 pub buffer: Id,
28 pub shape: Shape,
29 pub dtype: DType,
30 pub device: MpsGraphDevice,
31}
32
33unsafe impl Send for MpsGraphTensor {}
35unsafe impl Sync for MpsGraphTensor {}
36
37impl MpsGraphTensor {
38 pub fn num_elements(&self) -> usize { self.shape.num_elements() }
39}
40
41impl Clone for MpsGraphTensor {
42 fn clone(&self) -> Self {
43 Self {
44 buffer: unsafe { ffi::retain(self.buffer) },
45 shape: self.shape.clone(),
46 dtype: self.dtype,
47 device: self.device,
48 }
49 }
50}
51
52impl Drop for MpsGraphTensor {
53 fn drop(&mut self) {
54 unsafe { ffi::release(self.buffer); }
55 }
56}
57
58impl std::fmt::Debug for MpsGraphTensor {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("MpsGraphTensor")
61 .field("shape", &self.shape)
62 .field("dtype", &self.dtype)
63 .field("device", &self.device)
64 .finish()
65 }
66}
67
68impl TensorMetadata for MpsGraphTensor {
69 fn dtype(&self) -> DType { self.dtype }
70 fn shape(&self) -> Shape { self.shape.clone() }
71 fn rank(&self) -> usize { self.shape.num_dims() }
72}
73
74#[derive(Debug, Clone)]
76pub struct MpsGraphQTensor {
77 pub tensor: MpsGraphTensor,
78 pub scheme: QuantScheme,
79}
80
81impl TensorMetadata for MpsGraphQTensor {
82 fn dtype(&self) -> DType { DType::QFloat(self.scheme) }
83 fn shape(&self) -> Shape { self.tensor.shape.clone() }
84 fn rank(&self) -> usize { self.tensor.shape.num_dims() }
85}
86
87impl QTensorPrimitive for MpsGraphQTensor {
88 fn scheme(&self) -> &QuantScheme { &self.scheme }
89}