edgefirst_decoder/decoder/mod.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{ArrayView, ArrayViewD, Dimension};
5use num_traits::{AsPrimitive, Float};
6
7use crate::{DecoderError, DetectBox, ProtoData, Segmentation};
8
9pub mod config;
10pub mod configs;
11
12use configs::ModelType;
13
14#[derive(Debug, Clone, PartialEq)]
15pub struct Decoder {
16 model_type: ModelType,
17 pub iou_threshold: f32,
18 pub score_threshold: f32,
19 /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
20 /// models)
21 pub nms: Option<configs::Nms>,
22 /// Whether decoded boxes are in normalized [0,1] coordinates.
23 /// - `Some(true)`: Coordinates in [0,1] range
24 /// - `Some(false)`: Pixel coordinates
25 /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
26 /// 1.0)
27 normalized: Option<bool>,
28}
29
30#[derive(Debug)]
31pub enum ArrayViewDQuantized<'a> {
32 UInt8(ArrayViewD<'a, u8>),
33 Int8(ArrayViewD<'a, i8>),
34 UInt16(ArrayViewD<'a, u16>),
35 Int16(ArrayViewD<'a, i16>),
36 UInt32(ArrayViewD<'a, u32>),
37 Int32(ArrayViewD<'a, i32>),
38}
39
40impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
41where
42 D: Dimension,
43{
44 fn from(arr: ArrayView<'a, u8, D>) -> Self {
45 Self::UInt8(arr.into_dyn())
46 }
47}
48
49impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
50where
51 D: Dimension,
52{
53 fn from(arr: ArrayView<'a, i8, D>) -> Self {
54 Self::Int8(arr.into_dyn())
55 }
56}
57
58impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
59where
60 D: Dimension,
61{
62 fn from(arr: ArrayView<'a, u16, D>) -> Self {
63 Self::UInt16(arr.into_dyn())
64 }
65}
66
67impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
68where
69 D: Dimension,
70{
71 fn from(arr: ArrayView<'a, i16, D>) -> Self {
72 Self::Int16(arr.into_dyn())
73 }
74}
75
76impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
77where
78 D: Dimension,
79{
80 fn from(arr: ArrayView<'a, u32, D>) -> Self {
81 Self::UInt32(arr.into_dyn())
82 }
83}
84
85impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
86where
87 D: Dimension,
88{
89 fn from(arr: ArrayView<'a, i32, D>) -> Self {
90 Self::Int32(arr.into_dyn())
91 }
92}
93
94impl<'a> ArrayViewDQuantized<'a> {
95 /// Returns the shape of the underlying array.
96 ///
97 /// # Examples
98 /// ```rust
99 /// # use edgefirst_decoder::ArrayViewDQuantized;
100 /// # use ndarray::Array2;
101 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
102 /// let arr = Array2::from_shape_vec((2, 3), vec![1u8, 2, 3, 4, 5, 6])?;
103 /// let view = ArrayViewDQuantized::from(arr.view().into_dyn());
104 /// assert_eq!(view.shape(), &[2, 3]);
105 /// # Ok(())
106 /// # }
107 /// ```
108 pub fn shape(&self) -> &[usize] {
109 match self {
110 ArrayViewDQuantized::UInt8(a) => a.shape(),
111 ArrayViewDQuantized::Int8(a) => a.shape(),
112 ArrayViewDQuantized::UInt16(a) => a.shape(),
113 ArrayViewDQuantized::Int16(a) => a.shape(),
114 ArrayViewDQuantized::UInt32(a) => a.shape(),
115 ArrayViewDQuantized::Int32(a) => a.shape(),
116 }
117 }
118}
119
120/// WARNING: Do NOT nest `with_quantized!` calls. Each level multiplies
121/// monomorphized code paths by 6 (one per integer variant), so nesting
122/// N levels deep produces 6^N instantiations.
123///
124/// Instead, dequantize each tensor sequentially with `dequant_3d!`/`dequant_4d!`
125/// (6*N paths) or split into independent phases that each nest at most 2 levels.
126macro_rules! with_quantized {
127 ($x:expr, $var:ident, $body:expr) => {
128 match $x {
129 ArrayViewDQuantized::UInt8(x) => {
130 let $var = x;
131 $body
132 }
133 ArrayViewDQuantized::Int8(x) => {
134 let $var = x;
135 $body
136 }
137 ArrayViewDQuantized::UInt16(x) => {
138 let $var = x;
139 $body
140 }
141 ArrayViewDQuantized::Int16(x) => {
142 let $var = x;
143 $body
144 }
145 ArrayViewDQuantized::UInt32(x) => {
146 let $var = x;
147 $body
148 }
149 ArrayViewDQuantized::Int32(x) => {
150 let $var = x;
151 $body
152 }
153 }
154 };
155}
156
157mod builder;
158mod helpers;
159mod postprocess;
160mod tests;
161
162pub use builder::DecoderBuilder;
163pub use config::{ConfigOutput, ConfigOutputRef, ConfigOutputs};
164
165impl Decoder {
166 /// This function returns the parsed model type of the decoder.
167 ///
168 /// # Examples
169 ///
170 /// ```rust
171 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
172 /// # fn main() -> DecoderResult<()> {
173 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
174 /// let decoder = DecoderBuilder::default()
175 /// .with_config_yaml_str(config_yaml)
176 /// .build()?;
177 /// assert!(matches!(
178 /// decoder.model_type(),
179 /// ModelType::ModelPackDetSplit { .. }
180 /// ));
181 /// # Ok(())
182 /// # }
183 /// ```
184 pub fn model_type(&self) -> &ModelType {
185 &self.model_type
186 }
187
188 /// Returns the box coordinate format if known from the model config.
189 ///
190 /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
191 /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
192 /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
193 /// 1.0)
194 ///
195 /// This is determined by the model config's `normalized` field, not the NMS
196 /// mode. When coordinates are in pixels or unknown, the caller may need
197 /// to normalize using the model input dimensions.
198 ///
199 /// # Examples
200 ///
201 /// ```rust
202 /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
203 /// # fn main() -> DecoderResult<()> {
204 /// # let config_yaml = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string();
205 /// let decoder = DecoderBuilder::default()
206 /// .with_config_yaml_str(config_yaml)
207 /// .build()?;
208 /// // Config doesn't specify normalized, so it's None
209 /// assert!(decoder.normalized_boxes().is_none());
210 /// # Ok(())
211 /// # }
212 /// ```
213 pub fn normalized_boxes(&self) -> Option<bool> {
214 self.normalized
215 }
216
217 /// This function decodes quantized model outputs into detection boxes and
218 /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
219 /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
220 /// will be decoded. The function clears the provided output vectors
221 /// before populating them with the decoded results.
222 ///
223 /// This function returns a `DecoderError` if the the provided outputs don't
224 /// match the configuration provided by the user when building the decoder.
225 ///
226 /// # Examples
227 ///
228 /// ```rust
229 /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
230 /// # use ndarray::Array4;
231 /// # fn main() -> DecoderResult<()> {
232 /// # let detect0 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_9x15x18.bin"));
233 /// # let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
234 /// #
235 /// # let detect1 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_17x30x18.bin"));
236 /// # let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
237 /// # let model_output = vec![
238 /// # detect1.view().into_dyn().into(),
239 /// # detect0.view().into_dyn().into(),
240 /// # ];
241 /// let decoder = DecoderBuilder::default()
242 /// .with_config_yaml_str(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string())
243 /// .with_score_threshold(0.45)
244 /// .with_iou_threshold(0.45)
245 /// .build()?;
246 ///
247 /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
248 /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
249 /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
250 /// assert!(output_boxes[0].equal_within_delta(
251 /// &DetectBox {
252 /// bbox: BoundingBox {
253 /// xmin: 0.43171933,
254 /// ymin: 0.68243736,
255 /// xmax: 0.5626645,
256 /// ymax: 0.808863,
257 /// },
258 /// score: 0.99240804,
259 /// label: 0
260 /// },
261 /// 1e-6
262 /// ));
263 /// # Ok(())
264 /// # }
265 /// ```
266 pub fn decode_quantized(
267 &self,
268 outputs: &[ArrayViewDQuantized],
269 output_boxes: &mut Vec<DetectBox>,
270 output_masks: &mut Vec<Segmentation>,
271 ) -> Result<(), DecoderError> {
272 output_boxes.clear();
273 output_masks.clear();
274 match &self.model_type {
275 ModelType::ModelPackSegDet {
276 boxes,
277 scores,
278 segmentation,
279 } => {
280 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
281 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
282 }
283 ModelType::ModelPackSegDetSplit {
284 detection,
285 segmentation,
286 } => {
287 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
288 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
289 }
290 ModelType::ModelPackDet { boxes, scores } => {
291 self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
292 }
293 ModelType::ModelPackDetSplit { detection } => {
294 self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
295 }
296 ModelType::ModelPackSeg { segmentation } => {
297 self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
298 }
299 ModelType::YoloDet { boxes } => {
300 self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
301 }
302 ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
303 outputs,
304 boxes,
305 protos,
306 output_boxes,
307 output_masks,
308 ),
309 ModelType::YoloSplitDet { boxes, scores } => {
310 self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
311 }
312 ModelType::YoloSplitSegDet {
313 boxes,
314 scores,
315 mask_coeff,
316 protos,
317 } => self.decode_yolo_split_segdet_quantized(
318 outputs,
319 boxes,
320 scores,
321 mask_coeff,
322 protos,
323 output_boxes,
324 output_masks,
325 ),
326 ModelType::YoloEndToEndDet { boxes } => {
327 self.decode_yolo_end_to_end_det_quantized(outputs, boxes, output_boxes)
328 }
329 ModelType::YoloEndToEndSegDet { boxes, protos } => self
330 .decode_yolo_end_to_end_segdet_quantized(
331 outputs,
332 boxes,
333 protos,
334 output_boxes,
335 output_masks,
336 ),
337 ModelType::YoloSplitEndToEndDet {
338 boxes,
339 scores,
340 classes,
341 } => self.decode_yolo_split_end_to_end_det_quantized(
342 outputs,
343 boxes,
344 scores,
345 classes,
346 output_boxes,
347 ),
348 ModelType::YoloSplitEndToEndSegDet {
349 boxes,
350 scores,
351 classes,
352 mask_coeff,
353 protos,
354 } => self.decode_yolo_split_end_to_end_segdet_quantized(
355 outputs,
356 boxes,
357 scores,
358 classes,
359 mask_coeff,
360 protos,
361 output_boxes,
362 output_masks,
363 ),
364 }
365 }
366
367 /// This function decodes floating point model outputs into detection boxes
368 /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
369 /// masks will be decoded. The function clears the provided output
370 /// vectors before populating them with the decoded results.
371 ///
372 /// This function returns an `Error` if the the provided outputs don't
373 /// match the configuration provided by the user when building the decoder.
374 ///
375 /// Any quantization information in the configuration will be ignored.
376 ///
377 /// # Examples
378 ///
379 /// ```rust
380 /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
381 /// # use ndarray::Array3;
382 /// # fn main() -> DecoderResult<()> {
383 /// # let out = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/yolov8s_80_classes.bin"));
384 /// # let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
385 /// # let mut out_dequant = vec![0.0_f64; 84 * 8400];
386 /// # let quant = Quantization::new(0.0040811873, -123);
387 /// # dequantize_cpu(out, quant, &mut out_dequant);
388 /// # let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
389 /// let decoder = DecoderBuilder::default()
390 /// .with_config_yolo_det(configs::Detection {
391 /// decoder: DecoderType::Ultralytics,
392 /// quantization: None,
393 /// shape: vec![1, 84, 8400],
394 /// anchors: None,
395 /// dshape: Vec::new(),
396 /// normalized: Some(true),
397 /// },
398 /// Some(DecoderVersion::Yolo11))
399 /// .with_score_threshold(0.25)
400 /// .with_iou_threshold(0.7)
401 /// .build()?;
402 ///
403 /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
404 /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
405 /// let model_output_f64 = vec![model_output_f64.view().into()];
406 /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;
407 /// assert!(output_boxes[0].equal_within_delta(
408 /// &DetectBox {
409 /// bbox: BoundingBox {
410 /// xmin: 0.5285137,
411 /// ymin: 0.05305544,
412 /// xmax: 0.87541467,
413 /// ymax: 0.9998909,
414 /// },
415 /// score: 0.5591227,
416 /// label: 0
417 /// },
418 /// 1e-6
419 /// ));
420 ///
421 /// # Ok(())
422 /// # }
423 pub fn decode_float<T>(
424 &self,
425 outputs: &[ArrayViewD<T>],
426 output_boxes: &mut Vec<DetectBox>,
427 output_masks: &mut Vec<Segmentation>,
428 ) -> Result<(), DecoderError>
429 where
430 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
431 f32: AsPrimitive<T>,
432 {
433 output_boxes.clear();
434 output_masks.clear();
435 match &self.model_type {
436 ModelType::ModelPackSegDet {
437 boxes,
438 scores,
439 segmentation,
440 } => {
441 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
442 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
443 }
444 ModelType::ModelPackSegDetSplit {
445 detection,
446 segmentation,
447 } => {
448 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
449 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
450 }
451 ModelType::ModelPackDet { boxes, scores } => {
452 self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
453 }
454 ModelType::ModelPackDetSplit { detection } => {
455 self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
456 }
457 ModelType::ModelPackSeg { segmentation } => {
458 self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
459 }
460 ModelType::YoloDet { boxes } => {
461 self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
462 }
463 ModelType::YoloSegDet { boxes, protos } => {
464 self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
465 }
466 ModelType::YoloSplitDet { boxes, scores } => {
467 self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
468 }
469 ModelType::YoloSplitSegDet {
470 boxes,
471 scores,
472 mask_coeff,
473 protos,
474 } => {
475 self.decode_yolo_split_segdet_float(
476 outputs,
477 boxes,
478 scores,
479 mask_coeff,
480 protos,
481 output_boxes,
482 output_masks,
483 )?;
484 }
485 ModelType::YoloEndToEndDet { boxes } => {
486 self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
487 }
488 ModelType::YoloEndToEndSegDet { boxes, protos } => {
489 self.decode_yolo_end_to_end_segdet_float(
490 outputs,
491 boxes,
492 protos,
493 output_boxes,
494 output_masks,
495 )?;
496 }
497 ModelType::YoloSplitEndToEndDet {
498 boxes,
499 scores,
500 classes,
501 } => {
502 self.decode_yolo_split_end_to_end_det_float(
503 outputs,
504 boxes,
505 scores,
506 classes,
507 output_boxes,
508 )?;
509 }
510 ModelType::YoloSplitEndToEndSegDet {
511 boxes,
512 scores,
513 classes,
514 mask_coeff,
515 protos,
516 } => {
517 self.decode_yolo_split_end_to_end_segdet_float(
518 outputs,
519 boxes,
520 scores,
521 classes,
522 mask_coeff,
523 protos,
524 output_boxes,
525 output_masks,
526 )?;
527 }
528 }
529 Ok(())
530 }
531
532 /// Decodes quantized model outputs into detection boxes, returning raw
533 /// `ProtoData` for segmentation models instead of materialized masks.
534 ///
535 /// Returns `Ok(None)` for detection-only and ModelPack models (use
536 /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
537 /// YOLO segmentation models.
538 pub fn decode_quantized_proto(
539 &self,
540 outputs: &[ArrayViewDQuantized],
541 output_boxes: &mut Vec<DetectBox>,
542 ) -> Result<Option<ProtoData>, DecoderError> {
543 output_boxes.clear();
544 match &self.model_type {
545 // Detection-only and ModelPack variants: no proto data
546 ModelType::ModelPackSegDet { .. }
547 | ModelType::ModelPackSegDetSplit { .. }
548 | ModelType::ModelPackDet { .. }
549 | ModelType::ModelPackDetSplit { .. }
550 | ModelType::ModelPackSeg { .. }
551 | ModelType::YoloDet { .. }
552 | ModelType::YoloSplitDet { .. }
553 | ModelType::YoloEndToEndDet { .. }
554 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
555
556 ModelType::YoloSegDet { boxes, protos } => {
557 let proto =
558 self.decode_yolo_segdet_quantized_proto(outputs, boxes, protos, output_boxes)?;
559 Ok(Some(proto))
560 }
561 ModelType::YoloSplitSegDet {
562 boxes,
563 scores,
564 mask_coeff,
565 protos,
566 } => {
567 let proto = self.decode_yolo_split_segdet_quantized_proto(
568 outputs,
569 boxes,
570 scores,
571 mask_coeff,
572 protos,
573 output_boxes,
574 )?;
575 Ok(Some(proto))
576 }
577 ModelType::YoloEndToEndSegDet { boxes, protos } => {
578 let proto = self.decode_yolo_end_to_end_segdet_quantized_proto(
579 outputs,
580 boxes,
581 protos,
582 output_boxes,
583 )?;
584 Ok(Some(proto))
585 }
586 ModelType::YoloSplitEndToEndSegDet {
587 boxes,
588 scores,
589 classes,
590 mask_coeff,
591 protos,
592 } => {
593 let proto = self.decode_yolo_split_end_to_end_segdet_quantized_proto(
594 outputs,
595 boxes,
596 scores,
597 classes,
598 mask_coeff,
599 protos,
600 output_boxes,
601 )?;
602 Ok(Some(proto))
603 }
604 }
605 }
606
607 /// Decodes floating-point model outputs into detection boxes, returning
608 /// raw `ProtoData` for segmentation models instead of materialized masks.
609 ///
610 /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
611 /// `Ok(Some(ProtoData))` for YOLO segmentation models.
612 pub fn decode_float_proto<T>(
613 &self,
614 outputs: &[ArrayViewD<T>],
615 output_boxes: &mut Vec<DetectBox>,
616 ) -> Result<Option<ProtoData>, DecoderError>
617 where
618 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
619 f32: AsPrimitive<T>,
620 {
621 output_boxes.clear();
622 match &self.model_type {
623 // Detection-only and ModelPack variants: no proto data
624 ModelType::ModelPackSegDet { .. }
625 | ModelType::ModelPackSegDetSplit { .. }
626 | ModelType::ModelPackDet { .. }
627 | ModelType::ModelPackDetSplit { .. }
628 | ModelType::ModelPackSeg { .. }
629 | ModelType::YoloDet { .. }
630 | ModelType::YoloSplitDet { .. }
631 | ModelType::YoloEndToEndDet { .. }
632 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
633
634 ModelType::YoloSegDet { boxes, protos } => {
635 let proto =
636 self.decode_yolo_segdet_float_proto(outputs, boxes, protos, output_boxes)?;
637 Ok(Some(proto))
638 }
639 ModelType::YoloSplitSegDet {
640 boxes,
641 scores,
642 mask_coeff,
643 protos,
644 } => {
645 let proto = self.decode_yolo_split_segdet_float_proto(
646 outputs,
647 boxes,
648 scores,
649 mask_coeff,
650 protos,
651 output_boxes,
652 )?;
653 Ok(Some(proto))
654 }
655 ModelType::YoloEndToEndSegDet { boxes, protos } => {
656 let proto = self.decode_yolo_end_to_end_segdet_float_proto(
657 outputs,
658 boxes,
659 protos,
660 output_boxes,
661 )?;
662 Ok(Some(proto))
663 }
664 ModelType::YoloSplitEndToEndSegDet {
665 boxes,
666 scores,
667 classes,
668 mask_coeff,
669 protos,
670 } => {
671 let proto = self.decode_yolo_split_end_to_end_segdet_float_proto(
672 outputs,
673 boxes,
674 scores,
675 classes,
676 mask_coeff,
677 protos,
678 output_boxes,
679 )?;
680 Ok(Some(proto))
681 }
682 }
683 }
684}
685
686#[cfg(feature = "tracker")]
687pub use edgefirst_tracker::TrackInfo;
688
689#[cfg(feature = "tracker")]
690pub use edgefirst_tracker::Tracker;
691
692#[cfg(feature = "tracker")]
693impl Decoder {
694 /// This function decodes quantized model outputs into detection boxes and
695 /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
696 /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
697 /// will be decoded. The function clears the provided output vectors
698 /// before populating them with the decoded results.
699 ///
700 /// This function returns a `DecoderError` if the the provided outputs don't
701 /// match the configuration provided by the user when building the decoder.
702 ///
703 /// # Examples
704 ///
705 /// ```rust
706 /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
707 /// # use ndarray::Array4;
708 /// # fn main() -> DecoderResult<()> {
709 /// # let detect0 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_9x15x18.bin"));
710 /// # let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
711 /// #
712 /// # let detect1 = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split_17x30x18.bin"));
713 /// # let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
714 /// # let model_output = vec![
715 /// # detect1.view().into_dyn().into(),
716 /// # detect0.view().into_dyn().into(),
717 /// # ];
718 /// let decoder = DecoderBuilder::default()
719 /// .with_config_yaml_str(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/modelpack_split.yaml")).to_string())
720 /// .with_score_threshold(0.45)
721 /// .with_iou_threshold(0.45)
722 /// .build()?;
723 ///
724 /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
725 /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
726 /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
727 /// assert!(output_boxes[0].equal_within_delta(
728 /// &DetectBox {
729 /// bbox: BoundingBox {
730 /// xmin: 0.43171933,
731 /// ymin: 0.68243736,
732 /// xmax: 0.5626645,
733 /// ymax: 0.808863,
734 /// },
735 /// score: 0.99240804,
736 /// label: 0
737 /// },
738 /// 1e-6
739 /// ));
740 /// # Ok(())
741 /// # }
742 /// ```
743 pub fn decode_tracked_quantized<TR: edgefirst_tracker::Tracker<DetectBox>>(
744 &self,
745 tracker: &mut TR,
746 timestamp: u64,
747 outputs: &[ArrayViewDQuantized],
748 output_boxes: &mut Vec<DetectBox>,
749 output_masks: &mut Vec<Segmentation>,
750 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
751 ) -> Result<(), DecoderError> {
752 output_boxes.clear();
753 output_masks.clear();
754 output_tracks.clear();
755
756 // yolo segdet variants require special handling to separate boxes that come from decoding vs active tracks.
757 // Only boxes that come from decoding can be used for proto/mask generation.
758 match &self.model_type {
759 ModelType::YoloSegDet { boxes, protos } => self.decode_tracked_yolo_segdet_quantized(
760 tracker,
761 timestamp,
762 outputs,
763 boxes,
764 protos,
765 output_boxes,
766 output_masks,
767 output_tracks,
768 ),
769 ModelType::YoloSplitSegDet {
770 boxes,
771 scores,
772 mask_coeff,
773 protos,
774 } => self.decode_tracked_yolo_split_segdet_quantized(
775 tracker,
776 timestamp,
777 outputs,
778 boxes,
779 scores,
780 mask_coeff,
781 protos,
782 output_boxes,
783 output_masks,
784 output_tracks,
785 ),
786 ModelType::YoloEndToEndSegDet { boxes, protos } => self
787 .decode_tracked_yolo_end_to_end_segdet_quantized(
788 tracker,
789 timestamp,
790 outputs,
791 boxes,
792 protos,
793 output_boxes,
794 output_masks,
795 output_tracks,
796 ),
797 ModelType::YoloSplitEndToEndSegDet {
798 boxes,
799 scores,
800 classes,
801 mask_coeff,
802 protos,
803 } => self.decode_tracked_yolo_split_end_to_end_segdet_quantized(
804 tracker,
805 timestamp,
806 outputs,
807 boxes,
808 scores,
809 classes,
810 mask_coeff,
811 protos,
812 output_boxes,
813 output_masks,
814 output_tracks,
815 ),
816 _ => {
817 self.decode_quantized(outputs, output_boxes, output_masks)?;
818 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
819 Ok(())
820 }
821 }
822 }
823
824 /// This function decodes floating point model outputs into detection boxes
825 /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
826 /// masks will be decoded. The function clears the provided output
827 /// vectors before populating them with the decoded results.
828 ///
829 /// This function returns an `Error` if the provided outputs don't
830 /// match the configuration provided by the user when building the decoder.
831 ///
832 /// Any quantization information in the configuration will be ignored.
833 ///
834 /// # Examples
835 ///
836 /// ```rust
837 /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
838 /// # use ndarray::Array3;
839 /// # fn main() -> DecoderResult<()> {
840 /// # let out = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/yolov8s_80_classes.bin"));
841 /// # let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
842 /// # let mut out_dequant = vec![0.0_f64; 84 * 8400];
843 /// # let quant = Quantization::new(0.0040811873, -123);
844 /// # dequantize_cpu(out, quant, &mut out_dequant);
845 /// # let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
846 /// let decoder = DecoderBuilder::default()
847 /// .with_config_yolo_det(configs::Detection {
848 /// decoder: DecoderType::Ultralytics,
849 /// quantization: None,
850 /// shape: vec![1, 84, 8400],
851 /// anchors: None,
852 /// dshape: Vec::new(),
853 /// normalized: Some(true),
854 /// },
855 /// Some(DecoderVersion::Yolo11))
856 /// .with_score_threshold(0.25)
857 /// .with_iou_threshold(0.7)
858 /// .build()?;
859 ///
860 /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
861 /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
862 /// let model_output_f64 = vec![model_output_f64.view().into()];
863 /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;
864 /// assert!(output_boxes[0].equal_within_delta(
865 /// &DetectBox {
866 /// bbox: BoundingBox {
867 /// xmin: 0.5285137,
868 /// ymin: 0.05305544,
869 /// xmax: 0.87541467,
870 /// ymax: 0.9998909,
871 /// },
872 /// score: 0.5591227,
873 /// label: 0
874 /// },
875 /// 1e-6
876 /// ));
877 ///
878 /// # Ok(())
879 /// # }
880 pub fn decode_tracked_float<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
881 &self,
882 tracker: &mut TR,
883 timestamp: u64,
884 outputs: &[ArrayViewD<T>],
885 output_boxes: &mut Vec<DetectBox>,
886 output_masks: &mut Vec<Segmentation>,
887 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
888 ) -> Result<(), DecoderError>
889 where
890 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
891 f32: AsPrimitive<T>,
892 {
893 output_boxes.clear();
894 output_masks.clear();
895 output_tracks.clear();
896 match &self.model_type {
897 ModelType::YoloSegDet { boxes, protos } => {
898 self.decode_tracked_yolo_segdet_float(
899 tracker,
900 timestamp,
901 outputs,
902 boxes,
903 protos,
904 output_boxes,
905 output_masks,
906 output_tracks,
907 )?;
908 }
909 ModelType::YoloSplitSegDet {
910 boxes,
911 scores,
912 mask_coeff,
913 protos,
914 } => {
915 self.decode_tracked_yolo_split_segdet_float(
916 tracker,
917 timestamp,
918 outputs,
919 boxes,
920 scores,
921 mask_coeff,
922 protos,
923 output_boxes,
924 output_masks,
925 output_tracks,
926 )?;
927 }
928 ModelType::YoloEndToEndSegDet { boxes, protos } => {
929 self.decode_tracked_yolo_end_to_end_segdet_float(
930 tracker,
931 timestamp,
932 outputs,
933 boxes,
934 protos,
935 output_boxes,
936 output_masks,
937 output_tracks,
938 )?;
939 }
940 ModelType::YoloSplitEndToEndSegDet {
941 boxes,
942 scores,
943 classes,
944 mask_coeff,
945 protos,
946 } => {
947 self.decode_tracked_yolo_split_end_to_end_segdet_float(
948 tracker,
949 timestamp,
950 outputs,
951 boxes,
952 scores,
953 classes,
954 mask_coeff,
955 protos,
956 output_boxes,
957 output_masks,
958 output_tracks,
959 )?;
960 }
961 _ => {
962 self.decode_float(outputs, output_boxes, output_masks)?;
963 Self::update_tracker(tracker, timestamp, output_boxes, output_tracks);
964 }
965 }
966 Ok(())
967 }
968
969 /// Decodes quantized model outputs into detection boxes, returning raw
970 /// `ProtoData` for segmentation models instead of materialized masks.
971 ///
972 /// Returns `Ok(None)` for detection-only and ModelPack models (use
973 /// `decode_quantized` for those). Returns `Ok(Some(ProtoData))` for
974 /// YOLO segmentation models.
975 pub fn decode_tracked_quantized_proto<TR: edgefirst_tracker::Tracker<DetectBox>>(
976 &self,
977 tracker: &mut TR,
978 timestamp: u64,
979 outputs: &[ArrayViewDQuantized],
980 output_boxes: &mut Vec<DetectBox>,
981 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
982 ) -> Result<Option<ProtoData>, DecoderError> {
983 output_boxes.clear();
984 output_tracks.clear();
985 match &self.model_type {
986 // Detection-only and ModelPack variants: no proto data
987 ModelType::ModelPackSegDet { .. }
988 | ModelType::ModelPackSegDetSplit { .. }
989 | ModelType::ModelPackDet { .. }
990 | ModelType::ModelPackDetSplit { .. }
991 | ModelType::ModelPackSeg { .. }
992 | ModelType::YoloDet { .. }
993 | ModelType::YoloSplitDet { .. }
994 | ModelType::YoloEndToEndDet { .. }
995 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
996
997 ModelType::YoloSegDet { boxes, protos } => {
998 let proto = self.decode_tracked_yolo_segdet_quantized_proto(
999 tracker,
1000 timestamp,
1001 outputs,
1002 boxes,
1003 protos,
1004 output_boxes,
1005 output_tracks,
1006 )?;
1007 Ok(Some(proto))
1008 }
1009 ModelType::YoloSplitSegDet {
1010 boxes,
1011 scores,
1012 mask_coeff,
1013 protos,
1014 } => {
1015 let proto = self.decode_tracked_yolo_split_segdet_quantized_proto(
1016 tracker,
1017 timestamp,
1018 outputs,
1019 boxes,
1020 scores,
1021 mask_coeff,
1022 protos,
1023 output_boxes,
1024 output_tracks,
1025 )?;
1026 Ok(Some(proto))
1027 }
1028 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1029 let proto = self.decode_tracked_yolo_end_to_end_segdet_quantized_proto(
1030 tracker,
1031 timestamp,
1032 outputs,
1033 boxes,
1034 protos,
1035 output_boxes,
1036 output_tracks,
1037 )?;
1038 Ok(Some(proto))
1039 }
1040 ModelType::YoloSplitEndToEndSegDet {
1041 boxes,
1042 scores,
1043 classes,
1044 mask_coeff,
1045 protos,
1046 } => {
1047 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_quantized_proto(
1048 tracker,
1049 timestamp,
1050 outputs,
1051 boxes,
1052 scores,
1053 classes,
1054 mask_coeff,
1055 protos,
1056 output_boxes,
1057 output_tracks,
1058 )?;
1059 Ok(Some(proto))
1060 }
1061 }
1062 }
1063
1064 /// Decodes floating-point model outputs into detection boxes, returning
1065 /// raw `ProtoData` for segmentation models instead of materialized masks.
1066 ///
1067 /// Returns `Ok(None)` for detection-only and ModelPack models. Returns
1068 /// `Ok(Some(ProtoData))` for YOLO segmentation models.
1069 pub fn decode_tracked_float_proto<TR: edgefirst_tracker::Tracker<DetectBox>, T>(
1070 &self,
1071 tracker: &mut TR,
1072 timestamp: u64,
1073 outputs: &[ArrayViewD<T>],
1074 output_boxes: &mut Vec<DetectBox>,
1075 output_tracks: &mut Vec<edgefirst_tracker::TrackInfo>,
1076 ) -> Result<Option<ProtoData>, DecoderError>
1077 where
1078 T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
1079 f32: AsPrimitive<T>,
1080 {
1081 output_boxes.clear();
1082 output_tracks.clear();
1083 match &self.model_type {
1084 // Detection-only and ModelPack variants: no proto data
1085 ModelType::ModelPackSegDet { .. }
1086 | ModelType::ModelPackSegDetSplit { .. }
1087 | ModelType::ModelPackDet { .. }
1088 | ModelType::ModelPackDetSplit { .. }
1089 | ModelType::ModelPackSeg { .. }
1090 | ModelType::YoloDet { .. }
1091 | ModelType::YoloSplitDet { .. }
1092 | ModelType::YoloEndToEndDet { .. }
1093 | ModelType::YoloSplitEndToEndDet { .. } => Ok(None),
1094
1095 ModelType::YoloSegDet { boxes, protos } => {
1096 let proto = self.decode_tracked_yolo_segdet_float_proto(
1097 tracker,
1098 timestamp,
1099 outputs,
1100 boxes,
1101 protos,
1102 output_boxes,
1103 output_tracks,
1104 )?;
1105 Ok(Some(proto))
1106 }
1107 ModelType::YoloSplitSegDet {
1108 boxes,
1109 scores,
1110 mask_coeff,
1111 protos,
1112 } => {
1113 let proto = self.decode_tracked_yolo_split_segdet_float_proto(
1114 tracker,
1115 timestamp,
1116 outputs,
1117 boxes,
1118 scores,
1119 mask_coeff,
1120 protos,
1121 output_boxes,
1122 output_tracks,
1123 )?;
1124 Ok(Some(proto))
1125 }
1126 ModelType::YoloEndToEndSegDet { boxes, protos } => {
1127 let proto = self.decode_tracked_yolo_end_to_end_segdet_float_proto(
1128 tracker,
1129 timestamp,
1130 outputs,
1131 boxes,
1132 protos,
1133 output_boxes,
1134 output_tracks,
1135 )?;
1136 Ok(Some(proto))
1137 }
1138 ModelType::YoloSplitEndToEndSegDet {
1139 boxes,
1140 scores,
1141 classes,
1142 mask_coeff,
1143 protos,
1144 } => {
1145 let proto = self.decode_tracked_yolo_split_end_to_end_segdet_float_proto(
1146 tracker,
1147 timestamp,
1148 outputs,
1149 boxes,
1150 scores,
1151 classes,
1152 mask_coeff,
1153 protos,
1154 output_boxes,
1155 output_tracks,
1156 )?;
1157 Ok(Some(proto))
1158 }
1159 }
1160 }
1161}