1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct CocoDataset {
19 #[serde(default)]
21 pub info: CocoInfo,
22 #[serde(default)]
24 pub licenses: Vec<CocoLicense>,
25 pub images: Vec<CocoImage>,
27 #[serde(default)]
29 pub annotations: Vec<CocoAnnotation>,
30 #[serde(default)]
32 pub categories: Vec<CocoCategory>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CocoInfo {
38 #[serde(default)]
40 pub year: Option<u32>,
41 #[serde(default)]
43 pub version: Option<String>,
44 #[serde(default)]
46 pub description: Option<String>,
47 #[serde(default)]
49 pub contributor: Option<String>,
50 #[serde(default)]
52 pub url: Option<String>,
53 #[serde(default)]
55 pub date_created: Option<String>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CocoLicense {
61 pub id: u32,
63 pub name: String,
65 #[serde(default)]
67 pub url: Option<String>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, Default)]
74pub struct CocoImage {
75 pub id: u64,
77 pub width: u32,
79 pub height: u32,
81 pub file_name: String,
83 #[serde(default)]
85 pub license: Option<u32>,
86 #[serde(default)]
88 pub flickr_url: Option<String>,
89 #[serde(default)]
91 pub coco_url: Option<String>,
92 #[serde(default)]
94 pub date_captured: Option<String>,
95 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub neg_category_ids: Option<Vec<u32>>,
98 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub not_exhaustive_category_ids: Option<Vec<u32>>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Default)]
107pub struct CocoCategory {
108 pub id: u32,
110 pub name: String,
112 #[serde(default)]
114 pub supercategory: Option<String>,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
117 pub synset: Option<String>,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub frequency: Option<String>,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub synonyms: Option<Vec<String>>,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub def: Option<String>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub image_count: Option<u32>,
130 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub instance_count: Option<u32>,
133}
134
135#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct CocoAnnotation {
142 pub id: u64,
144 pub image_id: u64,
146 pub category_id: u32,
148 pub bbox: [f64; 4],
150 #[serde(default)]
152 pub area: f64,
153 #[serde(default)]
155 pub iscrowd: u8,
156 #[serde(default, skip_serializing_if = "Option::is_none")]
158 pub segmentation: Option<CocoSegmentation>,
159 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub score: Option<f64>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
171#[serde(untagged)]
172pub enum CocoSegmentation {
173 Polygon(Vec<Vec<f64>>),
177 Rle(CocoRle),
179 CompressedRle(CocoCompressedRle),
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct CocoRle {
189 pub counts: Vec<u32>,
191 pub size: [u32; 2],
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct CocoCompressedRle {
200 pub counts: String,
202 pub size: [u32; 2],
204}
205
206#[derive(Debug, Clone)]
210pub struct CocoIndex {
211 pub images: HashMap<u64, CocoImage>,
213 pub categories: HashMap<u32, CocoCategory>,
215 pub label_indices: HashMap<u32, u64>,
217 pub annotations_by_image: HashMap<u64, Vec<CocoAnnotation>>,
219 pub frequencies: HashMap<u32, String>,
221}
222
223impl CocoIndex {
224 pub fn from_dataset(dataset: &CocoDataset) -> Self {
229 let images: HashMap<_, _> = dataset
230 .images
231 .iter()
232 .map(|img| (img.id, img.clone()))
233 .collect();
234
235 let categories: HashMap<_, _> = dataset
236 .categories
237 .iter()
238 .map(|cat| (cat.id, cat.clone()))
239 .collect();
240
241 let label_indices: HashMap<_, _> = dataset
243 .categories
244 .iter()
245 .map(|c| (c.id, c.id as u64))
246 .collect();
247
248 let frequencies: HashMap<_, _> = dataset
249 .categories
250 .iter()
251 .filter_map(|c| c.frequency.as_ref().map(|f| (c.id, f.clone())))
252 .collect();
253
254 let mut annotations_by_image: HashMap<u64, Vec<CocoAnnotation>> = HashMap::new();
255 for ann in &dataset.annotations {
256 annotations_by_image
257 .entry(ann.image_id)
258 .or_default()
259 .push(ann.clone());
260 }
261
262 Self {
263 images,
264 categories,
265 label_indices,
266 annotations_by_image,
267 frequencies,
268 }
269 }
270
271 pub fn label_name(&self, category_id: u32) -> Option<&str> {
273 self.categories.get(&category_id).map(|c| c.name.as_str())
274 }
275
276 pub fn label_index(&self, category_id: u32) -> Option<u64> {
278 self.label_indices.get(&category_id).copied()
279 }
280
281 pub fn annotations_for_image(&self, image_id: u64) -> &[CocoAnnotation] {
283 self.annotations_by_image
284 .get(&image_id)
285 .map(|v| v.as_slice())
286 .unwrap_or(&[])
287 }
288
289 pub fn frequency(&self, category_id: u32) -> Option<&str> {
291 self.frequencies.get(&category_id).map(|s| s.as_str())
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_coco_dataset_default() {
301 let dataset = CocoDataset::default();
302 assert!(dataset.images.is_empty());
303 assert!(dataset.annotations.is_empty());
304 assert!(dataset.categories.is_empty());
305 }
306
307 #[test]
308 fn test_coco_index_from_dataset() {
309 let dataset = CocoDataset {
310 images: vec![
311 CocoImage {
312 id: 1,
313 width: 640,
314 height: 480,
315 file_name: "image1.jpg".to_string(),
316 ..Default::default()
317 },
318 CocoImage {
319 id: 2,
320 width: 800,
321 height: 600,
322 file_name: "image2.jpg".to_string(),
323 ..Default::default()
324 },
325 ],
326 categories: vec![
327 CocoCategory {
328 id: 1,
329 name: "person".to_string(),
330 supercategory: Some("human".to_string()),
331 ..Default::default()
332 },
333 CocoCategory {
334 id: 2,
335 name: "car".to_string(),
336 supercategory: Some("vehicle".to_string()),
337 ..Default::default()
338 },
339 ],
340 annotations: vec![
341 CocoAnnotation {
342 id: 100,
343 image_id: 1,
344 category_id: 1,
345 bbox: [10.0, 20.0, 100.0, 200.0],
346 area: 20000.0,
347 iscrowd: 0,
348 segmentation: None,
349 score: None,
350 },
351 CocoAnnotation {
352 id: 101,
353 image_id: 1,
354 category_id: 2,
355 bbox: [50.0, 60.0, 150.0, 100.0],
356 area: 15000.0,
357 iscrowd: 0,
358 segmentation: None,
359 score: None,
360 },
361 ],
362 ..Default::default()
363 };
364
365 let index = CocoIndex::from_dataset(&dataset);
366
367 assert_eq!(index.images.len(), 2);
369 assert_eq!(index.images.get(&1).unwrap().file_name, "image1.jpg");
370
371 assert_eq!(index.categories.len(), 2);
373 assert_eq!(index.label_name(1), Some("person"));
374 assert_eq!(index.label_name(2), Some("car"));
375
376 assert_eq!(index.label_index(2), Some(2)); assert_eq!(index.label_index(1), Some(1)); let anns = index.annotations_for_image(1);
382 assert_eq!(anns.len(), 2);
383
384 let anns = index.annotations_for_image(2);
385 assert!(anns.is_empty());
386 }
387
388 #[test]
389 fn test_coco_segmentation_polygon_deserialize() {
390 let json = r#"[[100.0, 200.0, 150.0, 250.0, 100.0, 250.0]]"#;
391 let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
392
393 match seg {
394 CocoSegmentation::Polygon(polys) => {
395 assert_eq!(polys.len(), 1);
396 assert_eq!(polys[0].len(), 6);
397 }
398 _ => panic!("Expected polygon segmentation"),
399 }
400 }
401
402 #[test]
403 fn test_coco_segmentation_rle_deserialize() {
404 let json = r#"{"counts": [10, 20, 30, 40], "size": [100, 200]}"#;
405 let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
406
407 match seg {
408 CocoSegmentation::Rle(rle) => {
409 assert_eq!(rle.counts, vec![10, 20, 30, 40]);
410 assert_eq!(rle.size, [100, 200]);
411 }
412 _ => panic!("Expected RLE segmentation"),
413 }
414 }
415
416 #[test]
417 fn test_coco_annotation_roundtrip() {
418 let ann = CocoAnnotation {
419 id: 12345,
420 image_id: 67890,
421 category_id: 1,
422 bbox: [100.5, 200.5, 50.0, 80.0],
423 area: 4000.0,
424 iscrowd: 0,
425 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
426 100.0, 200.0, 150.0, 200.0, 150.0, 280.0, 100.0, 280.0,
427 ]])),
428 score: None,
429 };
430
431 let json = serde_json::to_string(&ann).unwrap();
432 let restored: CocoAnnotation = serde_json::from_str(&json).unwrap();
433
434 assert_eq!(restored.id, ann.id);
435 assert_eq!(restored.image_id, ann.image_id);
436 assert_eq!(restored.category_id, ann.category_id);
437 assert_eq!(restored.bbox, ann.bbox);
438 }
439
440 #[test]
441 fn test_coco_index_preserves_category_id() {
442 let dataset = CocoDataset {
444 images: vec![CocoImage {
445 id: 1,
446 width: 640,
447 height: 480,
448 file_name: "img.jpg".to_string(),
449 ..Default::default()
450 }],
451 categories: vec![
452 CocoCategory {
453 id: 1,
454 name: "person".to_string(),
455 supercategory: None,
456 ..Default::default()
457 },
458 CocoCategory {
459 id: 3,
460 name: "car".to_string(),
461 supercategory: None,
462 ..Default::default()
463 },
464 CocoCategory {
465 id: 90,
466 name: "toothbrush".to_string(),
467 supercategory: None,
468 ..Default::default()
469 },
470 ],
471 annotations: vec![],
472 ..Default::default()
473 };
474
475 let index = CocoIndex::from_dataset(&dataset);
476
477 assert_eq!(index.label_index(1), Some(1)); assert_eq!(index.label_index(3), Some(3)); assert_eq!(index.label_index(90), Some(90)); assert_eq!(index.label_index(2), None);
484 assert_eq!(index.label_index(50), None);
485 }
486
487 #[test]
488 fn test_lvis_image_deserialize() {
489 let json = r#"{
490 "id": 397133,
491 "width": 640,
492 "height": 480,
493 "file_name": "000000397133.jpg",
494 "neg_category_ids": [5, 12, 87],
495 "not_exhaustive_category_ids": [3, 45]
496 }"#;
497 let image: CocoImage = serde_json::from_str(json).unwrap();
498 assert_eq!(image.neg_category_ids, Some(vec![5, 12, 87]));
499 assert_eq!(image.not_exhaustive_category_ids, Some(vec![3, 45]));
500 }
501
502 #[test]
503 fn test_lvis_category_deserialize() {
504 let json = r#"{
505 "id": 1,
506 "name": "aerosol_can",
507 "synset": "aerosol.n.02",
508 "frequency": "c",
509 "synonyms": ["aerosol_can", "spray_can"],
510 "def": "a dispenser that holds a substance under pressure",
511 "image_count": 57,
512 "instance_count": 98
513 }"#;
514 let cat: CocoCategory = serde_json::from_str(json).unwrap();
515 assert_eq!(cat.synset, Some("aerosol.n.02".to_string()));
516 assert_eq!(cat.frequency, Some("c".to_string()));
517 assert_eq!(
518 cat.synonyms,
519 Some(vec!["aerosol_can".to_string(), "spray_can".to_string()])
520 );
521 assert_eq!(
522 cat.def,
523 Some("a dispenser that holds a substance under pressure".to_string())
524 );
525 assert_eq!(cat.image_count, Some(57));
526 assert_eq!(cat.instance_count, Some(98));
527 }
528
529 #[test]
530 fn test_standard_coco_still_parses() {
531 let json = r#"{"id": 1, "name": "person", "supercategory": "human"}"#;
532 let cat: CocoCategory = serde_json::from_str(json).unwrap();
533 assert_eq!(cat.name, "person");
534 assert_eq!(cat.synset, None);
535 assert_eq!(cat.frequency, None);
536 assert_eq!(cat.synonyms, None);
537 assert_eq!(cat.def, None);
538 }
539}