Skip to main content

burn_mpsgraph/ops/
qtensor.rs

1use burn_backend::ops::{FloatTensorOps, QTensorOps};
2use burn_backend::tensor::{Device, FloatTensor, QuantizedTensor};
3use burn_backend::{DType, ExecutionError, TensorData};
4use burn_std::quantization::QuantScheme;
5use burn_std::{Shape, Slice};
6use std::future::Future;
7
8use crate::bridge;
9use crate::{MpsGraph, MpsGraphQTensor, MpsGraphTensor};
10
11impl QTensorOps<MpsGraph> for MpsGraph {
12    fn q_from_data(data: TensorData, device: &Device<MpsGraph>) -> QuantizedTensor<MpsGraph> {
13        let scheme = match data.dtype { DType::QFloat(s) => s, _ => panic!("Expected QFloat") };
14        let shape = Shape::from(data.shape.clone());
15        MpsGraphQTensor {
16            tensor: bridge::tensor_from_bytes(data.as_bytes(), shape, DType::I8, *device),
17            scheme,
18        }
19    }
20
21    fn quantize(tensor: FloatTensor<MpsGraph>, scheme: &QuantScheme, qparams: burn_backend::tensor::quantization::QuantizationParametersPrimitive<MpsGraph>) -> QuantizedTensor<MpsGraph> {
22        let sb = bridge::tensor_to_bytes(&qparams.scales);
23        let scale = f32::from_le_bytes([sb[0],sb[1],sb[2],sb[3]]);
24        let fb = bridge::tensor_to_bytes(&tensor);
25        let n = tensor.num_elements();
26        let mut qb = vec![0u8; n];
27        for i in 0..n {
28            let f = f32::from_le_bytes([fb[i*4],fb[i*4+1],fb[i*4+2],fb[i*4+3]]);
29            qb[i] = (f / scale).round().clamp(-128.0, 127.0) as i8 as u8;
30        }
31        MpsGraphQTensor {
32            tensor: bridge::tensor_from_bytes(&qb, tensor.shape.clone(), DType::I8, tensor.device),
33            scheme: *scheme,
34        }
35    }
36
37    fn dequantize(tensor: QuantizedTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
38        let qb = bridge::tensor_to_bytes(&tensor.tensor);
39        let n = tensor.tensor.num_elements();
40        let mut fb = vec![0u8; n * 4];
41        for i in 0..n {
42            let f = (qb[i] as i8) as f32;
43            fb[i*4..i*4+4].copy_from_slice(&f.to_le_bytes());
44        }
45        bridge::tensor_from_bytes(&fb, tensor.tensor.shape.clone(), DType::F32, tensor.tensor.device)
46    }
47
48    fn q_device(t: &QuantizedTensor<MpsGraph>) -> Device<MpsGraph> { t.tensor.device }
49    fn q_to_device(t: QuantizedTensor<MpsGraph>, d: &Device<MpsGraph>) -> QuantizedTensor<MpsGraph> {
50        { let buf = unsafe { crate::ffi::retain(t.tensor.buffer) }; MpsGraphQTensor { tensor: MpsGraphTensor { buffer: buf, shape: t.tensor.shape.clone(), dtype: t.tensor.dtype, device: *d }, scheme: t.scheme } }
51    }
52
53    fn q_reshape(t: QuantizedTensor<MpsGraph>, shape: Shape) -> QuantizedTensor<MpsGraph> {
54        let bytes = bridge::tensor_to_bytes(&t.tensor);
55        MpsGraphQTensor { tensor: bridge::tensor_from_bytes(&bytes, shape, t.tensor.dtype, t.tensor.device), scheme: t.scheme }
56    }
57
58    fn q_into_data(t: QuantizedTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
59        async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t.tensor), t.tensor.shape.clone(), DType::QFloat(t.scheme))) }
60    }
61
62    fn q_swap_dims(t: QuantizedTensor<MpsGraph>, d1: usize, d2: usize) -> QuantizedTensor<MpsGraph> {
63        let f = Self::dequantize(t.clone()); let s = MpsGraph::float_swap_dims(f,d1,d2); Self::quantize_dynamic(s, &t.scheme)
64    }
65    fn q_permute(t: QuantizedTensor<MpsGraph>, axes: &[usize]) -> QuantizedTensor<MpsGraph> {
66        let f = Self::dequantize(t.clone()); let p = MpsGraph::float_permute(f, axes); Self::quantize_dynamic(p, &t.scheme)
67    }
68    fn q_flip(t: QuantizedTensor<MpsGraph>, axes: &[usize]) -> QuantizedTensor<MpsGraph> {
69        let f = Self::dequantize(t.clone()); let fl = MpsGraph::float_flip(f, axes); Self::quantize_dynamic(fl, &t.scheme)
70    }
71    fn q_select(t: QuantizedTensor<MpsGraph>, dim: usize, idx: burn_backend::tensor::IntTensor<MpsGraph>) -> QuantizedTensor<MpsGraph> {
72        let f = Self::dequantize(t.clone()); let s = MpsGraph::float_select(f,dim,idx); Self::quantize_dynamic(s, &t.scheme)
73    }
74    fn q_slice(t: QuantizedTensor<MpsGraph>, slices: &[Slice]) -> QuantizedTensor<MpsGraph> {
75        let f = Self::dequantize(t.clone()); let s = MpsGraph::float_slice(f,slices); Self::quantize_dynamic(s, &t.scheme)
76    }
77    fn q_expand(t: QuantizedTensor<MpsGraph>, shape: Shape) -> QuantizedTensor<MpsGraph> {
78        let f = Self::dequantize(t.clone()); let e = MpsGraph::float_expand(f,shape); Self::quantize_dynamic(e, &t.scheme)
79    }
80}