Skip to main content

edgefirst_decoder/decoder/
configs.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use serde::{Deserialize, Serialize};
8
9/// Deserialize dshape from either array-of-tuples or array-of-single-key-dicts.
10///
11/// The metadata spec produces `[{"batch": 1}, {"num_features": 84}]` (dict format),
12/// while serde's default `Vec<(A, B)>` expects `[["batch", 1]]` (tuple format).
13/// This deserializer accepts both.
14pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
15where
16    D: serde::Deserializer<'de>,
17{
18    #[derive(Deserialize)]
19    #[serde(untagged)]
20    enum DShapeItem {
21        Tuple(DimName, usize),
22        Map(HashMap<DimName, usize>),
23    }
24
25    let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
26    items
27        .into_iter()
28        .map(|item| match item {
29            DShapeItem::Tuple(name, size) => Ok((name, size)),
30            DShapeItem::Map(map) => {
31                if map.len() != 1 {
32                    return Err(serde::de::Error::custom(
33                        "dshape map entry must have exactly one key",
34                    ));
35                }
36                let (name, size) = map.into_iter().next().unwrap();
37                Ok((name, size))
38            }
39        })
40        .collect()
41}
42
43#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
44pub struct QuantTuple(pub f32, pub i32);
45impl From<QuantTuple> for (f32, i32) {
46    fn from(value: QuantTuple) -> Self {
47        (value.0, value.1)
48    }
49}
50
51impl From<(f32, i32)> for QuantTuple {
52    fn from(value: (f32, i32)) -> Self {
53        QuantTuple(value.0, value.1)
54    }
55}
56
57#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
58pub struct Segmentation {
59    #[serde(default)]
60    pub decoder: DecoderType,
61    #[serde(default)]
62    pub quantization: Option<QuantTuple>,
63    #[serde(default)]
64    pub shape: Vec<usize>,
65    #[serde(default, deserialize_with = "deserialize_dshape")]
66    pub dshape: Vec<(DimName, usize)>,
67}
68
69#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
70pub struct Protos {
71    #[serde(default)]
72    pub decoder: DecoderType,
73    #[serde(default)]
74    pub quantization: Option<QuantTuple>,
75    #[serde(default)]
76    pub shape: Vec<usize>,
77    #[serde(default, deserialize_with = "deserialize_dshape")]
78    pub dshape: Vec<(DimName, usize)>,
79}
80
81#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
82pub struct MaskCoefficients {
83    #[serde(default)]
84    pub decoder: DecoderType,
85    #[serde(default)]
86    pub quantization: Option<QuantTuple>,
87    #[serde(default)]
88    pub shape: Vec<usize>,
89    #[serde(default, deserialize_with = "deserialize_dshape")]
90    pub dshape: Vec<(DimName, usize)>,
91}
92
93#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
94pub struct Mask {
95    #[serde(default)]
96    pub decoder: DecoderType,
97    #[serde(default)]
98    pub quantization: Option<QuantTuple>,
99    #[serde(default)]
100    pub shape: Vec<usize>,
101    #[serde(default, deserialize_with = "deserialize_dshape")]
102    pub dshape: Vec<(DimName, usize)>,
103}
104
105#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
106pub struct Detection {
107    #[serde(default)]
108    pub anchors: Option<Vec<[f32; 2]>>,
109    #[serde(default)]
110    pub decoder: DecoderType,
111    #[serde(default)]
112    pub quantization: Option<QuantTuple>,
113    #[serde(default)]
114    pub shape: Vec<usize>,
115    #[serde(default, deserialize_with = "deserialize_dshape")]
116    pub dshape: Vec<(DimName, usize)>,
117    /// Whether box coordinates are normalized to [0,1] range.
118    /// - `Some(true)`: Coordinates in [0,1] range relative to model input
119    /// - `Some(false)`: Pixel coordinates relative to model input
120    ///   (letterboxed)
121    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
122    ///   > 1.0)
123    #[serde(default)]
124    pub normalized: Option<bool>,
125}
126
127#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
128pub struct Scores {
129    #[serde(default)]
130    pub decoder: DecoderType,
131    #[serde(default)]
132    pub quantization: Option<QuantTuple>,
133    #[serde(default)]
134    pub shape: Vec<usize>,
135    #[serde(default, deserialize_with = "deserialize_dshape")]
136    pub dshape: Vec<(DimName, usize)>,
137}
138
139#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
140pub struct Boxes {
141    #[serde(default)]
142    pub decoder: DecoderType,
143    #[serde(default)]
144    pub quantization: Option<QuantTuple>,
145    #[serde(default)]
146    pub shape: Vec<usize>,
147    #[serde(default, deserialize_with = "deserialize_dshape")]
148    pub dshape: Vec<(DimName, usize)>,
149    /// Whether box coordinates are normalized to [0,1] range.
150    /// - `Some(true)`: Coordinates in [0,1] range relative to model input
151    /// - `Some(false)`: Pixel coordinates relative to model input
152    ///   (letterboxed)
153    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
154    ///   > 1.0)
155    #[serde(default)]
156    pub normalized: Option<bool>,
157}
158
159#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
160pub struct Classes {
161    #[serde(default)]
162    pub decoder: DecoderType,
163    #[serde(default)]
164    pub quantization: Option<QuantTuple>,
165    #[serde(default)]
166    pub shape: Vec<usize>,
167    #[serde(default, deserialize_with = "deserialize_dshape")]
168    pub dshape: Vec<(DimName, usize)>,
169}
170
171#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
172pub enum DimName {
173    #[serde(rename = "batch")]
174    Batch,
175    #[serde(rename = "height")]
176    Height,
177    #[serde(rename = "width")]
178    Width,
179    #[serde(rename = "num_classes")]
180    NumClasses,
181    #[serde(rename = "num_features")]
182    NumFeatures,
183    #[serde(rename = "num_boxes")]
184    NumBoxes,
185    #[serde(rename = "num_protos")]
186    NumProtos,
187    #[serde(rename = "num_anchors_x_features")]
188    NumAnchorsXFeatures,
189    #[serde(rename = "padding")]
190    Padding,
191    #[serde(rename = "box_coords")]
192    BoxCoords,
193}
194
195impl Display for DimName {
196    /// Formats the DimName for display
197    /// # Examples
198    /// ```rust
199    /// # use edgefirst_decoder::configs::DimName;
200    /// let dim = DimName::Height;
201    /// assert_eq!(format!("{}", dim), "height");
202    /// # let s = format!("{} {} {} {} {} {} {} {} {} {}", DimName::Batch, DimName::Height, DimName::Width, DimName::NumClasses, DimName::NumFeatures, DimName::NumBoxes, DimName::NumProtos, DimName::NumAnchorsXFeatures, DimName::Padding, DimName::BoxCoords);
203    /// # assert_eq!(s, "batch height width num_classes num_features num_boxes num_protos num_anchors_x_features padding box_coords");
204    /// ```
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        match self {
207            DimName::Batch => write!(f, "batch"),
208            DimName::Height => write!(f, "height"),
209            DimName::Width => write!(f, "width"),
210            DimName::NumClasses => write!(f, "num_classes"),
211            DimName::NumFeatures => write!(f, "num_features"),
212            DimName::NumBoxes => write!(f, "num_boxes"),
213            DimName::NumProtos => write!(f, "num_protos"),
214            DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
215            DimName::Padding => write!(f, "padding"),
216            DimName::BoxCoords => write!(f, "box_coords"),
217        }
218    }
219}
220
221#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
222pub enum DecoderType {
223    #[serde(rename = "modelpack")]
224    ModelPack,
225    #[default]
226    #[serde(rename = "ultralytics", alias = "yolov8")]
227    Ultralytics,
228}
229
230/// Decoder version for Ultralytics models.
231///
232/// Specifies the YOLO architecture version, which determines the decoding
233/// strategy:
234/// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
235///   NMS
236/// - `Yolo26`: End-to-end models with NMS embedded in the model
237///   architecture
238///
239/// When `decoder_version` is set to `Yolo26`, the decoder uses end-to-end
240/// model types regardless of the `nms` setting.
241#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
242#[serde(rename_all = "lowercase")]
243pub enum DecoderVersion {
244    /// YOLOv5 - anchor-free DFL decoder, requires external NMS
245    #[serde(rename = "yolov5")]
246    Yolov5,
247    /// YOLOv8 - anchor-free DFL decoder, requires external NMS
248    #[serde(rename = "yolov8")]
249    Yolov8,
250    /// YOLO11 - anchor-free DFL decoder, requires external NMS
251    #[serde(rename = "yolo11")]
252    Yolo11,
253    /// YOLO26 - end-to-end model with embedded NMS (one-to-one matching
254    /// heads)
255    #[serde(rename = "yolo26")]
256    Yolo26,
257}
258
259impl DecoderVersion {
260    /// Returns true if this version uses end-to-end inference (embedded
261    /// NMS).
262    pub fn is_end_to_end(&self) -> bool {
263        matches!(self, DecoderVersion::Yolo26)
264    }
265}
266
267/// NMS (Non-Maximum Suppression) mode for filtering overlapping detections.
268///
269/// This enum is used with `Option<Nms>`:
270/// - `Some(Nms::Auto)` — resolve from config or fall back to `ClassAgnostic`
271/// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
272///   boxes regardless of class label
273/// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
274///   share the same class label AND overlap above the IoU threshold
275/// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
276#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
277#[serde(rename_all = "snake_case")]
278pub enum Nms {
279    /// Let the builder resolve NMS mode from the model config (e.g.
280    /// `edgefirst.json`).  Falls back to [`Nms::ClassAgnostic`] when no
281    /// config specifies a mode.  This is the builder default — callers
282    /// should only use an explicit variant when they need to override
283    /// the config.
284    Auto,
285    /// Suppress overlapping boxes regardless of class label (default
286    /// concrete behavior).
287    #[default]
288    ClassAgnostic,
289    /// Only suppress boxes with the same class label that overlap.
290    ClassAware,
291}
292
293#[derive(Debug, Clone, PartialEq)]
294pub enum ModelType {
295    ModelPackSegDet {
296        boxes: Boxes,
297        scores: Scores,
298        segmentation: Segmentation,
299    },
300    ModelPackSegDetSplit {
301        detection: Vec<Detection>,
302        segmentation: Segmentation,
303    },
304    ModelPackDet {
305        boxes: Boxes,
306        scores: Scores,
307    },
308    ModelPackDetSplit {
309        detection: Vec<Detection>,
310    },
311    ModelPackSeg {
312        segmentation: Segmentation,
313    },
314    YoloDet {
315        boxes: Detection,
316    },
317    YoloSegDet {
318        boxes: Detection,
319        protos: Protos,
320    },
321    YoloSplitDet {
322        boxes: Boxes,
323        scores: Scores,
324    },
325    YoloSplitSegDet {
326        boxes: Boxes,
327        scores: Scores,
328        mask_coeff: MaskCoefficients,
329        protos: Protos,
330    },
331    /// 2-way split YOLO segmentation detection.
332    /// Combined detection tensor (boxes + scores) with separate mask
333    /// coefficients and prototype masks.
334    /// - detection: [1, nc+4, N] — boxes and scores combined
335    /// - mask_coeff: [1, 32, N] — mask coefficients (separate tensor)
336    /// - protos: [1, H/4, W/4, 32] — prototype masks
337    YoloSegDet2Way {
338        boxes: Detection,
339        mask_coeff: MaskCoefficients,
340        protos: Protos,
341    },
342    /// End-to-end YOLO detection (post-NMS output from model)
343    /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf,
344    /// class, ...]
345    YoloEndToEndDet {
346        boxes: Detection,
347    },
348    /// End-to-end YOLO detection + segmentation (post-NMS output from
349    /// model) Input shape: (1, N, 6 + num_protos) where columns are
350    /// [x1, y1, x2, y2, conf, class, mask_coeff_0, ..., mask_coeff_31]
351    YoloEndToEndSegDet {
352        boxes: Detection,
353        protos: Protos,
354    },
355    /// Split end-to-end YOLO detection (onnx2tf splits [1,N,6] into 3
356    /// tensors) boxes: [batch, N, 4] xyxy, scores: [batch, N, 1],
357    /// classes: [batch, N, 1]
358    YoloSplitEndToEndDet {
359        boxes: Boxes,
360        scores: Scores,
361        classes: Classes,
362    },
363    /// Split end-to-end YOLO seg detection (onnx2tf splits into 5
364    /// tensors)
365    YoloSplitEndToEndSegDet {
366        boxes: Boxes,
367        scores: Scores,
368        classes: Classes,
369        mask_coeff: MaskCoefficients,
370        protos: Protos,
371    },
372    /// Per-scale (physical-output-decomposition) YOLO model. The
373    /// per-scale subsystem (`crates/decoder/src/per_scale/`) owns
374    /// model decoding entirely; this variant exists as a marker so the
375    /// `Decoder::model_type` field has a sensible value for per-scale
376    /// Decoders that bypass the legacy `ModelType`-driven dispatch.
377    PerScale,
378}