use burn_tensor::backend::Backend;
use burn_tensor::TensorMetadata;
use mlx_rs::Array;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::device::MlxDevice;
static SEED: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone)]
pub struct MlxTensorPrimitive {
pub array: Array,
pub shape: Vec<usize>,
}
impl MlxTensorPrimitive {
pub fn new(array: Array) -> Self {
let shape = array.shape().iter().map(|&s| s as usize).collect();
Self { array, shape }
}
pub fn array(&self) -> &Array {
&self.array
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
}
unsafe impl Send for MlxTensorPrimitive {}
unsafe impl Sync for MlxTensorPrimitive {}
impl TensorMetadata for MlxTensorPrimitive {
fn dtype(&self) -> burn_tensor::DType {
match self.array.dtype() {
mlx_rs::Dtype::Float32 => burn_tensor::DType::F32,
mlx_rs::Dtype::Float16 => burn_tensor::DType::F16,
mlx_rs::Dtype::Bfloat16 => burn_tensor::DType::BF16,
mlx_rs::Dtype::Float64 => burn_tensor::DType::F64,
mlx_rs::Dtype::Int32 => burn_tensor::DType::I32,
mlx_rs::Dtype::Int64 => burn_tensor::DType::I64,
mlx_rs::Dtype::Bool => burn_tensor::DType::Bool,
_ => burn_tensor::DType::F32, }
}
fn shape(&self) -> burn_tensor::Shape {
burn_tensor::Shape::from(self.shape.clone())
}
}
#[derive(Debug, Clone)]
pub struct MlxQuantizedTensorPrimitive {
pub tensor: MlxTensorPrimitive,
pub scheme: QuantizationScheme,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum QuantizationScheme {
#[default]
None,
}
unsafe impl Send for MlxQuantizedTensorPrimitive {}
unsafe impl Sync for MlxQuantizedTensorPrimitive {}
impl TensorMetadata for MlxQuantizedTensorPrimitive {
fn dtype(&self) -> burn_tensor::DType {
self.tensor.dtype()
}
fn shape(&self) -> burn_tensor::Shape {
burn_tensor::Shape::from(self.tensor.shape.clone())
}
}
impl burn_tensor::quantization::QTensorPrimitive for MlxQuantizedTensorPrimitive {
fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
static SYMMETRIC: burn_tensor::quantization::QuantizationScheme =
burn_tensor::quantization::QuantizationScheme::PerTensorSymmetric(
burn_tensor::quantization::QuantizationType::QInt8,
);
&SYMMETRIC
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Mlx;
impl Backend for Mlx {
type Device = MlxDevice;
type FloatTensorPrimitive = MlxTensorPrimitive;
type FloatElem = f32;
type IntTensorPrimitive = MlxTensorPrimitive;
type IntElem = i32;
type BoolTensorPrimitive = MlxTensorPrimitive;
type BoolElem = bool;
type QuantizedTensorPrimitive = MlxQuantizedTensorPrimitive;
type QuantizedEncoding = i8;
fn name() -> String {
"mlx".to_string()
}
fn seed(seed: u64) {
SEED.store(seed, Ordering::SeqCst);
mlx_rs::random::seed(seed);
}
fn sync(device: &Self::Device) {
let _ = device;
}
}
pub fn get_seed() -> u64 {
SEED.load(Ordering::SeqCst)
}