burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
use burn_std::quantization::QuantScheme;

use crate::ffi::{self, Id};
use crate::MpsGraphDevice;

/// Element size in bytes for a given DType.
pub fn elem_size(dtype: DType) -> usize {
    match dtype {
        DType::F64 => 8,
        DType::F32 | DType::Flex32 => 4,
        DType::F16 | DType::BF16 => 2,
        DType::I64 | DType::U64 => 8,
        DType::I32 | DType::U32 => 4,
        DType::I16 | DType::U16 => 2,
        DType::I8 | DType::U8 | DType::Bool => 1,
        _ => panic!("Unsupported dtype: {:?}", dtype),
    }
}

/// GPU-resident tensor backed by an `MTLBuffer`.
///
/// On Apple Silicon the buffer uses shared memory so both CPU and GPU can
/// access the same physical pages — no PCIe copy.
pub struct MpsGraphTensor {
    /// Retained MTLBuffer pointer.
    pub buffer: Id,
    pub shape: Shape,
    pub dtype: DType,
    pub device: MpsGraphDevice,
}

// MTLBuffer is thread-safe (shared storage mode).
unsafe impl Send for MpsGraphTensor {}
unsafe impl Sync for MpsGraphTensor {}

impl MpsGraphTensor {
    pub fn num_elements(&self) -> usize { self.shape.num_elements() }
}

impl Clone for MpsGraphTensor {
    fn clone(&self) -> Self {
        Self {
            buffer: unsafe { ffi::retain(self.buffer) },
            shape: self.shape.clone(),
            dtype: self.dtype,
            device: self.device,
        }
    }
}

impl Drop for MpsGraphTensor {
    fn drop(&mut self) {
        unsafe { ffi::release(self.buffer); }
    }
}

impl std::fmt::Debug for MpsGraphTensor {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("MpsGraphTensor")
            .field("shape", &self.shape)
            .field("dtype", &self.dtype)
            .field("device", &self.device)
            .finish()
    }
}

impl TensorMetadata for MpsGraphTensor {
    fn dtype(&self) -> DType { self.dtype }
    fn shape(&self) -> Shape { self.shape.clone() }
    fn rank(&self) -> usize { self.shape.num_dims() }
}

/// Quantized tensor.
#[derive(Debug, Clone)]
pub struct MpsGraphQTensor {
    pub tensor: MpsGraphTensor,
    pub scheme: QuantScheme,
}

impl TensorMetadata for MpsGraphQTensor {
    fn dtype(&self) -> DType { DType::QFloat(self.scheme) }
    fn shape(&self) -> Shape { self.tensor.shape.clone() }
    fn rank(&self) -> usize { self.tensor.shape.num_dims() }
}

impl QTensorPrimitive for MpsGraphQTensor {
    fn scheme(&self) -> &QuantScheme { &self.scheme }
}