use alloc::string::String;
use burn_backend::{Backend, DType, DTypeUsage, DTypeUsageSet};
use std::sync::Mutex;
use crate::{MpsGraphDevice, MpsGraphQTensor, MpsGraphTensor};
extern crate alloc;
pub(crate) static SEED: Mutex<Option<rand::rngs::StdRng>> = Mutex::new(None);
#[derive(Clone, Copy, Default, Debug)]
pub struct MpsGraph;
impl Backend for MpsGraph {
type Device = MpsGraphDevice;
type FloatTensorPrimitive = MpsGraphTensor;
type FloatElem = f32;
type IntTensorPrimitive = MpsGraphTensor;
type IntElem = i32;
type BoolTensorPrimitive = MpsGraphTensor;
type BoolElem = bool;
type QuantizedTensorPrimitive = MpsGraphQTensor;
fn name(_device: &Self::Device) -> String { String::from("mpsgraph") }
fn seed(_device: &Self::Device, seed: u64) {
use rand::SeedableRng;
*SEED.lock().unwrap() = Some(rand::rngs::StdRng::seed_from_u64(seed));
}
fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
match dtype {
DType::F32 | DType::Flex32 => DTypeUsage::general() | DTypeUsage::Accelerated,
DType::F16 => DTypeUsage::general() | DTypeUsage::Accelerated,
DType::BF16 => DTypeUsage::Storage | DTypeUsage::Accelerated,
DType::I32 | DType::I16 | DType::I8 => DTypeUsage::general(),
DType::U32 | DType::U16 | DType::U8 => DTypeUsage::general(),
DType::I64 | DType::U64 => DTypeUsage::general(),
DType::Bool => DTypeUsage::general(),
DType::F64 => DTypeUsageSet::empty(), DType::QFloat(_) => DTypeUsageSet::empty(),
}
}
}