use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CocoDataset {
#[serde(default)]
pub info: CocoInfo,
#[serde(default)]
pub licenses: Vec<CocoLicense>,
pub images: Vec<CocoImage>,
#[serde(default)]
pub annotations: Vec<CocoAnnotation>,
#[serde(default)]
pub categories: Vec<CocoCategory>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CocoInfo {
#[serde(default)]
pub year: Option<u32>,
#[serde(default)]
pub version: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub contributor: Option<String>,
#[serde(default)]
pub url: Option<String>,
#[serde(default)]
pub date_created: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CocoLicense {
pub id: u32,
pub name: String,
#[serde(default)]
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CocoImage {
pub id: u64,
pub width: u32,
pub height: u32,
#[serde(default)]
pub file_name: String,
#[serde(default)]
pub license: Option<u32>,
#[serde(default)]
pub flickr_url: Option<String>,
#[serde(default)]
pub coco_url: Option<String>,
#[serde(default)]
pub date_captured: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub neg_category_ids: Option<Vec<u32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub not_exhaustive_category_ids: Option<Vec<u32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CocoCategory {
pub id: u32,
pub name: String,
#[serde(default)]
pub supercategory: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub synset: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub frequency: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub synonyms: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub def: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_count: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instance_count: Option<u32>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CocoAnnotation {
pub id: u64,
pub image_id: u64,
pub category_id: u32,
pub bbox: [f64; 4],
#[serde(default)]
pub area: f64,
#[serde(default)]
pub iscrowd: u8,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub segmentation: Option<CocoSegmentation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub score: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CocoSegmentation {
Polygon(Vec<Vec<f64>>),
Rle(CocoRle),
CompressedRle(CocoCompressedRle),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CocoRle {
pub counts: Vec<u32>,
pub size: [u32; 2],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CocoCompressedRle {
pub counts: String,
pub size: [u32; 2],
}
#[derive(Debug, Clone)]
pub struct CocoIndex {
pub images: HashMap<u64, CocoImage>,
pub categories: HashMap<u32, CocoCategory>,
pub label_indices: HashMap<u32, u64>,
pub annotations_by_image: HashMap<u64, Vec<CocoAnnotation>>,
pub frequencies: HashMap<u32, String>,
}
impl CocoIndex {
pub fn from_dataset(dataset: &CocoDataset) -> Self {
let images: HashMap<_, _> = dataset
.images
.iter()
.map(|img| (img.id, img.clone()))
.collect();
let categories: HashMap<_, _> = dataset
.categories
.iter()
.map(|cat| (cat.id, cat.clone()))
.collect();
let label_indices: HashMap<_, _> = dataset
.categories
.iter()
.map(|c| (c.id, c.id as u64))
.collect();
let frequencies: HashMap<_, _> = dataset
.categories
.iter()
.filter_map(|c| c.frequency.as_ref().map(|f| (c.id, f.clone())))
.collect();
let mut annotations_by_image: HashMap<u64, Vec<CocoAnnotation>> = HashMap::new();
for ann in &dataset.annotations {
annotations_by_image
.entry(ann.image_id)
.or_default()
.push(ann.clone());
}
Self {
images,
categories,
label_indices,
annotations_by_image,
frequencies,
}
}
pub fn label_name(&self, category_id: u32) -> Option<&str> {
self.categories.get(&category_id).map(|c| c.name.as_str())
}
pub fn label_index(&self, category_id: u32) -> Option<u64> {
self.label_indices.get(&category_id).copied()
}
pub fn annotations_for_image(&self, image_id: u64) -> &[CocoAnnotation] {
self.annotations_by_image
.get(&image_id)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn frequency(&self, category_id: u32) -> Option<&str> {
self.frequencies.get(&category_id).map(|s| s.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coco_dataset_default() {
let dataset = CocoDataset::default();
assert!(dataset.images.is_empty());
assert!(dataset.annotations.is_empty());
assert!(dataset.categories.is_empty());
}
#[test]
fn test_coco_index_from_dataset() {
let dataset = CocoDataset {
images: vec![
CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "image1.jpg".to_string(),
..Default::default()
},
CocoImage {
id: 2,
width: 800,
height: 600,
file_name: "image2.jpg".to_string(),
..Default::default()
},
],
categories: vec![
CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: Some("human".to_string()),
..Default::default()
},
CocoCategory {
id: 2,
name: "car".to_string(),
supercategory: Some("vehicle".to_string()),
..Default::default()
},
],
annotations: vec![
CocoAnnotation {
id: 100,
image_id: 1,
category_id: 1,
bbox: [10.0, 20.0, 100.0, 200.0],
area: 20000.0,
iscrowd: 0,
segmentation: None,
score: None,
},
CocoAnnotation {
id: 101,
image_id: 1,
category_id: 2,
bbox: [50.0, 60.0, 150.0, 100.0],
area: 15000.0,
iscrowd: 0,
segmentation: None,
score: None,
},
],
..Default::default()
};
let index = CocoIndex::from_dataset(&dataset);
assert_eq!(index.images.len(), 2);
assert_eq!(index.images.get(&1).unwrap().file_name, "image1.jpg");
assert_eq!(index.categories.len(), 2);
assert_eq!(index.label_name(1), Some("person"));
assert_eq!(index.label_name(2), Some("car"));
assert_eq!(index.label_index(2), Some(2)); assert_eq!(index.label_index(1), Some(1));
let anns = index.annotations_for_image(1);
assert_eq!(anns.len(), 2);
let anns = index.annotations_for_image(2);
assert!(anns.is_empty());
}
#[test]
fn test_coco_segmentation_polygon_deserialize() {
let json = r#"[[100.0, 200.0, 150.0, 250.0, 100.0, 250.0]]"#;
let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
match seg {
CocoSegmentation::Polygon(polys) => {
assert_eq!(polys.len(), 1);
assert_eq!(polys[0].len(), 6);
}
_ => panic!("Expected polygon segmentation"),
}
}
#[test]
fn test_coco_segmentation_rle_deserialize() {
let json = r#"{"counts": [10, 20, 30, 40], "size": [100, 200]}"#;
let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
match seg {
CocoSegmentation::Rle(rle) => {
assert_eq!(rle.counts, vec![10, 20, 30, 40]);
assert_eq!(rle.size, [100, 200]);
}
_ => panic!("Expected RLE segmentation"),
}
}
#[test]
fn test_coco_annotation_roundtrip() {
let ann = CocoAnnotation {
id: 12345,
image_id: 67890,
category_id: 1,
bbox: [100.5, 200.5, 50.0, 80.0],
area: 4000.0,
iscrowd: 0,
segmentation: Some(CocoSegmentation::Polygon(vec![vec![
100.0, 200.0, 150.0, 200.0, 150.0, 280.0, 100.0, 280.0,
]])),
score: None,
};
let json = serde_json::to_string(&ann).unwrap();
let restored: CocoAnnotation = serde_json::from_str(&json).unwrap();
assert_eq!(restored.id, ann.id);
assert_eq!(restored.image_id, ann.image_id);
assert_eq!(restored.category_id, ann.category_id);
assert_eq!(restored.bbox, ann.bbox);
}
#[test]
fn test_coco_index_preserves_category_id() {
let dataset = CocoDataset {
images: vec![CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "img.jpg".to_string(),
..Default::default()
}],
categories: vec![
CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
},
CocoCategory {
id: 3,
name: "car".to_string(),
supercategory: None,
..Default::default()
},
CocoCategory {
id: 90,
name: "toothbrush".to_string(),
supercategory: None,
..Default::default()
},
],
annotations: vec![],
..Default::default()
};
let index = CocoIndex::from_dataset(&dataset);
assert_eq!(index.label_index(1), Some(1)); assert_eq!(index.label_index(3), Some(3)); assert_eq!(index.label_index(90), Some(90));
assert_eq!(index.label_index(2), None);
assert_eq!(index.label_index(50), None);
}
#[test]
fn test_lvis_image_deserialize() {
let json = r#"{
"id": 397133,
"width": 640,
"height": 480,
"file_name": "000000397133.jpg",
"neg_category_ids": [5, 12, 87],
"not_exhaustive_category_ids": [3, 45]
}"#;
let image: CocoImage = serde_json::from_str(json).unwrap();
assert_eq!(image.neg_category_ids, Some(vec![5, 12, 87]));
assert_eq!(image.not_exhaustive_category_ids, Some(vec![3, 45]));
}
#[test]
fn test_lvis_category_deserialize() {
let json = r#"{
"id": 1,
"name": "aerosol_can",
"synset": "aerosol.n.02",
"frequency": "c",
"synonyms": ["aerosol_can", "spray_can"],
"def": "a dispenser that holds a substance under pressure",
"image_count": 57,
"instance_count": 98
}"#;
let cat: CocoCategory = serde_json::from_str(json).unwrap();
assert_eq!(cat.synset, Some("aerosol.n.02".to_string()));
assert_eq!(cat.frequency, Some("c".to_string()));
assert_eq!(
cat.synonyms,
Some(vec!["aerosol_can".to_string(), "spray_can".to_string()])
);
assert_eq!(
cat.def,
Some("a dispenser that holds a substance under pressure".to_string())
);
assert_eq!(cat.image_count, Some(57));
assert_eq!(cat.instance_count, Some(98));
}
#[test]
fn test_standard_coco_still_parses() {
let json = r#"{"id": 1, "name": "person", "supercategory": "human"}"#;
let cat: CocoCategory = serde_json::from_str(json).unwrap();
assert_eq!(cat.name, "person");
assert_eq!(cat.synset, None);
assert_eq!(cat.frequency, None);
assert_eq!(cat.synonyms, None);
assert_eq!(cat.def, None);
}
}