use std::collections::HashMap;
use ndarray::ArrayViewMut2;
use super::Similarity;
use crate::error::EvalError;
pub const COCO_PERSON_SIGMAS: [f64; 17] = [
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 0.107, 0.107,
0.087, 0.087, 0.089, 0.089,
];
#[derive(Debug, Clone, PartialEq)]
pub struct OksAnn {
pub category_id: i64,
pub keypoints: Vec<f64>,
pub num_keypoints: u32,
pub bbox: [f64; 4],
pub area: f64,
}
#[derive(Debug, Clone, Default)]
pub struct OksSimilarity {
pub sigmas: HashMap<i64, Vec<f64>>,
}
impl OksSimilarity {
#[must_use]
pub fn new(sigmas: HashMap<i64, Vec<f64>>) -> Self {
Self { sigmas }
}
#[inline]
fn sigmas_for(&self, category_id: i64) -> &[f64] {
self.sigmas
.get(&category_id)
.map(Vec::as_slice)
.unwrap_or(&COCO_PERSON_SIGMAS)
}
}
impl Similarity for OksSimilarity {
type Annotation = OksAnn;
fn compute(
&self,
gts: &[OksAnn],
dts: &[OksAnn],
out: &mut ArrayViewMut2<'_, f64>,
) -> Result<(), EvalError> {
if out.nrows() != gts.len() || out.ncols() != dts.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"OKS output is {}x{}, expected {}x{}",
out.nrows(),
out.ncols(),
gts.len(),
dts.len()
),
});
}
if gts.is_empty() || dts.is_empty() {
return Ok(());
}
for (side, anns) in [("gt", gts), ("dt", dts)] {
for (idx, ann) in anns.iter().enumerate() {
let k = self.sigmas_for(ann.category_id).len();
if ann.keypoints.len() != 3 * k {
return Err(EvalError::DimensionMismatch {
detail: format!(
"OKS {side}[{idx}] (cat {}): keypoints len {} != 3 * sigmas len {}",
ann.category_id,
ann.keypoints.len(),
k
),
});
}
}
}
for (g, gt) in gts.iter().enumerate() {
let sigmas = self.sigmas_for(gt.category_id);
let k = sigmas.len();
let vars: Vec<f64> = sigmas.iter().map(|s| (2.0 * s).powi(2)).collect();
let area_norm = gt.area + f64::EPSILON;
let k1 = gt.keypoints.chunks_exact(3).filter(|t| t[2] > 0.0).count();
let [bx, by, bw, bh] = gt.bbox;
let (x0, x1) = (bx - bw, bx + 2.0 * bw);
let (y0, y1) = (by - bh, by + 2.0 * bh);
let denom_count = if k1 > 0 { k1 } else { k };
if denom_count == 0 {
for d in 0..dts.len() {
out[[g, d]] = 0.0;
}
continue;
}
let inv_denom = 1.0 / (denom_count as f64);
for (d, dt) in dts.iter().enumerate() {
let mut e_sum = 0.0_f64;
if k1 > 0 {
for (i, (gt_t, dt_t)) in gt
.keypoints
.chunks_exact(3)
.zip(dt.keypoints.chunks_exact(3))
.enumerate()
{
if gt_t[2] <= 0.0 {
continue;
}
let dx = dt_t[0] - gt_t[0];
let dy = dt_t[1] - gt_t[1];
let e = (dx * dx + dy * dy) / vars[i] / area_norm / 2.0;
e_sum += (-e).exp();
}
} else {
for (i, dt_t) in dt.keypoints.chunks_exact(3).enumerate() {
let xd = dt_t[0];
let yd = dt_t[1];
let dx = (x0 - xd).max(0.0) + (xd - x1).max(0.0);
let dy = (y0 - yd).max(0.0) + (yd - y1).max(0.0);
let e = (dx * dx + dy * dy) / vars[i] / area_norm / 2.0;
e_sum += (-e).exp();
}
}
out[[g, d]] = e_sum * inv_denom;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn ann(cat: i64, kps: &[(f64, f64, u32)], bbox: [f64; 4], area: f64) -> OksAnn {
let mut keypoints = Vec::with_capacity(kps.len() * 3);
let mut visible = 0_u32;
for (x, y, v) in kps {
keypoints.push(*x);
keypoints.push(*y);
keypoints.push(f64::from(*v));
if *v > 0 {
visible += 1;
}
}
OksAnn {
category_id: cat,
keypoints,
num_keypoints: visible,
bbox,
area,
}
}
fn const_kps(x: f64, y: f64, v: u32) -> Vec<(f64, f64, u32)> {
vec![(x, y, v); 17]
}
fn compute(sim: &OksSimilarity, gts: &[OksAnn], dts: &[OksAnn]) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((gts.len(), dts.len()));
sim.compute(gts, dts, &mut out.view_mut()).unwrap();
out
}
#[test]
fn empty_gts_produces_zero_row_matrix() {
let dts = vec![ann(1, &const_kps(0.0, 0.0, 2), [0.0, 0.0, 1.0, 1.0], 1.0); 4];
let mut out = Array2::<f64>::from_elem((0, 4), 7.0);
OksSimilarity::default()
.compute(&[], &dts, &mut out.view_mut())
.unwrap();
assert_eq!(out.shape(), &[0, 4]);
}
#[test]
fn empty_dts_produces_zero_col_matrix() {
let gts = vec![ann(1, &const_kps(0.0, 0.0, 2), [0.0, 0.0, 1.0, 1.0], 1.0); 3];
let mut out = Array2::<f64>::from_elem((3, 0), 7.0);
OksSimilarity::default()
.compute(>s, &[], &mut out.view_mut())
.unwrap();
assert_eq!(out.shape(), &[3, 0]);
}
#[test]
fn both_empty_produces_zero_zero_matrix() {
let mut out = Array2::<f64>::zeros((0, 0));
OksSimilarity::default()
.compute(&[], &[], &mut out.view_mut())
.unwrap();
assert_eq!(out.shape(), &[0, 0]);
}
#[test]
fn single_perfect_match_is_one() {
let kps = const_kps(5.0, 7.0, 2);
let g = ann(1, &kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let d = ann(1, &kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let m = compute(&OksSimilarity::default(), &[g], &[d]);
assert!((m[[0, 0]] - 1.0).abs() < 1e-12);
}
#[test]
fn bbox_surrogate_path_when_no_visible_keypoints() {
let gt_kps: Vec<_> = (0..17).map(|_| (0.0, 0.0, 0)).collect();
let dt_kps = const_kps(5.0, 5.0, 2);
let g = ann(1, >_kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let d = ann(1, &dt_kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let m = compute(&OksSimilarity::default(), &[g], &[d]);
assert!((m[[0, 0]] - 1.0).abs() < 1e-12);
}
#[test]
fn per_category_sigma_override_changes_output() {
let gt_kps = const_kps(5.0, 5.0, 2);
let dt_kps = const_kps(6.0, 5.0, 2);
let g = ann(1, >_kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let d = ann(1, &dt_kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let default = compute(
&OksSimilarity::default(),
std::slice::from_ref(&g),
std::slice::from_ref(&d),
);
let mut override_map = HashMap::new();
override_map.insert(1_i64, vec![0.5_f64; 17]);
let custom = compute(&OksSimilarity::new(override_map), &[g], &[d]);
let area_norm = 100.0_f64 + f64::EPSILON;
let e = 1.0_f64 / 1.0_f64 / area_norm / 2.0;
let expected = (-e).exp();
assert!((custom[[0, 0]] - expected).abs() < 1e-10);
assert!((custom[[0, 0]] - default[[0, 0]]).abs() > 1e-6);
}
#[test]
fn f4_bbox_expansion_is_asymmetric_on_x() {
let gt_kps: Vec<_> = (0..17).map(|_| (0.0, 0.0, 0)).collect();
let g = ann(1, >_kps, [10.0, 0.0, 5.0, 1.0], 1.0);
let inside_kps = const_kps(19.999, 0.5, 2);
let outside_kps = const_kps(25.0, 0.5, 2);
let d_inside = ann(1, &inside_kps, [0.0, 0.0, 1.0, 1.0], 1.0);
let d_outside = ann(1, &outside_kps, [0.0, 0.0, 1.0, 1.0], 1.0);
let m = compute(&OksSimilarity::default(), &[g], &[d_inside, d_outside]);
assert!((m[[0, 0]] - 1.0).abs() < 1e-6, "inside x1 should be ~1.0");
assert!(m[[0, 1]] < 1.0 - 1e-6, "outside x1 should drop below 1.0");
let lower_in = const_kps(5.001, 0.5, 2);
let lower_out = const_kps(0.0, 0.5, 2);
let d_lower_in = ann(1, &lower_in, [0.0, 0.0, 1.0, 1.0], 1.0);
let d_lower_out = ann(1, &lower_out, [0.0, 0.0, 1.0, 1.0], 1.0);
let g2 = ann(
1,
&(0..17).map(|_| (0.0, 0.0, 0)).collect::<Vec<_>>(),
[10.0, 0.0, 5.0, 1.0],
1.0,
);
let m2 = compute(&OksSimilarity::default(), &[g2], &[d_lower_in, d_lower_out]);
assert!((m2[[0, 0]] - 1.0).abs() < 1e-6, "inside x0 should be ~1.0");
assert!(m2[[0, 1]] < 1.0 - 1e-6, "outside x0 should drop below 1.0");
}
#[test]
fn sigma_length_mismatch_returns_typed_error() {
let g = ann(1, &const_kps(0.0, 0.0, 2), [0.0, 0.0, 10.0, 10.0], 100.0);
let d = g.clone();
let mut override_map = HashMap::new();
override_map.insert(1_i64, vec![0.05_f64; 16]);
let sim = OksSimilarity::new(override_map);
let mut out = Array2::<f64>::zeros((1, 1));
let err = sim.compute(&[g], &[d], &mut out.view_mut()).unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(
detail.contains("keypoints"),
"expected keypoints detail, got {detail}",
);
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn output_shape_mismatch_returns_typed_error() {
let g = ann(1, &const_kps(0.0, 0.0, 2), [0.0, 0.0, 10.0, 10.0], 100.0);
let d = g.clone();
let mut out = Array2::<f64>::zeros((2, 3));
let err = OksSimilarity::default()
.compute(&[g], &[d], &mut out.view_mut())
.unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
#[test]
fn f2_area_epsilon_handles_zero_area_gt_without_nan() {
let kps = const_kps(0.0, 0.0, 2);
let g = ann(1, &kps, [0.0, 0.0, 0.0, 0.0], 0.0);
let d = ann(1, &kps, [0.0, 0.0, 0.0, 0.0], 0.0);
let m = compute(&OksSimilarity::default(), &[g], &[d]);
assert!(m[[0, 0]].is_finite());
assert!((m[[0, 0]] - 1.0).abs() < 1e-12);
}
#[test]
fn invisible_gt_keypoints_excluded_from_standard_path() {
let mut gt_kps = vec![(0.0, 0.0, 0); 17];
gt_kps[0] = (5.0, 5.0, 2);
let mut dt_kps = vec![(1000.0, 1000.0, 2); 17];
dt_kps[0] = (5.0, 5.0, 2);
let g = ann(1, >_kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let d = ann(1, &dt_kps, [0.0, 0.0, 10.0, 10.0], 100.0);
let m = compute(&OksSimilarity::default(), &[g], &[d]);
assert!((m[[0, 0]] - 1.0).abs() < 1e-12);
}
#[test]
fn impl_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<OksSimilarity>();
}
}