use super::dequant::{
dequant_f16_to_f16, dequant_f16_to_f32, dequant_f32_to_f16, dequant_f32_to_f32,
dequant_i16_to_f16, dequant_i16_to_f32, dequant_i8_to_f16, dequant_i8_to_f32,
dequant_u16_to_f16, dequant_u16_to_f32, dequant_u8_to_f16, dequant_u8_to_f32,
};
use crate::per_scale::plan::LevelPlan;
use crate::Quantization;
use half::f16;
macro_rules! impl_mc_level {
($name:ident, $i:ty, $f:ty, $dequant:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(
input: &[$i],
q: Quantization,
num_mc: usize,
level: &LevelPlan,
dst: &mut [$f],
) {
let n = level.h * level.w * num_mc;
debug_assert_eq!(input.len(), n);
debug_assert_eq!(dst.len(), n);
$dequant(input, q, dst);
}
};
}
impl_mc_level!(decode_mc_level_i8_to_f32, i8, f32, dequant_i8_to_f32);
impl_mc_level!(decode_mc_level_u8_to_f32, u8, f32, dequant_u8_to_f32);
impl_mc_level!(decode_mc_level_i16_to_f32, i16, f32, dequant_i16_to_f32);
impl_mc_level!(decode_mc_level_u16_to_f32, u16, f32, dequant_u16_to_f32);
impl_mc_level!(decode_mc_level_f16_to_f32, f16, f32, dequant_f16_to_f32);
impl_mc_level!(decode_mc_level_f32_to_f32, f32, f32, dequant_f32_to_f32);
impl_mc_level!(decode_mc_level_i8_to_f16, i8, f16, dequant_i8_to_f16);
impl_mc_level!(decode_mc_level_u8_to_f16, u8, f16, dequant_u8_to_f16);
impl_mc_level!(decode_mc_level_i16_to_f16, i16, f16, dequant_i16_to_f16);
impl_mc_level!(decode_mc_level_u16_to_f16, u16, f16, dequant_u16_to_f16);
impl_mc_level!(decode_mc_level_f16_to_f16, f16, f16, dequant_f16_to_f16);
impl_mc_level!(decode_mc_level_f32_to_f16, f32, f16, dequant_f32_to_f16);
macro_rules! impl_proto {
($name:ident, $i:ty, $f:ty, $dequant:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, dst: &mut [$f]) {
debug_assert_eq!(input.len(), dst.len());
$dequant(input, q, dst);
}
};
}
impl_proto!(decode_proto_i8_to_f32, i8, f32, dequant_i8_to_f32);
impl_proto!(decode_proto_u8_to_f32, u8, f32, dequant_u8_to_f32);
impl_proto!(decode_proto_i16_to_f32, i16, f32, dequant_i16_to_f32);
impl_proto!(decode_proto_u16_to_f32, u16, f32, dequant_u16_to_f32);
impl_proto!(decode_proto_f16_to_f32, f16, f32, dequant_f16_to_f32);
impl_proto!(decode_proto_f32_to_f32, f32, f32, dequant_f32_to_f32);
impl_proto!(decode_proto_i8_to_f16, i8, f16, dequant_i8_to_f16);
impl_proto!(decode_proto_u8_to_f16, u8, f16, dequant_u8_to_f16);
impl_proto!(decode_proto_i16_to_f16, i16, f16, dequant_i16_to_f16);
impl_proto!(decode_proto_u16_to_f16, u16, f16, dequant_u16_to_f16);
impl_proto!(decode_proto_f16_to_f16, f16, f16, dequant_f16_to_f16);
impl_proto!(decode_proto_f32_to_f16, f32, f16, dequant_f32_to_f16);
#[cfg(target_arch = "aarch64")]
macro_rules! impl_mc_level_neon {
($name:ident, $i:ty, $dequant_neon:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(
input: &[$i],
q: Quantization,
num_mc: usize,
level: &LevelPlan,
dst: &mut [f32],
) {
let n = level.h * level.w * num_mc;
debug_assert_eq!(input.len(), n);
debug_assert_eq!(dst.len(), n);
unsafe {
crate::per_scale::kernels::neon_baseline::$dequant_neon(input, q, dst);
}
}
};
}
#[cfg(target_arch = "aarch64")]
impl_mc_level_neon!(decode_mc_level_i8_to_f32_neon, i8, dequant_i8_to_f32_neon);
#[cfg(target_arch = "aarch64")]
impl_mc_level_neon!(decode_mc_level_u8_to_f32_neon, u8, dequant_u8_to_f32_neon);
#[cfg(target_arch = "aarch64")]
impl_mc_level_neon!(
decode_mc_level_i16_to_f32_neon,
i16,
dequant_i16_to_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_mc_level_neon!(
decode_mc_level_u16_to_f32_neon,
u16,
dequant_u16_to_f32_neon
);
#[cfg(target_arch = "aarch64")]
macro_rules! impl_proto_neon {
($name:ident, $i:ty, $dequant_neon:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, dst: &mut [f32]) {
debug_assert_eq!(input.len(), dst.len());
unsafe {
crate::per_scale::kernels::neon_baseline::$dequant_neon(input, q, dst);
}
}
};
}
#[cfg(target_arch = "aarch64")]
impl_proto_neon!(decode_proto_i8_to_f32_neon, i8, dequant_i8_to_f32_neon);
#[cfg(target_arch = "aarch64")]
impl_proto_neon!(decode_proto_u8_to_f32_neon, u8, dequant_u8_to_f32_neon);
#[cfg(target_arch = "aarch64")]
impl_proto_neon!(decode_proto_i16_to_f32_neon, i16, dequant_i16_to_f32_neon);
#[cfg(target_arch = "aarch64")]
impl_proto_neon!(decode_proto_u16_to_f32_neon, u16, dequant_u16_to_f32_neon);
#[cfg(test)]
mod tests {
use super::*;
use crate::per_scale::kernels::grids::make_anchor_grid;
#[test]
fn mc_level_i8_dequant_round_trip() {
let q = Quantization::new(0.5, -10);
let nm = 2;
let (gx, gy) = make_anchor_grid(2, 2);
let lvl = LevelPlan {
stride: 8.0,
h: 2,
w: 2,
reg_max: 16,
anchor_offset: 0,
grid_x: gx,
grid_y: gy,
box_shape: vec![1, 2, 2, 64].into_boxed_slice(),
score_shape: vec![1, 2, 2, 80].into_boxed_slice(),
mc_shape: Some(vec![1, 2, 2, nm].into_boxed_slice()),
layout: crate::per_scale::plan::Layout::Nhwc,
};
let input = vec![-10_i8, 0, 10, 20, -5, 5, 100, -100];
let mut out = vec![0.0_f32; 8];
decode_mc_level_i8_to_f32(&input, q, nm, &lvl, &mut out);
assert!((out[0] - 0.0).abs() < 1e-6);
assert!((out[1] - 5.0).abs() < 1e-6);
assert!((out[2] - 10.0).abs() < 1e-6);
assert!((out[7] - -45.0).abs() < 1e-6);
}
#[test]
fn proto_i8_dequant_works() {
let q = Quantization::new(0.1, 0);
let input = vec![0_i8, 10, -10, 50];
let mut out = vec![0.0_f32; 4];
decode_proto_i8_to_f32(&input, q, &mut out);
assert!((out[0] - 0.0).abs() < 1e-6);
assert!((out[1] - 1.0).abs() < 1e-6);
assert!((out[2] - -1.0).abs() < 1e-6);
assert!((out[3] - 5.0).abs() < 1e-6);
}
}