Skip to main content

edgefirst_decoder/decoder/
mod.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone)]
15pub struct Decoder {
16    model_type: ModelType,
17    pub iou_threshold: f32,
18    pub score_threshold: f32,
19    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
20    /// models)
21    pub nms: Option<configs::Nms>,
22    /// Whether decoded boxes are in normalized [0,1] coordinates.
23    /// - `Some(true)`: Coordinates in [0,1] range
24    /// - `Some(false)`: Pixel coordinates
25    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
26    ///   1.0)
27    normalized: Option<bool>,
28    /// Schema v2 merge program. Present when the decoder was built from
29    /// a [`crate::schema::SchemaV2`] whose logical outputs carry
30    /// physical children. Absent for flat configurations (v1 and
31    /// flat-v2).
32    pub(crate) decode_program: Option<merge::DecodeProgram>,
33}
34
35impl PartialEq for Decoder {
36    fn eq(&self, other: &Self) -> bool {
37        // DecodeProgram has non-comparable embedded data; compare by
38        // the config-derived fields only.
39        self.model_type == other.model_type
40            && self.iou_threshold == other.iou_threshold
41            && self.score_threshold == other.score_threshold
42            && self.nms == other.nms
43            && self.normalized == other.normalized
44            && self.decode_program.is_some() == other.decode_program.is_some()
45    }
46}
47
48#[derive(Debug)]
49pub(crate) enum ArrayViewDQuantized<'a> {
50    UInt8(ArrayViewD<'a, u8>),
51    Int8(ArrayViewD<'a, i8>),
52    UInt16(ArrayViewD<'a, u16>),
53    Int16(ArrayViewD<'a, i16>),
54    UInt32(ArrayViewD<'a, u32>),
55    Int32(ArrayViewD<'a, i32>),
56}
57
58impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
59where
60    D: Dimension,
61{
62    fn from(arr: ArrayView<'a, u8, D>) -> Self {
63        Self::UInt8(arr.into_dyn())
64    }
65}
66
67impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
68where
69    D: Dimension,
70{
71    fn from(arr: ArrayView<'a, i8, D>) -> Self {
72        Self::Int8(arr.into_dyn())
73    }
74}
75
76impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
77where
78    D: Dimension,
79{
80    fn from(arr: ArrayView<'a, u16, D>) -> Self {
81        Self::UInt16(arr.into_dyn())
82    }
83}
84
85impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
86where
87    D: Dimension,
88{
89    fn from(arr: ArrayView<'a, i16, D>) -> Self {
90        Self::Int16(arr.into_dyn())
91    }
92}
93
94impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
95where
96    D: Dimension,
97{
98    fn from(arr: ArrayView<'a, u32, D>) -> Self {
99        Self::UInt32(arr.into_dyn())
100    }
101}
102
103impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
104where
105    D: Dimension,
106{
107    fn from(arr: ArrayView<'a, i32, D>) -> Self {
108        Self::Int32(arr.into_dyn())
109    }
110}
111
112impl<'a> ArrayViewDQuantized<'a> {
113    /// Returns the shape of the underlying array.
114    pub(crate) fn shape(&self) -> &[usize] {
115        match self {
116            ArrayViewDQuantized::UInt8(a) => a.shape(),
117            ArrayViewDQuantized::Int8(a) => a.shape(),
118            ArrayViewDQuantized::UInt16(a) => a.shape(),
119            ArrayViewDQuantized::Int16(a) => a.shape(),
120            ArrayViewDQuantized::UInt32(a) => a.shape(),
121            ArrayViewDQuantized::Int32(a) => a.shape(),
122        }
123    }
124}
125
126/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
127/// monomorphized code paths by 6 (one per integer variant), so nesting
128/// N levels deep produces 6^N instantiations.
129///
130/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
131/// (6*N paths) or split into independent phases that each nest at most 2 levels.
132macro_rules! with_quantized {
133    ($x:expr, $var:ident, $body:expr) => {
134        match $x {
135            ArrayViewDQuantized::UInt8(x) => {
136                let $var = x;
137                $body
138            }
139            ArrayViewDQuantized::Int8(x) => {
140                let $var = x;
141                $body
142            }
143            ArrayViewDQuantized::UInt16(x) => {
144                let $var = x;
145                $body
146            }
147            ArrayViewDQuantized::Int16(x) => {
148                let $var = x;
149                $body
150            }
151            ArrayViewDQuantized::UInt32(x) => {
152                let $var = x;
153                $body
154            }
155            ArrayViewDQuantized::Int32(x) => {
156                let $var = x;
157                $body
158            }
159        }
160    };
161}
162
163mod builder;
164mod dfl;
165mod helpers;
166mod merge;
167mod postprocess;
168mod tensor_bridge;
169mod tests;
170
171pub use builder::DecoderBuilder;
172pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
173
174impl Decoder {
175    /// This function returns the parsed model type of the decoder.
176    ///
177    /// # Examples
178    ///
179    /// ```rust
180    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
181    /// # fn main() -> DecoderResult<()> {
182    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
183    ///     let decoder = DecoderBuilder::default()
184    ///         .with_config_yaml_str(config_yaml)
185    ///         .build()?;
186    ///     assert!(matches!(
187    ///         decoder.model_type(),
188    ///         ModelType::ModelPackDetSplit { .. }
189    ///     ));
190    /// #    Ok(())
191    /// # }
192    /// ```
193    pub fn model_type(&self) -> &ModelType {
194        &self.model_type
195    }
196
197    /// Returns the box coordinate format if known from the model config.
198    ///
199    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
200    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
201    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
202    ///   1.0)
203    ///
204    /// This is determined by the model config's `normalized` field, not the NMS
205    /// mode. When coordinates are in pixels or unknown, the caller may need
206    /// to normalize using the model input dimensions.
207    ///
208    /// # Examples
209    ///
210    /// ```rust
211    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
212    /// # fn main() -> DecoderResult<()> {
213    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
214    ///     let decoder = DecoderBuilder::default()
215    ///         .with_config_yaml_str(config_yaml)
216    ///         .build()?;
217    ///     // Config doesn't specify normalized, so it's None
218    ///     assert!(decoder.normalized_boxes().is_none());
219    /// #    Ok(())
220    /// # }
221    /// ```
222    pub fn normalized_boxes(&self) -> Option<bool> {
223        self.normalized
224    }
225
226    /// Decode quantized model outputs into detection boxes and segmentation
227    /// masks. The quantized outputs can be of u8, i8, u16, i16, u32, or i32
228    /// types. Clears the provided output vectors before populating them.
229    pub(crate) fn decode_quantized(
230        &self,
231        outputs: &[ArrayViewDQuantized],
232        output_boxes: &mut Vec<DetectBox>,
233        output_masks: &mut Vec<Segmentation>,
234    ) -> Result<(), DecoderError> {
235        output_boxes.clear();
236        output_masks.clear();
237        match &self.model_type {
238            ModelType::ModelPackSegDet {
239                boxes,
240                scores,
241                segmentation,
242            } => {
243                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
244                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
245            }
246            ModelType::ModelPackSegDetSplit {
247                detection,
248                segmentation,
249            } => {
250                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
251                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
252            }
253            ModelType::ModelPackDet { boxes, scores } => {
254                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
255            }
256            ModelType::ModelPackDetSplit { detection } => {
257                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
258            }
259            ModelType::ModelPackSeg { segmentation } => {
260                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
261            }
262            ModelType::YoloDet { boxes } => {
263                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
264            }
265            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
266                outputs,
267                boxes,
268                protos,
269                output_boxes,
270                output_masks,
271            ),
272            ModelType::YoloSplitDet { boxes, scores } => {
273                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
274            }
275            ModelType::YoloSplitSegDet {
276                boxes,
277                scores,
278                mask_coeff,
279                protos,
280            } => self.decode_yolo_split_segdet_quantized(
281                outputs,
282                boxes,
283                scores,
284                mask_coeff,
285                protos,
286                output_boxes,
287                output_masks,
288            ),
289            ModelType::YoloSegDet2Way {
290                boxes,
291                mask_coeff,
292                protos,
293            } => self.decode_yolo_segdet_2way_quantized(
294                outputs,
295                boxes,
296                mask_coeff,
297                protos,
298                output_boxes,
299                output_masks,
300            ),
301            ModelType::YoloEndToEndDet { boxes } => {
302                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
303            }
304            ModelType::YoloEndToEndSegDet { boxes, protos } => self
305                .decode_yolo_end_to_end_segdet_quantized(
306                    outputs,
307                    boxes,
308                    protos,
309                    output_boxes,
310                    output_masks,
311                ),
312            ModelType::YoloSplitEndToEndDet {
313                boxes,
314                scores,
315                classes,
316            } => self.decode_yolo_split_end_to_end_det_quantized(
317                outputs,
318                boxes,
319                scores,
320                classes,
321                output_boxes,
322            ),
323            ModelType::YoloSplitEndToEndSegDet {
324                boxes,
325                scores,
326                classes,
327                mask_coeff,
328                protos,
329            } => self.decode_yolo_split_end_to_end_segdet_quantized(
330                outputs,
331                boxes,
332                scores,
333                classes,
334                mask_coeff,
335                protos,
336                output_boxes,
337                output_masks,
338            ),
339        }
340    }
341
342    /// Decode floating point model outputs into detection boxes and
343    /// segmentation masks. Clears the provided output vectors before
344    /// populating them.
345    pub(crate) fn decode_float<T>(
346        &self,
347        outputs: &[ArrayViewD<T>],
348        output_boxes: &mut Vec<DetectBox>,
349        output_masks: &mut Vec<Segmentation>,
350    ) -> Result<(), DecoderError>
351    where
352        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
353        f32: AsPrimitive<T>,
354    {
355        output_boxes.clear();
356        output_masks.clear();
357        match &self.model_type {
358            ModelType::ModelPackSegDet {
359                boxes,
360                scores,
361                segmentation,
362            } => {
363                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
364                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
365            }
366            ModelType::ModelPackSegDetSplit {
367                detection,
368                segmentation,
369            } => {
370                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
371                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
372            }
373            ModelType::ModelPackDet { boxes, scores } => {
374                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
375            }
376            ModelType::ModelPackDetSplit { detection } => {
377                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
378            }
379            ModelType::ModelPackSeg { segmentation } => {
380                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
381            }
382            ModelType::YoloDet { boxes } => {
383                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
384            }
385            ModelType::YoloSegDet { boxes, protos } => {
386                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
387            }
388            ModelType::YoloSplitDet { boxes, scores } => {
389                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
390            }
391            ModelType::YoloSplitSegDet {
392                boxes,
393                scores,
394                mask_coeff,
395                protos,
396            } => {
397                self.decode_yolo_split_segdet_float(
398                    outputs,
399                    boxes,
400                    scores,
401                    mask_coeff,
402                    protos,
403                    output_boxes,
404                    output_masks,
405                )?;
406            }
407            ModelType::YoloSegDet2Way {
408                boxes,
409                mask_coeff,
410                protos,
411            } => {
412                self.decode_yolo_segdet_2way_float(
413                    outputs,
414                    boxes,
415                    mask_coeff,
416                    protos,
417                    output_boxes,
418                    output_masks,
419                )?;
420            }
421            ModelType::YoloEndToEndDet { boxes } => {
422                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
423            }
424            ModelType::YoloEndToEndSegDet { boxes, protos } => {
425                self.decode_yolo_end_to_end_segdet_float(
426                    outputs,
427                    boxes,
428                    protos,
429                    output_boxes,
430                    output_masks,
431                )?;
432            }
433            ModelType::YoloSplitEndToEndDet {
434                boxes,
435                scores,
436                classes,
437            } => {
438                self.decode_yolo_split_end_to_end_det_float(
439                    outputs,
440                    boxes,
441                    scores,
442                    classes,
443                    output_boxes,
444                )?;
445            }
446            ModelType::YoloSplitEndToEndSegDet {
447                boxes,
448                scores,
449                classes,
450                mask_coeff,
451                protos,
452            } => {
453                self.decode_yolo_split_end_to_end_segdet_float(
454                    outputs,
455                    boxes,
456                    scores,
457                    classes,
458                    mask_coeff,
459                    protos,
460                    output_boxes,
461                    output_masks,
462                )?;
463            }
464        }
465        Ok(())
466    }
467
468    /// Decodes quantized model outputs into detection boxes, returning raw
469    /// `ProtoData` for segmentation models instead of materialized masks.
470    ///
471    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
472    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
473    /// for YOLO segmentation models.
474    pub(crate) fn decode_quantized_proto(
475        &self,
476        outputs: &[ArrayViewDQuantized],
477        output_boxes: &mut Vec<DetectBox>,
478    ) -> Result<Option<ProtoData>, DecoderError> {
479        output_boxes.clear();
480        match &self.model_type {
481            // Detection-only variants: decode boxes, return None for proto data.
482            ModelType::ModelPackDet { boxes, scores } => {
483                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
484                Ok(None)
485            }
486            ModelType::ModelPackDetSplit { detection } => {
487                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
488                Ok(None)
489            }
490            ModelType::YoloDet { boxes } => {
491                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
492                Ok(None)
493            }
494            ModelType::YoloSplitDet { boxes, scores } => {
495                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
496                Ok(None)
497            }
498            ModelType::YoloEndToEndDet { boxes } => {
499                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
500                Ok(None)
501            }
502            ModelType::YoloSplitEndToEndDet {
503                boxes,
504                scores,
505                classes,
506            } => {
507                self.decode_yolo_split_end_to_end_det_quantized(
508                    outputs,
509                    boxes,
510                    scores,
511                    classes,
512                    output_boxes,
513                )?;
514                Ok(None)
515            }
516            // ModelPack seg/segdet variants have no YOLO proto data.
517            ModelType::ModelPackSegDet { boxes, scores, .. } => {
518                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
519                Ok(None)
520            }
521            ModelType::ModelPackSegDetSplit { detection, .. } => {
522                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
523                Ok(None)
524            }
525            ModelType::ModelPackSeg { .. } => Ok(None),
526
527            ModelType::YoloSegDet { boxes, protos } => {
528                let proto =
529                    self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
530                Ok(Some(proto))
531            }
532            ModelType::YoloSplitSegDet {
533                boxes,
534                scores,
535                mask_coeff,
536                protos,
537            } => {
538                let proto = self.decode_yolo_split_segdet_quantized_proto(
539                    outputs,
540                    boxes,
541                    scores,
542                    mask_coeff,
543                    protos,
544                    output_boxes,
545                )?;
546                Ok(Some(proto))
547            }
548            ModelType::YoloSegDet2Way {
549                boxes,
550                mask_coeff,
551                protos,
552            } => {
553                let proto = self.decode_yolo_segdet_2way_quantized_proto(
554                    outputs,
555                    boxes,
556                    mask_coeff,
557                    protos,
558                    output_boxes,
559                )?;
560                Ok(Some(proto))
561            }
562            ModelType::YoloEndToEndSegDet { boxes, protos } => {
563                let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
564                    outputs,
565                    boxes,
566                    protos,
567                    output_boxes,
568                )?;
569                Ok(Some(proto))
570            }
571            ModelType::YoloSplitEndToEndSegDet {
572                boxes,
573                scores,
574                classes,
575                mask_coeff,
576                protos,
577            } => {
578                let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
579                    outputs,
580                    boxes,
581                    scores,
582                    classes,
583                    mask_coeff,
584                    protos,
585                    output_boxes,
586                )?;
587                Ok(Some(proto))
588            }
589        }
590    }
591
592    /// Decodes floating-point model outputs into detection boxes, returning
593    /// raw `ProtoData` for segmentation models instead of materialized masks.
594    ///
595    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
596    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
597    /// for YOLO segmentation models.
598    pub(crate) fn decode_float_proto<T>(
599        &self,
600        outputs: &[ArrayViewD<T>],
601        output_boxes: &mut Vec<DetectBox>,
602    ) -> Result<Option<ProtoData>, DecoderError>
603    where
604        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
605        f32: AsPrimitive<T>,
606    {
607        output_boxes.clear();
608        match &self.model_type {
609            // Detection-only variants: decode boxes, return None for proto data.
610            ModelType::ModelPackDet { boxes, scores } => {
611                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
612                Ok(None)
613            }
614            ModelType::ModelPackDetSplit { detection } => {
615                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
616                Ok(None)
617            }
618            ModelType::YoloDet { boxes } => {
619                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
620                Ok(None)
621            }
622            ModelType::YoloSplitDet { boxes, scores } => {
623                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
624                Ok(None)
625            }
626            ModelType::YoloEndToEndDet { boxes } => {
627                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
628                Ok(None)
629            }
630            ModelType::YoloSplitEndToEndDet {
631                boxes,
632                scores,
633                classes,
634            } => {
635                self.decode_yolo_split_end_to_end_det_float(
636                    outputs,
637                    boxes,
638                    scores,
639                    classes,
640                    output_boxes,
641                )?;
642                Ok(None)
643            }
644            // ModelPack seg/segdet variants have no YOLO proto data.
645            ModelType::ModelPackSegDet { boxes, scores, .. } => {
646                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
647                Ok(None)
648            }
649            ModelType::ModelPackSegDetSplit { detection, .. } => {
650                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
651                Ok(None)
652            }
653            ModelType::ModelPackSeg { .. } => Ok(None),
654
655            ModelType::YoloSegDet { boxes, protos } => {
656                let proto =
657                    self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
658                Ok(Some(proto))
659            }
660            ModelType::YoloSplitSegDet {
661                boxes,
662                scores,
663                mask_coeff,
664                protos,
665            } => {
666                let proto = self.decode_yolo_split_segdet_float_proto(
667                    outputs,
668                    boxes,
669                    scores,
670                    mask_coeff,
671                    protos,
672                    output_boxes,
673                )?;
674                Ok(Some(proto))
675            }
676            ModelType::YoloSegDet2Way {
677                boxes,
678                mask_coeff,
679                protos,
680            } => {
681                let proto = self.decode_yolo_segdet_2way_float_proto(
682                    outputs,
683                    boxes,
684                    mask_coeff,
685                    protos,
686                    output_boxes,
687                )?;
688                Ok(Some(proto))
689            }
690            ModelType::YoloEndToEndSegDet { boxes, protos } => {
691                let proto = self.decode_yolo_end_to_end_segdet_float_proto(
692                    outputs,
693                    boxes,
694                    protos,
695                    output_boxes,
696                )?;
697                Ok(Some(proto))
698            }
699            ModelType::YoloSplitEndToEndSegDet {
700                boxes,
701                scores,
702                classes,
703                mask_coeff,
704                protos,
705            } => {
706                let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
707                    outputs,
708                    boxes,
709                    scores,
710                    classes,
711                    mask_coeff,
712                    protos,
713                    output_boxes,
714                )?;
715                Ok(Some(proto))
716            }
717        }
718    }
719
720    // ========================================================================
721    // TensorDyn-based public API
722    // ========================================================================
723
724    /// Decode model outputs into detection boxes and segmentation masks.
725    ///
726    /// This is the primary decode API. Accepts `TensorDyn` outputs directly
727    /// from model inference. Automatically dispatches to quantized or float
728    /// paths based on the tensor dtype.
729    ///
730    /// # Arguments
731    ///
732    /// * `outputs` - Tensor outputs from model inference
733    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
734    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
735    ///
736    /// # Errors
737    ///
738    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
739    /// or the outputs don't match the decoder's model configuration.
740    pub fn decode(
741        &self,
742        outputs: &[&edgefirst_tensor::TensorDyn],
743        output_boxes: &mut Vec<DetectBox>,
744        output_masks: &mut Vec<Segmentation>,
745    ) -> Result<(), DecoderError> {
746        // Schema v2 merge path: dequantize physical children into
747        // logical float32 tensors, then feed through the float dispatch.
748        if let Some(program) = &self.decode_program {
749            let merged = program.execute(outputs)?;
750            let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
751            return self.decode_float(&views, output_boxes, output_masks);
752        }
753
754        let mapped = tensor_bridge::map_tensors(outputs)?;
755        match &mapped {
756            tensor_bridge::MappedOutputs::Quantized(maps) => {
757                let views = tensor_bridge::quantized_views(maps)?;
758                self.decode_quantized(&views, output_boxes, output_masks)
759            }
760            tensor_bridge::MappedOutputs::Float32(maps) => {
761                let views = tensor_bridge::f32_views(maps)?;
762                self.decode_float(&views, output_boxes, output_masks)
763            }
764            tensor_bridge::MappedOutputs::Float64(maps) => {
765                let views = tensor_bridge::f64_views(maps)?;
766                self.decode_float(&views, output_boxes, output_masks)
767            }
768        }
769    }
770
771    /// Decode model outputs into detection boxes, returning raw proto data
772    /// for segmentation models instead of materialized masks.
773    ///
774    /// Accepts `TensorDyn` outputs directly from model inference.
775    /// Detections are always decoded into `output_boxes` regardless of model type.
776    /// Returns `Ok(None)` for detection-only and ModelPack models.
777    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
778    ///
779    /// # Arguments
780    ///
781    /// * `outputs` - Tensor outputs from model inference
782    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
783    ///
784    /// # Errors
785    ///
786    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
787    /// or the outputs don't match the decoder's model configuration.
788    pub fn decode_proto(
789        &self,
790        outputs: &[&edgefirst_tensor::TensorDyn],
791        output_boxes: &mut Vec<DetectBox>,
792    ) -> Result<Option<ProtoData>, DecoderError> {
793        let mapped = tensor_bridge::map_tensors(outputs)?;
794        match &mapped {
795            tensor_bridge::MappedOutputs::Quantized(maps) => {
796                let views = tensor_bridge::quantized_views(maps)?;
797                self.decode_quantized_proto(&views, output_boxes)
798            }
799            tensor_bridge::MappedOutputs::Float32(maps) => {
800                let views = tensor_bridge::f32_views(maps)?;
801                self.decode_float_proto(&views, output_boxes)
802            }
803            tensor_bridge::MappedOutputs::Float64(maps) => {
804                let views = tensor_bridge::f64_views(maps)?;
805                self.decode_float_proto(&views, output_boxes)
806            }
807        }
808    }
809}
810
811#[cfg(feature = "tracker")]
812pub use edgefirst_tracker::TrackInfo;
813
814#[cfg(feature = "tracker")]
815pub use edgefirst_tracker::Tracker;
816
817#[cfg(feature = "tracker")]
818impl Decoder {
819    /// Decode quantized model outputs into detection boxes and segmentation
820    /// masks with tracking. Clears the provided output vectors before
821    /// populating them.
822    pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
823        &self,
824        tracker: &mut TR,
825        timestamp: u64,
826        outputs: &[ArrayViewDQuantized],
827        output_boxes: &mut Vec<DetectBox>,
828        output_masks: &mut Vec<Segmentation>,
829        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
830    ) -> Result<(), DecoderError> {
831        output_boxes.clear();
832        output_masks.clear();
833        output_tracks.clear();
834
835        // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
836        // Only boxes that come from decoding can be used for proto/mask generation.
837        match &self.model_type {
838            ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
839                tracker,
840                timestamp,
841                outputs,
842                boxes,
843                protos,
844                output_boxes,
845                output_masks,
846                output_tracks,
847            ),
848            ModelType::YoloSplitSegDet {
849                boxes,
850                scores,
851                mask_coeff,
852                protos,
853            } => self.decode_tracked_yolo_split_segdet_quantized(
854                tracker,
855                timestamp,
856                outputs,
857                boxes,
858                scores,
859                mask_coeff,
860                protos,
861                output_boxes,
862                output_masks,
863                output_tracks,
864            ),
865            ModelType::YoloEndToEndSegDet { boxes, protos } => self
866                .decode_tracked_yolo_end_to_end_segdet_quantized(
867                    tracker,
868                    timestamp,
869                    outputs,
870                    boxes,
871                    protos,
872                    output_boxes,
873                    output_masks,
874                    output_tracks,
875                ),
876            ModelType::YoloSplitEndToEndSegDet {
877                boxes,
878                scores,
879                classes,
880                mask_coeff,
881                protos,
882            } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
883                tracker,
884                timestamp,
885                outputs,
886                boxes,
887                scores,
888                classes,
889                mask_coeff,
890                protos,
891                output_boxes,
892                output_masks,
893                output_tracks,
894            ),
895            ModelType::YoloSegDet2Way {
896                boxes,
897                mask_coeff,
898                protos,
899            } => self.decode_tracked_yolo_segdet_2way_quantized(
900                tracker,
901                timestamp,
902                outputs,
903                boxes,
904                mask_coeff,
905                protos,
906                output_boxes,
907                output_masks,
908                output_tracks,
909            ),
910            _ => {
911                self.decode_quantized(outputs, output_boxes, output_masks)?;
912                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
913                Ok(())
914            }
915        }
916    }
917
918    /// This function decodes floating point model outputs into detection boxes
919    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
920    /// masks will be decoded. The function clears the provided output
921    /// vectors before populating them with the decoded results.
922    ///
923    /// This function returns an `Error` if the provided outputs don't
924    /// match the configuration provided by the user when building the decoder.
925    ///
926    /// Any quantization information in the configuration will be ignored.
927    pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
928        &self,
929        tracker: &mut TR,
930        timestamp: u64,
931        outputs: &[ArrayViewD<T>],
932        output_boxes: &mut Vec<DetectBox>,
933        output_masks: &mut Vec<Segmentation>,
934        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
935    ) -> Result<(), DecoderError>
936    where
937        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
938        f32: AsPrimitive<T>,
939    {
940        output_boxes.clear();
941        output_masks.clear();
942        output_tracks.clear();
943        match &self.model_type {
944            ModelType::YoloSegDet { boxes, protos } => {
945                self.decode_tracked_yolo_segdet_float(
946                    tracker,
947                    timestamp,
948                    outputs,
949                    boxes,
950                    protos,
951                    output_boxes,
952                    output_masks,
953                    output_tracks,
954                )?;
955            }
956            ModelType::YoloSplitSegDet {
957                boxes,
958                scores,
959                mask_coeff,
960                protos,
961            } => {
962                self.decode_tracked_yolo_split_segdet_float(
963                    tracker,
964                    timestamp,
965                    outputs,
966                    boxes,
967                    scores,
968                    mask_coeff,
969                    protos,
970                    output_boxes,
971                    output_masks,
972                    output_tracks,
973                )?;
974            }
975            ModelType::YoloEndToEndSegDet { boxes, protos } => {
976                self.decode_tracked_yolo_end_to_end_segdet_float(
977                    tracker,
978                    timestamp,
979                    outputs,
980                    boxes,
981                    protos,
982                    output_boxes,
983                    output_masks,
984                    output_tracks,
985                )?;
986            }
987            ModelType::YoloSplitEndToEndSegDet {
988                boxes,
989                scores,
990                classes,
991                mask_coeff,
992                protos,
993            } => {
994                self.decode_tracked_yolo_split_end_to_end_segdet_float(
995                    tracker,
996                    timestamp,
997                    outputs,
998                    boxes,
999                    scores,
1000                    classes,
1001                    mask_coeff,
1002                    protos,
1003                    output_boxes,
1004                    output_masks,
1005                    output_tracks,
1006                )?;
1007            }
1008            ModelType::YoloSegDet2Way {
1009                boxes,
1010                mask_coeff,
1011                protos,
1012            } => {
1013                self.decode_tracked_yolo_segdet_2way_float(
1014                    tracker,
1015                    timestamp,
1016                    outputs,
1017                    boxes,
1018                    mask_coeff,
1019                    protos,
1020                    output_boxes,
1021                    output_masks,
1022                    output_tracks,
1023                )?;
1024            }
1025            _ => {
1026                self.decode_float(outputs, output_boxes, output_masks)?;
1027                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1028            }
1029        }
1030        Ok(())
1031    }
1032
1033    /// Decodes quantized model outputs into detection boxes, returning raw
1034    /// `ProtoData` for segmentation models instead of materialized masks.
1035    ///
1036    /// Returns `Ok(None)` for detection-only and ModelPack models (use
1037    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
1038    /// YOLO segmentation models.
1039    pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1040        &self,
1041        tracker: &mut TR,
1042        timestamp: u64,
1043        outputs: &[ArrayViewDQuantized],
1044        output_boxes: &mut Vec<DetectBox>,
1045        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1046    ) -> Result<Option<ProtoData>, DecoderError> {
1047        output_boxes.clear();
1048        output_tracks.clear();
1049        match &self.model_type {
1050            ModelType::YoloSegDet { boxes, protos } => {
1051                let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1052                    tracker,
1053                    timestamp,
1054                    outputs,
1055                    boxes,
1056                    protos,
1057                    output_boxes,
1058                    output_tracks,
1059                )?;
1060                Ok(Some(proto))
1061            }
1062            ModelType::YoloSplitSegDet {
1063                boxes,
1064                scores,
1065                mask_coeff,
1066                protos,
1067            } => {
1068                let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1069                    tracker,
1070                    timestamp,
1071                    outputs,
1072                    boxes,
1073                    scores,
1074                    mask_coeff,
1075                    protos,
1076                    output_boxes,
1077                    output_tracks,
1078                )?;
1079                Ok(Some(proto))
1080            }
1081            ModelType::YoloSegDet2Way {
1082                boxes,
1083                mask_coeff,
1084                protos,
1085            } => {
1086                let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1087                    tracker,
1088                    timestamp,
1089                    outputs,
1090                    boxes,
1091                    mask_coeff,
1092                    protos,
1093                    output_boxes,
1094                    output_tracks,
1095                )?;
1096                Ok(Some(proto))
1097            }
1098            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1099                let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1100                    tracker,
1101                    timestamp,
1102                    outputs,
1103                    boxes,
1104                    protos,
1105                    output_boxes,
1106                    output_tracks,
1107                )?;
1108                Ok(Some(proto))
1109            }
1110            ModelType::YoloSplitEndToEndSegDet {
1111                boxes,
1112                scores,
1113                classes,
1114                mask_coeff,
1115                protos,
1116            } => {
1117                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1118                    tracker,
1119                    timestamp,
1120                    outputs,
1121                    boxes,
1122                    scores,
1123                    classes,
1124                    mask_coeff,
1125                    protos,
1126                    output_boxes,
1127                    output_tracks,
1128                )?;
1129                Ok(Some(proto))
1130            }
1131            // Non-seg variants: decode boxes via the non-proto path, then track.
1132            _ => {
1133                let mut masks = Vec::new();
1134                self.decode_quantized(outputs, output_boxes, &mut masks)?;
1135                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1136                Ok(None)
1137            }
1138        }
1139    }
1140
1141    /// Decodes floating-point model outputs into detection boxes, returning
1142    /// raw `ProtoData` for segmentation models instead of materialized masks.
1143    ///
1144    /// Detections are always decoded into `output_boxes` regardless of model type.
1145    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1146    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1147    pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1148        &self,
1149        tracker: &mut TR,
1150        timestamp: u64,
1151        outputs: &[ArrayViewD<T>],
1152        output_boxes: &mut Vec<DetectBox>,
1153        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1154    ) -> Result<Option<ProtoData>, DecoderError>
1155    where
1156        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1157        f32: AsPrimitive<T>,
1158    {
1159        output_boxes.clear();
1160        output_tracks.clear();
1161        match &self.model_type {
1162            ModelType::YoloSegDet { boxes, protos } => {
1163                let proto = self.decode_tracked_yolo_segdet_float_proto(
1164                    tracker,
1165                    timestamp,
1166                    outputs,
1167                    boxes,
1168                    protos,
1169                    output_boxes,
1170                    output_tracks,
1171                )?;
1172                Ok(Some(proto))
1173            }
1174            ModelType::YoloSplitSegDet {
1175                boxes,
1176                scores,
1177                mask_coeff,
1178                protos,
1179            } => {
1180                let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1181                    tracker,
1182                    timestamp,
1183                    outputs,
1184                    boxes,
1185                    scores,
1186                    mask_coeff,
1187                    protos,
1188                    output_boxes,
1189                    output_tracks,
1190                )?;
1191                Ok(Some(proto))
1192            }
1193            ModelType::YoloSegDet2Way {
1194                boxes,
1195                mask_coeff,
1196                protos,
1197            } => {
1198                let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1199                    tracker,
1200                    timestamp,
1201                    outputs,
1202                    boxes,
1203                    mask_coeff,
1204                    protos,
1205                    output_boxes,
1206                    output_tracks,
1207                )?;
1208                Ok(Some(proto))
1209            }
1210            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1211                let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1212                    tracker,
1213                    timestamp,
1214                    outputs,
1215                    boxes,
1216                    protos,
1217                    output_boxes,
1218                    output_tracks,
1219                )?;
1220                Ok(Some(proto))
1221            }
1222            ModelType::YoloSplitEndToEndSegDet {
1223                boxes,
1224                scores,
1225                classes,
1226                mask_coeff,
1227                protos,
1228            } => {
1229                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1230                    tracker,
1231                    timestamp,
1232                    outputs,
1233                    boxes,
1234                    scores,
1235                    classes,
1236                    mask_coeff,
1237                    protos,
1238                    output_boxes,
1239                    output_tracks,
1240                )?;
1241                Ok(Some(proto))
1242            }
1243            // Non-seg variants: decode boxes via the non-proto path, then track.
1244            _ => {
1245                let mut masks = Vec::new();
1246                self.decode_float(outputs, output_boxes, &mut masks)?;
1247                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1248                Ok(None)
1249            }
1250        }
1251    }
1252
1253    // ========================================================================
1254    // TensorDyn-based tracked public API
1255    // ========================================================================
1256
1257    /// Decode model outputs with tracking.
1258    ///
1259    /// Accepts `TensorDyn` outputs directly from model inference. Automatically
1260    /// dispatches to quantized or float paths based on the tensor dtype, then
1261    /// updates the tracker with the decoded boxes.
1262    ///
1263    /// # Arguments
1264    ///
1265    /// * `tracker` - The tracker instance to update
1266    /// * `timestamp` - Current frame timestamp
1267    /// * `outputs` - Tensor outputs from model inference
1268    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1269    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
1270    /// * `output_tracks` - Destination for track info (cleared first)
1271    ///
1272    /// # Errors
1273    ///
1274    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1275    /// or the outputs don't match the decoder's model configuration.
1276    pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1277        &self,
1278        tracker: &mut TR,
1279        timestamp: u64,
1280        outputs: &[&edgefirst_tensor::TensorDyn],
1281        output_boxes: &mut Vec<DetectBox>,
1282        output_masks: &mut Vec<Segmentation>,
1283        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1284    ) -> Result<(), DecoderError> {
1285        let mapped = tensor_bridge::map_tensors(outputs)?;
1286        match &mapped {
1287            tensor_bridge::MappedOutputs::Quantized(maps) => {
1288                let views = tensor_bridge::quantized_views(maps)?;
1289                self.decode_tracked_quantized(
1290                    tracker,
1291                    timestamp,
1292                    &views,
1293                    output_boxes,
1294                    output_masks,
1295                    output_tracks,
1296                )
1297            }
1298            tensor_bridge::MappedOutputs::Float32(maps) => {
1299                let views = tensor_bridge::f32_views(maps)?;
1300                self.decode_tracked_float(
1301                    tracker,
1302                    timestamp,
1303                    &views,
1304                    output_boxes,
1305                    output_masks,
1306                    output_tracks,
1307                )
1308            }
1309            tensor_bridge::MappedOutputs::Float64(maps) => {
1310                let views = tensor_bridge::f64_views(maps)?;
1311                self.decode_tracked_float(
1312                    tracker,
1313                    timestamp,
1314                    &views,
1315                    output_boxes,
1316                    output_masks,
1317                    output_tracks,
1318                )
1319            }
1320        }
1321    }
1322
1323    /// Decode model outputs with tracking, returning raw proto data for
1324    /// segmentation models.
1325    ///
1326    /// Accepts `TensorDyn` outputs directly from model inference.
1327    /// Returns `Ok(None)` for detection-only and ModelPack models.
1328    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1329    ///
1330    /// # Arguments
1331    ///
1332    /// * `tracker` - The tracker instance to update
1333    /// * `timestamp` - Current frame timestamp
1334    /// * `outputs` - Tensor outputs from model inference
1335    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1336    /// * `output_tracks` - Destination for track info (cleared first)
1337    ///
1338    /// # Errors
1339    ///
1340    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1341    /// or the outputs don't match the decoder's model configuration.
1342    pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1343        &self,
1344        tracker: &mut TR,
1345        timestamp: u64,
1346        outputs: &[&edgefirst_tensor::TensorDyn],
1347        output_boxes: &mut Vec<DetectBox>,
1348        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1349    ) -> Result<Option<ProtoData>, DecoderError> {
1350        let mapped = tensor_bridge::map_tensors(outputs)?;
1351        match &mapped {
1352            tensor_bridge::MappedOutputs::Quantized(maps) => {
1353                let views = tensor_bridge::quantized_views(maps)?;
1354                self.decode_tracked_quantized_proto(
1355                    tracker,
1356                    timestamp,
1357                    &views,
1358                    output_boxes,
1359                    output_tracks,
1360                )
1361            }
1362            tensor_bridge::MappedOutputs::Float32(maps) => {
1363                let views = tensor_bridge::f32_views(maps)?;
1364                self.decode_tracked_float_proto(
1365                    tracker,
1366                    timestamp,
1367                    &views,
1368                    output_boxes,
1369                    output_tracks,
1370                )
1371            }
1372            tensor_bridge::MappedOutputs::Float64(maps) => {
1373                let views = tensor_bridge::f64_views(maps)?;
1374                self.decode_tracked_float_proto(
1375                    tracker,
1376                    timestamp,
1377                    &views,
1378                    output_boxes,
1379                    output_tracks,
1380                )
1381            }
1382        }
1383    }
1384}