burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
use alloc::string::String;
use burn_backend::{Backend, DType, DTypeUsage, DTypeUsageSet};
use std::sync::Mutex;

use crate::{MpsGraphDevice, MpsGraphQTensor, MpsGraphTensor};

extern crate alloc;

/// Globally seeded RNG, protected by a mutex.
pub(crate) static SEED: Mutex<Option<rand::rngs::StdRng>> = Mutex::new(None);

/// Apple MPSGraph backend for Burn.
#[derive(Clone, Copy, Default, Debug)]
pub struct MpsGraph;

impl Backend for MpsGraph {
    type Device = MpsGraphDevice;

    type FloatTensorPrimitive = MpsGraphTensor;
    type FloatElem = f32;
    type IntTensorPrimitive = MpsGraphTensor;
    type IntElem = i32;
    type BoolTensorPrimitive = MpsGraphTensor;
    type BoolElem = bool;
    type QuantizedTensorPrimitive = MpsGraphQTensor;

    fn name(_device: &Self::Device) -> String { String::from("mpsgraph") }

    fn seed(_device: &Self::Device, seed: u64) {
        use rand::SeedableRng;
        *SEED.lock().unwrap() = Some(rand::rngs::StdRng::seed_from_u64(seed));
    }

    fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
        match dtype {
            DType::F32 | DType::Flex32 => DTypeUsage::general() | DTypeUsage::Accelerated,
            DType::F16              => DTypeUsage::general() | DTypeUsage::Accelerated,
            DType::BF16             => DTypeUsage::Storage | DTypeUsage::Accelerated,
            DType::I32 | DType::I16 | DType::I8 => DTypeUsage::general(),
            DType::U32 | DType::U16 | DType::U8 => DTypeUsage::general(),
            DType::I64 | DType::U64 => DTypeUsage::general(),
            DType::Bool => DTypeUsage::general(),
            DType::F64 => DTypeUsageSet::empty(), // Not supported by Metal
            DType::QFloat(_) => DTypeUsageSet::empty(),
        }
    }
}