Skip to main content

burn_mpsgraph/
backend.rs

1use alloc::string::String;
2use burn_backend::{Backend, DType, DTypeUsage, DTypeUsageSet};
3use std::sync::Mutex;
4
5use crate::{MpsGraphDevice, MpsGraphQTensor, MpsGraphTensor};
6
7extern crate alloc;
8
9/// Globally seeded RNG, protected by a mutex.
10pub(crate) static SEED: Mutex<Option<rand::rngs::StdRng>> = Mutex::new(None);
11
12/// Apple MPSGraph backend for Burn.
13#[derive(Clone, Copy, Default, Debug)]
14pub struct MpsGraph;
15
16impl Backend for MpsGraph {
17    type Device = MpsGraphDevice;
18
19    type FloatTensorPrimitive = MpsGraphTensor;
20    type FloatElem = f32;
21    type IntTensorPrimitive = MpsGraphTensor;
22    type IntElem = i32;
23    type BoolTensorPrimitive = MpsGraphTensor;
24    type BoolElem = bool;
25    type QuantizedTensorPrimitive = MpsGraphQTensor;
26
27    fn name(_device: &Self::Device) -> String { String::from("mpsgraph") }
28
29    fn seed(_device: &Self::Device, seed: u64) {
30        use rand::SeedableRng;
31        *SEED.lock().unwrap() = Some(rand::rngs::StdRng::seed_from_u64(seed));
32    }
33
34    fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
35        match dtype {
36            DType::F32 | DType::Flex32 => DTypeUsage::general() | DTypeUsage::Accelerated,
37            DType::F16              => DTypeUsage::general() | DTypeUsage::Accelerated,
38            DType::BF16             => DTypeUsage::Storage | DTypeUsage::Accelerated,
39            DType::I32 | DType::I16 | DType::I8 => DTypeUsage::general(),
40            DType::U32 | DType::U16 | DType::U8 => DTypeUsage::general(),
41            DType::I64 | DType::U64 => DTypeUsage::general(),
42            DType::Bool => DTypeUsage::general(),
43            DType::F64 => DTypeUsageSet::empty(), // Not supported by Metal
44            DType::QFloat(_) => DTypeUsageSet::empty(),
45        }
46    }
47}