1use alloc::string::String;
2use burn_backend::{Backend, DType, DTypeUsage, DTypeUsageSet};
3use std::sync::Mutex;
4
5use crate::{MpsGraphDevice, MpsGraphQTensor, MpsGraphTensor};
6
7extern crate alloc;
8
9pub(crate) static SEED: Mutex<Option<rand::rngs::StdRng>> = Mutex::new(None);
11
12#[derive(Clone, Copy, Default, Debug)]
14pub struct MpsGraph;
15
16impl Backend for MpsGraph {
17 type Device = MpsGraphDevice;
18
19 type FloatTensorPrimitive = MpsGraphTensor;
20 type FloatElem = f32;
21 type IntTensorPrimitive = MpsGraphTensor;
22 type IntElem = i32;
23 type BoolTensorPrimitive = MpsGraphTensor;
24 type BoolElem = bool;
25 type QuantizedTensorPrimitive = MpsGraphQTensor;
26
27 fn name(_device: &Self::Device) -> String { String::from("mpsgraph") }
28
29 fn seed(_device: &Self::Device, seed: u64) {
30 use rand::SeedableRng;
31 *SEED.lock().unwrap() = Some(rand::rngs::StdRng::seed_from_u64(seed));
32 }
33
34 fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
35 match dtype {
36 DType::F32 | DType::Flex32 => DTypeUsage::general() | DTypeUsage::Accelerated,
37 DType::F16 => DTypeUsage::general() | DTypeUsage::Accelerated,
38 DType::BF16 => DTypeUsage::Storage | DTypeUsage::Accelerated,
39 DType::I32 | DType::I16 | DType::I8 => DTypeUsage::general(),
40 DType::U32 | DType::U16 | DType::U8 => DTypeUsage::general(),
41 DType::I64 | DType::U64 => DTypeUsage::general(),
42 DType::Bool => DTypeUsage::general(),
43 DType::F64 => DTypeUsageSet::empty(), DType::QFloat(_) => DTypeUsageSet::empty(),
45 }
46 }
47}