#![allow(
unsafe_code,
clippy::doc_markdown,
clippy::too_many_lines,
clippy::if_not_else,
clippy::ptr_as_ptr,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
use std::collections::HashMap;
use wide::{CmpGt, f32x8};
use fast_image_resize::images::Image;
use fast_image_resize::{FilterType, PixelType, ResizeAlg, ResizeOptions, Resizer};
use ndarray::{Array2, Array3, ArrayView1, ArrayViewMut2, Zip, s};
use crate::inference::InferenceConfig;
use crate::preprocessing::{PreprocessResult, clip_coords, scale_coords};
use crate::results::{Boxes, Keypoints, Masks, Obb, Probs, Results, Speed};
use crate::task::Task;
use crate::utils::{nms_per_class, nms_rotated_per_class};
#[must_use]
#[allow(
clippy::too_many_arguments,
clippy::similar_names,
clippy::implicit_hasher
)]
pub fn postprocess(
outputs: Vec<(&[f32], Vec<usize>)>,
task: Task,
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
end2end: bool,
kpt_shape: Option<(usize, usize)>,
) -> Results {
match task {
Task::Detect => {
let (output, shape) = &outputs[0];
if end2end || is_end2end_detect_shape(shape) {
postprocess_detect_end2end(
output,
shape,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
} else {
postprocess_detect(
output,
shape,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
}
}
Task::Segment => {
let proto_channels = outputs
.get(1)
.and_then(|(_, s)| if s.len() == 4 { Some(s[1]) } else { None });
if end2end || is_end2end_segment_shape(&outputs[0].1, proto_channels) {
postprocess_segment_end2end(
outputs,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
} else {
postprocess_segment(
outputs,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
}
}
Task::Pose => {
let (output, shape) = &outputs[0];
let resolved_kpt = kpt_shape.or_else(|| infer_end2end_kpt_shape(shape));
let is_end2end = end2end
|| resolved_kpt.is_some_and(|(nk, kd)| is_end2end_pose_shape(shape, nk, kd));
if is_end2end {
let (nk, kpt_dim) = resolved_kpt.unwrap_or((17, 3));
postprocess_pose_end2end(
output,
shape,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
nk,
kpt_dim,
)
} else {
postprocess_pose(
output,
shape,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
}
}
Task::Classify => {
let (output, _) = &outputs[0];
postprocess_classify(output, names, orig_img, path, speed, inference_shape)
}
Task::Obb => {
let (output, shape) = &outputs[0];
if end2end || is_end2end_obb_shape(shape) {
postprocess_obb_end2end(
output,
shape,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
} else {
postprocess_obb(
output,
shape,
preprocess,
config,
names,
orig_img,
path,
speed,
inference_shape,
)
}
}
}
}
fn is_end2end_detect_shape(shape: &[usize]) -> bool {
shape.len() == 3 && shape[2] == 6 && shape[1] <= 4096
}
fn is_end2end_segment_shape(shape: &[usize], proto_channels: Option<usize>) -> bool {
proto_channels.is_some_and(|nm| shape.len() == 3 && shape[2] == 6 + nm && shape[1] <= 4096)
}
fn is_end2end_pose_shape(shape: &[usize], nk: usize, kpt_dim: usize) -> bool {
shape.len() == 3 && shape[2] == 6 + nk * kpt_dim && shape[1] <= 4096
}
fn infer_end2end_kpt_shape(shape: &[usize]) -> Option<(usize, usize)> {
if shape.len() != 3 || shape[1] == 0 || shape[1] > 4096 || shape[2] <= 6 {
return None;
}
let kpt_feats = shape[2] - 6;
let div3 = kpt_feats.is_multiple_of(3);
let div2 = kpt_feats.is_multiple_of(2);
match (div3, div2) {
(true, false) => Some((kpt_feats / 3, 3)),
(false, true) => Some((kpt_feats / 2, 2)),
_ => None, }
}
fn is_end2end_obb_shape(shape: &[usize]) -> bool {
shape.len() == 3 && shape[2] == 7 && shape[1] <= 4096
}
#[allow(
clippy::too_many_arguments,
clippy::similar_names,
clippy::cast_precision_loss
)]
fn postprocess_detect(
output: &[f32],
output_shape: &[usize],
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
let (num_classes, num_predictions, is_transposed) =
parse_detect_shape(output_shape, names.len());
if output.is_empty() || num_predictions == 0 {
return results;
}
let boxes_data = extract_detect_boxes(
output,
num_classes,
num_predictions,
is_transposed,
preprocess,
config,
);
if !boxes_data.is_empty() {
results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
}
results
}
fn parse_detect_shape(shape: &[usize], expected_classes: usize) -> (usize, usize, bool) {
match shape.len() {
2 => {
let (a, b) = (shape[0], shape[1]);
if a < 4 && b < 4 {
return (expected_classes.max(1), 0, false);
}
if expected_classes == 0 {
let (num_features, num_preds, transposed) =
if a < b { (a, b, false) } else { (b, a, true) };
let inferred_classes = num_features.saturating_sub(4);
return (inferred_classes.max(1), num_preds, transposed);
}
if a == 4 + expected_classes || (a >= 4 && a > b) {
(a.saturating_sub(4), b, false)
} else {
(b.saturating_sub(4), a, true)
}
}
3 => {
let (a, b) = (shape[1], shape[2]);
if b == 0 || a < 4 {
return (expected_classes.max(1), 0, false);
}
if expected_classes == 0 {
let (num_features, num_preds, transposed) =
if a < b { (a, b, false) } else { (b, a, true) };
let inferred_classes = num_features.saturating_sub(4);
return (inferred_classes.max(1), num_preds, transposed);
}
if a == 4 + expected_classes || (expected_classes > 0 && a < b) {
(a.saturating_sub(4), b, false)
} else {
(b.saturating_sub(4), a, true)
}
}
_ => (expected_classes.max(1), 0, false),
}
}
#[derive(Clone, Copy)]
struct Candidate {
bbox: [f32; 4],
score: f32,
class: usize,
}
#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)]
fn extract_detect_boxes(
output: &[f32],
num_classes: usize,
num_predictions: usize,
is_transposed: bool,
preprocess: &PreprocessResult,
config: &InferenceConfig,
) -> Array2<f32> {
let feat_count = 4 + num_classes;
let (scale_y, scale_x) = preprocess.scale;
let (pad_top, pad_left) = preprocess.padding;
let orig_shape = preprocess.orig_shape;
let (max_w, max_h) = (orig_shape.1 as f32, orig_shape.0 as f32);
let conf_thresh = config.confidence_threshold;
let max_det = config.max_det;
let iou_thresh = config.iou_threshold;
let conf_v = f32x8::splat(conf_thresh);
let mut candidates: Vec<Candidate> = Vec::with_capacity(256);
if !is_transposed {
let mut max_scores = vec![conf_thresh; num_predictions];
let mut max_classes = vec![0usize; num_predictions];
for c in 0..num_classes {
let offset = (4 + c) * num_predictions;
let class_scores = &output[offset..offset + num_predictions];
for (idx, &score) in class_scores.iter().enumerate() {
if score > max_scores[idx] {
max_scores[idx] = score;
max_classes[idx] = c;
}
}
}
for (idx, &score) in max_scores.iter().enumerate() {
if score > conf_thresh {
let best_class = max_classes[idx];
if !config.keep_class(best_class) {
continue;
}
let cx = unsafe { *output.get_unchecked(idx) };
let cy = unsafe { *output.get_unchecked(num_predictions + idx) };
let w = unsafe { *output.get_unchecked(2 * num_predictions + idx) };
let h = unsafe { *output.get_unchecked(3 * num_predictions + idx) };
let x1 = (cx - w * 0.5 - pad_left) / scale_x;
let y1 = (cy - h * 0.5 - pad_top) / scale_y;
let x2 = (cx + w * 0.5 - pad_left) / scale_x;
let y2 = (cy + h * 0.5 - pad_top) / scale_y;
candidates.push(Candidate {
bbox: [x1, y1, x2, y2],
score,
class: best_class,
});
}
}
} else {
for idx in 0..num_predictions {
let base = idx * feat_count;
let row_ptr = unsafe { output.as_ptr().add(base + 4) };
let mut best_score = conf_thresh;
let mut best_class = 0;
for c_idx in (0..num_classes).step_by(8) {
if num_classes - c_idx >= 8 {
let scores: f32x8 =
unsafe { (row_ptr.add(c_idx) as *const f32x8).read_unaligned() };
if scores.simd_gt(conf_v).any() {
for i in 0..8 {
let s = unsafe { *row_ptr.add(c_idx + i) };
if s > best_score {
best_score = s;
best_class = c_idx + i;
}
}
}
} else {
for i in c_idx..num_classes {
let s = unsafe { *row_ptr.add(i) };
if s > best_score {
best_score = s;
best_class = i;
}
}
}
}
if best_score > conf_thresh {
if !config.keep_class(best_class) {
continue;
}
let cx = unsafe { *output.get_unchecked(base) };
let cy = unsafe { *output.get_unchecked(base + 1) };
let w = unsafe { *output.get_unchecked(base + 2) };
let h = unsafe { *output.get_unchecked(base + 3) };
let x1 = (cx - w * 0.5 - pad_left) / scale_x;
let y1 = (cy - h * 0.5 - pad_top) / scale_y;
let x2 = (cx + w * 0.5 - pad_left) / scale_x;
let y2 = (cy + h * 0.5 - pad_top) / scale_y;
candidates.push(Candidate {
bbox: [x1, y1, x2, y2],
score: best_score,
class: best_class,
});
}
}
}
if candidates.is_empty() {
return Array2::zeros((0, 6));
}
let nms_limit = (max_det * 10).min(candidates.len());
if candidates.len() > nms_limit {
candidates.select_nth_unstable_by(nms_limit, |a, b| b.score.partial_cmp(&a.score).unwrap());
candidates.truncate(nms_limit);
}
candidates.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
let n = candidates.len();
let mut x1 = Vec::with_capacity(n);
let mut y1 = Vec::with_capacity(n);
let mut x2 = Vec::with_capacity(n);
let mut y2 = Vec::with_capacity(n);
let mut areas = Vec::with_capacity(n);
for c in &candidates {
x1.push(c.bbox[0]);
y1.push(c.bbox[1]);
x2.push(c.bbox[2]);
y2.push(c.bbox[3]);
areas.push((c.bbox[2] - c.bbox[0]) * (c.bbox[3] - c.bbox[1]));
}
let mut suppressed = vec![false; n];
let mut keep = Vec::with_capacity(max_det);
let iou_v = f32x8::splat(iou_thresh);
for i in 0..n {
if suppressed[i] {
continue;
}
keep.push(i);
if keep.len() >= max_det {
break;
}
let ax1 = f32x8::splat(x1[i]);
let ay1 = f32x8::splat(y1[i]);
let ax2 = f32x8::splat(x2[i]);
let ay2 = f32x8::splat(y2[i]);
let aa = f32x8::splat(areas[i]);
let ac = candidates[i].class;
let mut j = i + 1;
while j < n {
if n - j >= 8 {
if (0..8).any(|k| candidates[j + k].class == ac && !suppressed[j + k]) {
let bx1 = unsafe { (x1.as_ptr().add(j) as *const f32x8).read_unaligned() };
let by1 = unsafe { (y1.as_ptr().add(j) as *const f32x8).read_unaligned() };
let bx2 = unsafe { (x2.as_ptr().add(j) as *const f32x8).read_unaligned() };
let by2 = unsafe { (y2.as_ptr().add(j) as *const f32x8).read_unaligned() };
let ba = unsafe { (areas.as_ptr().add(j) as *const f32x8).read_unaligned() };
let ix1 = ax1.max(bx1);
let iy1 = ay1.max(by1);
let ix2 = ax2.min(bx2);
let iy2 = ay2.min(by2);
let iw = (ix2 - ix1).max(f32x8::ZERO);
let ih = (iy2 - iy1).max(f32x8::ZERO);
let ia = iw * ih;
let iou = ia / (aa + ba - ia);
let mask = iou.simd_gt(iou_v).to_bitmask() as u8;
if mask != 0 {
for k in 0..8 {
if (mask & (1 << k)) != 0 && candidates[j + k].class == ac {
suppressed[j + k] = true;
}
}
}
}
j += 8;
} else {
for k in j..n {
if !suppressed[k] && candidates[k].class == ac {
let ix1 = x1[i].max(x1[k]);
let iy1 = y1[i].max(y1[k]);
let ix2 = x2[i].min(x2[k]);
let iy2 = y2[i].min(y2[k]);
let iw = (ix2 - ix1).max(0.0);
let ih = (iy2 - iy1).max(0.0);
let ia = iw * ih;
let iou = ia / (areas[i] + areas[k] - ia);
if iou > iou_thresh {
suppressed[k] = true;
}
}
}
break;
}
}
}
let num_kept = keep.len();
let mut result = Array2::zeros((num_kept, 6));
for (out_idx, &idx) in keep.iter().enumerate() {
let c = &candidates[idx];
result[[out_idx, 0]] = c.bbox[0].clamp(0.0, max_w);
result[[out_idx, 1]] = c.bbox[1].clamp(0.0, max_h);
result[[out_idx, 2]] = c.bbox[2].clamp(0.0, max_w);
result[[out_idx, 3]] = c.bbox[3].clamp(0.0, max_h);
result[[out_idx, 4]] = c.score;
result[[out_idx, 5]] = c.class as f32;
}
result
}
#[allow(
clippy::too_many_arguments,
clippy::similar_names,
clippy::cast_precision_loss,
clippy::too_many_lines,
clippy::needless_pass_by_value,
clippy::manual_let_else,
clippy::cast_possible_truncation
)]
fn postprocess_segment(
outputs: Vec<(&[f32], Vec<usize>)>,
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
if outputs.len() < 2 {
eprintln!(
"WARNING ⚠️ Segmentation model missing protos output (expected 2 outputs, got {}). Returning empty masks.",
outputs.len()
);
return results;
}
let (output0, shape0) = &outputs[0];
let (output1, shape1) = &outputs[1];
let num_masks = 32;
let expected_features = 4 + names.len() + num_masks;
let (num_preds, is_transposed) = if shape0.len() == 3 {
let (a, b) = (shape0[1], shape0[2]);
if a == expected_features {
(b, false) } else if b == expected_features {
(a, true) } else {
if a < b { (b, false) } else { (a, true) }
}
} else {
(0, false)
};
if output0.is_empty() || num_preds == 0 {
return results;
}
let output_2d = if is_transposed {
Array2::from_shape_vec((num_preds, expected_features), output0.to_vec())
.unwrap_or_else(|_| Array2::zeros((0, 0)))
} else {
let arr = Array2::from_shape_vec((expected_features, num_preds), output0.to_vec())
.unwrap_or_else(|_| Array2::zeros((0, 0)));
arr.t().to_owned()
};
let mut candidates = Vec::new();
for i in 0..num_preds {
let scores = output_2d.slice(s![i, 4..4 + names.len()]);
let (best_class, best_score) = scores
.iter()
.enumerate()
.max_by(|&(_, a), &(_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or((0, 0.0), |(idx, &score)| (idx, score));
if best_score < config.confidence_threshold {
continue;
}
let cx = output_2d[[i, 0]];
let cy = output_2d[[i, 1]];
let w = output_2d[[i, 2]];
let h = output_2d[[i, 3]];
let x1 = cx - w / 2.0;
let y1 = cy - h / 2.0;
let x2 = cx + w / 2.0;
let y2 = cy + h / 2.0;
let scaled = scale_coords(&[x1, y1, x2, y2], preprocess.scale, preprocess.padding);
let clipped = clip_coords(&scaled, preprocess.orig_shape);
if !config.keep_class(best_class) {
continue;
}
candidates.push((
[clipped[0], clipped[1], clipped[2], clipped[3]],
best_score,
best_class,
i, ));
}
if candidates.is_empty() {
return results;
}
let nms_candidates: Vec<_> = candidates
.iter()
.map(|(bbox, score, class, _)| (*bbox, *score, *class))
.collect();
let keep_indices = nms_per_class(&nms_candidates, config.iou_threshold);
let num_kept = keep_indices.len().min(config.max_det);
let mut boxes_data = Array2::zeros((num_kept, 6));
let mut mask_coeffs = Array2::zeros((num_kept, num_masks));
for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
let (bbox, score, class, orig_idx) = &candidates[keep_idx];
boxes_data[[out_idx, 0]] = bbox[0];
boxes_data[[out_idx, 1]] = bbox[1];
boxes_data[[out_idx, 2]] = bbox[2];
boxes_data[[out_idx, 3]] = bbox[3];
boxes_data[[out_idx, 4]] = *score;
boxes_data[[out_idx, 5]] = *class as f32;
let start = 4 + names.len();
let coeffs = output_2d.slice(s![*orig_idx, start..start + num_masks]);
for m in 0..num_masks {
mask_coeffs[[out_idx, m]] = coeffs[m];
}
}
results.boxes = Some(Boxes::new(boxes_data.clone(), preprocess.orig_shape));
if shape1.len() < 4 {
eprintln!(
"WARNING ⚠️ Protos output has unexpected shape (expected 4 dims, got {}). Skipping mask generation.",
shape1.len()
);
return results;
}
let mh = shape1[2];
let mw = shape1[3];
if shape1[1] != num_masks {
eprintln!(
"WARNING ⚠️ Protos output has {} mask channels, expected {}. Mask quality may be affected.",
shape1[1], num_masks
);
}
let protos = match Array2::from_shape_vec((num_masks, mh * mw), output1.to_vec()) {
Ok(arr) => arr,
Err(e) => {
eprintln!("WARNING ⚠️ Failed to create protos array: {e}. Skipping mask generation.");
return results;
}
};
let masks_flat = mask_coeffs.dot(&protos);
let (oh, ow) = preprocess.orig_shape;
let (th, tw) = inference_shape;
let (pad_top, pad_left) = preprocess.padding;
let scale_w = mw as f32 / tw as f32;
let scale_h = mh as f32 / th as f32;
let crop_x = pad_left * scale_w;
let crop_y = pad_top * scale_h;
let crop_w = 2.0f32.mul_add(-crop_x, mw as f32);
let crop_h = 2.0f32.mul_add(-crop_y, mh as f32);
let mut masks_data = Array3::zeros((num_kept, oh as usize, ow as usize));
Zip::from(masks_data.outer_iter_mut())
.and(masks_flat.outer_iter())
.and(boxes_data.outer_iter())
.par_for_each(
|mut mask_out: ArrayViewMut2<f32>,
mask_flat: ArrayView1<f32>,
box_data: ArrayView1<f32>| {
let mut resizer = Resizer::new();
let resize_alg = ResizeAlg::Convolution(FilterType::Bilinear);
let f32_data: Vec<f32> = mask_flat
.iter()
.map(|&val| 1.0 / (1.0 + (-val).exp()))
.collect();
let src_bytes: &[u8] = bytemuck::cast_slice(&f32_data);
let src_image = match Image::from_vec_u8(
mw as u32,
mh as u32,
src_bytes.to_vec(),
PixelType::F32,
) {
Ok(img) => img,
Err(_) => return, };
let mut dst_image = Image::new(ow, oh, PixelType::F32);
let safe_crop_x = f64::from(crop_x.max(0.0));
let safe_crop_y = f64::from(crop_y.max(0.0));
let safe_crop_w = f64::from(crop_w.max(1.0).min(mw as f32));
let safe_crop_h = f64::from(crop_h.max(1.0).min(mh as f32));
let options = ResizeOptions::new().resize_alg(resize_alg).crop(
safe_crop_x,
safe_crop_y,
safe_crop_w,
safe_crop_h,
);
if resizer
.resize(&src_image, &mut dst_image, &options)
.is_err()
{
return;
}
let dst_bytes = dst_image.buffer();
let dst_slice: &[f32] = bytemuck::cast_slice(dst_bytes);
let x1 = box_data[0].max(0.0).min(ow as f32);
let y1 = box_data[1].max(0.0).min(oh as f32);
let x2 = box_data[2].max(0.0).min(ow as f32);
let y2 = box_data[3].max(0.0).min(oh as f32);
for y in 0..oh as usize {
for x in 0..ow as usize {
let val = dst_slice[y * ow as usize + x];
let x_f = x as f32;
let y_f = y as f32;
if x_f >= x1 && x_f <= x2 && y_f >= y1 && y_f <= y2 {
mask_out[[y, x]] = val;
}
}
}
},
);
results.masks = Some(Masks::new(masks_data, preprocess.orig_shape));
results
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::similar_names,
clippy::type_complexity,
clippy::cast_precision_loss,
clippy::doc_lazy_continuation
)]
fn postprocess_pose(
output: &[f32],
output_shape: &[usize],
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
let num_keypoints = 17;
let kpt_dim = 3; let kpt_features = num_keypoints * kpt_dim;
let num_classes = names.len().max(1);
let expected_features = 4 + num_classes + kpt_features;
let (num_preds, is_transposed) = if output_shape.len() == 3 {
let (a, b) = (output_shape[1], output_shape[2]);
if a == expected_features || (a < b && a >= 4 + kpt_features) {
(b, false) } else {
(a, true) }
} else if output_shape.len() == 2 {
let (a, b) = (output_shape[0], output_shape[1]);
if a < b { (b, false) } else { (a, true) }
} else {
(0, false)
};
if output.is_empty() || num_preds == 0 {
return results;
}
let actual_features = output.len() / num_preds;
if actual_features < 4 + kpt_features {
eprintln!(
"WARNING ⚠️ Pose model has insufficient features ({actual_features}), expected at least {}",
4 + kpt_features
);
return results;
}
let output_2d = if is_transposed {
Array2::from_shape_vec((num_preds, actual_features), output.to_vec())
.unwrap_or_else(|_| Array2::zeros((0, 0)))
} else {
let arr = Array2::from_shape_vec((actual_features, num_preds), output.to_vec())
.unwrap_or_else(|_| Array2::zeros((0, 0)));
arr.t().to_owned()
};
if output_2d.is_empty() {
return results;
}
let derived_classes = actual_features.saturating_sub(4 + kpt_features);
let num_classes = derived_classes.max(1);
let mut candidates: Vec<([f32; 4], f32, usize, Vec<[f32; 3]>)> = Vec::new();
for i in 0..num_preds {
let class_scores = output_2d.slice(s![i, 4..4 + num_classes]);
let (best_class, best_score) = class_scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
.map_or((0, 0.0), |(idx, &score)| {
(idx, if score.is_nan() { 0.0 } else { score })
});
if best_score < config.confidence_threshold {
continue;
}
let cx = output_2d[[i, 0]];
let cy = output_2d[[i, 1]];
let w = output_2d[[i, 2]];
let h = output_2d[[i, 3]];
let x1 = cx - w / 2.0;
let y1 = cy - h / 2.0;
let x2 = cx + w / 2.0;
let y2 = cy + h / 2.0;
let scaled = scale_coords(&[x1, y1, x2, y2], preprocess.scale, preprocess.padding);
let clipped = clip_coords(&scaled, preprocess.orig_shape);
let kpt_start = 4 + num_classes;
let mut keypoints = Vec::with_capacity(num_keypoints);
for k in 0..num_keypoints {
let kpt_offset = kpt_start + k * kpt_dim;
let kpt_x = output_2d[[i, kpt_offset]];
let kpt_y = output_2d[[i, kpt_offset + 1]];
let kpt_conf = output_2d[[i, kpt_offset + 2]];
let scaled_kpt = scale_coords(
&[kpt_x, kpt_y, kpt_x, kpt_y],
preprocess.scale,
preprocess.padding,
);
let (oh, ow) = preprocess.orig_shape;
#[allow(clippy::cast_precision_loss)]
let scaled_x = scaled_kpt[0].max(0.0).min(ow as f32);
#[allow(clippy::cast_precision_loss)]
let scaled_y = scaled_kpt[1].max(0.0).min(oh as f32);
keypoints.push([scaled_x, scaled_y, kpt_conf]);
}
if !config.keep_class(best_class) {
continue;
}
candidates.push((
[clipped[0], clipped[1], clipped[2], clipped[3]],
best_score,
best_class,
keypoints,
));
}
if candidates.is_empty() {
results.keypoints = Some(Keypoints::new(
Array3::zeros((0, num_keypoints, kpt_dim)),
preprocess.orig_shape,
));
return results;
}
let nms_candidates: Vec<_> = candidates
.iter()
.map(|(bbox, score, class, _)| (*bbox, *score, *class))
.collect();
let keep_indices = nms_per_class(&nms_candidates, config.iou_threshold);
let num_kept = keep_indices.len().min(config.max_det);
let mut boxes_data = Array2::zeros((num_kept, 6));
let mut keypoints_data = Array3::zeros((num_kept, num_keypoints, kpt_dim));
for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
let (bbox, score, class, kpts) = &candidates[keep_idx];
boxes_data[[out_idx, 0]] = bbox[0];
boxes_data[[out_idx, 1]] = bbox[1];
boxes_data[[out_idx, 2]] = bbox[2];
boxes_data[[out_idx, 3]] = bbox[3];
boxes_data[[out_idx, 4]] = *score;
#[allow(clippy::cast_precision_loss)]
let class_f32 = *class as f32;
boxes_data[[out_idx, 5]] = class_f32;
for (k, kpt) in kpts.iter().enumerate() {
keypoints_data[[out_idx, k, 0]] = kpt[0]; keypoints_data[[out_idx, k, 1]] = kpt[1]; keypoints_data[[out_idx, k, 2]] = kpt[2]; }
}
results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
results.keypoints = Some(Keypoints::new(keypoints_data, preprocess.orig_shape));
results
}
fn postprocess_classify(
output: &[f32],
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
if output.is_empty() {
return results;
}
let mut probs_vec = output.to_vec();
let sum: f32 = probs_vec.iter().sum();
if (sum - 1.0).abs() > 0.1 && sum > 0.0 {
let max_val = probs_vec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = probs_vec.iter().map(|&v| (v - max_val).exp()).collect();
let exp_sum: f32 = exp_vals.iter().sum();
if exp_sum > 0.0 {
probs_vec = exp_vals.iter().map(|&v| v / exp_sum).collect();
}
}
let probs = ndarray::Array1::from_vec(probs_vec);
results.probs = Some(Probs::new(probs));
results
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::similar_names
)]
fn postprocess_obb(
output: &[f32],
output_shape: &[usize],
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
let num_classes = names.len().max(1);
let expected_features = 4 + num_classes + 1;
let (num_preds, is_transposed) = if output_shape.len() == 3 {
let (a, b) = (output_shape[1], output_shape[2]);
if a == expected_features || (a < b && a >= 6) {
(b, false) } else {
(a, true) }
} else if output_shape.len() == 2 {
let (a, b) = (output_shape[0], output_shape[1]);
if a < b { (b, false) } else { (a, true) }
} else {
(0, false)
};
if output.is_empty() || num_preds == 0 {
return results;
}
let actual_features = output.len() / num_preds;
if actual_features < 6 {
eprintln!(
"WARNING ⚠️ OBB model has insufficient features ({actual_features}), expected at least 6"
);
return results;
}
let output_2d = if is_transposed {
Array2::from_shape_vec((num_preds, actual_features), output.to_vec())
.unwrap_or_else(|_| Array2::zeros((0, 0)))
} else {
let arr = Array2::from_shape_vec((actual_features, num_preds), output.to_vec())
.unwrap_or_else(|_| Array2::zeros((0, 0)));
arr.t().to_owned()
};
if output_2d.is_empty() {
return results;
}
let derived_classes = actual_features.saturating_sub(5); let num_classes = derived_classes.max(1);
let mut candidates: Vec<([f32; 5], f32, usize)> = Vec::new();
for i in 0..num_preds {
let class_scores = output_2d.slice(s![i, 4..4 + num_classes]);
let (best_class, best_score) = class_scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
.map_or((0, 0.0), |(idx, &score)| {
(idx, if score.is_nan() { 0.0 } else { score })
});
if best_score < config.confidence_threshold {
continue;
}
let cx = output_2d[[i, 0]];
let cy = output_2d[[i, 1]];
let w = output_2d[[i, 2]];
let h = output_2d[[i, 3]];
let angle = output_2d[[i, 4 + num_classes]];
let scaled = scale_coords(&[cx, cy, cx, cy], preprocess.scale, preprocess.padding);
let scaled_cx = scaled[0];
let scaled_cy = scaled[1];
let scaled_w = w / preprocess.scale.1;
let scaled_h = h / preprocess.scale.0;
let (oh, ow) = preprocess.orig_shape;
#[allow(clippy::cast_precision_loss)]
let clipped_cx = scaled_cx.max(0.0).min(ow as f32);
#[allow(clippy::cast_precision_loss)]
let clipped_cy = scaled_cy.max(0.0).min(oh as f32);
if !config.keep_class(best_class) {
continue;
}
candidates.push((
[clipped_cx, clipped_cy, scaled_w, scaled_h, angle],
best_score,
best_class,
));
}
if candidates.is_empty() {
results.obb = Some(Obb::new(Array2::zeros((0, 7)), preprocess.orig_shape));
return results;
}
let keep_indices = nms_rotated_per_class(&candidates, config.iou_threshold);
let num_kept = keep_indices.len().min(config.max_det);
let mut obb_data = Array2::zeros((num_kept, 7));
for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
let (xywhr, score, class) = &candidates[keep_idx];
obb_data[[out_idx, 0]] = xywhr[0]; obb_data[[out_idx, 1]] = xywhr[1]; obb_data[[out_idx, 2]] = xywhr[2]; obb_data[[out_idx, 3]] = xywhr[3]; obb_data[[out_idx, 4]] = xywhr[4]; obb_data[[out_idx, 5]] = *score;
#[allow(clippy::cast_precision_loss)]
let class_f32 = *class as f32;
obb_data[[out_idx, 6]] = class_f32;
}
results.obb = Some(Obb::new(obb_data, preprocess.orig_shape));
results
}
#[inline]
fn scale_xyxy(
x1: f32,
y1: f32,
x2: f32,
y2: f32,
preprocess: &PreprocessResult,
) -> (f32, f32, f32, f32) {
let (scale_y, scale_x) = preprocess.scale;
let (pad_top, pad_left) = preprocess.padding;
(
(x1 - pad_left) / scale_x,
(y1 - pad_top) / scale_y,
(x2 - pad_left) / scale_x,
(y2 - pad_top) / scale_y,
)
}
#[allow(clippy::too_many_arguments, clippy::cast_precision_loss)]
fn postprocess_detect_end2end(
output: &[f32],
output_shape: &[usize],
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
if output_shape.len() != 3 || output.is_empty() {
return results;
}
let max_det = output_shape[1];
let feats = output_shape[2];
if feats < 6 || max_det == 0 {
return results;
}
let (oh, ow) = preprocess.orig_shape;
let (max_w, max_h) = (ow as f32, oh as f32);
let user_cap = config.max_det.min(max_det);
let mut flat: Vec<f32> = Vec::with_capacity(user_cap * 6);
for i in 0..max_det {
let base = i * feats;
let conf = output[base + 4];
if conf < config.confidence_threshold {
break;
}
let cls = output[base + 5] as usize;
if !config.keep_class(cls) {
continue;
}
let (x1, y1, x2, y2) = scale_xyxy(
output[base],
output[base + 1],
output[base + 2],
output[base + 3],
preprocess,
);
flat.extend_from_slice(&[
x1.clamp(0.0, max_w),
y1.clamp(0.0, max_h),
x2.clamp(0.0, max_w),
y2.clamp(0.0, max_h),
conf,
cls as f32,
]);
if flat.len() >= user_cap * 6 {
break;
}
}
let n = flat.len() / 6;
if n > 0 {
let boxes_data = Array2::from_shape_vec((n, 6), flat).expect("flat length matches (n, 6)");
results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
}
results
}
#[allow(
clippy::too_many_arguments,
clippy::cast_precision_loss,
clippy::too_many_lines,
clippy::needless_pass_by_value,
clippy::similar_names,
clippy::manual_let_else
)]
fn postprocess_segment_end2end(
outputs: Vec<(&[f32], Vec<usize>)>,
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
if outputs.len() < 2 {
eprintln!(
"WARNING ⚠️ End2end segmentation missing protos output (got {} outputs).",
outputs.len()
);
return results;
}
let (output0, shape0) = &outputs[0];
let (output1, shape1) = &outputs[1];
if shape0.len() != 3 || shape1.len() != 4 {
return results;
}
let max_det = shape0[1];
let feats = shape0[2];
let num_masks = shape1[1];
if feats < 6 + num_masks {
eprintln!("WARNING ⚠️ End2end segment features ({feats}) < 6 + num_masks ({num_masks}).");
return results;
}
let (oh, ow) = preprocess.orig_shape;
let (max_w, max_h) = (ow as f32, oh as f32);
let user_cap = config.max_det.min(max_det);
let mut flat_boxes: Vec<f32> = Vec::with_capacity(user_cap * 6);
let mut flat_coeffs: Vec<f32> = Vec::with_capacity(user_cap * num_masks);
for i in 0..max_det {
let base = i * feats;
let conf = output0[base + 4];
if conf < config.confidence_threshold {
break;
}
let cls = output0[base + 5] as usize;
if !config.keep_class(cls) {
continue;
}
let (x1, y1, x2, y2) = scale_xyxy(
output0[base],
output0[base + 1],
output0[base + 2],
output0[base + 3],
preprocess,
);
flat_boxes.extend_from_slice(&[
x1.clamp(0.0, max_w),
y1.clamp(0.0, max_h),
x2.clamp(0.0, max_w),
y2.clamp(0.0, max_h),
conf,
cls as f32,
]);
let coeff_start = base + 6;
flat_coeffs.extend_from_slice(&output0[coeff_start..coeff_start + num_masks]);
if flat_boxes.len() >= user_cap * 6 {
break;
}
}
let num_kept = flat_boxes.len() / 6;
if num_kept == 0 {
return results;
}
let boxes_data =
Array2::from_shape_vec((num_kept, 6), flat_boxes).expect("flat length matches (n, 6)");
let mask_coeffs = Array2::from_shape_vec((num_kept, num_masks), flat_coeffs)
.expect("flat length matches (n, num_masks)");
let mh = shape1[2];
let mw = shape1[3];
let protos = match Array2::from_shape_vec((num_masks, mh * mw), output1.to_vec()) {
Ok(a) => a,
Err(e) => {
eprintln!("WARNING ⚠️ Failed to build protos array: {e}. Skipping masks.");
return results;
}
};
let masks_flat = mask_coeffs.dot(&protos);
let (th, tw) = inference_shape;
let (pad_top, pad_left) = preprocess.padding;
let scale_w = mw as f32 / tw as f32;
let scale_h = mh as f32 / th as f32;
let crop_x = pad_left * scale_w;
let crop_y = pad_top * scale_h;
let crop_w = 2.0f32.mul_add(-crop_x, mw as f32);
let crop_h = 2.0f32.mul_add(-crop_y, mh as f32);
let mut masks_data = Array3::zeros((num_kept, oh as usize, ow as usize));
Zip::from(masks_data.outer_iter_mut())
.and(masks_flat.outer_iter())
.and(boxes_data.outer_iter())
.par_for_each(
|mut mask_out: ArrayViewMut2<f32>,
mask_flat: ArrayView1<f32>,
box_data: ArrayView1<f32>| {
let mut resizer = Resizer::new();
let resize_alg = ResizeAlg::Convolution(FilterType::Bilinear);
let f32_data: Vec<f32> = mask_flat
.iter()
.map(|&v| 1.0 / (1.0 + (-v).exp()))
.collect();
let src_bytes: &[u8] = bytemuck::cast_slice(&f32_data);
let src_image = match Image::from_vec_u8(
mw as u32,
mh as u32,
src_bytes.to_vec(),
PixelType::F32,
) {
Ok(i) => i,
Err(_) => return,
};
let mut dst_image = Image::new(ow, oh, PixelType::F32);
let options = ResizeOptions::new().resize_alg(resize_alg).crop(
f64::from(crop_x.max(0.0)),
f64::from(crop_y.max(0.0)),
f64::from(crop_w.max(1.0).min(mw as f32)),
f64::from(crop_h.max(1.0).min(mh as f32)),
);
if resizer
.resize(&src_image, &mut dst_image, &options)
.is_err()
{
return;
}
let dst_bytes = dst_image.buffer();
let dst_slice: &[f32] = bytemuck::cast_slice(dst_bytes);
let x1 = box_data[0].max(0.0).min(ow as f32);
let y1 = box_data[1].max(0.0).min(oh as f32);
let x2 = box_data[2].max(0.0).min(ow as f32);
let y2 = box_data[3].max(0.0).min(oh as f32);
for y in 0..oh as usize {
for x in 0..ow as usize {
let val = dst_slice[y * ow as usize + x];
let xf = x as f32;
let yf = y as f32;
if xf >= x1 && xf <= x2 && yf >= y1 && yf <= y2 {
mask_out[[y, x]] = val;
}
}
}
},
);
results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
results.masks = Some(Masks::new(masks_data, preprocess.orig_shape));
results
}
#[allow(
clippy::too_many_arguments,
clippy::cast_precision_loss,
clippy::similar_names
)]
fn postprocess_pose_end2end(
output: &[f32],
output_shape: &[usize],
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
nk: usize,
kpt_dim: usize,
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
if output_shape.len() != 3 || output.is_empty() || nk == 0 || kpt_dim < 2 {
return results;
}
let max_det = output_shape[1];
let feats = output_shape[2];
if feats < 6 + nk * kpt_dim || max_det == 0 {
return results;
}
let (oh, ow) = preprocess.orig_shape;
let (max_w, max_h) = (ow as f32, oh as f32);
let (scale_y, scale_x) = preprocess.scale;
let (pad_top, pad_left) = preprocess.padding;
let user_cap = config.max_det.min(max_det);
let mut flat_boxes: Vec<f32> = Vec::with_capacity(user_cap * 6);
let mut flat_kpts: Vec<f32> = Vec::with_capacity(user_cap * nk * 3);
for i in 0..max_det {
let base = i * feats;
let conf = output[base + 4];
if conf < config.confidence_threshold {
break;
}
let cls = output[base + 5] as usize;
if !config.keep_class(cls) {
continue;
}
let (x1, y1, x2, y2) = scale_xyxy(
output[base],
output[base + 1],
output[base + 2],
output[base + 3],
preprocess,
);
flat_boxes.extend_from_slice(&[
x1.clamp(0.0, max_w),
y1.clamp(0.0, max_h),
x2.clamp(0.0, max_w),
y2.clamp(0.0, max_h),
conf,
cls as f32,
]);
let kstart = base + 6;
for k in 0..nk {
let off = kstart + k * kpt_dim;
let sx = (output[off] - pad_left) / scale_x;
let sy = (output[off + 1] - pad_top) / scale_y;
let kconf = if kpt_dim >= 3 { output[off + 2] } else { 1.0 };
flat_kpts.extend_from_slice(&[sx.clamp(0.0, max_w), sy.clamp(0.0, max_h), kconf]);
}
if flat_boxes.len() >= user_cap * 6 {
break;
}
}
let n = flat_boxes.len() / 6;
let kdata =
Array3::from_shape_vec((n, nk, 3), flat_kpts).expect("flat length matches (n, nk, 3)");
results.keypoints = Some(Keypoints::new(kdata, preprocess.orig_shape));
if n > 0 {
let boxes_data =
Array2::from_shape_vec((n, 6), flat_boxes).expect("flat length matches (n, 6)");
results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
}
results
}
#[allow(clippy::too_many_arguments, clippy::cast_precision_loss)]
fn postprocess_obb_end2end(
output: &[f32],
output_shape: &[usize],
preprocess: &PreprocessResult,
config: &InferenceConfig,
names: &HashMap<usize, String>,
orig_img: Array3<u8>,
path: String,
speed: Speed,
inference_shape: (u32, u32),
) -> Results {
let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
let mut flat: Vec<f32> = Vec::new();
if output_shape.len() == 3 && !output.is_empty() {
let max_det = output_shape[1];
let feats = output_shape[2];
if feats >= 7 && max_det > 0 {
let (oh, ow) = preprocess.orig_shape;
let (max_w, max_h) = (ow as f32, oh as f32);
let (scale_y, scale_x) = preprocess.scale;
let (pad_top, pad_left) = preprocess.padding;
let user_cap = config.max_det.min(max_det);
flat.reserve(user_cap * 7);
for i in 0..max_det {
let base = i * feats;
let conf = output[base + 4];
if conf < config.confidence_threshold {
break;
}
let cls = output[base + 5] as usize;
if !config.keep_class(cls) {
continue;
}
let cx = (output[base] - pad_left) / scale_x;
let cy = (output[base + 1] - pad_top) / scale_y;
flat.extend_from_slice(&[
cx.clamp(0.0, max_w),
cy.clamp(0.0, max_h),
output[base + 2] / scale_x,
output[base + 3] / scale_y,
output[base + 6],
conf,
cls as f32,
]);
if flat.len() >= user_cap * 7 {
break;
}
}
}
}
let n = flat.len() / 7;
let obb_data = Array2::from_shape_vec((n, 7), flat).expect("flat length matches (n, 7)");
results.obb = Some(Obb::new(obb_data, preprocess.orig_shape));
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_detect_shape() {
let (nc, np, transposed) = parse_detect_shape(&[1, 84, 8400], 80);
assert_eq!(nc, 80);
assert_eq!(np, 8400);
assert!(!transposed);
let (nc, np, transposed) = parse_detect_shape(&[1, 8400, 84], 80);
assert_eq!(nc, 80);
assert_eq!(np, 8400);
assert!(transposed);
}
#[test]
fn test_infer_end2end_kpt_shape() {
assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 51]), Some((17, 3)));
assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 34]), Some((17, 2)));
assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 36]), None);
assert_eq!(infer_end2end_kpt_shape(&[1, 56, 8400]), None);
assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6]), None);
}
#[test]
fn test_parse_detect_shape_no_metadata() {
let (nc, np, transposed) = parse_detect_shape(&[1, 84, 8400], 0);
assert_eq!(nc, 80); assert_eq!(np, 8400);
assert!(!transposed);
let (nc, np, transposed) = parse_detect_shape(&[1, 8400, 84], 0);
assert_eq!(nc, 80); assert_eq!(np, 8400);
assert!(transposed);
}
#[test]
fn test_empty_output() {
let output: Vec<f32> = vec![];
let preprocess = PreprocessResult {
tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
tensor_f16: None,
orig_shape: (480, 640),
scale: (1.0, 1.0),
padding: (0.0, 0.0),
};
let config = InferenceConfig::default();
let names = HashMap::new();
let orig_img = ndarray::Array3::zeros((480, 640, 3));
let results = postprocess_detect(
&output,
&[1, 84, 0],
&preprocess,
&config,
&names,
orig_img,
String::new(),
Speed::default(),
(640, 640),
);
assert!(results.is_empty());
}
#[test]
fn test_nan_scores_handled() {
let mut output: Vec<f32> = vec![0.0; 84]; output[0] = 100.0; output[1] = 100.0; output[2] = 50.0; output[3] = 50.0; output[4] = f32::NAN;
output[5] = 0.9;
let preprocess = PreprocessResult {
tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
tensor_f16: None,
orig_shape: (640, 640),
scale: (1.0, 1.0),
padding: (0.0, 0.0),
};
let config = InferenceConfig::default();
let mut names = HashMap::new();
names.insert(0, "class0".to_string());
names.insert(1, "class1".to_string());
let orig_img = ndarray::Array3::zeros((640, 640, 3));
let results = postprocess_detect(
&output,
&[1, 84, 1],
&preprocess,
&config,
&names,
orig_img,
String::new(),
Speed::default(),
(640, 640),
);
let _ = results;
}
#[test]
fn test_malformed_shape_fallback() {
let output: Vec<f32> = vec![0.0; 100];
let preprocess = PreprocessResult {
tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
tensor_f16: None,
orig_shape: (640, 640),
scale: (1.0, 1.0),
padding: (0.0, 0.0),
};
let config = InferenceConfig::default();
let names = HashMap::new();
let orig_img = ndarray::Array3::zeros((640, 640, 3));
let results = postprocess_detect(
&output,
&[],
&preprocess,
&config,
&names,
orig_img.clone(),
String::new(),
Speed::default(),
(640, 640),
);
assert!(results.is_empty());
let results = postprocess_detect(
&output,
&[100],
&preprocess,
&config,
&names,
orig_img,
String::new(),
Speed::default(),
(640, 640),
);
assert!(results.is_empty());
}
#[test]
fn test_postprocess_pose_logic() {
let num_preds = 100;
let num_features = 56;
let mut output = vec![0.0; num_preds * num_features];
let idx = 0;
output[idx] = 100.0;
output[idx + num_preds] = 100.0;
output[idx + num_preds * 2] = 50.0;
output[idx + num_preds * 3] = 50.0;
output[idx + num_preds * 4] = 0.9;
for k in 0..17 {
let offset = 5 + k * 3;
output[idx + num_preds * offset] = 100.0; output[idx + num_preds * (offset + 1)] = 100.0; output[idx + num_preds * (offset + 2)] = 0.8; }
let preprocess = PreprocessResult {
tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
tensor_f16: None,
orig_shape: (640, 640),
scale: (1.0, 1.0),
padding: (0.0, 0.0),
};
let config = InferenceConfig::default();
let mut names = HashMap::new();
names.insert(0, "person".to_string());
let results = postprocess_pose(
&output,
&[1, num_features, num_preds],
&preprocess,
&config,
&names,
ndarray::Array3::zeros((640, 640, 3)),
"test.jpg".to_string(),
Speed::default(),
(640, 640),
);
assert!(results.keypoints.is_some());
let kpts = results.keypoints.unwrap();
assert_eq!(kpts.data.shape()[0], 1); assert_eq!(kpts.data.shape()[1], 17); assert_eq!(kpts.data.shape()[2], 3);
#[allow(clippy::float_cmp)]
{
assert_eq!(kpts.data[[0, 0, 0]], 100.0);
assert_eq!(kpts.data[[0, 0, 2]], 0.8);
}
}
#[test]
fn test_postprocess_obb_logic() {
let num_preds = 100;
let num_features = 6;
let mut output = vec![0.0; num_preds * num_features];
let idx = 0;
output[idx] = 100.0;
output[idx + num_preds] = 100.0;
output[idx + num_preds * 2] = 50.0;
output[idx + num_preds * 3] = 20.0;
output[idx + num_preds * 4] = 0.95;
output[idx + num_preds * 5] = std::f32::consts::FRAC_PI_4;
let preprocess = PreprocessResult {
tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
tensor_f16: None,
orig_shape: (640, 640),
scale: (1.0, 1.0),
padding: (0.0, 0.0),
};
let config = InferenceConfig::default();
let mut names = HashMap::new();
names.insert(0, "object".to_string());
let results = postprocess_obb(
&output,
&[1, num_features, num_preds],
&preprocess,
&config,
&names,
ndarray::Array3::zeros((640, 640, 3)),
"test.jpg".to_string(),
Speed::default(),
(640, 640),
);
assert!(results.obb.is_some());
let obb = results.obb.unwrap();
assert_eq!(obb.len(), 1);
let data = obb.data.row(0);
#[allow(clippy::float_cmp)]
{
assert_eq!(data[0], 100.0); assert_eq!(data[4], std::f32::consts::FRAC_PI_4); assert_eq!(data[5], 0.95); }
}
}