use alloc::vec::Vec;
use burn_std::{
Shape, Slice,
quantization::{QuantPropagation, QuantScheme},
};
use crate::tensor::{
BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
quantization::{Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range},
};
use crate::{
Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
};
#[macro_export]
macro_rules! dequant_op_quant {
(
ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
) => {{
let scheme = $t1.scheme().clone();
let t1_f = <$ty>::dequantize($t1);
let t2_f = <$ty>::dequantize($t2);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(t1_f, t2_f);
<$ty>::quantize_dynamic(out_f, &scheme)
}};
(
ty $ty:ty, float_op $float_op:expr, $tensor:expr
) => {{
let scheme = $tensor.scheme().clone();
let tensor_f = <$ty>::dequantize($tensor);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(tensor_f);
<$ty>::quantize_dynamic(out_f, &scheme)
}};
}
#[macro_export]
macro_rules! dequant_op_flow {
(
ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
) => {{
let scheme = $t1.scheme().clone();
let propagation = $t1.propagation();
let t1_f = <$ty>::dequantize($t1);
let t2_f = <$ty>::dequantize($t2);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(t1_f, t2_f);
match propagation {
QuantPropagation::Propagate => {
TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
}
QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
}
}};
(
ty $ty:ty, float_op $float_op:expr, $tensor:expr
) => {{
let scheme = $tensor.scheme().clone();
let propagation = $tensor.propagation();
let tensor_f = <$ty>::dequantize($tensor);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(tensor_f);
match propagation {
QuantPropagation::Propagate => {
TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
}
QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
}
}};
}
pub trait QTensorOps<B: Backend> {
fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
fn quantize(
tensor: FloatTensor<B>,
scheme: &QuantScheme,
qparams: QuantizationParametersPrimitive<B>,
) -> QuantizedTensor<B>;
fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
let qparams = compute_q_params(scheme, min, max);
Self::quantize(tensor, scheme, qparams)
}
fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
fn q_into_data(
tensor: QuantizedTensor<B>,
) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
tensor
}
fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
tensor
}
fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
false
}
fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let ndims = tensor.shape().num_dims();
Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
}
fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
fn q_select(
tensor: QuantizedTensor<B>,
dim: usize,
indices: IntTensor<B>,
) -> QuantizedTensor<B>;
fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
fn q_gather(
dim: usize,
tensor: QuantizedTensor<B>,
indices: IntTensor<B>,
) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_gather(dim, tensor, indices),
tensor
)
}
fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_repeat_dim(tensor, dim, times),
tensor
)
}
fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_add(lhs, rhs),
lhs,
rhs
)
}
fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_add_scalar(tensor, rhs),
lhs
)
}
fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_clamp_min(tensor, min),
tensor
)
}
fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_clamp_max(tensor, max),
tensor
)
}
fn q_clamp(
tensor: QuantizedTensor<B>,
min: FloatElem<B>,
max: FloatElem<B>,
) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_clamp(tensor, min, max),
tensor
)
}
fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_sub(lhs, rhs),
lhs,
rhs
)
}
fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sub_scalar(tensor, rhs),
lhs
)
}
fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_mul(lhs, rhs),
lhs,
rhs
)
}
fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_mul_scalar(tensor, rhs),
lhs
)
}
fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_div(lhs, rhs),
lhs,
rhs
)
}
fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_div_scalar(tensor, rhs),
lhs
)
}
fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
let mut propagation = QuantPropagation::Inhibit;
let mut scheme = QuantScheme::default();
let lhs = match lhs {
TensorPrimitive::Float(lhs) => lhs,
TensorPrimitive::QFloat(lhs) => {
propagation = lhs.propagation();
scheme = *lhs.scheme();
Self::dequantize(lhs)
}
};
let rhs = match rhs {
TensorPrimitive::Float(rhs) => rhs,
TensorPrimitive::QFloat(rhs) => {
propagation = rhs.propagation();
scheme = *rhs.scheme();
Self::dequantize(rhs)
}
};
let out_f = B::float_matmul(lhs, rhs);
match propagation {
QuantPropagation::Propagate => {
TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
}
QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
}
}
fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_neg(tensor),
tensor
)
}
fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_recip(tensor),
tensor
)
}
fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sum(tensor),
tensor
)
}
fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sum_dim(tensor, dim),
tensor
)
}
fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_prod(tensor),
tensor
)
}
fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_prod_dim(tensor, dim),
tensor
)
}
fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_mean(tensor),
tensor
)
}
fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_mean_dim(tensor, dim),
tensor
)
}
fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cumsum(tensor, dim),
tensor
)
}
fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cumprod(tensor, dim),
tensor
)
}
fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cummin(tensor, dim),
tensor
)
}
fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cummax(tensor, dim),
tensor
)
}
fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_exp(tensor),
tensor
)
}
fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_log(tensor),
tensor
)
}
fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_log1p(tensor),
tensor
)
}
fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_powf(lhs, rhs),
lhs,
rhs
)
}
fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_powi(tensor, rhs),
lhs
)
}
fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_powi_scalar(tensor, rhs),
lhs
)
}
fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_powf_scalar(tensor, value),
tensor
)
}
fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sqrt(tensor),
tensor
)
}
fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_abs(tensor),
tensor
)
}
fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cos(tensor),
tensor
)
}
fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sin(tensor),
tensor
)
}
fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_tan(tensor),
tensor
)
}
fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cosh(tensor),
tensor
)
}
fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sinh(tensor),
tensor
)
}
fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_tanh(tensor),
tensor
)
}
fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_erf(tensor),
tensor
)
}
fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
let scheme = *tensors.first().unwrap().scheme();
let tensor_f = tensors
.into_iter()
.map(|tensor| Self::dequantize(tensor))
.collect();
let out_f = B::float_cat(tensor_f, dim);
Self::quantize_dynamic(out_f, &scheme)
}
fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_argmax(tensor_f, dim)
}
fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_argmin(tensor_f, dim)
}
fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let shape = tensor.shape();
let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
B::q_max_dim(tensor, 0)
}
fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
let index = B::q_argmax(tensor.clone(), dim);
B::q_gather(dim, tensor, index)
}
fn q_max_dim_with_indices(
tensor: QuantizedTensor<B>,
dim: usize,
) -> (QuantizedTensor<B>, IntTensor<B>) {
let index = B::q_argmax(tensor.clone(), dim);
let values = B::q_gather(dim, tensor, index.clone());
(values, index)
}
fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let shape = tensor.shape();
let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
B::q_min_dim(tensor, 0)
}
fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
let index = B::q_argmin(tensor.clone(), dim);
B::q_gather(dim, tensor, index)
}
fn q_min_dim_with_indices(
tensor: QuantizedTensor<B>,
dim: usize,
) -> (QuantizedTensor<B>, IntTensor<B>) {
let index = B::q_argmin(tensor.clone(), dim);
let values = B::q_gather(dim, tensor, index.clone());
(values, index)
}
fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let shape = tensor.shape();
let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
B::q_max_abs_dim(tensor, 0)
}
fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
B::q_gather(dim, tensor, index)
}
fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_any(tensor_f)
}
fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_any_dim(tensor_f, dim)
}
fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_all(tensor_f)
}
fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_all_dim(tensor_f, dim)
}
fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_sort(tensor, dim, descending),
tensor
)
}
fn q_sort_with_indices(
tensor: QuantizedTensor<B>,
dim: usize,
descending: bool,
) -> (QuantizedTensor<B>, IntTensor<B>) {
let scheme = *tensor.scheme();
let tensor_f = Self::dequantize(tensor);
let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
(Self::quantize_dynamic(out_f, &scheme), indices)
}
fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_argsort(tensor_f, dim, descending)
}
}