Skip to main content

edgefirst_decoder/decoder/
configs.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use serde::{Deserialize, Serialize};
8
9/// Deserialize dshape from either array-of-tuples or array-of-single-key-dicts.
10///
11/// The metadata spec produces `[{"batch": 1}, {"num_features": 84}]` (dict format),
12/// while serde's default `Vec<(A, B)>` expects `[["batch", 1]]` (tuple format).
13/// This deserializer accepts both.
14pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
15where
16    D: serde::Deserializer<'de>,
17{
18    #[derive(Deserialize)]
19    #[serde(untagged)]
20    enum DShapeItem {
21        Tuple(DimName, usize),
22        Map(HashMap<DimName, usize>),
23    }
24
25    let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
26    items
27        .into_iter()
28        .map(|item| match item {
29            DShapeItem::Tuple(name, size) => Ok((name, size)),
30            DShapeItem::Map(map) => {
31                if map.len() != 1 {
32                    return Err(serde::de::Error::custom(
33                        "dshape map entry must have exactly one key",
34                    ));
35                }
36                let (name, size) = map.into_iter().next().unwrap();
37                Ok((name, size))
38            }
39        })
40        .collect()
41}
42
43#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
44pub struct QuantTuple(pub f32, pub i32);
45impl From<QuantTuple> for (f32, i32) {
46    fn from(value: QuantTuple) -> Self {
47        (value.0, value.1)
48    }
49}
50
51impl From<(f32, i32)> for QuantTuple {
52    fn from(value: (f32, i32)) -> Self {
53        QuantTuple(value.0, value.1)
54    }
55}
56
57#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
58pub struct Segmentation {
59    #[serde(default)]
60    pub decoder: DecoderType,
61    #[serde(default)]
62    pub quantization: Option<QuantTuple>,
63    #[serde(default)]
64    pub shape: Vec<usize>,
65    #[serde(default, deserialize_with = "deserialize_dshape")]
66    pub dshape: Vec<(DimName, usize)>,
67}
68
69#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
70pub struct Protos {
71    #[serde(default)]
72    pub decoder: DecoderType,
73    #[serde(default)]
74    pub quantization: Option<QuantTuple>,
75    #[serde(default)]
76    pub shape: Vec<usize>,
77    #[serde(default, deserialize_with = "deserialize_dshape")]
78    pub dshape: Vec<(DimName, usize)>,
79}
80
81#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
82pub struct MaskCoefficients {
83    #[serde(default)]
84    pub decoder: DecoderType,
85    #[serde(default)]
86    pub quantization: Option<QuantTuple>,
87    #[serde(default)]
88    pub shape: Vec<usize>,
89    #[serde(default, deserialize_with = "deserialize_dshape")]
90    pub dshape: Vec<(DimName, usize)>,
91}
92
93#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
94pub struct Mask {
95    #[serde(default)]
96    pub decoder: DecoderType,
97    #[serde(default)]
98    pub quantization: Option<QuantTuple>,
99    #[serde(default)]
100    pub shape: Vec<usize>,
101    #[serde(default, deserialize_with = "deserialize_dshape")]
102    pub dshape: Vec<(DimName, usize)>,
103}
104
105#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
106pub struct Detection {
107    #[serde(default)]
108    pub anchors: Option<Vec<[f32; 2]>>,
109    #[serde(default)]
110    pub decoder: DecoderType,
111    #[serde(default)]
112    pub quantization: Option<QuantTuple>,
113    #[serde(default)]
114    pub shape: Vec<usize>,
115    #[serde(default, deserialize_with = "deserialize_dshape")]
116    pub dshape: Vec<(DimName, usize)>,
117    /// Whether box coordinates are normalized to [0,1] range.
118    /// - `Some(true)`: Coordinates in [0,1] range relative to model input
119    /// - `Some(false)`: Pixel coordinates relative to model input
120    ///   (letterboxed)
121    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
122    ///   > 1.0)
123    #[serde(default)]
124    pub normalized: Option<bool>,
125}
126
127#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
128pub struct Scores {
129    #[serde(default)]
130    pub decoder: DecoderType,
131    #[serde(default)]
132    pub quantization: Option<QuantTuple>,
133    #[serde(default)]
134    pub shape: Vec<usize>,
135    #[serde(default, deserialize_with = "deserialize_dshape")]
136    pub dshape: Vec<(DimName, usize)>,
137}
138
139#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
140pub struct Boxes {
141    #[serde(default)]
142    pub decoder: DecoderType,
143    #[serde(default)]
144    pub quantization: Option<QuantTuple>,
145    #[serde(default)]
146    pub shape: Vec<usize>,
147    #[serde(default, deserialize_with = "deserialize_dshape")]
148    pub dshape: Vec<(DimName, usize)>,
149    /// Whether box coordinates are normalized to [0,1] range.
150    /// - `Some(true)`: Coordinates in [0,1] range relative to model input
151    /// - `Some(false)`: Pixel coordinates relative to model input
152    ///   (letterboxed)
153    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
154    ///   > 1.0)
155    #[serde(default)]
156    pub normalized: Option<bool>,
157}
158
159#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
160pub struct Classes {
161    #[serde(default)]
162    pub decoder: DecoderType,
163    #[serde(default)]
164    pub quantization: Option<QuantTuple>,
165    #[serde(default)]
166    pub shape: Vec<usize>,
167    #[serde(default, deserialize_with = "deserialize_dshape")]
168    pub dshape: Vec<(DimName, usize)>,
169}
170
171#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
172pub enum DimName {
173    #[serde(rename = "batch")]
174    Batch,
175    #[serde(rename = "height")]
176    Height,
177    #[serde(rename = "width")]
178    Width,
179    #[serde(rename = "num_classes")]
180    NumClasses,
181    #[serde(rename = "num_features")]
182    NumFeatures,
183    #[serde(rename = "num_boxes")]
184    NumBoxes,
185    #[serde(rename = "num_protos")]
186    NumProtos,
187    #[serde(rename = "num_anchors_x_features")]
188    NumAnchorsXFeatures,
189    #[serde(rename = "padding")]
190    Padding,
191    #[serde(rename = "box_coords")]
192    BoxCoords,
193}
194
195impl Display for DimName {
196    /// Formats the DimName for display
197    /// # Examples
198    /// ```rust
199    /// # use edgefirst_decoder::configs::DimName;
200    /// let dim = DimName::Height;
201    /// assert_eq!(format!("{}", dim), "height");
202    /// # let s = format!("{} {} {} {} {} {} {} {} {} {}", DimName::Batch, DimName::Height, DimName::Width, DimName::NumClasses, DimName::NumFeatures, DimName::NumBoxes, DimName::NumProtos, DimName::NumAnchorsXFeatures, DimName::Padding, DimName::BoxCoords);
203    /// # assert_eq!(s, "batch height width num_classes num_features num_boxes num_protos num_anchors_x_features padding box_coords");
204    /// ```
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        match self {
207            DimName::Batch => write!(f, "batch"),
208            DimName::Height => write!(f, "height"),
209            DimName::Width => write!(f, "width"),
210            DimName::NumClasses => write!(f, "num_classes"),
211            DimName::NumFeatures => write!(f, "num_features"),
212            DimName::NumBoxes => write!(f, "num_boxes"),
213            DimName::NumProtos => write!(f, "num_protos"),
214            DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
215            DimName::Padding => write!(f, "padding"),
216            DimName::BoxCoords => write!(f, "box_coords"),
217        }
218    }
219}
220
221#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
222pub enum DecoderType {
223    #[serde(rename = "modelpack")]
224    ModelPack,
225    #[default]
226    #[serde(rename = "ultralytics", alias = "yolov8")]
227    Ultralytics,
228}
229
230/// Decoder version for Ultralytics models.
231///
232/// Specifies the YOLO architecture version, which determines the decoding
233/// strategy:
234/// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
235///   NMS
236/// - `Yolo26`: End-to-end models with NMS embedded in the model
237///   architecture
238///
239/// When `decoder_version` is set to `Yolo26`, the decoder uses end-to-end
240/// model types regardless of the `nms` setting.
241#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
242#[serde(rename_all = "lowercase")]
243pub enum DecoderVersion {
244    /// YOLOv5 - anchor-free DFL decoder, requires external NMS
245    #[serde(rename = "yolov5")]
246    Yolov5,
247    /// YOLOv8 - anchor-free DFL decoder, requires external NMS
248    #[serde(rename = "yolov8")]
249    Yolov8,
250    /// YOLO11 - anchor-free DFL decoder, requires external NMS
251    #[serde(rename = "yolo11")]
252    Yolo11,
253    /// YOLO26 - end-to-end model with embedded NMS (one-to-one matching
254    /// heads)
255    #[serde(rename = "yolo26")]
256    Yolo26,
257}
258
259impl DecoderVersion {
260    /// Returns true if this version uses end-to-end inference (embedded
261    /// NMS).
262    pub fn is_end_to_end(&self) -> bool {
263        matches!(self, DecoderVersion::Yolo26)
264    }
265}
266
267/// NMS (Non-Maximum Suppression) mode for filtering overlapping detections.
268///
269/// This enum is used with `Option<Nms>`:
270/// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
271///   overlapping boxes regardless of class label
272/// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
273///   share the same class label AND overlap above the IoU threshold
274/// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
275#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
276#[serde(rename_all = "snake_case")]
277pub enum Nms {
278    /// Suppress overlapping boxes regardless of class label (default HAL
279    /// behavior)
280    #[default]
281    ClassAgnostic,
282    /// Only suppress boxes with the same class label that overlap
283    ClassAware,
284}
285
286#[derive(Debug, Clone, PartialEq)]
287pub enum ModelType {
288    ModelPackSegDet {
289        boxes: Boxes,
290        scores: Scores,
291        segmentation: Segmentation,
292    },
293    ModelPackSegDetSplit {
294        detection: Vec<Detection>,
295        segmentation: Segmentation,
296    },
297    ModelPackDet {
298        boxes: Boxes,
299        scores: Scores,
300    },
301    ModelPackDetSplit {
302        detection: Vec<Detection>,
303    },
304    ModelPackSeg {
305        segmentation: Segmentation,
306    },
307    YoloDet {
308        boxes: Detection,
309    },
310    YoloSegDet {
311        boxes: Detection,
312        protos: Protos,
313    },
314    YoloSplitDet {
315        boxes: Boxes,
316        scores: Scores,
317    },
318    YoloSplitSegDet {
319        boxes: Boxes,
320        scores: Scores,
321        mask_coeff: MaskCoefficients,
322        protos: Protos,
323    },
324    /// 2-way split YOLO segmentation detection.
325    /// Combined detection tensor (boxes + scores) with separate mask
326    /// coefficients and prototype masks.
327    /// - detection: [1, nc+4, N] — boxes and scores combined
328    /// - mask_coeff: [1, 32, N] — mask coefficients (separate tensor)
329    /// - protos: [1, H/4, W/4, 32] — prototype masks
330    YoloSegDet2Way {
331        boxes: Detection,
332        mask_coeff: MaskCoefficients,
333        protos: Protos,
334    },
335    /// End-to-end YOLO detection (post-NMS output from model)
336    /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf,
337    /// class, ...]
338    YoloEndToEndDet {
339        boxes: Detection,
340    },
341    /// End-to-end YOLO detection + segmentation (post-NMS output from
342    /// model) Input shape: (1, N, 6 + num_protos) where columns are
343    /// [x1, y1, x2, y2, conf, class, mask_coeff_0, ..., mask_coeff_31]
344    YoloEndToEndSegDet {
345        boxes: Detection,
346        protos: Protos,
347    },
348    /// Split end-to-end YOLO detection (onnx2tf splits [1,N,6] into 3
349    /// tensors) boxes: [batch, N, 4] xyxy, scores: [batch, N, 1],
350    /// classes: [batch, N, 1]
351    YoloSplitEndToEndDet {
352        boxes: Boxes,
353        scores: Scores,
354        classes: Classes,
355    },
356    /// Split end-to-end YOLO seg detection (onnx2tf splits into 5
357    /// tensors)
358    YoloSplitEndToEndSegDet {
359        boxes: Boxes,
360        scores: Scores,
361        classes: Classes,
362        mask_coeff: MaskCoefficients,
363        protos: Protos,
364    },
365}