use burn_backend::ops::{FloatTensorOps, QTensorOps};
use burn_backend::tensor::{Device, FloatTensor, QuantizedTensor};
use burn_backend::{DType, ExecutionError, TensorData};
use burn_std::quantization::QuantScheme;
use burn_std::{Shape, Slice};
use std::future::Future;
use crate::bridge;
use crate::{MpsGraph, MpsGraphQTensor, MpsGraphTensor};
impl QTensorOps<MpsGraph> for MpsGraph {
fn q_from_data(data: TensorData, device: &Device<MpsGraph>) -> QuantizedTensor<MpsGraph> {
let scheme = match data.dtype { DType::QFloat(s) => s, _ => panic!("Expected QFloat") };
let shape = Shape::from(data.shape.clone());
MpsGraphQTensor {
tensor: bridge::tensor_from_bytes(data.as_bytes(), shape, DType::I8, *device),
scheme,
}
}
fn quantize(tensor: FloatTensor<MpsGraph>, scheme: &QuantScheme, qparams: burn_backend::tensor::quantization::QuantizationParametersPrimitive<MpsGraph>) -> QuantizedTensor<MpsGraph> {
let sb = bridge::tensor_to_bytes(&qparams.scales);
let scale = f32::from_le_bytes([sb[0],sb[1],sb[2],sb[3]]);
let fb = bridge::tensor_to_bytes(&tensor);
let n = tensor.num_elements();
let mut qb = vec![0u8; n];
for i in 0..n {
let f = f32::from_le_bytes([fb[i*4],fb[i*4+1],fb[i*4+2],fb[i*4+3]]);
qb[i] = (f / scale).round().clamp(-128.0, 127.0) as i8 as u8;
}
MpsGraphQTensor {
tensor: bridge::tensor_from_bytes(&qb, tensor.shape.clone(), DType::I8, tensor.device),
scheme: *scheme,
}
}
fn dequantize(tensor: QuantizedTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
let qb = bridge::tensor_to_bytes(&tensor.tensor);
let n = tensor.tensor.num_elements();
let mut fb = vec![0u8; n * 4];
for i in 0..n {
let f = (qb[i] as i8) as f32;
fb[i*4..i*4+4].copy_from_slice(&f.to_le_bytes());
}
bridge::tensor_from_bytes(&fb, tensor.tensor.shape.clone(), DType::F32, tensor.tensor.device)
}
fn q_device(t: &QuantizedTensor<MpsGraph>) -> Device<MpsGraph> { t.tensor.device }
fn q_to_device(t: QuantizedTensor<MpsGraph>, d: &Device<MpsGraph>) -> QuantizedTensor<MpsGraph> {
{ let buf = unsafe { crate::ffi::retain(t.tensor.buffer) }; MpsGraphQTensor { tensor: MpsGraphTensor { buffer: buf, shape: t.tensor.shape.clone(), dtype: t.tensor.dtype, device: *d }, scheme: t.scheme } }
}
fn q_reshape(t: QuantizedTensor<MpsGraph>, shape: Shape) -> QuantizedTensor<MpsGraph> {
let bytes = bridge::tensor_to_bytes(&t.tensor);
MpsGraphQTensor { tensor: bridge::tensor_from_bytes(&bytes, shape, t.tensor.dtype, t.tensor.device), scheme: t.scheme }
}
fn q_into_data(t: QuantizedTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t.tensor), t.tensor.shape.clone(), DType::QFloat(t.scheme))) }
}
fn q_swap_dims(t: QuantizedTensor<MpsGraph>, d1: usize, d2: usize) -> QuantizedTensor<MpsGraph> {
let f = Self::dequantize(t.clone()); let s = MpsGraph::float_swap_dims(f,d1,d2); Self::quantize_dynamic(s, &t.scheme)
}
fn q_permute(t: QuantizedTensor<MpsGraph>, axes: &[usize]) -> QuantizedTensor<MpsGraph> {
let f = Self::dequantize(t.clone()); let p = MpsGraph::float_permute(f, axes); Self::quantize_dynamic(p, &t.scheme)
}
fn q_flip(t: QuantizedTensor<MpsGraph>, axes: &[usize]) -> QuantizedTensor<MpsGraph> {
let f = Self::dequantize(t.clone()); let fl = MpsGraph::float_flip(f, axes); Self::quantize_dynamic(fl, &t.scheme)
}
fn q_select(t: QuantizedTensor<MpsGraph>, dim: usize, idx: burn_backend::tensor::IntTensor<MpsGraph>) -> QuantizedTensor<MpsGraph> {
let f = Self::dequantize(t.clone()); let s = MpsGraph::float_select(f,dim,idx); Self::quantize_dynamic(s, &t.scheme)
}
fn q_slice(t: QuantizedTensor<MpsGraph>, slices: &[Slice]) -> QuantizedTensor<MpsGraph> {
let f = Self::dequantize(t.clone()); let s = MpsGraph::float_slice(f,slices); Self::quantize_dynamic(s, &t.scheme)
}
fn q_expand(t: QuantizedTensor<MpsGraph>, shape: Shape) -> QuantizedTensor<MpsGraph> {
let f = Self::dequantize(t.clone()); let e = MpsGraph::float_expand(f,shape); Self::quantize_dynamic(e, &t.scheme)
}
}