use ndarray::{Array2, ArrayView2, ArrayView3};
use num_traits::{AsPrimitive, Float, PrimInt};
use crate::{
byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
configs::Detection,
dequant_detect_box,
float::{nms_float, postprocess_boxes_float},
BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
};
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct ModelPackDetectionConfig {
pub(crate) anchors: Vec<[f32; 2]>,
pub(crate) quantization: Option<Quantization>,
}
impl TryFrom<&Detection> for ModelPackDetectionConfig {
type Error = DecoderError;
fn try_from(value: &Detection) -> Result<Self, DecoderError> {
Ok(Self {
anchors: value.anchors.clone().ok_or_else(|| {
DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
})?,
quantization: value.quantization.map(Quantization::from),
})
}
}
pub(crate) fn decode_modelpack_det<
BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: (ArrayView2<BOX>, Quantization),
scores_tensor: (ArrayView2<SCORE>, Quantization),
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
impl_modelpack_quant::<XYXY, _, _>(
boxes_tensor,
scores_tensor,
score_threshold,
iou_threshold,
max_det,
output_boxes,
)
}
pub(crate) fn decode_modelpack_float<
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: ArrayView2<BOX>,
scores_tensor: ArrayView2<SCORE>,
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
impl_modelpack_float::<XYXY, _, _>(
boxes_tensor,
scores_tensor,
score_threshold,
iou_threshold,
max_det,
output_boxes,
)
}
#[cfg(test)]
pub(crate) fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
outputs: &[ArrayView3<D>],
configs: &[ModelPackDetectionConfig],
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) {
impl_modelpack_split_quant::<XYWH, D>(
outputs,
configs,
score_threshold,
iou_threshold,
max_det,
output_boxes,
)
}
pub(crate) fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
outputs: &[ArrayView3<D>],
configs: &[ModelPackDetectionConfig],
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) {
impl_modelpack_split_float::<XYWH, D>(
outputs,
configs,
score_threshold,
iou_threshold,
max_det,
output_boxes,
);
}
pub(crate) fn impl_modelpack_quant<
B: BBoxTypeTrait,
BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
>(
boxes: (ArrayView2<BOX>, Quantization),
scores: (ArrayView2<SCORE>, Quantization),
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
let (boxes_tensor, quant_boxes) = boxes;
let (scores_tensor, quant_scores) = scores;
let boxes = {
let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
postprocess_boxes_quant::<B, _, _>(
score_threshold,
boxes_tensor,
scores_tensor,
quant_boxes,
)
};
let boxes = nms_int(iou_threshold, Some(max_det), boxes);
output_boxes.clear();
for b in boxes.into_iter().take(max_det) {
output_boxes.push(dequant_detect_box(&b, quant_scores));
}
}
pub(crate) fn impl_modelpack_float<
B: BBoxTypeTrait,
BOX: Float + AsPrimitive<f32> + Send + Sync,
SCORE: Float + AsPrimitive<f32> + Send + Sync,
>(
boxes_tensor: ArrayView2<BOX>,
scores_tensor: ArrayView2<SCORE>,
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) where
f32: AsPrimitive<SCORE>,
{
let boxes =
postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
let boxes = nms_float(iou_threshold, Some(max_det), boxes);
output_boxes.clear();
for b in boxes.into_iter().take(max_det) {
output_boxes.push(b);
}
}
#[cfg(test)]
pub(crate) fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
outputs: &[ArrayView3<D>],
configs: &[ModelPackDetectionConfig],
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) {
let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
let boxes = postprocess_boxes_float::<B, _, _>(
score_threshold,
boxes_tensor.view(),
scores_tensor.view(),
);
let boxes = nms_float(iou_threshold, Some(max_det), boxes);
output_boxes.clear();
for b in boxes.into_iter().take(max_det) {
output_boxes.push(b);
}
}
pub(crate) fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
outputs: &[ArrayView3<D>],
configs: &[ModelPackDetectionConfig],
score_threshold: f32,
iou_threshold: f32,
max_det: usize,
output_boxes: &mut Vec<DetectBox>,
) {
let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
let boxes = postprocess_boxes_float::<B, _, _>(
score_threshold,
boxes_tensor.view(),
scores_tensor.view(),
);
let boxes = nms_float(iou_threshold, Some(max_det), boxes);
output_boxes.clear();
for b in boxes.into_iter().take(max_det) {
output_boxes.push(b);
}
}
#[cfg(test)]
pub(crate) fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
outputs: &[ArrayView3<T>],
config: &[ModelPackDetectionConfig],
) -> (Array2<f32>, Array2<f32>) {
let mut total_capacity = 0;
let mut nc = 0;
for (p, detail) in outputs.iter().zip(config) {
let shape = p.shape();
let na = detail.anchors.len();
nc = *shape
.last()
.expect("Shape must have at least one dimension")
/ na
- 5;
total_capacity += shape[0] * shape[1] * na;
}
let mut bboxes = Vec::with_capacity(total_capacity * 4);
let mut bscores = Vec::with_capacity(total_capacity * nc);
for (p, detail) in outputs.iter().zip(config) {
let anchors = &detail.anchors;
let na = detail.anchors.len();
let shape = p.shape();
assert_eq!(
shape.iter().product::<usize>(),
p.len(),
"Shape product doesn't match tensor length"
);
let p_sigmoid = if let Some(quant) = &detail.quantization {
let scaled_zero = -quant.zero_point as f32 * quant.scale;
p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
} else {
p.mapv(|x| fast_sigmoid_impl(x.as_()))
};
let p_sigmoid = p_sigmoid.as_standard_layout();
let p = p_sigmoid
.as_slice()
.expect("Sigmoids are not in standard layout");
let height = shape[0];
let width = shape[1];
let div_width = 1.0 / width as f32;
let div_height = 1.0 / height as f32;
let mut grid = Vec::with_capacity(height * width * na * 2);
for y in 0..height {
for x in 0..width {
for _ in 0..na {
grid.push(x as f32 - 0.5);
grid.push(y as f32 - 0.5);
}
}
}
for ((p, g), anchor) in p
.chunks_exact(nc + 5)
.zip(grid.chunks_exact(2))
.zip(anchors.iter().cycle())
{
let (x, y) = (p[0], p[1]);
let x = (x * 2.0 + g[0]) * div_width;
let y = (y * 2.0 + g[1]) * div_height;
let (w, h) = (p[2], p[3]);
let w = w * w * 4.0 * anchor[0];
let h = h * h * 4.0 * anchor[1];
bboxes.push(x);
bboxes.push(y);
bboxes.push(w);
bboxes.push(h);
if nc == 1 {
bscores.push(p[4]);
} else {
let obj = p[4];
let probs = p[5..].iter().map(|x| *x * obj);
bscores.extend(probs);
}
}
}
debug_assert_eq!(bboxes.len() % 4, 0);
debug_assert_eq!(bscores.len() % nc, 0);
let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
.expect("Failed to create bboxes array");
let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
.expect("Failed to create bscores array");
(bboxes, bscores)
}
pub(crate) fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
outputs: &[ArrayView3<T>],
config: &[ModelPackDetectionConfig],
) -> (Array2<f32>, Array2<f32>) {
let mut total_capacity = 0;
let mut nc = 0;
for (p, detail) in outputs.iter().zip(config) {
let shape = p.shape();
let na = detail.anchors.len();
nc = *shape
.last()
.expect("Shape must have at least one dimension")
/ na
- 5;
total_capacity += shape[0] * shape[1] * na;
}
let mut bboxes = Vec::with_capacity(total_capacity * 4);
let mut bscores = Vec::with_capacity(total_capacity * nc);
for (p, detail) in outputs.iter().zip(config) {
let anchors = &detail.anchors;
let na = detail.anchors.len();
let shape = p.shape();
assert_eq!(
shape.iter().product::<usize>(),
p.len(),
"Shape product doesn't match tensor length"
);
let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
let p_sigmoid = p_sigmoid.as_standard_layout();
let p = p_sigmoid
.as_slice()
.expect("Sigmoids are not in standard layout");
let height = shape[0];
let width = shape[1];
let div_width = 1.0 / width as f32;
let div_height = 1.0 / height as f32;
let mut grid = Vec::with_capacity(height * width * na * 2);
for y in 0..height {
for x in 0..width {
for _ in 0..na {
grid.push(x as f32 - 0.5);
grid.push(y as f32 - 0.5);
}
}
}
for ((p, g), anchor) in p
.chunks_exact(nc + 5)
.zip(grid.chunks_exact(2))
.zip(anchors.iter().cycle())
{
let (x, y) = (p[0], p[1]);
let x = (x * 2.0 + g[0]) * div_width;
let y = (y * 2.0 + g[1]) * div_height;
let (w, h) = (p[2], p[3]);
let w = w * w * 4.0 * anchor[0];
let h = h * h * 4.0 * anchor[1];
bboxes.push(x);
bboxes.push(y);
bboxes.push(w);
bboxes.push(h);
if nc == 1 {
bscores.push(p[4]);
} else {
let obj = p[4];
let probs = p[5..].iter().map(|x| *x * obj);
bscores.extend(probs);
}
}
}
debug_assert_eq!(bboxes.len() % 4, 0);
debug_assert_eq!(bscores.len() % nc, 0);
let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
.expect("Failed to create bboxes array");
let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
.expect("Failed to create bscores array");
(bboxes, bscores)
}
#[inline(always)]
fn fast_sigmoid_impl(f: f32) -> f32 {
if f.abs() > 80.0 {
f.signum() * 0.5 + 0.5
} else {
1.0 / (1.0 + fast_math::exp_raw(-f))
}
}
pub(crate) fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
use argminmax::ArgMinMax;
assert!(
segmentation.shape()[2] > 1,
"Model Instance Segmentation should have shape (H, W, x) where x > 1"
);
let height = segmentation.shape()[0];
let width = segmentation.shape()[1];
let channels = segmentation.shape()[2];
let segmentation = segmentation.as_standard_layout();
let seg = segmentation
.as_slice()
.expect("Segmentation is not in standard layout");
let argmax = seg
.chunks_exact(channels)
.map(|x| x.argmax() as u8)
.collect::<Vec<_>>();
Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod modelpack_tests {
#![allow(clippy::excessive_precision)]
use ndarray::Array3;
use crate::configs::{DecoderType, DimName};
use super::*;
#[test]
fn test_detection_config() {
let det = Detection {
anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
quantization: Some((0.1, 128).into()),
decoder: DecoderType::ModelPack,
shape: vec![1, 9, 17, 18],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 9),
(DimName::Width, 17),
(DimName::NumAnchorsXFeatures, 18),
],
normalized: Some(true),
};
let config = ModelPackDetectionConfig::try_from(&det).unwrap();
assert_eq!(
config,
ModelPackDetectionConfig {
anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
quantization: Some(Quantization::new(0.1, 128)),
}
);
let det = Detection {
anchors: None,
quantization: Some((0.1, 128).into()),
decoder: DecoderType::ModelPack,
shape: vec![1, 9, 17, 18],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 9),
(DimName::Width, 17),
(DimName::NumAnchorsXFeatures, 18),
],
normalized: Some(true),
};
let result = ModelPackDetectionConfig::try_from(&det);
assert!(
matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
);
}
#[test]
fn test_fast_sigmoid() {
fn full_sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
for i in -2550..=2550 {
let x = i as f32 * 0.1;
let fast = fast_sigmoid_impl(x);
let full = full_sigmoid(x);
let diff = (fast - full).abs();
assert!(
diff < 0.0005,
"Fast sigmoid differs from full sigmoid by {} at input {}",
diff,
x
);
}
}
#[test]
fn test_modelpack_segmentation_to_mask() {
let seg = Array3::from_shape_vec(
(2, 2, 3),
vec![
0u8, 10, 5, 20, 15, 25, 30, 5, 10, 0, 0, 0, ],
)
.unwrap();
let mask = modelpack_segmentation_to_mask(seg.view());
let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
assert_eq!(mask, expected_mask);
}
#[test]
#[should_panic(
expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
)]
fn test_modelpack_segmentation_to_mask_invalid() {
let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
let _ = modelpack_segmentation_to_mask(seg.view());
}
}