Skip to main content

burn_mpsgraph/
tensor.rs

1use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
2use burn_std::quantization::QuantScheme;
3
4use crate::ffi::{self, Id};
5use crate::MpsGraphDevice;
6
7/// Element size in bytes for a given DType.
8pub fn elem_size(dtype: DType) -> usize {
9    match dtype {
10        DType::F64 => 8,
11        DType::F32 | DType::Flex32 => 4,
12        DType::F16 | DType::BF16 => 2,
13        DType::I64 | DType::U64 => 8,
14        DType::I32 | DType::U32 => 4,
15        DType::I16 | DType::U16 => 2,
16        DType::I8 | DType::U8 | DType::Bool => 1,
17        _ => panic!("Unsupported dtype: {:?}", dtype),
18    }
19}
20
21/// GPU-resident tensor backed by an `MTLBuffer`.
22///
23/// On Apple Silicon the buffer uses shared memory so both CPU and GPU can
24/// access the same physical pages — no PCIe copy.
25pub struct MpsGraphTensor {
26    /// Retained MTLBuffer pointer.
27    pub buffer: Id,
28    pub shape: Shape,
29    pub dtype: DType,
30    pub device: MpsGraphDevice,
31}
32
33// MTLBuffer is thread-safe (shared storage mode).
34unsafe impl Send for MpsGraphTensor {}
35unsafe impl Sync for MpsGraphTensor {}
36
37impl MpsGraphTensor {
38    pub fn num_elements(&self) -> usize { self.shape.num_elements() }
39}
40
41impl Clone for MpsGraphTensor {
42    fn clone(&self) -> Self {
43        Self {
44            buffer: unsafe { ffi::retain(self.buffer) },
45            shape: self.shape.clone(),
46            dtype: self.dtype,
47            device: self.device,
48        }
49    }
50}
51
52impl Drop for MpsGraphTensor {
53    fn drop(&mut self) {
54        unsafe { ffi::release(self.buffer); }
55    }
56}
57
58impl std::fmt::Debug for MpsGraphTensor {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("MpsGraphTensor")
61            .field("shape", &self.shape)
62            .field("dtype", &self.dtype)
63            .field("device", &self.device)
64            .finish()
65    }
66}
67
68impl TensorMetadata for MpsGraphTensor {
69    fn dtype(&self) -> DType { self.dtype }
70    fn shape(&self) -> Shape { self.shape.clone() }
71    fn rank(&self) -> usize { self.shape.num_dims() }
72}
73
74/// Quantized tensor.
75#[derive(Debug, Clone)]
76pub struct MpsGraphQTensor {
77    pub tensor: MpsGraphTensor,
78    pub scheme: QuantScheme,
79}
80
81impl TensorMetadata for MpsGraphQTensor {
82    fn dtype(&self) -> DType { DType::QFloat(self.scheme) }
83    fn shape(&self) -> Shape { self.tensor.shape.clone() }
84    fn rank(&self) -> usize { self.tensor.shape.num_dims() }
85}
86
87impl QTensorPrimitive for MpsGraphQTensor {
88    fn scheme(&self) -> &QuantScheme { &self.scheme }
89}