Skip to main content

edgefirst_decoder/decoder/
mod.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone)]
15pub struct Decoder {
16    model_type: ModelType,
17    pub iou_threshold: f32,
18    pub score_threshold: f32,
19    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
20    /// models)
21    pub nms: Option<configs::Nms>,
22    /// Whether decoded boxes are in normalized [0,1] coordinates.
23    /// - `Some(true)`: Coordinates in [0,1] range
24    /// - `Some(false)`: Pixel coordinates
25    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
26    ///   1.0)
27    normalized: Option<bool>,
28    /// Schema v2 merge program. Present when the decoder was built from
29    /// a [`crate::schema::SchemaV2`] whose logical outputs carry
30    /// physical children. Absent for flat configurations (v1 and
31    /// flat-v2).
32    pub(crate) decode_program: Option<merge::DecodeProgram>,
33}
34
35impl PartialEq for Decoder {
36    fn eq(&self, other: &Self) -> bool {
37        // DecodeProgram has non-comparable embedded data; compare by
38        // the config-derived fields only.
39        self.model_type == other.model_type
40            && self.iou_threshold == other.iou_threshold
41            && self.score_threshold == other.score_threshold
42            && self.nms == other.nms
43            && self.normalized == other.normalized
44            && self.decode_program.is_some() == other.decode_program.is_some()
45    }
46}
47
48#[derive(Debug)]
49pub(crate) enum ArrayViewDQuantized<'a> {
50    UInt8(ArrayViewD<'a, u8>),
51    Int8(ArrayViewD<'a, i8>),
52    UInt16(ArrayViewD<'a, u16>),
53    Int16(ArrayViewD<'a, i16>),
54    UInt32(ArrayViewD<'a, u32>),
55    Int32(ArrayViewD<'a, i32>),
56}
57
58impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
59where
60    D: Dimension,
61{
62    fn from(arr: ArrayView<'a, u8, D>) -> Self {
63        Self::UInt8(arr.into_dyn())
64    }
65}
66
67impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
68where
69    D: Dimension,
70{
71    fn from(arr: ArrayView<'a, i8, D>) -> Self {
72        Self::Int8(arr.into_dyn())
73    }
74}
75
76impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
77where
78    D: Dimension,
79{
80    fn from(arr: ArrayView<'a, u16, D>) -> Self {
81        Self::UInt16(arr.into_dyn())
82    }
83}
84
85impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
86where
87    D: Dimension,
88{
89    fn from(arr: ArrayView<'a, i16, D>) -> Self {
90        Self::Int16(arr.into_dyn())
91    }
92}
93
94impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
95where
96    D: Dimension,
97{
98    fn from(arr: ArrayView<'a, u32, D>) -> Self {
99        Self::UInt32(arr.into_dyn())
100    }
101}
102
103impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
104where
105    D: Dimension,
106{
107    fn from(arr: ArrayView<'a, i32, D>) -> Self {
108        Self::Int32(arr.into_dyn())
109    }
110}
111
112impl<'a> ArrayViewDQuantized<'a> {
113    /// Returns the shape of the underlying array.
114    pub(crate) fn shape(&self) -> &[usize] {
115        match self {
116            ArrayViewDQuantized::UInt8(a) => a.shape(),
117            ArrayViewDQuantized::Int8(a) => a.shape(),
118            ArrayViewDQuantized::UInt16(a) => a.shape(),
119            ArrayViewDQuantized::Int16(a) => a.shape(),
120            ArrayViewDQuantized::UInt32(a) => a.shape(),
121            ArrayViewDQuantized::Int32(a) => a.shape(),
122        }
123    }
124}
125
126/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
127/// monomorphized code paths by 6 (one per integer variant), so nesting
128/// N levels deep produces 6^N instantiations.
129///
130/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
131/// (6*N paths) or split into independent phases that each nest at most 2 levels.
132macro_rules! with_quantized {
133    ($x:expr, $var:ident, $body:expr) => {
134        match $x {
135            ArrayViewDQuantized::UInt8(x) => {
136                let $var = x;
137                $body
138            }
139            ArrayViewDQuantized::Int8(x) => {
140                let $var = x;
141                $body
142            }
143            ArrayViewDQuantized::UInt16(x) => {
144                let $var = x;
145                $body
146            }
147            ArrayViewDQuantized::Int16(x) => {
148                let $var = x;
149                $body
150            }
151            ArrayViewDQuantized::UInt32(x) => {
152                let $var = x;
153                $body
154            }
155            ArrayViewDQuantized::Int32(x) => {
156                let $var = x;
157                $body
158            }
159        }
160    };
161}
162
163mod builder;
164mod dfl;
165mod helpers;
166mod merge;
167mod postprocess;
168mod tensor_bridge;
169mod tests;
170
171pub use builder::DecoderBuilder;
172pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
173
174impl Decoder {
175    /// This function returns the parsed model type of the decoder.
176    ///
177    /// # Examples
178    ///
179    /// ```rust
180    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
181    /// # fn main() -> DecoderResult<()> {
182    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
183    ///     let decoder = DecoderBuilder::default()
184    ///         .with_config_yaml_str(config_yaml)
185    ///         .build()?;
186    ///     assert!(matches!(
187    ///         decoder.model_type(),
188    ///         ModelType::ModelPackDetSplit { .. }
189    ///     ));
190    /// #    Ok(())
191    /// # }
192    /// ```
193    pub fn model_type(&self) -> &ModelType {
194        &self.model_type
195    }
196
197    /// Returns the box coordinate format if known from the model config.
198    ///
199    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
200    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
201    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
202    ///   1.0)
203    ///
204    /// This is determined by the model config's `normalized` field, not the NMS
205    /// mode. When coordinates are in pixels or unknown, the caller may need
206    /// to normalize using the model input dimensions.
207    ///
208    /// # Examples
209    ///
210    /// ```rust
211    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
212    /// # fn main() -> DecoderResult<()> {
213    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
214    ///     let decoder = DecoderBuilder::default()
215    ///         .with_config_yaml_str(config_yaml)
216    ///         .build()?;
217    ///     // Config doesn't specify normalized, so it's None
218    ///     assert!(decoder.normalized_boxes().is_none());
219    /// #    Ok(())
220    /// # }
221    /// ```
222    pub fn normalized_boxes(&self) -> Option<bool> {
223        self.normalized
224    }
225
226    /// Decode quantized model outputs into detection boxes and segmentation
227    /// masks. The quantized outputs can be of u8, i8, u16, i16, u32, or i32
228    /// types. Clears the provided output vectors before populating them.
229    pub(crate) fn decode_quantized(
230        &self,
231        outputs: &[ArrayViewDQuantized],
232        output_boxes: &mut Vec<DetectBox>,
233        output_masks: &mut Vec<Segmentation>,
234    ) -> Result<(), DecoderError> {
235        output_boxes.clear();
236        output_masks.clear();
237        match &self.model_type {
238            ModelType::ModelPackSegDet {
239                boxes,
240                scores,
241                segmentation,
242            } => {
243                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
244                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
245            }
246            ModelType::ModelPackSegDetSplit {
247                detection,
248                segmentation,
249            } => {
250                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
251                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
252            }
253            ModelType::ModelPackDet { boxes, scores } => {
254                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
255            }
256            ModelType::ModelPackDetSplit { detection } => {
257                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
258            }
259            ModelType::ModelPackSeg { segmentation } => {
260                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
261            }
262            ModelType::YoloDet { boxes } => {
263                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
264            }
265            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
266                outputs,
267                boxes,
268                protos,
269                output_boxes,
270                output_masks,
271            ),
272            ModelType::YoloSplitDet { boxes, scores } => {
273                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
274            }
275            ModelType::YoloSplitSegDet {
276                boxes,
277                scores,
278                mask_coeff,
279                protos,
280            } => self.decode_yolo_split_segdet_quantized(
281                outputs,
282                boxes,
283                scores,
284                mask_coeff,
285                protos,
286                output_boxes,
287                output_masks,
288            ),
289            ModelType::YoloSegDet2Way {
290                boxes,
291                mask_coeff,
292                protos,
293            } => self.decode_yolo_segdet_2way_quantized(
294                outputs,
295                boxes,
296                mask_coeff,
297                protos,
298                output_boxes,
299                output_masks,
300            ),
301            ModelType::YoloEndToEndDet { boxes } => {
302                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
303            }
304            ModelType::YoloEndToEndSegDet { boxes, protos } => self
305                .decode_yolo_end_to_end_segdet_quantized(
306                    outputs,
307                    boxes,
308                    protos,
309                    output_boxes,
310                    output_masks,
311                ),
312            ModelType::YoloSplitEndToEndDet {
313                boxes,
314                scores,
315                classes,
316            } => self.decode_yolo_split_end_to_end_det_quantized(
317                outputs,
318                boxes,
319                scores,
320                classes,
321                output_boxes,
322            ),
323            ModelType::YoloSplitEndToEndSegDet {
324                boxes,
325                scores,
326                classes,
327                mask_coeff,
328                protos,
329            } => self.decode_yolo_split_end_to_end_segdet_quantized(
330                outputs,
331                boxes,
332                scores,
333                classes,
334                mask_coeff,
335                protos,
336                output_boxes,
337                output_masks,
338            ),
339        }
340    }
341
342    /// Decode floating point model outputs into detection boxes and
343    /// segmentation masks. Clears the provided output vectors before
344    /// populating them.
345    pub(crate) fn decode_float<T>(
346        &self,
347        outputs: &[ArrayViewD<T>],
348        output_boxes: &mut Vec<DetectBox>,
349        output_masks: &mut Vec<Segmentation>,
350    ) -> Result<(), DecoderError>
351    where
352        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
353        f32: AsPrimitive<T>,
354    {
355        output_boxes.clear();
356        output_masks.clear();
357        match &self.model_type {
358            ModelType::ModelPackSegDet {
359                boxes,
360                scores,
361                segmentation,
362            } => {
363                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
364                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
365            }
366            ModelType::ModelPackSegDetSplit {
367                detection,
368                segmentation,
369            } => {
370                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
371                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
372            }
373            ModelType::ModelPackDet { boxes, scores } => {
374                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
375            }
376            ModelType::ModelPackDetSplit { detection } => {
377                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
378            }
379            ModelType::ModelPackSeg { segmentation } => {
380                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
381            }
382            ModelType::YoloDet { boxes } => {
383                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
384            }
385            ModelType::YoloSegDet { boxes, protos } => {
386                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
387            }
388            ModelType::YoloSplitDet { boxes, scores } => {
389                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
390            }
391            ModelType::YoloSplitSegDet {
392                boxes,
393                scores,
394                mask_coeff,
395                protos,
396            } => {
397                self.decode_yolo_split_segdet_float(
398                    outputs,
399                    boxes,
400                    scores,
401                    mask_coeff,
402                    protos,
403                    output_boxes,
404                    output_masks,
405                )?;
406            }
407            ModelType::YoloSegDet2Way {
408                boxes,
409                mask_coeff,
410                protos,
411            } => {
412                self.decode_yolo_segdet_2way_float(
413                    outputs,
414                    boxes,
415                    mask_coeff,
416                    protos,
417                    output_boxes,
418                    output_masks,
419                )?;
420            }
421            ModelType::YoloEndToEndDet { boxes } => {
422                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
423            }
424            ModelType::YoloEndToEndSegDet { boxes, protos } => {
425                self.decode_yolo_end_to_end_segdet_float(
426                    outputs,
427                    boxes,
428                    protos,
429                    output_boxes,
430                    output_masks,
431                )?;
432            }
433            ModelType::YoloSplitEndToEndDet {
434                boxes,
435                scores,
436                classes,
437            } => {
438                self.decode_yolo_split_end_to_end_det_float(
439                    outputs,
440                    boxes,
441                    scores,
442                    classes,
443                    output_boxes,
444                )?;
445            }
446            ModelType::YoloSplitEndToEndSegDet {
447                boxes,
448                scores,
449                classes,
450                mask_coeff,
451                protos,
452            } => {
453                self.decode_yolo_split_end_to_end_segdet_float(
454                    outputs,
455                    boxes,
456                    scores,
457                    classes,
458                    mask_coeff,
459                    protos,
460                    output_boxes,
461                    output_masks,
462                )?;
463            }
464        }
465        Ok(())
466    }
467
468    /// Decodes quantized model outputs into detection boxes, returning raw
469    /// `ProtoData` for segmentation models instead of materialized masks.
470    ///
471    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
472    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
473    /// for YOLO segmentation models.
474    pub(crate) fn decode_quantized_proto(
475        &self,
476        outputs: &[ArrayViewDQuantized],
477        output_boxes: &mut Vec<DetectBox>,
478    ) -> Result<Option<ProtoData>, DecoderError> {
479        output_boxes.clear();
480        match &self.model_type {
481            // Detection-only variants: decode boxes, return None for proto data.
482            ModelType::ModelPackDet { boxes, scores } => {
483                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
484                Ok(None)
485            }
486            ModelType::ModelPackDetSplit { detection } => {
487                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
488                Ok(None)
489            }
490            ModelType::YoloDet { boxes } => {
491                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
492                Ok(None)
493            }
494            ModelType::YoloSplitDet { boxes, scores } => {
495                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
496                Ok(None)
497            }
498            ModelType::YoloEndToEndDet { boxes } => {
499                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
500                Ok(None)
501            }
502            ModelType::YoloSplitEndToEndDet {
503                boxes,
504                scores,
505                classes,
506            } => {
507                self.decode_yolo_split_end_to_end_det_quantized(
508                    outputs,
509                    boxes,
510                    scores,
511                    classes,
512                    output_boxes,
513                )?;
514                Ok(None)
515            }
516            // ModelPack seg/segdet variants have no YOLO proto data.
517            ModelType::ModelPackSegDet { boxes, scores, .. } => {
518                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
519                Ok(None)
520            }
521            ModelType::ModelPackSegDetSplit { detection, .. } => {
522                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
523                Ok(None)
524            }
525            ModelType::ModelPackSeg { .. } => Ok(None),
526
527            ModelType::YoloSegDet { boxes, protos } => {
528                let proto =
529                    self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
530                Ok(Some(proto))
531            }
532            ModelType::YoloSplitSegDet {
533                boxes,
534                scores,
535                mask_coeff,
536                protos,
537            } => {
538                let proto = self.decode_yolo_split_segdet_quantized_proto(
539                    outputs,
540                    boxes,
541                    scores,
542                    mask_coeff,
543                    protos,
544                    output_boxes,
545                )?;
546                Ok(Some(proto))
547            }
548            ModelType::YoloSegDet2Way {
549                boxes,
550                mask_coeff,
551                protos,
552            } => {
553                let proto = self.decode_yolo_segdet_2way_quantized_proto(
554                    outputs,
555                    boxes,
556                    mask_coeff,
557                    protos,
558                    output_boxes,
559                )?;
560                Ok(Some(proto))
561            }
562            ModelType::YoloEndToEndSegDet { boxes, protos } => {
563                let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
564                    outputs,
565                    boxes,
566                    protos,
567                    output_boxes,
568                )?;
569                Ok(Some(proto))
570            }
571            ModelType::YoloSplitEndToEndSegDet {
572                boxes,
573                scores,
574                classes,
575                mask_coeff,
576                protos,
577            } => {
578                let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
579                    outputs,
580                    boxes,
581                    scores,
582                    classes,
583                    mask_coeff,
584                    protos,
585                    output_boxes,
586                )?;
587                Ok(Some(proto))
588            }
589        }
590    }
591
592    /// Decodes floating-point model outputs into detection boxes, returning
593    /// raw `ProtoData` for segmentation models instead of materialized masks.
594    ///
595    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
596    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
597    /// for YOLO segmentation models.
598    pub(crate) fn decode_float_proto<T>(
599        &self,
600        outputs: &[ArrayViewD<T>],
601        output_boxes: &mut Vec<DetectBox>,
602    ) -> Result<Option<ProtoData>, DecoderError>
603    where
604        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
605        f32: AsPrimitive<T>,
606    {
607        output_boxes.clear();
608        match &self.model_type {
609            // Detection-only variants: decode boxes, return None for proto data.
610            ModelType::ModelPackDet { boxes, scores } => {
611                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
612                Ok(None)
613            }
614            ModelType::ModelPackDetSplit { detection } => {
615                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
616                Ok(None)
617            }
618            ModelType::YoloDet { boxes } => {
619                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
620                Ok(None)
621            }
622            ModelType::YoloSplitDet { boxes, scores } => {
623                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
624                Ok(None)
625            }
626            ModelType::YoloEndToEndDet { boxes } => {
627                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
628                Ok(None)
629            }
630            ModelType::YoloSplitEndToEndDet {
631                boxes,
632                scores,
633                classes,
634            } => {
635                self.decode_yolo_split_end_to_end_det_float(
636                    outputs,
637                    boxes,
638                    scores,
639                    classes,
640                    output_boxes,
641                )?;
642                Ok(None)
643            }
644            // ModelPack seg/segdet variants have no YOLO proto data.
645            ModelType::ModelPackSegDet { boxes, scores, .. } => {
646                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
647                Ok(None)
648            }
649            ModelType::ModelPackSegDetSplit { detection, .. } => {
650                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
651                Ok(None)
652            }
653            ModelType::ModelPackSeg { .. } => Ok(None),
654
655            ModelType::YoloSegDet { boxes, protos } => {
656                let proto =
657                    self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
658                Ok(Some(proto))
659            }
660            ModelType::YoloSplitSegDet {
661                boxes,
662                scores,
663                mask_coeff,
664                protos,
665            } => {
666                let proto = self.decode_yolo_split_segdet_float_proto(
667                    outputs,
668                    boxes,
669                    scores,
670                    mask_coeff,
671                    protos,
672                    output_boxes,
673                )?;
674                Ok(Some(proto))
675            }
676            ModelType::YoloSegDet2Way {
677                boxes,
678                mask_coeff,
679                protos,
680            } => {
681                let proto = self.decode_yolo_segdet_2way_float_proto(
682                    outputs,
683                    boxes,
684                    mask_coeff,
685                    protos,
686                    output_boxes,
687                )?;
688                Ok(Some(proto))
689            }
690            ModelType::YoloEndToEndSegDet { boxes, protos } => {
691                let proto = self.decode_yolo_end_to_end_segdet_float_proto(
692                    outputs,
693                    boxes,
694                    protos,
695                    output_boxes,
696                )?;
697                Ok(Some(proto))
698            }
699            ModelType::YoloSplitEndToEndSegDet {
700                boxes,
701                scores,
702                classes,
703                mask_coeff,
704                protos,
705            } => {
706                let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
707                    outputs,
708                    boxes,
709                    scores,
710                    classes,
711                    mask_coeff,
712                    protos,
713                    output_boxes,
714                )?;
715                Ok(Some(proto))
716            }
717        }
718    }
719
720    // ========================================================================
721    // TensorDyn-based public API
722    // ========================================================================
723
724    /// Decode model outputs into detection boxes and segmentation masks.
725    ///
726    /// This is the primary decode API. Accepts `TensorDyn` outputs directly
727    /// from model inference. Automatically dispatches to quantized or float
728    /// paths based on the tensor dtype.
729    ///
730    /// # Arguments
731    ///
732    /// * `outputs` - Tensor outputs from model inference
733    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
734    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
735    ///
736    /// # Errors
737    ///
738    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
739    /// or the outputs don't match the decoder's model configuration.
740    pub fn decode(
741        &self,
742        outputs: &[&edgefirst_tensor::TensorDyn],
743        output_boxes: &mut Vec<DetectBox>,
744        output_masks: &mut Vec<Segmentation>,
745    ) -> Result<(), DecoderError> {
746        // Schema v2 merge path: dequantize physical children into
747        // logical float32 tensors, then feed through the float dispatch.
748        if let Some(program) = &self.decode_program {
749            let merged = program.execute(outputs)?;
750            let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
751            return self.decode_float(&views, output_boxes, output_masks);
752        }
753
754        let mapped = tensor_bridge::map_tensors(outputs)?;
755        match &mapped {
756            tensor_bridge::MappedOutputs::Quantized(maps) => {
757                let views = tensor_bridge::quantized_views(maps)?;
758                self.decode_quantized(&views, output_boxes, output_masks)
759            }
760            tensor_bridge::MappedOutputs::Float16(maps) => {
761                let views = tensor_bridge::f16_views(maps)?;
762                self.decode_float(&views, output_boxes, output_masks)
763            }
764            tensor_bridge::MappedOutputs::Float32(maps) => {
765                let views = tensor_bridge::f32_views(maps)?;
766                self.decode_float(&views, output_boxes, output_masks)
767            }
768            tensor_bridge::MappedOutputs::Float64(maps) => {
769                let views = tensor_bridge::f64_views(maps)?;
770                self.decode_float(&views, output_boxes, output_masks)
771            }
772        }
773    }
774
775    /// Decode model outputs into detection boxes, returning raw proto data
776    /// for segmentation models instead of materialized masks.
777    ///
778    /// Accepts `TensorDyn` outputs directly from model inference.
779    /// Detections are always decoded into `output_boxes` regardless of model type.
780    /// Returns `Ok(None)` for detection-only and ModelPack models.
781    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
782    ///
783    /// # Arguments
784    ///
785    /// * `outputs` - Tensor outputs from model inference
786    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
787    ///
788    /// # Errors
789    ///
790    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
791    /// or the outputs don't match the decoder's model configuration.
792    pub fn decode_proto(
793        &self,
794        outputs: &[&edgefirst_tensor::TensorDyn],
795        output_boxes: &mut Vec<DetectBox>,
796    ) -> Result<Option<ProtoData>, DecoderError> {
797        // Schema v2 merge path: dequantize physical children into
798        // logical float32 tensors, then feed through the float dispatch.
799        if let Some(program) = &self.decode_program {
800            let merged = program.execute(outputs)?;
801            let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
802            return self.decode_float_proto(&views, output_boxes);
803        }
804
805        let mapped = tensor_bridge::map_tensors(outputs)?;
806        match &mapped {
807            tensor_bridge::MappedOutputs::Quantized(maps) => {
808                let views = tensor_bridge::quantized_views(maps)?;
809                self.decode_quantized_proto(&views, output_boxes)
810            }
811            tensor_bridge::MappedOutputs::Float16(maps) => {
812                let views = tensor_bridge::f16_views(maps)?;
813                self.decode_float_proto(&views, output_boxes)
814            }
815            tensor_bridge::MappedOutputs::Float32(maps) => {
816                let views = tensor_bridge::f32_views(maps)?;
817                self.decode_float_proto(&views, output_boxes)
818            }
819            tensor_bridge::MappedOutputs::Float64(maps) => {
820                let views = tensor_bridge::f64_views(maps)?;
821                self.decode_float_proto(&views, output_boxes)
822            }
823        }
824    }
825}
826
827#[cfg(feature = "tracker")]
828pub use edgefirst_tracker::TrackInfo;
829
830#[cfg(feature = "tracker")]
831pub use edgefirst_tracker::Tracker;
832
833#[cfg(feature = "tracker")]
834impl Decoder {
835    /// Decode quantized model outputs into detection boxes and segmentation
836    /// masks with tracking. Clears the provided output vectors before
837    /// populating them.
838    pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
839        &self,
840        tracker: &mut TR,
841        timestamp: u64,
842        outputs: &[ArrayViewDQuantized],
843        output_boxes: &mut Vec<DetectBox>,
844        output_masks: &mut Vec<Segmentation>,
845        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
846    ) -> Result<(), DecoderError> {
847        output_boxes.clear();
848        output_masks.clear();
849        output_tracks.clear();
850
851        // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
852        // Only boxes that come from decoding can be used for proto/mask generation.
853        match &self.model_type {
854            ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
855                tracker,
856                timestamp,
857                outputs,
858                boxes,
859                protos,
860                output_boxes,
861                output_masks,
862                output_tracks,
863            ),
864            ModelType::YoloSplitSegDet {
865                boxes,
866                scores,
867                mask_coeff,
868                protos,
869            } => self.decode_tracked_yolo_split_segdet_quantized(
870                tracker,
871                timestamp,
872                outputs,
873                boxes,
874                scores,
875                mask_coeff,
876                protos,
877                output_boxes,
878                output_masks,
879                output_tracks,
880            ),
881            ModelType::YoloEndToEndSegDet { boxes, protos } => self
882                .decode_tracked_yolo_end_to_end_segdet_quantized(
883                    tracker,
884                    timestamp,
885                    outputs,
886                    boxes,
887                    protos,
888                    output_boxes,
889                    output_masks,
890                    output_tracks,
891                ),
892            ModelType::YoloSplitEndToEndSegDet {
893                boxes,
894                scores,
895                classes,
896                mask_coeff,
897                protos,
898            } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
899                tracker,
900                timestamp,
901                outputs,
902                boxes,
903                scores,
904                classes,
905                mask_coeff,
906                protos,
907                output_boxes,
908                output_masks,
909                output_tracks,
910            ),
911            ModelType::YoloSegDet2Way {
912                boxes,
913                mask_coeff,
914                protos,
915            } => self.decode_tracked_yolo_segdet_2way_quantized(
916                tracker,
917                timestamp,
918                outputs,
919                boxes,
920                mask_coeff,
921                protos,
922                output_boxes,
923                output_masks,
924                output_tracks,
925            ),
926            _ => {
927                self.decode_quantized(outputs, output_boxes, output_masks)?;
928                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
929                Ok(())
930            }
931        }
932    }
933
934    /// This function decodes floating point model outputs into detection boxes
935    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
936    /// masks will be decoded. The function clears the provided output
937    /// vectors before populating them with the decoded results.
938    ///
939    /// This function returns an `Error` if the provided outputs don't
940    /// match the configuration provided by the user when building the decoder.
941    ///
942    /// Any quantization information in the configuration will be ignored.
943    pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
944        &self,
945        tracker: &mut TR,
946        timestamp: u64,
947        outputs: &[ArrayViewD<T>],
948        output_boxes: &mut Vec<DetectBox>,
949        output_masks: &mut Vec<Segmentation>,
950        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
951    ) -> Result<(), DecoderError>
952    where
953        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
954        f32: AsPrimitive<T>,
955    {
956        output_boxes.clear();
957        output_masks.clear();
958        output_tracks.clear();
959        match &self.model_type {
960            ModelType::YoloSegDet { boxes, protos } => {
961                self.decode_tracked_yolo_segdet_float(
962                    tracker,
963                    timestamp,
964                    outputs,
965                    boxes,
966                    protos,
967                    output_boxes,
968                    output_masks,
969                    output_tracks,
970                )?;
971            }
972            ModelType::YoloSplitSegDet {
973                boxes,
974                scores,
975                mask_coeff,
976                protos,
977            } => {
978                self.decode_tracked_yolo_split_segdet_float(
979                    tracker,
980                    timestamp,
981                    outputs,
982                    boxes,
983                    scores,
984                    mask_coeff,
985                    protos,
986                    output_boxes,
987                    output_masks,
988                    output_tracks,
989                )?;
990            }
991            ModelType::YoloEndToEndSegDet { boxes, protos } => {
992                self.decode_tracked_yolo_end_to_end_segdet_float(
993                    tracker,
994                    timestamp,
995                    outputs,
996                    boxes,
997                    protos,
998                    output_boxes,
999                    output_masks,
1000                    output_tracks,
1001                )?;
1002            }
1003            ModelType::YoloSplitEndToEndSegDet {
1004                boxes,
1005                scores,
1006                classes,
1007                mask_coeff,
1008                protos,
1009            } => {
1010                self.decode_tracked_yolo_split_end_to_end_segdet_float(
1011                    tracker,
1012                    timestamp,
1013                    outputs,
1014                    boxes,
1015                    scores,
1016                    classes,
1017                    mask_coeff,
1018                    protos,
1019                    output_boxes,
1020                    output_masks,
1021                    output_tracks,
1022                )?;
1023            }
1024            ModelType::YoloSegDet2Way {
1025                boxes,
1026                mask_coeff,
1027                protos,
1028            } => {
1029                self.decode_tracked_yolo_segdet_2way_float(
1030                    tracker,
1031                    timestamp,
1032                    outputs,
1033                    boxes,
1034                    mask_coeff,
1035                    protos,
1036                    output_boxes,
1037                    output_masks,
1038                    output_tracks,
1039                )?;
1040            }
1041            _ => {
1042                self.decode_float(outputs, output_boxes, output_masks)?;
1043                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1044            }
1045        }
1046        Ok(())
1047    }
1048
1049    /// Decodes quantized model outputs into detection boxes, returning raw
1050    /// `ProtoData` for segmentation models instead of materialized masks.
1051    ///
1052    /// Returns `Ok(None)` for detection-only and ModelPack models (use
1053    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
1054    /// YOLO segmentation models.
1055    pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1056        &self,
1057        tracker: &mut TR,
1058        timestamp: u64,
1059        outputs: &[ArrayViewDQuantized],
1060        output_boxes: &mut Vec<DetectBox>,
1061        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1062    ) -> Result<Option<ProtoData>, DecoderError> {
1063        output_boxes.clear();
1064        output_tracks.clear();
1065        match &self.model_type {
1066            ModelType::YoloSegDet { boxes, protos } => {
1067                let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1068                    tracker,
1069                    timestamp,
1070                    outputs,
1071                    boxes,
1072                    protos,
1073                    output_boxes,
1074                    output_tracks,
1075                )?;
1076                Ok(Some(proto))
1077            }
1078            ModelType::YoloSplitSegDet {
1079                boxes,
1080                scores,
1081                mask_coeff,
1082                protos,
1083            } => {
1084                let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1085                    tracker,
1086                    timestamp,
1087                    outputs,
1088                    boxes,
1089                    scores,
1090                    mask_coeff,
1091                    protos,
1092                    output_boxes,
1093                    output_tracks,
1094                )?;
1095                Ok(Some(proto))
1096            }
1097            ModelType::YoloSegDet2Way {
1098                boxes,
1099                mask_coeff,
1100                protos,
1101            } => {
1102                let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1103                    tracker,
1104                    timestamp,
1105                    outputs,
1106                    boxes,
1107                    mask_coeff,
1108                    protos,
1109                    output_boxes,
1110                    output_tracks,
1111                )?;
1112                Ok(Some(proto))
1113            }
1114            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1115                let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1116                    tracker,
1117                    timestamp,
1118                    outputs,
1119                    boxes,
1120                    protos,
1121                    output_boxes,
1122                    output_tracks,
1123                )?;
1124                Ok(Some(proto))
1125            }
1126            ModelType::YoloSplitEndToEndSegDet {
1127                boxes,
1128                scores,
1129                classes,
1130                mask_coeff,
1131                protos,
1132            } => {
1133                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1134                    tracker,
1135                    timestamp,
1136                    outputs,
1137                    boxes,
1138                    scores,
1139                    classes,
1140                    mask_coeff,
1141                    protos,
1142                    output_boxes,
1143                    output_tracks,
1144                )?;
1145                Ok(Some(proto))
1146            }
1147            // Non-seg variants: decode boxes via the non-proto path, then track.
1148            _ => {
1149                let mut masks = Vec::new();
1150                self.decode_quantized(outputs, output_boxes, &mut masks)?;
1151                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1152                Ok(None)
1153            }
1154        }
1155    }
1156
1157    /// Decodes floating-point model outputs into detection boxes, returning
1158    /// raw `ProtoData` for segmentation models instead of materialized masks.
1159    ///
1160    /// Detections are always decoded into `output_boxes` regardless of model type.
1161    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1162    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1163    pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1164        &self,
1165        tracker: &mut TR,
1166        timestamp: u64,
1167        outputs: &[ArrayViewD<T>],
1168        output_boxes: &mut Vec<DetectBox>,
1169        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1170    ) -> Result<Option<ProtoData>, DecoderError>
1171    where
1172        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1173        f32: AsPrimitive<T>,
1174    {
1175        output_boxes.clear();
1176        output_tracks.clear();
1177        match &self.model_type {
1178            ModelType::YoloSegDet { boxes, protos } => {
1179                let proto = self.decode_tracked_yolo_segdet_float_proto(
1180                    tracker,
1181                    timestamp,
1182                    outputs,
1183                    boxes,
1184                    protos,
1185                    output_boxes,
1186                    output_tracks,
1187                )?;
1188                Ok(Some(proto))
1189            }
1190            ModelType::YoloSplitSegDet {
1191                boxes,
1192                scores,
1193                mask_coeff,
1194                protos,
1195            } => {
1196                let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1197                    tracker,
1198                    timestamp,
1199                    outputs,
1200                    boxes,
1201                    scores,
1202                    mask_coeff,
1203                    protos,
1204                    output_boxes,
1205                    output_tracks,
1206                )?;
1207                Ok(Some(proto))
1208            }
1209            ModelType::YoloSegDet2Way {
1210                boxes,
1211                mask_coeff,
1212                protos,
1213            } => {
1214                let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1215                    tracker,
1216                    timestamp,
1217                    outputs,
1218                    boxes,
1219                    mask_coeff,
1220                    protos,
1221                    output_boxes,
1222                    output_tracks,
1223                )?;
1224                Ok(Some(proto))
1225            }
1226            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1227                let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1228                    tracker,
1229                    timestamp,
1230                    outputs,
1231                    boxes,
1232                    protos,
1233                    output_boxes,
1234                    output_tracks,
1235                )?;
1236                Ok(Some(proto))
1237            }
1238            ModelType::YoloSplitEndToEndSegDet {
1239                boxes,
1240                scores,
1241                classes,
1242                mask_coeff,
1243                protos,
1244            } => {
1245                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1246                    tracker,
1247                    timestamp,
1248                    outputs,
1249                    boxes,
1250                    scores,
1251                    classes,
1252                    mask_coeff,
1253                    protos,
1254                    output_boxes,
1255                    output_tracks,
1256                )?;
1257                Ok(Some(proto))
1258            }
1259            // Non-seg variants: decode boxes via the non-proto path, then track.
1260            _ => {
1261                let mut masks = Vec::new();
1262                self.decode_float(outputs, output_boxes, &mut masks)?;
1263                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1264                Ok(None)
1265            }
1266        }
1267    }
1268
1269    // ========================================================================
1270    // TensorDyn-based tracked public API
1271    // ========================================================================
1272
1273    /// Decode model outputs with tracking.
1274    ///
1275    /// Accepts `TensorDyn` outputs directly from model inference. Automatically
1276    /// dispatches to quantized or float paths based on the tensor dtype, then
1277    /// updates the tracker with the decoded boxes.
1278    ///
1279    /// # Arguments
1280    ///
1281    /// * `tracker` - The tracker instance to update
1282    /// * `timestamp` - Current frame timestamp
1283    /// * `outputs` - Tensor outputs from model inference
1284    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1285    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
1286    /// * `output_tracks` - Destination for track info (cleared first)
1287    ///
1288    /// # Errors
1289    ///
1290    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1291    /// or the outputs don't match the decoder's model configuration.
1292    pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1293        &self,
1294        tracker: &mut TR,
1295        timestamp: u64,
1296        outputs: &[&edgefirst_tensor::TensorDyn],
1297        output_boxes: &mut Vec<DetectBox>,
1298        output_masks: &mut Vec<Segmentation>,
1299        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1300    ) -> Result<(), DecoderError> {
1301        let mapped = tensor_bridge::map_tensors(outputs)?;
1302        match &mapped {
1303            tensor_bridge::MappedOutputs::Quantized(maps) => {
1304                let views = tensor_bridge::quantized_views(maps)?;
1305                self.decode_tracked_quantized(
1306                    tracker,
1307                    timestamp,
1308                    &views,
1309                    output_boxes,
1310                    output_masks,
1311                    output_tracks,
1312                )
1313            }
1314            tensor_bridge::MappedOutputs::Float16(maps) => {
1315                let views = tensor_bridge::f16_views(maps)?;
1316                self.decode_tracked_float(
1317                    tracker,
1318                    timestamp,
1319                    &views,
1320                    output_boxes,
1321                    output_masks,
1322                    output_tracks,
1323                )
1324            }
1325            tensor_bridge::MappedOutputs::Float32(maps) => {
1326                let views = tensor_bridge::f32_views(maps)?;
1327                self.decode_tracked_float(
1328                    tracker,
1329                    timestamp,
1330                    &views,
1331                    output_boxes,
1332                    output_masks,
1333                    output_tracks,
1334                )
1335            }
1336            tensor_bridge::MappedOutputs::Float64(maps) => {
1337                let views = tensor_bridge::f64_views(maps)?;
1338                self.decode_tracked_float(
1339                    tracker,
1340                    timestamp,
1341                    &views,
1342                    output_boxes,
1343                    output_masks,
1344                    output_tracks,
1345                )
1346            }
1347        }
1348    }
1349
1350    /// Decode model outputs with tracking, returning raw proto data for
1351    /// segmentation models.
1352    ///
1353    /// Accepts `TensorDyn` outputs directly from model inference.
1354    /// Returns `Ok(None)` for detection-only and ModelPack models.
1355    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1356    ///
1357    /// # Arguments
1358    ///
1359    /// * `tracker` - The tracker instance to update
1360    /// * `timestamp` - Current frame timestamp
1361    /// * `outputs` - Tensor outputs from model inference
1362    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1363    /// * `output_tracks` - Destination for track info (cleared first)
1364    ///
1365    /// # Errors
1366    ///
1367    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1368    /// or the outputs don't match the decoder's model configuration.
1369    pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1370        &self,
1371        tracker: &mut TR,
1372        timestamp: u64,
1373        outputs: &[&edgefirst_tensor::TensorDyn],
1374        output_boxes: &mut Vec<DetectBox>,
1375        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1376    ) -> Result<Option<ProtoData>, DecoderError> {
1377        let mapped = tensor_bridge::map_tensors(outputs)?;
1378        match &mapped {
1379            tensor_bridge::MappedOutputs::Quantized(maps) => {
1380                let views = tensor_bridge::quantized_views(maps)?;
1381                self.decode_tracked_quantized_proto(
1382                    tracker,
1383                    timestamp,
1384                    &views,
1385                    output_boxes,
1386                    output_tracks,
1387                )
1388            }
1389            tensor_bridge::MappedOutputs::Float16(maps) => {
1390                let views = tensor_bridge::f16_views(maps)?;
1391                self.decode_tracked_float_proto(
1392                    tracker,
1393                    timestamp,
1394                    &views,
1395                    output_boxes,
1396                    output_tracks,
1397                )
1398            }
1399            tensor_bridge::MappedOutputs::Float32(maps) => {
1400                let views = tensor_bridge::f32_views(maps)?;
1401                self.decode_tracked_float_proto(
1402                    tracker,
1403                    timestamp,
1404                    &views,
1405                    output_boxes,
1406                    output_tracks,
1407                )
1408            }
1409            tensor_bridge::MappedOutputs::Float64(maps) => {
1410                let views = tensor_bridge::f64_views(maps)?;
1411                self.decode_tracked_float_proto(
1412                    tracker,
1413                    timestamp,
1414                    &views,
1415                    output_boxes,
1416                    output_tracks,
1417                )
1418            }
1419        }
1420    }
1421}