Skip to main content

edgefirst_decoder/decoder/
mod.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone, PartialEq)]
15pub struct Decoder {
16    model_type: ModelType,
17    pub iou_threshold: f32,
18    pub score_threshold: f32,
19    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
20    /// models)
21    pub nms: Option<configs::Nms>,
22    /// Whether decoded boxes are in normalized [0,1] coordinates.
23    /// - `Some(true)`: Coordinates in [0,1] range
24    /// - `Some(false)`: Pixel coordinates
25    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
26    ///   1.0)
27    normalized: Option<bool>,
28}
29
30#[derive(Debug)]
31pub enum ArrayViewDQuantized<'a> {
32    UInt8(ArrayViewD<'a, u8>),
33    Int8(ArrayViewD<'a, i8>),
34    UInt16(ArrayViewD<'a, u16>),
35    Int16(ArrayViewD<'a, i16>),
36    UInt32(ArrayViewD<'a, u32>),
37    Int32(ArrayViewD<'a, i32>),
38}
39
40impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
41where
42    D: Dimension,
43{
44    fn from(arr: ArrayView<'a, u8, D>) -> Self {
45        Self::UInt8(arr.into_dyn())
46    }
47}
48
49impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
50where
51    D: Dimension,
52{
53    fn from(arr: ArrayView<'a, i8, D>) -> Self {
54        Self::Int8(arr.into_dyn())
55    }
56}
57
58impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
59where
60    D: Dimension,
61{
62    fn from(arr: ArrayView<'a, u16, D>) -> Self {
63        Self::UInt16(arr.into_dyn())
64    }
65}
66
67impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
68where
69    D: Dimension,
70{
71    fn from(arr: ArrayView<'a, i16, D>) -> Self {
72        Self::Int16(arr.into_dyn())
73    }
74}
75
76impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
77where
78    D: Dimension,
79{
80    fn from(arr: ArrayView<'a, u32, D>) -> Self {
81        Self::UInt32(arr.into_dyn())
82    }
83}
84
85impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
86where
87    D: Dimension,
88{
89    fn from(arr: ArrayView<'a, i32, D>) -> Self {
90        Self::Int32(arr.into_dyn())
91    }
92}
93
94impl<'a> ArrayViewDQuantized<'a> {
95    /// Returns the shape of the underlying array.
96    ///
97    /// # Examples
98    /// ```rust
99    /// # use edgefirst_decoder::ArrayViewDQuantized;
100    /// # use ndarray::Array2;
101    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
102    /// let arr = Array2::from_shape_vec((2, 3), vec![1u8, 2, 3, 4, 5, 6])?;
103    /// let view = ArrayViewDQuantized::from(arr.view().into_dyn());
104    /// assert_eq!(view.shape(), &[2, 3]);
105    /// # Ok(())
106    /// # }
107    /// ```
108    pub fn shape(&self) -> &[usize] {
109        match self {
110            ArrayViewDQuantized::UInt8(a) => a.shape(),
111            ArrayViewDQuantized::Int8(a) => a.shape(),
112            ArrayViewDQuantized::UInt16(a) => a.shape(),
113            ArrayViewDQuantized::Int16(a) => a.shape(),
114            ArrayViewDQuantized::UInt32(a) => a.shape(),
115            ArrayViewDQuantized::Int32(a) => a.shape(),
116        }
117    }
118}
119
120/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
121/// monomorphized code paths by 6 (one per integer variant), so nesting
122/// N levels deep produces 6^N instantiations.
123///
124/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
125/// (6*N paths) or split into independent phases that each nest at most 2 levels.
126macro_rules! with_quantized {
127    ($x:expr, $var:ident, $body:expr) => {
128        match $x {
129            ArrayViewDQuantized::UInt8(x) => {
130                let $var = x;
131                $body
132            }
133            ArrayViewDQuantized::Int8(x) => {
134                let $var = x;
135                $body
136            }
137            ArrayViewDQuantized::UInt16(x) => {
138                let $var = x;
139                $body
140            }
141            ArrayViewDQuantized::Int16(x) => {
142                let $var = x;
143                $body
144            }
145            ArrayViewDQuantized::UInt32(x) => {
146                let $var = x;
147                $body
148            }
149            ArrayViewDQuantized::Int32(x) => {
150                let $var = x;
151                $body
152            }
153        }
154    };
155}
156
157mod builder;
158mod helpers;
159mod postprocess;
160mod tests;
161
162pub use builder::DecoderBuilder;
163pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
164
165impl Decoder {
166    /// This function returns the parsed model type of the decoder.
167    ///
168    /// # Examples
169    ///
170    /// ```rust
171    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
172    /// # fn main() -> DecoderResult<()> {
173    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
174    ///     let decoder = DecoderBuilder::default()
175    ///         .with_config_yaml_str(config_yaml)
176    ///         .build()?;
177    ///     assert!(matches!(
178    ///         decoder.model_type(),
179    ///         ModelType::ModelPackDetSplit { .. }
180    ///     ));
181    /// #    Ok(())
182    /// # }
183    /// ```
184    pub fn model_type(&self) -> &ModelType {
185        &self.model_type
186    }
187
188    /// Returns the box coordinate format if known from the model config.
189    ///
190    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
191    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
192    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
193    ///   1.0)
194    ///
195    /// This is determined by the model config's `normalized` field, not the NMS
196    /// mode. When coordinates are in pixels or unknown, the caller may need
197    /// to normalize using the model input dimensions.
198    ///
199    /// # Examples
200    ///
201    /// ```rust
202    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
203    /// # fn main() -> DecoderResult<()> {
204    /// #    let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
205    ///     let decoder = DecoderBuilder::default()
206    ///         .with_config_yaml_str(config_yaml)
207    ///         .build()?;
208    ///     // Config doesn't specify normalized, so it's None
209    ///     assert!(decoder.normalized_boxes().is_none());
210    /// #    Ok(())
211    /// # }
212    /// ```
213    pub fn normalized_boxes(&self) -> Option<bool> {
214        self.normalized
215    }
216
217    /// This function decodes quantized model outputs into detection boxes and
218    /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
219    /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
220    /// will be decoded. The function clears the provided output vectors
221    /// before populating them with the decoded results.
222    ///
223    /// This function returns a `DecoderError` if the the provided outputs don't
224    /// match the configuration provided by the user when building the decoder.
225    ///
226    /// # Examples
227    ///
228    /// ```rust
229    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
230    /// # use ndarray::Array4;
231    /// # fn main() -> DecoderResult<()> {
232    /// #    let detect0 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_9x15x18.bin"));
233    /// #    let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
234    /// #
235    /// #    let detect1 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_17x30x18.bin"));
236    /// #    let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
237    /// #    let model_output = vec![
238    /// #        detect1.view().into_dyn().into(),
239    /// #        detect0.view().into_dyn().into(),
240    /// #    ];
241    /// let decoder = DecoderBuilder::default()
242    ///     .with_config_yaml_str(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string())
243    ///     .with_score_threshold(0.45)
244    ///     .with_iou_threshold(0.45)
245    ///     .build()?;
246    ///
247    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
248    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
249    /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
250    /// assert!(output_boxes[0].equal_within_delta(
251    ///     &DetectBox {
252    ///         bbox: BoundingBox {
253    ///             xmin: 0.43171933,
254    ///             ymin: 0.68243736,
255    ///             xmax: 0.5626645,
256    ///             ymax: 0.808863,
257    ///         },
258    ///         score: 0.99240804,
259    ///         label: 0
260    ///     },
261    ///     1e-6
262    /// ));
263    /// #    Ok(())
264    /// # }
265    /// ```
266    pub fn decode_quantized(
267        &self,
268        outputs: &[ArrayViewDQuantized],
269        output_boxes: &mut Vec<DetectBox>,
270        output_masks: &mut Vec<Segmentation>,
271    ) -> Result<(), DecoderError> {
272        output_boxes.clear();
273        output_masks.clear();
274        match &self.model_type {
275            ModelType::ModelPackSegDet {
276                boxes,
277                scores,
278                segmentation,
279            } => {
280                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
281                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
282            }
283            ModelType::ModelPackSegDetSplit {
284                detection,
285                segmentation,
286            } => {
287                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
288                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
289            }
290            ModelType::ModelPackDet { boxes, scores } => {
291                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
292            }
293            ModelType::ModelPackDetSplit { detection } => {
294                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
295            }
296            ModelType::ModelPackSeg { segmentation } => {
297                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
298            }
299            ModelType::YoloDet { boxes } => {
300                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
301            }
302            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
303                outputs,
304                boxes,
305                protos,
306                output_boxes,
307                output_masks,
308            ),
309            ModelType::YoloSplitDet { boxes, scores } => {
310                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
311            }
312            ModelType::YoloSplitSegDet {
313                boxes,
314                scores,
315                mask_coeff,
316                protos,
317            } => self.decode_yolo_split_segdet_quantized(
318                outputs,
319                boxes,
320                scores,
321                mask_coeff,
322                protos,
323                output_boxes,
324                output_masks,
325            ),
326            ModelType::YoloEndToEndDet { boxes } => {
327                self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
328            }
329            ModelType::YoloEndToEndSegDet { boxes, protos } => self
330                .decode_yolo_end_to_end_segdet_quantized(
331                    outputs,
332                    boxes,
333                    protos,
334                    output_boxes,
335                    output_masks,
336                ),
337            ModelType::YoloSplitEndToEndDet {
338                boxes,
339                scores,
340                classes,
341            } => self.decode_yolo_split_end_to_end_det_quantized(
342                outputs,
343                boxes,
344                scores,
345                classes,
346                output_boxes,
347            ),
348            ModelType::YoloSplitEndToEndSegDet {
349                boxes,
350                scores,
351                classes,
352                mask_coeff,
353                protos,
354            } => self.decode_yolo_split_end_to_end_segdet_quantized(
355                outputs,
356                boxes,
357                scores,
358                classes,
359                mask_coeff,
360                protos,
361                output_boxes,
362                output_masks,
363            ),
364        }
365    }
366
367    /// This function decodes floating point model outputs into detection boxes
368    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
369    /// masks will be decoded. The function clears the provided output
370    /// vectors before populating them with the decoded results.
371    ///
372    /// This function returns an `Error` if the the provided outputs don't
373    /// match the configuration provided by the user when building the decoder.
374    ///
375    /// Any quantization information in the configuration will be ignored.
376    ///
377    /// # Examples
378    ///
379    /// ```rust
380    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
381    /// # use ndarray::Array3;
382    /// # fn main() -> DecoderResult<()> {
383    /// #   let out = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/yolov8s_80_classes.bin"));
384    /// #   let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
385    /// #   let mut out_dequant = vec![0.0_f64; 84 * 8400];
386    /// #   let quant = Quantization::new(0.0040811873, -123);
387    /// #   dequantize_cpu(out, quant, &mut out_dequant);
388    /// #   let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
389    ///    let decoder = DecoderBuilder::default()
390    ///     .with_config_yolo_det(configs::Detection {
391    ///         decoder: DecoderType::Ultralytics,
392    ///         quantization: None,
393    ///         shape: vec![1, 84, 8400],
394    ///         anchors: None,
395    ///         dshape: Vec::new(),
396    ///         normalized: Some(true),
397    ///     },
398    ///     Some(DecoderVersion::Yolo11))
399    ///     .with_score_threshold(0.25)
400    ///     .with_iou_threshold(0.7)
401    ///     .build()?;
402    ///
403    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
404    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
405    /// let model_output_f64 = vec![model_output_f64.view().into()];
406    /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;    
407    /// assert!(output_boxes[0].equal_within_delta(
408    ///        &DetectBox {
409    ///            bbox: BoundingBox {
410    ///                xmin: 0.5285137,
411    ///                ymin: 0.05305544,
412    ///                xmax: 0.87541467,
413    ///                ymax: 0.9998909,
414    ///            },
415    ///            score: 0.5591227,
416    ///            label: 0
417    ///        },
418    ///        1e-6
419    ///    ));
420    ///
421    /// #    Ok(())
422    /// # }
423    pub fn decode_float<T>(
424        &self,
425        outputs: &[ArrayViewD<T>],
426        output_boxes: &mut Vec<DetectBox>,
427        output_masks: &mut Vec<Segmentation>,
428    ) -> Result<(), DecoderError>
429    where
430        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
431        f32: AsPrimitive<T>,
432    {
433        output_boxes.clear();
434        output_masks.clear();
435        match &self.model_type {
436            ModelType::ModelPackSegDet {
437                boxes,
438                scores,
439                segmentation,
440            } => {
441                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
442                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
443            }
444            ModelType::ModelPackSegDetSplit {
445                detection,
446                segmentation,
447            } => {
448                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
449                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
450            }
451            ModelType::ModelPackDet { boxes, scores } => {
452                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
453            }
454            ModelType::ModelPackDetSplit { detection } => {
455                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
456            }
457            ModelType::ModelPackSeg { segmentation } => {
458                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
459            }
460            ModelType::YoloDet { boxes } => {
461                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
462            }
463            ModelType::YoloSegDet { boxes, protos } => {
464                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
465            }
466            ModelType::YoloSplitDet { boxes, scores } => {
467                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
468            }
469            ModelType::YoloSplitSegDet {
470                boxes,
471                scores,
472                mask_coeff,
473                protos,
474            } => {
475                self.decode_yolo_split_segdet_float(
476                    outputs,
477                    boxes,
478                    scores,
479                    mask_coeff,
480                    protos,
481                    output_boxes,
482                    output_masks,
483                )?;
484            }
485            ModelType::YoloEndToEndDet { boxes } => {
486                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
487            }
488            ModelType::YoloEndToEndSegDet { boxes, protos } => {
489                self.decode_yolo_end_to_end_segdet_float(
490                    outputs,
491                    boxes,
492                    protos,
493                    output_boxes,
494                    output_masks,
495                )?;
496            }
497            ModelType::YoloSplitEndToEndDet {
498                boxes,
499                scores,
500                classes,
501            } => {
502                self.decode_yolo_split_end_to_end_det_float(
503                    outputs,
504                    boxes,
505                    scores,
506                    classes,
507                    output_boxes,
508                )?;
509            }
510            ModelType::YoloSplitEndToEndSegDet {
511                boxes,
512                scores,
513                classes,
514                mask_coeff,
515                protos,
516            } => {
517                self.decode_yolo_split_end_to_end_segdet_float(
518                    outputs,
519                    boxes,
520                    scores,
521                    classes,
522                    mask_coeff,
523                    protos,
524                    output_boxes,
525                    output_masks,
526                )?;
527            }
528        }
529        Ok(())
530    }
531
532    /// Decodes quantized model outputs into detection boxes, returning raw
533    /// `ProtoData` for segmentation models instead of materialized masks.
534    ///
535    /// Returns `Ok(None)` for detection-only and ModelPack models (use
536    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
537    /// YOLO segmentation models.
538    pub fn decode_quantized_proto(
539        &self,
540        outputs: &[ArrayViewDQuantized],
541        output_boxes: &mut Vec<DetectBox>,
542    ) -> Result<Option<ProtoData>, DecoderError> {
543        output_boxes.clear();
544        match &self.model_type {
545            // Detection-only and ModelPack variants: no proto data
546            ModelType::ModelPackSegDet { .. }
547            | ModelType::ModelPackSegDetSplit { .. }
548            | ModelType::ModelPackDet { .. }
549            | ModelType::ModelPackDetSplit { .. }
550            | ModelType::ModelPackSeg { .. }
551            | ModelType::YoloDet { .. }
552            | ModelType::YoloSplitDet { .. }
553            | ModelType::YoloEndToEndDet { .. }
554            | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
555
556            ModelType::YoloSegDet { boxes, protos } => {
557                let proto =
558                    self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
559                Ok(Some(proto))
560            }
561            ModelType::YoloSplitSegDet {
562                boxes,
563                scores,
564                mask_coeff,
565                protos,
566            } => {
567                let proto = self.decode_yolo_split_segdet_quantized_proto(
568                    outputs,
569                    boxes,
570                    scores,
571                    mask_coeff,
572                    protos,
573                    output_boxes,
574                )?;
575                Ok(Some(proto))
576            }
577            ModelType::YoloEndToEndSegDet { boxes, protos } => {
578                let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
579                    outputs,
580                    boxes,
581                    protos,
582                    output_boxes,
583                )?;
584                Ok(Some(proto))
585            }
586            ModelType::YoloSplitEndToEndSegDet {
587                boxes,
588                scores,
589                classes,
590                mask_coeff,
591                protos,
592            } => {
593                let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
594                    outputs,
595                    boxes,
596                    scores,
597                    classes,
598                    mask_coeff,
599                    protos,
600                    output_boxes,
601                )?;
602                Ok(Some(proto))
603            }
604        }
605    }
606
607    /// Decodes floating-point model outputs into detection boxes, returning
608    /// raw `ProtoData` for segmentation models instead of materialized masks.
609    ///
610    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
611    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
612    pub fn decode_float_proto<T>(
613        &self,
614        outputs: &[ArrayViewD<T>],
615        output_boxes: &mut Vec<DetectBox>,
616    ) -> Result<Option<ProtoData>, DecoderError>
617    where
618        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
619        f32: AsPrimitive<T>,
620    {
621        output_boxes.clear();
622        match &self.model_type {
623            // Detection-only and ModelPack variants: no proto data
624            ModelType::ModelPackSegDet { .. }
625            | ModelType::ModelPackSegDetSplit { .. }
626            | ModelType::ModelPackDet { .. }
627            | ModelType::ModelPackDetSplit { .. }
628            | ModelType::ModelPackSeg { .. }
629            | ModelType::YoloDet { .. }
630            | ModelType::YoloSplitDet { .. }
631            | ModelType::YoloEndToEndDet { .. }
632            | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
633
634            ModelType::YoloSegDet { boxes, protos } => {
635                let proto =
636                    self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
637                Ok(Some(proto))
638            }
639            ModelType::YoloSplitSegDet {
640                boxes,
641                scores,
642                mask_coeff,
643                protos,
644            } => {
645                let proto = self.decode_yolo_split_segdet_float_proto(
646                    outputs,
647                    boxes,
648                    scores,
649                    mask_coeff,
650                    protos,
651                    output_boxes,
652                )?;
653                Ok(Some(proto))
654            }
655            ModelType::YoloEndToEndSegDet { boxes, protos } => {
656                let proto = self.decode_yolo_end_to_end_segdet_float_proto(
657                    outputs,
658                    boxes,
659                    protos,
660                    output_boxes,
661                )?;
662                Ok(Some(proto))
663            }
664            ModelType::YoloSplitEndToEndSegDet {
665                boxes,
666                scores,
667                classes,
668                mask_coeff,
669                protos,
670            } => {
671                let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
672                    outputs,
673                    boxes,
674                    scores,
675                    classes,
676                    mask_coeff,
677                    protos,
678                    output_boxes,
679                )?;
680                Ok(Some(proto))
681            }
682        }
683    }
684}
685
686#[cfg(feature = "tracker")]
687pub use edgefirst_tracker::TrackInfo;
688
689#[cfg(feature = "tracker")]
690pub use edgefirst_tracker::Tracker;
691
692#[cfg(feature = "tracker")]
693impl Decoder {
694    /// This function decodes quantized model outputs into detection boxes and
695    /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
696    /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
697    /// will be decoded. The function clears the provided output vectors
698    /// before populating them with the decoded results.
699    ///
700    /// This function returns a `DecoderError` if the the provided outputs don't
701    /// match the configuration provided by the user when building the decoder.
702    ///
703    /// # Examples
704    ///
705    /// ```rust
706    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
707    /// # use ndarray::Array4;
708    /// # fn main() -> DecoderResult<()> {
709    /// #    let detect0 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_9x15x18.bin"));
710    /// #    let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
711    /// #
712    /// #    let detect1 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_17x30x18.bin"));
713    /// #    let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
714    /// #    let model_output = vec![
715    /// #        detect1.view().into_dyn().into(),
716    /// #        detect0.view().into_dyn().into(),
717    /// #    ];
718    /// let decoder = DecoderBuilder::default()
719    ///     .with_config_yaml_str(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string())
720    ///     .with_score_threshold(0.45)
721    ///     .with_iou_threshold(0.45)
722    ///     .build()?;
723    ///
724    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
725    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
726    /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
727    /// assert!(output_boxes[0].equal_within_delta(
728    ///     &DetectBox {
729    ///         bbox: BoundingBox {
730    ///             xmin: 0.43171933,
731    ///             ymin: 0.68243736,
732    ///             xmax: 0.5626645,
733    ///             ymax: 0.808863,
734    ///         },
735    ///         score: 0.99240804,
736    ///         label: 0
737    ///     },
738    ///     1e-6
739    /// ));
740    /// #    Ok(())
741    /// # }
742    /// ```
743    pub fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
744        &self,
745        tracker: &mut TR,
746        timestamp: u64,
747        outputs: &[ArrayViewDQuantized],
748        output_boxes: &mut Vec<DetectBox>,
749        output_masks: &mut Vec<Segmentation>,
750        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
751    ) -> Result<(), DecoderError> {
752        output_boxes.clear();
753        output_masks.clear();
754        output_tracks.clear();
755
756        // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
757        // Only boxes that come from decoding can be used for proto/mask generation.
758        match &self.model_type {
759            ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
760                tracker,
761                timestamp,
762                outputs,
763                boxes,
764                protos,
765                output_boxes,
766                output_masks,
767                output_tracks,
768            ),
769            ModelType::YoloSplitSegDet {
770                boxes,
771                scores,
772                mask_coeff,
773                protos,
774            } => self.decode_tracked_yolo_split_segdet_quantized(
775                tracker,
776                timestamp,
777                outputs,
778                boxes,
779                scores,
780                mask_coeff,
781                protos,
782                output_boxes,
783                output_masks,
784                output_tracks,
785            ),
786            ModelType::YoloEndToEndSegDet { boxes, protos } => self
787                .decode_tracked_yolo_end_to_end_segdet_quantized(
788                    tracker,
789                    timestamp,
790                    outputs,
791                    boxes,
792                    protos,
793                    output_boxes,
794                    output_masks,
795                    output_tracks,
796                ),
797            ModelType::YoloSplitEndToEndSegDet {
798                boxes,
799                scores,
800                classes,
801                mask_coeff,
802                protos,
803            } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
804                tracker,
805                timestamp,
806                outputs,
807                boxes,
808                scores,
809                classes,
810                mask_coeff,
811                protos,
812                output_boxes,
813                output_masks,
814                output_tracks,
815            ),
816            _ => {
817                self.decode_quantized(outputs, output_boxes, output_masks)?;
818                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
819                Ok(())
820            }
821        }
822    }
823
824    /// This function decodes floating point model outputs into detection boxes
825    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
826    /// masks will be decoded. The function clears the provided output
827    /// vectors before populating them with the decoded results.
828    ///
829    /// This function returns an `Error` if the provided outputs don't
830    /// match the configuration provided by the user when building the decoder.
831    ///
832    /// Any quantization information in the configuration will be ignored.
833    ///
834    /// # Examples
835    ///
836    /// ```rust
837    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
838    /// # use ndarray::Array3;
839    /// # fn main() -> DecoderResult<()> {
840    /// #   let out = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/yolov8s_80_classes.bin"));
841    /// #   let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
842    /// #   let mut out_dequant = vec![0.0_f64; 84 * 8400];
843    /// #   let quant = Quantization::new(0.0040811873, -123);
844    /// #   dequantize_cpu(out, quant, &mut out_dequant);
845    /// #   let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
846    ///    let decoder = DecoderBuilder::default()
847    ///     .with_config_yolo_det(configs::Detection {
848    ///         decoder: DecoderType::Ultralytics,
849    ///         quantization: None,
850    ///         shape: vec![1, 84, 8400],
851    ///         anchors: None,
852    ///         dshape: Vec::new(),
853    ///         normalized: Some(true),
854    ///     },
855    ///     Some(DecoderVersion::Yolo11))
856    ///     .with_score_threshold(0.25)
857    ///     .with_iou_threshold(0.7)
858    ///     .build()?;
859    ///
860    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
861    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
862    /// let model_output_f64 = vec![model_output_f64.view().into()];
863    /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;    
864    /// assert!(output_boxes[0].equal_within_delta(
865    ///        &DetectBox {
866    ///            bbox: BoundingBox {
867    ///                xmin: 0.5285137,
868    ///                ymin: 0.05305544,
869    ///                xmax: 0.87541467,
870    ///                ymax: 0.9998909,
871    ///            },
872    ///            score: 0.5591227,
873    ///            label: 0
874    ///        },
875    ///        1e-6
876    ///    ));
877    ///
878    /// #    Ok(())
879    /// # }
880    pub fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
881        &self,
882        tracker: &mut TR,
883        timestamp: u64,
884        outputs: &[ArrayViewD<T>],
885        output_boxes: &mut Vec<DetectBox>,
886        output_masks: &mut Vec<Segmentation>,
887        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
888    ) -> Result<(), DecoderError>
889    where
890        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
891        f32: AsPrimitive<T>,
892    {
893        output_boxes.clear();
894        output_masks.clear();
895        output_tracks.clear();
896        match &self.model_type {
897            ModelType::YoloSegDet { boxes, protos } => {
898                self.decode_tracked_yolo_segdet_float(
899                    tracker,
900                    timestamp,
901                    outputs,
902                    boxes,
903                    protos,
904                    output_boxes,
905                    output_masks,
906                    output_tracks,
907                )?;
908            }
909            ModelType::YoloSplitSegDet {
910                boxes,
911                scores,
912                mask_coeff,
913                protos,
914            } => {
915                self.decode_tracked_yolo_split_segdet_float(
916                    tracker,
917                    timestamp,
918                    outputs,
919                    boxes,
920                    scores,
921                    mask_coeff,
922                    protos,
923                    output_boxes,
924                    output_masks,
925                    output_tracks,
926                )?;
927            }
928            ModelType::YoloEndToEndSegDet { boxes, protos } => {
929                self.decode_tracked_yolo_end_to_end_segdet_float(
930                    tracker,
931                    timestamp,
932                    outputs,
933                    boxes,
934                    protos,
935                    output_boxes,
936                    output_masks,
937                    output_tracks,
938                )?;
939            }
940            ModelType::YoloSplitEndToEndSegDet {
941                boxes,
942                scores,
943                classes,
944                mask_coeff,
945                protos,
946            } => {
947                self.decode_tracked_yolo_split_end_to_end_segdet_float(
948                    tracker,
949                    timestamp,
950                    outputs,
951                    boxes,
952                    scores,
953                    classes,
954                    mask_coeff,
955                    protos,
956                    output_boxes,
957                    output_masks,
958                    output_tracks,
959                )?;
960            }
961            _ => {
962                self.decode_float(outputs, output_boxes, output_masks)?;
963                Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
964            }
965        }
966        Ok(())
967    }
968
969    /// Decodes quantized model outputs into detection boxes, returning raw
970    /// `ProtoData` for segmentation models instead of materialized masks.
971    ///
972    /// Returns `Ok(None)` for detection-only and ModelPack models (use
973    /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
974    /// YOLO segmentation models.
975    pub fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
976        &self,
977        tracker: &mut TR,
978        timestamp: u64,
979        outputs: &[ArrayViewDQuantized],
980        output_boxes: &mut Vec<DetectBox>,
981        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
982    ) -> Result<Option<ProtoData>, DecoderError> {
983        output_boxes.clear();
984        output_tracks.clear();
985        match &self.model_type {
986            // Detection-only and ModelPack variants: no proto data
987            ModelType::ModelPackSegDet { .. }
988            | ModelType::ModelPackSegDetSplit { .. }
989            | ModelType::ModelPackDet { .. }
990            | ModelType::ModelPackDetSplit { .. }
991            | ModelType::ModelPackSeg { .. }
992            | ModelType::YoloDet { .. }
993            | ModelType::YoloSplitDet { .. }
994            | ModelType::YoloEndToEndDet { .. }
995            | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
996
997            ModelType::YoloSegDet { boxes, protos } => {
998                let proto = self.decode_tracked_yolo_segdet_quantized_proto(
999                    tracker,
1000                    timestamp,
1001                    outputs,
1002                    boxes,
1003                    protos,
1004                    output_boxes,
1005                    output_tracks,
1006                )?;
1007                Ok(Some(proto))
1008            }
1009            ModelType::YoloSplitSegDet {
1010                boxes,
1011                scores,
1012                mask_coeff,
1013                protos,
1014            } => {
1015                let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1016                    tracker,
1017                    timestamp,
1018                    outputs,
1019                    boxes,
1020                    scores,
1021                    mask_coeff,
1022                    protos,
1023                    output_boxes,
1024                    output_tracks,
1025                )?;
1026                Ok(Some(proto))
1027            }
1028            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1029                let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1030                    tracker,
1031                    timestamp,
1032                    outputs,
1033                    boxes,
1034                    protos,
1035                    output_boxes,
1036                    output_tracks,
1037                )?;
1038                Ok(Some(proto))
1039            }
1040            ModelType::YoloSplitEndToEndSegDet {
1041                boxes,
1042                scores,
1043                classes,
1044                mask_coeff,
1045                protos,
1046            } => {
1047                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1048                    tracker,
1049                    timestamp,
1050                    outputs,
1051                    boxes,
1052                    scores,
1053                    classes,
1054                    mask_coeff,
1055                    protos,
1056                    output_boxes,
1057                    output_tracks,
1058                )?;
1059                Ok(Some(proto))
1060            }
1061        }
1062    }
1063
1064    /// Decodes floating-point model outputs into detection boxes, returning
1065    /// raw `ProtoData` for segmentation models instead of materialized masks.
1066    ///
1067    /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1068    /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1069    pub fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1070        &self,
1071        tracker: &mut TR,
1072        timestamp: u64,
1073        outputs: &[ArrayViewD<T>],
1074        output_boxes: &mut Vec<DetectBox>,
1075        output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1076    ) -> Result<Option<ProtoData>, DecoderError>
1077    where
1078        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1079        f32: AsPrimitive<T>,
1080    {
1081        output_boxes.clear();
1082        output_tracks.clear();
1083        match &self.model_type {
1084            // Detection-only and ModelPack variants: no proto data
1085            ModelType::ModelPackSegDet { .. }
1086            | ModelType::ModelPackSegDetSplit { .. }
1087            | ModelType::ModelPackDet { .. }
1088            | ModelType::ModelPackDetSplit { .. }
1089            | ModelType::ModelPackSeg { .. }
1090            | ModelType::YoloDet { .. }
1091            | ModelType::YoloSplitDet { .. }
1092            | ModelType::YoloEndToEndDet { .. }
1093            | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
1094
1095            ModelType::YoloSegDet { boxes, protos } => {
1096                let proto = self.decode_tracked_yolo_segdet_float_proto(
1097                    tracker,
1098                    timestamp,
1099                    outputs,
1100                    boxes,
1101                    protos,
1102                    output_boxes,
1103                    output_tracks,
1104                )?;
1105                Ok(Some(proto))
1106            }
1107            ModelType::YoloSplitSegDet {
1108                boxes,
1109                scores,
1110                mask_coeff,
1111                protos,
1112            } => {
1113                let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1114                    tracker,
1115                    timestamp,
1116                    outputs,
1117                    boxes,
1118                    scores,
1119                    mask_coeff,
1120                    protos,
1121                    output_boxes,
1122                    output_tracks,
1123                )?;
1124                Ok(Some(proto))
1125            }
1126            ModelType::YoloEndToEndSegDet { boxes, protos } => {
1127                let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1128                    tracker,
1129                    timestamp,
1130                    outputs,
1131                    boxes,
1132                    protos,
1133                    output_boxes,
1134                    output_tracks,
1135                )?;
1136                Ok(Some(proto))
1137            }
1138            ModelType::YoloSplitEndToEndSegDet {
1139                boxes,
1140                scores,
1141                classes,
1142                mask_coeff,
1143                protos,
1144            } => {
1145                let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1146                    tracker,
1147                    timestamp,
1148                    outputs,
1149                    boxes,
1150                    scores,
1151                    classes,
1152                    mask_coeff,
1153                    protos,
1154                    output_boxes,
1155                    output_tracks,
1156                )?;
1157                Ok(Some(proto))
1158            }
1159        }
1160    }
1161}