use super::box_primitives::{
dist2bbox_anchor_f16, dist2bbox_anchor_f32, weighted_sum_4sides_f16, weighted_sum_4sides_f32,
};
use super::dequant::{
dequant_f16_to_f16, dequant_f16_to_f32, dequant_f32_to_f16, dequant_f32_to_f32,
dequant_i16_to_f16, dequant_i16_to_f32, dequant_i8_to_f16, dequant_i8_to_f32,
dequant_u16_to_f16, dequant_u16_to_f32, dequant_u8_to_f16, dequant_u8_to_f32,
};
use super::softmax::{softmax_inplace_f16, softmax_inplace_f32};
use crate::per_scale::plan::LevelPlan;
use crate::Quantization;
use half::f16;
const MAX_REG_MAX: usize = 64;
macro_rules! impl_dfl_level_f32 {
($name:ident, $i:ty, $dequant:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, level: &LevelPlan, dst: &mut [f32]) {
let h = level.h;
let w = level.w;
let reg_max = level.reg_max;
debug_assert_eq!(input.len(), h * w * 4 * reg_max);
debug_assert_eq!(dst.len(), 4 * h * w);
debug_assert!(reg_max <= MAX_REG_MAX);
let mut deq: [f32; 4 * MAX_REG_MAX] = [0.0_f32; 4 * MAX_REG_MAX];
for anchor in 0..(h * w) {
let in_base = anchor * 4 * reg_max;
$dequant(
&input[in_base..in_base + 4 * reg_max],
q,
&mut deq[..4 * reg_max],
);
for side in 0..4 {
softmax_inplace_f32(&mut deq[side * reg_max..(side + 1) * reg_max]);
}
let ltrb = weighted_sum_4sides_f32(&deq[..4 * reg_max], reg_max);
let gx = level.grid_x[anchor];
let gy = level.grid_y[anchor];
let xywh = dist2bbox_anchor_f32(ltrb, gx, gy, level.stride);
let out_base = anchor * 4;
dst[out_base..out_base + 4].copy_from_slice(&xywh);
}
}
};
}
macro_rules! impl_dfl_level_f16 {
($name:ident, $i:ty, $dequant:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, level: &LevelPlan, dst: &mut [f16]) {
let h = level.h;
let w = level.w;
let reg_max = level.reg_max;
debug_assert_eq!(input.len(), h * w * 4 * reg_max);
debug_assert_eq!(dst.len(), 4 * h * w);
debug_assert!(reg_max <= MAX_REG_MAX);
let mut deq: [f16; 4 * MAX_REG_MAX] = [f16::ZERO; 4 * MAX_REG_MAX];
for anchor in 0..(h * w) {
let in_base = anchor * 4 * reg_max;
$dequant(
&input[in_base..in_base + 4 * reg_max],
q,
&mut deq[..4 * reg_max],
);
for side in 0..4 {
softmax_inplace_f16(&mut deq[side * reg_max..(side + 1) * reg_max]);
}
let ltrb = weighted_sum_4sides_f16(&deq[..4 * reg_max], reg_max);
let gx = f16::from_f32(level.grid_x[anchor]);
let gy = f16::from_f32(level.grid_y[anchor]);
let stride = f16::from_f32(level.stride);
let xywh = dist2bbox_anchor_f16(ltrb, gx, gy, stride);
let out_base = anchor * 4;
dst[out_base..out_base + 4].copy_from_slice(&xywh);
}
}
};
}
impl_dfl_level_f32!(decode_box_level_dfl_i8_to_f32, i8, dequant_i8_to_f32);
impl_dfl_level_f32!(decode_box_level_dfl_u8_to_f32, u8, dequant_u8_to_f32);
impl_dfl_level_f32!(decode_box_level_dfl_i16_to_f32, i16, dequant_i16_to_f32);
impl_dfl_level_f32!(decode_box_level_dfl_u16_to_f32, u16, dequant_u16_to_f32);
impl_dfl_level_f32!(decode_box_level_dfl_f16_to_f32, f16, dequant_f16_to_f32);
impl_dfl_level_f32!(decode_box_level_dfl_f32_to_f32, f32, dequant_f32_to_f32);
impl_dfl_level_f16!(decode_box_level_dfl_i8_to_f16, i8, dequant_i8_to_f16);
impl_dfl_level_f16!(decode_box_level_dfl_u8_to_f16, u8, dequant_u8_to_f16);
impl_dfl_level_f16!(decode_box_level_dfl_i16_to_f16, i16, dequant_i16_to_f16);
impl_dfl_level_f16!(decode_box_level_dfl_u16_to_f16, u16, dequant_u16_to_f16);
impl_dfl_level_f16!(decode_box_level_dfl_f16_to_f16, f16, dequant_f16_to_f16);
impl_dfl_level_f16!(decode_box_level_dfl_f32_to_f16, f32, dequant_f32_to_f16);
macro_rules! impl_ltrb_level_f32 {
($name:ident, $i:ty, $dequant:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, level: &LevelPlan, dst: &mut [f32]) {
let h = level.h;
let w = level.w;
debug_assert_eq!(input.len(), h * w * 4);
debug_assert_eq!(dst.len(), 4 * h * w);
let mut deq: [f32; 4] = [0.0_f32; 4];
for anchor in 0..(h * w) {
let in_base = anchor * 4;
$dequant(&input[in_base..in_base + 4], q, &mut deq);
let gx = level.grid_x[anchor];
let gy = level.grid_y[anchor];
let xywh = dist2bbox_anchor_f32(deq, gx, gy, level.stride);
let out_base = anchor * 4;
dst[out_base..out_base + 4].copy_from_slice(&xywh);
}
}
};
}
macro_rules! impl_ltrb_level_f16 {
($name:ident, $i:ty, $dequant:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, level: &LevelPlan, dst: &mut [f16]) {
let h = level.h;
let w = level.w;
debug_assert_eq!(input.len(), h * w * 4);
debug_assert_eq!(dst.len(), 4 * h * w);
let mut deq: [f16; 4] = [f16::ZERO; 4];
for anchor in 0..(h * w) {
let in_base = anchor * 4;
$dequant(&input[in_base..in_base + 4], q, &mut deq);
let gx = f16::from_f32(level.grid_x[anchor]);
let gy = f16::from_f32(level.grid_y[anchor]);
let stride = f16::from_f32(level.stride);
let xywh = dist2bbox_anchor_f16(deq, gx, gy, stride);
let out_base = anchor * 4;
dst[out_base..out_base + 4].copy_from_slice(&xywh);
}
}
};
}
impl_ltrb_level_f32!(decode_box_level_ltrb_i8_to_f32, i8, dequant_i8_to_f32);
impl_ltrb_level_f32!(decode_box_level_ltrb_u8_to_f32, u8, dequant_u8_to_f32);
impl_ltrb_level_f32!(decode_box_level_ltrb_i16_to_f32, i16, dequant_i16_to_f32);
impl_ltrb_level_f32!(decode_box_level_ltrb_u16_to_f32, u16, dequant_u16_to_f32);
impl_ltrb_level_f32!(decode_box_level_ltrb_f16_to_f32, f16, dequant_f16_to_f32);
impl_ltrb_level_f32!(decode_box_level_ltrb_f32_to_f32, f32, dequant_f32_to_f32);
impl_ltrb_level_f16!(decode_box_level_ltrb_i8_to_f16, i8, dequant_i8_to_f16);
impl_ltrb_level_f16!(decode_box_level_ltrb_u8_to_f16, u8, dequant_u8_to_f16);
impl_ltrb_level_f16!(decode_box_level_ltrb_i16_to_f16, i16, dequant_i16_to_f16);
impl_ltrb_level_f16!(decode_box_level_ltrb_u16_to_f16, u16, dequant_u16_to_f16);
impl_ltrb_level_f16!(decode_box_level_ltrb_f16_to_f16, f16, dequant_f16_to_f16);
impl_ltrb_level_f16!(decode_box_level_ltrb_f32_to_f16, f32, dequant_f32_to_f16);
#[cfg(target_arch = "aarch64")]
macro_rules! impl_dfl_level_f32_neon {
($name:ident, $i:ty, $dequant_neon:ident, $softmax:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, level: &LevelPlan, dst: &mut [f32]) {
let h = level.h;
let w = level.w;
let reg_max = level.reg_max;
debug_assert_eq!(input.len(), h * w * 4 * reg_max);
debug_assert_eq!(dst.len(), 4 * h * w);
debug_assert!(reg_max <= MAX_REG_MAX);
let mut deq: [f32; 4 * MAX_REG_MAX] = [0.0_f32; 4 * MAX_REG_MAX];
for anchor in 0..(h * w) {
let in_base = anchor * 4 * reg_max;
unsafe {
crate::per_scale::kernels::neon_baseline::$dequant_neon(
&input[in_base..in_base + 4 * reg_max],
q,
&mut deq[..4 * reg_max],
);
for side in 0..4 {
crate::per_scale::kernels::neon_baseline::$softmax(
&mut deq[side * reg_max..(side + 1) * reg_max],
);
}
}
let ltrb = weighted_sum_4sides_f32(&deq[..4 * reg_max], reg_max);
let gx = level.grid_x[anchor];
let gy = level.grid_y[anchor];
let xywh = dist2bbox_anchor_f32(ltrb, gx, gy, level.stride);
let out_base = anchor * 4;
dst[out_base..out_base + 4].copy_from_slice(&xywh);
}
}
};
}
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_i8_to_f32_neon,
i8,
dequant_i8_to_f32_neon,
softmax_inplace_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_u8_to_f32_neon,
u8,
dequant_u8_to_f32_neon,
softmax_inplace_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_i16_to_f32_neon,
i16,
dequant_i16_to_f32_neon,
softmax_inplace_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_u16_to_f32_neon,
u16,
dequant_u16_to_f32_neon,
softmax_inplace_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_i8_to_f32_neon_fp16,
i8,
dequant_i8_to_f32_neon,
softmax_inplace_f32_neon_fp16
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_u8_to_f32_neon_fp16,
u8,
dequant_u8_to_f32_neon,
softmax_inplace_f32_neon_fp16
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_i16_to_f32_neon_fp16,
i16,
dequant_i16_to_f32_neon,
softmax_inplace_f32_neon_fp16
);
#[cfg(target_arch = "aarch64")]
impl_dfl_level_f32_neon!(
decode_box_level_dfl_u16_to_f32_neon_fp16,
u16,
dequant_u16_to_f32_neon,
softmax_inplace_f32_neon_fp16
);
#[cfg(target_arch = "aarch64")]
macro_rules! impl_ltrb_level_f32_neon {
($name:ident, $i:ty, $dequant_neon:ident) => {
#[allow(dead_code)]
pub(crate) fn $name(input: &[$i], q: Quantization, level: &LevelPlan, dst: &mut [f32]) {
let h = level.h;
let w = level.w;
debug_assert_eq!(input.len(), h * w * 4);
debug_assert_eq!(dst.len(), 4 * h * w);
let n = h * w * 4;
let mut deq: Vec<f32> = vec![0.0_f32; n];
unsafe {
crate::per_scale::kernels::neon_baseline::$dequant_neon(input, q, &mut deq);
}
for anchor in 0..(h * w) {
let base = anchor * 4;
let dq = [deq[base], deq[base + 1], deq[base + 2], deq[base + 3]];
let gx = level.grid_x[anchor];
let gy = level.grid_y[anchor];
let xywh = dist2bbox_anchor_f32(dq, gx, gy, level.stride);
dst[base..base + 4].copy_from_slice(&xywh);
}
}
};
}
#[cfg(target_arch = "aarch64")]
impl_ltrb_level_f32_neon!(
decode_box_level_ltrb_i8_to_f32_neon,
i8,
dequant_i8_to_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_ltrb_level_f32_neon!(
decode_box_level_ltrb_u8_to_f32_neon,
u8,
dequant_u8_to_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_ltrb_level_f32_neon!(
decode_box_level_ltrb_i16_to_f32_neon,
i16,
dequant_i16_to_f32_neon
);
#[cfg(target_arch = "aarch64")]
impl_ltrb_level_f32_neon!(
decode_box_level_ltrb_u16_to_f32_neon,
u16,
dequant_u16_to_f32_neon
);
#[cfg(test)]
mod tests {
use super::*;
use crate::per_scale::kernels::grids::make_anchor_grid;
use crate::per_scale::plan::LevelPlan;
use crate::Quantization;
fn level_1x1(reg_max: usize, stride: f32, box_channels: usize) -> LevelPlan {
let (gx, gy) = make_anchor_grid(1, 1);
LevelPlan {
stride,
h: 1,
w: 1,
reg_max,
anchor_offset: 0,
grid_x: gx,
grid_y: gy,
box_shape: vec![1, 1, 1, box_channels].into_boxed_slice(),
score_shape: vec![1, 1, 1, 80].into_boxed_slice(),
mc_shape: None,
layout: crate::per_scale::plan::Layout::Nhwc,
}
}
#[test]
fn dfl_one_hot_at_bin_5_yields_distance_5() {
let mut logits = [0.0_f32; 64];
for side in 0..4 {
logits[side * 16 + 5] = 100.0;
}
let mut out_boxes = [0.0_f32; 4];
decode_box_level_dfl_f32_to_f32(
&logits,
Quantization::identity(),
&level_1x1(16, 8.0, 64),
&mut out_boxes,
);
assert!((out_boxes[0] - 4.0).abs() < 1e-3);
assert!((out_boxes[2] - 80.0).abs() < 1e-3);
}
#[test]
fn ltrb_direct_distances_yield_dist2bbox() {
let logits = [1.0_f32, 1.0, 3.0, 3.0];
let mut out_boxes = [0.0_f32; 4];
decode_box_level_ltrb_f32_to_f32(
&logits,
Quantization::identity(),
&level_1x1(1, 8.0, 4),
&mut out_boxes,
);
assert!((out_boxes[0] - 12.0).abs() < 1e-3);
assert!((out_boxes[2] - 32.0).abs() < 1e-3);
}
#[test]
fn dfl_2x2_uniform_logits_each_anchor_at_centre() {
let logits = [1.0_f32; 2 * 2 * 4 * 4]; let (gx, gy) = make_anchor_grid(2, 2);
let lvl = LevelPlan {
stride: 4.0,
h: 2,
w: 2,
reg_max: 4,
anchor_offset: 0,
grid_x: gx,
grid_y: gy,
box_shape: vec![1, 2, 2, 16].into_boxed_slice(),
score_shape: vec![1, 2, 2, 80].into_boxed_slice(),
mc_shape: None,
layout: crate::per_scale::plan::Layout::Nhwc,
};
let mut out = [0.0_f32; 16]; decode_box_level_dfl_f32_to_f32(&logits, Quantization::identity(), &lvl, &mut out);
assert!((out[0] - 2.0).abs() < 1e-3);
assert!((out[1] - 2.0).abs() < 1e-3);
assert!((out[2] - 12.0).abs() < 1e-3);
assert!((out[4] - 6.0).abs() < 1e-3, "expected xc=6, got {}", out[4]);
assert!((out[9] - 6.0).abs() < 1e-3, "expected yc=6, got {}", out[9]);
}
#[test]
fn dfl_i8_dequant_then_decode_matches_f32_input() {
let q = Quantization::new(0.1, -128);
let mut input_i8 = [-128_i8; 64];
for side in 0..4 {
input_i8[side * 16 + 5] = 127;
}
let mut out_i8 = [0.0_f32; 4];
decode_box_level_dfl_i8_to_f32(&input_i8, q, &level_1x1(16, 8.0, 64), &mut out_i8);
let mut logits_f32 = [0.0_f32; 64];
for side in 0..4 {
logits_f32[side * 16 + 5] = 25.5;
}
let mut out_f32 = [0.0_f32; 4];
decode_box_level_dfl_f32_to_f32(
&logits_f32,
Quantization::identity(),
&level_1x1(16, 8.0, 64),
&mut out_f32,
);
for i in 0..4 {
assert!(
(out_i8[i] - out_f32[i]).abs() < 1e-2,
"box[{i}]: i8={} f32={}",
out_i8[i],
out_f32[i]
);
}
}
}