Skip to main content

edgefirst_decoder/decoder/
mod.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone)]
15pub struct Decoder {
16    model_type: ModelType,
17    pub iou_threshold: f32,
18    pub score_threshold: f32,
19    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
20    /// models)
21    pub nms: Option<configs::Nms>,
22    /// Maximum number of candidate boxes fed into NMS after score filtering.
23    /// Reduces O(N²) NMS cost when many low-confidence proposals pass the
24    /// threshold (common during COCO mAP evaluation with threshold ≈ 0.001).
25    /// Candidates are ranked by score; only the top `pre_nms_top_k` proceed
26    /// to NMS.  Default: 300.  Ignored when `nms` is `None`.
27    ///
28    /// Applies to segmentation decode paths only. Detection-only models use
29    /// `output_boxes.capacity()` as their implicit output limit.
30    pub pre_nms_top_k: usize,
31    /// Maximum number of detections returned after NMS. Matches the
32    /// Ultralytics `max_det` parameter.  Default: 300.
33    ///
34    /// Applies to segmentation decode paths only. Detection-only models are
35    /// bounded by `output_boxes.capacity()`.
36    pub max_det: usize,
37    /// Whether decoded boxes are in normalized [0,1] coordinates.
38    /// - `Some(true)`: Coordinates in [0,1] range
39    /// - `Some(false)`: Pixel coordinates
40    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
41    ///   1.0)
42    normalized: Option<bool>,
43    /// Schema v2 merge program. Present when the decoder was built from
44    /// a [`crate::schema::SchemaV2`] whose logical outputs carry
45    /// physical children. Absent for flat configurations (v1 and
46    /// flat-v2).
47    pub(crate) decode_program: Option<merge::DecodeProgram>,
48}
49
50impl PartialEq for Decoder {
51    fn eq(&self, other: &Self) -> bool {
52        // DecodeProgram has non-comparable embedded data; compare by
53        // the config-derived fields only.
54        self.model_type == other.model_type
55            && self.iou_threshold == other.iou_threshold
56            && self.score_threshold == other.score_threshold
57            && self.nms == other.nms
58            && self.pre_nms_top_k == other.pre_nms_top_k
59            && self.max_det == other.max_det
60            && self.normalized == other.normalized
61            && self.decode_program.is_some() == other.decode_program.is_some()
62    }
63}
64
65#[derive(Debug)]
66pub(crate) enum ArrayViewDQuantized<'a> {
67    UInt8(ArrayViewD<'a, u8>),
68    Int8(ArrayViewD<'a, i8>),
69    UInt16(ArrayViewD<'a, u16>),
70    Int16(ArrayViewD<'a, i16>),
71    UInt32(ArrayViewD<'a, u32>),
72    Int32(ArrayViewD<'a, i32>),
73}
74
75impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
76where
77    D: Dimension,
78{
79    fn from(arr: ArrayView<'a, u8, D>) -> Self {
80        Self::UInt8(arr.into_dyn())
81    }
82}
83
84impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
85where
86    D: Dimension,
87{
88    fn from(arr: ArrayView<'a, i8, D>) -> Self {
89        Self::Int8(arr.into_dyn())
90    }
91}
92
93impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
94where
95    D: Dimension,
96{
97    fn from(arr: ArrayView<'a, u16, D>) -> Self {
98        Self::UInt16(arr.into_dyn())
99    }
100}
101
102impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
103where
104    D: Dimension,
105{
106    fn from(arr: ArrayView<'a, i16, D>) -> Self {
107        Self::Int16(arr.into_dyn())
108    }
109}
110
111impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
112where
113    D: Dimension,
114{
115    fn from(arr: ArrayView<'a, u32, D>) -> Self {
116        Self::UInt32(arr.into_dyn())
117    }
118}
119
120impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
121where
122    D: Dimension,
123{
124    fn from(arr: ArrayView<'a, i32, D>) -> Self {
125        Self::Int32(arr.into_dyn())
126    }
127}
128
129impl<'a> ArrayViewDQuantized<'a> {
130    /// Returns the shape of the underlying array.
131    pub(crate) fn shape(&self) -> &[usize] {
132        match self {
133            ArrayViewDQuantized::UInt8(a) => a.shape(),
134            ArrayViewDQuantized::Int8(a) => a.shape(),
135            ArrayViewDQuantized::UInt16(a) => a.shape(),
136            ArrayViewDQuantized::Int16(a) => a.shape(),
137            ArrayViewDQuantized::UInt32(a) => a.shape(),
138            ArrayViewDQuantized::Int32(a) => a.shape(),
139        }
140    }
141}
142
143/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
144/// monomorphized code paths by 6 (one per integer variant), so nesting
145/// N levels deep produces 6^N instantiations.
146///
147/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
148/// (6*N paths) or split into independent phases that each nest at most 2 levels.
149macro_rules! with_quantized {
150    ($x:expr, $var:ident, $body:expr) => {
151        match $x {
152            ArrayViewDQuantized::UInt8(x) => {
153                let $var = x;
154                $body
155            }
156            ArrayViewDQuantized::Int8(x) => {
157                let $var = x;
158                $body
159            }
160            ArrayViewDQuantized::UInt16(x) => {
161                let $var = x;
162                $body
163            }
164            ArrayViewDQuantized::Int16(x) => {
165                let $var = x;
166                $body
167            }
168            ArrayViewDQuantized::UInt32(x) => {
169                let $var = x;
170                $body
171            }
172            ArrayViewDQuantized::Int32(x) => {
173                let $var = x;
174                $body
175            }
176        }
177    };
178}
179
180mod builder;
181mod dfl;
182mod helpers;
183mod merge;
184mod postprocess;
185mod tensor_bridge;
186mod tests;
187
188pub use builder::DecoderBuilder;
189pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
190
191impl Decoder {
192    /// This function returns the parsed model type of the decoder.
193    ///
194    /// # Examples
195    ///
196    /// ```rust
197    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
198    /// # fn main() -> DecoderResult<()> {
199    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
200    ///     let decoder = DecoderBuilder::default()
201    ///         .with_config_yaml_str(config_yaml)
202    ///         .build()?;
203    ///     assert!(matches!(
204    ///         decoder.model_type(),
205    ///         ModelType::ModelPackDetSplit { .. }
206    ///     ));
207    /// #    Ok(())
208    /// # }
209    /// ```
210    pub fn model_type(&self) -> &ModelType {
211        &self.model_type
212    }
213
214    /// Returns the box coordinate format if known from the model config.
215    ///
216    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
217    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
218    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
219    ///   1.0)
220    ///
221    /// This is determined by the model config's `normalized` field, not the NMS
222    /// mode. When coordinates are in pixels or unknown, the caller may need
223    /// to normalize using the model input dimensions.
224    ///
225    /// # Examples
226    ///
227    /// ```rust
228    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
229    /// # fn main() -> DecoderResult<()> {
230    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
231    ///     let decoder = DecoderBuilder::default()
232    ///         .with_config_yaml_str(config_yaml)
233    ///         .build()?;
234    ///     // Config doesn't specify normalized, so it's None
235    ///     assert!(decoder.normalized_boxes().is_none());
236    /// #    Ok(())
237    /// # }
238    /// ```
239    pub fn normalized_boxes(&self) -> Option<bool> {
240        self.normalized
241    }
242
243    /// Decode quantized model outputs into detection boxes and segmentation
244    /// masks. The quantized outputs can be of u8, i8, u16, i16, u32, or i32
245    /// types. Clears the provided output vectors before populating them.
246    pub(crate) fn decode_quantized(
247        &self,
248        outputs: &[ArrayViewDQuantized],
249        output_boxes: &mut Vec<DetectBox>,
250        output_masks: &mut Vec<Segmentation>,
251    ) -> Result<(), DecoderError> {
252        output_boxes.clear();
253        output_masks.clear();
254        match &self.model_type {
255            ModelType::ModelPackSegDet {
256                boxes,
257                scores,
258                segmentation,
259            } => {
260                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
261                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
262            }
263            ModelType::ModelPackSegDetSplit {
264                detection,
265                segmentation,
266            } => {
267                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
268                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
269            }
270            ModelType::ModelPackDet { boxes, scores } => {
271                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
272            }
273            ModelType::ModelPackDetSplit { detection } => {
274                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
275            }
276            ModelType::ModelPackSeg { segmentation } => {
277                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
278            }
279            ModelType::YoloDet { boxes } => {
280                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
281            }
282            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
283                outputs,
284                boxes,
285                protos,
286                output_boxes,
287                output_masks,
288            ),
289            ModelType::YoloSplitDet { boxes, scores } => {
290                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
291            }
292            ModelType::YoloSplitSegDet {
293                boxes,
294                scores,
295                mask_coeff,
296                protos,
297            } => self.decode_yolo_split_segdet_quantized(
298                outputs,
299                boxes,
300                scores,
301                mask_coeff,
302                protos,
303                output_boxes,
304                output_masks,
305            ),
306            ModelType::YoloSegDet2Way {
307                boxes,
308                mask_coeff,
309                protos,
310            } => self.decode_yolo_segdet_2way_quantized(
311                outputs,
312                boxes,
313                mask_coeff,
314                protos,
315                output_boxes,
316                output_masks,
317            ),
318            ModelType::YoloEndToEndDet { boxes } => {
319                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
320            }
321            ModelType::YoloEndToEndSegDet { boxes, protos } => self
322                .decode_yolo_end_to_end_segdet_quantized(
323                    outputs,
324                    boxes,
325                    protos,
326                    output_boxes,
327                    output_masks,
328                ),
329            ModelType::YoloSplitEndToEndDet {
330                boxes,
331                scores,
332                classes,
333            } => self.decode_yolo_split_end_to_end_det_quantized(
334                outputs,
335                boxes,
336                scores,
337                classes,
338                output_boxes,
339            ),
340            ModelType::YoloSplitEndToEndSegDet {
341                boxes,
342                scores,
343                classes,
344                mask_coeff,
345                protos,
346            } => self.decode_yolo_split_end_to_end_segdet_quantized(
347                outputs,
348                boxes,
349                scores,
350                classes,
351                mask_coeff,
352                protos,
353                output_boxes,
354                output_masks,
355            ),
356        }
357    }
358
359    /// Decode floating point model outputs into detection boxes and
360    /// segmentation masks. Clears the provided output vectors before
361    /// populating them.
362    pub(crate) fn decode_float<T>(
363        &self,
364        outputs: &[ArrayViewD<T>],
365        output_boxes: &mut Vec<DetectBox>,
366        output_masks: &mut Vec<Segmentation>,
367    ) -> Result<(), DecoderError>
368    where
369        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
370        f32: AsPrimitive<T>,
371    {
372        output_boxes.clear();
373        output_masks.clear();
374        match &self.model_type {
375            ModelType::ModelPackSegDet {
376                boxes,
377                scores,
378                segmentation,
379            } => {
380                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
381                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
382            }
383            ModelType::ModelPackSegDetSplit {
384                detection,
385                segmentation,
386            } => {
387                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
388                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
389            }
390            ModelType::ModelPackDet { boxes, scores } => {
391                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
392            }
393            ModelType::ModelPackDetSplit { detection } => {
394                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
395            }
396            ModelType::ModelPackSeg { segmentation } => {
397                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
398            }
399            ModelType::YoloDet { boxes } => {
400                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
401            }
402            ModelType::YoloSegDet { boxes, protos } => {
403                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
404            }
405            ModelType::YoloSplitDet { boxes, scores } => {
406                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
407            }
408            ModelType::YoloSplitSegDet {
409                boxes,
410                scores,
411                mask_coeff,
412                protos,
413            } => {
414                self.decode_yolo_split_segdet_float(
415                    outputs,
416                    boxes,
417                    scores,
418                    mask_coeff,
419                    protos,
420                    output_boxes,
421                    output_masks,
422                )?;
423            }
424            ModelType::YoloSegDet2Way {
425                boxes,
426                mask_coeff,
427                protos,
428            } => {
429                self.decode_yolo_segdet_2way_float(
430                    outputs,
431                    boxes,
432                    mask_coeff,
433                    protos,
434                    output_boxes,
435                    output_masks,
436                )?;
437            }
438            ModelType::YoloEndToEndDet { boxes } => {
439                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
440            }
441            ModelType::YoloEndToEndSegDet { boxes, protos } => {
442                self.decode_yolo_end_to_end_segdet_float(
443                    outputs,
444                    boxes,
445                    protos,
446                    output_boxes,
447                    output_masks,
448                )?;
449            }
450            ModelType::YoloSplitEndToEndDet {
451                boxes,
452                scores,
453                classes,
454            } => {
455                self.decode_yolo_split_end_to_end_det_float(
456                    outputs,
457                    boxes,
458                    scores,
459                    classes,
460                    output_boxes,
461                )?;
462            }
463            ModelType::YoloSplitEndToEndSegDet {
464                boxes,
465                scores,
466                classes,
467                mask_coeff,
468                protos,
469            } => {
470                self.decode_yolo_split_end_to_end_segdet_float(
471                    outputs,
472                    boxes,
473                    scores,
474                    classes,
475                    mask_coeff,
476                    protos,
477                    output_boxes,
478                    output_masks,
479                )?;
480            }
481        }
482        Ok(())
483    }
484
485    /// Decodes quantized model outputs into detection boxes, returning raw
486    /// `ProtoData` for segmentation models instead of materialized masks.
487    ///
488    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
489    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
490    /// for YOLO segmentation models.
491    pub(crate) fn decode_quantized_proto(
492        &self,
493        outputs: &[ArrayViewDQuantized],
494        output_boxes: &mut Vec<DetectBox>,
495    ) -> Result<Option<ProtoData>, DecoderError> {
496        output_boxes.clear();
497        match &self.model_type {
498            // Detection-only variants: decode boxes, return None for proto data.
499            ModelType::ModelPackDet { boxes, scores } => {
500                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
501                Ok(None)
502            }
503            ModelType::ModelPackDetSplit { detection } => {
504                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
505                Ok(None)
506            }
507            ModelType::YoloDet { boxes } => {
508                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)?;
509                Ok(None)
510            }
511            ModelType::YoloSplitDet { boxes, scores } => {
512                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)?;
513                Ok(None)
514            }
515            ModelType::YoloEndToEndDet { boxes } => {
516                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)?;
517                Ok(None)
518            }
519            ModelType::YoloSplitEndToEndDet {
520                boxes,
521                scores,
522                classes,
523            } => {
524                self.decode_yolo_split_end_to_end_det_quantized(
525                    outputs,
526                    boxes,
527                    scores,
528                    classes,
529                    output_boxes,
530                )?;
531                Ok(None)
532            }
533            // ModelPack seg/segdet variants have no YOLO proto data.
534            ModelType::ModelPackSegDet { boxes, scores, .. } => {
535                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
536                Ok(None)
537            }
538            ModelType::ModelPackSegDetSplit { detection, .. } => {
539                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
540                Ok(None)
541            }
542            ModelType::ModelPackSeg { .. } => Ok(None),
543
544            ModelType::YoloSegDet { boxes, protos } => {
545                let proto =
546                    self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
547                Ok(Some(proto))
548            }
549            ModelType::YoloSplitSegDet {
550                boxes,
551                scores,
552                mask_coeff,
553                protos,
554            } => {
555                let proto = self.decode_yolo_split_segdet_quantized_proto(
556                    outputs,
557                    boxes,
558                    scores,
559                    mask_coeff,
560                    protos,
561                    output_boxes,
562                )?;
563                Ok(Some(proto))
564            }
565            ModelType::YoloSegDet2Way {
566                boxes,
567                mask_coeff,
568                protos,
569            } => {
570                let proto = self.decode_yolo_segdet_2way_quantized_proto(
571                    outputs,
572                    boxes,
573                    mask_coeff,
574                    protos,
575                    output_boxes,
576                )?;
577                Ok(Some(proto))
578            }
579            ModelType::YoloEndToEndSegDet { boxes, protos } => {
580                let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
581                    outputs,
582                    boxes,
583                    protos,
584                    output_boxes,
585                )?;
586                Ok(Some(proto))
587            }
588            ModelType::YoloSplitEndToEndSegDet {
589                boxes,
590                scores,
591                classes,
592                mask_coeff,
593                protos,
594            } => {
595                let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
596                    outputs,
597                    boxes,
598                    scores,
599                    classes,
600                    mask_coeff,
601                    protos,
602                    output_boxes,
603                )?;
604                Ok(Some(proto))
605            }
606        }
607    }
608
609    /// Decodes floating-point model outputs into detection boxes, returning
610    /// raw `ProtoData` for segmentation models instead of materialized masks.
611    ///
612    /// Returns `Ok(None)` for detection-only and ModelPack models (detections
613    /// are still decoded into `output_boxes`). Returns `Ok(Some(ProtoData))`
614    /// for YOLO segmentation models.
615    pub(crate) fn decode_float_proto<T>(
616        &self,
617        outputs: &[ArrayViewD<T>],
618        output_boxes: &mut Vec<DetectBox>,
619    ) -> Result<Option<ProtoData>, DecoderError>
620    where
621        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
622        f32: AsPrimitive<T>,
623    {
624        output_boxes.clear();
625        match &self.model_type {
626            // Detection-only variants: decode boxes, return None for proto data.
627            ModelType::ModelPackDet { boxes, scores } => {
628                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
629                Ok(None)
630            }
631            ModelType::ModelPackDetSplit { detection } => {
632                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
633                Ok(None)
634            }
635            ModelType::YoloDet { boxes } => {
636                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
637                Ok(None)
638            }
639            ModelType::YoloSplitDet { boxes, scores } => {
640                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
641                Ok(None)
642            }
643            ModelType::YoloEndToEndDet { boxes } => {
644                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
645                Ok(None)
646            }
647            ModelType::YoloSplitEndToEndDet {
648                boxes,
649                scores,
650                classes,
651            } => {
652                self.decode_yolo_split_end_to_end_det_float(
653                    outputs,
654                    boxes,
655                    scores,
656                    classes,
657                    output_boxes,
658                )?;
659                Ok(None)
660            }
661            // ModelPack seg/segdet variants have no YOLO proto data.
662            ModelType::ModelPackSegDet { boxes, scores, .. } => {
663                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
664                Ok(None)
665            }
666            ModelType::ModelPackSegDetSplit { detection, .. } => {
667                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
668                Ok(None)
669            }
670            ModelType::ModelPackSeg { .. } => Ok(None),
671
672            ModelType::YoloSegDet { boxes, protos } => {
673                let proto =
674                    self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
675                Ok(Some(proto))
676            }
677            ModelType::YoloSplitSegDet {
678                boxes,
679                scores,
680                mask_coeff,
681                protos,
682            } => {
683                let proto = self.decode_yolo_split_segdet_float_proto(
684                    outputs,
685                    boxes,
686                    scores,
687                    mask_coeff,
688                    protos,
689                    output_boxes,
690                )?;
691                Ok(Some(proto))
692            }
693            ModelType::YoloSegDet2Way {
694                boxes,
695                mask_coeff,
696                protos,
697            } => {
698                let proto = self.decode_yolo_segdet_2way_float_proto(
699                    outputs,
700                    boxes,
701                    mask_coeff,
702                    protos,
703                    output_boxes,
704                )?;
705                Ok(Some(proto))
706            }
707            ModelType::YoloEndToEndSegDet { boxes, protos } => {
708                let proto = self.decode_yolo_end_to_end_segdet_float_proto(
709                    outputs,
710                    boxes,
711                    protos,
712                    output_boxes,
713                )?;
714                Ok(Some(proto))
715            }
716            ModelType::YoloSplitEndToEndSegDet {
717                boxes,
718                scores,
719                classes,
720                mask_coeff,
721                protos,
722            } => {
723                let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
724                    outputs,
725                    boxes,
726                    scores,
727                    classes,
728                    mask_coeff,
729                    protos,
730                    output_boxes,
731                )?;
732                Ok(Some(proto))
733            }
734        }
735    }
736
737    // ========================================================================
738    // TensorDyn-based public API
739    // ========================================================================
740
741    /// Decode model outputs into detection boxes and segmentation masks.
742    ///
743    /// This is the primary decode API. Accepts `TensorDyn` outputs directly
744    /// from model inference. Automatically dispatches to quantized or float
745    /// paths based on the tensor dtype.
746    ///
747    /// # Arguments
748    ///
749    /// * `outputs` - Tensor outputs from model inference
750    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
751    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
752    ///
753    /// # Errors
754    ///
755    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
756    /// or the outputs don't match the decoder's model configuration.
757    pub fn decode(
758        &self,
759        outputs: &[&edgefirst_tensor::TensorDyn],
760        output_boxes: &mut Vec<DetectBox>,
761        output_masks: &mut Vec<Segmentation>,
762    ) -> Result<(), DecoderError> {
763        // Schema v2 merge path: dequantize physical children into
764        // logical float32 tensors, then feed through the float dispatch.
765        if let Some(program) = &self.decode_program {
766            let merged = program.execute(outputs)?;
767            let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
768            return self.decode_float(&views, output_boxes, output_masks);
769        }
770
771        let mapped = tensor_bridge::map_tensors(outputs)?;
772        match &mapped {
773            tensor_bridge::MappedOutputs::Quantized(maps) => {
774                let views = tensor_bridge::quantized_views(maps)?;
775                self.decode_quantized(&views, output_boxes, output_masks)
776            }
777            tensor_bridge::MappedOutputs::Float16(maps) => {
778                let views = tensor_bridge::f16_views(maps)?;
779                self.decode_float(&views, output_boxes, output_masks)
780            }
781            tensor_bridge::MappedOutputs::Float32(maps) => {
782                let views = tensor_bridge::f32_views(maps)?;
783                self.decode_float(&views, output_boxes, output_masks)
784            }
785            tensor_bridge::MappedOutputs::Float64(maps) => {
786                let views = tensor_bridge::f64_views(maps)?;
787                self.decode_float(&views, output_boxes, output_masks)
788            }
789        }
790    }
791
792    /// Decode model outputs into detection boxes, returning raw proto data
793    /// for segmentation models instead of materialized masks.
794    ///
795    /// Accepts `TensorDyn` outputs directly from model inference.
796    /// Detections are always decoded into `output_boxes` regardless of model type.
797    /// Returns `Ok(None)` for detection-only and ModelPack models.
798    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
799    ///
800    /// # Arguments
801    ///
802    /// * `outputs` - Tensor outputs from model inference
803    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
804    ///
805    /// # Errors
806    ///
807    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
808    /// or the outputs don't match the decoder's model configuration.
809    pub fn decode_proto(
810        &self,
811        outputs: &[&edgefirst_tensor::TensorDyn],
812        output_boxes: &mut Vec<DetectBox>,
813    ) -> Result<Option<ProtoData>, DecoderError> {
814        // Schema v2 merge path: dequantize physical children into
815        // logical float32 tensors, then feed through the float dispatch.
816        if let Some(program) = &self.decode_program {
817            let merged = program.execute(outputs)?;
818            let views: Vec<_> = merged.iter().map(|a| a.view()).collect();
819            return self.decode_float_proto(&views, output_boxes);
820        }
821
822        let mapped = tensor_bridge::map_tensors(outputs)?;
823        let result = match &mapped {
824            tensor_bridge::MappedOutputs::Quantized(maps) => {
825                let views = tensor_bridge::quantized_views(maps)?;
826                self.decode_quantized_proto(&views, output_boxes)
827            }
828            tensor_bridge::MappedOutputs::Float16(maps) => {
829                let views = tensor_bridge::f16_views(maps)?;
830                self.decode_float_proto(&views, output_boxes)
831            }
832            tensor_bridge::MappedOutputs::Float32(maps) => {
833                let views = tensor_bridge::f32_views(maps)?;
834                self.decode_float_proto(&views, output_boxes)
835            }
836            tensor_bridge::MappedOutputs::Float64(maps) => {
837                let views = tensor_bridge::f64_views(maps)?;
838                self.decode_float_proto(&views, output_boxes)
839            }
840        };
841        result
842    }
843}
844
845#[cfg(feature = "tracker")]
846pub use edgefirst_tracker::TrackInfo;
847
848#[cfg(feature = "tracker")]
849pub use edgefirst_tracker::Tracker;
850
851#[cfg(feature = "tracker")]
852impl Decoder {
853    /// Decode quantized model outputs into detection boxes and segmentation
854    /// masks with tracking. Clears the provided output vectors before
855    /// populating them.
856    pub(crate) fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
857        &self,
858        tracker: &mut TR,
859        timestamp: u64,
860        outputs: &[ArrayViewDQuantized],
861        output_boxes: &mut Vec<DetectBox>,
862        output_masks: &mut Vec<Segmentation>,
863        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
864    ) -> Result<(), DecoderError> {
865        output_boxes.clear();
866        output_masks.clear();
867        output_tracks.clear();
868
869        // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
870        // Only boxes that come from decoding can be used for proto/mask generation.
871        match &self.model_type {
872            ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
873                tracker,
874                timestamp,
875                outputs,
876                boxes,
877                protos,
878                output_boxes,
879                output_masks,
880                output_tracks,
881            ),
882            ModelType::YoloSplitSegDet {
883                boxes,
884                scores,
885                mask_coeff,
886                protos,
887            } => self.decode_tracked_yolo_split_segdet_quantized(
888                tracker,
889                timestamp,
890                outputs,
891                boxes,
892                scores,
893                mask_coeff,
894                protos,
895                output_boxes,
896                output_masks,
897                output_tracks,
898            ),
899            ModelType::YoloEndToEndSegDet { boxes, protos } => self
900                .decode_tracked_yolo_end_to_end_segdet_quantized(
901                    tracker,
902                    timestamp,
903                    outputs,
904                    boxes,
905                    protos,
906                    output_boxes,
907                    output_masks,
908                    output_tracks,
909                ),
910            ModelType::YoloSplitEndToEndSegDet {
911                boxes,
912                scores,
913                classes,
914                mask_coeff,
915                protos,
916            } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
917                tracker,
918                timestamp,
919                outputs,
920                boxes,
921                scores,
922                classes,
923                mask_coeff,
924                protos,
925                output_boxes,
926                output_masks,
927                output_tracks,
928            ),
929            ModelType::YoloSegDet2Way {
930                boxes,
931                mask_coeff,
932                protos,
933            } => self.decode_tracked_yolo_segdet_2way_quantized(
934                tracker,
935                timestamp,
936                outputs,
937                boxes,
938                mask_coeff,
939                protos,
940                output_boxes,
941                output_masks,
942                output_tracks,
943            ),
944            _ => {
945                self.decode_quantized(outputs, output_boxes, output_masks)?;
946                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
947                Ok(())
948            }
949        }
950    }
951
952    /// This function decodes floating point model outputs into detection boxes
953    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
954    /// masks will be decoded. The function clears the provided output
955    /// vectors before populating them with the decoded results.
956    ///
957    /// This function returns an `Error` if the provided outputs don't
958    /// match the configuration provided by the user when building the decoder.
959    ///
960    /// Any quantization information in the configuration will be ignored.
961    pub(crate) fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
962        &self,
963        tracker: &mut TR,
964        timestamp: u64,
965        outputs: &[ArrayViewD<T>],
966        output_boxes: &mut Vec<DetectBox>,
967        output_masks: &mut Vec<Segmentation>,
968        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
969    ) -> Result<(), DecoderError>
970    where
971        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
972        f32: AsPrimitive<T>,
973    {
974        output_boxes.clear();
975        output_masks.clear();
976        output_tracks.clear();
977        match &self.model_type {
978            ModelType::YoloSegDet { boxes, protos } => {
979                self.decode_tracked_yolo_segdet_float(
980                    tracker,
981                    timestamp,
982                    outputs,
983                    boxes,
984                    protos,
985                    output_boxes,
986                    output_masks,
987                    output_tracks,
988                )?;
989            }
990            ModelType::YoloSplitSegDet {
991                boxes,
992                scores,
993                mask_coeff,
994                protos,
995            } => {
996                self.decode_tracked_yolo_split_segdet_float(
997                    tracker,
998                    timestamp,
999                    outputs,
1000                    boxes,
1001                    scores,
1002                    mask_coeff,
1003                    protos,
1004                    output_boxes,
1005                    output_masks,
1006                    output_tracks,
1007                )?;
1008            }
1009            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1010                self.decode_tracked_yolo_end_to_end_segdet_float(
1011                    tracker,
1012                    timestamp,
1013                    outputs,
1014                    boxes,
1015                    protos,
1016                    output_boxes,
1017                    output_masks,
1018                    output_tracks,
1019                )?;
1020            }
1021            ModelType::YoloSplitEndToEndSegDet {
1022                boxes,
1023                scores,
1024                classes,
1025                mask_coeff,
1026                protos,
1027            } => {
1028                self.decode_tracked_yolo_split_end_to_end_segdet_float(
1029                    tracker,
1030                    timestamp,
1031                    outputs,
1032                    boxes,
1033                    scores,
1034                    classes,
1035                    mask_coeff,
1036                    protos,
1037                    output_boxes,
1038                    output_masks,
1039                    output_tracks,
1040                )?;
1041            }
1042            ModelType::YoloSegDet2Way {
1043                boxes,
1044                mask_coeff,
1045                protos,
1046            } => {
1047                self.decode_tracked_yolo_segdet_2way_float(
1048                    tracker,
1049                    timestamp,
1050                    outputs,
1051                    boxes,
1052                    mask_coeff,
1053                    protos,
1054                    output_boxes,
1055                    output_masks,
1056                    output_tracks,
1057                )?;
1058            }
1059            _ => {
1060                self.decode_float(outputs, output_boxes, output_masks)?;
1061                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1062            }
1063        }
1064        Ok(())
1065    }
1066
1067    /// Decodes quantized model outputs into detection boxes, returning raw
1068    /// `ProtoData` for segmentation models instead of materialized masks.
1069    ///
1070    /// Returns `Ok(None)` for detection-only and ModelPack models (use
1071    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
1072    /// YOLO segmentation models.
1073    pub(crate) fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
1074        &self,
1075        tracker: &mut TR,
1076        timestamp: u64,
1077        outputs: &[ArrayViewDQuantized],
1078        output_boxes: &mut Vec<DetectBox>,
1079        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1080    ) -> Result<Option<ProtoData>, DecoderError> {
1081        output_boxes.clear();
1082        output_tracks.clear();
1083        match &self.model_type {
1084            ModelType::YoloSegDet { boxes, protos } => {
1085                let proto = self.decode_tracked_yolo_segdet_quantized_proto(
1086                    tracker,
1087                    timestamp,
1088                    outputs,
1089                    boxes,
1090                    protos,
1091                    output_boxes,
1092                    output_tracks,
1093                )?;
1094                Ok(Some(proto))
1095            }
1096            ModelType::YoloSplitSegDet {
1097                boxes,
1098                scores,
1099                mask_coeff,
1100                protos,
1101            } => {
1102                let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1103                    tracker,
1104                    timestamp,
1105                    outputs,
1106                    boxes,
1107                    scores,
1108                    mask_coeff,
1109                    protos,
1110                    output_boxes,
1111                    output_tracks,
1112                )?;
1113                Ok(Some(proto))
1114            }
1115            ModelType::YoloSegDet2Way {
1116                boxes,
1117                mask_coeff,
1118                protos,
1119            } => {
1120                let proto = self.decode_tracked_yolo_segdet_2way_quantized_proto(
1121                    tracker,
1122                    timestamp,
1123                    outputs,
1124                    boxes,
1125                    mask_coeff,
1126                    protos,
1127                    output_boxes,
1128                    output_tracks,
1129                )?;
1130                Ok(Some(proto))
1131            }
1132            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1133                let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1134                    tracker,
1135                    timestamp,
1136                    outputs,
1137                    boxes,
1138                    protos,
1139                    output_boxes,
1140                    output_tracks,
1141                )?;
1142                Ok(Some(proto))
1143            }
1144            ModelType::YoloSplitEndToEndSegDet {
1145                boxes,
1146                scores,
1147                classes,
1148                mask_coeff,
1149                protos,
1150            } => {
1151                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1152                    tracker,
1153                    timestamp,
1154                    outputs,
1155                    boxes,
1156                    scores,
1157                    classes,
1158                    mask_coeff,
1159                    protos,
1160                    output_boxes,
1161                    output_tracks,
1162                )?;
1163                Ok(Some(proto))
1164            }
1165            // Non-seg variants: decode boxes via the non-proto path, then track.
1166            _ => {
1167                let mut masks = Vec::new();
1168                self.decode_quantized(outputs, output_boxes, &mut masks)?;
1169                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1170                Ok(None)
1171            }
1172        }
1173    }
1174
1175    /// Decodes floating-point model outputs into detection boxes, returning
1176    /// raw `ProtoData` for segmentation models instead of materialized masks.
1177    ///
1178    /// Detections are always decoded into `output_boxes` regardless of model type.
1179    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1180    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1181    pub(crate) fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1182        &self,
1183        tracker: &mut TR,
1184        timestamp: u64,
1185        outputs: &[ArrayViewD<T>],
1186        output_boxes: &mut Vec<DetectBox>,
1187        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1188    ) -> Result<Option<ProtoData>, DecoderError>
1189    where
1190        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + crate::yolo::FloatProtoElem,
1191        f32: AsPrimitive<T>,
1192    {
1193        output_boxes.clear();
1194        output_tracks.clear();
1195        match &self.model_type {
1196            ModelType::YoloSegDet { boxes, protos } => {
1197                let proto = self.decode_tracked_yolo_segdet_float_proto(
1198                    tracker,
1199                    timestamp,
1200                    outputs,
1201                    boxes,
1202                    protos,
1203                    output_boxes,
1204                    output_tracks,
1205                )?;
1206                Ok(Some(proto))
1207            }
1208            ModelType::YoloSplitSegDet {
1209                boxes,
1210                scores,
1211                mask_coeff,
1212                protos,
1213            } => {
1214                let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1215                    tracker,
1216                    timestamp,
1217                    outputs,
1218                    boxes,
1219                    scores,
1220                    mask_coeff,
1221                    protos,
1222                    output_boxes,
1223                    output_tracks,
1224                )?;
1225                Ok(Some(proto))
1226            }
1227            ModelType::YoloSegDet2Way {
1228                boxes,
1229                mask_coeff,
1230                protos,
1231            } => {
1232                let proto = self.decode_tracked_yolo_segdet_2way_float_proto(
1233                    tracker,
1234                    timestamp,
1235                    outputs,
1236                    boxes,
1237                    mask_coeff,
1238                    protos,
1239                    output_boxes,
1240                    output_tracks,
1241                )?;
1242                Ok(Some(proto))
1243            }
1244            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1245                let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1246                    tracker,
1247                    timestamp,
1248                    outputs,
1249                    boxes,
1250                    protos,
1251                    output_boxes,
1252                    output_tracks,
1253                )?;
1254                Ok(Some(proto))
1255            }
1256            ModelType::YoloSplitEndToEndSegDet {
1257                boxes,
1258                scores,
1259                classes,
1260                mask_coeff,
1261                protos,
1262            } => {
1263                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1264                    tracker,
1265                    timestamp,
1266                    outputs,
1267                    boxes,
1268                    scores,
1269                    classes,
1270                    mask_coeff,
1271                    protos,
1272                    output_boxes,
1273                    output_tracks,
1274                )?;
1275                Ok(Some(proto))
1276            }
1277            // Non-seg variants: decode boxes via the non-proto path, then track.
1278            _ => {
1279                let mut masks = Vec::new();
1280                self.decode_float(outputs, output_boxes, &mut masks)?;
1281                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
1282                Ok(None)
1283            }
1284        }
1285    }
1286
1287    // ========================================================================
1288    // TensorDyn-based tracked public API
1289    // ========================================================================
1290
1291    /// Decode model outputs with tracking.
1292    ///
1293    /// Accepts `TensorDyn` outputs directly from model inference. Automatically
1294    /// dispatches to quantized or float paths based on the tensor dtype, then
1295    /// updates the tracker with the decoded boxes.
1296    ///
1297    /// # Arguments
1298    ///
1299    /// * `tracker` - The tracker instance to update
1300    /// * `timestamp` - Current frame timestamp
1301    /// * `outputs` - Tensor outputs from model inference
1302    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1303    /// * `output_masks` - Destination for decoded segmentation masks (cleared first)
1304    /// * `output_tracks` - Destination for track info (cleared first)
1305    ///
1306    /// # Errors
1307    ///
1308    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1309    /// or the outputs don't match the decoder's model configuration.
1310    pub fn decode_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1311        &self,
1312        tracker: &mut TR,
1313        timestamp: u64,
1314        outputs: &[&edgefirst_tensor::TensorDyn],
1315        output_boxes: &mut Vec<DetectBox>,
1316        output_masks: &mut Vec<Segmentation>,
1317        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1318    ) -> Result<(), DecoderError> {
1319        let mapped = tensor_bridge::map_tensors(outputs)?;
1320        match &mapped {
1321            tensor_bridge::MappedOutputs::Quantized(maps) => {
1322                let views = tensor_bridge::quantized_views(maps)?;
1323                self.decode_tracked_quantized(
1324                    tracker,
1325                    timestamp,
1326                    &views,
1327                    output_boxes,
1328                    output_masks,
1329                    output_tracks,
1330                )
1331            }
1332            tensor_bridge::MappedOutputs::Float16(maps) => {
1333                let views = tensor_bridge::f16_views(maps)?;
1334                self.decode_tracked_float(
1335                    tracker,
1336                    timestamp,
1337                    &views,
1338                    output_boxes,
1339                    output_masks,
1340                    output_tracks,
1341                )
1342            }
1343            tensor_bridge::MappedOutputs::Float32(maps) => {
1344                let views = tensor_bridge::f32_views(maps)?;
1345                self.decode_tracked_float(
1346                    tracker,
1347                    timestamp,
1348                    &views,
1349                    output_boxes,
1350                    output_masks,
1351                    output_tracks,
1352                )
1353            }
1354            tensor_bridge::MappedOutputs::Float64(maps) => {
1355                let views = tensor_bridge::f64_views(maps)?;
1356                self.decode_tracked_float(
1357                    tracker,
1358                    timestamp,
1359                    &views,
1360                    output_boxes,
1361                    output_masks,
1362                    output_tracks,
1363                )
1364            }
1365        }
1366    }
1367
1368    /// Decode model outputs with tracking, returning raw proto data for
1369    /// segmentation models.
1370    ///
1371    /// Accepts `TensorDyn` outputs directly from model inference.
1372    /// Returns `Ok(None)` for detection-only and ModelPack models.
1373    /// Returns `Ok(Some(ProtoData))` for YOLO segmentation models.
1374    ///
1375    /// # Arguments
1376    ///
1377    /// * `tracker` - The tracker instance to update
1378    /// * `timestamp` - Current frame timestamp
1379    /// * `outputs` - Tensor outputs from model inference
1380    /// * `output_boxes` - Destination for decoded detection boxes (cleared first)
1381    /// * `output_tracks` - Destination for track info (cleared first)
1382    ///
1383    /// # Errors
1384    ///
1385    /// Returns `DecoderError` if tensor mapping fails, dtypes are unsupported,
1386    /// or the outputs don't match the decoder's model configuration.
1387    pub fn decode_proto_tracked<TR: edgefirst_tracker::Tracker<DetectBox>>(
1388        &self,
1389        tracker: &mut TR,
1390        timestamp: u64,
1391        outputs: &[&edgefirst_tensor::TensorDyn],
1392        output_boxes: &mut Vec<DetectBox>,
1393        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1394    ) -> Result<Option<ProtoData>, DecoderError> {
1395        let mapped = tensor_bridge::map_tensors(outputs)?;
1396        match &mapped {
1397            tensor_bridge::MappedOutputs::Quantized(maps) => {
1398                let views = tensor_bridge::quantized_views(maps)?;
1399                self.decode_tracked_quantized_proto(
1400                    tracker,
1401                    timestamp,
1402                    &views,
1403                    output_boxes,
1404                    output_tracks,
1405                )
1406            }
1407            tensor_bridge::MappedOutputs::Float16(maps) => {
1408                let views = tensor_bridge::f16_views(maps)?;
1409                self.decode_tracked_float_proto(
1410                    tracker,
1411                    timestamp,
1412                    &views,
1413                    output_boxes,
1414                    output_tracks,
1415                )
1416            }
1417            tensor_bridge::MappedOutputs::Float32(maps) => {
1418                let views = tensor_bridge::f32_views(maps)?;
1419                self.decode_tracked_float_proto(
1420                    tracker,
1421                    timestamp,
1422                    &views,
1423                    output_boxes,
1424                    output_tracks,
1425                )
1426            }
1427            tensor_bridge::MappedOutputs::Float64(maps) => {
1428                let views = tensor_bridge::f64_views(maps)?;
1429                self.decode_tracked_float_proto(
1430                    tracker,
1431                    timestamp,
1432                    &views,
1433                    output_boxes,
1434                    output_tracks,
1435                )
1436            }
1437        }
1438    }
1439}