use std::collections::{HashMap, HashSet};
use std::sync::{Arc, OnceLock};
use serde::{Deserialize, Serialize};
use crate::error::EvalError;
use crate::parity::ParityMode;
use crate::segmentation::{Segmentation, SegmentationRleCounts};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ImageId(pub i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct CategoryId(pub i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct AnnId(pub i64);
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageMeta {
pub id: ImageId,
pub width: u32,
pub height: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub file_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CategoryMeta {
pub id: CategoryId,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub supercategory: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(from = "[f64; 4]", into = "[f64; 4]")]
pub struct Bbox {
pub x: f64,
pub y: f64,
pub w: f64,
pub h: f64,
}
impl From<[f64; 4]> for Bbox {
fn from([x, y, w, h]: [f64; 4]) -> Self {
Self { x, y, w, h }
}
}
impl From<Bbox> for [f64; 4] {
fn from(b: Bbox) -> Self {
[b.x, b.y, b.w, b.h]
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CocoAnnotation {
pub id: AnnId,
pub image_id: ImageId,
pub category_id: CategoryId,
pub area: f64,
#[serde(rename = "iscrowd", default, deserialize_with = "deserialize_bool_int")]
pub is_crowd: bool,
#[serde(
rename = "ignore",
default,
deserialize_with = "deserialize_opt_bool_int"
)]
pub ignore_flag: Option<bool>,
pub bbox: Bbox,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub segmentation: Option<Segmentation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub keypoints: Option<Vec<f64>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub num_keypoints: Option<u32>,
}
impl CocoAnnotation {
pub fn effective_ignore(&self, mode: ParityMode) -> bool {
match mode {
ParityMode::Strict => self.is_crowd,
ParityMode::Corrected => self.ignore_flag.unwrap_or(self.is_crowd),
}
}
}
pub trait Annotation {
fn image_id(&self) -> ImageId;
fn category_id(&self) -> CategoryId;
fn area(&self) -> f64;
fn is_crowd(&self) -> bool;
fn effective_ignore(&self, mode: ParityMode) -> bool;
}
impl Annotation for CocoAnnotation {
fn image_id(&self) -> ImageId {
self.image_id
}
fn category_id(&self) -> CategoryId {
self.category_id
}
fn area(&self) -> f64 {
self.area
}
fn is_crowd(&self) -> bool {
self.is_crowd
}
fn effective_ignore(&self, mode: ParityMode) -> bool {
Self::effective_ignore(self, mode)
}
}
pub trait EvalDataset: Send + Sync {
type Annotation: Annotation;
fn images(&self) -> &[ImageMeta];
fn categories(&self) -> &[CategoryMeta];
fn annotations(&self) -> &[Self::Annotation];
fn ann_indices_for_image(&self, image_id: ImageId) -> &[usize];
fn ann_indices_for_category(&self, cat_id: CategoryId) -> &[usize];
fn ann_iter_for_image(&self, image_id: ImageId) -> AnnotationIter<'_, Self::Annotation> {
AnnotationIter {
anns: self.annotations(),
indices: self.ann_indices_for_image(image_id).iter(),
}
}
fn ann_iter_for_category(&self, cat_id: CategoryId) -> AnnotationIter<'_, Self::Annotation> {
AnnotationIter {
anns: self.annotations(),
indices: self.ann_indices_for_category(cat_id).iter(),
}
}
}
pub struct AnnotationIter<'a, A> {
anns: &'a [A],
indices: std::slice::Iter<'a, usize>,
}
impl<'a, A> Iterator for AnnotationIter<'a, A> {
type Item = &'a A;
fn next(&mut self) -> Option<Self::Item> {
let idx = *self.indices.next()?;
self.anns.get(idx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.indices.size_hint()
}
}
impl<'a, A> ExactSizeIterator for AnnotationIter<'a, A> {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CocoJson {
pub images: Vec<ImageMeta>,
pub annotations: Vec<CocoAnnotation>,
pub categories: Vec<CategoryMeta>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Frequency {
#[serde(rename = "r")]
Rare,
#[serde(rename = "c")]
Common,
#[serde(rename = "f")]
Frequent,
}
impl Frequency {
pub const fn as_letter(self) -> &'static str {
match self {
Self::Rare => "r",
Self::Common => "c",
Self::Frequent => "f",
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct LvisImageRaw {
id: ImageId,
width: u32,
height: u32,
#[serde(default)]
file_name: Option<String>,
#[serde(default)]
neg_category_ids: Option<Vec<CategoryId>>,
#[serde(default)]
not_exhaustive_category_ids: Option<Vec<CategoryId>>,
}
#[derive(Debug, Clone, Deserialize)]
struct LvisCategoryRaw {
id: CategoryId,
name: String,
#[serde(default)]
supercategory: Option<String>,
#[serde(default)]
frequency: Option<Frequency>,
}
#[derive(Debug, Clone, Deserialize)]
struct LvisJson {
images: Vec<LvisImageRaw>,
annotations: Vec<CocoAnnotation>,
categories: Vec<LvisCategoryRaw>,
}
#[derive(Debug, Clone)]
pub struct FederatedMetadata {
pub pos_category_ids: HashMap<ImageId, HashSet<CategoryId>>,
pub neg_category_ids: HashMap<ImageId, HashSet<CategoryId>>,
pub not_exhaustive_category_ids: HashMap<ImageId, HashSet<CategoryId>>,
pub category_frequency: HashMap<CategoryId, Frequency>,
}
#[derive(Debug, Clone)]
pub struct CocoDataset {
images: Arc<Vec<ImageMeta>>,
categories: Arc<Vec<CategoryMeta>>,
annotations: Arc<Vec<CocoAnnotation>>,
by_image: HashMap<ImageId, Vec<usize>>,
by_category: HashMap<CategoryId, Vec<usize>>,
by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>>,
federated: Option<FederatedMetadata>,
cached_hash: Arc<OnceLock<[u8; 32]>>,
}
impl CocoDataset {
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self, EvalError> {
let raw: CocoJson = serde_json::from_slice(bytes)?;
Self::from_parts(raw.images, raw.annotations, raw.categories)
}
pub fn from_parts(
images: Vec<ImageMeta>,
annotations: Vec<CocoAnnotation>,
categories: Vec<CategoryMeta>,
) -> Result<Self, EvalError> {
let known_images: HashSet<ImageId> = images.iter().map(|i| i.id).collect();
let known_categories: HashSet<CategoryId> = categories.iter().map(|c| c.id).collect();
let mut by_image: HashMap<ImageId, Vec<usize>> = HashMap::with_capacity(images.len());
let mut by_category: HashMap<CategoryId, Vec<usize>> =
HashMap::with_capacity(categories.len());
let mut by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>> = HashMap::new();
for (idx, ann) in annotations.iter().enumerate() {
if !known_images.contains(&ann.image_id) {
return Err(EvalError::InvalidAnnotation {
detail: format!(
"annotation id={} references unknown image_id={}",
ann.id.0, ann.image_id.0
),
});
}
if !known_categories.contains(&ann.category_id) {
return Err(EvalError::InvalidAnnotation {
detail: format!(
"annotation id={} references unknown category_id={}",
ann.id.0, ann.category_id.0
),
});
}
by_image.entry(ann.image_id).or_default().push(idx);
by_category.entry(ann.category_id).or_default().push(idx);
by_image_cat
.entry((ann.image_id, ann.category_id))
.or_default()
.push(idx);
}
Ok(Self {
images: Arc::new(images),
categories: Arc::new(categories),
annotations: Arc::new(annotations),
by_image,
by_category,
by_image_cat,
federated: None,
cached_hash: Arc::new(OnceLock::new()),
})
}
pub fn from_lvis_json_bytes(bytes: &[u8]) -> Result<Self, EvalError> {
let raw: LvisJson = serde_json::from_slice(bytes)?;
let images: Vec<ImageMeta> = raw
.images
.iter()
.map(|im| ImageMeta {
id: im.id,
width: im.width,
height: im.height,
file_name: im.file_name.clone(),
})
.collect();
let categories: Vec<CategoryMeta> = raw
.categories
.iter()
.map(|c| CategoryMeta {
id: c.id,
name: c.name.clone(),
supercategory: c.supercategory.clone(),
})
.collect();
let mut missing_freq: Vec<i64> = raw
.categories
.iter()
.filter(|c| c.frequency.is_none())
.map(|c| c.id.0)
.collect();
if !missing_freq.is_empty() {
missing_freq.sort_unstable();
return Err(EvalError::MissingFrequency {
category_ids: missing_freq,
});
}
let category_frequency: HashMap<CategoryId, Frequency> = raw
.categories
.iter()
.filter_map(|c| c.frequency.map(|f| (c.id, f)))
.collect();
let mut dataset = Self::from_parts(images, raw.annotations, categories)?;
let mut pos: HashMap<ImageId, HashSet<CategoryId>> =
HashMap::with_capacity(raw.images.len());
for im in &raw.images {
pos.entry(im.id).or_default();
}
for ann in dataset.annotations.iter() {
pos.entry(ann.image_id).or_default().insert(ann.category_id);
}
let mut neg: HashMap<ImageId, HashSet<CategoryId>> =
HashMap::with_capacity(raw.images.len());
let mut nel: HashMap<ImageId, HashSet<CategoryId>> =
HashMap::with_capacity(raw.images.len());
for im in &raw.images {
let neg_set: HashSet<CategoryId> = im
.neg_category_ids
.as_deref()
.unwrap_or(&[])
.iter()
.copied()
.collect();
let nel_set: HashSet<CategoryId> = im
.not_exhaustive_category_ids
.as_deref()
.unwrap_or(&[])
.iter()
.copied()
.collect();
neg.insert(im.id, neg_set);
nel.insert(im.id, nel_set);
}
for im in &raw.images {
let image_id = im.id;
let pos_i = pos.get(&image_id).map_or_else(HashSet::new, Clone::clone);
let neg_i = &neg[&image_id];
let nel_i = &nel[&image_id];
if let Some(c) = pos_i.intersection(neg_i).next().copied() {
return Err(EvalError::LvisFederatedConflict {
image_id: image_id.0,
category_id: c.0,
detail: "category has GT on image but is also in neg_category_ids",
});
}
if let Some(c) = nel_i.difference(&pos_i).next().copied() {
return Err(EvalError::LvisFederatedConflict {
image_id: image_id.0,
category_id: c.0,
detail:
"category in not_exhaustive_category_ids but not in pos (no GT on image)",
});
}
if let Some(c) = nel_i.intersection(neg_i).next().copied() {
return Err(EvalError::LvisFederatedConflict {
image_id: image_id.0,
category_id: c.0,
detail: "category in both not_exhaustive_category_ids and neg_category_ids",
});
}
}
dataset.federated = Some(FederatedMetadata {
pos_category_ids: pos,
neg_category_ids: neg,
not_exhaustive_category_ids: nel,
category_frequency,
});
Ok(dataset)
}
pub fn federated(&self) -> Option<&FederatedMetadata> {
self.federated.as_ref()
}
pub fn pos_category_ids(&self) -> Option<&HashMap<ImageId, HashSet<CategoryId>>> {
self.federated.as_ref().map(|f| &f.pos_category_ids)
}
pub fn neg_category_ids(&self) -> Option<&HashMap<ImageId, HashSet<CategoryId>>> {
self.federated.as_ref().map(|f| &f.neg_category_ids)
}
pub fn not_exhaustive_category_ids(&self) -> Option<&HashMap<ImageId, HashSet<CategoryId>>> {
self.federated
.as_ref()
.map(|f| &f.not_exhaustive_category_ids)
}
pub fn category_frequency(&self) -> Option<&HashMap<CategoryId, Frequency>> {
self.federated.as_ref().map(|f| &f.category_frequency)
}
pub fn is_federated(&self) -> bool {
self.federated.is_some()
}
pub fn to_json_value(&self) -> CocoJson {
CocoJson {
images: (*self.images).clone(),
annotations: (*self.annotations).clone(),
categories: (*self.categories).clone(),
}
}
}
impl EvalDataset for CocoDataset {
type Annotation = CocoAnnotation;
fn images(&self) -> &[ImageMeta] {
&self.images
}
fn categories(&self) -> &[CategoryMeta] {
&self.categories
}
fn annotations(&self) -> &[CocoAnnotation] {
&self.annotations
}
fn ann_indices_for_image(&self, image_id: ImageId) -> &[usize] {
self.by_image.get(&image_id).map_or(&[][..], Vec::as_slice)
}
fn ann_indices_for_category(&self, cat_id: CategoryId) -> &[usize] {
self.by_category.get(&cat_id).map_or(&[][..], Vec::as_slice)
}
}
impl CocoDataset {
pub fn ann_indices_for(&self, image: ImageId, cat: CategoryId) -> &[usize] {
self.by_image_cat
.get(&(image, cat))
.map_or(&[][..], Vec::as_slice)
}
}
const HASH_TAG_DATASET: &[u8; 4] = b"DSET";
const HASH_TAG_IMAGES: &[u8; 4] = b"IMGS";
const HASH_TAG_CATEGORIES: &[u8; 4] = b"CATS";
const HASH_TAG_ANNOTATIONS: &[u8; 4] = b"ANNS";
const HASH_TAG_FEDERATED: &[u8; 4] = b"FEDM";
const HASH_CANONICAL_VERSION: u8 = 1;
#[inline]
fn hash_u8(h: &mut blake3::Hasher, v: u8) {
h.update(&[v]);
}
#[inline]
fn hash_u32(h: &mut blake3::Hasher, v: u32) {
h.update(&v.to_le_bytes());
}
#[inline]
fn hash_i64(h: &mut blake3::Hasher, v: i64) {
h.update(&v.to_le_bytes());
}
#[inline]
fn hash_u64(h: &mut blake3::Hasher, v: u64) {
h.update(&v.to_le_bytes());
}
#[inline]
fn hash_f64(h: &mut blake3::Hasher, v: f64) {
h.update(&v.to_bits().to_le_bytes());
}
#[inline]
fn hash_bool(h: &mut blake3::Hasher, v: bool) {
hash_u8(h, u8::from(v));
}
#[inline]
fn hash_bytes(h: &mut blake3::Hasher, bytes: &[u8]) {
hash_u64(h, bytes.len() as u64);
h.update(bytes);
}
#[inline]
fn hash_string(h: &mut blake3::Hasher, s: &str) {
hash_bytes(h, s.as_bytes());
}
#[inline]
fn hash_option<T>(
h: &mut blake3::Hasher,
opt: Option<T>,
write: impl FnOnce(&mut blake3::Hasher, T),
) {
match opt {
None => hash_u8(h, 0),
Some(v) => {
hash_u8(h, 1);
write(h, v);
}
}
}
fn hash_bbox(h: &mut blake3::Hasher, b: &Bbox) {
hash_f64(h, b.x);
hash_f64(h, b.y);
hash_f64(h, b.w);
hash_f64(h, b.h);
}
fn hash_segmentation(h: &mut blake3::Hasher, seg: Option<&Segmentation>) {
match seg {
None => hash_u8(h, 0),
Some(Segmentation::Polygons(polys)) => {
hash_u8(h, 1);
hash_u64(h, polys.len() as u64);
for poly in polys {
hash_u64(h, poly.len() as u64);
for &v in poly {
hash_f64(h, v);
}
}
}
Some(Segmentation::Rle(rle)) => {
let [rh, rw] = rle.size;
match &rle.counts {
SegmentationRleCounts::Compressed(s) => {
hash_u8(h, 2);
hash_u32(h, rh);
hash_u32(h, rw);
hash_string(h, s);
}
SegmentationRleCounts::Uncompressed(counts) => {
hash_u8(h, 3);
hash_u32(h, rh);
hash_u32(h, rw);
hash_u64(h, counts.len() as u64);
for &c in counts.iter() {
hash_u32(h, c);
}
}
}
}
}
}
fn hash_id_sorted<T>(
h: &mut blake3::Hasher,
tag: &[u8; 4],
items: &[T],
key: impl Fn(&T) -> i64,
write: impl Fn(&mut blake3::Hasher, &T),
) {
h.update(tag);
let mut order: Vec<usize> = (0..items.len()).collect();
order.sort_unstable_by_key(|&i| key(&items[i]));
hash_u64(h, order.len() as u64);
for &i in &order {
write(h, &items[i]);
}
}
fn hash_image_meta(h: &mut blake3::Hasher, im: &ImageMeta) {
let ImageMeta {
id,
width,
height,
file_name,
} = im;
hash_i64(h, id.0);
hash_u32(h, *width);
hash_u32(h, *height);
hash_option(h, file_name.as_deref(), hash_string);
}
fn hash_category_meta(h: &mut blake3::Hasher, c: &CategoryMeta) {
let CategoryMeta {
id,
name,
supercategory,
} = c;
hash_i64(h, id.0);
hash_string(h, name);
hash_option(h, supercategory.as_deref(), hash_string);
}
fn hash_coco_annotation(h: &mut blake3::Hasher, a: &CocoAnnotation) {
let CocoAnnotation {
id,
image_id,
category_id,
area,
is_crowd,
ignore_flag,
bbox,
segmentation,
keypoints,
num_keypoints,
} = a;
hash_i64(h, id.0);
hash_i64(h, image_id.0);
hash_i64(h, category_id.0);
hash_f64(h, *area);
hash_bool(h, *is_crowd);
hash_option(h, *ignore_flag, hash_bool);
hash_bbox(h, bbox);
hash_segmentation(h, segmentation.as_ref());
hash_option(h, keypoints.as_deref(), |h, kps| {
hash_u64(h, kps.len() as u64);
for &v in kps {
hash_f64(h, v);
}
});
hash_option(h, *num_keypoints, hash_u32);
}
fn hash_federated(h: &mut blake3::Hasher, fed: &FederatedMetadata) {
h.update(HASH_TAG_FEDERATED);
let mut freq_pairs: Vec<(i64, &Frequency)> = fed
.category_frequency
.iter()
.map(|(k, v)| (k.0, v))
.collect();
freq_pairs.sort_unstable_by_key(|(k, _)| *k);
hash_u64(h, freq_pairs.len() as u64);
for (cid, freq) in freq_pairs {
hash_i64(h, cid);
hash_u8(h, freq.as_letter().as_bytes()[0]);
}
type FedSection<'a> = (&'a [u8; 3], &'a HashMap<ImageId, HashSet<CategoryId>>);
let sections: [FedSection<'_>; 3] = [
(b"POS", &fed.pos_category_ids),
(b"NEG", &fed.neg_category_ids),
(b"NEX", &fed.not_exhaustive_category_ids),
];
for (tag, map) in sections {
h.update(tag);
let mut entries: Vec<(i64, Vec<i64>)> = map
.iter()
.map(|(image_id, cats)| {
let mut cat_ids: Vec<i64> = cats.iter().map(|c| c.0).collect();
cat_ids.sort_unstable();
(image_id.0, cat_ids)
})
.collect();
entries.sort_unstable_by_key(|(image_id, _)| *image_id);
hash_u64(h, entries.len() as u64);
for (image_id, cat_ids) in entries {
hash_i64(h, image_id);
hash_u64(h, cat_ids.len() as u64);
for cid in cat_ids {
hash_i64(h, cid);
}
}
}
}
impl CocoDataset {
pub fn dataset_hash(&self) -> [u8; 32] {
*self.cached_hash.get_or_init(|| self.compute_dataset_hash())
}
fn compute_dataset_hash(&self) -> [u8; 32] {
let mut h = blake3::Hasher::new();
h.update(HASH_TAG_DATASET);
hash_u8(&mut h, HASH_CANONICAL_VERSION);
hash_id_sorted(
&mut h,
HASH_TAG_IMAGES,
&self.images,
|im| im.id.0,
hash_image_meta,
);
hash_id_sorted(
&mut h,
HASH_TAG_CATEGORIES,
&self.categories,
|c| c.id.0,
hash_category_meta,
);
hash_id_sorted(
&mut h,
HASH_TAG_ANNOTATIONS,
&self.annotations,
|a| a.id.0,
hash_coco_annotation,
);
match self.federated.as_ref() {
None => hash_u8(&mut h, 0),
Some(fed) => {
hash_u8(&mut h, 1);
hash_federated(&mut h, fed);
}
}
*h.finalize().as_bytes()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CocoDetection {
pub id: AnnId,
pub image_id: ImageId,
pub category_id: CategoryId,
pub score: f64,
pub bbox: Bbox,
pub area: f64,
pub segmentation: Option<Segmentation>,
pub keypoints: Option<Vec<f64>>,
pub num_keypoints: Option<u32>,
}
impl Annotation for CocoDetection {
fn image_id(&self) -> ImageId {
self.image_id
}
fn category_id(&self) -> CategoryId {
self.category_id
}
fn area(&self) -> f64 {
self.area
}
fn is_crowd(&self) -> bool {
false
}
fn effective_ignore(&self, _: ParityMode) -> bool {
false
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DetectionInput {
#[serde(default)]
pub id: Option<AnnId>,
pub image_id: ImageId,
pub category_id: CategoryId,
pub score: f64,
pub bbox: Bbox,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub segmentation: Option<Segmentation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub keypoints: Option<Vec<f64>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub num_keypoints: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct CocoDetections {
detections: Arc<Vec<CocoDetection>>,
by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>>,
by_image: HashMap<ImageId, Vec<usize>>,
}
impl CocoDetections {
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self, EvalError> {
let raw: Vec<DetectionInput> = serde_json::from_slice(bytes)?;
Self::from_inputs(raw)
}
pub fn from_inputs(inputs: Vec<DetectionInput>) -> Result<Self, EvalError> {
let mut detections = Vec::with_capacity(inputs.len());
let mut next_auto = 1i64;
for input in inputs {
if !input.score.is_finite() {
return Err(EvalError::NonFinite {
context: "detection score",
});
}
let id = match input.id {
Some(id) => id,
None => {
let id = AnnId(next_auto);
next_auto += 1;
id
}
};
detections.push(CocoDetection {
id,
image_id: input.image_id,
category_id: input.category_id,
score: input.score,
bbox: input.bbox,
area: input.bbox.w * input.bbox.h,
segmentation: input.segmentation,
keypoints: input.keypoints,
num_keypoints: input.num_keypoints,
});
}
let mut by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>> = HashMap::new();
let mut by_image: HashMap<ImageId, Vec<usize>> = HashMap::new();
for (idx, dt) in detections.iter().enumerate() {
by_image_cat
.entry((dt.image_id, dt.category_id))
.or_default()
.push(idx);
by_image.entry(dt.image_id).or_default().push(idx);
}
Ok(Self {
detections: Arc::new(detections),
by_image_cat,
by_image,
})
}
pub fn from_records(records: Vec<CocoDetection>) -> Self {
let mut by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>> = HashMap::new();
let mut by_image: HashMap<ImageId, Vec<usize>> = HashMap::new();
for (idx, dt) in records.iter().enumerate() {
by_image_cat
.entry((dt.image_id, dt.category_id))
.or_default()
.push(idx);
by_image.entry(dt.image_id).or_default().push(idx);
}
Self {
detections: Arc::new(records),
by_image_cat,
by_image,
}
}
pub fn detections(&self) -> &[CocoDetection] {
&self.detections
}
pub fn indices_for(&self, image: ImageId, cat: CategoryId) -> &[usize] {
self.by_image_cat
.get(&(image, cat))
.map_or(&[][..], Vec::as_slice)
}
pub fn indices_for_image(&self, image: ImageId) -> &[usize] {
self.by_image.get(&image).map_or(&[][..], Vec::as_slice)
}
pub fn lvis_trim(&self, max_dets: i64) -> CocoDetections {
if max_dets < 0 {
return self.clone();
}
let cap = max_dets as usize;
let mut by_image_groups: HashMap<ImageId, Vec<usize>> = HashMap::new();
for (idx, dt) in self.detections.iter().enumerate() {
by_image_groups.entry(dt.image_id).or_default().push(idx);
}
let mut image_ids: Vec<ImageId> = by_image_groups.keys().copied().collect();
image_ids.sort_unstable_by_key(|i| i.0);
let upper_bound = self
.detections
.len()
.min(cap.saturating_mul(image_ids.len()));
let mut out: Vec<CocoDetection> = Vec::with_capacity(upper_bound);
for image_id in image_ids {
let mut group = by_image_groups.remove(&image_id).unwrap_or_default();
group.sort_by(|&a, &b| {
self.detections[b]
.score
.partial_cmp(&self.detections[a].score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for &idx in group.iter().take(cap) {
out.push(self.detections[idx].clone());
}
}
CocoDetections::from_records(out)
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum BoolOrInt {
Bool(bool),
Int(i64),
}
impl BoolOrInt {
fn into_bool<E: serde::de::Error>(self) -> Result<bool, E> {
match self {
Self::Bool(b) => Ok(b),
Self::Int(0) => Ok(false),
Self::Int(1) => Ok(true),
Self::Int(other) => Err(E::custom(format!(
"expected 0 or 1 for COCO bool field, got {other}"
))),
}
}
}
fn deserialize_bool_int<'de, D>(de: D) -> Result<bool, D::Error>
where
D: serde::Deserializer<'de>,
{
BoolOrInt::deserialize(de)?.into_bool()
}
fn deserialize_opt_bool_int<'de, D>(de: D) -> Result<Option<bool>, D::Error>
where
D: serde::Deserializer<'de>,
{
Option::<BoolOrInt>::deserialize(de)?
.map(BoolOrInt::into_bool)
.transpose()
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
const CROWD_REGION_GT: &str = r#"{
"images": [
{"id": 1, "width": 200, "height": 200, "file_name": "img1.png"}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [100, 100, 50, 50], "area": 2500, "iscrowd": 0},
{"id": 2, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 200, 200], "area": 40000, "iscrowd": 1}
],
"categories": [
{"id": 1, "name": "widget", "supercategory": "thing"}
]
}"#;
fn load_crowd_region() -> CocoDataset {
CocoDataset::from_json_bytes(CROWD_REGION_GT.as_bytes()).unwrap()
}
#[test]
fn loads_crowd_region_fixture() {
let ds = load_crowd_region();
assert_eq!(ds.images().len(), 1);
assert_eq!(ds.categories().len(), 1);
assert_eq!(ds.annotations().len(), 2);
assert_eq!(ds.images()[0].file_name.as_deref(), Some("img1.png"));
assert_eq!(ds.categories()[0].name, "widget");
}
#[test]
fn by_image_index_returns_both_anns() {
let ds = load_crowd_region();
let idxs = ds.ann_indices_for_image(ImageId(1));
assert_eq!(idxs.len(), 2);
let anns: Vec<_> = ds.ann_iter_for_image(ImageId(1)).collect();
assert_eq!(anns.len(), 2);
assert_eq!(anns[0].id, AnnId(1));
assert_eq!(anns[1].id, AnnId(2));
}
#[test]
fn by_category_index_returns_both_anns() {
let ds = load_crowd_region();
let idxs = ds.ann_indices_for_category(CategoryId(1));
assert_eq!(idxs.len(), 2);
}
#[test]
fn unknown_image_returns_empty_slice() {
let ds = load_crowd_region();
assert!(ds.ann_indices_for_image(ImageId(999)).is_empty());
assert!(ds.ann_indices_for_category(CategoryId(999)).is_empty());
}
#[test]
fn empty_image_or_category_returns_empty_slice_not_missing() {
const ONLY_EMPTY_IMG: &str = r#"{
"images": [{"id": 7, "width": 1, "height": 1}],
"annotations": [],
"categories": [{"id": 3, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(ONLY_EMPTY_IMG.as_bytes()).unwrap();
assert!(ds.ann_indices_for_image(ImageId(7)).is_empty());
assert!(ds.ann_indices_for_category(CategoryId(3)).is_empty());
}
#[test]
fn rejects_annotation_referencing_unknown_image() {
const BAD: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 99, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let err = CocoDataset::from_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::InvalidAnnotation { detail } => {
assert!(detail.contains("image_id=99"), "msg: {detail}");
}
other => panic!("expected InvalidAnnotation, got {other:?}"),
}
}
#[test]
fn rejects_annotation_referencing_unknown_category() {
const BAD: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 42,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let err = CocoDataset::from_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::InvalidAnnotation { detail } => {
assert!(detail.contains("category_id=42"), "msg: {detail}");
}
other => panic!("expected InvalidAnnotation, got {other:?}"),
}
}
#[test]
fn round_trips_through_json() {
let ds = load_crowd_region();
let json = serde_json::to_string(&ds.to_json_value()).unwrap();
let again = CocoDataset::from_json_bytes(json.as_bytes()).unwrap();
assert_eq!(ds.images(), again.images());
assert_eq!(ds.categories(), again.categories());
assert_eq!(ds.annotations(), again.annotations());
}
#[test]
fn d1_strict_mode_drops_explicit_ignore_field() {
const ANN_JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1,
"iscrowd": 0, "ignore": 1}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(ANN_JSON.as_bytes()).unwrap();
let ann = &ds.annotations()[0];
assert!(!ann.effective_ignore(ParityMode::Strict));
assert!(ann.effective_ignore(ParityMode::Corrected));
}
#[test]
fn d1_strict_mode_uses_iscrowd_when_ignore_absent() {
const ANN_JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 1}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(ANN_JSON.as_bytes()).unwrap();
let ann = &ds.annotations()[0];
assert!(ann.effective_ignore(ParityMode::Strict));
assert!(ann.effective_ignore(ParityMode::Corrected));
}
#[test]
fn ann_indices_for_image_cat_returns_correct_subset() {
const TWO_CATS: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0},
{"id": 2, "image_id": 1, "category_id": 2,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0},
{"id": 3, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [
{"id": 1, "name": "a"}, {"id": 2, "name": "b"}
]
}"#;
let ds = CocoDataset::from_json_bytes(TWO_CATS.as_bytes()).unwrap();
let cat1: Vec<AnnId> = ds
.ann_indices_for(ImageId(1), CategoryId(1))
.iter()
.map(|&i| ds.annotations()[i].id)
.collect();
assert_eq!(cat1, vec![AnnId(1), AnnId(3)]);
let cat2: Vec<AnnId> = ds
.ann_indices_for(ImageId(1), CategoryId(2))
.iter()
.map(|&i| ds.annotations()[i].id)
.collect();
assert_eq!(cat2, vec![AnnId(2)]);
assert!(ds.ann_indices_for(ImageId(1), CategoryId(99)).is_empty());
assert!(ds.ann_indices_for(ImageId(99), CategoryId(1)).is_empty());
}
fn dt_input(image: i64, cat: i64, score: f64, bbox: (f64, f64, f64, f64)) -> DetectionInput {
DetectionInput {
id: None,
image_id: ImageId(image),
category_id: CategoryId(cat),
score,
bbox: Bbox {
x: bbox.0,
y: bbox.1,
w: bbox.2,
h: bbox.3,
},
segmentation: None,
keypoints: None,
num_keypoints: None,
}
}
#[test]
fn j1_auto_assigns_ids_when_absent() {
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0)),
dt_input(1, 1, 0.8, (0.0, 0.0, 1.0, 1.0)),
])
.unwrap();
let ids: Vec<AnnId> = dts.detections().iter().map(|d| d.id).collect();
assert_eq!(ids, vec![AnnId(1), AnnId(2)]);
}
#[test]
fn j1_preserves_user_supplied_ids() {
let mut a = dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0));
a.id = Some(AnnId(42));
let mut b = dt_input(1, 1, 0.8, (0.0, 0.0, 1.0, 1.0));
b.id = Some(AnnId(7));
let dts = CocoDetections::from_inputs(vec![a, b]).unwrap();
let ids: Vec<AnnId> = dts.detections().iter().map(|d| d.id).collect();
assert_eq!(ids, vec![AnnId(42), AnnId(7)]);
}
#[test]
fn j3_derives_area_from_bbox() {
let dts =
CocoDetections::from_inputs(vec![dt_input(1, 1, 0.5, (10.0, 10.0, 4.0, 5.0))]).unwrap();
assert_eq!(dts.detections()[0].area, 20.0);
}
#[test]
fn rejects_non_finite_score() {
let err = CocoDetections::from_inputs(vec![dt_input(1, 1, f64::NAN, (0.0, 0.0, 1.0, 1.0))])
.unwrap_err();
assert!(matches!(
err,
EvalError::NonFinite {
context: "detection score"
}
));
}
#[test]
fn detections_indices_per_image_cat() {
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0)),
dt_input(1, 2, 0.8, (0.0, 0.0, 1.0, 1.0)),
dt_input(2, 1, 0.7, (0.0, 0.0, 1.0, 1.0)),
])
.unwrap();
assert_eq!(dts.indices_for(ImageId(1), CategoryId(1)), &[0]);
assert_eq!(dts.indices_for(ImageId(1), CategoryId(2)), &[1]);
assert_eq!(dts.indices_for(ImageId(2), CategoryId(1)), &[2]);
assert!(dts.indices_for(ImageId(99), CategoryId(1)).is_empty());
let img1: Vec<usize> = dts.indices_for_image(ImageId(1)).to_vec();
assert_eq!(img1, vec![0, 1]);
}
#[test]
fn loads_detections_from_json_array() {
const JSON: &str = r#"[
{"image_id": 1, "category_id": 1, "score": 0.9,
"bbox": [0, 0, 2, 3]},
{"id": 7, "image_id": 1, "category_id": 1, "score": 0.5,
"bbox": [1, 1, 1, 1]}
]"#;
let dts = CocoDetections::from_json_bytes(JSON.as_bytes()).unwrap();
let ds = dts.detections();
assert_eq!(ds[0].id, AnnId(1)); assert_eq!(ds[0].area, 6.0); assert_eq!(ds[1].id, AnnId(7)); assert!(!ds[0].is_crowd()); assert!(ds[0].segmentation.is_none());
}
#[test]
fn gt_loads_polygon_segmentation() {
const JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 4, 4], "area": 16, "iscrowd": 0,
"segmentation": [[0, 0, 4, 0, 4, 4, 0, 4]]}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(JSON.as_bytes()).unwrap();
let seg = ds.annotations()[0].segmentation.as_ref().unwrap();
let rle = seg.to_rle(10, 10).unwrap();
assert_eq!(rle.area(), 16);
}
#[test]
fn gt_loads_compressed_rle_segmentation() {
let counts_str = String::from_utf8(vernier_mask::encode_counts(&[0, 16])).unwrap();
let json = format!(
r#"{{
"images": [{{"id": 1, "width": 4, "height": 4}}],
"annotations": [
{{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 4, 4], "area": 16, "iscrowd": 1,
"segmentation": {{"size": [4, 4], "counts": "{counts_str}"}}}}
],
"categories": [{{"id": 1, "name": "thing"}}]
}}"#
);
let ds = CocoDataset::from_json_bytes(json.as_bytes()).unwrap();
let seg = ds.annotations()[0].segmentation.as_ref().unwrap();
let rle = seg.to_rle(4, 4).unwrap();
assert_eq!((rle.h, rle.w), (4, 4));
assert_eq!(rle.area(), 16);
}
#[test]
fn gt_segmentation_round_trips_through_to_json_value() {
const JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 4, 4], "area": 16, "iscrowd": 0,
"segmentation": [[0, 0, 4, 0, 4, 4, 0, 4]]}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(JSON.as_bytes()).unwrap();
let serialized = serde_json::to_string(&ds.to_json_value()).unwrap();
let again = CocoDataset::from_json_bytes(serialized.as_bytes()).unwrap();
assert_eq!(ds.annotations(), again.annotations());
}
#[test]
fn gt_without_segmentation_field_loads_as_none() {
let ds = load_crowd_region();
assert!(ds.annotations().iter().all(|a| a.segmentation.is_none()));
}
#[test]
fn dt_loads_compressed_rle_segmentation() {
const JSON: &str = r#"[
{"image_id": 1, "category_id": 1, "score": 0.9,
"bbox": [0, 0, 4, 4],
"segmentation": {"size": [4, 4], "counts": "04L4"}}
]"#;
let dts = CocoDetections::from_json_bytes(JSON.as_bytes()).unwrap();
assert!(dts.detections()[0].segmentation.is_some());
}
#[test]
fn dt_without_segmentation_loads_as_none() {
const JSON: &str = r#"[
{"image_id": 1, "category_id": 1, "score": 0.9, "bbox": [0, 0, 1, 1]}
]"#;
let dts = CocoDetections::from_json_bytes(JSON.as_bytes()).unwrap();
assert!(dts.detections()[0].segmentation.is_none());
}
fn arb_image() -> impl Strategy<Value = ImageMeta> {
(1i64..1000, 1u32..2048, 1u32..2048).prop_map(|(id, w, h)| ImageMeta {
id: ImageId(id),
width: w,
height: h,
file_name: None,
})
}
fn arb_category() -> impl Strategy<Value = CategoryMeta> {
(1i64..100, "[a-z]{1,8}").prop_map(|(id, name)| CategoryMeta {
id: CategoryId(id),
name,
supercategory: None,
})
}
fn make_min_annotation(
id: AnnId,
image_id: ImageId,
category_id: CategoryId,
) -> CocoAnnotation {
CocoAnnotation {
id,
image_id,
category_id,
area: 25.0,
is_crowd: false,
ignore_flag: None,
bbox: Bbox {
x: 0.0,
y: 0.0,
w: 5.0,
h: 5.0,
},
segmentation: None,
keypoints: None,
num_keypoints: None,
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(64))]
#[test]
fn index_invariants_hold(
images in proptest::collection::vec(arb_image(), 1..6),
categories in proptest::collection::vec(arb_category(), 1..6),
n_anns in 0usize..40,
ann_seed in any::<u64>(),
) {
let mut images = images;
images.sort_by_key(|i| i.id);
images.dedup_by_key(|i| i.id);
let mut categories = categories;
categories.sort_by_key(|c| c.id);
categories.dedup_by_key(|c| c.id);
let mut state = ann_seed.wrapping_add(1);
let mut next = || {
state = state.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
state
};
let mut annotations = Vec::with_capacity(n_anns);
for ann_idx in 0..n_anns {
let img = &images[(next() as usize) % images.len()];
let cat = &categories[(next() as usize) % categories.len()];
annotations.push(CocoAnnotation {
id: AnnId(ann_idx as i64 + 1),
image_id: img.id,
category_id: cat.id,
area: 1.0,
is_crowd: false,
ignore_flag: None,
bbox: Bbox { x: 0.0, y: 0.0, w: 1.0, h: 1.0 },
segmentation: None,
keypoints: None,
num_keypoints: None,
});
}
let ds = CocoDataset::from_parts(
images.clone(), annotations.clone(), categories.clone()
).unwrap();
let mut seen_img: Vec<usize> = images.iter()
.flat_map(|i| ds.ann_indices_for_image(i.id).iter().copied())
.collect();
seen_img.sort_unstable();
let expected: Vec<usize> = (0..annotations.len()).collect();
prop_assert_eq!(&seen_img, &expected);
let mut seen_cat: Vec<usize> = categories.iter()
.flat_map(|c| ds.ann_indices_for_category(c.id).iter().copied())
.collect();
seen_cat.sort_unstable();
prop_assert_eq!(&seen_cat, &expected);
for img in &images {
for &idx in ds.ann_indices_for_image(img.id) {
prop_assert_eq!(ds.annotations()[idx].image_id, img.id);
}
}
for cat in &categories {
for &idx in ds.ann_indices_for_category(cat.id) {
prop_assert_eq!(ds.annotations()[idx].category_id, cat.id);
}
}
}
}
const LVIS_MIN_VALID: &str = r#"{
"images": [
{"id": 1, "width": 100, "height": 100,
"neg_category_ids": [2], "not_exhaustive_category_ids": []},
{"id": 2, "width": 100, "height": 100,
"neg_category_ids": [], "not_exhaustive_category_ids": [2]}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 10, 10], "area": 100, "iscrowd": 0},
{"id": 2, "image_id": 2, "category_id": 2,
"bbox": [0, 0, 20, 20], "area": 400, "iscrowd": 0}
],
"categories": [
{"id": 1, "name": "a", "frequency": "f"},
{"id": 2, "name": "b", "frequency": "r"}
]
}"#;
#[test]
fn lvis_loads_minimal_valid_dataset() {
let ds = CocoDataset::from_lvis_json_bytes(LVIS_MIN_VALID.as_bytes()).unwrap();
assert_eq!(ds.images().len(), 2);
assert_eq!(ds.categories().len(), 2);
assert_eq!(ds.annotations().len(), 2);
assert!(ds.is_federated());
let pos = ds.pos_category_ids().unwrap();
let neg = ds.neg_category_ids().unwrap();
let nel = ds.not_exhaustive_category_ids().unwrap();
let freq = ds.category_frequency().unwrap();
assert_eq!(pos[&ImageId(1)], HashSet::from([CategoryId(1)]));
assert_eq!(pos[&ImageId(2)], HashSet::from([CategoryId(2)]));
assert_eq!(neg[&ImageId(1)], HashSet::from([CategoryId(2)]));
assert_eq!(neg[&ImageId(2)], HashSet::new());
assert_eq!(nel[&ImageId(1)], HashSet::new());
assert_eq!(nel[&ImageId(2)], HashSet::from([CategoryId(2)]));
assert_eq!(freq[&CategoryId(1)], Frequency::Frequent);
assert_eq!(freq[&CategoryId(2)], Frequency::Rare);
}
#[test]
fn aa1_pos_derived_from_gts_does_not_include_zero_ann_categories() {
let ds = CocoDataset::from_lvis_json_bytes(LVIS_MIN_VALID.as_bytes()).unwrap();
let pos = ds.pos_category_ids().unwrap();
assert!(!pos[&ImageId(1)].contains(&CategoryId(2)));
assert!(!pos[&ImageId(2)].contains(&CategoryId(1)));
}
#[test]
fn from_json_bytes_leaves_federated_metadata_none() {
let ds = CocoDataset::from_json_bytes(LVIS_MIN_VALID.as_bytes()).unwrap();
assert!(!ds.is_federated());
assert!(ds.pos_category_ids().is_none());
assert!(ds.neg_category_ids().is_none());
assert!(ds.not_exhaustive_category_ids().is_none());
assert!(ds.category_frequency().is_none());
}
#[test]
fn aa7_pos_intersect_neg_rejected() {
const BAD: &str = r#"{
"images": [
{"id": 1, "width": 10, "height": 10,
"neg_category_ids": [1], "not_exhaustive_category_ids": []}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "a", "frequency": "f"}]
}"#;
let err = CocoDataset::from_lvis_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::LvisFederatedConflict {
image_id,
category_id,
detail,
} => {
assert_eq!(image_id, 1);
assert_eq!(category_id, 1);
assert!(detail.contains("GT"));
}
other => panic!("expected LvisFederatedConflict, got {other:?}"),
}
}
#[test]
fn aa7_not_exhaustive_outside_pos_rejected() {
const BAD: &str = r#"{
"images": [
{"id": 1, "width": 10, "height": 10,
"neg_category_ids": [], "not_exhaustive_category_ids": [2]}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [
{"id": 1, "name": "a", "frequency": "f"},
{"id": 2, "name": "b", "frequency": "r"}
]
}"#;
let err = CocoDataset::from_lvis_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::LvisFederatedConflict {
image_id,
category_id,
detail,
} => {
assert_eq!(image_id, 1);
assert_eq!(category_id, 2);
assert!(detail.contains("not_exhaustive"));
}
other => panic!("expected LvisFederatedConflict, got {other:?}"),
}
}
#[test]
fn ab6_missing_frequency_collects_all_offenders() {
const BAD: &str = r#"{
"images": [
{"id": 1, "width": 10, "height": 10,
"neg_category_ids": [], "not_exhaustive_category_ids": []}
],
"annotations": [],
"categories": [
{"id": 7, "name": "g"},
{"id": 3, "name": "c"}
]
}"#;
let err = CocoDataset::from_lvis_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::MissingFrequency { category_ids } => {
assert_eq!(category_ids, vec![3, 7]);
}
other => panic!("expected MissingFrequency, got {other:?}"),
}
}
#[test]
fn lvis_loader_treats_absent_neg_field_as_empty() {
const TOLERANT: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [],
"categories": [{"id": 1, "name": "a", "frequency": "c"}]
}"#;
let ds = CocoDataset::from_lvis_json_bytes(TOLERANT.as_bytes()).unwrap();
let neg = ds.neg_category_ids().unwrap();
let nel = ds.not_exhaustive_category_ids().unwrap();
assert!(neg[&ImageId(1)].is_empty());
assert!(nel[&ImageId(1)].is_empty());
}
#[test]
fn frequency_round_trips_serde() {
for f in [Frequency::Rare, Frequency::Common, Frequency::Frequent] {
let s = serde_json::to_string(&f).unwrap();
let back: Frequency = serde_json::from_str(&s).unwrap();
assert_eq!(f, back);
}
assert_eq!(serde_json::to_string(&Frequency::Rare).unwrap(), "\"r\"");
assert_eq!(serde_json::to_string(&Frequency::Common).unwrap(), "\"c\"");
assert_eq!(
serde_json::to_string(&Frequency::Frequent).unwrap(),
"\"f\""
);
}
#[test]
fn ac2_q1_trims_500_single_category_to_300() {
let dts = CocoDetections::from_inputs(
(0..500)
.map(|i| {
let score = 1.0 - (i as f64) / 1000.0; dt_input(1, 1, score, (0.0, 0.0, 1.0, 1.0))
})
.collect(),
)
.unwrap();
let trimmed = dts.lvis_trim(300);
assert_eq!(trimmed.detections().len(), 300);
let scores: Vec<f64> = trimmed.detections().iter().map(|d| d.score).collect();
for w in scores.windows(2) {
assert!(
w[0] >= w[1],
"lvis_trim must preserve score-descending order"
);
}
assert!((scores[0] - 1.0).abs() < 1e-12);
assert!((scores[299] - 0.701).abs() < 1e-12);
}
#[test]
fn ac3_q2_cross_class_crowding_keeps_300_total_across_classes() {
let mut inputs = Vec::with_capacity(600);
for i in 0..250 {
let score = 0.5 - (i as f64) * 0.002;
inputs.push(dt_input(1, 1, score, (0.0, 0.0, 1.0, 1.0)));
}
for i in 0..350 {
let score = 1.0 - (i as f64) * 0.002;
inputs.push(dt_input(1, 2, score, (0.0, 0.0, 1.0, 1.0)));
}
let dts = CocoDetections::from_inputs(inputs).unwrap();
let trimmed = dts.lvis_trim(300);
assert_eq!(trimmed.detections().len(), 300);
let n_cat1 = trimmed
.detections()
.iter()
.filter(|d| d.category_id == CategoryId(1))
.count();
let n_cat2 = trimmed
.detections()
.iter()
.filter(|d| d.category_id == CategoryId(2))
.count();
assert_eq!(n_cat1 + n_cat2, 300);
assert!(n_cat1 > 0, "cat 1 must keep at least its top-score entries");
assert!(n_cat2 > 0, "cat 2 must keep its high-score entries");
}
#[test]
fn ac5_negative_max_dets_disables_trim() {
let dts = CocoDetections::from_inputs(
(0..50)
.map(|i| dt_input(1, 1, i as f64 / 100.0, (0.0, 0.0, 1.0, 1.0)))
.collect(),
)
.unwrap();
let trimmed = dts.lvis_trim(-1);
assert_eq!(trimmed.detections().len(), 50);
for (i, dt) in trimmed.detections().iter().enumerate() {
assert!((dt.score - (i as f64 / 100.0)).abs() < 1e-12);
}
}
#[test]
fn ac5_max_dets_at_capacity_is_no_op() {
let dts = CocoDetections::from_inputs(
(0..10)
.map(|i| dt_input(1, 1, i as f64 / 10.0, (0.0, 0.0, 1.0, 1.0)))
.collect(),
)
.unwrap();
let trimmed = dts.lvis_trim(100);
assert_eq!(trimmed.detections().len(), 10);
}
#[test]
fn ac4_stable_sort_preserves_input_order_for_score_ties() {
let mut a = dt_input(1, 1, 0.5, (0.0, 0.0, 1.0, 1.0));
a.id = Some(AnnId(100));
let mut b = dt_input(1, 1, 0.5, (1.0, 0.0, 1.0, 1.0));
b.id = Some(AnnId(200));
let dts = CocoDetections::from_inputs(vec![a, b]).unwrap();
let trimmed = dts.lvis_trim(2);
let ids: Vec<AnnId> = trimmed.detections().iter().map(|d| d.id).collect();
assert_eq!(
ids,
vec![AnnId(100), AnnId(200)],
"AC4: stable sort must preserve input order on score ties"
);
}
#[test]
fn lvis_trim_groups_by_image_id() {
let mut inputs = Vec::with_capacity(15);
for img in 1..=3i64 {
for i in 0..5 {
let score = 1.0 - (img as f64) * 0.01 - (i as f64) * 0.001;
inputs.push(dt_input(img, img, score, (0.0, 0.0, 1.0, 1.0)));
}
}
let dts = CocoDetections::from_inputs(inputs).unwrap();
let trimmed = dts.lvis_trim(2);
assert_eq!(trimmed.detections().len(), 6);
for img in 1..=3i64 {
let n = trimmed
.detections()
.iter()
.filter(|d| d.image_id == ImageId(img))
.count();
assert_eq!(n, 2, "image {img} must trim to 2");
}
}
#[test]
fn lvis_trim_zero_max_dets_keeps_nothing() {
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0)),
dt_input(1, 1, 0.5, (0.0, 0.0, 1.0, 1.0)),
])
.unwrap();
let trimmed = dts.lvis_trim(0);
assert!(trimmed.detections().is_empty());
}
#[test]
fn lvis_loader_inherits_invalid_annotation_validation() {
const BAD: &str = r#"{
"images": [
{"id": 1, "width": 10, "height": 10,
"neg_category_ids": [], "not_exhaustive_category_ids": []}
],
"annotations": [
{"id": 1, "image_id": 99, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "a", "frequency": "f"}]
}"#;
let err = CocoDataset::from_lvis_json_bytes(BAD.as_bytes()).unwrap_err();
assert!(matches!(err, EvalError::InvalidAnnotation { .. }));
}
#[test]
fn dataset_hash_is_stable_for_equal_inputs() {
let a = load_crowd_region();
let b = load_crowd_region();
assert_eq!(a.dataset_hash(), b.dataset_hash());
}
#[test]
fn dataset_hash_caches_via_arc_clone() {
let a = load_crowd_region();
let b = a.clone();
let h1 = a.dataset_hash();
let h2 = b.dataset_hash();
assert_eq!(h1, h2);
}
#[test]
fn dataset_hash_invariant_to_image_order() {
let order_a = r#"{
"images": [
{"id": 1, "width": 10, "height": 10},
{"id": 2, "width": 20, "height": 20}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "x"}]
}"#;
let order_b = r#"{
"images": [
{"id": 2, "width": 20, "height": 20},
{"id": 1, "width": 10, "height": 10}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "x"}]
}"#;
let a = CocoDataset::from_json_bytes(order_a.as_bytes()).unwrap();
let b = CocoDataset::from_json_bytes(order_b.as_bytes()).unwrap();
assert_eq!(a.dataset_hash(), b.dataset_hash());
}
#[test]
fn dataset_hash_invariant_to_annotation_order() {
let order_a = r#"{
"images": [{"id": 1, "width": 200, "height": 200}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 5, 5], "area": 25, "iscrowd": 0},
{"id": 2, "image_id": 1, "category_id": 1,
"bbox": [10, 10, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "x"}]
}"#;
let order_b = r#"{
"images": [{"id": 1, "width": 200, "height": 200}],
"annotations": [
{"id": 2, "image_id": 1, "category_id": 1,
"bbox": [10, 10, 5, 5], "area": 25, "iscrowd": 0},
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "x"}]
}"#;
let a = CocoDataset::from_json_bytes(order_a.as_bytes()).unwrap();
let b = CocoDataset::from_json_bytes(order_b.as_bytes()).unwrap();
assert_eq!(a.dataset_hash(), b.dataset_hash());
}
#[test]
fn dataset_hash_changes_when_bbox_changes_by_one_pixel() {
let base = r#"{
"images": [{"id": 1, "width": 200, "height": 200}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [10, 10, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "x"}]
}"#;
let shifted = r#"{
"images": [{"id": 1, "width": 200, "height": 200}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [11, 10, 5, 5], "area": 25, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "x"}]
}"#;
let a = CocoDataset::from_json_bytes(base.as_bytes()).unwrap();
let b = CocoDataset::from_json_bytes(shifted.as_bytes()).unwrap();
assert_ne!(a.dataset_hash(), b.dataset_hash());
}
proptest! {
#[test]
fn dataset_hash_invariant_under_id_shuffle(
mut images in proptest::collection::vec(arb_image(), 1..16),
categories in proptest::collection::vec(arb_category(), 1..4),
) {
images.sort_by_key(|im| im.id.0);
images.dedup_by_key(|im| im.id.0);
let mut unique_categories = categories;
unique_categories.sort_by_key(|c| c.id.0);
unique_categories.dedup_by_key(|c| c.id.0);
prop_assume!(!images.is_empty());
prop_assume!(!unique_categories.is_empty());
let cat_id = unique_categories[0].id;
let annotations: Vec<CocoAnnotation> = images
.iter()
.enumerate()
.map(|(i, im)| make_min_annotation(AnnId((i as i64) + 1), im.id, cat_id))
.collect();
let mut shuffled = images.clone();
shuffled.reverse();
let a = CocoDataset::from_parts(
images,
annotations.clone(),
unique_categories.clone(),
).unwrap();
let b = CocoDataset::from_parts(
shuffled,
annotations,
unique_categories,
).unwrap();
prop_assert_eq!(a.dataset_hash(), b.dataset_hash());
}
}
#[test]
fn params_hash_is_stable_for_equal_inputs() {
use crate::evaluate::OwnedEvaluateParams;
let a = OwnedEvaluateParams {
iou_thresholds: vec![0.5, 0.55, 0.6],
area_ranges: vec![],
max_dets_per_image: 100,
use_cats: true,
retain_iou: false,
};
let b = a.clone();
assert_eq!(a.params_hash().unwrap(), b.params_hash().unwrap());
}
#[test]
fn params_hash_changes_when_thresholds_change() {
use crate::evaluate::OwnedEvaluateParams;
let a = OwnedEvaluateParams {
iou_thresholds: vec![0.5, 0.55, 0.6],
area_ranges: vec![],
max_dets_per_image: 100,
use_cats: true,
retain_iou: false,
};
let mut b = a.clone();
b.iou_thresholds.push(0.65);
assert_ne!(a.params_hash().unwrap(), b.params_hash().unwrap());
}
#[test]
fn params_hash_changes_when_use_cats_toggles() {
use crate::evaluate::OwnedEvaluateParams;
let a = OwnedEvaluateParams {
iou_thresholds: vec![0.5],
area_ranges: vec![],
max_dets_per_image: 100,
use_cats: true,
retain_iou: false,
};
let mut b = a.clone();
b.use_cats = false;
assert_ne!(a.params_hash().unwrap(), b.params_hash().unwrap());
}
}