1use std::collections::HashMap;
5use std::fmt::Display;
6
7use serde::{Deserialize, Serialize};
8
9pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
15where
16 D: serde::Deserializer<'de>,
17{
18 #[derive(Deserialize)]
19 #[serde(untagged)]
20 enum DShapeItem {
21 Tuple(DimName, usize),
22 Map(HashMap<DimName, usize>),
23 }
24
25 let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
26 items
27 .into_iter()
28 .map(|item| match item {
29 DShapeItem::Tuple(name, size) => Ok((name, size)),
30 DShapeItem::Map(map) => {
31 if map.len() != 1 {
32 return Err(serde::de::Error::custom(
33 "dshape map entry must have exactly one key",
34 ));
35 }
36 let (name, size) = map.into_iter().next().unwrap();
37 Ok((name, size))
38 }
39 })
40 .collect()
41}
42
43#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
44pub struct QuantTuple(pub f32, pub i32);
45impl From<QuantTuple> for (f32, i32) {
46 fn from(value: QuantTuple) -> Self {
47 (value.0, value.1)
48 }
49}
50
51impl From<(f32, i32)> for QuantTuple {
52 fn from(value: (f32, i32)) -> Self {
53 QuantTuple(value.0, value.1)
54 }
55}
56
57#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
58pub struct Segmentation {
59 #[serde(default)]
60 pub decoder: DecoderType,
61 #[serde(default)]
62 pub quantization: Option<QuantTuple>,
63 #[serde(default)]
64 pub shape: Vec<usize>,
65 #[serde(default, deserialize_with = "deserialize_dshape")]
66 pub dshape: Vec<(DimName, usize)>,
67}
68
69#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
70pub struct Protos {
71 #[serde(default)]
72 pub decoder: DecoderType,
73 #[serde(default)]
74 pub quantization: Option<QuantTuple>,
75 #[serde(default)]
76 pub shape: Vec<usize>,
77 #[serde(default, deserialize_with = "deserialize_dshape")]
78 pub dshape: Vec<(DimName, usize)>,
79}
80
81#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
82pub struct MaskCoefficients {
83 #[serde(default)]
84 pub decoder: DecoderType,
85 #[serde(default)]
86 pub quantization: Option<QuantTuple>,
87 #[serde(default)]
88 pub shape: Vec<usize>,
89 #[serde(default, deserialize_with = "deserialize_dshape")]
90 pub dshape: Vec<(DimName, usize)>,
91}
92
93#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
94pub struct Mask {
95 #[serde(default)]
96 pub decoder: DecoderType,
97 #[serde(default)]
98 pub quantization: Option<QuantTuple>,
99 #[serde(default)]
100 pub shape: Vec<usize>,
101 #[serde(default, deserialize_with = "deserialize_dshape")]
102 pub dshape: Vec<(DimName, usize)>,
103}
104
105#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
106pub struct Detection {
107 #[serde(default)]
108 pub anchors: Option<Vec<[f32; 2]>>,
109 #[serde(default)]
110 pub decoder: DecoderType,
111 #[serde(default)]
112 pub quantization: Option<QuantTuple>,
113 #[serde(default)]
114 pub shape: Vec<usize>,
115 #[serde(default, deserialize_with = "deserialize_dshape")]
116 pub dshape: Vec<(DimName, usize)>,
117 #[serde(default)]
124 pub normalized: Option<bool>,
125}
126
127#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
128pub struct Scores {
129 #[serde(default)]
130 pub decoder: DecoderType,
131 #[serde(default)]
132 pub quantization: Option<QuantTuple>,
133 #[serde(default)]
134 pub shape: Vec<usize>,
135 #[serde(default, deserialize_with = "deserialize_dshape")]
136 pub dshape: Vec<(DimName, usize)>,
137}
138
139#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
140pub struct Boxes {
141 #[serde(default)]
142 pub decoder: DecoderType,
143 #[serde(default)]
144 pub quantization: Option<QuantTuple>,
145 #[serde(default)]
146 pub shape: Vec<usize>,
147 #[serde(default, deserialize_with = "deserialize_dshape")]
148 pub dshape: Vec<(DimName, usize)>,
149 #[serde(default)]
156 pub normalized: Option<bool>,
157}
158
159#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
160pub struct Classes {
161 #[serde(default)]
162 pub decoder: DecoderType,
163 #[serde(default)]
164 pub quantization: Option<QuantTuple>,
165 #[serde(default)]
166 pub shape: Vec<usize>,
167 #[serde(default, deserialize_with = "deserialize_dshape")]
168 pub dshape: Vec<(DimName, usize)>,
169}
170
171#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
172pub enum DimName {
173 #[serde(rename = "batch")]
174 Batch,
175 #[serde(rename = "height")]
176 Height,
177 #[serde(rename = "width")]
178 Width,
179 #[serde(rename = "num_classes")]
180 NumClasses,
181 #[serde(rename = "num_features")]
182 NumFeatures,
183 #[serde(rename = "num_boxes")]
184 NumBoxes,
185 #[serde(rename = "num_protos")]
186 NumProtos,
187 #[serde(rename = "num_anchors_x_features")]
188 NumAnchorsXFeatures,
189 #[serde(rename = "padding")]
190 Padding,
191 #[serde(rename = "box_coords")]
192 BoxCoords,
193}
194
195impl Display for DimName {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 match self {
207 DimName::Batch => write!(f, "batch"),
208 DimName::Height => write!(f, "height"),
209 DimName::Width => write!(f, "width"),
210 DimName::NumClasses => write!(f, "num_classes"),
211 DimName::NumFeatures => write!(f, "num_features"),
212 DimName::NumBoxes => write!(f, "num_boxes"),
213 DimName::NumProtos => write!(f, "num_protos"),
214 DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
215 DimName::Padding => write!(f, "padding"),
216 DimName::BoxCoords => write!(f, "box_coords"),
217 }
218 }
219}
220
221#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
222pub enum DecoderType {
223 #[serde(rename = "modelpack")]
224 ModelPack,
225 #[default]
226 #[serde(rename = "ultralytics", alias = "yolov8")]
227 Ultralytics,
228}
229
230#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
242#[serde(rename_all = "lowercase")]
243pub enum DecoderVersion {
244 #[serde(rename = "yolov5")]
246 Yolov5,
247 #[serde(rename = "yolov8")]
249 Yolov8,
250 #[serde(rename = "yolo11")]
252 Yolo11,
253 #[serde(rename = "yolo26")]
256 Yolo26,
257}
258
259impl DecoderVersion {
260 pub fn is_end_to_end(&self) -> bool {
263 matches!(self, DecoderVersion::Yolo26)
264 }
265}
266
267#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
276#[serde(rename_all = "snake_case")]
277pub enum Nms {
278 #[default]
281 ClassAgnostic,
282 ClassAware,
284}
285
286#[derive(Debug, Clone, PartialEq)]
287pub enum ModelType {
288 ModelPackSegDet {
289 boxes: Boxes,
290 scores: Scores,
291 segmentation: Segmentation,
292 },
293 ModelPackSegDetSplit {
294 detection: Vec<Detection>,
295 segmentation: Segmentation,
296 },
297 ModelPackDet {
298 boxes: Boxes,
299 scores: Scores,
300 },
301 ModelPackDetSplit {
302 detection: Vec<Detection>,
303 },
304 ModelPackSeg {
305 segmentation: Segmentation,
306 },
307 YoloDet {
308 boxes: Detection,
309 },
310 YoloSegDet {
311 boxes: Detection,
312 protos: Protos,
313 },
314 YoloSplitDet {
315 boxes: Boxes,
316 scores: Scores,
317 },
318 YoloSplitSegDet {
319 boxes: Boxes,
320 scores: Scores,
321 mask_coeff: MaskCoefficients,
322 protos: Protos,
323 },
324 YoloEndToEndDet {
328 boxes: Detection,
329 },
330 YoloEndToEndSegDet {
334 boxes: Detection,
335 protos: Protos,
336 },
337 YoloSplitEndToEndDet {
341 boxes: Boxes,
342 scores: Scores,
343 classes: Classes,
344 },
345 YoloSplitEndToEndSegDet {
348 boxes: Boxes,
349 scores: Scores,
350 classes: Classes,
351 mask_coeff: MaskCoefficients,
352 protos: Protos,
353 },
354}
355
356#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
357#[serde(rename_all = "lowercase")]
358pub enum DataType {
359 Raw = 0,
360 Int8 = 1,
361 UInt8 = 2,
362 Int16 = 3,
363 UInt16 = 4,
364 Float16 = 5,
365 Int32 = 6,
366 UInt32 = 7,
367 Float32 = 8,
368 Int64 = 9,
369 UInt64 = 10,
370 Float64 = 11,
371 String = 12,
372}