use alloc::vec::Vec;
use core::ops::Range;
use crate::{
Device, Shape, TensorData, TensorMetadata, TensorPrimitive,
backend::Backend,
quantization::{
Calibration, QTensorPrimitive, QuantPropagation, QuantScheme,
QuantizationParametersPrimitive,
},
};
use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
#[macro_export]
macro_rules! dequant_op_quant {
(
ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
) => {{
let scheme = $t1.scheme().clone();
let t1_f = <$ty>::dequantize($t1);
let t2_f = <$ty>::dequantize($t2);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(t1_f, t2_f);
<$ty>::quantize_dynamic(out_f, &scheme)
}};
(
ty $ty:ty, float_op $float_op:expr, $tensor:expr
) => {{
let scheme = $tensor.scheme().clone();
let tensor_f = <$ty>::dequantize($tensor);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(tensor_f);
<$ty>::quantize_dynamic(out_f, &scheme)
}};
}
#[macro_export]
macro_rules! dequant_op_flow {
(
ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
) => {{
let scheme = $t1.scheme().clone();
let t1_f = <$ty>::dequantize($t1);
let t2_f = <$ty>::dequantize($t2);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(t1_f, t2_f);
match scheme.propagation {
QuantPropagation::Propagate => {
TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
}
QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
}
}};
(
ty $ty:ty, float_op $float_op:expr, $tensor:expr
) => {{
let scheme = $tensor.scheme().clone();
let tensor_f = <$ty>::dequantize($tensor);
#[allow(clippy::redundant_closure_call)]
let out_f = $float_op(tensor_f);
match scheme.propagation {
QuantPropagation::Propagate => {
TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
}
QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
}
}};
}
pub trait QTensorOps<B: Backend> {
fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
fn quantize(
tensor: FloatTensor<B>,
scheme: &QuantScheme,
qparams: QuantizationParametersPrimitive<B>,
) -> QuantizedTensor<B>;
fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
let (min, max) = scheme.compute_range_primitive::<B>(tensor.clone(), &Calibration::MinMax);
let qparams = scheme.compute_q_params_primitive(min, max);
Self::quantize(tensor, scheme, qparams)
}
fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
fn q_into_data(tensor: QuantizedTensor<B>) -> impl Future<Output = TensorData> + Send;
fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
tensor
}
fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
tensor
}
fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
false
}
fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let ndims = tensor.shape().num_dims();
Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
}
fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
fn q_select(
tensor: QuantizedTensor<B>,
dim: usize,
indices: IntTensor<B>,
) -> QuantizedTensor<B>;
fn q_slice(tensor: QuantizedTensor<B>, ranges: &[Range<usize>]) -> QuantizedTensor<B>;
fn q_gather(
dim: usize,
tensor: QuantizedTensor<B>,
indices: IntTensor<B>,
) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_gather(dim, tensor, indices),
tensor
)
}
fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_repeat_dim(tensor, dim, times),
tensor
)
}
fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_add(lhs, rhs),
lhs,
rhs
)
}
fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_add_scalar(tensor, rhs),
lhs
)
}
fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_clamp_min(tensor, min),
tensor
)
}
fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_clamp_max(tensor, max),
tensor
)
}
fn q_clamp(
tensor: QuantizedTensor<B>,
min: FloatElem<B>,
max: FloatElem<B>,
) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_clamp(tensor, min, max),
tensor
)
}
fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_sub(lhs, rhs),
lhs,
rhs
)
}
fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sub_scalar(tensor, rhs),
lhs
)
}
fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_mul(lhs, rhs),
lhs,
rhs
)
}
fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_mul_scalar(tensor, rhs),
lhs
)
}
fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_div(lhs, rhs),
lhs,
rhs
)
}
fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_div_scalar(tensor, rhs),
lhs
)
}
fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_matmul(lhs, rhs),
lhs,
rhs
)
}
fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_neg(tensor),
tensor
)
}
fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_recip(tensor),
tensor
)
}
fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sum(tensor),
tensor
)
}
fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sum_dim(tensor, dim),
tensor
)
}
fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_prod(tensor),
tensor
)
}
fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_prod_dim(tensor, dim),
tensor
)
}
fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_mean(tensor),
tensor
)
}
fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_mean_dim(tensor, dim),
tensor
)
}
fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_exp(tensor),
tensor
)
}
fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_log(tensor),
tensor
)
}
fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_log1p(tensor),
tensor
)
}
fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |lhs, rhs| B::float_powf(lhs, rhs),
lhs,
rhs
)
}
fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_powi(tensor, rhs),
lhs
)
}
fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_powi_scalar(tensor, rhs),
lhs
)
}
fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_powf_scalar(tensor, value),
tensor
)
}
fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sqrt(tensor),
tensor
)
}
fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_abs(tensor),
tensor
)
}
fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cos(tensor),
tensor
)
}
fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sin(tensor),
tensor
)
}
fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_tan(tensor),
tensor
)
}
fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_cosh(tensor),
tensor
)
}
fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_sinh(tensor),
tensor
)
}
fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_tanh(tensor),
tensor
)
}
fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
dequant_op_flow!(
ty Self,
float_op |tensor| B::float_erf(tensor),
tensor
)
}
fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
let scheme = *tensors.first().unwrap().scheme();
let tensor_f = tensors
.into_iter()
.map(|tensor| Self::dequantize(tensor))
.collect();
let out_f = B::float_cat(tensor_f, dim);
Self::quantize_dynamic(out_f, &scheme)
}
fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_argmax(tensor_f, dim)
}
fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_argmin(tensor_f, dim)
}
fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let shape = tensor.shape();
let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
B::q_max_dim(tensor, 0)
}
fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
let index = B::q_argmax(tensor.clone(), dim);
B::q_gather(dim, tensor, index)
}
fn q_max_dim_with_indices(
tensor: QuantizedTensor<B>,
dim: usize,
) -> (QuantizedTensor<B>, IntTensor<B>) {
let index = B::q_argmax(tensor.clone(), dim);
let values = B::q_gather(dim, tensor, index.clone());
(values, index)
}
fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let shape = tensor.shape();
let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
B::q_min_dim(tensor, 0)
}
fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
let index = B::q_argmin(tensor.clone(), dim);
B::q_gather(dim, tensor, index)
}
fn q_min_dim_with_indices(
tensor: QuantizedTensor<B>,
dim: usize,
) -> (QuantizedTensor<B>, IntTensor<B>) {
let index = B::q_argmin(tensor.clone(), dim);
let values = B::q_gather(dim, tensor, index.clone());
(values, index)
}
fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
let shape = tensor.shape();
let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
B::q_max_abs_dim(tensor, 0)
}
fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
B::q_gather(dim, tensor, index)
}
fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_any(tensor_f)
}
fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_any_dim(tensor_f, dim)
}
fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_all(tensor_f)
}
fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_all_dim(tensor_f, dim)
}
fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
dequant_op_quant!(
ty Self,
float_op |tensor| B::float_sort(tensor, dim, descending),
tensor
)
}
fn q_sort_with_indices(
tensor: QuantizedTensor<B>,
dim: usize,
descending: bool,
) -> (QuantizedTensor<B>, IntTensor<B>) {
let scheme = *tensor.scheme();
let tensor_f = Self::dequantize(tensor);
let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
(Self::quantize_dynamic(out_f, &scheme), indices)
}
fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
let tensor_f = Self::dequantize(tensor);
B::float_argsort(tensor_f, dim, descending)
}
}