burn_mpsgraph/ops/
qtensor.rs1use burn_backend::ops::{FloatTensorOps, QTensorOps};
2use burn_backend::tensor::{Device, FloatTensor, QuantizedTensor};
3use burn_backend::{DType, ExecutionError, TensorData};
4use burn_std::quantization::QuantScheme;
5use burn_std::{Shape, Slice};
6use std::future::Future;
7
8use crate::bridge;
9use crate::{MpsGraph, MpsGraphQTensor, MpsGraphTensor};
10
11impl QTensorOps<MpsGraph> for MpsGraph {
12 fn q_from_data(data: TensorData, device: &Device<MpsGraph>) -> QuantizedTensor<MpsGraph> {
13 let scheme = match data.dtype { DType::QFloat(s) => s, _ => panic!("Expected QFloat") };
14 let shape = Shape::from(data.shape.clone());
15 MpsGraphQTensor {
16 tensor: bridge::tensor_from_bytes(data.as_bytes(), shape, DType::I8, *device),
17 scheme,
18 }
19 }
20
21 fn quantize(tensor: FloatTensor<MpsGraph>, scheme: &QuantScheme, qparams: burn_backend::tensor::quantization::QuantizationParametersPrimitive<MpsGraph>) -> QuantizedTensor<MpsGraph> {
22 let sb = bridge::tensor_to_bytes(&qparams.scales);
23 let scale = f32::from_le_bytes([sb[0],sb[1],sb[2],sb[3]]);
24 let fb = bridge::tensor_to_bytes(&tensor);
25 let n = tensor.num_elements();
26 let mut qb = vec![0u8; n];
27 for i in 0..n {
28 let f = f32::from_le_bytes([fb[i*4],fb[i*4+1],fb[i*4+2],fb[i*4+3]]);
29 qb[i] = (f / scale).round().clamp(-128.0, 127.0) as i8 as u8;
30 }
31 MpsGraphQTensor {
32 tensor: bridge::tensor_from_bytes(&qb, tensor.shape.clone(), DType::I8, tensor.device),
33 scheme: *scheme,
34 }
35 }
36
37 fn dequantize(tensor: QuantizedTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
38 let qb = bridge::tensor_to_bytes(&tensor.tensor);
39 let n = tensor.tensor.num_elements();
40 let mut fb = vec![0u8; n * 4];
41 for i in 0..n {
42 let f = (qb[i] as i8) as f32;
43 fb[i*4..i*4+4].copy_from_slice(&f.to_le_bytes());
44 }
45 bridge::tensor_from_bytes(&fb, tensor.tensor.shape.clone(), DType::F32, tensor.tensor.device)
46 }
47
48 fn q_device(t: &QuantizedTensor<MpsGraph>) -> Device<MpsGraph> { t.tensor.device }
49 fn q_to_device(t: QuantizedTensor<MpsGraph>, d: &Device<MpsGraph>) -> QuantizedTensor<MpsGraph> {
50 { let buf = unsafe { crate::ffi::retain(t.tensor.buffer) }; MpsGraphQTensor { tensor: MpsGraphTensor { buffer: buf, shape: t.tensor.shape.clone(), dtype: t.tensor.dtype, device: *d }, scheme: t.scheme } }
51 }
52
53 fn q_reshape(t: QuantizedTensor<MpsGraph>, shape: Shape) -> QuantizedTensor<MpsGraph> {
54 let bytes = bridge::tensor_to_bytes(&t.tensor);
55 MpsGraphQTensor { tensor: bridge::tensor_from_bytes(&bytes, shape, t.tensor.dtype, t.tensor.device), scheme: t.scheme }
56 }
57
58 fn q_into_data(t: QuantizedTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
59 async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t.tensor), t.tensor.shape.clone(), DType::QFloat(t.scheme))) }
60 }
61
62 fn q_swap_dims(t: QuantizedTensor<MpsGraph>, d1: usize, d2: usize) -> QuantizedTensor<MpsGraph> {
63 let f = Self::dequantize(t.clone()); let s = MpsGraph::float_swap_dims(f,d1,d2); Self::quantize_dynamic(s, &t.scheme)
64 }
65 fn q_permute(t: QuantizedTensor<MpsGraph>, axes: &[usize]) -> QuantizedTensor<MpsGraph> {
66 let f = Self::dequantize(t.clone()); let p = MpsGraph::float_permute(f, axes); Self::quantize_dynamic(p, &t.scheme)
67 }
68 fn q_flip(t: QuantizedTensor<MpsGraph>, axes: &[usize]) -> QuantizedTensor<MpsGraph> {
69 let f = Self::dequantize(t.clone()); let fl = MpsGraph::float_flip(f, axes); Self::quantize_dynamic(fl, &t.scheme)
70 }
71 fn q_select(t: QuantizedTensor<MpsGraph>, dim: usize, idx: burn_backend::tensor::IntTensor<MpsGraph>) -> QuantizedTensor<MpsGraph> {
72 let f = Self::dequantize(t.clone()); let s = MpsGraph::float_select(f,dim,idx); Self::quantize_dynamic(s, &t.scheme)
73 }
74 fn q_slice(t: QuantizedTensor<MpsGraph>, slices: &[Slice]) -> QuantizedTensor<MpsGraph> {
75 let f = Self::dequantize(t.clone()); let s = MpsGraph::float_slice(f,slices); Self::quantize_dynamic(s, &t.scheme)
76 }
77 fn q_expand(t: QuantizedTensor<MpsGraph>, shape: Shape) -> QuantizedTensor<MpsGraph> {
78 let f = Self::dequantize(t.clone()); let e = MpsGraph::float_expand(f,shape); Self::quantize_dynamic(e, &t.scheme)
79 }
80}