use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct AreaRange {
pub label: String,
pub range: [f64; 2],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
pub enum IouType {
Bbox,
Segm,
Keypoints,
Obb,
}
impl fmt::Display for IouType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IouType::Bbox => write!(f, "bbox"),
IouType::Segm => write!(f, "segm"),
IouType::Keypoints => write!(f, "keypoints"),
IouType::Obb => write!(f, "obb"),
}
}
}
impl FromStr for IouType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"bbox" => Ok(IouType::Bbox),
"segm" => Ok(IouType::Segm),
"keypoints" => Ok(IouType::Keypoints),
"obb" => Ok(IouType::Obb),
_ => Err(format!(
"Unknown iou_type: '{}'. Expected 'bbox', 'segm', 'keypoints', or 'obb'",
s
)),
}
}
}
pub(crate) const AREA_SMALL: f64 = 32.0 * 32.0;
pub(crate) const AREA_LARGE: f64 = 96.0 * 96.0;
pub(crate) const KPT_OKS_SIGMAS: [f64; 17] = [
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 0.107, 0.107,
0.087, 0.087, 0.089, 0.089,
];
pub(crate) fn default_iou_thrs() -> Vec<f64> {
(0..10).map(|i| 0.5 + 0.05 * i as f64).collect()
}
#[derive(Debug, Clone)]
pub struct Params {
pub iou_type: IouType,
pub img_ids: Vec<u64>,
pub cat_ids: Vec<u64>,
pub iou_thrs: Vec<f64>,
pub rec_thrs: Vec<f64>,
pub max_dets: Vec<usize>,
pub area_ranges: Vec<AreaRange>,
pub use_cats: bool,
pub kpt_oks_sigmas: Vec<f64>,
pub expand_dt: bool,
}
impl Params {
pub fn area_range_idx(&self, label: &str) -> Option<usize> {
self.area_ranges.iter().position(|ar| ar.label == label)
}
pub fn new(iou_type: IouType) -> Self {
let (max_dets, area_ranges) = match iou_type {
IouType::Keypoints => (
vec![20],
vec![
AreaRange {
label: "all".into(),
range: [0.0, 1e10],
},
AreaRange {
label: "medium".into(),
range: [AREA_SMALL, AREA_LARGE],
},
AreaRange {
label: "large".into(),
range: [AREA_LARGE, 1e10],
},
],
),
_ => (
vec![1, 10, 100],
vec![
AreaRange {
label: "all".into(),
range: [0.0, 1e10],
},
AreaRange {
label: "small".into(),
range: [0.0, AREA_SMALL],
},
AreaRange {
label: "medium".into(),
range: [AREA_SMALL, AREA_LARGE],
},
AreaRange {
label: "large".into(),
range: [AREA_LARGE, 1e10],
},
],
),
};
let kpt_oks_sigmas = KPT_OKS_SIGMAS.to_vec();
let iou_thrs = default_iou_thrs();
let rec_thrs: Vec<f64> = (0..=100).map(|i| i as f64 / 100.0).collect();
Params {
iou_type,
img_ids: Vec::new(),
cat_ids: Vec::new(),
iou_thrs,
rec_thrs,
max_dets,
area_ranges,
use_cats: true,
kpt_oks_sigmas,
expand_dt: false,
}
}
}