use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::error::EvalError;
use crate::parity::ParityMode;
use crate::segmentation::Segmentation;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ImageId(pub i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct CategoryId(pub i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct AnnId(pub i64);
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageMeta {
pub id: ImageId,
pub width: u32,
pub height: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub file_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CategoryMeta {
pub id: CategoryId,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub supercategory: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(from = "[f64; 4]", into = "[f64; 4]")]
pub struct Bbox {
pub x: f64,
pub y: f64,
pub w: f64,
pub h: f64,
}
impl From<[f64; 4]> for Bbox {
fn from([x, y, w, h]: [f64; 4]) -> Self {
Self { x, y, w, h }
}
}
impl From<Bbox> for [f64; 4] {
fn from(b: Bbox) -> Self {
[b.x, b.y, b.w, b.h]
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CocoAnnotation {
pub id: AnnId,
pub image_id: ImageId,
pub category_id: CategoryId,
pub area: f64,
#[serde(rename = "iscrowd", default, deserialize_with = "deserialize_bool_int")]
pub is_crowd: bool,
#[serde(
rename = "ignore",
default,
deserialize_with = "deserialize_opt_bool_int"
)]
pub ignore_flag: Option<bool>,
pub bbox: Bbox,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub segmentation: Option<Segmentation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub keypoints: Option<Vec<f64>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub num_keypoints: Option<u32>,
}
impl CocoAnnotation {
pub fn effective_ignore(&self, mode: ParityMode) -> bool {
match mode {
ParityMode::Strict => self.is_crowd,
ParityMode::Corrected => self.ignore_flag.unwrap_or(self.is_crowd),
}
}
}
pub trait Annotation {
fn image_id(&self) -> ImageId;
fn category_id(&self) -> CategoryId;
fn area(&self) -> f64;
fn is_crowd(&self) -> bool;
fn effective_ignore(&self, mode: ParityMode) -> bool;
}
impl Annotation for CocoAnnotation {
fn image_id(&self) -> ImageId {
self.image_id
}
fn category_id(&self) -> CategoryId {
self.category_id
}
fn area(&self) -> f64 {
self.area
}
fn is_crowd(&self) -> bool {
self.is_crowd
}
fn effective_ignore(&self, mode: ParityMode) -> bool {
Self::effective_ignore(self, mode)
}
}
pub trait EvalDataset: Send + Sync {
type Annotation: Annotation;
fn images(&self) -> &[ImageMeta];
fn categories(&self) -> &[CategoryMeta];
fn annotations(&self) -> &[Self::Annotation];
fn ann_indices_for_image(&self, image_id: ImageId) -> &[usize];
fn ann_indices_for_category(&self, cat_id: CategoryId) -> &[usize];
fn ann_iter_for_image(&self, image_id: ImageId) -> AnnotationIter<'_, Self::Annotation> {
AnnotationIter {
anns: self.annotations(),
indices: self.ann_indices_for_image(image_id).iter(),
}
}
fn ann_iter_for_category(&self, cat_id: CategoryId) -> AnnotationIter<'_, Self::Annotation> {
AnnotationIter {
anns: self.annotations(),
indices: self.ann_indices_for_category(cat_id).iter(),
}
}
}
pub struct AnnotationIter<'a, A> {
anns: &'a [A],
indices: std::slice::Iter<'a, usize>,
}
impl<'a, A> Iterator for AnnotationIter<'a, A> {
type Item = &'a A;
fn next(&mut self) -> Option<Self::Item> {
let idx = *self.indices.next()?;
self.anns.get(idx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.indices.size_hint()
}
}
impl<'a, A> ExactSizeIterator for AnnotationIter<'a, A> {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CocoJson {
pub images: Vec<ImageMeta>,
pub annotations: Vec<CocoAnnotation>,
pub categories: Vec<CategoryMeta>,
}
#[derive(Debug, Clone)]
pub struct CocoDataset {
images: Arc<Vec<ImageMeta>>,
categories: Arc<Vec<CategoryMeta>>,
annotations: Arc<Vec<CocoAnnotation>>,
by_image: HashMap<ImageId, Vec<usize>>,
by_category: HashMap<CategoryId, Vec<usize>>,
by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>>,
}
impl CocoDataset {
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self, EvalError> {
let raw: CocoJson = serde_json::from_slice(bytes)?;
Self::from_parts(raw.images, raw.annotations, raw.categories)
}
pub fn from_parts(
images: Vec<ImageMeta>,
annotations: Vec<CocoAnnotation>,
categories: Vec<CategoryMeta>,
) -> Result<Self, EvalError> {
let known_images: HashSet<ImageId> = images.iter().map(|i| i.id).collect();
let known_categories: HashSet<CategoryId> = categories.iter().map(|c| c.id).collect();
let mut by_image: HashMap<ImageId, Vec<usize>> = HashMap::with_capacity(images.len());
let mut by_category: HashMap<CategoryId, Vec<usize>> =
HashMap::with_capacity(categories.len());
let mut by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>> = HashMap::new();
for (idx, ann) in annotations.iter().enumerate() {
if !known_images.contains(&ann.image_id) {
return Err(EvalError::InvalidAnnotation {
detail: format!(
"annotation id={} references unknown image_id={}",
ann.id.0, ann.image_id.0
),
});
}
if !known_categories.contains(&ann.category_id) {
return Err(EvalError::InvalidAnnotation {
detail: format!(
"annotation id={} references unknown category_id={}",
ann.id.0, ann.category_id.0
),
});
}
by_image.entry(ann.image_id).or_default().push(idx);
by_category.entry(ann.category_id).or_default().push(idx);
by_image_cat
.entry((ann.image_id, ann.category_id))
.or_default()
.push(idx);
}
Ok(Self {
images: Arc::new(images),
categories: Arc::new(categories),
annotations: Arc::new(annotations),
by_image,
by_category,
by_image_cat,
})
}
pub fn to_json_value(&self) -> CocoJson {
CocoJson {
images: (*self.images).clone(),
annotations: (*self.annotations).clone(),
categories: (*self.categories).clone(),
}
}
}
impl EvalDataset for CocoDataset {
type Annotation = CocoAnnotation;
fn images(&self) -> &[ImageMeta] {
&self.images
}
fn categories(&self) -> &[CategoryMeta] {
&self.categories
}
fn annotations(&self) -> &[CocoAnnotation] {
&self.annotations
}
fn ann_indices_for_image(&self, image_id: ImageId) -> &[usize] {
self.by_image.get(&image_id).map_or(&[][..], Vec::as_slice)
}
fn ann_indices_for_category(&self, cat_id: CategoryId) -> &[usize] {
self.by_category.get(&cat_id).map_or(&[][..], Vec::as_slice)
}
}
impl CocoDataset {
pub fn ann_indices_for(&self, image: ImageId, cat: CategoryId) -> &[usize] {
self.by_image_cat
.get(&(image, cat))
.map_or(&[][..], Vec::as_slice)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CocoDetection {
pub id: AnnId,
pub image_id: ImageId,
pub category_id: CategoryId,
pub score: f64,
pub bbox: Bbox,
pub area: f64,
pub segmentation: Option<Segmentation>,
pub keypoints: Option<Vec<f64>>,
pub num_keypoints: Option<u32>,
}
impl Annotation for CocoDetection {
fn image_id(&self) -> ImageId {
self.image_id
}
fn category_id(&self) -> CategoryId {
self.category_id
}
fn area(&self) -> f64 {
self.area
}
fn is_crowd(&self) -> bool {
false
}
fn effective_ignore(&self, _: ParityMode) -> bool {
false
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DetectionInput {
#[serde(default)]
pub id: Option<AnnId>,
pub image_id: ImageId,
pub category_id: CategoryId,
pub score: f64,
pub bbox: Bbox,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub segmentation: Option<Segmentation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub keypoints: Option<Vec<f64>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub num_keypoints: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct CocoDetections {
detections: Arc<Vec<CocoDetection>>,
by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>>,
by_image: HashMap<ImageId, Vec<usize>>,
}
impl CocoDetections {
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self, EvalError> {
let raw: Vec<DetectionInput> = serde_json::from_slice(bytes)?;
Self::from_inputs(raw)
}
pub fn from_inputs(inputs: Vec<DetectionInput>) -> Result<Self, EvalError> {
let mut detections = Vec::with_capacity(inputs.len());
let mut next_auto = 1i64;
for input in inputs {
if !input.score.is_finite() {
return Err(EvalError::NonFinite {
context: "detection score",
});
}
let id = match input.id {
Some(id) => id,
None => {
let id = AnnId(next_auto);
next_auto += 1;
id
}
};
detections.push(CocoDetection {
id,
image_id: input.image_id,
category_id: input.category_id,
score: input.score,
bbox: input.bbox,
area: input.bbox.w * input.bbox.h,
segmentation: input.segmentation,
keypoints: input.keypoints,
num_keypoints: input.num_keypoints,
});
}
let mut by_image_cat: HashMap<(ImageId, CategoryId), Vec<usize>> = HashMap::new();
let mut by_image: HashMap<ImageId, Vec<usize>> = HashMap::new();
for (idx, dt) in detections.iter().enumerate() {
by_image_cat
.entry((dt.image_id, dt.category_id))
.or_default()
.push(idx);
by_image.entry(dt.image_id).or_default().push(idx);
}
Ok(Self {
detections: Arc::new(detections),
by_image_cat,
by_image,
})
}
pub fn detections(&self) -> &[CocoDetection] {
&self.detections
}
pub fn indices_for(&self, image: ImageId, cat: CategoryId) -> &[usize] {
self.by_image_cat
.get(&(image, cat))
.map_or(&[][..], Vec::as_slice)
}
pub fn indices_for_image(&self, image: ImageId) -> &[usize] {
self.by_image.get(&image).map_or(&[][..], Vec::as_slice)
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum BoolOrInt {
Bool(bool),
Int(i64),
}
impl BoolOrInt {
fn into_bool<E: serde::de::Error>(self) -> Result<bool, E> {
match self {
Self::Bool(b) => Ok(b),
Self::Int(0) => Ok(false),
Self::Int(1) => Ok(true),
Self::Int(other) => Err(E::custom(format!(
"expected 0 or 1 for COCO bool field, got {other}"
))),
}
}
}
fn deserialize_bool_int<'de, D>(de: D) -> Result<bool, D::Error>
where
D: serde::Deserializer<'de>,
{
BoolOrInt::deserialize(de)?.into_bool()
}
fn deserialize_opt_bool_int<'de, D>(de: D) -> Result<Option<bool>, D::Error>
where
D: serde::Deserializer<'de>,
{
Option::<BoolOrInt>::deserialize(de)?
.map(BoolOrInt::into_bool)
.transpose()
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
const CROWD_REGION_GT: &str = r#"{
"images": [
{"id": 1, "width": 200, "height": 200, "file_name": "img1.png"}
],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [100, 100, 50, 50], "area": 2500, "iscrowd": 0},
{"id": 2, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 200, 200], "area": 40000, "iscrowd": 1}
],
"categories": [
{"id": 1, "name": "widget", "supercategory": "thing"}
]
}"#;
fn load_crowd_region() -> CocoDataset {
CocoDataset::from_json_bytes(CROWD_REGION_GT.as_bytes()).unwrap()
}
#[test]
fn loads_crowd_region_fixture() {
let ds = load_crowd_region();
assert_eq!(ds.images().len(), 1);
assert_eq!(ds.categories().len(), 1);
assert_eq!(ds.annotations().len(), 2);
assert_eq!(ds.images()[0].file_name.as_deref(), Some("img1.png"));
assert_eq!(ds.categories()[0].name, "widget");
}
#[test]
fn by_image_index_returns_both_anns() {
let ds = load_crowd_region();
let idxs = ds.ann_indices_for_image(ImageId(1));
assert_eq!(idxs.len(), 2);
let anns: Vec<_> = ds.ann_iter_for_image(ImageId(1)).collect();
assert_eq!(anns.len(), 2);
assert_eq!(anns[0].id, AnnId(1));
assert_eq!(anns[1].id, AnnId(2));
}
#[test]
fn by_category_index_returns_both_anns() {
let ds = load_crowd_region();
let idxs = ds.ann_indices_for_category(CategoryId(1));
assert_eq!(idxs.len(), 2);
}
#[test]
fn unknown_image_returns_empty_slice() {
let ds = load_crowd_region();
assert!(ds.ann_indices_for_image(ImageId(999)).is_empty());
assert!(ds.ann_indices_for_category(CategoryId(999)).is_empty());
}
#[test]
fn empty_image_or_category_returns_empty_slice_not_missing() {
const ONLY_EMPTY_IMG: &str = r#"{
"images": [{"id": 7, "width": 1, "height": 1}],
"annotations": [],
"categories": [{"id": 3, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(ONLY_EMPTY_IMG.as_bytes()).unwrap();
assert!(ds.ann_indices_for_image(ImageId(7)).is_empty());
assert!(ds.ann_indices_for_category(CategoryId(3)).is_empty());
}
#[test]
fn rejects_annotation_referencing_unknown_image() {
const BAD: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 99, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let err = CocoDataset::from_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::InvalidAnnotation { detail } => {
assert!(detail.contains("image_id=99"), "msg: {detail}");
}
other => panic!("expected InvalidAnnotation, got {other:?}"),
}
}
#[test]
fn rejects_annotation_referencing_unknown_category() {
const BAD: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 42,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let err = CocoDataset::from_json_bytes(BAD.as_bytes()).unwrap_err();
match err {
EvalError::InvalidAnnotation { detail } => {
assert!(detail.contains("category_id=42"), "msg: {detail}");
}
other => panic!("expected InvalidAnnotation, got {other:?}"),
}
}
#[test]
fn round_trips_through_json() {
let ds = load_crowd_region();
let json = serde_json::to_string(&ds.to_json_value()).unwrap();
let again = CocoDataset::from_json_bytes(json.as_bytes()).unwrap();
assert_eq!(ds.images(), again.images());
assert_eq!(ds.categories(), again.categories());
assert_eq!(ds.annotations(), again.annotations());
}
#[test]
fn d1_strict_mode_drops_explicit_ignore_field() {
const ANN_JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1,
"iscrowd": 0, "ignore": 1}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(ANN_JSON.as_bytes()).unwrap();
let ann = &ds.annotations()[0];
assert!(!ann.effective_ignore(ParityMode::Strict));
assert!(ann.effective_ignore(ParityMode::Corrected));
}
#[test]
fn d1_strict_mode_uses_iscrowd_when_ignore_absent() {
const ANN_JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 1}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(ANN_JSON.as_bytes()).unwrap();
let ann = &ds.annotations()[0];
assert!(ann.effective_ignore(ParityMode::Strict));
assert!(ann.effective_ignore(ParityMode::Corrected));
}
#[test]
fn ann_indices_for_image_cat_returns_correct_subset() {
const TWO_CATS: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0},
{"id": 2, "image_id": 1, "category_id": 2,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0},
{"id": 3, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 1, 1], "area": 1, "iscrowd": 0}
],
"categories": [
{"id": 1, "name": "a"}, {"id": 2, "name": "b"}
]
}"#;
let ds = CocoDataset::from_json_bytes(TWO_CATS.as_bytes()).unwrap();
let cat1: Vec<AnnId> = ds
.ann_indices_for(ImageId(1), CategoryId(1))
.iter()
.map(|&i| ds.annotations()[i].id)
.collect();
assert_eq!(cat1, vec![AnnId(1), AnnId(3)]);
let cat2: Vec<AnnId> = ds
.ann_indices_for(ImageId(1), CategoryId(2))
.iter()
.map(|&i| ds.annotations()[i].id)
.collect();
assert_eq!(cat2, vec![AnnId(2)]);
assert!(ds.ann_indices_for(ImageId(1), CategoryId(99)).is_empty());
assert!(ds.ann_indices_for(ImageId(99), CategoryId(1)).is_empty());
}
fn dt_input(image: i64, cat: i64, score: f64, bbox: (f64, f64, f64, f64)) -> DetectionInput {
DetectionInput {
id: None,
image_id: ImageId(image),
category_id: CategoryId(cat),
score,
bbox: Bbox {
x: bbox.0,
y: bbox.1,
w: bbox.2,
h: bbox.3,
},
segmentation: None,
keypoints: None,
num_keypoints: None,
}
}
#[test]
fn j1_auto_assigns_ids_when_absent() {
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0)),
dt_input(1, 1, 0.8, (0.0, 0.0, 1.0, 1.0)),
])
.unwrap();
let ids: Vec<AnnId> = dts.detections().iter().map(|d| d.id).collect();
assert_eq!(ids, vec![AnnId(1), AnnId(2)]);
}
#[test]
fn j1_preserves_user_supplied_ids() {
let mut a = dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0));
a.id = Some(AnnId(42));
let mut b = dt_input(1, 1, 0.8, (0.0, 0.0, 1.0, 1.0));
b.id = Some(AnnId(7));
let dts = CocoDetections::from_inputs(vec![a, b]).unwrap();
let ids: Vec<AnnId> = dts.detections().iter().map(|d| d.id).collect();
assert_eq!(ids, vec![AnnId(42), AnnId(7)]);
}
#[test]
fn j3_derives_area_from_bbox() {
let dts =
CocoDetections::from_inputs(vec![dt_input(1, 1, 0.5, (10.0, 10.0, 4.0, 5.0))]).unwrap();
assert_eq!(dts.detections()[0].area, 20.0);
}
#[test]
fn rejects_non_finite_score() {
let err = CocoDetections::from_inputs(vec![dt_input(1, 1, f64::NAN, (0.0, 0.0, 1.0, 1.0))])
.unwrap_err();
assert!(matches!(
err,
EvalError::NonFinite {
context: "detection score"
}
));
}
#[test]
fn detections_indices_per_image_cat() {
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (0.0, 0.0, 1.0, 1.0)),
dt_input(1, 2, 0.8, (0.0, 0.0, 1.0, 1.0)),
dt_input(2, 1, 0.7, (0.0, 0.0, 1.0, 1.0)),
])
.unwrap();
assert_eq!(dts.indices_for(ImageId(1), CategoryId(1)), &[0]);
assert_eq!(dts.indices_for(ImageId(1), CategoryId(2)), &[1]);
assert_eq!(dts.indices_for(ImageId(2), CategoryId(1)), &[2]);
assert!(dts.indices_for(ImageId(99), CategoryId(1)).is_empty());
let img1: Vec<usize> = dts.indices_for_image(ImageId(1)).to_vec();
assert_eq!(img1, vec![0, 1]);
}
#[test]
fn loads_detections_from_json_array() {
const JSON: &str = r#"[
{"image_id": 1, "category_id": 1, "score": 0.9,
"bbox": [0, 0, 2, 3]},
{"id": 7, "image_id": 1, "category_id": 1, "score": 0.5,
"bbox": [1, 1, 1, 1]}
]"#;
let dts = CocoDetections::from_json_bytes(JSON.as_bytes()).unwrap();
let ds = dts.detections();
assert_eq!(ds[0].id, AnnId(1)); assert_eq!(ds[0].area, 6.0); assert_eq!(ds[1].id, AnnId(7)); assert!(!ds[0].is_crowd()); assert!(ds[0].segmentation.is_none());
}
#[test]
fn gt_loads_polygon_segmentation() {
const JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 4, 4], "area": 16, "iscrowd": 0,
"segmentation": [[0, 0, 4, 0, 4, 4, 0, 4]]}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(JSON.as_bytes()).unwrap();
let seg = ds.annotations()[0].segmentation.as_ref().unwrap();
let rle = seg.to_rle(10, 10).unwrap();
assert_eq!(rle.area(), 16);
}
#[test]
fn gt_loads_compressed_rle_segmentation() {
let counts_str = String::from_utf8(vernier_mask::encode_counts(&[0, 16])).unwrap();
let json = format!(
r#"{{
"images": [{{"id": 1, "width": 4, "height": 4}}],
"annotations": [
{{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 4, 4], "area": 16, "iscrowd": 1,
"segmentation": {{"size": [4, 4], "counts": "{counts_str}"}}}}
],
"categories": [{{"id": 1, "name": "thing"}}]
}}"#
);
let ds = CocoDataset::from_json_bytes(json.as_bytes()).unwrap();
let seg = ds.annotations()[0].segmentation.as_ref().unwrap();
let rle = seg.to_rle(4, 4).unwrap();
assert_eq!((rle.h, rle.w), (4, 4));
assert_eq!(rle.area(), 16);
}
#[test]
fn gt_segmentation_round_trips_through_to_json_value() {
const JSON: &str = r#"{
"images": [{"id": 1, "width": 10, "height": 10}],
"annotations": [
{"id": 1, "image_id": 1, "category_id": 1,
"bbox": [0, 0, 4, 4], "area": 16, "iscrowd": 0,
"segmentation": [[0, 0, 4, 0, 4, 4, 0, 4]]}
],
"categories": [{"id": 1, "name": "thing"}]
}"#;
let ds = CocoDataset::from_json_bytes(JSON.as_bytes()).unwrap();
let serialized = serde_json::to_string(&ds.to_json_value()).unwrap();
let again = CocoDataset::from_json_bytes(serialized.as_bytes()).unwrap();
assert_eq!(ds.annotations(), again.annotations());
}
#[test]
fn gt_without_segmentation_field_loads_as_none() {
let ds = load_crowd_region();
assert!(ds.annotations().iter().all(|a| a.segmentation.is_none()));
}
#[test]
fn dt_loads_compressed_rle_segmentation() {
const JSON: &str = r#"[
{"image_id": 1, "category_id": 1, "score": 0.9,
"bbox": [0, 0, 4, 4],
"segmentation": {"size": [4, 4], "counts": "04L4"}}
]"#;
let dts = CocoDetections::from_json_bytes(JSON.as_bytes()).unwrap();
assert!(dts.detections()[0].segmentation.is_some());
}
#[test]
fn dt_without_segmentation_loads_as_none() {
const JSON: &str = r#"[
{"image_id": 1, "category_id": 1, "score": 0.9, "bbox": [0, 0, 1, 1]}
]"#;
let dts = CocoDetections::from_json_bytes(JSON.as_bytes()).unwrap();
assert!(dts.detections()[0].segmentation.is_none());
}
fn arb_image() -> impl Strategy<Value = ImageMeta> {
(1i64..1000, 1u32..2048, 1u32..2048).prop_map(|(id, w, h)| ImageMeta {
id: ImageId(id),
width: w,
height: h,
file_name: None,
})
}
fn arb_category() -> impl Strategy<Value = CategoryMeta> {
(1i64..100, "[a-z]{1,8}").prop_map(|(id, name)| CategoryMeta {
id: CategoryId(id),
name,
supercategory: None,
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(64))]
#[test]
fn index_invariants_hold(
images in proptest::collection::vec(arb_image(), 1..6),
categories in proptest::collection::vec(arb_category(), 1..6),
n_anns in 0usize..40,
ann_seed in any::<u64>(),
) {
let mut images = images;
images.sort_by_key(|i| i.id);
images.dedup_by_key(|i| i.id);
let mut categories = categories;
categories.sort_by_key(|c| c.id);
categories.dedup_by_key(|c| c.id);
let mut state = ann_seed.wrapping_add(1);
let mut next = || {
state = state.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
state
};
let mut annotations = Vec::with_capacity(n_anns);
for ann_idx in 0..n_anns {
let img = &images[(next() as usize) % images.len()];
let cat = &categories[(next() as usize) % categories.len()];
annotations.push(CocoAnnotation {
id: AnnId(ann_idx as i64 + 1),
image_id: img.id,
category_id: cat.id,
area: 1.0,
is_crowd: false,
ignore_flag: None,
bbox: Bbox { x: 0.0, y: 0.0, w: 1.0, h: 1.0 },
segmentation: None,
keypoints: None,
num_keypoints: None,
});
}
let ds = CocoDataset::from_parts(
images.clone(), annotations.clone(), categories.clone()
).unwrap();
let mut seen_img: Vec<usize> = images.iter()
.flat_map(|i| ds.ann_indices_for_image(i.id).iter().copied())
.collect();
seen_img.sort_unstable();
let expected: Vec<usize> = (0..annotations.len()).collect();
prop_assert_eq!(&seen_img, &expected);
let mut seen_cat: Vec<usize> = categories.iter()
.flat_map(|c| ds.ann_indices_for_category(c.id).iter().copied())
.collect();
seen_cat.sort_unstable();
prop_assert_eq!(&seen_cat, &expected);
for img in &images {
for &idx in ds.ann_indices_for_image(img.id) {
prop_assert_eq!(ds.annotations()[idx].image_id, img.id);
}
}
for cat in &categories {
for &idx in ds.ann_indices_for_category(cat.id) {
prop_assert_eq!(ds.annotations()[idx].category_id, cat.id);
}
}
}
}
}