1use std::collections::HashMap;
5use std::fmt::Display;
6
7use serde::{Deserialize, Serialize};
8
9pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
15where
16 D: serde::Deserializer<'de>,
17{
18 #[derive(Deserialize)]
19 #[serde(untagged)]
20 enum DShapeItem {
21 Tuple(DimName, usize),
22 Map(HashMap<DimName, usize>),
23 }
24
25 let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
26 items
27 .into_iter()
28 .map(|item| match item {
29 DShapeItem::Tuple(name, size) => Ok((name, size)),
30 DShapeItem::Map(map) => {
31 if map.len() != 1 {
32 return Err(serde::de::Error::custom(
33 "dshape map entry must have exactly one key",
34 ));
35 }
36 let (name, size) = map.into_iter().next().unwrap();
37 Ok((name, size))
38 }
39 })
40 .collect()
41}
42
43#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
44pub struct QuantTuple(pub f32, pub i32);
45impl From<QuantTuple> for (f32, i32) {
46 fn from(value: QuantTuple) -> Self {
47 (value.0, value.1)
48 }
49}
50
51impl From<(f32, i32)> for QuantTuple {
52 fn from(value: (f32, i32)) -> Self {
53 QuantTuple(value.0, value.1)
54 }
55}
56
57#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
58pub struct Segmentation {
59 #[serde(default)]
60 pub decoder: DecoderType,
61 #[serde(default)]
62 pub quantization: Option<QuantTuple>,
63 #[serde(default)]
64 pub shape: Vec<usize>,
65 #[serde(default, deserialize_with = "deserialize_dshape")]
66 pub dshape: Vec<(DimName, usize)>,
67}
68
69#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
70pub struct Protos {
71 #[serde(default)]
72 pub decoder: DecoderType,
73 #[serde(default)]
74 pub quantization: Option<QuantTuple>,
75 #[serde(default)]
76 pub shape: Vec<usize>,
77 #[serde(default, deserialize_with = "deserialize_dshape")]
78 pub dshape: Vec<(DimName, usize)>,
79}
80
81#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
82pub struct MaskCoefficients {
83 #[serde(default)]
84 pub decoder: DecoderType,
85 #[serde(default)]
86 pub quantization: Option<QuantTuple>,
87 #[serde(default)]
88 pub shape: Vec<usize>,
89 #[serde(default, deserialize_with = "deserialize_dshape")]
90 pub dshape: Vec<(DimName, usize)>,
91}
92
93#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
94pub struct Mask {
95 #[serde(default)]
96 pub decoder: DecoderType,
97 #[serde(default)]
98 pub quantization: Option<QuantTuple>,
99 #[serde(default)]
100 pub shape: Vec<usize>,
101 #[serde(default, deserialize_with = "deserialize_dshape")]
102 pub dshape: Vec<(DimName, usize)>,
103}
104
105#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
106pub struct Detection {
107 #[serde(default)]
108 pub anchors: Option<Vec<[f32; 2]>>,
109 #[serde(default)]
110 pub decoder: DecoderType,
111 #[serde(default)]
112 pub quantization: Option<QuantTuple>,
113 #[serde(default)]
114 pub shape: Vec<usize>,
115 #[serde(default, deserialize_with = "deserialize_dshape")]
116 pub dshape: Vec<(DimName, usize)>,
117 #[serde(default)]
124 pub normalized: Option<bool>,
125}
126
127#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
128pub struct Scores {
129 #[serde(default)]
130 pub decoder: DecoderType,
131 #[serde(default)]
132 pub quantization: Option<QuantTuple>,
133 #[serde(default)]
134 pub shape: Vec<usize>,
135 #[serde(default, deserialize_with = "deserialize_dshape")]
136 pub dshape: Vec<(DimName, usize)>,
137}
138
139#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
140pub struct Boxes {
141 #[serde(default)]
142 pub decoder: DecoderType,
143 #[serde(default)]
144 pub quantization: Option<QuantTuple>,
145 #[serde(default)]
146 pub shape: Vec<usize>,
147 #[serde(default, deserialize_with = "deserialize_dshape")]
148 pub dshape: Vec<(DimName, usize)>,
149 #[serde(default)]
156 pub normalized: Option<bool>,
157}
158
159#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
160pub struct Classes {
161 #[serde(default)]
162 pub decoder: DecoderType,
163 #[serde(default)]
164 pub quantization: Option<QuantTuple>,
165 #[serde(default)]
166 pub shape: Vec<usize>,
167 #[serde(default, deserialize_with = "deserialize_dshape")]
168 pub dshape: Vec<(DimName, usize)>,
169}
170
171#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
172pub enum DimName {
173 #[serde(rename = "batch")]
174 Batch,
175 #[serde(rename = "height")]
176 Height,
177 #[serde(rename = "width")]
178 Width,
179 #[serde(rename = "num_classes")]
180 NumClasses,
181 #[serde(rename = "num_features")]
182 NumFeatures,
183 #[serde(rename = "num_boxes")]
184 NumBoxes,
185 #[serde(rename = "num_protos")]
186 NumProtos,
187 #[serde(rename = "num_anchors_x_features")]
188 NumAnchorsXFeatures,
189 #[serde(rename = "padding")]
190 Padding,
191 #[serde(rename = "box_coords")]
192 BoxCoords,
193}
194
195impl Display for DimName {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 match self {
207 DimName::Batch => write!(f, "batch"),
208 DimName::Height => write!(f, "height"),
209 DimName::Width => write!(f, "width"),
210 DimName::NumClasses => write!(f, "num_classes"),
211 DimName::NumFeatures => write!(f, "num_features"),
212 DimName::NumBoxes => write!(f, "num_boxes"),
213 DimName::NumProtos => write!(f, "num_protos"),
214 DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
215 DimName::Padding => write!(f, "padding"),
216 DimName::BoxCoords => write!(f, "box_coords"),
217 }
218 }
219}
220
221#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
222pub enum DecoderType {
223 #[serde(rename = "modelpack")]
224 ModelPack,
225 #[default]
226 #[serde(rename = "ultralytics", alias = "yolov8")]
227 Ultralytics,
228}
229
230#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
242#[serde(rename_all = "lowercase")]
243pub enum DecoderVersion {
244 #[serde(rename = "yolov5")]
246 Yolov5,
247 #[serde(rename = "yolov8")]
249 Yolov8,
250 #[serde(rename = "yolo11")]
252 Yolo11,
253 #[serde(rename = "yolo26")]
256 Yolo26,
257}
258
259impl DecoderVersion {
260 pub fn is_end_to_end(&self) -> bool {
263 matches!(self, DecoderVersion::Yolo26)
264 }
265}
266
267#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
277#[serde(rename_all = "snake_case")]
278pub enum Nms {
279 Auto,
285 #[default]
288 ClassAgnostic,
289 ClassAware,
291}
292
293#[derive(Debug, Clone, PartialEq)]
294pub enum ModelType {
295 ModelPackSegDet {
296 boxes: Boxes,
297 scores: Scores,
298 segmentation: Segmentation,
299 },
300 ModelPackSegDetSplit {
301 detection: Vec<Detection>,
302 segmentation: Segmentation,
303 },
304 ModelPackDet {
305 boxes: Boxes,
306 scores: Scores,
307 },
308 ModelPackDetSplit {
309 detection: Vec<Detection>,
310 },
311 ModelPackSeg {
312 segmentation: Segmentation,
313 },
314 YoloDet {
315 boxes: Detection,
316 },
317 YoloSegDet {
318 boxes: Detection,
319 protos: Protos,
320 },
321 YoloSplitDet {
322 boxes: Boxes,
323 scores: Scores,
324 },
325 YoloSplitSegDet {
326 boxes: Boxes,
327 scores: Scores,
328 mask_coeff: MaskCoefficients,
329 protos: Protos,
330 },
331 YoloSegDet2Way {
338 boxes: Detection,
339 mask_coeff: MaskCoefficients,
340 protos: Protos,
341 },
342 YoloEndToEndDet {
346 boxes: Detection,
347 },
348 YoloEndToEndSegDet {
352 boxes: Detection,
353 protos: Protos,
354 },
355 YoloSplitEndToEndDet {
359 boxes: Boxes,
360 scores: Scores,
361 classes: Classes,
362 },
363 YoloSplitEndToEndSegDet {
366 boxes: Boxes,
367 scores: Scores,
368 classes: Classes,
369 mask_coeff: MaskCoefficients,
370 protos: Protos,
371 },
372 PerScale,
378}