edgefirst_decoder/decoder/mod.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 /// NMS mode (always a concrete variant after build — `Nms::Auto` is
20 /// resolved during `DecoderBuilder::build()` and never stored here):
21 /// - `Some(ClassAgnostic)` — class-agnostic NMS
22 /// - `Some(ClassAware)` — class-aware NMS
23 /// - `None` — NMS bypassed (end-to-end models)
24 pub nms: Option<configs::Nms>,
25 /// Maximum number of candidate boxes fed into NMS after score filtering.
26 /// Reduces O(N²) NMS cost when many low-confidence proposals pass the
27 /// threshold (common during COCO mAP evaluation with threshold ≈ 0.001).
28 /// Candidates are ranked by score; only the top `pre_nms_top_k` proceed
29 /// to NMS. Default: 300. Ignored when `nms` is `None`.
30 ///
31 /// # ⚠️ Validation vs Deployment
32 ///
33 /// The default of 300 is tuned for **deployment** (score threshold ≥ 0.25)
34 /// where few anchors pass the score filter, making top-K a no-op in
35 /// practice while bounding worst-case NMS cost.
36 ///
37 /// For **mAP evaluation** (score threshold ≈ 0.001), most of the 8 400
38 /// YOLO anchors pass the score filter. At `pre_nms_top_k = 300`, roughly
39 /// 74 % of candidates that would survive NMS are discarded *before* NMS
40 /// runs, causing **~9 pp box mAP loss** — a measurement artifact, not a
41 /// model quality issue.
42 ///
43 /// | Use case | `pre_nms_top_k` | `score_threshold` |
44 /// |----------|----------------:|------------------:|
45 /// | Deployment | 300 (default) | ≥ 0.25 |
46 /// | COCO mAP evaluation | 8 400 (all anchors) | 0.001 |
47 /// | Unbounded | 0 (no limit) | any |
48 ///
49 /// Post-processing latency scales with the number of candidates entering
50 /// NMS. At deployment thresholds the candidate count is already small, so
51 /// raising `pre_nms_top_k` has negligible cost. At validation thresholds
52 /// the increase is measurable but necessary for correct recall.
53 pub pre_nms_top_k: usize,
54 /// Maximum number of detections returned after NMS. Matches the
55 /// Ultralytics `max_det` parameter. Default: 300.
56 ///
57 /// This bound applies uniformly across all segmentation and detection
58 /// decode paths reached via [`Decoder::decode`] / [`Decoder::decode_proto`].
59 /// The output `Vec`'s capacity is only an allocation hint; the post-NMS
60 /// detection count is bounded solely by `max_det` (EDGEAI-1302).
61 pub max_det: usize,
62 /// Whether decoded boxes are in normalized [0,1] coordinates.
63 /// - `Some(true)`: Coordinates in [0,1] range
64 /// - `Some(false)`: Pixel coordinates
65 /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
66 /// 1.0)
67 normalized: Option<bool>,
68 /// Model input spatial dimensions `(width, height)`, captured from
69 /// the schema's `input.shape` / `input.dshape` at builder time.
70 /// Required to honour `normalized: false`: pixel-space box coords
71 /// emitted by the model are divided by these dimensions before NMS
72 /// so the post-NMS bbox is in `[0, 1]`. `None` when no schema input
73 /// spec is available — the legacy >2.0 reject in `protobox` then
74 /// preserves the previous safety net (EDGEAI-1303).
75 input_dims: Option<(usize, usize)>,
76 /// Schema v2 merge program. Present when the decoder was built from
77 /// a [`crate::schema::SchemaV2`] whose logical outputs carry
78 /// physical children. Absent for flat configurations (v1 and
79 /// flat-v2).
80 pub(crate) decode_program: Option<merge::DecodeProgram>,
81 /// Per-scale fast path. Constructed at build time from a schema-v2
82 /// document with per-scale children. Wrapped in `Mutex` because
83 /// `Decoder::decode_proto` and `Decoder::decode` are `&self` but
84 /// the per-scale buffers are mutated per-frame.
85 pub(crate) per_scale: Option<std::sync::Mutex<crate::per_scale::PerScaleDecoder>>,
86}
87
88impl PartialEq for Decoder {
89 fn eq(&self, other: &Self) -> bool {
90 // DecodeProgram and PerScaleDecoder have non-comparable embedded
91 // data; compare by the config-derived fields only.
92 self.model_type == other.model_type
93 && self.iou_threshold == other.iou_threshold
94 && self.score_threshold == other.score_threshold
95 && self.nms == other.nms
96 && self.pre_nms_top_k == other.pre_nms_top_k
97 && self.max_det == other.max_det
98 && self.normalized == other.normalized
99 && self.input_dims == other.input_dims
100 && self.decode_program.is_some() == other.decode_program.is_some()
101 && self.per_scale.is_some() == other.per_scale.is_some()
102 }
103}
104
105impl Clone for Decoder {
106 /// Cloning a `Decoder` preserves the legacy decode path
107 /// (`decode_program`) but drops the per-scale fast path:
108 /// `PerScaleDecoder` owns mutable per-frame scratch buffers and is
109 /// not `Clone`. Decoders built from a per-scale schema should be
110 /// rebuilt via [`DecoderBuilder`] rather than cloned to preserve the
111 /// fast path; cloning is intended for tests and rare configs.
112 fn clone(&self) -> Self {
113 Self {
114 model_type: self.model_type.clone(),
115 iou_threshold: self.iou_threshold,
116 score_threshold: self.score_threshold,
117 nms: self.nms,
118 pre_nms_top_k: self.pre_nms_top_k,
119 max_det: self.max_det,
120 normalized: self.normalized,
121 input_dims: self.input_dims,
122 decode_program: self.decode_program.clone(),
123 per_scale: None,
124 }
125 }
126}
127
128#[derive(Debug)]
129pub(crate) enum ArrayViewDQuantized<'a> {
130 UInt8(ArrayViewD<'a, u8>),
131 Int8(ArrayViewD<'a, i8>),
132 UInt16(ArrayViewD<'a, u16>),
133 Int16(ArrayViewD<'a, i16>),
134 UInt32(ArrayViewD<'a, u32>),
135 Int32(ArrayViewD<'a, i32>),
136}
137
138impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
139where
140 D: Dimension,
141{
142 fn from(arr: ArrayView<'a, u8, D>) -> Self {
143 Self::UInt8(arr.into_dyn())
144 }
145}
146
147impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
148where
149 D: Dimension,
150{
151 fn from(arr: ArrayView<'a, i8, D>) -> Self {
152 Self::Int8(arr.into_dyn())
153 }
154}
155
156impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
157where
158 D: Dimension,
159{
160 fn from(arr: ArrayView<'a, u16, D>) -> Self {
161 Self::UInt16(arr.into_dyn())
162 }
163}
164
165impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
166where
167 D: Dimension,
168{
169 fn from(arr: ArrayView<'a, i16, D>) -> Self {
170 Self::Int16(arr.into_dyn())
171 }
172}
173
174impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
175where
176 D: Dimension,
177{
178 fn from(arr: ArrayView<'a, u32, D>) -> Self {
179 Self::UInt32(arr.into_dyn())
180 }
181}
182
183impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
184where
185 D: Dimension,
186{
187 fn from(arr: ArrayView<'a, i32, D>) -> Self {
188 Self::Int32(arr.into_dyn())
189 }
190}
191
192impl<'a> ArrayViewDQuantized<'a> {
193 /// Returns the shape of the underlying array.
194 pub(crate) fn shape(&self) -> &[usize] {
195 match self {
196 ArrayViewDQuantized::UInt8(a) => a.shape(),
197 ArrayViewDQuantized::Int8(a) => a.shape(),
198 ArrayViewDQuantized::UInt16(a) => a.shape(),
199 ArrayViewDQuantized::Int16(a) => a.shape(),
200 ArrayViewDQuantized::UInt32(a) => a.shape(),
201 ArrayViewDQuantized::Int32(a) => a.shape(),
202 }
203 }
204}
205
206/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
207/// monomorphized code paths by 6 (one per integer variant), so nesting
208/// N levels deep produces 6^N instantiations.
209///
210/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
211/// (6*N paths) or split into independent phases that each nest at most 2 levels.
212macro_rules! with_quantized {
213 ($x:expr, $var:ident, $body:expr) => {
214 match $x {
215 ArrayViewDQuantized::UInt8(x) => {
216 let $var = x;
217 $body
218 }
219 ArrayViewDQuantized::Int8(x) => {
220 let $var = x;
221 $body
222 }
223 ArrayViewDQuantized::UInt16(x) => {
224 let $var = x;
225 $body
226 }
227 ArrayViewDQuantized::Int16(x) => {
228 let $var = x;
229 $body
230 }
231 ArrayViewDQuantized::UInt32(x) => {
232 let $var = x;
233 $body
234 }
235 ArrayViewDQuantized::Int32(x) => {
236 let $var = x;
237 $body
238 }
239 }
240 };
241}
242
243mod builder;
244mod dfl;
245mod helpers;
246mod merge;
247mod per_scale_bridge;
248mod postprocess;
249mod tensor_bridge;
250mod tests;
251
252pub use builder::DecoderBuilder;
253pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
254
255impl Decoder {
256 /// Static label identifying which dispatch path `decode` / `decode_proto`
257 /// will take, used as a tracing-span attribute. Lets profiling tools
258 /// distinguish `per_scale` (the fast path), `decode_program` (schema-v2
259 /// merge), and `legacy` (config-driven) without requiring callers to
260 /// inspect the model.
261 fn decode_path_label(&self) -> &'static str {
262 if self.per_scale.is_some() {
263 "per_scale"
264 } else if self.decode_program.is_some() {
265 "decode_program"
266 } else {
267 "legacy"
268 }
269 }
270
271 /// This function returns the parsed model type of the decoder.
272 ///
273 /// # Examples
274 ///
275 /// ```rust
276 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
277 /// # fn main() -> DecoderResult<()> {
278 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
279 /// let decoder = DecoderBuilder::default()
280 /// .with_config_yaml_str(config_yaml)
281 /// .build()?;
282 /// assert!(matches!(
283 /// decoder.model_type(),
284 /// ModelType::ModelPackDetSplit { .. }
285 /// ));
286 /// # Ok(())
287 /// # }
288 /// ```
289 pub fn model_type(&self) -> &ModelType {
290 &self.model_type
291 }
292
293 /// Returns the box coordinate format if known from the model config.
294 ///
295 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
296 /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
297 /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
298 /// 1.0)
299 ///
300 /// This is determined by the model config's `normalized` field, not the NMS
301 /// mode. When coordinates are in pixels or unknown, the caller may need
302 /// to normalize using the model input dimensions.
303 ///
304 /// # Examples
305 ///
306 /// ```rust
307 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
308 /// # fn main() -> DecoderResult<()> {
309 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
310 /// let decoder = DecoderBuilder::default()
311 /// .with_config_yaml_str(config_yaml)
312 /// .build()?;
313 /// // Config doesn't specify normalized, so it's None
314 /// assert!(decoder.normalized_boxes().is_none());
315 /// # Ok(())
316 /// # }
317 /// ```
318 pub fn normalized_boxes(&self) -> Option<bool> {
319 self.normalized
320 }
321
322 /// Model input dimensions `(width, height)` captured from the
323 /// schema's `input.shape` / `input.dshape`, or `None` when the
324 /// schema did not declare an input spec (e.g. flat YAML configs
325 /// or `DecoderBuilder::add_output(...)` programmatic builds).
326 ///
327 /// Used together with [`normalized_boxes`](Self::normalized_boxes):
328 /// when the decoder reports `normalized_boxes() == Some(false)` and
329 /// `input_dims()` is `Some((w, h))`, the decoder divides post-NMS
330 /// bbox coordinates by `(w, h)` so they enter the canonical `[0, 1]`
331 /// range before mask cropping (EDGEAI-1303). When `input_dims()` is
332 /// `None`, the decoder cannot perform the division and the existing
333 /// `protobox` `> 2.0` reject acts as a safety net.
334 ///
335 /// # Examples
336 ///
337 /// ```rust
338 /// # use edgefirst_decoder::{schema::SchemaV2, DecoderBuilder, DecoderResult};
339 /// # fn main() -> DecoderResult<()> {
340 /// let json = r#"{
341 /// "schema_version": 2,
342 /// "nms": "class_agnostic",
343 /// "input": {
344 /// "shape": [1, 640, 640, 3],
345 /// "dshape": [{"batch": 1}, {"height": 640}, {"width": 640}, {"num_features": 3}]
346 /// },
347 /// "outputs": [{
348 /// "name": "out", "type": "detection",
349 /// "shape": [1, 38, 256],
350 /// "dshape": [{"batch": 1}, {"num_features": 38}, {"num_boxes": 256}],
351 /// "decoder": "ultralytics", "encoding": "direct", "normalized": false
352 /// }]
353 /// }"#;
354 /// let schema: SchemaV2 = serde_json::from_str(json).unwrap();
355 /// let decoder = DecoderBuilder::default().with_schema(schema).build()?;
356 /// assert_eq!(decoder.input_dims(), Some((640, 640)));
357 /// # Ok(())
358 /// # }
359 /// ```
360 pub fn input_dims(&self) -> Option<(usize, usize)> {
361 self.input_dims
362 }
363
364 /// Decode quantized model outputs into detection boxes and segmentation
365 /// masks. The quantized outputs can be of u8, i8, u16, i16, u32, or i32
366 /// types. Clears the provided output vectors before populating them.
367 pub(crate) fn decode_quantized(
368 &self,
369 outputs: &[ArrayViewDQuantized],
370 output_boxes: &mut Vec<DetectBox>,
371 output_masks: &mut Vec<Segmentation>,
372 ) -> Result<(), DecoderError> {
373 output_boxes.clear();
374 output_masks.clear();
375 match &self.model_type {
376 ModelType::ModelPackSegDet {
377 boxes,
378 scores,
379 segmentation,
380 } => {
381 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
382 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
383 }
384 ModelType::ModelPackSegDetSplit {
385 detection,
386 segmentation,
387 } => {
388 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
389 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
390 }
391 ModelType::ModelPackDet { boxes, scores } => {
392 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
393 }
394 ModelType::ModelPackDetSplit { detection } => {
395 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
396 }
397 ModelType::ModelPackSeg { segmentation } => {
398 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
399 }
400 ModelType::YoloDet { boxes } => {
401 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
402 }
403 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
404 outputs,
405 boxes,
406 protos,
407 output_boxes,
408 output_masks,
409 ),
410 ModelType::YoloSplitDet { boxes, scores } => {
411 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
412 }
413 ModelType::YoloSplitSegDet {
414 boxes,
415 scores,
416 mask_coeff,
417 protos,
418 } => self.decode_yolo_split_segdet_quantized(
419 outputs,
420 boxes,
421 scores,
422 mask_coeff,
423 protos,
424 output_boxes,
425 output_masks,
426 ),
427 ModelType::YoloSegDet2Way {
428 boxes,
429 mask_coeff,
430 protos,
431 } => self.decode_yolo_segdet_2way_quantized(
432 outputs,
433 boxes,
434 mask_coeff,
435 protos,
436 output_boxes,
437 output_masks,
438 ),
439 ModelType::YoloEndToEndDet { boxes } => {
440 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
441 }
442 ModelType::YoloEndToEndSegDet { boxes, protos } => self
443 .decode_yolo_end_to_end_segdet_quantized(
444 outputs,
445 boxes,
446 protos,
447 output_boxes,
448 output_masks,
449 ),
450 ModelType::YoloSplitEndToEndDet {
451 boxes,
452 scores,
453 classes,
454 } => self.decode_yolo_split_end_to_end_det_quantized(
455 outputs,
456 boxes,
457 scores,
458 classes,
459 output_boxes,
460 ),
461 ModelType::YoloSplitEndToEndSegDet {
462 boxes,
463 scores,
464 classes,
465 mask_coeff,
466 protos,
467 } => self.decode_yolo_split_end_to_end_segdet_quantized(
468 outputs,
469 boxes,
470 scores,
471 classes,
472 mask_coeff,
473 protos,
474 output_boxes,
475 output_masks,
476 ),
477 ModelType::PerScale => Err(DecoderError::Internal(
478 "per-scale path must be intercepted before ModelType dispatch".into(),
479 )),
480 }
481 }
482
483 /// Decode floating point model outputs into detection boxes and
484 /// segmentation masks. Clears the provided output vectors before
485 /// populating them.
486 pub(crate) fn decode_float<T>(
487 &self,
488 outputs: &[ArrayViewD<T>],
489 output_boxes: &mut Vec<DetectBox>,
490 output_masks: &mut Vec<Segmentation>,
491 ) -> Result<(), DecoderError>
492 where
493 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
494 f32: AsPrimitive<T>,
495 {
496 output_boxes.clear();
497 output_masks.clear();
498 match &self.model_type {
499 ModelType::ModelPackSegDet {
500 boxes,
501 scores,
502 segmentation,
503 } => {
504 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
505 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
506 }
507 ModelType::ModelPackSegDetSplit {
508 detection,
509 segmentation,
510 } => {
511 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
512 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
513 }
514 ModelType::ModelPackDet { boxes, scores } => {
515 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
516 }
517 ModelType::ModelPackDetSplit { detection } => {
518 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
519 }
520 ModelType::ModelPackSeg { segmentation } => {
521 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
522 }
523 ModelType::YoloDet { boxes } => {
524 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
525 }
526 ModelType::YoloSegDet { boxes, protos } => {
527 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
528 }
529 ModelType::YoloSplitDet { boxes, scores } => {
530 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
531 }
532 ModelType::YoloSplitSegDet {
533 boxes,
534 scores,
535 mask_coeff,
536 protos,
537 } => {
538 self.decode_yolo_split_segdet_float(
539 outputs,
540 boxes,
541 scores,
542 mask_coeff,
543 protos,
544 output_boxes,
545 output_masks,
546 )?;
547 }
548 ModelType::YoloSegDet2Way {
549 boxes,
550 mask_coeff,
551 protos,
552 } => {
553 self.decode_yolo_segdet_2way_float(
554 outputs,
555 boxes,
556 mask_coeff,
557 protos,
558 output_boxes,
559 output_masks,
560 )?;
561 }
562 ModelType::YoloEndToEndDet { boxes } => {
563 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
564 }
565 ModelType::YoloEndToEndSegDet { boxes, protos } => {
566 self.decode_yolo_end_to_end_segdet_float(
567 outputs,
568 boxes,
569 protos,
570 output_boxes,
571 output_masks,
572 )?;
573 }
574 ModelType::YoloSplitEndToEndDet {
575 boxes,
576 scores,
577 classes,
578 } => {
579 self.decode_yolo_split_end_to_end_det_float(
580 outputs,
581 boxes,
582 scores,
583 classes,
584 output_boxes,
585 )?;
586 }
587 ModelType::YoloSplitEndToEndSegDet {
588 boxes,
589 scores,
590 classes,
591 mask_coeff,
592 protos,
593 } => {
594 self.decode_yolo_split_end_to_end_segdet_float(
595 outputs,
596 boxes,
597 scores,
598 classes,
599 mask_coeff,
600 protos,
601 output_boxes,
602 output_masks,
603 )?;
604 }
605 ModelType::PerScale => {
606 return Err(DecoderError::Internal(
607 "per-scale path must be intercepted before ModelType dispatch".into(),
608 ));
609 }
610 }
611 Ok(())
612 }
613
614 /// Decodes quantized model outputs into detection boxes, returning raw
615 /// `ProtoData` for segmentation models instead of materialized masks.
616 ///
617 /// Returns `Ok(None)` for detection-only and ModelPack models (detections
618 /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
619 /// for YOLO segmentation models.
620 pub(crate) fn decode_quantized_proto(
621 &self,
622 outputs: &[ArrayViewDQuantized],
623 output_boxes: &mut Vec<DetectBox>,
624 ) -> Result<Option<ProtoData>, DecoderError> {
625 output_boxes.clear();
626 match &self.model_type {
627 // Detection-only variants: decode boxes, return None for proto data.
628 ModelType::ModelPackDet { boxes, scores } => {
629 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
630 Ok(None)
631 }
632 ModelType::ModelPackDetSplit { detection } => {
633 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
634 Ok(None)
635 }
636 ModelType::YoloDet { boxes } => {
637 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
638 Ok(None)
639 }
640 ModelType::YoloSplitDet { boxes, scores } => {
641 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
642 Ok(None)
643 }
644 ModelType::YoloEndToEndDet { boxes } => {
645 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
646 Ok(None)
647 }
648 ModelType::YoloSplitEndToEndDet {
649 boxes,
650 scores,
651 classes,
652 } => {
653 self.decode_yolo_split_end_to_end_det_quantized(
654 outputs,
655 boxes,
656 scores,
657 classes,
658 output_boxes,
659 )?;
660 Ok(None)
661 }
662 // ModelPack seg/segdet variants have no YOLO proto data.
663 ModelType::ModelPackSegDet { boxes, scores, .. } => {
664 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
665 Ok(None)
666 }
667 ModelType::ModelPackSegDetSplit { detection, .. } => {
668 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
669 Ok(None)
670 }
671 ModelType::ModelPackSeg { .. } => Ok(None),
672
673 ModelType::YoloSegDet { boxes, protos } => {
674 let proto =
675 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
676 Ok(Some(proto))
677 }
678 ModelType::YoloSplitSegDet {
679 boxes,
680 scores,
681 mask_coeff,
682 protos,
683 } => {
684 let proto = self.decode_yolo_split_segdet_quantized_proto(
685 outputs,
686 boxes,
687 scores,
688 mask_coeff,
689 protos,
690 output_boxes,
691 )?;
692 Ok(Some(proto))
693 }
694 ModelType::YoloSegDet2Way {
695 boxes,
696 mask_coeff,
697 protos,
698 } => {
699 let proto = self.decode_yolo_segdet_2way_quantized_proto(
700 outputs,
701 boxes,
702 mask_coeff,
703 protos,
704 output_boxes,
705 )?;
706 Ok(Some(proto))
707 }
708 ModelType::YoloEndToEndSegDet { boxes, protos } => {
709 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
710 outputs,
711 boxes,
712 protos,
713 output_boxes,
714 )?;
715 Ok(Some(proto))
716 }
717 ModelType::YoloSplitEndToEndSegDet {
718 boxes,
719 scores,
720 classes,
721 mask_coeff,
722 protos,
723 } => {
724 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
725 outputs,
726 boxes,
727 scores,
728 classes,
729 mask_coeff,
730 protos,
731 output_boxes,
732 )?;
733 Ok(Some(proto))
734 }
735 ModelType::PerScale => Err(DecoderError::Internal(
736 "per-scale path must be intercepted before ModelType dispatch".into(),
737 )),
738 }
739 }
740
741 /// Decodes floating-point model outputs into detection boxes, returning
742 /// raw `ProtoData` for segmentation models instead of materialized masks.
743 ///
744 /// Returns `Ok(None)` for detection-only and ModelPack models (detections
745 /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
746 /// for YOLO segmentation models.
747 pub(crate) fn decode_float_proto<T>(
748 &self,
749 outputs: &[ArrayViewD<T>],
750 output_boxes: &mut Vec<DetectBox>,
751 ) -> Result<Option<ProtoData>, DecoderError>
752 where
753 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
754 f32: AsPrimitive<T>,
755 {
756 output_boxes.clear();
757 match &self.model_type {
758 // Detection-only variants: decode boxes, return None for proto data.
759 ModelType::ModelPackDet { boxes, scores } => {
760 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
761 Ok(None)
762 }
763 ModelType::ModelPackDetSplit { detection } => {
764 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
765 Ok(None)
766 }
767 ModelType::YoloDet { boxes } => {
768 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
769 Ok(None)
770 }
771 ModelType::YoloSplitDet { boxes, scores } => {
772 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
773 Ok(None)
774 }
775 ModelType::YoloEndToEndDet { boxes } => {
776 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
777 Ok(None)
778 }
779 ModelType::YoloSplitEndToEndDet {
780 boxes,
781 scores,
782 classes,
783 } => {
784 self.decode_yolo_split_end_to_end_det_float(
785 outputs,
786 boxes,
787 scores,
788 classes,
789 output_boxes,
790 )?;
791 Ok(None)
792 }
793 // ModelPack seg/segdet variants have no YOLO proto data.
794 ModelType::ModelPackSegDet { boxes, scores, .. } => {
795 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
796 Ok(None)
797 }
798 ModelType::ModelPackSegDetSplit { detection, .. } => {
799 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
800 Ok(None)
801 }
802 ModelType::ModelPackSeg { .. } => Ok(None),
803
804 ModelType::YoloSegDet { boxes, protos } => {
805 let proto =
806 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
807 Ok(Some(proto))
808 }
809 ModelType::YoloSplitSegDet {
810 boxes,
811 scores,
812 mask_coeff,
813 protos,
814 } => {
815 let proto = self.decode_yolo_split_segdet_float_proto(
816 outputs,
817 boxes,
818 scores,
819 mask_coeff,
820 protos,
821 output_boxes,
822 )?;
823 Ok(Some(proto))
824 }
825 ModelType::YoloSegDet2Way {
826 boxes,
827 mask_coeff,
828 protos,
829 } => {
830 let proto = self.decode_yolo_segdet_2way_float_proto(
831 outputs,
832 boxes,
833 mask_coeff,
834 protos,
835 output_boxes,
836 )?;
837 Ok(Some(proto))
838 }
839 ModelType::YoloEndToEndSegDet { boxes, protos } => {
840 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
841 outputs,
842 boxes,
843 protos,
844 output_boxes,
845 )?;
846 Ok(Some(proto))
847 }
848 ModelType::YoloSplitEndToEndSegDet {
849 boxes,
850 scores,
851 classes,
852 mask_coeff,
853 protos,
854 } => {
855 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
856 outputs,
857 boxes,
858 scores,
859 classes,
860 mask_coeff,
861 protos,
862 output_boxes,
863 )?;
864 Ok(Some(proto))
865 }
866 ModelType::PerScale => Err(DecoderError::Internal(
867 "per-scale path must be intercepted before ModelType dispatch".into(),
868 )),
869 }
870 }
871
872 // ========================================================================
873 // TensorDyn-based public API
874 // ========================================================================
875
876 /// Decode model outputs into detection boxes and segmentation masks.
877 ///
878 /// This is the primary decode API. Accepts `TensorDyn` outputs directly
879 /// from model inference. Automatically dispatches to quantized or float
880 /// paths based on the tensor dtype.
881 ///
882 /// # Arguments
883 ///
884 /// * `outputs` - Tensor outputs from model inference
885 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
886 /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
887 ///
888 /// # `output_boxes` / `output_masks` capacity
889 ///
890 /// The capacity of the supplied `Vec`s is **only** an allocation hint —
891 /// it is **not** a cap on the number of detections returned. The
892 /// post-NMS detection count is bounded by [`Decoder::max_det`] (set
893 /// via [`DecoderBuilder::with_max_det`], default `300`). Passing
894 /// `Vec::new()` (capacity 0) returns up to `max_det` detections;
895 /// pre-allocating with [`Vec::with_capacity`] only avoids the
896 /// reallocation when the decoder grows the buffer.
897 ///
898 /// # Errors
899 ///
900 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
901 /// or the outputs don't match the decoder's model configuration.
902 pub fn decode(
903 &self,
904 outputs: &[&edgefirst_tensor::TensorDyn],
905 output_boxes: &mut Vec<DetectBox>,
906 output_masks: &mut Vec<Segmentation>,
907 ) -> Result<(), DecoderError> {
908 let path = self.decode_path_label();
909 let _span = tracing::trace_span!("Decoder::decode", path = path, n_outputs = outputs.len())
910 .entered();
911 // Per-scale fast path — selected at builder time when the schema
912 // declares per-scale children with DFL or LTRB encoding.
913 if let Some(per_scale_mutex) = &self.per_scale {
914 let mut ps = per_scale_mutex
915 .lock()
916 .map_err(|e| DecoderError::Internal(format!("per_scale mutex poisoned: {e}")))?;
917 let decoded = ps.run(outputs)?;
918 return per_scale_bridge::per_scale_to_masks(
919 &decoded,
920 output_boxes,
921 output_masks,
922 self.iou_threshold,
923 self.score_threshold,
924 self.nms,
925 self.pre_nms_top_k,
926 self.max_det,
927 self.normalized,
928 self.input_dims,
929 );
930 }
931
932 // Schema v2 merge path: dequantize physical children into
933 // logical float32 tensors, then feed through the float dispatch.
934 if let Some(program) = &self.decode_program {
935 let merged = program.execute(outputs)?;
936 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
937 return self.decode_float(&views, output_boxes, output_masks);
938 }
939
940 let mapped = tensor_bridge::map_tensors(outputs)?;
941 match &mapped {
942 tensor_bridge::MappedOutputs::Quantized(maps) => {
943 let views = tensor_bridge::quantized_views(maps)?;
944 self.decode_quantized(&views, output_boxes, output_masks)
945 }
946 tensor_bridge::MappedOutputs::Float16(maps) => {
947 let views = tensor_bridge::f16_views(maps)?;
948 self.decode_float(&views, output_boxes, output_masks)
949 }
950 tensor_bridge::MappedOutputs::Float32(maps) => {
951 let views = tensor_bridge::f32_views(maps)?;
952 self.decode_float(&views, output_boxes, output_masks)
953 }
954 tensor_bridge::MappedOutputs::Float64(maps) => {
955 let views = tensor_bridge::f64_views(maps)?;
956 self.decode_float(&views, output_boxes, output_masks)
957 }
958 }
959 }
960
961 /// Decode model outputs into detection boxes, returning raw proto data
962 /// for segmentation models instead of materialized masks.
963 ///
964 /// Accepts `TensorDyn` outputs directly from model inference.
965 /// Detections are always decoded into `output_boxes` regardless of model type.
966 /// Returns `Ok(None)` for detection-only and ModelPack models.
967 /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
968 ///
969 /// # Arguments
970 ///
971 /// * `outputs` - Tensor outputs from model inference
972 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
973 ///
974 /// # `output_boxes` capacity
975 ///
976 /// The capacity of `output_boxes` is **only** an allocation hint — it
977 /// is **not** a cap on the number of detections returned. The
978 /// post-NMS detection count is bounded by [`Decoder::max_det`] (set
979 /// via [`DecoderBuilder::with_max_det`], default `300`). Passing
980 /// `Vec::new()` (capacity 0) returns up to `max_det` detections.
981 ///
982 /// # Errors
983 ///
984 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
985 /// or the outputs don't match the decoder's model configuration.
986 pub fn decode_proto(
987 &self,
988 outputs: &[&edgefirst_tensor::TensorDyn],
989 output_boxes: &mut Vec<DetectBox>,
990 ) -> Result<Option<ProtoData>, DecoderError> {
991 let path = self.decode_path_label();
992 let _span = tracing::trace_span!(
993 "Decoder::decode_proto",
994 path = path,
995 n_outputs = outputs.len()
996 )
997 .entered();
998 // Per-scale fast path — selected at builder time when the schema
999 // declares per-scale children with DFL or LTRB encoding.
1000 if let Some(per_scale_mutex) = &self.per_scale {
1001 let mut ps = per_scale_mutex
1002 .lock()
1003 .map_err(|e| DecoderError::Internal(format!("per_scale mutex poisoned: {e}")))?;
1004 let decoded = ps.run(outputs)?;
1005 return per_scale_bridge::per_scale_to_proto_data(
1006 &decoded,
1007 output_boxes,
1008 self.iou_threshold,
1009 self.score_threshold,
1010 self.nms,
1011 self.pre_nms_top_k,
1012 self.max_det,
1013 self.normalized,
1014 self.input_dims,
1015 );
1016 }
1017
1018 // Schema v2 merge path: dequantize physical children into
1019 // logical float32 tensors, then feed through the float dispatch.
1020 if let Some(program) = &self.decode_program {
1021 let merged = program.execute(outputs)?;
1022 let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
1023 return self.decode_float_proto(&views, output_boxes);
1024 }
1025
1026 let mapped = tensor_bridge::map_tensors(outputs)?;
1027 let result = match &mapped {
1028 tensor_bridge::MappedOutputs::Quantized(maps) => {
1029 let views = tensor_bridge::quantized_views(maps)?;
1030 self.decode_quantized_proto(&views, output_boxes)
1031 }
1032 tensor_bridge::MappedOutputs::Float16(maps) => {
1033 let views = tensor_bridge::f16_views(maps)?;
1034 self.decode_float_proto(&views, output_boxes)
1035 }
1036 tensor_bridge::MappedOutputs::Float32(maps) => {
1037 let views = tensor_bridge::f32_views(maps)?;
1038 self.decode_float_proto(&views, output_boxes)
1039 }
1040 tensor_bridge::MappedOutputs::Float64(maps) => {
1041 let views = tensor_bridge::f64_views(maps)?;
1042 self.decode_float_proto(&views, output_boxes)
1043 }
1044 };
1045 result
1046 }
1047
1048 /// Run the per-scale pipeline and return pre-NMS buffers as owned f32.
1049 ///
1050 /// Test-only entry point used by the parity-fixture tests to compare
1051 /// HAL stage output against the NumPy reference's stage output
1052 /// without NMS ordering noise. Returns an error if the decoder
1053 /// isn't configured for per-scale decoding.
1054 #[doc(hidden)]
1055 pub fn _testing_run_per_scale_pre_nms(
1056 &self,
1057 outputs: &[&edgefirst_tensor::TensorDyn],
1058 ) -> Result<crate::per_scale::PreNmsCapture, crate::error::DecoderError> {
1059 let mutex = self.per_scale.as_ref().ok_or_else(|| {
1060 crate::error::DecoderError::Internal("decoder not configured for per-scale".into())
1061 })?;
1062 let mut ps = mutex.lock().map_err(|e| {
1063 crate::error::DecoderError::Internal(format!("per_scale mutex poisoned: {e}"))
1064 })?;
1065 // Drop the borrowed view immediately so we can reborrow buffers below.
1066 {
1067 ps.run(outputs)?;
1068 }
1069 let total_anchors = ps.plan.total_anchors;
1070 let num_classes = ps.plan.num_classes;
1071 let num_mc = ps.plan.num_mask_coefs;
1072 Ok(ps
1073 .buffers
1074 .snapshot_owned_f32(total_anchors, num_classes, num_mc))
1075 }
1076}
1077
1078#[cfg(feature = "tracker")]
1079pub use edgefirst_tracker::TrackInfo;
1080
1081#[cfg(feature = "tracker")]
1082pub use edgefirst_tracker::Tracker;
1083
1084#[cfg(feature = "tracker")]
1085impl Decoder {
1086 /// Decode quantized model outputs into detection boxes and segmentation
1087 /// masks with tracking. Clears the provided output vectors before
1088 /// populating them.
1089 pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
1090 &self,
1091 tracker: &mut TR,
1092 timestamp: u64,
1093 outputs: &[ArrayViewDQuantized],
1094 output_boxes: &mut Vec<DetectBox>,
1095 output_masks: &mut Vec<Segmentation>,
1096 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1097 ) -> Result<(), DecoderError> {
1098 output_boxes.clear();
1099 output_masks.clear();
1100 output_tracks.clear();
1101
1102 // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
1103 // Only boxes that come from decoding can be used for proto/mask generation.
1104 match &self.model_type {
1105 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
1106 tracker,
1107 timestamp,
1108 outputs,
1109 boxes,
1110 protos,
1111 output_boxes,
1112 output_masks,
1113 output_tracks,
1114 ),
1115 ModelType::YoloSplitSegDet {
1116 boxes,
1117 scores,
1118 mask_coeff,
1119 protos,
1120 } => self.decode_tracked_yolo_split_segdet_quantized(
1121 tracker,
1122 timestamp,
1123 outputs,
1124 boxes,
1125 scores,
1126 mask_coeff,
1127 protos,
1128 output_boxes,
1129 output_masks,
1130 output_tracks,
1131 ),
1132 ModelType::YoloEndToEndSegDet { boxes, protos } => self
1133 .decode_tracked_yolo_end_to_end_segdet_quantized(
1134 tracker,
1135 timestamp,
1136 outputs,
1137 boxes,
1138 protos,
1139 output_boxes,
1140 output_masks,
1141 output_tracks,
1142 ),
1143 ModelType::YoloSplitEndToEndSegDet {
1144 boxes,
1145 scores,
1146 classes,
1147 mask_coeff,
1148 protos,
1149 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
1150 tracker,
1151 timestamp,
1152 outputs,
1153 boxes,
1154 scores,
1155 classes,
1156 mask_coeff,
1157 protos,
1158 output_boxes,
1159 output_masks,
1160 output_tracks,
1161 ),
1162 ModelType::YoloSegDet2Way {
1163 boxes,
1164 mask_coeff,
1165 protos,
1166 } => self.decode_tracked_yolo_segdet_2way_quantized(
1167 tracker,
1168 timestamp,
1169 outputs,
1170 boxes,
1171 mask_coeff,
1172 protos,
1173 output_boxes,
1174 output_masks,
1175 output_tracks,
1176 ),
1177 _ => {
1178 self.decode_quantized(outputs, output_boxes, output_masks)?;
1179 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1180 Ok(())
1181 }
1182 }
1183 }
1184
1185 /// This function decodes floating point model outputs into detection boxes
1186 /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
1187 /// masks will be decoded. The function clears the provided output
1188 /// vectors before populating them with the decoded results.
1189 ///
1190 /// This function returns an `Error` if the provided outputs don't
1191 /// match the configuration provided by the user when building the decoder.
1192 ///
1193 /// Any quantization information in the configuration will be ignored.
1194 pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1195 &self,
1196 tracker: &mut TR,
1197 timestamp: u64,
1198 outputs: &[ArrayViewD<T>],
1199 output_boxes: &mut Vec<DetectBox>,
1200 output_masks: &mut Vec<Segmentation>,
1201 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1202 ) -> Result<(), DecoderError>
1203 where
1204 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1205 f32: AsPrimitive<T>,
1206 {
1207 output_boxes.clear();
1208 output_masks.clear();
1209 output_tracks.clear();
1210 match &self.model_type {
1211 ModelType::YoloSegDet { boxes, protos } => {
1212 self.decode_tracked_yolo_segdet_float(
1213 tracker,
1214 timestamp,
1215 outputs,
1216 boxes,
1217 protos,
1218 output_boxes,
1219 output_masks,
1220 output_tracks,
1221 )?;
1222 }
1223 ModelType::YoloSplitSegDet {
1224 boxes,
1225 scores,
1226 mask_coeff,
1227 protos,
1228 } => {
1229 self.decode_tracked_yolo_split_segdet_float(
1230 tracker,
1231 timestamp,
1232 outputs,
1233 boxes,
1234 scores,
1235 mask_coeff,
1236 protos,
1237 output_boxes,
1238 output_masks,
1239 output_tracks,
1240 )?;
1241 }
1242 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1243 self.decode_tracked_yolo_end_to_end_segdet_float(
1244 tracker,
1245 timestamp,
1246 outputs,
1247 boxes,
1248 protos,
1249 output_boxes,
1250 output_masks,
1251 output_tracks,
1252 )?;
1253 }
1254 ModelType::YoloSplitEndToEndSegDet {
1255 boxes,
1256 scores,
1257 classes,
1258 mask_coeff,
1259 protos,
1260 } => {
1261 self.decode_tracked_yolo_split_end_to_end_segdet_float(
1262 tracker,
1263 timestamp,
1264 outputs,
1265 boxes,
1266 scores,
1267 classes,
1268 mask_coeff,
1269 protos,
1270 output_boxes,
1271 output_masks,
1272 output_tracks,
1273 )?;
1274 }
1275 ModelType::YoloSegDet2Way {
1276 boxes,
1277 mask_coeff,
1278 protos,
1279 } => {
1280 self.decode_tracked_yolo_segdet_2way_float(
1281 tracker,
1282 timestamp,
1283 outputs,
1284 boxes,
1285 mask_coeff,
1286 protos,
1287 output_boxes,
1288 output_masks,
1289 output_tracks,
1290 )?;
1291 }
1292 _ => {
1293 self.decode_float(outputs, output_boxes, output_masks)?;
1294 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1295 }
1296 }
1297 Ok(())
1298 }
1299
1300 /// Decodes quantized model outputs into detection boxes, returning raw
1301 /// `ProtoData` for segmentation models instead of materialized masks.
1302 ///
1303 /// Returns `Ok(None)` for detection-only and ModelPack models (use
1304 /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
1305 /// YOLO segmentation models.
1306 pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1307 &self,
1308 tracker: &mut TR,
1309 timestamp: u64,
1310 outputs: &[ArrayViewDQuantized],
1311 output_boxes: &mut Vec<DetectBox>,
1312 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1313 ) -> Result<Option<ProtoData>, DecoderError> {
1314 output_boxes.clear();
1315 output_tracks.clear();
1316 match &self.model_type {
1317 ModelType::YoloSegDet { boxes, protos } => {
1318 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1319 tracker,
1320 timestamp,
1321 outputs,
1322 boxes,
1323 protos,
1324 output_boxes,
1325 output_tracks,
1326 )?;
1327 Ok(Some(proto))
1328 }
1329 ModelType::YoloSplitSegDet {
1330 boxes,
1331 scores,
1332 mask_coeff,
1333 protos,
1334 } => {
1335 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1336 tracker,
1337 timestamp,
1338 outputs,
1339 boxes,
1340 scores,
1341 mask_coeff,
1342 protos,
1343 output_boxes,
1344 output_tracks,
1345 )?;
1346 Ok(Some(proto))
1347 }
1348 ModelType::YoloSegDet2Way {
1349 boxes,
1350 mask_coeff,
1351 protos,
1352 } => {
1353 let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1354 tracker,
1355 timestamp,
1356 outputs,
1357 boxes,
1358 mask_coeff,
1359 protos,
1360 output_boxes,
1361 output_tracks,
1362 )?;
1363 Ok(Some(proto))
1364 }
1365 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1366 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1367 tracker,
1368 timestamp,
1369 outputs,
1370 boxes,
1371 protos,
1372 output_boxes,
1373 output_tracks,
1374 )?;
1375 Ok(Some(proto))
1376 }
1377 ModelType::YoloSplitEndToEndSegDet {
1378 boxes,
1379 scores,
1380 classes,
1381 mask_coeff,
1382 protos,
1383 } => {
1384 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1385 tracker,
1386 timestamp,
1387 outputs,
1388 boxes,
1389 scores,
1390 classes,
1391 mask_coeff,
1392 protos,
1393 output_boxes,
1394 output_tracks,
1395 )?;
1396 Ok(Some(proto))
1397 }
1398 // Non-seg variants: decode boxes via the non-proto path, then track.
1399 _ => {
1400 let mut masks = Vec::new();
1401 self.decode_quantized(outputs, output_boxes, &mut masks)?;
1402 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1403 Ok(None)
1404 }
1405 }
1406 }
1407
1408 /// Decodes floating-point model outputs into detection boxes, returning
1409 /// raw `ProtoData` for segmentation models instead of materialized masks.
1410 ///
1411 /// Detections are always decoded into `output_boxes` regardless of model type.
1412 /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1413 /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1414 pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1415 &self,
1416 tracker: &mut TR,
1417 timestamp: u64,
1418 outputs: &[ArrayViewD<T>],
1419 output_boxes: &mut Vec<DetectBox>,
1420 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1421 ) -> Result<Option<ProtoData>, DecoderError>
1422 where
1423 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1424 f32: AsPrimitive<T>,
1425 {
1426 output_boxes.clear();
1427 output_tracks.clear();
1428 match &self.model_type {
1429 ModelType::YoloSegDet { boxes, protos } => {
1430 let proto = self.decode_tracked_yolo_segdet_float_proto(
1431 tracker,
1432 timestamp,
1433 outputs,
1434 boxes,
1435 protos,
1436 output_boxes,
1437 output_tracks,
1438 )?;
1439 Ok(Some(proto))
1440 }
1441 ModelType::YoloSplitSegDet {
1442 boxes,
1443 scores,
1444 mask_coeff,
1445 protos,
1446 } => {
1447 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1448 tracker,
1449 timestamp,
1450 outputs,
1451 boxes,
1452 scores,
1453 mask_coeff,
1454 protos,
1455 output_boxes,
1456 output_tracks,
1457 )?;
1458 Ok(Some(proto))
1459 }
1460 ModelType::YoloSegDet2Way {
1461 boxes,
1462 mask_coeff,
1463 protos,
1464 } => {
1465 let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1466 tracker,
1467 timestamp,
1468 outputs,
1469 boxes,
1470 mask_coeff,
1471 protos,
1472 output_boxes,
1473 output_tracks,
1474 )?;
1475 Ok(Some(proto))
1476 }
1477 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1478 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1479 tracker,
1480 timestamp,
1481 outputs,
1482 boxes,
1483 protos,
1484 output_boxes,
1485 output_tracks,
1486 )?;
1487 Ok(Some(proto))
1488 }
1489 ModelType::YoloSplitEndToEndSegDet {
1490 boxes,
1491 scores,
1492 classes,
1493 mask_coeff,
1494 protos,
1495 } => {
1496 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1497 tracker,
1498 timestamp,
1499 outputs,
1500 boxes,
1501 scores,
1502 classes,
1503 mask_coeff,
1504 protos,
1505 output_boxes,
1506 output_tracks,
1507 )?;
1508 Ok(Some(proto))
1509 }
1510 // Non-seg variants: decode boxes via the non-proto path, then track.
1511 _ => {
1512 let mut masks = Vec::new();
1513 self.decode_float(outputs, output_boxes, &mut masks)?;
1514 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1515 Ok(None)
1516 }
1517 }
1518 }
1519
1520 // ========================================================================
1521 // TensorDyn-based tracked public API
1522 // ========================================================================
1523
1524 /// Decode model outputs with tracking.
1525 ///
1526 /// Accepts `TensorDyn` outputs directly from model inference. Automatically
1527 /// dispatches to quantized or float paths based on the tensor dtype, then
1528 /// updates the tracker with the decoded boxes.
1529 ///
1530 /// # Arguments
1531 ///
1532 /// * `tracker` - The tracker instance to update
1533 /// * `timestamp` - Current frame timestamp
1534 /// * `outputs` - Tensor outputs from model inference
1535 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1536 /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
1537 /// * `output_tracks` - Destination for track info (cleared first)
1538 ///
1539 /// # Errors
1540 ///
1541 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1542 /// or the outputs don't match the decoder's model configuration.
1543 pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1544 &self,
1545 tracker: &mut TR,
1546 timestamp: u64,
1547 outputs: &[&edgefirst_tensor::TensorDyn],
1548 output_boxes: &mut Vec<DetectBox>,
1549 output_masks: &mut Vec<Segmentation>,
1550 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1551 ) -> Result<(), DecoderError> {
1552 // Per-scale fast path: route via the basic decode then update the
1553 // tracker. Phase 1 keeps the tracker integration simple; per-frame
1554 // decoupling between detection and tracking is preserved.
1555 if self.per_scale.is_some() {
1556 output_tracks.clear();
1557 self.decode(outputs, output_boxes, output_masks)?;
1558 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1559 return Ok(());
1560 }
1561
1562 let mapped = tensor_bridge::map_tensors(outputs)?;
1563 match &mapped {
1564 tensor_bridge::MappedOutputs::Quantized(maps) => {
1565 let views = tensor_bridge::quantized_views(maps)?;
1566 self.decode_tracked_quantized(
1567 tracker,
1568 timestamp,
1569 &views,
1570 output_boxes,
1571 output_masks,
1572 output_tracks,
1573 )
1574 }
1575 tensor_bridge::MappedOutputs::Float16(maps) => {
1576 let views = tensor_bridge::f16_views(maps)?;
1577 self.decode_tracked_float(
1578 tracker,
1579 timestamp,
1580 &views,
1581 output_boxes,
1582 output_masks,
1583 output_tracks,
1584 )
1585 }
1586 tensor_bridge::MappedOutputs::Float32(maps) => {
1587 let views = tensor_bridge::f32_views(maps)?;
1588 self.decode_tracked_float(
1589 tracker,
1590 timestamp,
1591 &views,
1592 output_boxes,
1593 output_masks,
1594 output_tracks,
1595 )
1596 }
1597 tensor_bridge::MappedOutputs::Float64(maps) => {
1598 let views = tensor_bridge::f64_views(maps)?;
1599 self.decode_tracked_float(
1600 tracker,
1601 timestamp,
1602 &views,
1603 output_boxes,
1604 output_masks,
1605 output_tracks,
1606 )
1607 }
1608 }
1609 }
1610
1611 /// Decode model outputs with tracking, returning raw proto data for
1612 /// segmentation models.
1613 ///
1614 /// Accepts `TensorDyn` outputs directly from model inference.
1615 /// Returns `Ok(None)` for detection-only and ModelPack models.
1616 /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1617 ///
1618 /// # Arguments
1619 ///
1620 /// * `tracker` - The tracker instance to update
1621 /// * `timestamp` - Current frame timestamp
1622 /// * `outputs` - Tensor outputs from model inference
1623 /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1624 /// * `output_tracks` - Destination for track info (cleared first)
1625 ///
1626 /// # Errors
1627 ///
1628 /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1629 /// or the outputs don't match the decoder's model configuration.
1630 pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1631 &self,
1632 tracker: &mut TR,
1633 timestamp: u64,
1634 outputs: &[&edgefirst_tensor::TensorDyn],
1635 output_boxes: &mut Vec<DetectBox>,
1636 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1637 ) -> Result<Option<ProtoData>, DecoderError> {
1638 // Per-scale fast path: route via the basic decode_proto then
1639 // update the tracker on the resulting boxes.
1640 if self.per_scale.is_some() {
1641 output_tracks.clear();
1642 let proto = self.decode_proto(outputs, output_boxes)?;
1643 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1644 return Ok(proto);
1645 }
1646
1647 let mapped = tensor_bridge::map_tensors(outputs)?;
1648 match &mapped {
1649 tensor_bridge::MappedOutputs::Quantized(maps) => {
1650 let views = tensor_bridge::quantized_views(maps)?;
1651 self.decode_tracked_quantized_proto(
1652 tracker,
1653 timestamp,
1654 &views,
1655 output_boxes,
1656 output_tracks,
1657 )
1658 }
1659 tensor_bridge::MappedOutputs::Float16(maps) => {
1660 let views = tensor_bridge::f16_views(maps)?;
1661 self.decode_tracked_float_proto(
1662 tracker,
1663 timestamp,
1664 &views,
1665 output_boxes,
1666 output_tracks,
1667 )
1668 }
1669 tensor_bridge::MappedOutputs::Float32(maps) => {
1670 let views = tensor_bridge::f32_views(maps)?;
1671 self.decode_tracked_float_proto(
1672 tracker,
1673 timestamp,
1674 &views,
1675 output_boxes,
1676 output_tracks,
1677 )
1678 }
1679 tensor_bridge::MappedOutputs::Float64(maps) => {
1680 let views = tensor_bridge::f64_views(maps)?;
1681 self.decode_tracked_float_proto(
1682 tracker,
1683 timestamp,
1684 &views,
1685 output_boxes,
1686 output_tracks,
1687 )
1688 }
1689 }
1690 }
1691}