edgefirst_decoder/decoder/mod.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 /// NMS mode (always a concrete variant after build — `Nms::Auto` is
20 /// resolved during `DecoderBuilder::build()` and never stored here):
21 /// - `Some(ClassAgnostic)` — class-agnostic NMS
22 /// - `Some(ClassAware)` — class-aware NMS
23 /// - `None` — NMS bypassed (end-to-end models)
24 pub nms: Option<configs::Nms>,
25 /// Maximum number of candidate boxes fed into NMS after score filtering.
26 /// Reduces O(N²) NMS cost when many low-confidence proposals pass the
27 /// threshold (common during COCO mAP evaluation with threshold ≈ 0.001).
28 /// Candidates are ranked by score; only the top `pre_nms_top_k` proceed
29 /// to NMS. Default: 300. Ignored when `nms` is `None`.
30 ///
31 /// # ⚠️ Validation vs Deployment
32 ///
33 /// The default of 300 is tuned for **deployment** (score threshold ≥ 0.25)
34 /// where few anchors pass the score filter, making top-K a no-op in
35 /// practice while bounding worst-case NMS cost.
36 ///
37 /// For **mAP evaluation** (score threshold ≈ 0.001), most of the 8 400
38 /// YOLO anchors pass the score filter. At `pre_nms_top_k = 300`, roughly
39 /// 74 % of candidates that would survive NMS are discarded *before* NMS
40 /// runs, causing **~9 pp box mAP loss** — a measurement artifact, not a
41 /// model quality issue.
42 ///
43 /// | Use case | `pre_nms_top_k` | `score_threshold` |
44 /// |----------|----------------:|------------------:|
45 /// | Deployment | 300 (default) | ≥ 0.25 |
46 /// | COCO mAP evaluation | 8 400 (all anchors) | 0.001 |
47 /// | Unbounded | 0 (no limit) | any |
48 ///
49 /// Post-processing latency scales with the number of candidates entering
50 /// NMS. At deployment thresholds the candidate count is already small, so
51 /// raising `pre_nms_top_k` has negligible cost. At validation thresholds
52 /// the increase is measurable but necessary for correct recall.
53 pub pre_nms_top_k: usize,
54 /// Maximum number of detections returned after NMS. Matches the
55 /// Ultralytics `max_det` parameter. Default: 300.
56 ///
57 /// This bound applies uniformly across all segmentation and detection
58 /// decode paths reached via [`Decoder::decode`] / [`Decoder::decode_proto`].
59 /// The output `Vec`'s capacity is only an allocation hint; the post-NMS
60 /// detection count is bounded solely by `max_det` (EDGEAI-1302).
61 pub max_det: usize,
62 /// Whether decoded boxes are in normalized [0,1] coordinates.
63 /// - `Some(true)`: Coordinates in [0,1] range
64 /// - `Some(false)`: Pixel coordinates
65 /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
66 /// 1.0)
67 normalized: Option<bool>,
68 /// Model input spatial dimensions `(width, height)`, captured from
69 /// the schema's `input.shape` / `input.dshape` at builder time.
70 /// Required to honour `normalized: false`: pixel-space box coords
71 /// emitted by the model are divided by these dimensions before NMS
72 /// so the post-NMS bbox is in `[0, 1]`. `None` when no schema input
73 /// spec is available — the legacy >2.0 reject in `protobox` then
74 /// preserves the previous safety net (EDGEAI-1303).
75 input_dims: Option<(usize, usize)>,
76 /// Schema v2 merge program. Present when the decoder was built from
77 /// a [`crate::schema::SchemaV2`] whose logical outputs carry
78 /// physical children. Absent for flat configurations (v1 and
79 /// flat-v2).
80 pub(crate) decode_program: Option<merge::DecodeProgram>,
81 /// Per-scale fast path. Constructed at build time from a schema-v2
82 /// document with per-scale children. Wrapped in `Mutex` because
83 /// `Decoder::decode_proto` and `Decoder::decode` are `&self` but
84 /// the per-scale buffers are mutated per-frame.
85 pub(crate) per_scale: Option<std::sync::Mutex<crate::per_scale::PerScaleDecoder>>,
86}
87
88impl PartialEq for Decoder {
89 fn eq(&self, other: &Self) -> bool {
90 // DecodeProgram and PerScaleDecoder have non-comparable embedded
91 // data; compare by the config-derived fields only.
92 self.model_type == other.model_type
93 && self.iou_threshold == other.iou_threshold
94 && self.score_threshold == other.score_threshold
95 && self.nms == other.nms
96 && self.pre_nms_top_k == other.pre_nms_top_k
97 && self.max_det == other.max_det
98 && self.normalized == other.normalized
99 && self.input_dims == other.input_dims
100 && self.decode_program.is_some() == other.decode_program.is_some()
101 && self.per_scale.is_some() == other.per_scale.is_some()
102 }
103}
104
105impl Clone for Decoder {
106 /// Cloning a `Decoder` preserves the legacy decode path
107 /// (`decode_program`) but drops the per-scale fast path:
108 /// `PerScaleDecoder` owns mutable per-frame scratch buffers and is
109 /// not `Clone`. Decoders built from a per-scale schema should be
110 /// rebuilt via [`DecoderBuilder`] rather than cloned to preserve the
111 /// fast path; cloning is intended for tests and rare configs.
112 fn clone(&self) -> Self {
113 Self {
114 model_type: self.model_type.clone(),
115 iou_threshold: self.iou_threshold,
116 score_threshold: self.score_threshold,
117 nms: self.nms,
118 pre_nms_top_k: self.pre_nms_top_k,
119 max_det: self.max_det,
120 normalized: self.normalized,
121 input_dims: self.input_dims,
122 decode_program: self.decode_program.clone(),
123 per_scale: None,
124 }
125 }
126}
127
128#[derive(Debug)]
129pub(crate) enum ArrayViewDQuantized<'a> {
130 UInt8(ArrayViewD<'a, u8>),
131 Int8(ArrayViewD<'a, i8>),
132 UInt16(ArrayViewD<'a, u16>),
133 Int16(ArrayViewD<'a, i16>),
134 UInt32(ArrayViewD<'a, u32>),
135 Int32(ArrayViewD<'a, i32>),
136}
137
138impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
139where
140 D: Dimension,
141{
142 fn from(arr: ArrayView<'a, u8, D>) -> Self {
143 Self::UInt8(arr.into_dyn())
144 }
145}
146
147impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
148where
149 D: Dimension,
150{
151 fn from(arr: ArrayView<'a, i8, D>) -> Self {
152 Self::Int8(arr.into_dyn())
153 }
154}
155
156impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
157where
158 D: Dimension,
159{
160 fn from(arr: ArrayView<'a, u16, D>) -> Self {
161 Self::UInt16(arr.into_dyn())
162 }
163}
164
165impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
166where
167 D: Dimension,
168{
169 fn from(arr: ArrayView<'a, i16, D>) -> Self {
170 Self::Int16(arr.into_dyn())
171 }
172}
173
174impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
175where
176 D: Dimension,
177{
178 fn from(arr: ArrayView<'a, u32, D>) -> Self {
179 Self::UInt32(arr.into_dyn())
180 }
181}
182
183impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
184where
185 D: Dimension,
186{
187 fn from(arr: ArrayView<'a, i32, D>) -> Self {
188 Self::Int32(arr.into_dyn())
189 }
190}
191
192impl<'a> ArrayViewDQuantized<'a> {
193 /// Returns the shape of the underlying array.
194 pub(crate) fn shape(&self) -> &[usize] {
195 match self {
196 ArrayViewDQuantized::UInt8(a) => a.shape(),
197 ArrayViewDQuantized::Int8(a) => a.shape(),
198 ArrayViewDQuantized::UInt16(a) => a.shape(),
199 ArrayViewDQuantized::Int16(a) => a.shape(),
200 ArrayViewDQuantized::UInt32(a) => a.shape(),
201 ArrayViewDQuantized::Int32(a) => a.shape(),
202 }
203 }
204}
205
206/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
207/// monomorphized code paths by 6 (one per integer variant), so nesting
208/// N levels deep produces 6^N instantiations.
209///
210/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
211/// (6*N paths) or split into independent phases that each nest at most 2 levels.
212macro_rules! with_quantized {
213 ($x:expr, $var:ident, $body:expr) => {
214 match $x {
215 ArrayViewDQuantized::UInt8(x) => {
216 let $var = x;
217 $body
218 }
219 ArrayViewDQuantized::Int8(x) => {
220 let $var = x;
221 $body
222 }
223 ArrayViewDQuantized::UInt16(x) => {
224 let $var = x;
225 $body
226 }
227 ArrayViewDQuantized::Int16(x) => {
228 let $var = x;
229 $body
230 }
231 ArrayViewDQuantized::UInt32(x) => {
232 let $var = x;
233 $body
234 }
235 ArrayViewDQuantized::Int32(x) => {
236 let $var = x;
237 $body
238 }
239 }
240 };
241}
242
243mod builder;
244mod dfl;
245mod helpers;
246mod merge;
247mod per_scale_bridge;
248mod postprocess;
249mod tensor_bridge;
250mod tests;
251
252pub use builder::DecoderBuilder;
253pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
254
255impl Decoder {
256 /// Static label identifying which dispatch path `decode` / `decode_proto`
257 /// will take, used as a tracing-span attribute. Lets profiling tools
258 /// distinguish `per_scale` (the fast path), `decode_program` (schema-v2
259 /// merge), and `legacy` (config-driven) without requiring callers to
260 /// inspect the model.
261 fn decode_path_label(&self) -> &'static str {
262 if self.per_scale.is_some() {
263 "per_scale"
264 } else if self.decode_program.is_some() {
265 "decode_program"
266 } else {
267 "legacy"
268 }
269 }
270
271 /// This function returns the parsed model type of the decoder.
272 ///
273 /// # Examples
274 ///
275 /// ```rust
276 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
277 /// # fn main() -> DecoderResult<()> {
278 /// # let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
279 /// let decoder = DecoderBuilder::default()
280 /// .with_config_yaml_str(config_yaml)
281 /// .build()?;
282 /// assert!(matches!(
283 /// decoder.model_type(),
284 /// ModelType::ModelPackDetSplit { .. }
285 /// ));
286 /// # Ok(())
287 /// # }
288 /// ```
289 pub fn model_type(&self) -> &ModelType {
290 &self.model_type
291 }
292
293 /// Returns the coordinate format of the boxes the decoder emits to
294 /// the caller.
295 ///
296 /// - `Some(true)`: Boxes are in normalized `[0, 1]` coordinates
297 /// - `Some(false)`: Boxes are in pixel coordinates relative to the
298 /// model input
299 /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
300 /// > 1.0)
301 ///
302 /// This describes the **post-decode** coordinate space, not the raw
303 /// schema annotation. The decoder applies EDGEAI-1303 normalization
304 /// (dividing bbox channels by `(input_w, input_h)`) on a per-path
305 /// basis, not unconditionally. Four paths are known to invoke the
306 /// helper uniformly across all of their entry points (`decode`,
307 /// `decode_proto`, and — where applicable — `decode_tracked` and
308 /// `decode_tracked_proto`):
309 ///
310 /// 1. The **per-scale fast path** (DFL/LTRB → dist2bbox → sigmoid),
311 /// which emits pixel-space boxes by design and always normalizes
312 /// before returning.
313 /// 2. [`ModelType::YoloSegDet`](crate::ModelType::YoloSegDet), whose
314 /// quantized and float, tracked and untracked, masks and proto
315 /// variants each call the helper after NMS.
316 /// 3. [`ModelType::YoloSplitSegDet`](crate::ModelType::YoloSplitSegDet),
317 /// aligned across `decode`, `decode_proto`, `decode_tracked`,
318 /// and `decode_tracked_proto` for both quantized and float
319 /// variants.
320 /// 4. [`ModelType::YoloSegDet2Way`](crate::ModelType::YoloSegDet2Way),
321 /// aligned across the same four entry points and both element
322 /// type variants.
323 ///
324 /// When any of those paths is active and the schema declares
325 /// `normalized: false` with valid [`input_dims`](Self::input_dims),
326 /// this accessor reports `Some(true)` to match what the caller
327 /// actually receives.
328 ///
329 /// The remaining model types still surface the raw schema flag
330 /// because their post-decode contract differs:
331 /// [`ModelType::YoloDet`](crate::ModelType::YoloDet) and
332 /// [`ModelType::YoloSplitDet`](crate::ModelType::YoloSplitDet)
333 /// (detection-only, no protobox crop coupling), the
334 /// `YoloEndToEnd*` family (model embeds its own NMS and emits its
335 /// own coordinate space), and the `ModelPack*` family (separate
336 /// model conventions). For those, this accessor returns
337 /// `self.normalized` verbatim and leaves it to the caller to
338 /// handle pixel-space output explicitly (e.g. divide by
339 /// `input_dims()` themselves).
340 ///
341 /// # Examples
342 ///
343 /// ```rust
344 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
345 /// # fn main() -> DecoderResult<()> {
346 /// # let config_yaml = edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string();
347 /// let decoder = DecoderBuilder::default()
348 /// .with_config_yaml_str(config_yaml)
349 /// .build()?;
350 /// // Config doesn't specify normalized, so it's None
351 /// assert!(decoder.normalized_boxes().is_none());
352 /// # Ok(())
353 /// # }
354 /// ```
355 pub fn normalized_boxes(&self) -> Option<bool> {
356 // Four paths invoke `yolo::maybe_normalize_boxes_in_place`
357 // uniformly across every entry point that can reach them:
358 // - the per-scale fast path (always normalizes by design),
359 // - `ModelType::YoloSegDet` (helper fires in
360 // `decode`/`decode_proto` via `yolo::impl_yolo_segdet_*` and
361 // in `decode_tracked`/`decode_tracked_proto` via the
362 // `process_tracked_yolo_segmentation!` macro and
363 // `process_tracked_yolo_segdet_float`),
364 // - `ModelType::YoloSplitSegDet` (helper fires in
365 // `decode_yolo_split_segdet_*`, `impl_yolo_split_segdet_*`,
366 // `process_tracked_yolo_segmentation_split!`, and
367 // `process_tracked_yolo_segdet_split_float`), and
368 // - `ModelType::YoloSegDet2Way` (helper fires in
369 // `decode_yolo_segdet_2way_*`, the float decode routes
370 // through `impl_yolo_split_segdet_float*`,
371 // `process_tracked_yolo_segmentation_2way!`, and the
372 // inline tracked-2way float helpers).
373 // For those, `normalized == Some(false)` with valid `input_dims`
374 // upgrades to a post-decode `Some(true)`. Other paths invoke
375 // the helper inconsistently across `ModelType` variants and
376 // tracked/proto entry points — surface the raw schema flag
377 // there and let callers handle pixel-space output explicitly.
378 if self.per_scale.is_some() || self.legacy_path_normalizes_uniformly() {
379 match (self.normalized, self.input_dims) {
380 (Some(true), _) => Some(true),
381 (Some(false), Some((w, h))) if w != 0 && h != 0 => Some(true),
382 (Some(false), _) => Some(false),
383 (None, _) => None,
384 }
385 } else {
386 self.normalized
387 }
388 }
389
390 /// Returns true for legacy `ModelType` dispatch paths that are known
391 /// to call `yolo::maybe_normalize_boxes_in_place` on every entry
392 /// point (`decode`, `decode_proto`, `decode_tracked`,
393 /// `decode_tracked_proto`, both quantized and float variants).
394 ///
395 /// Used by [`normalized_boxes`](Self::normalized_boxes) to gate the
396 /// pixel→normalized upgrade for non-per-scale model types whose
397 /// post-decode contract matches the per-scale path. Extend this
398 /// list as additional `ModelType` variants are brought into
399 /// uniform-normalization alignment.
400 fn legacy_path_normalizes_uniformly(&self) -> bool {
401 matches!(
402 self.model_type,
403 ModelType::YoloSegDet { .. }
404 | ModelType::YoloSplitSegDet { .. }
405 | ModelType::YoloSegDet2Way { .. }
406 )
407 }
408
409 /// Model input dimensions `(width, height)` captured from the
410 /// schema's `input.shape` / `input.dshape`, or `None` when the
411 /// schema did not declare an input spec (e.g. flat YAML configs
412 /// or `DecoderBuilder::add_output(...)` programmatic builds).
413 ///
414 /// Drives EDGEAI-1303 normalization on the paths that invoke the
415 /// helper uniformly: when the schema declares pixel-space outputs
416 /// and `input_dims()` is `Some((w, h))`, the per-scale bridge and
417 /// the `ModelType::YoloSegDet`, `ModelType::YoloSplitSegDet`, and
418 /// `ModelType::YoloSegDet2Way` dispatch paths divide post-NMS
419 /// bbox coordinates by `(w, h)` so they enter the canonical
420 /// `[0, 1]` range before mask cropping / tracker dispatch, and
421 /// [`normalized_boxes`](Self::normalized_boxes) reports
422 /// `Some(true)` to match. The remaining legacy `ModelType`
423 /// dispatch paths (detection-only `YoloDet`/`YoloSplitDet`, the
424 /// `YoloEndToEnd*` family, and the `ModelPack*` family) do not
425 /// apply this division — see
426 /// [`normalized_boxes`](Self::normalized_boxes) for the per-path
427 /// contract. The legacy `protobox` `> 2.0` reject acts as a safety
428 /// net for paths that emit pixel-space coordinates.
429 ///
430 /// # Examples
431 ///
432 /// ```rust
433 /// # use edgefirst_decoder::{schema::SchemaV2, DecoderBuilder, DecoderResult};
434 /// # fn main() -> DecoderResult<()> {
435 /// let json = r#"{
436 /// "schema_version": 2,
437 /// "nms": "class_agnostic",
438 /// "input": {
439 /// "shape": [1, 640, 640, 3],
440 /// "dshape": [{"batch": 1}, {"height": 640}, {"width": 640}, {"num_features": 3}]
441 /// },
442 /// "outputs": [{
443 /// "name": "out", "type": "detection",
444 /// "shape": [1, 38, 256],
445 /// "dshape": [{"batch": 1}, {"num_features": 38}, {"num_boxes": 256}],
446 /// "decoder": "ultralytics", "encoding": "direct", "normalized": false
447 /// }]
448 /// }"#;
449 /// let schema: SchemaV2 = serde_json::from_str(json).unwrap();
450 /// let decoder = DecoderBuilder::default().with_schema(schema).build()?;
451 /// assert_eq!(decoder.input_dims(), Some((640, 640)));
452 /// # Ok(())
453 /// # }
454 /// ```
455 pub fn input_dims(&self) -> Option<(usize, usize)> {
456 self.input_dims
457 }
458
459 /// Decode quantized model outputs into detection boxes and segmentation
460 /// masks. The quantized outputs can be of u8, i8, u16, i16, u32, or i32
461 /// types. Clears the provided output vectors before populating them.
462 pub(crate) fn decode_quantized(
463 &self,
464 outputs: &[ArrayViewDQuantized],
465 output_boxes: &mut Vec<DetectBox>,
466 output_masks: &mut Vec<Segmentation>,
467 ) -> Result<(), DecoderError> {
468 output_boxes.clear();
469 output_masks.clear();
470 match &self.model_type {
471 ModelType::ModelPackSegDet {
472 boxes,
473 scores,
474 segmentation,
475 } => {
476 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
477 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
478 }
479 ModelType::ModelPackSegDetSplit {
480 detection,
481 segmentation,
482 } => {
483 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
484 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
485 }
486 ModelType::ModelPackDet { boxes, scores } => {
487 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
488 }
489 ModelType::ModelPackDetSplit { detection } => {
490 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
491 }
492 ModelType::ModelPackSeg { segmentation } => {
493 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
494 }
495 ModelType::YoloDet { boxes } => {
496 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
497 }
498 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
499 outputs,
500 boxes,
501 protos,
502 output_boxes,
503 output_masks,
504 ),
505 ModelType::YoloSplitDet { boxes, scores } => {
506 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
507 }
508 ModelType::YoloSplitSegDet {
509 boxes,
510 scores,
511 mask_coeff,
512 protos,
513 } => self.decode_yolo_split_segdet_quantized(
514 outputs,
515 boxes,
516 scores,
517 mask_coeff,
518 protos,
519 output_boxes,
520 output_masks,
521 ),
522 ModelType::YoloSegDet2Way {
523 boxes,
524 mask_coeff,
525 protos,
526 } => self.decode_yolo_segdet_2way_quantized(
527 outputs,
528 boxes,
529 mask_coeff,
530 protos,
531 output_boxes,
532 output_masks,
533 ),
534 ModelType::YoloEndToEndDet { boxes } => {
535 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
536 }
537 ModelType::YoloEndToEndSegDet { boxes, protos } => self
538 .decode_yolo_end_to_end_segdet_quantized(
539 outputs,
540 boxes,
541 protos,
542 output_boxes,
543 output_masks,
544 ),
545 ModelType::YoloSplitEndToEndDet {
546 boxes,
547 scores,
548 classes,
549 } => self.decode_yolo_split_end_to_end_det_quantized(
550 outputs,
551 boxes,
552 scores,
553 classes,
554 output_boxes,
555 ),
556 ModelType::YoloSplitEndToEndSegDet {
557 boxes,
558 scores,
559 classes,
560 mask_coeff,
561 protos,
562 } => self.decode_yolo_split_end_to_end_segdet_quantized(
563 outputs,
564 boxes,
565 scores,
566 classes,
567 mask_coeff,
568 protos,
569 output_boxes,
570 output_masks,
571 ),
572 ModelType::PerScale => Err(DecoderError::Internal(
573 "per-scale path must be intercepted before ModelType dispatch".into(),
574 )),
575 }
576 }
577
578 /// Decode floating point model outputs into detection boxes and
579 /// segmentation masks. Clears the provided output vectors before
580 /// populating them.
581 pub(crate) fn decode_float<T>(
582 &self,
583 outputs: &[ArrayViewD<T>],
584 output_boxes: &mut Vec<DetectBox>,
585 output_masks: &mut Vec<Segmentation>,
586 ) -> Result<(), DecoderError>
587 where
588 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
589 f32: AsPrimitive<T>,
590 {
591 output_boxes.clear();
592 output_masks.clear();
593 match &self.model_type {
594 ModelType::ModelPackSegDet {
595 boxes,
596 scores,
597 segmentation,
598 } => {
599 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
600 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
601 }
602 ModelType::ModelPackSegDetSplit {
603 detection,
604 segmentation,
605 } => {
606 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
607 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
608 }
609 ModelType::ModelPackDet { boxes, scores } => {
610 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
611 }
612 ModelType::ModelPackDetSplit { detection } => {
613 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
614 }
615 ModelType::ModelPackSeg { segmentation } => {
616 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
617 }
618 ModelType::YoloDet { boxes } => {
619 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
620 }
621 ModelType::YoloSegDet { boxes, protos } => {
622 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
623 }
624 ModelType::YoloSplitDet { boxes, scores } => {
625 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
626 }
627 ModelType::YoloSplitSegDet {
628 boxes,
629 scores,
630 mask_coeff,
631 protos,
632 } => {
633 self.decode_yolo_split_segdet_float(
634 outputs,
635 boxes,
636 scores,
637 mask_coeff,
638 protos,
639 output_boxes,
640 output_masks,
641 )?;
642 }
643 ModelType::YoloSegDet2Way {
644 boxes,
645 mask_coeff,
646 protos,
647 } => {
648 self.decode_yolo_segdet_2way_float(
649 outputs,
650 boxes,
651 mask_coeff,
652 protos,
653 output_boxes,
654 output_masks,
655 )?;
656 }
657 ModelType::YoloEndToEndDet { boxes } => {
658 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
659 }
660 ModelType::YoloEndToEndSegDet { boxes, protos } => {
661 self.decode_yolo_end_to_end_segdet_float(
662 outputs,
663 boxes,
664 protos,
665 output_boxes,
666 output_masks,
667 )?;
668 }
669 ModelType::YoloSplitEndToEndDet {
670 boxes,
671 scores,
672 classes,
673 } => {
674 self.decode_yolo_split_end_to_end_det_float(
675 outputs,
676 boxes,
677 scores,
678 classes,
679 output_boxes,
680 )?;
681 }
682 ModelType::YoloSplitEndToEndSegDet {
683 boxes,
684 scores,
685 classes,
686 mask_coeff,
687 protos,
688 } => {
689 self.decode_yolo_split_end_to_end_segdet_float(
690 outputs,
691 boxes,
692 scores,
693 classes,
694 mask_coeff,
695 protos,
696 output_boxes,
697 output_masks,
698 )?;
699 }
700 ModelType::PerScale => {
701 return Err(DecoderError::Internal(
702 "per-scale path must be intercepted before ModelType dispatch".into(),
703 ));
704 }
705 }
706 Ok(())
707 }
708
709 /// Decodes quantized model outputs into detection boxes, returning raw
710 /// `ProtoData` for segmentation models instead of materialized masks.
711 ///
712 /// Returns `Ok(None)` for detection-only and ModelPack models (detections
713 /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
714 /// for YOLO segmentation models.
715 pub(crate) fn decode_quantized_proto(
716 &self,
717 outputs: &[ArrayViewDQuantized],
718 output_boxes: &mut Vec<DetectBox>,
719 ) -> Result<Option<ProtoData>, DecoderError> {
720 output_boxes.clear();
721 match &self.model_type {
722 // Detection-only variants: decode boxes, return None for proto data.
723 ModelType::ModelPackDet { boxes, scores } => {
724 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
725 Ok(None)
726 }
727 ModelType::ModelPackDetSplit { detection } => {
728 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
729 Ok(None)
730 }
731 ModelType::YoloDet { boxes } => {
732 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
733 Ok(None)
734 }
735 ModelType::YoloSplitDet { boxes, scores } => {
736 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
737 Ok(None)
738 }
739 ModelType::YoloEndToEndDet { boxes } => {
740 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
741 Ok(None)
742 }
743 ModelType::YoloSplitEndToEndDet {
744 boxes,
745 scores,
746 classes,
747 } => {
748 self.decode_yolo_split_end_to_end_det_quantized(
749 outputs,
750 boxes,
751 scores,
752 classes,
753 output_boxes,
754 )?;
755 Ok(None)
756 }
757 // ModelPack seg/segdet variants have no YOLO proto data.
758 ModelType::ModelPackSegDet { boxes, scores, .. } => {
759 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
760 Ok(None)
761 }
762 ModelType::ModelPackSegDetSplit { detection, .. } => {
763 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
764 Ok(None)
765 }
766 ModelType::ModelPackSeg { .. } => Ok(None),
767
768 ModelType::YoloSegDet { boxes, protos } => {
769 let proto =
770 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
771 Ok(Some(proto))
772 }
773 ModelType::YoloSplitSegDet {
774 boxes,
775 scores,
776 mask_coeff,
777 protos,
778 } => {
779 let proto = self.decode_yolo_split_segdet_quantized_proto(
780 outputs,
781 boxes,
782 scores,
783 mask_coeff,
784 protos,
785 output_boxes,
786 )?;
787 Ok(Some(proto))
788 }
789 ModelType::YoloSegDet2Way {
790 boxes,
791 mask_coeff,
792 protos,
793 } => {
794 let proto = self.decode_yolo_segdet_2way_quantized_proto(
795 outputs,
796 boxes,
797 mask_coeff,
798 protos,
799 output_boxes,
800 )?;
801 Ok(Some(proto))
802 }
803 ModelType::YoloEndToEndSegDet { boxes, protos } => {
804 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
805 outputs,
806 boxes,
807 protos,
808 output_boxes,
809 )?;
810 Ok(Some(proto))
811 }
812 ModelType::YoloSplitEndToEndSegDet {
813 boxes,
814 scores,
815 classes,
816 mask_coeff,
817 protos,
818 } => {
819 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
820 outputs,
821 boxes,
822 scores,
823 classes,
824 mask_coeff,
825 protos,
826 output_boxes,
827 )?;
828 Ok(Some(proto))
829 }
830 ModelType::PerScale => Err(DecoderError::Internal(
831 "per-scale path must be intercepted before ModelType dispatch".into(),
832 )),
833 }
834 }
835
836 /// Decodes floating-point model outputs into detection boxes, returning
837 /// raw `ProtoData` for segmentation models instead of materialized masks.
838 ///
839 /// Returns `Ok(None)` for detection-only and ModelPack models (detections
840 /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
841 /// for YOLO segmentation models.
842 pub(crate) fn decode_float_proto<T>(
843 &self,
844 outputs: &[ArrayViewD<T>],
845 output_boxes: &mut Vec<DetectBox>,
846 ) -> Result<Option<ProtoData>, DecoderError>
847 where
848 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
849 f32: AsPrimitive<T>,
850 {
851 output_boxes.clear();
852 match &self.model_type {
853 // Detection-only variants: decode boxes, return None for proto data.
854 ModelType::ModelPackDet { boxes, scores } => {
855 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
856 Ok(None)
857 }
858 ModelType::ModelPackDetSplit { detection } => {
859 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
860 Ok(None)
861 }
862 ModelType::YoloDet { boxes } => {
863 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
864 Ok(None)
865 }
866 ModelType::YoloSplitDet { boxes, scores } => {
867 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
868 Ok(None)
869 }
870 ModelType::YoloEndToEndDet { boxes } => {
871 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
872 Ok(None)
873 }
874 ModelType::YoloSplitEndToEndDet {
875 boxes,
876 scores,
877 classes,
878 } => {
879 self.decode_yolo_split_end_to_end_det_float(
880 outputs,
881 boxes,
882 scores,
883 classes,
884 output_boxes,
885 )?;
886 Ok(None)
887 }
888 // ModelPack seg/segdet variants have no YOLO proto data.
889 ModelType::ModelPackSegDet { boxes, scores, .. } => {
890 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
891 Ok(None)
892 }
893 ModelType::ModelPackSegDetSplit { detection, .. } => {
894 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
895 Ok(None)
896 }
897 ModelType::ModelPackSeg { .. } => Ok(None),
898
899 ModelType::YoloSegDet { boxes, protos } => {
900 let proto =
901 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
902 Ok(Some(proto))
903 }
904 ModelType::YoloSplitSegDet {
905 boxes,
906 scores,
907 mask_coeff,
908 protos,
909 } => {
910 let proto = self.decode_yolo_split_segdet_float_proto(
911 outputs,
912 boxes,
913 scores,
914 mask_coeff,
915 protos,
916 output_boxes,
917 )?;
918 Ok(Some(proto))
919 }
920 ModelType::YoloSegDet2Way {
921 boxes,
922 mask_coeff,
923 protos,
924 } => {
925 let proto = self.decode_yolo_segdet_2way_float_proto(
926 outputs,
927 boxes,
928 mask_coeff,
929 protos,
930 output_boxes,
931 )?;
932 Ok(Some(proto))
933 }
934 ModelType::YoloEndToEndSegDet { boxes, protos } => {
935 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
936 outputs,
937 boxes,
938 protos,
939 output_boxes,
940 )?;
941 Ok(Some(proto))
942 }
943 ModelType::YoloSplitEndToEndSegDet {
944 boxes,
945 scores,
946 classes,
947 mask_coeff,
948 protos,
949 } => {
950 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
951 outputs,
952 boxes,
953 scores,
954 classes,
955 mask_coeff,
956 protos,
957 output_boxes,
958 )?;
959 Ok(Some(proto))
960 }
961 ModelType::PerScale => Err(DecoderError::Internal(
962 "per-scale path must be intercepted before ModelType dispatch".into(),
963 )),
964 }
965 }
966
967 // ========================================================================
968 // TensorDyn-based public API
969 // ========================================================================
970
971 /// Decode model outputs into detection boxes and segmentation masks.
972 ///
973 /// This is the primary decode API. Accepts `TensorDyn` outputs directly
974 /// from model inference. Automatically dispatches to quantized or float
975 /// paths based on the tensor dtype.
976 ///
977 /// # Arguments
978 ///
979 /// * `outputs` - Tensor outputs from model inference
980 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
981 /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
982 ///
983 /// # `output_boxes` / `output_masks` capacity
984 ///
985 /// The capacity of the supplied `Vec`s is **only** an allocation hint —
986 /// it is **not** a cap on the number of detections returned. The
987 /// post-NMS detection count is bounded by [`Decoder::max_det`] (set
988 /// via [`DecoderBuilder::with_max_det`], default `300`). Passing
989 /// `Vec::new()` (capacity 0) returns up to `max_det` detections;
990 /// pre-allocating with [`Vec::with_capacity`] only avoids the
991 /// reallocation when the decoder grows the buffer.
992 ///
993 /// # Errors
994 ///
995 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
996 /// or the outputs don't match the decoder's model configuration.
997 pub fn decode(
998 &self,
999 outputs: &[&edgefirst_tensor::TensorDyn],
1000 output_boxes: &mut Vec<DetectBox>,
1001 output_masks: &mut Vec<Segmentation>,
1002 ) -> Result<(), DecoderError> {
1003 let path = self.decode_path_label();
1004 let _span = tracing::trace_span!("decoder.decode", path = path, n_outputs = outputs.len())
1005 .entered();
1006 // Per-scale fast path — selected at builder time when the schema
1007 // declares per-scale children with DFL or LTRB encoding.
1008 if let Some(per_scale_mutex) = &self.per_scale {
1009 let mut ps = per_scale_mutex
1010 .lock()
1011 .map_err(|e| DecoderError::Internal(format!("per_scale mutex poisoned: {e}")))?;
1012 let decoded = ps.run(outputs)?;
1013 return per_scale_bridge::per_scale_to_masks(
1014 &decoded,
1015 output_boxes,
1016 output_masks,
1017 self.iou_threshold,
1018 self.score_threshold,
1019 self.nms,
1020 self.pre_nms_top_k,
1021 self.max_det,
1022 self.normalized,
1023 self.input_dims,
1024 );
1025 }
1026
1027 // Schema v2 merge path: dequantize physical children into
1028 // logical float32 tensors, then feed through the float dispatch.
1029 if let Some(program) = &self.decode_program {
1030 let merged = program.execute(outputs)?;
1031 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
1032 return self.decode_float(&views, output_boxes, output_masks);
1033 }
1034
1035 let mapped = tensor_bridge::map_tensors(outputs)?;
1036 match &mapped {
1037 tensor_bridge::MappedOutputs::Quantized(maps) => {
1038 let views = tensor_bridge::quantized_views(maps)?;
1039 self.decode_quantized(&views, output_boxes, output_masks)
1040 }
1041 tensor_bridge::MappedOutputs::Float16(maps) => {
1042 let views = tensor_bridge::f16_views(maps)?;
1043 self.decode_float(&views, output_boxes, output_masks)
1044 }
1045 tensor_bridge::MappedOutputs::Float32(maps) => {
1046 let views = tensor_bridge::f32_views(maps)?;
1047 self.decode_float(&views, output_boxes, output_masks)
1048 }
1049 tensor_bridge::MappedOutputs::Float64(maps) => {
1050 let views = tensor_bridge::f64_views(maps)?;
1051 self.decode_float(&views, output_boxes, output_masks)
1052 }
1053 }
1054 }
1055
1056 /// Decode model outputs into detection boxes, returning raw proto data
1057 /// for segmentation models instead of materialized masks.
1058 ///
1059 /// Accepts `TensorDyn` outputs directly from model inference.
1060 /// Detections are always decoded into `output_boxes` regardless of model type.
1061 /// Returns `Ok(None)` for detection-only and ModelPack models.
1062 /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1063 ///
1064 /// # Arguments
1065 ///
1066 /// * `outputs` - Tensor outputs from model inference
1067 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1068 ///
1069 /// # `output_boxes` capacity
1070 ///
1071 /// The capacity of `output_boxes` is **only** an allocation hint — it
1072 /// is **not** a cap on the number of detections returned. The
1073 /// post-NMS detection count is bounded by [`Decoder::max_det`] (set
1074 /// via [`DecoderBuilder::with_max_det`], default `300`). Passing
1075 /// `Vec::new()` (capacity 0) returns up to `max_det` detections.
1076 ///
1077 /// # Errors
1078 ///
1079 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1080 /// or the outputs don't match the decoder's model configuration.
1081 pub fn decode_proto(
1082 &self,
1083 outputs: &[&edgefirst_tensor::TensorDyn],
1084 output_boxes: &mut Vec<DetectBox>,
1085 ) -> Result<Option<ProtoData>, DecoderError> {
1086 let path = self.decode_path_label();
1087 let _span = tracing::trace_span!(
1088 "decoder.decode_proto",
1089 path = path,
1090 n_outputs = outputs.len()
1091 )
1092 .entered();
1093 // Per-scale fast path — selected at builder time when the schema
1094 // declares per-scale children with DFL or LTRB encoding.
1095 if let Some(per_scale_mutex) = &self.per_scale {
1096 let mut ps = per_scale_mutex
1097 .lock()
1098 .map_err(|e| DecoderError::Internal(format!("per_scale mutex poisoned: {e}")))?;
1099 let decoded = ps.run(outputs)?;
1100 return per_scale_bridge::per_scale_to_proto_data(
1101 &decoded,
1102 output_boxes,
1103 self.iou_threshold,
1104 self.score_threshold,
1105 self.nms,
1106 self.pre_nms_top_k,
1107 self.max_det,
1108 self.normalized,
1109 self.input_dims,
1110 );
1111 }
1112
1113 // Schema v2 merge path: dequantize physical children into
1114 // logical float32 tensors, then feed through the float dispatch.
1115 if let Some(program) = &self.decode_program {
1116 let merged = program.execute(outputs)?;
1117 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
1118 return self.decode_float_proto(&views, output_boxes);
1119 }
1120
1121 let mapped = tensor_bridge::map_tensors(outputs)?;
1122 let result = match &mapped {
1123 tensor_bridge::MappedOutputs::Quantized(maps) => {
1124 let views = tensor_bridge::quantized_views(maps)?;
1125 self.decode_quantized_proto(&views, output_boxes)
1126 }
1127 tensor_bridge::MappedOutputs::Float16(maps) => {
1128 let views = tensor_bridge::f16_views(maps)?;
1129 self.decode_float_proto(&views, output_boxes)
1130 }
1131 tensor_bridge::MappedOutputs::Float32(maps) => {
1132 let views = tensor_bridge::f32_views(maps)?;
1133 self.decode_float_proto(&views, output_boxes)
1134 }
1135 tensor_bridge::MappedOutputs::Float64(maps) => {
1136 let views = tensor_bridge::f64_views(maps)?;
1137 self.decode_float_proto(&views, output_boxes)
1138 }
1139 };
1140 result
1141 }
1142
1143 /// Run the per-scale pipeline and return pre-NMS buffers as owned f32.
1144 ///
1145 /// Test-only entry point used by the parity-fixture tests to compare
1146 /// HAL stage output against the NumPy reference's stage output
1147 /// without NMS ordering noise. Returns an error if the decoder
1148 /// isn't configured for per-scale decoding.
1149 #[doc(hidden)]
1150 pub fn _testing_run_per_scale_pre_nms(
1151 &self,
1152 outputs: &[&edgefirst_tensor::TensorDyn],
1153 ) -> Result<crate::per_scale::PreNmsCapture, crate::error::DecoderError> {
1154 let mutex = self.per_scale.as_ref().ok_or_else(|| {
1155 crate::error::DecoderError::Internal("decoder not configured for per-scale".into())
1156 })?;
1157 let mut ps = mutex.lock().map_err(|e| {
1158 crate::error::DecoderError::Internal(format!("per_scale mutex poisoned: {e}"))
1159 })?;
1160 // Drop the borrowed view immediately so we can reborrow buffers below.
1161 {
1162 ps.run(outputs)?;
1163 }
1164 let total_anchors = ps.plan.total_anchors;
1165 let num_classes = ps.plan.num_classes;
1166 let num_mc = ps.plan.num_mask_coefs;
1167 Ok(ps
1168 .buffers
1169 .snapshot_owned_f32(total_anchors, num_classes, num_mc))
1170 }
1171}
1172
1173#[cfg(feature = "tracker")]
1174pub use edgefirst_tracker::TrackInfo;
1175
1176#[cfg(feature = "tracker")]
1177pub use edgefirst_tracker::Tracker;
1178
1179#[cfg(feature = "tracker")]
1180impl Decoder {
1181 /// Decode quantized model outputs into detection boxes and segmentation
1182 /// masks with tracking. Clears the provided output vectors before
1183 /// populating them.
1184 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
1185 &self,
1186 tracker: &mut TR,
1187 timestamp: u64,
1188 outputs: &[ArrayViewDQuantized],
1189 output_boxes: &mut Vec<DetectBox>,
1190 output_masks: &mut Vec<Segmentation>,
1191 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1192 ) -> Result<(), DecoderError> {
1193 output_boxes.clear();
1194 output_masks.clear();
1195 output_tracks.clear();
1196
1197 // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
1198 // Only boxes that come from decoding can be used for proto/mask generation.
1199 match &self.model_type {
1200 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
1201 tracker,
1202 timestamp,
1203 outputs,
1204 boxes,
1205 protos,
1206 output_boxes,
1207 output_masks,
1208 output_tracks,
1209 ),
1210 ModelType::YoloSplitSegDet {
1211 boxes,
1212 scores,
1213 mask_coeff,
1214 protos,
1215 } => self.decode_tracked_yolo_split_segdet_quantized(
1216 tracker,
1217 timestamp,
1218 outputs,
1219 boxes,
1220 scores,
1221 mask_coeff,
1222 protos,
1223 output_boxes,
1224 output_masks,
1225 output_tracks,
1226 ),
1227 ModelType::YoloEndToEndSegDet { boxes, protos } => self
1228 .decode_tracked_yolo_end_to_end_segdet_quantized(
1229 tracker,
1230 timestamp,
1231 outputs,
1232 boxes,
1233 protos,
1234 output_boxes,
1235 output_masks,
1236 output_tracks,
1237 ),
1238 ModelType::YoloSplitEndToEndSegDet {
1239 boxes,
1240 scores,
1241 classes,
1242 mask_coeff,
1243 protos,
1244 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
1245 tracker,
1246 timestamp,
1247 outputs,
1248 boxes,
1249 scores,
1250 classes,
1251 mask_coeff,
1252 protos,
1253 output_boxes,
1254 output_masks,
1255 output_tracks,
1256 ),
1257 ModelType::YoloSegDet2Way {
1258 boxes,
1259 mask_coeff,
1260 protos,
1261 } => self.decode_tracked_yolo_segdet_2way_quantized(
1262 tracker,
1263 timestamp,
1264 outputs,
1265 boxes,
1266 mask_coeff,
1267 protos,
1268 output_boxes,
1269 output_masks,
1270 output_tracks,
1271 ),
1272 _ => {
1273 self.decode_quantized(outputs, output_boxes, output_masks)?;
1274 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1275 Ok(())
1276 }
1277 }
1278 }
1279
1280 /// This function decodes floating point model outputs into detection boxes
1281 /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
1282 /// masks will be decoded. The function clears the provided output
1283 /// vectors before populating them with the decoded results.
1284 ///
1285 /// This function returns an `Error` if the provided outputs don't
1286 /// match the configuration provided by the user when building the decoder.
1287 ///
1288 /// Any quantization information in the configuration will be ignored.
1289 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1290 &self,
1291 tracker: &mut TR,
1292 timestamp: u64,
1293 outputs: &[ArrayViewD<T>],
1294 output_boxes: &mut Vec<DetectBox>,
1295 output_masks: &mut Vec<Segmentation>,
1296 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1297 ) -> Result<(), DecoderError>
1298 where
1299 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1300 f32: AsPrimitive<T>,
1301 {
1302 output_boxes.clear();
1303 output_masks.clear();
1304 output_tracks.clear();
1305 match &self.model_type {
1306 ModelType::YoloSegDet { boxes, protos } => {
1307 self.decode_tracked_yolo_segdet_float(
1308 tracker,
1309 timestamp,
1310 outputs,
1311 boxes,
1312 protos,
1313 output_boxes,
1314 output_masks,
1315 output_tracks,
1316 )?;
1317 }
1318 ModelType::YoloSplitSegDet {
1319 boxes,
1320 scores,
1321 mask_coeff,
1322 protos,
1323 } => {
1324 self.decode_tracked_yolo_split_segdet_float(
1325 tracker,
1326 timestamp,
1327 outputs,
1328 boxes,
1329 scores,
1330 mask_coeff,
1331 protos,
1332 output_boxes,
1333 output_masks,
1334 output_tracks,
1335 )?;
1336 }
1337 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1338 self.decode_tracked_yolo_end_to_end_segdet_float(
1339 tracker,
1340 timestamp,
1341 outputs,
1342 boxes,
1343 protos,
1344 output_boxes,
1345 output_masks,
1346 output_tracks,
1347 )?;
1348 }
1349 ModelType::YoloSplitEndToEndSegDet {
1350 boxes,
1351 scores,
1352 classes,
1353 mask_coeff,
1354 protos,
1355 } => {
1356 self.decode_tracked_yolo_split_end_to_end_segdet_float(
1357 tracker,
1358 timestamp,
1359 outputs,
1360 boxes,
1361 scores,
1362 classes,
1363 mask_coeff,
1364 protos,
1365 output_boxes,
1366 output_masks,
1367 output_tracks,
1368 )?;
1369 }
1370 ModelType::YoloSegDet2Way {
1371 boxes,
1372 mask_coeff,
1373 protos,
1374 } => {
1375 self.decode_tracked_yolo_segdet_2way_float(
1376 tracker,
1377 timestamp,
1378 outputs,
1379 boxes,
1380 mask_coeff,
1381 protos,
1382 output_boxes,
1383 output_masks,
1384 output_tracks,
1385 )?;
1386 }
1387 _ => {
1388 self.decode_float(outputs, output_boxes, output_masks)?;
1389 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1390 }
1391 }
1392 Ok(())
1393 }
1394
1395 /// Decodes quantized model outputs into detection boxes, returning raw
1396 /// `ProtoData` for segmentation models instead of materialized masks.
1397 ///
1398 /// Returns `Ok(None)` for detection-only and ModelPack models (use
1399 /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
1400 /// YOLO segmentation models.
1401 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1402 &self,
1403 tracker: &mut TR,
1404 timestamp: u64,
1405 outputs: &[ArrayViewDQuantized],
1406 output_boxes: &mut Vec<DetectBox>,
1407 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1408 ) -> Result<Option<ProtoData>, DecoderError> {
1409 output_boxes.clear();
1410 output_tracks.clear();
1411 match &self.model_type {
1412 ModelType::YoloSegDet { boxes, protos } => {
1413 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1414 tracker,
1415 timestamp,
1416 outputs,
1417 boxes,
1418 protos,
1419 output_boxes,
1420 output_tracks,
1421 )?;
1422 Ok(Some(proto))
1423 }
1424 ModelType::YoloSplitSegDet {
1425 boxes,
1426 scores,
1427 mask_coeff,
1428 protos,
1429 } => {
1430 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1431 tracker,
1432 timestamp,
1433 outputs,
1434 boxes,
1435 scores,
1436 mask_coeff,
1437 protos,
1438 output_boxes,
1439 output_tracks,
1440 )?;
1441 Ok(Some(proto))
1442 }
1443 ModelType::YoloSegDet2Way {
1444 boxes,
1445 mask_coeff,
1446 protos,
1447 } => {
1448 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1449 tracker,
1450 timestamp,
1451 outputs,
1452 boxes,
1453 mask_coeff,
1454 protos,
1455 output_boxes,
1456 output_tracks,
1457 )?;
1458 Ok(Some(proto))
1459 }
1460 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1461 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1462 tracker,
1463 timestamp,
1464 outputs,
1465 boxes,
1466 protos,
1467 output_boxes,
1468 output_tracks,
1469 )?;
1470 Ok(Some(proto))
1471 }
1472 ModelType::YoloSplitEndToEndSegDet {
1473 boxes,
1474 scores,
1475 classes,
1476 mask_coeff,
1477 protos,
1478 } => {
1479 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1480 tracker,
1481 timestamp,
1482 outputs,
1483 boxes,
1484 scores,
1485 classes,
1486 mask_coeff,
1487 protos,
1488 output_boxes,
1489 output_tracks,
1490 )?;
1491 Ok(Some(proto))
1492 }
1493 // Non-seg variants: decode boxes via the non-proto path, then track.
1494 _ => {
1495 let mut masks = Vec::new();
1496 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1497 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1498 Ok(None)
1499 }
1500 }
1501 }
1502
1503 /// Decodes floating-point model outputs into detection boxes, returning
1504 /// raw `ProtoData` for segmentation models instead of materialized masks.
1505 ///
1506 /// Detections are always decoded into `output_boxes` regardless of model type.
1507 /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1508 /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1509 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1510 &self,
1511 tracker: &mut TR,
1512 timestamp: u64,
1513 outputs: &[ArrayViewD<T>],
1514 output_boxes: &mut Vec<DetectBox>,
1515 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1516 ) -> Result<Option<ProtoData>, DecoderError>
1517 where
1518 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1519 f32: AsPrimitive<T>,
1520 {
1521 output_boxes.clear();
1522 output_tracks.clear();
1523 match &self.model_type {
1524 ModelType::YoloSegDet { boxes, protos } => {
1525 let proto = self.decode_tracked_yolo_segdet_float_proto(
1526 tracker,
1527 timestamp,
1528 outputs,
1529 boxes,
1530 protos,
1531 output_boxes,
1532 output_tracks,
1533 )?;
1534 Ok(Some(proto))
1535 }
1536 ModelType::YoloSplitSegDet {
1537 boxes,
1538 scores,
1539 mask_coeff,
1540 protos,
1541 } => {
1542 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1543 tracker,
1544 timestamp,
1545 outputs,
1546 boxes,
1547 scores,
1548 mask_coeff,
1549 protos,
1550 output_boxes,
1551 output_tracks,
1552 )?;
1553 Ok(Some(proto))
1554 }
1555 ModelType::YoloSegDet2Way {
1556 boxes,
1557 mask_coeff,
1558 protos,
1559 } => {
1560 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1561 tracker,
1562 timestamp,
1563 outputs,
1564 boxes,
1565 mask_coeff,
1566 protos,
1567 output_boxes,
1568 output_tracks,
1569 )?;
1570 Ok(Some(proto))
1571 }
1572 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1573 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1574 tracker,
1575 timestamp,
1576 outputs,
1577 boxes,
1578 protos,
1579 output_boxes,
1580 output_tracks,
1581 )?;
1582 Ok(Some(proto))
1583 }
1584 ModelType::YoloSplitEndToEndSegDet {
1585 boxes,
1586 scores,
1587 classes,
1588 mask_coeff,
1589 protos,
1590 } => {
1591 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1592 tracker,
1593 timestamp,
1594 outputs,
1595 boxes,
1596 scores,
1597 classes,
1598 mask_coeff,
1599 protos,
1600 output_boxes,
1601 output_tracks,
1602 )?;
1603 Ok(Some(proto))
1604 }
1605 // Non-seg variants: decode boxes via the non-proto path, then track.
1606 _ => {
1607 let mut masks = Vec::new();
1608 self.decode_float(outputs, output_boxes, &mut masks)?;
1609 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1610 Ok(None)
1611 }
1612 }
1613 }
1614
1615 // ========================================================================
1616 // TensorDyn-based tracked public API
1617 // ========================================================================
1618
1619 /// Decode model outputs with tracking.
1620 ///
1621 /// Accepts `TensorDyn` outputs directly from model inference. Automatically
1622 /// dispatches to quantized or float paths based on the tensor dtype, then
1623 /// updates the tracker with the decoded boxes.
1624 ///
1625 /// # Arguments
1626 ///
1627 /// * `tracker` - The tracker instance to update
1628 /// * `timestamp` - Current frame timestamp
1629 /// * `outputs` - Tensor outputs from model inference
1630 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1631 /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
1632 /// * `output_tracks` - Destination for track info (cleared first)
1633 ///
1634 /// # Errors
1635 ///
1636 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1637 /// or the outputs don't match the decoder's model configuration.
1638 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1639 &self,
1640 tracker: &mut TR,
1641 timestamp: u64,
1642 outputs: &[&edgefirst_tensor::TensorDyn],
1643 output_boxes: &mut Vec<DetectBox>,
1644 output_masks: &mut Vec<Segmentation>,
1645 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1646 ) -> Result<(), DecoderError> {
1647 // Per-scale fast path: route via the basic decode then update the
1648 // tracker. The current implementation keeps the tracker integration simple; per-frame
1649 // decoupling between detection and tracking is preserved.
1650 if self.per_scale.is_some() {
1651 output_tracks.clear();
1652 self.decode(outputs, output_boxes, output_masks)?;
1653 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1654 return Ok(());
1655 }
1656
1657 let mapped = tensor_bridge::map_tensors(outputs)?;
1658 match &mapped {
1659 tensor_bridge::MappedOutputs::Quantized(maps) => {
1660 let views = tensor_bridge::quantized_views(maps)?;
1661 self.decode_tracked_quantized(
1662 tracker,
1663 timestamp,
1664 &views,
1665 output_boxes,
1666 output_masks,
1667 output_tracks,
1668 )
1669 }
1670 tensor_bridge::MappedOutputs::Float16(maps) => {
1671 let views = tensor_bridge::f16_views(maps)?;
1672 self.decode_tracked_float(
1673 tracker,
1674 timestamp,
1675 &views,
1676 output_boxes,
1677 output_masks,
1678 output_tracks,
1679 )
1680 }
1681 tensor_bridge::MappedOutputs::Float32(maps) => {
1682 let views = tensor_bridge::f32_views(maps)?;
1683 self.decode_tracked_float(
1684 tracker,
1685 timestamp,
1686 &views,
1687 output_boxes,
1688 output_masks,
1689 output_tracks,
1690 )
1691 }
1692 tensor_bridge::MappedOutputs::Float64(maps) => {
1693 let views = tensor_bridge::f64_views(maps)?;
1694 self.decode_tracked_float(
1695 tracker,
1696 timestamp,
1697 &views,
1698 output_boxes,
1699 output_masks,
1700 output_tracks,
1701 )
1702 }
1703 }
1704 }
1705
1706 /// Decode model outputs with tracking, returning raw proto data for
1707 /// segmentation models.
1708 ///
1709 /// Accepts `TensorDyn` outputs directly from model inference.
1710 /// Returns `Ok(None)` for detection-only and ModelPack models.
1711 /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1712 ///
1713 /// # Arguments
1714 ///
1715 /// * `tracker` - The tracker instance to update
1716 /// * `timestamp` - Current frame timestamp
1717 /// * `outputs` - Tensor outputs from model inference
1718 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1719 /// * `output_tracks` - Destination for track info (cleared first)
1720 ///
1721 /// # Errors
1722 ///
1723 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1724 /// or the outputs don't match the decoder's model configuration.
1725 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1726 &self,
1727 tracker: &mut TR,
1728 timestamp: u64,
1729 outputs: &[&edgefirst_tensor::TensorDyn],
1730 output_boxes: &mut Vec<DetectBox>,
1731 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1732 ) -> Result<Option<ProtoData>, DecoderError> {
1733 // Per-scale fast path: route via the basic decode_proto then
1734 // update the tracker on the resulting boxes.
1735 if self.per_scale.is_some() {
1736 output_tracks.clear();
1737 let proto = self.decode_proto(outputs, output_boxes)?;
1738 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1739 return Ok(proto);
1740 }
1741
1742 let mapped = tensor_bridge::map_tensors(outputs)?;
1743 match &mapped {
1744 tensor_bridge::MappedOutputs::Quantized(maps) => {
1745 let views = tensor_bridge::quantized_views(maps)?;
1746 self.decode_tracked_quantized_proto(
1747 tracker,
1748 timestamp,
1749 &views,
1750 output_boxes,
1751 output_tracks,
1752 )
1753 }
1754 tensor_bridge::MappedOutputs::Float16(maps) => {
1755 let views = tensor_bridge::f16_views(maps)?;
1756 self.decode_tracked_float_proto(
1757 tracker,
1758 timestamp,
1759 &views,
1760 output_boxes,
1761 output_tracks,
1762 )
1763 }
1764 tensor_bridge::MappedOutputs::Float32(maps) => {
1765 let views = tensor_bridge::f32_views(maps)?;
1766 self.decode_tracked_float_proto(
1767 tracker,
1768 timestamp,
1769 &views,
1770 output_boxes,
1771 output_tracks,
1772 )
1773 }
1774 tensor_bridge::MappedOutputs::Float64(maps) => {
1775 let views = tensor_bridge::f64_views(maps)?;
1776 self.decode_tracked_float_proto(
1777 tracker,
1778 timestamp,
1779 &views,
1780 output_boxes,
1781 output_tracks,
1782 )
1783 }
1784 }
1785 }
1786}