1#![allow(
9 unsafe_code,
10 clippy::doc_markdown,
11 clippy::too_many_lines,
12 clippy::if_not_else,
13 clippy::ptr_as_ptr,
14 clippy::cast_possible_truncation,
15 clippy::cast_sign_loss
16)]
17
18use std::collections::HashMap;
19
20use wide::{CmpGt, f32x8};
21
22use fast_image_resize::images::Image;
23use fast_image_resize::{FilterType, PixelType, ResizeAlg, ResizeOptions, Resizer};
24use ndarray::{Array2, Array3, ArrayView1, ArrayViewMut2, Zip, s};
25
26use crate::inference::InferenceConfig;
27use crate::preprocessing::{PreprocessResult, clip_coords, scale_coords};
28use crate::results::{Boxes, Keypoints, Masks, Obb, Probs, Results, Speed};
29use crate::task::Task;
30use crate::utils::{nms_per_class, nms_rotated_per_class};
31
32#[must_use]
50#[allow(
51 clippy::too_many_arguments,
52 clippy::similar_names,
53 clippy::implicit_hasher
54)]
55pub fn postprocess(
56 outputs: Vec<(&[f32], Vec<usize>)>,
57 task: Task,
58 preprocess: &PreprocessResult,
59 config: &InferenceConfig,
60 names: &HashMap<usize, String>,
61 orig_img: Array3<u8>,
62 path: String,
63 speed: Speed,
64 inference_shape: (u32, u32),
65 end2end: bool,
66 kpt_shape: Option<(usize, usize)>,
67) -> Results {
68 match task {
69 Task::Detect => {
70 let (output, shape) = &outputs[0];
71 if end2end || is_end2end_detect_shape(shape) {
72 postprocess_detect_end2end(
73 output,
74 shape,
75 preprocess,
76 config,
77 names,
78 orig_img,
79 path,
80 speed,
81 inference_shape,
82 )
83 } else {
84 postprocess_detect(
85 output,
86 shape,
87 preprocess,
88 config,
89 names,
90 orig_img,
91 path,
92 speed,
93 inference_shape,
94 )
95 }
96 }
97 Task::Segment => {
98 let proto_channels = outputs
101 .get(1)
102 .and_then(|(_, s)| if s.len() == 4 { Some(s[1]) } else { None });
103 if end2end || is_end2end_segment_shape(&outputs[0].1, proto_channels) {
104 postprocess_segment_end2end(
105 outputs,
106 preprocess,
107 config,
108 names,
109 orig_img,
110 path,
111 speed,
112 inference_shape,
113 )
114 } else {
115 postprocess_segment(
116 outputs,
117 preprocess,
118 config,
119 names,
120 orig_img,
121 path,
122 speed,
123 inference_shape,
124 )
125 }
126 }
127 Task::Pose => {
128 let (output, shape) = &outputs[0];
129 let resolved_kpt = kpt_shape.or_else(|| infer_end2end_kpt_shape(shape));
132 let is_end2end = end2end
133 || resolved_kpt.is_some_and(|(nk, kd)| is_end2end_pose_shape(shape, nk, kd));
134 if is_end2end {
135 let (nk, kpt_dim) = resolved_kpt.unwrap_or((17, 3));
136 postprocess_pose_end2end(
137 output,
138 shape,
139 preprocess,
140 config,
141 names,
142 orig_img,
143 path,
144 speed,
145 inference_shape,
146 nk,
147 kpt_dim,
148 )
149 } else {
150 postprocess_pose(
151 output,
152 shape,
153 preprocess,
154 config,
155 names,
156 orig_img,
157 path,
158 speed,
159 inference_shape,
160 )
161 }
162 }
163 Task::Classify => {
164 let (output, _) = &outputs[0];
165 postprocess_classify(output, names, orig_img, path, speed, inference_shape)
166 }
167 Task::Obb => {
168 let (output, shape) = &outputs[0];
169 if end2end || is_end2end_obb_shape(shape) {
170 postprocess_obb_end2end(
171 output,
172 shape,
173 preprocess,
174 config,
175 names,
176 orig_img,
177 path,
178 speed,
179 inference_shape,
180 )
181 } else {
182 postprocess_obb(
183 output,
184 shape,
185 preprocess,
186 config,
187 names,
188 orig_img,
189 path,
190 speed,
191 inference_shape,
192 )
193 }
194 }
195 }
196}
197
198fn is_end2end_detect_shape(shape: &[usize]) -> bool {
200 shape.len() == 3 && shape[2] == 6 && shape[1] <= 4096
201}
202
203fn is_end2end_segment_shape(shape: &[usize], proto_channels: Option<usize>) -> bool {
209 proto_channels.is_some_and(|nm| shape.len() == 3 && shape[2] == 6 + nm && shape[1] <= 4096)
210}
211
212fn is_end2end_pose_shape(shape: &[usize], nk: usize, kpt_dim: usize) -> bool {
214 shape.len() == 3 && shape[2] == 6 + nk * kpt_dim && shape[1] <= 4096
215}
216
217fn infer_end2end_kpt_shape(shape: &[usize]) -> Option<(usize, usize)> {
226 if shape.len() != 3 || shape[1] == 0 || shape[1] > 4096 || shape[2] <= 6 {
227 return None;
228 }
229 let kpt_feats = shape[2] - 6;
230 let div3 = kpt_feats.is_multiple_of(3);
231 let div2 = kpt_feats.is_multiple_of(2);
232 match (div3, div2) {
233 (true, false) => Some((kpt_feats / 3, 3)),
234 (false, true) => Some((kpt_feats / 2, 2)),
235 _ => None, }
237}
238
239fn is_end2end_obb_shape(shape: &[usize]) -> bool {
241 shape.len() == 3 && shape[2] == 7 && shape[1] <= 4096
242}
243
244#[allow(
248 clippy::too_many_arguments,
249 clippy::similar_names,
250 clippy::cast_precision_loss
251)]
252fn postprocess_detect(
253 output: &[f32],
254 output_shape: &[usize],
255 preprocess: &PreprocessResult,
256 config: &InferenceConfig,
257 names: &HashMap<usize, String>,
258 orig_img: Array3<u8>,
259 path: String,
260 speed: Speed,
261 inference_shape: (u32, u32),
262) -> Results {
263 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
264
265 let (num_classes, num_predictions, is_transposed) =
267 parse_detect_shape(output_shape, names.len());
268
269 if output.is_empty() || num_predictions == 0 {
270 return results;
271 }
272
273 let boxes_data = extract_detect_boxes(
275 output,
276 num_classes,
277 num_predictions,
278 is_transposed,
279 preprocess,
280 config,
281 );
282
283 if !boxes_data.is_empty() {
284 results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
285 }
286
287 results
288}
289
290fn parse_detect_shape(shape: &[usize], expected_classes: usize) -> (usize, usize, bool) {
296 match shape.len() {
297 2 => {
298 let (a, b) = (shape[0], shape[1]);
300 if a < 4 && b < 4 {
302 return (expected_classes.max(1), 0, false);
303 }
304 if expected_classes == 0 {
307 let (num_features, num_preds, transposed) =
309 if a < b { (a, b, false) } else { (b, a, true) };
310 let inferred_classes = num_features.saturating_sub(4);
311 return (inferred_classes.max(1), num_preds, transposed);
312 }
313 if a == 4 + expected_classes || (a >= 4 && a > b) {
314 (a.saturating_sub(4), b, false)
316 } else {
317 (b.saturating_sub(4), a, true)
319 }
320 }
321 3 => {
322 let (a, b) = (shape[1], shape[2]);
324 if b == 0 || a < 4 {
326 return (expected_classes.max(1), 0, false);
327 }
328 if expected_classes == 0 {
330 let (num_features, num_preds, transposed) =
333 if a < b { (a, b, false) } else { (b, a, true) };
334 let inferred_classes = num_features.saturating_sub(4);
335 return (inferred_classes.max(1), num_preds, transposed);
336 }
337 if a == 4 + expected_classes || (expected_classes > 0 && a < b) {
338 (a.saturating_sub(4), b, false)
340 } else {
341 (b.saturating_sub(4), a, true)
343 }
344 }
345 _ => (expected_classes.max(1), 0, false),
346 }
347}
348
349#[derive(Clone, Copy)]
350struct Candidate {
351 bbox: [f32; 4],
352 score: f32,
353 class: usize,
354}
355
356#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)]
364fn extract_detect_boxes(
365 output: &[f32],
366 num_classes: usize,
367 num_predictions: usize,
368 is_transposed: bool,
369 preprocess: &PreprocessResult,
370 config: &InferenceConfig,
371) -> Array2<f32> {
372 let feat_count = 4 + num_classes;
373 let (scale_y, scale_x) = preprocess.scale;
374 let (pad_top, pad_left) = preprocess.padding;
375 let orig_shape = preprocess.orig_shape;
376 let (max_w, max_h) = (orig_shape.1 as f32, orig_shape.0 as f32);
377 let conf_thresh = config.confidence_threshold;
378 let max_det = config.max_det;
379 let iou_thresh = config.iou_threshold;
380 let conf_v = f32x8::splat(conf_thresh);
381
382 let mut candidates: Vec<Candidate> = Vec::with_capacity(256);
383
384 if !is_transposed {
386 let mut max_scores = vec![conf_thresh; num_predictions];
388 let mut max_classes = vec![0usize; num_predictions];
389
390 for c in 0..num_classes {
391 let offset = (4 + c) * num_predictions;
392 let class_scores = &output[offset..offset + num_predictions];
393 for (idx, &score) in class_scores.iter().enumerate() {
394 if score > max_scores[idx] {
395 max_scores[idx] = score;
396 max_classes[idx] = c;
397 }
398 }
399 }
400
401 for (idx, &score) in max_scores.iter().enumerate() {
402 if score > conf_thresh {
403 let best_class = max_classes[idx];
404
405 if !config.keep_class(best_class) {
407 continue;
408 }
409
410 let cx = unsafe { *output.get_unchecked(idx) };
411 let cy = unsafe { *output.get_unchecked(num_predictions + idx) };
412 let w = unsafe { *output.get_unchecked(2 * num_predictions + idx) };
413 let h = unsafe { *output.get_unchecked(3 * num_predictions + idx) };
414
415 let x1 = (cx - w * 0.5 - pad_left) / scale_x;
416 let y1 = (cy - h * 0.5 - pad_top) / scale_y;
417 let x2 = (cx + w * 0.5 - pad_left) / scale_x;
418 let y2 = (cy + h * 0.5 - pad_top) / scale_y;
419
420 candidates.push(Candidate {
421 bbox: [x1, y1, x2, y2],
422 score,
423 class: best_class,
424 });
425 }
426 }
427 } else {
428 for idx in 0..num_predictions {
430 let base = idx * feat_count;
431 let row_ptr = unsafe { output.as_ptr().add(base + 4) };
432 let mut best_score = conf_thresh;
433 let mut best_class = 0;
434
435 for c_idx in (0..num_classes).step_by(8) {
436 if num_classes - c_idx >= 8 {
437 let scores: f32x8 =
438 unsafe { (row_ptr.add(c_idx) as *const f32x8).read_unaligned() };
439 if scores.simd_gt(conf_v).any() {
440 for i in 0..8 {
441 let s = unsafe { *row_ptr.add(c_idx + i) };
442 if s > best_score {
443 best_score = s;
444 best_class = c_idx + i;
445 }
446 }
447 }
448 } else {
449 for i in c_idx..num_classes {
450 let s = unsafe { *row_ptr.add(i) };
451 if s > best_score {
452 best_score = s;
453 best_class = i;
454 }
455 }
456 }
457 }
458
459 if best_score > conf_thresh {
460 if !config.keep_class(best_class) {
462 continue;
463 }
464
465 let cx = unsafe { *output.get_unchecked(base) };
466 let cy = unsafe { *output.get_unchecked(base + 1) };
467 let w = unsafe { *output.get_unchecked(base + 2) };
468 let h = unsafe { *output.get_unchecked(base + 3) };
469
470 let x1 = (cx - w * 0.5 - pad_left) / scale_x;
471 let y1 = (cy - h * 0.5 - pad_top) / scale_y;
472 let x2 = (cx + w * 0.5 - pad_left) / scale_x;
473 let y2 = (cy + h * 0.5 - pad_top) / scale_y;
474
475 candidates.push(Candidate {
476 bbox: [x1, y1, x2, y2],
477 score: best_score,
478 class: best_class,
479 });
480 }
481 }
482 }
483
484 if candidates.is_empty() {
485 return Array2::zeros((0, 6));
486 }
487
488 let nms_limit = (max_det * 10).min(candidates.len());
490 if candidates.len() > nms_limit {
491 candidates.select_nth_unstable_by(nms_limit, |a, b| b.score.partial_cmp(&a.score).unwrap());
492 candidates.truncate(nms_limit);
493 }
494 candidates.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
495
496 let n = candidates.len();
498 let mut x1 = Vec::with_capacity(n);
499 let mut y1 = Vec::with_capacity(n);
500 let mut x2 = Vec::with_capacity(n);
501 let mut y2 = Vec::with_capacity(n);
502 let mut areas = Vec::with_capacity(n);
503
504 for c in &candidates {
505 x1.push(c.bbox[0]);
506 y1.push(c.bbox[1]);
507 x2.push(c.bbox[2]);
508 y2.push(c.bbox[3]);
509 areas.push((c.bbox[2] - c.bbox[0]) * (c.bbox[3] - c.bbox[1]));
510 }
511
512 let mut suppressed = vec![false; n];
513 let mut keep = Vec::with_capacity(max_det);
514 let iou_v = f32x8::splat(iou_thresh);
515 for i in 0..n {
516 if suppressed[i] {
517 continue;
518 }
519 keep.push(i);
520 if keep.len() >= max_det {
521 break;
522 }
523
524 let ax1 = f32x8::splat(x1[i]);
525 let ay1 = f32x8::splat(y1[i]);
526 let ax2 = f32x8::splat(x2[i]);
527 let ay2 = f32x8::splat(y2[i]);
528 let aa = f32x8::splat(areas[i]);
529 let ac = candidates[i].class;
530
531 let mut j = i + 1;
532 while j < n {
533 if n - j >= 8 {
534 if (0..8).any(|k| candidates[j + k].class == ac && !suppressed[j + k]) {
535 let bx1 = unsafe { (x1.as_ptr().add(j) as *const f32x8).read_unaligned() };
536 let by1 = unsafe { (y1.as_ptr().add(j) as *const f32x8).read_unaligned() };
537 let bx2 = unsafe { (x2.as_ptr().add(j) as *const f32x8).read_unaligned() };
538 let by2 = unsafe { (y2.as_ptr().add(j) as *const f32x8).read_unaligned() };
539 let ba = unsafe { (areas.as_ptr().add(j) as *const f32x8).read_unaligned() };
540
541 let ix1 = ax1.max(bx1);
542 let iy1 = ay1.max(by1);
543 let ix2 = ax2.min(bx2);
544 let iy2 = ay2.min(by2);
545
546 let iw = (ix2 - ix1).max(f32x8::ZERO);
547 let ih = (iy2 - iy1).max(f32x8::ZERO);
548 let ia = iw * ih;
549 let iou = ia / (aa + ba - ia);
550
551 let mask = iou.simd_gt(iou_v).to_bitmask() as u8;
552 if mask != 0 {
553 for k in 0..8 {
554 if (mask & (1 << k)) != 0 && candidates[j + k].class == ac {
555 suppressed[j + k] = true;
556 }
557 }
558 }
559 }
560 j += 8;
561 } else {
562 for k in j..n {
563 if !suppressed[k] && candidates[k].class == ac {
564 let ix1 = x1[i].max(x1[k]);
565 let iy1 = y1[i].max(y1[k]);
566 let ix2 = x2[i].min(x2[k]);
567 let iy2 = y2[i].min(y2[k]);
568 let iw = (ix2 - ix1).max(0.0);
569 let ih = (iy2 - iy1).max(0.0);
570 let ia = iw * ih;
571 let iou = ia / (areas[i] + areas[k] - ia);
572 if iou > iou_thresh {
573 suppressed[k] = true;
574 }
575 }
576 }
577 break;
578 }
579 }
580 }
581 let num_kept = keep.len();
583 let mut result = Array2::zeros((num_kept, 6));
584 for (out_idx, &idx) in keep.iter().enumerate() {
585 let c = &candidates[idx];
586 result[[out_idx, 0]] = c.bbox[0].clamp(0.0, max_w);
587 result[[out_idx, 1]] = c.bbox[1].clamp(0.0, max_h);
588 result[[out_idx, 2]] = c.bbox[2].clamp(0.0, max_w);
589 result[[out_idx, 3]] = c.bbox[3].clamp(0.0, max_h);
590 result[[out_idx, 4]] = c.score;
591 result[[out_idx, 5]] = c.class as f32;
592 }
593
594 result
595}
596
597#[allow(
616 clippy::too_many_arguments,
617 clippy::similar_names,
618 clippy::cast_precision_loss,
619 clippy::too_many_lines,
620 clippy::needless_pass_by_value,
621 clippy::manual_let_else,
622 clippy::cast_possible_truncation
623)]
624fn postprocess_segment(
625 outputs: Vec<(&[f32], Vec<usize>)>,
626 preprocess: &PreprocessResult,
627 config: &InferenceConfig,
628 names: &HashMap<usize, String>,
629 orig_img: Array3<u8>,
630 path: String,
631 speed: Speed,
632 inference_shape: (u32, u32),
633) -> Results {
634 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
635
636 if outputs.len() < 2 {
637 eprintln!(
639 "WARNING ⚠️ Segmentation model missing protos output (expected 2 outputs, got {}). Returning empty masks.",
640 outputs.len()
641 );
642 return results;
643 }
644
645 let (output0, shape0) = &outputs[0];
646 let (output1, shape1) = &outputs[1];
647
648 let num_masks = 32;
654 let expected_features = 4 + names.len() + num_masks;
655
656 let (num_preds, is_transposed) = if shape0.len() == 3 {
658 let (a, b) = (shape0[1], shape0[2]);
659 if a == expected_features {
660 (b, false) } else if b == expected_features {
662 (a, true) } else {
664 if a < b { (b, false) } else { (a, true) }
666 }
667 } else {
668 (0, false)
669 };
670
671 if output0.is_empty() || num_preds == 0 {
672 return results;
673 }
674
675 let output_2d = if is_transposed {
677 Array2::from_shape_vec((num_preds, expected_features), output0.to_vec())
678 .unwrap_or_else(|_| Array2::zeros((0, 0)))
679 } else {
680 let arr = Array2::from_shape_vec((expected_features, num_preds), output0.to_vec())
681 .unwrap_or_else(|_| Array2::zeros((0, 0)));
682 arr.t().to_owned()
683 };
684
685 let mut candidates = Vec::new(); for i in 0..num_preds {
689 let scores = output_2d.slice(s![i, 4..4 + names.len()]);
690 let (best_class, best_score) = scores
691 .iter()
692 .enumerate()
693 .max_by(|&(_, a), &(_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
694 .map_or((0, 0.0), |(idx, &score)| (idx, score));
695
696 if best_score < config.confidence_threshold {
697 continue;
698 }
699
700 let cx = output_2d[[i, 0]];
702 let cy = output_2d[[i, 1]];
703 let w = output_2d[[i, 2]];
704 let h = output_2d[[i, 3]];
705 let x1 = cx - w / 2.0;
706 let y1 = cy - h / 2.0;
707 let x2 = cx + w / 2.0;
708 let y2 = cy + h / 2.0;
709
710 let scaled = scale_coords(&[x1, y1, x2, y2], preprocess.scale, preprocess.padding);
711 let clipped = clip_coords(&scaled, preprocess.orig_shape);
712
713 if !config.keep_class(best_class) {
715 continue;
716 }
717
718 candidates.push((
719 [clipped[0], clipped[1], clipped[2], clipped[3]],
720 best_score,
721 best_class,
722 i, ));
724 }
725
726 if candidates.is_empty() {
727 return results;
728 }
729
730 let nms_candidates: Vec<_> = candidates
732 .iter()
733 .map(|(bbox, score, class, _)| (*bbox, *score, *class))
734 .collect();
735
736 let keep_indices = nms_per_class(&nms_candidates, config.iou_threshold);
737 let num_kept = keep_indices.len().min(config.max_det);
738
739 let mut boxes_data = Array2::zeros((num_kept, 6));
741 let mut mask_coeffs = Array2::zeros((num_kept, num_masks));
742
743 for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
744 let (bbox, score, class, orig_idx) = &candidates[keep_idx];
745 boxes_data[[out_idx, 0]] = bbox[0];
746 boxes_data[[out_idx, 1]] = bbox[1];
747 boxes_data[[out_idx, 2]] = bbox[2];
748 boxes_data[[out_idx, 3]] = bbox[3];
749 boxes_data[[out_idx, 4]] = *score;
750 boxes_data[[out_idx, 5]] = *class as f32;
751
752 let start = 4 + names.len();
754 let coeffs = output_2d.slice(s![*orig_idx, start..start + num_masks]);
755 for m in 0..num_masks {
756 mask_coeffs[[out_idx, m]] = coeffs[m];
757 }
758 }
759
760 results.boxes = Some(Boxes::new(boxes_data.clone(), preprocess.orig_shape));
761
762 if shape1.len() < 4 {
766 eprintln!(
767 "WARNING ⚠️ Protos output has unexpected shape (expected 4 dims, got {}). Skipping mask generation.",
768 shape1.len()
769 );
770 return results;
771 }
772 let mh = shape1[2];
773 let mw = shape1[3];
774
775 if shape1[1] != num_masks {
777 eprintln!(
778 "WARNING ⚠️ Protos output has {} mask channels, expected {}. Mask quality may be affected.",
779 shape1[1], num_masks
780 );
781 }
782
783 let protos = match Array2::from_shape_vec((num_masks, mh * mw), output1.to_vec()) {
784 Ok(arr) => arr,
785 Err(e) => {
786 eprintln!("WARNING ⚠️ Failed to create protos array: {e}. Skipping mask generation.");
787 return results;
788 }
789 };
790
791 let masks_flat = mask_coeffs.dot(&protos);
793
794 let (oh, ow) = preprocess.orig_shape;
796 let (th, tw) = inference_shape;
797 let (pad_top, pad_left) = preprocess.padding;
798
799 let scale_w = mw as f32 / tw as f32;
801 let scale_h = mh as f32 / th as f32;
802 let crop_x = pad_left * scale_w;
803 let crop_y = pad_top * scale_h;
804 let crop_w = 2.0f32.mul_add(-crop_x, mw as f32);
805 let crop_h = 2.0f32.mul_add(-crop_y, mh as f32);
806
807 let mut masks_data = Array3::zeros((num_kept, oh as usize, ow as usize));
809
810 Zip::from(masks_data.outer_iter_mut())
818 .and(masks_flat.outer_iter())
819 .and(boxes_data.outer_iter())
820 .par_for_each(
821 |mut mask_out: ArrayViewMut2<f32>,
822 mask_flat: ArrayView1<f32>,
823 box_data: ArrayView1<f32>| {
824 let mut resizer = Resizer::new();
826 let resize_alg = ResizeAlg::Convolution(FilterType::Bilinear);
827
828 let f32_data: Vec<f32> = mask_flat
830 .iter()
831 .map(|&val| 1.0 / (1.0 + (-val).exp()))
832 .collect();
833
834 let src_bytes: &[u8] = bytemuck::cast_slice(&f32_data);
836
837 let src_image = match Image::from_vec_u8(
839 mw as u32,
840 mh as u32,
841 src_bytes.to_vec(),
842 PixelType::F32,
843 ) {
844 Ok(img) => img,
845 Err(_) => return, };
847
848 let mut dst_image = Image::new(ow, oh, PixelType::F32);
850
851 let safe_crop_x = f64::from(crop_x.max(0.0));
853 let safe_crop_y = f64::from(crop_y.max(0.0));
854 let safe_crop_w = f64::from(crop_w.max(1.0).min(mw as f32));
855 let safe_crop_h = f64::from(crop_h.max(1.0).min(mh as f32));
856
857 let options = ResizeOptions::new().resize_alg(resize_alg).crop(
858 safe_crop_x,
859 safe_crop_y,
860 safe_crop_w,
861 safe_crop_h,
862 );
863
864 if resizer
866 .resize(&src_image, &mut dst_image, &options)
867 .is_err()
868 {
869 return;
870 }
871
872 let dst_bytes = dst_image.buffer();
874 let dst_slice: &[f32] = bytemuck::cast_slice(dst_bytes);
875
876 let x1 = box_data[0].max(0.0).min(ow as f32);
878 let y1 = box_data[1].max(0.0).min(oh as f32);
879 let x2 = box_data[2].max(0.0).min(ow as f32);
880 let y2 = box_data[3].max(0.0).min(oh as f32);
881
882 for y in 0..oh as usize {
883 for x in 0..ow as usize {
884 let val = dst_slice[y * ow as usize + x];
885 let x_f = x as f32;
886 let y_f = y as f32;
887 if x_f >= x1 && x_f <= x2 && y_f >= y1 && y_f <= y2 {
889 mask_out[[y, x]] = val;
890 }
891 }
892 }
893 },
894 );
895
896 results.masks = Some(Masks::new(masks_data, preprocess.orig_shape));
897
898 results
899}
900
901#[allow(
921 clippy::too_many_arguments,
922 clippy::too_many_lines,
923 clippy::similar_names,
924 clippy::type_complexity,
925 clippy::cast_precision_loss,
926 clippy::doc_lazy_continuation
927)]
928fn postprocess_pose(
929 output: &[f32],
930 output_shape: &[usize],
931 preprocess: &PreprocessResult,
932 config: &InferenceConfig,
933 names: &HashMap<usize, String>,
934 orig_img: Array3<u8>,
935 path: String,
936 speed: Speed,
937 inference_shape: (u32, u32),
938) -> Results {
939 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
940
941 let num_keypoints = 17;
943 let kpt_dim = 3; let kpt_features = num_keypoints * kpt_dim; let num_classes = names.len().max(1);
948 let expected_features = 4 + num_classes + kpt_features;
949
950 let (num_preds, is_transposed) = if output_shape.len() == 3 {
952 let (a, b) = (output_shape[1], output_shape[2]);
953 if a == expected_features || (a < b && a >= 4 + kpt_features) {
954 (b, false) } else {
956 (a, true) }
958 } else if output_shape.len() == 2 {
959 let (a, b) = (output_shape[0], output_shape[1]);
960 if a < b { (b, false) } else { (a, true) }
961 } else {
962 (0, false)
963 };
964
965 if output.is_empty() || num_preds == 0 {
966 return results;
967 }
968
969 let actual_features = output.len() / num_preds;
971 if actual_features < 4 + kpt_features {
972 eprintln!(
973 "WARNING ⚠️ Pose model has insufficient features ({actual_features}), expected at least {}",
974 4 + kpt_features
975 );
976 return results;
977 }
978
979 let output_2d = if is_transposed {
981 Array2::from_shape_vec((num_preds, actual_features), output.to_vec())
982 .unwrap_or_else(|_| Array2::zeros((0, 0)))
983 } else {
984 let arr = Array2::from_shape_vec((actual_features, num_preds), output.to_vec())
985 .unwrap_or_else(|_| Array2::zeros((0, 0)));
986 arr.t().to_owned()
987 };
988
989 if output_2d.is_empty() {
990 return results;
991 }
992
993 let derived_classes = actual_features.saturating_sub(4 + kpt_features);
995 let num_classes = derived_classes.max(1);
996
997 let mut candidates: Vec<([f32; 4], f32, usize, Vec<[f32; 3]>)> = Vec::new();
999
1000 for i in 0..num_preds {
1001 let class_scores = output_2d.slice(s![i, 4..4 + num_classes]);
1003 let (best_class, best_score) = class_scores
1004 .iter()
1005 .enumerate()
1006 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
1007 .map_or((0, 0.0), |(idx, &score)| {
1008 (idx, if score.is_nan() { 0.0 } else { score })
1009 });
1010
1011 if best_score < config.confidence_threshold {
1012 continue;
1013 }
1014
1015 let cx = output_2d[[i, 0]];
1017 let cy = output_2d[[i, 1]];
1018 let w = output_2d[[i, 2]];
1019 let h = output_2d[[i, 3]];
1020
1021 let x1 = cx - w / 2.0;
1023 let y1 = cy - h / 2.0;
1024 let x2 = cx + w / 2.0;
1025 let y2 = cy + h / 2.0;
1026
1027 let scaled = scale_coords(&[x1, y1, x2, y2], preprocess.scale, preprocess.padding);
1029 let clipped = clip_coords(&scaled, preprocess.orig_shape);
1030
1031 let kpt_start = 4 + num_classes;
1033 let mut keypoints = Vec::with_capacity(num_keypoints);
1034 for k in 0..num_keypoints {
1035 let kpt_offset = kpt_start + k * kpt_dim;
1036 let kpt_x = output_2d[[i, kpt_offset]];
1037 let kpt_y = output_2d[[i, kpt_offset + 1]];
1038 let kpt_conf = output_2d[[i, kpt_offset + 2]];
1039
1040 let scaled_kpt = scale_coords(
1042 &[kpt_x, kpt_y, kpt_x, kpt_y],
1043 preprocess.scale,
1044 preprocess.padding,
1045 );
1046 let (oh, ow) = preprocess.orig_shape;
1047 #[allow(clippy::cast_precision_loss)]
1048 let scaled_x = scaled_kpt[0].max(0.0).min(ow as f32);
1049 #[allow(clippy::cast_precision_loss)]
1050 let scaled_y = scaled_kpt[1].max(0.0).min(oh as f32);
1051
1052 keypoints.push([scaled_x, scaled_y, kpt_conf]);
1053 }
1054
1055 if !config.keep_class(best_class) {
1057 continue;
1058 }
1059
1060 candidates.push((
1061 [clipped[0], clipped[1], clipped[2], clipped[3]],
1062 best_score,
1063 best_class,
1064 keypoints,
1065 ));
1066 }
1067
1068 if candidates.is_empty() {
1069 results.keypoints = Some(Keypoints::new(
1070 Array3::zeros((0, num_keypoints, kpt_dim)),
1071 preprocess.orig_shape,
1072 ));
1073 return results;
1074 }
1075
1076 let nms_candidates: Vec<_> = candidates
1078 .iter()
1079 .map(|(bbox, score, class, _)| (*bbox, *score, *class))
1080 .collect();
1081 let keep_indices = nms_per_class(&nms_candidates, config.iou_threshold);
1082 let num_kept = keep_indices.len().min(config.max_det);
1083
1084 let mut boxes_data = Array2::zeros((num_kept, 6));
1086 let mut keypoints_data = Array3::zeros((num_kept, num_keypoints, kpt_dim));
1087
1088 for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
1089 let (bbox, score, class, kpts) = &candidates[keep_idx];
1090
1091 boxes_data[[out_idx, 0]] = bbox[0];
1093 boxes_data[[out_idx, 1]] = bbox[1];
1094 boxes_data[[out_idx, 2]] = bbox[2];
1095 boxes_data[[out_idx, 3]] = bbox[3];
1096 boxes_data[[out_idx, 4]] = *score;
1097 #[allow(clippy::cast_precision_loss)]
1098 let class_f32 = *class as f32;
1099 boxes_data[[out_idx, 5]] = class_f32;
1100
1101 for (k, kpt) in kpts.iter().enumerate() {
1103 keypoints_data[[out_idx, k, 0]] = kpt[0]; keypoints_data[[out_idx, k, 1]] = kpt[1]; keypoints_data[[out_idx, k, 2]] = kpt[2]; }
1107 }
1108
1109 results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1110 results.keypoints = Some(Keypoints::new(keypoints_data, preprocess.orig_shape));
1111
1112 results
1113}
1114
1115fn postprocess_classify(
1132 output: &[f32],
1133 names: &HashMap<usize, String>,
1134 orig_img: Array3<u8>,
1135 path: String,
1136 speed: Speed,
1137 inference_shape: (u32, u32),
1138) -> Results {
1139 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1140
1141 if output.is_empty() {
1142 return results;
1143 }
1144
1145 let mut probs_vec = output.to_vec();
1147
1148 let sum: f32 = probs_vec.iter().sum();
1150 if (sum - 1.0).abs() > 0.1 && sum > 0.0 {
1151 let max_val = probs_vec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1153 let exp_vals: Vec<f32> = probs_vec.iter().map(|&v| (v - max_val).exp()).collect();
1154 let exp_sum: f32 = exp_vals.iter().sum();
1155 if exp_sum > 0.0 {
1156 probs_vec = exp_vals.iter().map(|&v| v / exp_sum).collect();
1157 }
1158 }
1159
1160 let probs = ndarray::Array1::from_vec(probs_vec);
1161 results.probs = Some(Probs::new(probs));
1162
1163 results
1164}
1165
1166#[allow(
1186 clippy::too_many_arguments,
1187 clippy::too_many_lines,
1188 clippy::similar_names
1189)]
1190fn postprocess_obb(
1191 output: &[f32],
1192 output_shape: &[usize],
1193 preprocess: &PreprocessResult,
1194 config: &InferenceConfig,
1195 names: &HashMap<usize, String>,
1196 orig_img: Array3<u8>,
1197 path: String,
1198 speed: Speed,
1199 inference_shape: (u32, u32),
1200) -> Results {
1201 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1202
1203 let num_classes = names.len().max(1);
1206 let expected_features = 4 + num_classes + 1;
1207
1208 let (num_preds, is_transposed) = if output_shape.len() == 3 {
1210 let (a, b) = (output_shape[1], output_shape[2]);
1211 if a == expected_features || (a < b && a >= 6) {
1212 (b, false) } else {
1214 (a, true) }
1216 } else if output_shape.len() == 2 {
1217 let (a, b) = (output_shape[0], output_shape[1]);
1218 if a < b { (b, false) } else { (a, true) }
1219 } else {
1220 (0, false)
1221 };
1222
1223 if output.is_empty() || num_preds == 0 {
1224 return results;
1225 }
1226
1227 let actual_features = output.len() / num_preds;
1229 if actual_features < 6 {
1230 eprintln!(
1231 "WARNING ⚠️ OBB model has insufficient features ({actual_features}), expected at least 6"
1232 );
1233 return results;
1234 }
1235
1236 let output_2d = if is_transposed {
1238 Array2::from_shape_vec((num_preds, actual_features), output.to_vec())
1239 .unwrap_or_else(|_| Array2::zeros((0, 0)))
1240 } else {
1241 let arr = Array2::from_shape_vec((actual_features, num_preds), output.to_vec())
1242 .unwrap_or_else(|_| Array2::zeros((0, 0)));
1243 arr.t().to_owned()
1244 };
1245
1246 if output_2d.is_empty() {
1247 return results;
1248 }
1249
1250 let derived_classes = actual_features.saturating_sub(5); let num_classes = derived_classes.max(1);
1253
1254 let mut candidates: Vec<([f32; 5], f32, usize)> = Vec::new(); for i in 0..num_preds {
1258 let class_scores = output_2d.slice(s![i, 4..4 + num_classes]);
1260 let (best_class, best_score) = class_scores
1261 .iter()
1262 .enumerate()
1263 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
1264 .map_or((0, 0.0), |(idx, &score)| {
1265 (idx, if score.is_nan() { 0.0 } else { score })
1266 });
1267
1268 if best_score < config.confidence_threshold {
1269 continue;
1270 }
1271
1272 let cx = output_2d[[i, 0]];
1274 let cy = output_2d[[i, 1]];
1275 let w = output_2d[[i, 2]];
1276 let h = output_2d[[i, 3]];
1277 let angle = output_2d[[i, 4 + num_classes]]; let scaled = scale_coords(&[cx, cy, cx, cy], preprocess.scale, preprocess.padding);
1281 let scaled_cx = scaled[0];
1282 let scaled_cy = scaled[1];
1283
1284 let scaled_w = w / preprocess.scale.1;
1286 let scaled_h = h / preprocess.scale.0;
1287
1288 let (oh, ow) = preprocess.orig_shape;
1290 #[allow(clippy::cast_precision_loss)]
1291 let clipped_cx = scaled_cx.max(0.0).min(ow as f32);
1292 #[allow(clippy::cast_precision_loss)]
1293 let clipped_cy = scaled_cy.max(0.0).min(oh as f32);
1294
1295 if !config.keep_class(best_class) {
1297 continue;
1298 }
1299
1300 candidates.push((
1301 [clipped_cx, clipped_cy, scaled_w, scaled_h, angle],
1302 best_score,
1303 best_class,
1304 ));
1305 }
1306
1307 if candidates.is_empty() {
1308 results.obb = Some(Obb::new(Array2::zeros((0, 7)), preprocess.orig_shape));
1309 return results;
1310 }
1311
1312 let keep_indices = nms_rotated_per_class(&candidates, config.iou_threshold);
1316 let num_kept = keep_indices.len().min(config.max_det);
1317
1318 let mut obb_data = Array2::zeros((num_kept, 7));
1320
1321 for (out_idx, &keep_idx) in keep_indices.iter().take(num_kept).enumerate() {
1322 let (xywhr, score, class) = &candidates[keep_idx];
1323 obb_data[[out_idx, 0]] = xywhr[0]; obb_data[[out_idx, 1]] = xywhr[1]; obb_data[[out_idx, 2]] = xywhr[2]; obb_data[[out_idx, 3]] = xywhr[3]; obb_data[[out_idx, 4]] = xywhr[4]; obb_data[[out_idx, 5]] = *score;
1329 #[allow(clippy::cast_precision_loss)]
1330 let class_f32 = *class as f32;
1331 obb_data[[out_idx, 6]] = class_f32;
1332 }
1333
1334 results.obb = Some(Obb::new(obb_data, preprocess.orig_shape));
1335
1336 results
1337}
1338
1339#[inline]
1349fn scale_xyxy(
1350 x1: f32,
1351 y1: f32,
1352 x2: f32,
1353 y2: f32,
1354 preprocess: &PreprocessResult,
1355) -> (f32, f32, f32, f32) {
1356 let (scale_y, scale_x) = preprocess.scale;
1357 let (pad_top, pad_left) = preprocess.padding;
1358 (
1359 (x1 - pad_left) / scale_x,
1360 (y1 - pad_top) / scale_y,
1361 (x2 - pad_left) / scale_x,
1362 (y2 - pad_top) / scale_y,
1363 )
1364}
1365
1366#[allow(clippy::too_many_arguments, clippy::cast_precision_loss)]
1368fn postprocess_detect_end2end(
1369 output: &[f32],
1370 output_shape: &[usize],
1371 preprocess: &PreprocessResult,
1372 config: &InferenceConfig,
1373 names: &HashMap<usize, String>,
1374 orig_img: Array3<u8>,
1375 path: String,
1376 speed: Speed,
1377 inference_shape: (u32, u32),
1378) -> Results {
1379 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1380
1381 if output_shape.len() != 3 || output.is_empty() {
1382 return results;
1383 }
1384 let max_det = output_shape[1];
1385 let feats = output_shape[2];
1386 if feats < 6 || max_det == 0 {
1387 return results;
1388 }
1389
1390 let (oh, ow) = preprocess.orig_shape;
1391 let (max_w, max_h) = (ow as f32, oh as f32);
1392 let user_cap = config.max_det.min(max_det);
1393
1394 let mut flat: Vec<f32> = Vec::with_capacity(user_cap * 6);
1395 for i in 0..max_det {
1396 let base = i * feats;
1397 let conf = output[base + 4];
1398 if conf < config.confidence_threshold {
1400 break;
1401 }
1402 let cls = output[base + 5] as usize;
1403 if !config.keep_class(cls) {
1404 continue;
1405 }
1406 let (x1, y1, x2, y2) = scale_xyxy(
1407 output[base],
1408 output[base + 1],
1409 output[base + 2],
1410 output[base + 3],
1411 preprocess,
1412 );
1413 flat.extend_from_slice(&[
1414 x1.clamp(0.0, max_w),
1415 y1.clamp(0.0, max_h),
1416 x2.clamp(0.0, max_w),
1417 y2.clamp(0.0, max_h),
1418 conf,
1419 cls as f32,
1420 ]);
1421 if flat.len() >= user_cap * 6 {
1422 break;
1423 }
1424 }
1425
1426 let n = flat.len() / 6;
1427 if n > 0 {
1428 let boxes_data = Array2::from_shape_vec((n, 6), flat).expect("flat length matches (n, 6)");
1429 results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1430 }
1431 results
1432}
1433
1434#[allow(
1438 clippy::too_many_arguments,
1439 clippy::cast_precision_loss,
1440 clippy::too_many_lines,
1441 clippy::needless_pass_by_value,
1442 clippy::similar_names,
1443 clippy::manual_let_else
1444)]
1445fn postprocess_segment_end2end(
1446 outputs: Vec<(&[f32], Vec<usize>)>,
1447 preprocess: &PreprocessResult,
1448 config: &InferenceConfig,
1449 names: &HashMap<usize, String>,
1450 orig_img: Array3<u8>,
1451 path: String,
1452 speed: Speed,
1453 inference_shape: (u32, u32),
1454) -> Results {
1455 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1456 if outputs.len() < 2 {
1457 eprintln!(
1458 "WARNING ⚠️ End2end segmentation missing protos output (got {} outputs).",
1459 outputs.len()
1460 );
1461 return results;
1462 }
1463 let (output0, shape0) = &outputs[0];
1464 let (output1, shape1) = &outputs[1];
1465
1466 if shape0.len() != 3 || shape1.len() != 4 {
1467 return results;
1468 }
1469 let max_det = shape0[1];
1470 let feats = shape0[2];
1471 let num_masks = shape1[1];
1472 if feats < 6 + num_masks {
1473 eprintln!("WARNING ⚠️ End2end segment features ({feats}) < 6 + num_masks ({num_masks}).");
1474 return results;
1475 }
1476
1477 let (oh, ow) = preprocess.orig_shape;
1478 let (max_w, max_h) = (ow as f32, oh as f32);
1479 let user_cap = config.max_det.min(max_det);
1480
1481 let mut flat_boxes: Vec<f32> = Vec::with_capacity(user_cap * 6);
1482 let mut flat_coeffs: Vec<f32> = Vec::with_capacity(user_cap * num_masks);
1483
1484 for i in 0..max_det {
1485 let base = i * feats;
1486 let conf = output0[base + 4];
1487 if conf < config.confidence_threshold {
1488 break;
1489 }
1490 let cls = output0[base + 5] as usize;
1491 if !config.keep_class(cls) {
1492 continue;
1493 }
1494 let (x1, y1, x2, y2) = scale_xyxy(
1495 output0[base],
1496 output0[base + 1],
1497 output0[base + 2],
1498 output0[base + 3],
1499 preprocess,
1500 );
1501 flat_boxes.extend_from_slice(&[
1502 x1.clamp(0.0, max_w),
1503 y1.clamp(0.0, max_h),
1504 x2.clamp(0.0, max_w),
1505 y2.clamp(0.0, max_h),
1506 conf,
1507 cls as f32,
1508 ]);
1509 let coeff_start = base + 6;
1510 flat_coeffs.extend_from_slice(&output0[coeff_start..coeff_start + num_masks]);
1511 if flat_boxes.len() >= user_cap * 6 {
1512 break;
1513 }
1514 }
1515
1516 let num_kept = flat_boxes.len() / 6;
1517 if num_kept == 0 {
1518 return results;
1519 }
1520
1521 let boxes_data =
1522 Array2::from_shape_vec((num_kept, 6), flat_boxes).expect("flat length matches (n, 6)");
1523 let mask_coeffs = Array2::from_shape_vec((num_kept, num_masks), flat_coeffs)
1524 .expect("flat length matches (n, num_masks)");
1525
1526 let mh = shape1[2];
1528 let mw = shape1[3];
1529 let protos = match Array2::from_shape_vec((num_masks, mh * mw), output1.to_vec()) {
1530 Ok(a) => a,
1531 Err(e) => {
1532 eprintln!("WARNING ⚠️ Failed to build protos array: {e}. Skipping masks.");
1533 return results;
1534 }
1535 };
1536 let masks_flat = mask_coeffs.dot(&protos);
1537
1538 let (th, tw) = inference_shape;
1539 let (pad_top, pad_left) = preprocess.padding;
1540 let scale_w = mw as f32 / tw as f32;
1541 let scale_h = mh as f32 / th as f32;
1542 let crop_x = pad_left * scale_w;
1543 let crop_y = pad_top * scale_h;
1544 let crop_w = 2.0f32.mul_add(-crop_x, mw as f32);
1545 let crop_h = 2.0f32.mul_add(-crop_y, mh as f32);
1546
1547 let mut masks_data = Array3::zeros((num_kept, oh as usize, ow as usize));
1548 Zip::from(masks_data.outer_iter_mut())
1549 .and(masks_flat.outer_iter())
1550 .and(boxes_data.outer_iter())
1551 .par_for_each(
1552 |mut mask_out: ArrayViewMut2<f32>,
1553 mask_flat: ArrayView1<f32>,
1554 box_data: ArrayView1<f32>| {
1555 let mut resizer = Resizer::new();
1556 let resize_alg = ResizeAlg::Convolution(FilterType::Bilinear);
1557 let f32_data: Vec<f32> = mask_flat
1558 .iter()
1559 .map(|&v| 1.0 / (1.0 + (-v).exp()))
1560 .collect();
1561 let src_bytes: &[u8] = bytemuck::cast_slice(&f32_data);
1562 let src_image = match Image::from_vec_u8(
1563 mw as u32,
1564 mh as u32,
1565 src_bytes.to_vec(),
1566 PixelType::F32,
1567 ) {
1568 Ok(i) => i,
1569 Err(_) => return,
1570 };
1571 let mut dst_image = Image::new(ow, oh, PixelType::F32);
1572 let options = ResizeOptions::new().resize_alg(resize_alg).crop(
1573 f64::from(crop_x.max(0.0)),
1574 f64::from(crop_y.max(0.0)),
1575 f64::from(crop_w.max(1.0).min(mw as f32)),
1576 f64::from(crop_h.max(1.0).min(mh as f32)),
1577 );
1578 if resizer
1579 .resize(&src_image, &mut dst_image, &options)
1580 .is_err()
1581 {
1582 return;
1583 }
1584 let dst_bytes = dst_image.buffer();
1585 let dst_slice: &[f32] = bytemuck::cast_slice(dst_bytes);
1586 let x1 = box_data[0].max(0.0).min(ow as f32);
1587 let y1 = box_data[1].max(0.0).min(oh as f32);
1588 let x2 = box_data[2].max(0.0).min(ow as f32);
1589 let y2 = box_data[3].max(0.0).min(oh as f32);
1590 for y in 0..oh as usize {
1591 for x in 0..ow as usize {
1592 let val = dst_slice[y * ow as usize + x];
1593 let xf = x as f32;
1594 let yf = y as f32;
1595 if xf >= x1 && xf <= x2 && yf >= y1 && yf <= y2 {
1596 mask_out[[y, x]] = val;
1597 }
1598 }
1599 }
1600 },
1601 );
1602
1603 results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1604 results.masks = Some(Masks::new(masks_data, preprocess.orig_shape));
1605 results
1606}
1607
1608#[allow(
1610 clippy::too_many_arguments,
1611 clippy::cast_precision_loss,
1612 clippy::similar_names
1613)]
1614fn postprocess_pose_end2end(
1615 output: &[f32],
1616 output_shape: &[usize],
1617 preprocess: &PreprocessResult,
1618 config: &InferenceConfig,
1619 names: &HashMap<usize, String>,
1620 orig_img: Array3<u8>,
1621 path: String,
1622 speed: Speed,
1623 inference_shape: (u32, u32),
1624 nk: usize,
1625 kpt_dim: usize,
1626) -> Results {
1627 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1628 if output_shape.len() != 3 || output.is_empty() || nk == 0 || kpt_dim < 2 {
1629 return results;
1630 }
1631 let max_det = output_shape[1];
1632 let feats = output_shape[2];
1633 if feats < 6 + nk * kpt_dim || max_det == 0 {
1634 return results;
1635 }
1636
1637 let (oh, ow) = preprocess.orig_shape;
1638 let (max_w, max_h) = (ow as f32, oh as f32);
1639 let (scale_y, scale_x) = preprocess.scale;
1640 let (pad_top, pad_left) = preprocess.padding;
1641 let user_cap = config.max_det.min(max_det);
1642
1643 let mut flat_boxes: Vec<f32> = Vec::with_capacity(user_cap * 6);
1644 let mut flat_kpts: Vec<f32> = Vec::with_capacity(user_cap * nk * 3);
1645
1646 for i in 0..max_det {
1647 let base = i * feats;
1648 let conf = output[base + 4];
1649 if conf < config.confidence_threshold {
1650 break;
1651 }
1652 let cls = output[base + 5] as usize;
1653 if !config.keep_class(cls) {
1654 continue;
1655 }
1656 let (x1, y1, x2, y2) = scale_xyxy(
1657 output[base],
1658 output[base + 1],
1659 output[base + 2],
1660 output[base + 3],
1661 preprocess,
1662 );
1663 flat_boxes.extend_from_slice(&[
1664 x1.clamp(0.0, max_w),
1665 y1.clamp(0.0, max_h),
1666 x2.clamp(0.0, max_w),
1667 y2.clamp(0.0, max_h),
1668 conf,
1669 cls as f32,
1670 ]);
1671 let kstart = base + 6;
1672 for k in 0..nk {
1673 let off = kstart + k * kpt_dim;
1674 let sx = (output[off] - pad_left) / scale_x;
1675 let sy = (output[off + 1] - pad_top) / scale_y;
1676 let kconf = if kpt_dim >= 3 { output[off + 2] } else { 1.0 };
1677 flat_kpts.extend_from_slice(&[sx.clamp(0.0, max_w), sy.clamp(0.0, max_h), kconf]);
1678 }
1679 if flat_boxes.len() >= user_cap * 6 {
1680 break;
1681 }
1682 }
1683
1684 let n = flat_boxes.len() / 6;
1685 let kdata =
1687 Array3::from_shape_vec((n, nk, 3), flat_kpts).expect("flat length matches (n, nk, 3)");
1688 results.keypoints = Some(Keypoints::new(kdata, preprocess.orig_shape));
1689 if n > 0 {
1690 let boxes_data =
1691 Array2::from_shape_vec((n, 6), flat_boxes).expect("flat length matches (n, 6)");
1692 results.boxes = Some(Boxes::new(boxes_data, preprocess.orig_shape));
1693 }
1694 results
1695}
1696
1697#[allow(clippy::too_many_arguments, clippy::cast_precision_loss)]
1700fn postprocess_obb_end2end(
1701 output: &[f32],
1702 output_shape: &[usize],
1703 preprocess: &PreprocessResult,
1704 config: &InferenceConfig,
1705 names: &HashMap<usize, String>,
1706 orig_img: Array3<u8>,
1707 path: String,
1708 speed: Speed,
1709 inference_shape: (u32, u32),
1710) -> Results {
1711 let mut results = Results::new(orig_img, path, names.clone(), speed, inference_shape);
1712 let mut flat: Vec<f32> = Vec::new();
1713
1714 if output_shape.len() == 3 && !output.is_empty() {
1715 let max_det = output_shape[1];
1716 let feats = output_shape[2];
1717 if feats >= 7 && max_det > 0 {
1718 let (oh, ow) = preprocess.orig_shape;
1719 let (max_w, max_h) = (ow as f32, oh as f32);
1720 let (scale_y, scale_x) = preprocess.scale;
1721 let (pad_top, pad_left) = preprocess.padding;
1722 let user_cap = config.max_det.min(max_det);
1723 flat.reserve(user_cap * 7);
1724
1725 for i in 0..max_det {
1726 let base = i * feats;
1727 let conf = output[base + 4];
1728 if conf < config.confidence_threshold {
1729 break;
1730 }
1731 let cls = output[base + 5] as usize;
1732 if !config.keep_class(cls) {
1733 continue;
1734 }
1735 let cx = (output[base] - pad_left) / scale_x;
1736 let cy = (output[base + 1] - pad_top) / scale_y;
1737 flat.extend_from_slice(&[
1738 cx.clamp(0.0, max_w),
1739 cy.clamp(0.0, max_h),
1740 output[base + 2] / scale_x,
1741 output[base + 3] / scale_y,
1742 output[base + 6],
1743 conf,
1744 cls as f32,
1745 ]);
1746 if flat.len() >= user_cap * 7 {
1747 break;
1748 }
1749 }
1750 }
1751 }
1752
1753 let n = flat.len() / 7;
1754 let obb_data = Array2::from_shape_vec((n, 7), flat).expect("flat length matches (n, 7)");
1755 results.obb = Some(Obb::new(obb_data, preprocess.orig_shape));
1756 results
1757}
1758
1759#[cfg(test)]
1760mod tests {
1761 use super::*;
1762
1763 #[test]
1764 fn test_parse_detect_shape() {
1765 let (nc, np, transposed) = parse_detect_shape(&[1, 84, 8400], 80);
1767 assert_eq!(nc, 80);
1768 assert_eq!(np, 8400);
1769 assert!(!transposed);
1770
1771 let (nc, np, transposed) = parse_detect_shape(&[1, 8400, 84], 80);
1773 assert_eq!(nc, 80);
1774 assert_eq!(np, 8400);
1775 assert!(transposed);
1776 }
1777
1778 #[test]
1779 fn test_infer_end2end_kpt_shape() {
1780 assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 51]), Some((17, 3)));
1782 assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 34]), Some((17, 2)));
1784 assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6 + 36]), None);
1786 assert_eq!(infer_end2end_kpt_shape(&[1, 56, 8400]), None);
1788 assert_eq!(infer_end2end_kpt_shape(&[1, 300, 6]), None);
1790 }
1791
1792 #[test]
1793 fn test_parse_detect_shape_no_metadata() {
1794 let (nc, np, transposed) = parse_detect_shape(&[1, 84, 8400], 0);
1797 assert_eq!(nc, 80); assert_eq!(np, 8400);
1799 assert!(!transposed);
1800
1801 let (nc, np, transposed) = parse_detect_shape(&[1, 8400, 84], 0);
1803 assert_eq!(nc, 80); assert_eq!(np, 8400);
1805 assert!(transposed);
1806 }
1807
1808 #[test]
1809 fn test_empty_output() {
1810 let output: Vec<f32> = vec![];
1811 let preprocess = PreprocessResult {
1812 tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1813 tensor_f16: None,
1814 orig_shape: (480, 640),
1815 scale: (1.0, 1.0),
1816 padding: (0.0, 0.0),
1817 };
1818 let config = InferenceConfig::default();
1819 let names = HashMap::new();
1820 let orig_img = ndarray::Array3::zeros((480, 640, 3));
1821
1822 let results = postprocess_detect(
1823 &output,
1824 &[1, 84, 0],
1825 &preprocess,
1826 &config,
1827 &names,
1828 orig_img,
1829 String::new(),
1830 Speed::default(),
1831 (640, 640),
1832 );
1833
1834 assert!(results.is_empty());
1835 }
1836
1837 #[test]
1838 fn test_nan_scores_handled() {
1839 let mut output: Vec<f32> = vec![0.0; 84]; output[0] = 100.0; output[1] = 100.0; output[2] = 50.0; output[3] = 50.0; output[4] = f32::NAN;
1848 output[5] = 0.9; let preprocess = PreprocessResult {
1851 tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1852 tensor_f16: None,
1853 orig_shape: (640, 640),
1854 scale: (1.0, 1.0),
1855 padding: (0.0, 0.0),
1856 };
1857 let config = InferenceConfig::default();
1858 let mut names = HashMap::new();
1859 names.insert(0, "class0".to_string());
1860 names.insert(1, "class1".to_string());
1861 let orig_img = ndarray::Array3::zeros((640, 640, 3));
1862
1863 let results = postprocess_detect(
1865 &output,
1866 &[1, 84, 1],
1867 &preprocess,
1868 &config,
1869 &names,
1870 orig_img,
1871 String::new(),
1872 Speed::default(),
1873 (640, 640),
1874 );
1875
1876 let _ = results;
1880 }
1881
1882 #[test]
1883 fn test_malformed_shape_fallback() {
1884 let output: Vec<f32> = vec![0.0; 100]; let preprocess = PreprocessResult {
1888 tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1889 tensor_f16: None,
1890 orig_shape: (640, 640),
1891 scale: (1.0, 1.0),
1892 padding: (0.0, 0.0),
1893 };
1894 let config = InferenceConfig::default();
1895 let names = HashMap::new();
1896 let orig_img = ndarray::Array3::zeros((640, 640, 3));
1897
1898 let results = postprocess_detect(
1900 &output,
1901 &[],
1902 &preprocess,
1903 &config,
1904 &names,
1905 orig_img.clone(),
1906 String::new(),
1907 Speed::default(),
1908 (640, 640),
1909 );
1910 assert!(results.is_empty());
1911
1912 let results = postprocess_detect(
1914 &output,
1915 &[100],
1916 &preprocess,
1917 &config,
1918 &names,
1919 orig_img,
1920 String::new(),
1921 Speed::default(),
1922 (640, 640),
1923 );
1924 assert!(results.is_empty());
1925 }
1926
1927 #[test]
1928 fn test_postprocess_pose_logic() {
1929 let num_preds = 100;
1932 let num_features = 56;
1933 let mut output = vec![0.0; num_preds * num_features];
1934
1935 let idx = 0;
1937 output[idx] = 100.0;
1939 output[idx + num_preds] = 100.0;
1940 output[idx + num_preds * 2] = 50.0;
1941 output[idx + num_preds * 3] = 50.0;
1942 output[idx + num_preds * 4] = 0.9;
1944 for k in 0..17 {
1946 let offset = 5 + k * 3;
1947 output[idx + num_preds * offset] = 100.0; output[idx + num_preds * (offset + 1)] = 100.0; output[idx + num_preds * (offset + 2)] = 0.8; }
1951
1952 let preprocess = PreprocessResult {
1953 tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
1954 tensor_f16: None,
1955 orig_shape: (640, 640),
1956 scale: (1.0, 1.0),
1957 padding: (0.0, 0.0),
1958 };
1959 let config = InferenceConfig::default();
1960 let mut names = HashMap::new();
1961 names.insert(0, "person".to_string());
1962
1963 let results = postprocess_pose(
1965 &output,
1966 &[1, num_features, num_preds],
1967 &preprocess,
1968 &config,
1969 &names,
1970 ndarray::Array3::zeros((640, 640, 3)),
1971 "test.jpg".to_string(),
1972 Speed::default(),
1973 (640, 640),
1974 );
1975
1976 assert!(results.keypoints.is_some());
1977 let kpts = results.keypoints.unwrap();
1978 assert_eq!(kpts.data.shape()[0], 1); assert_eq!(kpts.data.shape()[1], 17); assert_eq!(kpts.data.shape()[2], 3); #[allow(clippy::float_cmp)]
1983 {
1984 assert_eq!(kpts.data[[0, 0, 0]], 100.0);
1986 assert_eq!(kpts.data[[0, 0, 2]], 0.8);
1987 }
1988 }
1989
1990 #[test]
1991 fn test_postprocess_obb_logic() {
1992 let num_preds = 100;
1995 let num_features = 6;
1996 let mut output = vec![0.0; num_preds * num_features];
1997
1998 let idx = 0;
2000 output[idx] = 100.0;
2002 output[idx + num_preds] = 100.0;
2003 output[idx + num_preds * 2] = 50.0;
2004 output[idx + num_preds * 3] = 20.0;
2005 output[idx + num_preds * 4] = 0.95;
2007 output[idx + num_preds * 5] = std::f32::consts::FRAC_PI_4; let preprocess = PreprocessResult {
2011 tensor: ndarray::Array4::zeros((1, 3, 640, 640)),
2012 tensor_f16: None,
2013 orig_shape: (640, 640),
2014 scale: (1.0, 1.0),
2015 padding: (0.0, 0.0),
2016 };
2017 let config = InferenceConfig::default();
2018 let mut names = HashMap::new();
2019 names.insert(0, "object".to_string());
2020
2021 let results = postprocess_obb(
2023 &output,
2024 &[1, num_features, num_preds],
2025 &preprocess,
2026 &config,
2027 &names,
2028 ndarray::Array3::zeros((640, 640, 3)),
2029 "test.jpg".to_string(),
2030 Speed::default(),
2031 (640, 640),
2032 );
2033
2034 assert!(results.obb.is_some());
2035 let obb = results.obb.unwrap();
2036 assert_eq!(obb.len(), 1);
2037
2038 let data = obb.data.row(0);
2040 #[allow(clippy::float_cmp)]
2041 {
2042 assert_eq!(data[0], 100.0); assert_eq!(data[4], std::f32::consts::FRAC_PI_4); assert_eq!(data[5], 0.95); }
2046 }
2047}