use std::collections::HashMap;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
pub fn deserialize_dshape<'de, D>(deserializer: D) -> Result<Vec<(DimName, usize)>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum DShapeItem {
Tuple(DimName, usize),
Map(HashMap<DimName, usize>),
}
let items: Vec<DShapeItem> = Vec::deserialize(deserializer)?;
items
.into_iter()
.map(|item| match item {
DShapeItem::Tuple(name, size) => Ok((name, size)),
DShapeItem::Map(map) => {
if map.len() != 1 {
return Err(serde::de::Error::custom(
"dshape map entry must have exactly one key",
));
}
let (name, size) = map.into_iter().next().unwrap();
Ok((name, size))
}
})
.collect()
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
pub struct QuantTuple(pub f32, pub i32);
impl From<QuantTuple> for (f32, i32) {
fn from(value: QuantTuple) -> Self {
(value.0, value.1)
}
}
impl From<(f32, i32)> for QuantTuple {
fn from(value: (f32, i32)) -> Self {
QuantTuple(value.0, value.1)
}
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Segmentation {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Protos {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct MaskCoefficients {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Mask {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Detection {
#[serde(default)]
pub anchors: Option<Vec<[f32; 2]>>,
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
#[serde(default)]
pub normalized: Option<bool>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Scores {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Boxes {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
#[serde(default)]
pub normalized: Option<bool>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
pub struct Classes {
#[serde(default)]
pub decoder: DecoderType,
#[serde(default)]
pub quantization: Option<QuantTuple>,
#[serde(default)]
pub shape: Vec<usize>,
#[serde(default, deserialize_with = "deserialize_dshape")]
pub dshape: Vec<(DimName, usize)>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
pub enum DimName {
#[serde(rename = "batch")]
Batch,
#[serde(rename = "height")]
Height,
#[serde(rename = "width")]
Width,
#[serde(rename = "num_classes")]
NumClasses,
#[serde(rename = "num_features")]
NumFeatures,
#[serde(rename = "num_boxes")]
NumBoxes,
#[serde(rename = "num_protos")]
NumProtos,
#[serde(rename = "num_anchors_x_features")]
NumAnchorsXFeatures,
#[serde(rename = "padding")]
Padding,
#[serde(rename = "box_coords")]
BoxCoords,
}
impl Display for DimName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DimName::Batch => write!(f, "batch"),
DimName::Height => write!(f, "height"),
DimName::Width => write!(f, "width"),
DimName::NumClasses => write!(f, "num_classes"),
DimName::NumFeatures => write!(f, "num_features"),
DimName::NumBoxes => write!(f, "num_boxes"),
DimName::NumProtos => write!(f, "num_protos"),
DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
DimName::Padding => write!(f, "padding"),
DimName::BoxCoords => write!(f, "box_coords"),
}
}
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
pub enum DecoderType {
#[serde(rename = "modelpack")]
ModelPack,
#[default]
#[serde(rename = "ultralytics", alias = "yolov8")]
Ultralytics,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
#[serde(rename_all = "lowercase")]
pub enum DecoderVersion {
#[serde(rename = "yolov5")]
Yolov5,
#[serde(rename = "yolov8")]
Yolov8,
#[serde(rename = "yolo11")]
Yolo11,
#[serde(rename = "yolo26")]
Yolo26,
}
impl DecoderVersion {
pub fn is_end_to_end(&self) -> bool {
matches!(self, DecoderVersion::Yolo26)
}
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum Nms {
#[default]
ClassAgnostic,
ClassAware,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ModelType {
ModelPackSegDet {
boxes: Boxes,
scores: Scores,
segmentation: Segmentation,
},
ModelPackSegDetSplit {
detection: Vec<Detection>,
segmentation: Segmentation,
},
ModelPackDet {
boxes: Boxes,
scores: Scores,
},
ModelPackDetSplit {
detection: Vec<Detection>,
},
ModelPackSeg {
segmentation: Segmentation,
},
YoloDet {
boxes: Detection,
},
YoloSegDet {
boxes: Detection,
protos: Protos,
},
YoloSplitDet {
boxes: Boxes,
scores: Scores,
},
YoloSplitSegDet {
boxes: Boxes,
scores: Scores,
mask_coeff: MaskCoefficients,
protos: Protos,
},
YoloSegDet2Way {
boxes: Detection,
mask_coeff: MaskCoefficients,
protos: Protos,
},
YoloEndToEndDet {
boxes: Detection,
},
YoloEndToEndSegDet {
boxes: Detection,
protos: Protos,
},
YoloSplitEndToEndDet {
boxes: Boxes,
scores: Scores,
classes: Classes,
},
YoloSplitEndToEndSegDet {
boxes: Boxes,
scores: Scores,
classes: Classes,
mask_coeff: MaskCoefficients,
protos: Protos,
},
}