use std::mem::MaybeUninit;
use rayon::prelude::*;
use rten_shape_inference::ops as shape_ops;
use rten_simd::SimdOp;
use rten_tensor::prelude::*;
use rten_tensor::{AssumeInit, NdTensor, NdTensorView, Scalar, Tensor, TensorView};
use rten_vecmath as vecmath;
use crate::buffer_pool::BufferPool;
use crate::infer_shapes::{InferShapes, UnaryOp};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext,
};
use crate::ops::{map_value_view, resolve_axis};
use crate::value::{DataType, Value, ValueType, ValueView};
pub trait Dequantize<To> {
fn dequantize(self, scale: To, zero_point: Self) -> To;
}
macro_rules! impl_dequantize_f32 {
($src:ty) => {
impl Dequantize<f32> for $src {
fn dequantize(self, scale: f32, zero_point: Self) -> f32 {
let x = (self as i32) - zero_point as i32;
(x as f32) * scale
}
}
};
}
impl_dequantize_f32!(i32);
impl_dequantize_f32!(i8);
impl_dequantize_f32!(u8);
pub fn dequantize_linear<T: Copy + Default + Dequantize<f32> + Scalar>(
pool: &BufferPool,
input: TensorView<T>,
scale: TensorView<f32>,
zero_point: Option<TensorView<T>>,
axis: isize,
) -> Result<Tensor<f32>, OpError> {
if let Some(zero_point) = zero_point.as_ref()
&& zero_point.shape() != scale.shape()
{
return Err(OpError::InvalidValue(
"scale and zero_point must have same shape",
));
}
match scale.ndim() {
0 => {
let scale = scale.item().unwrap();
let zero_point = zero_point.and_then(|z| z.item()).unwrap();
Ok(input.map_in(pool, |x| x.dequantize(*scale, *zero_point)))
}
1 => {
let axis = resolve_axis(input.ndim(), axis)?;
let scale: NdTensorView<f32, 1> = scale.try_into().unwrap();
let zero = NdTensor::from(T::default());
let zero_point: NdTensorView<T, 1> = zero_point
.map(|zp| {
let zp_vec: NdTensorView<T, 1> = zp.try_into().unwrap();
zp_vec
})
.unwrap_or(zero.broadcast(scale.shape()));
let mut output = Tensor::uninit_in(pool, input.shape());
output
.axis_iter_mut(axis)
.zip(input.axis_iter(axis))
.zip(scale.iter())
.zip(zero_point.iter())
.for_each(|(((mut out_slice, in_slice), &scale), &zero_point)| {
for (y, &x) in out_slice.iter_mut().zip(in_slice.iter()) {
y.write(x.dequantize(scale, zero_point));
}
});
Ok(unsafe { output.assume_init() })
}
_ => Err(OpError::UnsupportedValue(
"Blocked dequantization is not supported",
)),
}
}
#[derive(Debug)]
pub struct DequantizeLinear {
pub axis: isize,
}
impl Operator for DequantizeLinear {
fn name(&self) -> &str {
"DequantizeLinear"
}
fn max_inputs(&self) -> Option<usize> {
Some(3)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let scale = inputs.require_as(1)?;
map_value_view!(input, x, [Int8Tensor, UInt8Tensor, Int32Tensor], {
let zero_point = inputs.get_as(2)?;
dequantize_linear(ctx.pool(), x, scale, zero_point, self.axis).into_op_result()
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(DataType::Float))].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&UnaryOp)
}
}
pub trait Quantize<To> {
fn quantize(self, inv_scale: Self, zero_point: To) -> To;
fn quantize_slice<'a>(
src: &[Self],
dest: &'a mut [MaybeUninit<To>],
inv_scale: Self,
zero_point: To,
) -> &'a mut [To]
where
Self: Copy + Sized,
To: Copy,
{
assert_eq!(src.len(), dest.len());
for (x, y) in src.iter().zip(dest.iter_mut()) {
y.write(x.quantize(inv_scale, zero_point));
}
unsafe { dest.assume_init() }
}
}
impl Quantize<u8> for f32 {
fn quantize(self, inv_scale: Self, zero_point: u8) -> u8 {
let y = (self * inv_scale).round_ties_even();
let y = y + zero_point as f32;
y as u8 }
fn quantize_slice<'a>(
src: &[f32],
dest: &'a mut [MaybeUninit<u8>],
inv_scale: f32,
zero_point: u8,
) -> &'a mut [u8] {
vecmath::Quantize::new(src, dest, inv_scale, zero_point).dispatch()
}
}
impl Quantize<i8> for f32 {
fn quantize(self, inv_scale: Self, zero_point: i8) -> i8 {
let y = (self * inv_scale).round_ties_even();
let y = y + zero_point as f32;
y as i8 }
}
pub fn quantize_linear<T: Copy + Default + Send + Sync + Scalar>(
pool: &BufferPool,
input: TensorView<f32>,
scale: TensorView<f32>,
zero_point: Option<TensorView<T>>,
axis: isize,
) -> Result<Tensor<T>, OpError>
where
f32: Quantize<T>,
{
if let Some(zero_point) = zero_point.as_ref()
&& zero_point.shape() != scale.shape()
{
return Err(OpError::InvalidValue(
"scale and zero_point must have same shape",
));
}
match scale.ndim() {
0 => {
let inv_scale = 1. / *scale.item().unwrap();
let zero_point = *zero_point.and_then(|z| z.item()).unwrap();
if let Some(data) = input.data() {
let mut buf = pool.alloc(data.len());
let buf_uninit = &mut buf.spare_capacity_mut()[..data.len()];
let chunk_size = 4096;
buf_uninit
.par_chunks_mut(chunk_size)
.zip(data.par_chunks(chunk_size))
.for_each(|(out_data, data)| {
Quantize::quantize_slice(data, out_data, inv_scale, zero_point);
});
let buf_uninit_len = buf_uninit.len();
unsafe { buf.set_len(buf_uninit_len) };
Ok(Tensor::from_data(input.shape(), buf))
} else {
Ok(input.map_in(pool, |x| x.quantize(inv_scale, zero_point)))
}
}
1 => {
let axis = resolve_axis(input.ndim(), axis)?;
let scale: NdTensorView<f32, 1> = scale.try_into().unwrap();
let zero = NdTensor::from(T::default());
let zero_point: NdTensorView<T, 1> = zero_point
.map(|zp| {
let zp_vec: NdTensorView<T, 1> = zp.try_into().unwrap();
zp_vec
})
.unwrap_or(zero.broadcast(scale.shape()));
let mut output = Tensor::uninit_in(pool, input.shape());
output
.axis_iter_mut(axis)
.into_par_iter()
.zip(input.axis_iter(axis))
.zip(scale.iter())
.zip(zero_point.iter())
.for_each(|(((mut out_slice, in_slice), &scale), &zero_point)| {
let inv_scale = 1. / scale;
for (y, &x) in out_slice.iter_mut().zip(in_slice.iter()) {
y.write(x.quantize(inv_scale, zero_point));
}
});
Ok(unsafe { output.assume_init() })
}
_ => Err(OpError::UnsupportedValue(
"Blocked quantization is not supported",
)),
}
}
#[derive(Debug)]
pub struct QuantizeLinear {
pub axis: isize,
pub output_dtype: Option<DataType>,
}
impl Operator for QuantizeLinear {
fn name(&self) -> &str {
"QuantizeLinear"
}
fn max_inputs(&self) -> Option<usize> {
Some(3)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let pool = ctx.pool();
let input = inputs.require_as(0)?;
let y_scale = inputs.require_as(1)?;
let y_zero_point = inputs.get(2);
match (y_zero_point, self.output_dtype) {
(Some(ValueView::UInt8Tensor(y_zero_point)), Some(DataType::UInt8) | None) => {
quantize_linear(pool, input, y_scale, Some(y_zero_point.view()), self.axis)
.into_op_result()
}
(None, Some(DataType::UInt8)) => {
quantize_linear::<u8>(pool, input.view(), y_scale.view(), None, self.axis)
.into_op_result()
}
(Some(ValueView::Int8Tensor(y_zero_point)), Some(DataType::Int8) | None) => {
quantize_linear(
pool,
input.view(),
y_scale.view(),
Some(y_zero_point.view()),
self.axis,
)
.into_op_result()
}
(None, Some(DataType::Int8)) => {
quantize_linear::<i8>(pool, input.view(), y_scale.view(), None, self.axis)
.into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
let dtype = self.output_dtype.unwrap_or(DataType::Int8);
Some([OutputType::Fixed(ValueType::Tensor(dtype))].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&UnaryOp)
}
}
pub trait SaturatingCast<To> {
fn saturating_cast(self) -> To;
}
impl SaturatingCast<u8> for f32 {
fn saturating_cast(self) -> u8 {
self.clamp(0., 255.) as u8
}
}
pub struct DynamicQuantizeOutput<T> {
pub quantized: Tensor<T>,
pub scale: Tensor<f32>,
pub zero_point: Tensor<T>,
}
pub fn dynamic_quantize_linear<T: Copy + Default + Send + Sync + Scalar>(
pool: &BufferPool,
input: TensorView<f32>,
) -> Result<DynamicQuantizeOutput<T>, OpError>
where
f32: Quantize<T> + SaturatingCast<T>,
{
if input.is_empty() {
return Ok(DynamicQuantizeOutput {
quantized: Tensor::zeros(input.shape()),
zero_point: Tensor::from(T::default()),
scale: Tensor::from(1.),
});
}
let q_min = 0.;
let q_max = 255.;
let input = input.to_contiguous_in(pool);
let chunk_size = 4096;
let (x_min, x_max) = input
.data()
.par_chunks(chunk_size)
.map(|chunk| vecmath::MinMax::new(chunk).dispatch())
.reduce(
|| (f32::MAX, f32::MIN),
|(prev_min, prev_max), (chunk_min, chunk_max)| {
(chunk_min.min(prev_min), chunk_max.max(prev_max))
},
);
let x_min_adjusted = x_min.min(q_min);
let x_max_adjusted = x_max.max(q_min);
let x_range = x_max_adjusted - x_min_adjusted;
let scale = x_range / q_max;
let min_scaled = x_min_adjusted / scale;
let initial_zero_point = q_min - min_scaled;
let clipped_zero_point = initial_zero_point.clamp(q_min, q_max);
let rounded_zero_point = clipped_zero_point.round_ties_even();
let zero_point: T = rounded_zero_point.saturating_cast();
let scale_tensor = Tensor::from(scale);
let zero_point_tensor = Tensor::from(zero_point);
let quantized = quantize_linear(
pool,
input.view().into(),
scale_tensor.view(),
Some(zero_point_tensor.view()),
1,
)?;
Ok(DynamicQuantizeOutput {
quantized,
scale: scale_tensor,
zero_point: zero_point_tensor,
})
}
#[derive(Debug)]
pub struct DynamicQuantizeLinear {}
impl Operator for DynamicQuantizeLinear {
fn name(&self) -> &str {
"DynamicQuantizeLinear"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require_as(0)?;
let DynamicQuantizeOutput {
quantized,
scale,
zero_point,
} = dynamic_quantize_linear::<u8>(ctx.pool(), input)?;
let quantized: Value = quantized.into();
let scale: Value = scale.into();
let zero_point: Value = zero_point.into();
Ok([quantized, scale, zero_point].into_iter().collect())
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some(OutputTypeList::from_slice(&[
OutputType::Fixed(ValueType::Tensor(DataType::UInt8)),
OutputType::Fixed(ValueType::Tensor(DataType::Float)),
OutputType::Fixed(ValueType::Tensor(DataType::UInt8)),
]))
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::DynamicQuantizeLinear)
}
}
#[cfg(test)]
mod tests {
use rten_tensor::Tensor;
use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal_with_tolerance;
use rten_testing::TestCases;
use super::{dequantize_linear, dynamic_quantize_linear, quantize_linear};
use crate::buffer_pool::BufferPool;
use crate::operator::OpError;
use crate::value::Value;
#[test]
fn test_dequantize_quantize_linear() {
#[derive(Debug)]
struct Case {
axis: isize,
input: Value,
scale: Tensor<f32>,
zero_point: Option<Value>,
expected: Result<Tensor<f32>, OpError>,
}
let cases = [
Case {
axis: 1,
input: Tensor::from([20u8, 30, 40]).into(),
scale: Tensor::from(0.5),
zero_point: Some(Tensor::from(10u8).into()),
expected: Ok(Tensor::from([5., 10., 15.])),
},
Case {
axis: 1,
input: Tensor::from([20i8, 30, 40]).into(),
scale: Tensor::from(0.5),
zero_point: Some(Tensor::from(10i8).into()),
expected: Ok(Tensor::from([5., 10., 15.])),
},
Case {
axis: 0,
input: Tensor::from([[10u8, 20], [30, 40]]).into(),
scale: Tensor::from([0.5, 2.]),
zero_point: Some(Tensor::from([10u8, 20]).into()),
expected: Ok(Tensor::from([[0., 5.], [20., 40.]])),
},
Case {
axis: 0,
input: Tensor::from([10u8]).into(),
scale: Tensor::from([0.5, 2.]),
zero_point: Some(Tensor::from([1u8, 2, 3]).into()),
expected: Err(OpError::InvalidValue(
"scale and zero_point must have same shape",
)),
},
Case {
axis: 0,
input: Tensor::from([[10u8, 20], [30, 40]]).into(),
scale: Tensor::from([[1., 2.], [3., 4.]]),
zero_point: Some(Tensor::from([[1u8, 2], [3, 4]]).into()),
expected: Err(OpError::UnsupportedValue(
"Blocked dequantization is not supported",
)),
},
Case {
axis: 0,
input: Tensor::<u8>::zeros(&[0]).into(),
scale: Tensor::zeros(&[0]),
zero_point: Some(Tensor::<u8>::zeros(&[0]).into()),
expected: Ok(Tensor::zeros(&[0])),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let Case {
input,
scale,
zero_point,
axis,
expected,
} = case;
match input {
Value::UInt8Tensor(input) => {
let zero_point: Option<Tensor<u8>> =
zero_point.clone().map(|zp| zp.try_into().unwrap());
let result = dequantize_linear(
&pool,
input.view(),
scale.view(),
zero_point.as_ref().map(|zp| zp.view()),
*axis,
);
assert_eq!(result, *expected);
if let Ok(dequant) = result {
let requantized = quantize_linear(
&pool,
dequant.view(),
scale.view(),
zero_point.as_ref().map(|zp| zp.view()),
*axis,
)
.unwrap();
assert_eq!(requantized, *input);
}
}
Value::Int8Tensor(input) => {
let zero_point: Option<Tensor<i8>> =
zero_point.clone().map(|zp| zp.try_into().unwrap());
let result = dequantize_linear(
&pool,
input.view(),
scale.view(),
zero_point.as_ref().map(|zp| zp.view()),
*axis,
);
assert_eq!(result, *expected);
if let Ok(dequant) = result {
let requantized = quantize_linear(
&pool,
dequant.view(),
scale.view(),
zero_point.as_ref().map(|zp| zp.view()),
*axis,
)
.unwrap();
assert_eq!(requantized, *input);
}
}
_ => panic!("unsupported quantized type"),
};
})
}
#[test]
fn test_dynamic_quantize_linear() {
#[derive(Debug)]
struct Case {
input: Tensor<f32>,
max_error: f32,
}
let cases = [
Case {
input: [-2., -1., 0., 1., 2.].into(),
max_error: 0.01,
},
Case {
input: [1., 2., 3., 4., 5.].into(),
max_error: 0.01,
},
Case {
input: [-1., -2., -3., -4., -5.].into(),
max_error: 0.01,
},
Case {
input: Tensor::arange(-0.1, 0.1, Some(0.01)),
max_error: 0.001,
},
Case {
input: Tensor::from([234.56]),
max_error: 0.,
},
Case {
input: Tensor::from([-234.56]),
max_error: 0.,
},
Case {
input: Tensor::zeros(&[0]),
max_error: 0.,
},
];
cases.test_each(|case| {
let Case { input, max_error } = case;
let pool = BufferPool::new();
let output = dynamic_quantize_linear::<u8>(&pool, input.view()).unwrap();
assert_eq!(output.quantized.shape(), input.shape());
let zero_point = *output.zero_point.item().unwrap();
let scale = *output.scale.item().unwrap();
let dequantized = output
.quantized
.map(|&q| (q as i32 - zero_point as i32) as f32 * scale);
expect_equal_with_tolerance(&dequantized, &input, *max_error, *max_error).unwrap();
})
}
}