Skip to main content

edgefirst_decoder/decoder/
mod.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone, PartialEq)]
15pub struct Decoder {
16    model_type: ModelType,
17    pub iou_threshold: f32,
18    pub score_threshold: f32,
19    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
20    /// models)
21    pub nms: Option<configs::Nms>,
22    /// Whether decoded boxes are in normalized [0,1] coordinates.
23    /// - `Some(true)`: Coordinates in [0,1] range
24    /// - `Some(false)`: Pixel coordinates
25    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
26    ///   1.0)
27    normalized: Option<bool>,
28}
29
30#[derive(Debug)]
31pub(crate) enum ArrayViewDQuantized<'a> {
32    UInt8(ArrayViewD<'a, u8>),
33    Int8(ArrayViewD<'a, i8>),
34    UInt16(ArrayViewD<'a, u16>),
35    Int16(ArrayViewD<'a, i16>),
36    UInt32(ArrayViewD<'a, u32>),
37    Int32(ArrayViewD<'a, i32>),
38}
39
40impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
41where
42    D: Dimension,
43{
44    fn from(arr: ArrayView<'a, u8, D>) -> Self {
45        Self::UInt8(arr.into_dyn())
46    }
47}
48
49impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
50where
51    D: Dimension,
52{
53    fn from(arr: ArrayView<'a, i8, D>) -> Self {
54        Self::Int8(arr.into_dyn())
55    }
56}
57
58impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
59where
60    D: Dimension,
61{
62    fn from(arr: ArrayView<'a, u16, D>) -> Self {
63        Self::UInt16(arr.into_dyn())
64    }
65}
66
67impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
68where
69    D: Dimension,
70{
71    fn from(arr: ArrayView<'a, i16, D>) -> Self {
72        Self::Int16(arr.into_dyn())
73    }
74}
75
76impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
77where
78    D: Dimension,
79{
80    fn from(arr: ArrayView<'a, u32, D>) -> Self {
81        Self::UInt32(arr.into_dyn())
82    }
83}
84
85impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
86where
87    D: Dimension,
88{
89    fn from(arr: ArrayView<'a, i32, D>) -> Self {
90        Self::Int32(arr.into_dyn())
91    }
92}
93
94impl<'a> ArrayViewDQuantized<'a> {
95    /// Returns the shape of the underlying array.
96    pub(crate) fn shape(&self) -> &[usize] {
97        match self {
98            ArrayViewDQuantized::UInt8(a) => a.shape(),
99            ArrayViewDQuantized::Int8(a) => a.shape(),
100            ArrayViewDQuantized::UInt16(a) => a.shape(),
101            ArrayViewDQuantized::Int16(a) => a.shape(),
102            ArrayViewDQuantized::UInt32(a) => a.shape(),
103            ArrayViewDQuantized::Int32(a) => a.shape(),
104        }
105    }
106}
107
108/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
109/// monomorphized code paths by 6 (one per integer variant), so nesting
110/// N levels deep produces 6^N instantiations.
111///
112/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
113/// (6*N paths) or split into independent phases that each nest at most 2 levels.
114macro_rules! with_quantized {
115    ($x:expr, $var:ident, $body:expr) => {
116        match $x {
117            ArrayViewDQuantized::UInt8(x) => {
118                let $var = x;
119                $body
120            }
121            ArrayViewDQuantized::Int8(x) => {
122                let $var = x;
123                $body
124            }
125            ArrayViewDQuantized::UInt16(x) => {
126                let $var = x;
127                $body
128            }
129            ArrayViewDQuantized::Int16(x) => {
130                let $var = x;
131                $body
132            }
133            ArrayViewDQuantized::UInt32(x) => {
134                let $var = x;
135                $body
136            }
137            ArrayViewDQuantized::Int32(x) => {
138                let $var = x;
139                $body
140            }
141        }
142    };
143}
144
145mod builder;
146mod helpers;
147mod postprocess;
148mod tensor_bridge;
149mod tests;
150
151pub use builder::DecoderBuilder;
152pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
153
154impl Decoder {
155    /// This function returns the parsed model type of the decoder.
156    ///
157    /// # Examples
158    ///
159    /// ```rust
160    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
161    /// # fn main() -> DecoderResult<()> {
162    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
163    ///     let decoder = DecoderBuilder::default()
164    ///         .with_config_yaml_str(config_yaml)
165    ///         .build()?;
166    ///     assert!(matches!(
167    ///         decoder.model_type(),
168    ///         ModelType::ModelPackDetSplit { .. }
169    ///     ));
170    /// #    Ok(())
171    /// # }
172    /// ```
173    pub fn model_type(&self) -> &ModelType {
174        &self.model_type
175    }
176
177    /// Returns the box coordinate format if known from the model config.
178    ///
179    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
180    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
181    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
182    ///   1.0)
183    ///
184    /// This is determined by the model config's `normalized` field, not the NMS
185    /// mode. When coordinates are in pixels or unknown, the caller may need
186    /// to normalize using the model input dimensions.
187    ///
188    /// # Examples
189    ///
190    /// ```rust
191    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
192    /// # fn main() -> DecoderResult<()> {
193    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
194    ///     let decoder = DecoderBuilder::default()
195    ///         .with_config_yaml_str(config_yaml)
196    ///         .build()?;
197    ///     // Config doesn't specify normalized, so it's None
198    ///     assert!(decoder.normalized_boxes().is_none());
199    /// #    Ok(())
200    /// # }
201    /// ```
202    pub fn normalized_boxes(&self) -> Option<bool> {
203        self.normalized
204    }
205
206    /// Decode quantized model outputs into detection boxes and segmentation
207    /// masks. The quantized outputs can be of u8, i8, u16, i16, u32, or i32
208    /// types. Clears the provided output vectors before populating them.
209    pub(crate) fn decode_quantized(
210        &self,
211        outputs: &[ArrayViewDQuantized],
212        output_boxes: &mut Vec<DetectBox>,
213        output_masks: &mut Vec<Segmentation>,
214    ) -> Result<(), DecoderError> {
215        output_boxes.clear();
216        output_masks.clear();
217        match &self.model_type {
218            ModelType::ModelPackSegDet {
219                boxes,
220                scores,
221                segmentation,
222            } => {
223                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
224                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
225            }
226            ModelType::ModelPackSegDetSplit {
227                detection,
228                segmentation,
229            } => {
230                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
231                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
232            }
233            ModelType::ModelPackDet { boxes, scores } => {
234                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
235            }
236            ModelType::ModelPackDetSplit { detection } => {
237                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
238            }
239            ModelType::ModelPackSeg { segmentation } => {
240                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
241            }
242            ModelType::YoloDet { boxes } => {
243                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
244            }
245            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
246                outputs,
247                boxes,
248                protos,
249                output_boxes,
250                output_masks,
251            ),
252            ModelType::YoloSplitDet { boxes, scores } => {
253                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
254            }
255            ModelType::YoloSplitSegDet {
256                boxes,
257                scores,
258                mask_coeff,
259                protos,
260            } => self.decode_yolo_split_segdet_quantized(
261                outputs,
262                boxes,
263                scores,
264                mask_coeff,
265                protos,
266                output_boxes,
267                output_masks,
268            ),
269            ModelType::YoloSegDet2Way {
270                boxes,
271                mask_coeff,
272                protos,
273            } => self.decode_yolo_segdet_2way_quantized(
274                outputs,
275                boxes,
276                mask_coeff,
277                protos,
278                output_boxes,
279                output_masks,
280            ),
281            ModelType::YoloEndToEndDet { boxes } => {
282                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
283            }
284            ModelType::YoloEndToEndSegDet { boxes, protos } => self
285                .decode_yolo_end_to_end_segdet_quantized(
286                    outputs,
287                    boxes,
288                    protos,
289                    output_boxes,
290                    output_masks,
291                ),
292            ModelType::YoloSplitEndToEndDet {
293                boxes,
294                scores,
295                classes,
296            } => self.decode_yolo_split_end_to_end_det_quantized(
297                outputs,
298                boxes,
299                scores,
300                classes,
301                output_boxes,
302            ),
303            ModelType::YoloSplitEndToEndSegDet {
304                boxes,
305                scores,
306                classes,
307                mask_coeff,
308                protos,
309            } => self.decode_yolo_split_end_to_end_segdet_quantized(
310                outputs,
311                boxes,
312                scores,
313                classes,
314                mask_coeff,
315                protos,
316                output_boxes,
317                output_masks,
318            ),
319        }
320    }
321
322    /// Decode floating point model outputs into detection boxes and
323    /// segmentation masks. Clears the provided output vectors before
324    /// populating them.
325    pub(crate) fn decode_float<T>(
326        &self,
327        outputs: &[ArrayViewD<T>],
328        output_boxes: &mut Vec<DetectBox>,
329        output_masks: &mut Vec<Segmentation>,
330    ) -> Result<(), DecoderError>
331    where
332        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
333        f32: AsPrimitive<T>,
334    {
335        output_boxes.clear();
336        output_masks.clear();
337        match &self.model_type {
338            ModelType::ModelPackSegDet {
339                boxes,
340                scores,
341                segmentation,
342            } => {
343                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
344                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
345            }
346            ModelType::ModelPackSegDetSplit {
347                detection,
348                segmentation,
349            } => {
350                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
351                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
352            }
353            ModelType::ModelPackDet { boxes, scores } => {
354                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
355            }
356            ModelType::ModelPackDetSplit { detection } => {
357                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
358            }
359            ModelType::ModelPackSeg { segmentation } => {
360                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
361            }
362            ModelType::YoloDet { boxes } => {
363                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
364            }
365            ModelType::YoloSegDet { boxes, protos } => {
366                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
367            }
368            ModelType::YoloSplitDet { boxes, scores } => {
369                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
370            }
371            ModelType::YoloSplitSegDet {
372                boxes,
373                scores,
374                mask_coeff,
375                protos,
376            } => {
377                self.decode_yolo_split_segdet_float(
378                    outputs,
379                    boxes,
380                    scores,
381                    mask_coeff,
382                    protos,
383                    output_boxes,
384                    output_masks,
385                )?;
386            }
387            ModelType::YoloSegDet2Way {
388                boxes,
389                mask_coeff,
390                protos,
391            } => {
392                self.decode_yolo_segdet_2way_float(
393                    outputs,
394                    boxes,
395                    mask_coeff,
396                    protos,
397                    output_boxes,
398                    output_masks,
399                )?;
400            }
401            ModelType::YoloEndToEndDet { boxes } => {
402                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
403            }
404            ModelType::YoloEndToEndSegDet { boxes, protos } => {
405                self.decode_yolo_end_to_end_segdet_float(
406                    outputs,
407                    boxes,
408                    protos,
409                    output_boxes,
410                    output_masks,
411                )?;
412            }
413            ModelType::YoloSplitEndToEndDet {
414                boxes,
415                scores,
416                classes,
417            } => {
418                self.decode_yolo_split_end_to_end_det_float(
419                    outputs,
420                    boxes,
421                    scores,
422                    classes,
423                    output_boxes,
424                )?;
425            }
426            ModelType::YoloSplitEndToEndSegDet {
427                boxes,
428                scores,
429                classes,
430                mask_coeff,
431                protos,
432            } => {
433                self.decode_yolo_split_end_to_end_segdet_float(
434                    outputs,
435                    boxes,
436                    scores,
437                    classes,
438                    mask_coeff,
439                    protos,
440                    output_boxes,
441                    output_masks,
442                )?;
443            }
444        }
445        Ok(())
446    }
447
448    /// Decodes quantized model outputs into detection boxes, returning raw
449    /// `ProtoData` for segmentation models instead of materialized masks.
450    ///
451    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
452    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
453    /// for YOLO segmentation models.
454    pub(crate) fn decode_quantized_proto(
455        &self,
456        outputs: &[ArrayViewDQuantized],
457        output_boxes: &mut Vec<DetectBox>,
458    ) -> Result<Option<ProtoData>, DecoderError> {
459        output_boxes.clear();
460        match &self.model_type {
461            // Detection-only variants: decode boxes, return None for proto data.
462            ModelType::ModelPackDet { boxes, scores } => {
463                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
464                Ok(None)
465            }
466            ModelType::ModelPackDetSplit { detection } => {
467                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
468                Ok(None)
469            }
470            ModelType::YoloDet { boxes } => {
471                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
472                Ok(None)
473            }
474            ModelType::YoloSplitDet { boxes, scores } => {
475                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
476                Ok(None)
477            }
478            ModelType::YoloEndToEndDet { boxes } => {
479                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
480                Ok(None)
481            }
482            ModelType::YoloSplitEndToEndDet {
483                boxes,
484                scores,
485                classes,
486            } => {
487                self.decode_yolo_split_end_to_end_det_quantized(
488                    outputs,
489                    boxes,
490                    scores,
491                    classes,
492                    output_boxes,
493                )?;
494                Ok(None)
495            }
496            // ModelPack seg/segdet variants have no YOLO proto data.
497            ModelType::ModelPackSegDet { boxes, scores, .. } => {
498                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
499                Ok(None)
500            }
501            ModelType::ModelPackSegDetSplit { detection, .. } => {
502                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
503                Ok(None)
504            }
505            ModelType::ModelPackSeg { .. } => Ok(None),
506
507            ModelType::YoloSegDet { boxes, protos } => {
508                let proto =
509                    self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
510                Ok(Some(proto))
511            }
512            ModelType::YoloSplitSegDet {
513                boxes,
514                scores,
515                mask_coeff,
516                protos,
517            } => {
518                let proto = self.decode_yolo_split_segdet_quantized_proto(
519                    outputs,
520                    boxes,
521                    scores,
522                    mask_coeff,
523                    protos,
524                    output_boxes,
525                )?;
526                Ok(Some(proto))
527            }
528            ModelType::YoloSegDet2Way {
529                boxes,
530                mask_coeff,
531                protos,
532            } => {
533                let proto = self.decode_yolo_segdet_2way_quantized_proto(
534                    outputs,
535                    boxes,
536                    mask_coeff,
537                    protos,
538                    output_boxes,
539                )?;
540                Ok(Some(proto))
541            }
542            ModelType::YoloEndToEndSegDet { boxes, protos } => {
543                let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
544                    outputs,
545                    boxes,
546                    protos,
547                    output_boxes,
548                )?;
549                Ok(Some(proto))
550            }
551            ModelType::YoloSplitEndToEndSegDet {
552                boxes,
553                scores,
554                classes,
555                mask_coeff,
556                protos,
557            } => {
558                let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
559                    outputs,
560                    boxes,
561                    scores,
562                    classes,
563                    mask_coeff,
564                    protos,
565                    output_boxes,
566                )?;
567                Ok(Some(proto))
568            }
569        }
570    }
571
572    /// Decodes floating-point model outputs into detection boxes, returning
573    /// raw `ProtoData` for segmentation models instead of materialized masks.
574    ///
575    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
576    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
577    /// for YOLO segmentation models.
578    pub(crate) fn decode_float_proto<T>(
579        &self,
580        outputs: &[ArrayViewD<T>],
581        output_boxes: &mut Vec<DetectBox>,
582    ) -> Result<Option<ProtoData>, DecoderError>
583    where
584        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
585        f32: AsPrimitive<T>,
586    {
587        output_boxes.clear();
588        match &self.model_type {
589            // Detection-only variants: decode boxes, return None for proto data.
590            ModelType::ModelPackDet { boxes, scores } => {
591                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
592                Ok(None)
593            }
594            ModelType::ModelPackDetSplit { detection } => {
595                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
596                Ok(None)
597            }
598            ModelType::YoloDet { boxes } => {
599                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
600                Ok(None)
601            }
602            ModelType::YoloSplitDet { boxes, scores } => {
603                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
604                Ok(None)
605            }
606            ModelType::YoloEndToEndDet { boxes } => {
607                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
608                Ok(None)
609            }
610            ModelType::YoloSplitEndToEndDet {
611                boxes,
612                scores,
613                classes,
614            } => {
615                self.decode_yolo_split_end_to_end_det_float(
616                    outputs,
617                    boxes,
618                    scores,
619                    classes,
620                    output_boxes,
621                )?;
622                Ok(None)
623            }
624            // ModelPack seg/segdet variants have no YOLO proto data.
625            ModelType::ModelPackSegDet { boxes, scores, .. } => {
626                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
627                Ok(None)
628            }
629            ModelType::ModelPackSegDetSplit { detection, .. } => {
630                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
631                Ok(None)
632            }
633            ModelType::ModelPackSeg { .. } => Ok(None),
634
635            ModelType::YoloSegDet { boxes, protos } => {
636                let proto =
637                    self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
638                Ok(Some(proto))
639            }
640            ModelType::YoloSplitSegDet {
641                boxes,
642                scores,
643                mask_coeff,
644                protos,
645            } => {
646                let proto = self.decode_yolo_split_segdet_float_proto(
647                    outputs,
648                    boxes,
649                    scores,
650                    mask_coeff,
651                    protos,
652                    output_boxes,
653                )?;
654                Ok(Some(proto))
655            }
656            ModelType::YoloSegDet2Way {
657                boxes,
658                mask_coeff,
659                protos,
660            } => {
661                let proto = self.decode_yolo_segdet_2way_float_proto(
662                    outputs,
663                    boxes,
664                    mask_coeff,
665                    protos,
666                    output_boxes,
667                )?;
668                Ok(Some(proto))
669            }
670            ModelType::YoloEndToEndSegDet { boxes, protos } => {
671                let proto = self.decode_yolo_end_to_end_segdet_float_proto(
672                    outputs,
673                    boxes,
674                    protos,
675                    output_boxes,
676                )?;
677                Ok(Some(proto))
678            }
679            ModelType::YoloSplitEndToEndSegDet {
680                boxes,
681                scores,
682                classes,
683                mask_coeff,
684                protos,
685            } => {
686                let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
687                    outputs,
688                    boxes,
689                    scores,
690                    classes,
691                    mask_coeff,
692                    protos,
693                    output_boxes,
694                )?;
695                Ok(Some(proto))
696            }
697        }
698    }
699
700    // ========================================================================
701    // TensorDyn-based public API
702    // ========================================================================
703
704    /// Decode model outputs into detection boxes and segmentation masks.
705    ///
706    /// This is the primary decode API. Accepts `TensorDyn` outputs directly
707    /// from model inference. Automatically dispatches to quantized or float
708    /// paths based on the tensor dtype.
709    ///
710    /// # Arguments
711    ///
712    /// * `outputs` - Tensor outputs from model inference
713    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
714    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
715    ///
716    /// # Errors
717    ///
718    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
719    /// or the outputs don't match the decoder's model configuration.
720    pub fn decode(
721        &self,
722        outputs: &[&edgefirst_tensor::TensorDyn],
723        output_boxes: &mut Vec<DetectBox>,
724        output_masks: &mut Vec<Segmentation>,
725    ) -> Result<(), DecoderError> {
726        let mapped = tensor_bridge::map_tensors(outputs)?;
727        match &mapped {
728            tensor_bridge::MappedOutputs::Quantized(maps) => {
729                let views = tensor_bridge::quantized_views(maps)?;
730                self.decode_quantized(&views, output_boxes, output_masks)
731            }
732            tensor_bridge::MappedOutputs::Float32(maps) => {
733                let views = tensor_bridge::f32_views(maps)?;
734                self.decode_float(&views, output_boxes, output_masks)
735            }
736            tensor_bridge::MappedOutputs::Float64(maps) => {
737                let views = tensor_bridge::f64_views(maps)?;
738                self.decode_float(&views, output_boxes, output_masks)
739            }
740        }
741    }
742
743    /// Decode model outputs into detection boxes, returning raw proto data
744    /// for segmentation models instead of materialized masks.
745    ///
746    /// Accepts `TensorDyn` outputs directly from model inference.
747    /// Detections are always decoded into `output_boxes` regardless of model type.
748    /// Returns `Ok(None)` for detection-only and ModelPack models.
749    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
750    ///
751    /// # Arguments
752    ///
753    /// * `outputs` - Tensor outputs from model inference
754    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
755    ///
756    /// # Errors
757    ///
758    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
759    /// or the outputs don't match the decoder's model configuration.
760    pub fn decode_proto(
761        &self,
762        outputs: &[&edgefirst_tensor::TensorDyn],
763        output_boxes: &mut Vec<DetectBox>,
764    ) -> Result<Option<ProtoData>, DecoderError> {
765        let mapped = tensor_bridge::map_tensors(outputs)?;
766        match &mapped {
767            tensor_bridge::MappedOutputs::Quantized(maps) => {
768                let views = tensor_bridge::quantized_views(maps)?;
769                self.decode_quantized_proto(&views, output_boxes)
770            }
771            tensor_bridge::MappedOutputs::Float32(maps) => {
772                let views = tensor_bridge::f32_views(maps)?;
773                self.decode_float_proto(&views, output_boxes)
774            }
775            tensor_bridge::MappedOutputs::Float64(maps) => {
776                let views = tensor_bridge::f64_views(maps)?;
777                self.decode_float_proto(&views, output_boxes)
778            }
779        }
780    }
781}
782
783#[cfg(feature = "tracker")]
784pub use edgefirst_tracker::TrackInfo;
785
786#[cfg(feature = "tracker")]
787pub use edgefirst_tracker::Tracker;
788
789#[cfg(feature = "tracker")]
790impl Decoder {
791    /// Decode quantized model outputs into detection boxes and segmentation
792    /// masks with tracking. Clears the provided output vectors before
793    /// populating them.
794    pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
795        &self,
796        tracker: &mut TR,
797        timestamp: u64,
798        outputs: &[ArrayViewDQuantized],
799        output_boxes: &mut Vec<DetectBox>,
800        output_masks: &mut Vec<Segmentation>,
801        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
802    ) -> Result<(), DecoderError> {
803        output_boxes.clear();
804        output_masks.clear();
805        output_tracks.clear();
806
807        // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
808        // Only boxes that come from decoding can be used for proto/mask generation.
809        match &self.model_type {
810            ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
811                tracker,
812                timestamp,
813                outputs,
814                boxes,
815                protos,
816                output_boxes,
817                output_masks,
818                output_tracks,
819            ),
820            ModelType::YoloSplitSegDet {
821                boxes,
822                scores,
823                mask_coeff,
824                protos,
825            } => self.decode_tracked_yolo_split_segdet_quantized(
826                tracker,
827                timestamp,
828                outputs,
829                boxes,
830                scores,
831                mask_coeff,
832                protos,
833                output_boxes,
834                output_masks,
835                output_tracks,
836            ),
837            ModelType::YoloEndToEndSegDet { boxes, protos } => self
838                .decode_tracked_yolo_end_to_end_segdet_quantized(
839                    tracker,
840                    timestamp,
841                    outputs,
842                    boxes,
843                    protos,
844                    output_boxes,
845                    output_masks,
846                    output_tracks,
847                ),
848            ModelType::YoloSplitEndToEndSegDet {
849                boxes,
850                scores,
851                classes,
852                mask_coeff,
853                protos,
854            } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
855                tracker,
856                timestamp,
857                outputs,
858                boxes,
859                scores,
860                classes,
861                mask_coeff,
862                protos,
863                output_boxes,
864                output_masks,
865                output_tracks,
866            ),
867            ModelType::YoloSegDet2Way {
868                boxes,
869                mask_coeff,
870                protos,
871            } => self.decode_tracked_yolo_segdet_2way_quantized(
872                tracker,
873                timestamp,
874                outputs,
875                boxes,
876                mask_coeff,
877                protos,
878                output_boxes,
879                output_masks,
880                output_tracks,
881            ),
882            _ => {
883                self.decode_quantized(outputs, output_boxes, output_masks)?;
884                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
885                Ok(())
886            }
887        }
888    }
889
890    /// This function decodes floating point model outputs into detection boxes
891    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
892    /// masks will be decoded. The function clears the provided output
893    /// vectors before populating them with the decoded results.
894    ///
895    /// This function returns an `Error` if the provided outputs don't
896    /// match the configuration provided by the user when building the decoder.
897    ///
898    /// Any quantization information in the configuration will be ignored.
899    pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
900        &self,
901        tracker: &mut TR,
902        timestamp: u64,
903        outputs: &[ArrayViewD<T>],
904        output_boxes: &mut Vec<DetectBox>,
905        output_masks: &mut Vec<Segmentation>,
906        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
907    ) -> Result<(), DecoderError>
908    where
909        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
910        f32: AsPrimitive<T>,
911    {
912        output_boxes.clear();
913        output_masks.clear();
914        output_tracks.clear();
915        match &self.model_type {
916            ModelType::YoloSegDet { boxes, protos } => {
917                self.decode_tracked_yolo_segdet_float(
918                    tracker,
919                    timestamp,
920                    outputs,
921                    boxes,
922                    protos,
923                    output_boxes,
924                    output_masks,
925                    output_tracks,
926                )?;
927            }
928            ModelType::YoloSplitSegDet {
929                boxes,
930                scores,
931                mask_coeff,
932                protos,
933            } => {
934                self.decode_tracked_yolo_split_segdet_float(
935                    tracker,
936                    timestamp,
937                    outputs,
938                    boxes,
939                    scores,
940                    mask_coeff,
941                    protos,
942                    output_boxes,
943                    output_masks,
944                    output_tracks,
945                )?;
946            }
947            ModelType::YoloEndToEndSegDet { boxes, protos } => {
948                self.decode_tracked_yolo_end_to_end_segdet_float(
949                    tracker,
950                    timestamp,
951                    outputs,
952                    boxes,
953                    protos,
954                    output_boxes,
955                    output_masks,
956                    output_tracks,
957                )?;
958            }
959            ModelType::YoloSplitEndToEndSegDet {
960                boxes,
961                scores,
962                classes,
963                mask_coeff,
964                protos,
965            } => {
966                self.decode_tracked_yolo_split_end_to_end_segdet_float(
967                    tracker,
968                    timestamp,
969                    outputs,
970                    boxes,
971                    scores,
972                    classes,
973                    mask_coeff,
974                    protos,
975                    output_boxes,
976                    output_masks,
977                    output_tracks,
978                )?;
979            }
980            ModelType::YoloSegDet2Way {
981                boxes,
982                mask_coeff,
983                protos,
984            } => {
985                self.decode_tracked_yolo_segdet_2way_float(
986                    tracker,
987                    timestamp,
988                    outputs,
989                    boxes,
990                    mask_coeff,
991                    protos,
992                    output_boxes,
993                    output_masks,
994                    output_tracks,
995                )?;
996            }
997            _ => {
998                self.decode_float(outputs, output_boxes, output_masks)?;
999                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1000            }
1001        }
1002        Ok(())
1003    }
1004
1005    /// Decodes quantized model outputs into detection boxes, returning raw
1006    /// `ProtoData` for segmentation models instead of materialized masks.
1007    ///
1008    /// Returns `Ok(None)` for detection-only and ModelPack models (use
1009    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
1010    /// YOLO segmentation models.
1011    pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1012        &self,
1013        tracker: &mut TR,
1014        timestamp: u64,
1015        outputs: &[ArrayViewDQuantized],
1016        output_boxes: &mut Vec<DetectBox>,
1017        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1018    ) -> Result<Option<ProtoData>, DecoderError> {
1019        output_boxes.clear();
1020        output_tracks.clear();
1021        match &self.model_type {
1022            ModelType::YoloSegDet { boxes, protos } => {
1023                let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1024                    tracker,
1025                    timestamp,
1026                    outputs,
1027                    boxes,
1028                    protos,
1029                    output_boxes,
1030                    output_tracks,
1031                )?;
1032                Ok(Some(proto))
1033            }
1034            ModelType::YoloSplitSegDet {
1035                boxes,
1036                scores,
1037                mask_coeff,
1038                protos,
1039            } => {
1040                let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1041                    tracker,
1042                    timestamp,
1043                    outputs,
1044                    boxes,
1045                    scores,
1046                    mask_coeff,
1047                    protos,
1048                    output_boxes,
1049                    output_tracks,
1050                )?;
1051                Ok(Some(proto))
1052            }
1053            ModelType::YoloSegDet2Way {
1054                boxes,
1055                mask_coeff,
1056                protos,
1057            } => {
1058                let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1059                    tracker,
1060                    timestamp,
1061                    outputs,
1062                    boxes,
1063                    mask_coeff,
1064                    protos,
1065                    output_boxes,
1066                    output_tracks,
1067                )?;
1068                Ok(Some(proto))
1069            }
1070            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1071                let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1072                    tracker,
1073                    timestamp,
1074                    outputs,
1075                    boxes,
1076                    protos,
1077                    output_boxes,
1078                    output_tracks,
1079                )?;
1080                Ok(Some(proto))
1081            }
1082            ModelType::YoloSplitEndToEndSegDet {
1083                boxes,
1084                scores,
1085                classes,
1086                mask_coeff,
1087                protos,
1088            } => {
1089                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1090                    tracker,
1091                    timestamp,
1092                    outputs,
1093                    boxes,
1094                    scores,
1095                    classes,
1096                    mask_coeff,
1097                    protos,
1098                    output_boxes,
1099                    output_tracks,
1100                )?;
1101                Ok(Some(proto))
1102            }
1103            // Non-seg variants: decode boxes via the non-proto path, then track.
1104            _ => {
1105                let mut masks = Vec::new();
1106                self.decode_quantized(outputs, output_boxes, &mut masks)?;
1107                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1108                Ok(None)
1109            }
1110        }
1111    }
1112
1113    /// Decodes floating-point model outputs into detection boxes, returning
1114    /// raw `ProtoData` for segmentation models instead of materialized masks.
1115    ///
1116    /// Detections are always decoded into `output_boxes` regardless of model type.
1117    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1118    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1119    pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1120        &self,
1121        tracker: &mut TR,
1122        timestamp: u64,
1123        outputs: &[ArrayViewD<T>],
1124        output_boxes: &mut Vec<DetectBox>,
1125        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1126    ) -> Result<Option<ProtoData>, DecoderError>
1127    where
1128        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1129        f32: AsPrimitive<T>,
1130    {
1131        output_boxes.clear();
1132        output_tracks.clear();
1133        match &self.model_type {
1134            ModelType::YoloSegDet { boxes, protos } => {
1135                let proto = self.decode_tracked_yolo_segdet_float_proto(
1136                    tracker,
1137                    timestamp,
1138                    outputs,
1139                    boxes,
1140                    protos,
1141                    output_boxes,
1142                    output_tracks,
1143                )?;
1144                Ok(Some(proto))
1145            }
1146            ModelType::YoloSplitSegDet {
1147                boxes,
1148                scores,
1149                mask_coeff,
1150                protos,
1151            } => {
1152                let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1153                    tracker,
1154                    timestamp,
1155                    outputs,
1156                    boxes,
1157                    scores,
1158                    mask_coeff,
1159                    protos,
1160                    output_boxes,
1161                    output_tracks,
1162                )?;
1163                Ok(Some(proto))
1164            }
1165            ModelType::YoloSegDet2Way {
1166                boxes,
1167                mask_coeff,
1168                protos,
1169            } => {
1170                let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1171                    tracker,
1172                    timestamp,
1173                    outputs,
1174                    boxes,
1175                    mask_coeff,
1176                    protos,
1177                    output_boxes,
1178                    output_tracks,
1179                )?;
1180                Ok(Some(proto))
1181            }
1182            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1183                let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1184                    tracker,
1185                    timestamp,
1186                    outputs,
1187                    boxes,
1188                    protos,
1189                    output_boxes,
1190                    output_tracks,
1191                )?;
1192                Ok(Some(proto))
1193            }
1194            ModelType::YoloSplitEndToEndSegDet {
1195                boxes,
1196                scores,
1197                classes,
1198                mask_coeff,
1199                protos,
1200            } => {
1201                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1202                    tracker,
1203                    timestamp,
1204                    outputs,
1205                    boxes,
1206                    scores,
1207                    classes,
1208                    mask_coeff,
1209                    protos,
1210                    output_boxes,
1211                    output_tracks,
1212                )?;
1213                Ok(Some(proto))
1214            }
1215            // Non-seg variants: decode boxes via the non-proto path, then track.
1216            _ => {
1217                let mut masks = Vec::new();
1218                self.decode_float(outputs, output_boxes, &mut masks)?;
1219                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1220                Ok(None)
1221            }
1222        }
1223    }
1224
1225    // ========================================================================
1226    // TensorDyn-based tracked public API
1227    // ========================================================================
1228
1229    /// Decode model outputs with tracking.
1230    ///
1231    /// Accepts `TensorDyn` outputs directly from model inference. Automatically
1232    /// dispatches to quantized or float paths based on the tensor dtype, then
1233    /// updates the tracker with the decoded boxes.
1234    ///
1235    /// # Arguments
1236    ///
1237    /// * `tracker` - The tracker instance to update
1238    /// * `timestamp` - Current frame timestamp
1239    /// * `outputs` - Tensor outputs from model inference
1240    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1241    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
1242    /// * `output_tracks` - Destination for track info (cleared first)
1243    ///
1244    /// # Errors
1245    ///
1246    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1247    /// or the outputs don't match the decoder's model configuration.
1248    pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1249        &self,
1250        tracker: &mut TR,
1251        timestamp: u64,
1252        outputs: &[&edgefirst_tensor::TensorDyn],
1253        output_boxes: &mut Vec<DetectBox>,
1254        output_masks: &mut Vec<Segmentation>,
1255        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1256    ) -> Result<(), DecoderError> {
1257        let mapped = tensor_bridge::map_tensors(outputs)?;
1258        match &mapped {
1259            tensor_bridge::MappedOutputs::Quantized(maps) => {
1260                let views = tensor_bridge::quantized_views(maps)?;
1261                self.decode_tracked_quantized(
1262                    tracker,
1263                    timestamp,
1264                    &views,
1265                    output_boxes,
1266                    output_masks,
1267                    output_tracks,
1268                )
1269            }
1270            tensor_bridge::MappedOutputs::Float32(maps) => {
1271                let views = tensor_bridge::f32_views(maps)?;
1272                self.decode_tracked_float(
1273                    tracker,
1274                    timestamp,
1275                    &views,
1276                    output_boxes,
1277                    output_masks,
1278                    output_tracks,
1279                )
1280            }
1281            tensor_bridge::MappedOutputs::Float64(maps) => {
1282                let views = tensor_bridge::f64_views(maps)?;
1283                self.decode_tracked_float(
1284                    tracker,
1285                    timestamp,
1286                    &views,
1287                    output_boxes,
1288                    output_masks,
1289                    output_tracks,
1290                )
1291            }
1292        }
1293    }
1294
1295    /// Decode model outputs with tracking, returning raw proto data for
1296    /// segmentation models.
1297    ///
1298    /// Accepts `TensorDyn` outputs directly from model inference.
1299    /// Returns `Ok(None)` for detection-only and ModelPack models.
1300    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1301    ///
1302    /// # Arguments
1303    ///
1304    /// * `tracker` - The tracker instance to update
1305    /// * `timestamp` - Current frame timestamp
1306    /// * `outputs` - Tensor outputs from model inference
1307    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1308    /// * `output_tracks` - Destination for track info (cleared first)
1309    ///
1310    /// # Errors
1311    ///
1312    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1313    /// or the outputs don't match the decoder's model configuration.
1314    pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1315        &self,
1316        tracker: &mut TR,
1317        timestamp: u64,
1318        outputs: &[&edgefirst_tensor::TensorDyn],
1319        output_boxes: &mut Vec<DetectBox>,
1320        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1321    ) -> Result<Option<ProtoData>, DecoderError> {
1322        let mapped = tensor_bridge::map_tensors(outputs)?;
1323        match &mapped {
1324            tensor_bridge::MappedOutputs::Quantized(maps) => {
1325                let views = tensor_bridge::quantized_views(maps)?;
1326                self.decode_tracked_quantized_proto(
1327                    tracker,
1328                    timestamp,
1329                    &views,
1330                    output_boxes,
1331                    output_tracks,
1332                )
1333            }
1334            tensor_bridge::MappedOutputs::Float32(maps) => {
1335                let views = tensor_bridge::f32_views(maps)?;
1336                self.decode_tracked_float_proto(
1337                    tracker,
1338                    timestamp,
1339                    &views,
1340                    output_boxes,
1341                    output_tracks,
1342                )
1343            }
1344            tensor_bridge::MappedOutputs::Float64(maps) => {
1345                let views = tensor_bridge::f64_views(maps)?;
1346                self.decode_tracked_float_proto(
1347                    tracker,
1348                    timestamp,
1349                    &views,
1350                    output_boxes,
1351                    output_tracks,
1352                )
1353            }
1354        }
1355    }
1356}