Skip to main content

edgefirst_decoder/decoder/
configs.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use serde::{Deserialize, Serialize};
8
9/// Deserialize dshape from either array-of-tuples or array-of-single-key-dicts.
10///
11/// The metadata spec produces `[{"batch": 1}, {"num_features": 84}]` (dict format),
12/// while serde's default `Vec<(A, B)>` expects `[["batch", 1]]` (tuple format).
13/// This deserializer accepts both.
14pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
15where
16    D: serde::Deserializer<'de>,
17{
18    #[derive(Deserialize)]
19    #[serde(untagged)]
20    enum DShapeItem {
21        Tuple(DimName, usize),
22        Map(HashMap<DimName, usize>),
23    }
24
25    let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
26    items
27        .into_iter()
28        .map(|item| match item {
29            DShapeItem::Tuple(name, size) => Ok((name, size)),
30            DShapeItem::Map(map) => {
31                if map.len() != 1 {
32                    return Err(serde::de::Error::custom(
33                        "dshape map entry must have exactly one key",
34                    ));
35                }
36                let (name, size) = map.into_iter().next().unwrap();
37                Ok((name, size))
38            }
39        })
40        .collect()
41}
42
43#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
44pub struct QuantTuple(pub f32, pub i32);
45impl From<QuantTuple> for (f32, i32) {
46    fn from(value: QuantTuple) -> Self {
47        (value.0, value.1)
48    }
49}
50
51impl From<(f32, i32)> for QuantTuple {
52    fn from(value: (f32, i32)) -> Self {
53        QuantTuple(value.0, value.1)
54    }
55}
56
57#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
58pub struct Segmentation {
59    #[serde(default)]
60    pub decoder: DecoderType,
61    #[serde(default)]
62    pub quantization: Option<QuantTuple>,
63    #[serde(default)]
64    pub shape: Vec<usize>,
65    #[serde(default, deserialize_with = "deserialize_dshape")]
66    pub dshape: Vec<(DimName, usize)>,
67}
68
69#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
70pub struct Protos {
71    #[serde(default)]
72    pub decoder: DecoderType,
73    #[serde(default)]
74    pub quantization: Option<QuantTuple>,
75    #[serde(default)]
76    pub shape: Vec<usize>,
77    #[serde(default, deserialize_with = "deserialize_dshape")]
78    pub dshape: Vec<(DimName, usize)>,
79}
80
81#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
82pub struct MaskCoefficients {
83    #[serde(default)]
84    pub decoder: DecoderType,
85    #[serde(default)]
86    pub quantization: Option<QuantTuple>,
87    #[serde(default)]
88    pub shape: Vec<usize>,
89    #[serde(default, deserialize_with = "deserialize_dshape")]
90    pub dshape: Vec<(DimName, usize)>,
91}
92
93#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
94pub struct Mask {
95    #[serde(default)]
96    pub decoder: DecoderType,
97    #[serde(default)]
98    pub quantization: Option<QuantTuple>,
99    #[serde(default)]
100    pub shape: Vec<usize>,
101    #[serde(default, deserialize_with = "deserialize_dshape")]
102    pub dshape: Vec<(DimName, usize)>,
103}
104
105#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
106pub struct Detection {
107    #[serde(default)]
108    pub anchors: Option<Vec<[f32; 2]>>,
109    #[serde(default)]
110    pub decoder: DecoderType,
111    #[serde(default)]
112    pub quantization: Option<QuantTuple>,
113    #[serde(default)]
114    pub shape: Vec<usize>,
115    #[serde(default, deserialize_with = "deserialize_dshape")]
116    pub dshape: Vec<(DimName, usize)>,
117    /// Whether box coordinates are normalized to [0,1] range.
118    /// - `Some(true)`: Coordinates in [0,1] range relative to model input
119    /// - `Some(false)`: Pixel coordinates relative to model input
120    ///   (letterboxed)
121    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
122    ///   > 1.0)
123    #[serde(default)]
124    pub normalized: Option<bool>,
125}
126
127#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
128pub struct Scores {
129    #[serde(default)]
130    pub decoder: DecoderType,
131    #[serde(default)]
132    pub quantization: Option<QuantTuple>,
133    #[serde(default)]
134    pub shape: Vec<usize>,
135    #[serde(default, deserialize_with = "deserialize_dshape")]
136    pub dshape: Vec<(DimName, usize)>,
137}
138
139#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
140pub struct Boxes {
141    #[serde(default)]
142    pub decoder: DecoderType,
143    #[serde(default)]
144    pub quantization: Option<QuantTuple>,
145    #[serde(default)]
146    pub shape: Vec<usize>,
147    #[serde(default, deserialize_with = "deserialize_dshape")]
148    pub dshape: Vec<(DimName, usize)>,
149    /// Whether box coordinates are normalized to [0,1] range.
150    /// - `Some(true)`: Coordinates in [0,1] range relative to model input
151    /// - `Some(false)`: Pixel coordinates relative to model input
152    ///   (letterboxed)
153    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
154    ///   > 1.0)
155    #[serde(default)]
156    pub normalized: Option<bool>,
157}
158
159#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
160pub struct Classes {
161    #[serde(default)]
162    pub decoder: DecoderType,
163    #[serde(default)]
164    pub quantization: Option<QuantTuple>,
165    #[serde(default)]
166    pub shape: Vec<usize>,
167    #[serde(default, deserialize_with = "deserialize_dshape")]
168    pub dshape: Vec<(DimName, usize)>,
169}
170
171#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
172pub enum DimName {
173    #[serde(rename = "batch")]
174    Batch,
175    #[serde(rename = "height")]
176    Height,
177    #[serde(rename = "width")]
178    Width,
179    #[serde(rename = "num_classes")]
180    NumClasses,
181    #[serde(rename = "num_features")]
182    NumFeatures,
183    #[serde(rename = "num_boxes")]
184    NumBoxes,
185    #[serde(rename = "num_protos")]
186    NumProtos,
187    #[serde(rename = "num_anchors_x_features")]
188    NumAnchorsXFeatures,
189    #[serde(rename = "padding")]
190    Padding,
191    #[serde(rename = "box_coords")]
192    BoxCoords,
193    /// Any axis name the HAL does not recognise (e.g. a producer's
194    /// `channels` on the input dshape). Preserved so the dshape length still
195    /// matches the shape and the axis sorts to the canonical tail in
196    /// `swap_axes_if_needed`, but it never satisfies a required-dimension
197    /// check. Keeps metadata parsing tolerant of unknown axis names instead
198    /// of failing the whole decoder build (DE-2651).
199    #[serde(other)]
200    Unknown,
201}
202
203impl Display for DimName {
204    /// Formats the DimName for display
205    /// # Examples
206    /// ```rust
207    /// # use edgefirst_decoder::configs::DimName;
208    /// let dim = DimName::Height;
209    /// assert_eq!(format!("{}", dim), "height");
210    /// # let s = format!("{} {} {} {} {} {} {} {} {} {}", DimName::Batch, DimName::Height, DimName::Width, DimName::NumClasses, DimName::NumFeatures, DimName::NumBoxes, DimName::NumProtos, DimName::NumAnchorsXFeatures, DimName::Padding, DimName::BoxCoords);
211    /// # assert_eq!(s, "batch height width num_classes num_features num_boxes num_protos num_anchors_x_features padding box_coords");
212    /// ```
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            DimName::Batch => write!(f, "batch"),
216            DimName::Height => write!(f, "height"),
217            DimName::Width => write!(f, "width"),
218            DimName::NumClasses => write!(f, "num_classes"),
219            DimName::NumFeatures => write!(f, "num_features"),
220            DimName::NumBoxes => write!(f, "num_boxes"),
221            DimName::NumProtos => write!(f, "num_protos"),
222            DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
223            DimName::Padding => write!(f, "padding"),
224            DimName::BoxCoords => write!(f, "box_coords"),
225            DimName::Unknown => write!(f, "unknown"),
226        }
227    }
228}
229
230#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
231pub enum DecoderType {
232    #[serde(rename = "modelpack")]
233    ModelPack,
234    #[default]
235    #[serde(rename = "ultralytics", alias = "yolov8")]
236    Ultralytics,
237}
238
239/// Decoder version for Ultralytics models.
240///
241/// Specifies the YOLO architecture version, which determines the decoding
242/// strategy:
243/// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
244///   NMS
245/// - `Yolo26`: End-to-end models with NMS embedded in the model
246///   architecture
247///
248/// When `decoder_version` is set to `Yolo26`, the decoder uses end-to-end
249/// model types regardless of the `nms` setting.
250#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
251#[serde(rename_all = "lowercase")]
252pub enum DecoderVersion {
253    /// YOLOv5 - anchor-free DFL decoder, requires external NMS
254    #[serde(rename = "yolov5")]
255    Yolov5,
256    /// YOLOv8 - anchor-free DFL decoder, requires external NMS
257    #[serde(rename = "yolov8")]
258    Yolov8,
259    /// YOLO11 - anchor-free DFL decoder, requires external NMS
260    #[serde(rename = "yolo11")]
261    Yolo11,
262    /// YOLO26 - end-to-end model with embedded NMS (one-to-one matching
263    /// heads)
264    #[serde(rename = "yolo26")]
265    Yolo26,
266}
267
268impl DecoderVersion {
269    /// Returns true if this version uses end-to-end inference (embedded
270    /// NMS).
271    pub fn is_end_to_end(&self) -> bool {
272        matches!(self, DecoderVersion::Yolo26)
273    }
274}
275
276/// NMS (Non-Maximum Suppression) mode for filtering overlapping detections.
277///
278/// This enum is used with `Option<Nms>`:
279/// - `Some(Nms::Auto)` — resolve from config or fall back to `ClassAgnostic`
280/// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
281///   boxes regardless of class label
282/// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
283///   share the same class label AND overlap above the IoU threshold
284/// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
285#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
286#[serde(rename_all = "snake_case")]
287pub enum Nms {
288    /// Let the builder resolve NMS mode from the model config (e.g.
289    /// `edgefirst.json`).  Falls back to [`Nms::ClassAgnostic`] when no
290    /// config specifies a mode.  This is the builder default — callers
291    /// should only use an explicit variant when they need to override
292    /// the config.
293    Auto,
294    /// Suppress overlapping boxes regardless of class label (default
295    /// concrete behavior).
296    #[default]
297    ClassAgnostic,
298    /// Only suppress boxes with the same class label that overlap.
299    ClassAware,
300}
301
302#[derive(Debug, Clone, PartialEq)]
303pub enum ModelType {
304    ModelPackSegDet {
305        boxes: Boxes,
306        scores: Scores,
307        segmentation: Segmentation,
308    },
309    ModelPackSegDetSplit {
310        detection: Vec<Detection>,
311        segmentation: Segmentation,
312    },
313    ModelPackDet {
314        boxes: Boxes,
315        scores: Scores,
316    },
317    ModelPackDetSplit {
318        detection: Vec<Detection>,
319    },
320    ModelPackSeg {
321        segmentation: Segmentation,
322    },
323    YoloDet {
324        boxes: Detection,
325    },
326    YoloSegDet {
327        boxes: Detection,
328        protos: Protos,
329    },
330    YoloSplitDet {
331        boxes: Boxes,
332        scores: Scores,
333    },
334    YoloSplitSegDet {
335        boxes: Boxes,
336        scores: Scores,
337        mask_coeff: MaskCoefficients,
338        protos: Protos,
339    },
340    /// 2-way split YOLO segmentation detection.
341    /// Combined detection tensor (boxes + scores) with separate mask
342    /// coefficients and prototype masks.
343    /// - detection: [1, nc+4, N] — boxes and scores combined
344    /// - mask_coeff: [1, 32, N] — mask coefficients (separate tensor)
345    /// - protos: [1, H/4, W/4, 32] — prototype masks
346    YoloSegDet2Way {
347        boxes: Detection,
348        mask_coeff: MaskCoefficients,
349        protos: Protos,
350    },
351    /// End-to-end YOLO detection (post-NMS output from model)
352    /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf,
353    /// class, ...]
354    YoloEndToEndDet {
355        boxes: Detection,
356    },
357    /// End-to-end YOLO detection + segmentation (post-NMS output from
358    /// model) Input shape: (1, N, 6 + num_protos) where columns are
359    /// [x1, y1, x2, y2, conf, class, mask_coeff_0, ..., mask_coeff_31]
360    YoloEndToEndSegDet {
361        boxes: Detection,
362        protos: Protos,
363    },
364    /// Split end-to-end YOLO detection (onnx2tf splits [1,N,6] into 3
365    /// tensors) boxes: [batch, N, 4] xyxy, scores: [batch, N, 1],
366    /// classes: [batch, N, 1]
367    YoloSplitEndToEndDet {
368        boxes: Boxes,
369        scores: Scores,
370        classes: Classes,
371    },
372    /// Split end-to-end YOLO seg detection (onnx2tf splits into 5
373    /// tensors)
374    YoloSplitEndToEndSegDet {
375        boxes: Boxes,
376        scores: Scores,
377        classes: Classes,
378        mask_coeff: MaskCoefficients,
379        protos: Protos,
380    },
381    /// Per-scale (physical-output-decomposition) YOLO model. The
382    /// per-scale subsystem (`crates/decoder/src/per_scale/`) owns
383    /// model decoding entirely; this variant exists as a marker so the
384    /// `Decoder::model_type` field has a sensible value for per-scale
385    /// Decoders that bypass the legacy `ModelType`-driven dispatch.
386    PerScale,
387}