1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct CocoDataset {
19 #[serde(default)]
21 pub info: CocoInfo,
22 #[serde(default)]
24 pub licenses: Vec<CocoLicense>,
25 pub images: Vec<CocoImage>,
27 #[serde(default)]
29 pub annotations: Vec<CocoAnnotation>,
30 #[serde(default)]
32 pub categories: Vec<CocoCategory>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CocoInfo {
38 #[serde(default)]
40 pub year: Option<u32>,
41 #[serde(default)]
43 pub version: Option<String>,
44 #[serde(default)]
46 pub description: Option<String>,
47 #[serde(default)]
49 pub contributor: Option<String>,
50 #[serde(default)]
52 pub url: Option<String>,
53 #[serde(default)]
55 pub date_created: Option<String>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CocoLicense {
61 pub id: u32,
63 pub name: String,
65 #[serde(default)]
67 pub url: Option<String>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, Default)]
74pub struct CocoImage {
75 pub id: u64,
77 pub width: u32,
79 pub height: u32,
81 #[serde(default)]
88 pub file_name: String,
89 #[serde(default)]
91 pub license: Option<u32>,
92 #[serde(default)]
94 pub flickr_url: Option<String>,
95 #[serde(default)]
97 pub coco_url: Option<String>,
98 #[serde(default)]
100 pub date_captured: Option<String>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub neg_category_ids: Option<Vec<u32>>,
104 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub not_exhaustive_category_ids: Option<Vec<u32>>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, Default)]
113pub struct CocoCategory {
114 pub id: u32,
116 pub name: String,
118 #[serde(default)]
120 pub supercategory: Option<String>,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub synset: Option<String>,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub frequency: Option<String>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub synonyms: Option<Vec<String>>,
130 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub def: Option<String>,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub image_count: Option<u32>,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub instance_count: Option<u32>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
147pub struct CocoAnnotation {
148 pub id: u64,
150 pub image_id: u64,
152 pub category_id: u32,
154 pub bbox: [f64; 4],
156 #[serde(default)]
158 pub area: f64,
159 #[serde(default)]
161 pub iscrowd: u8,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
164 pub segmentation: Option<CocoSegmentation>,
165 #[serde(default, skip_serializing_if = "Option::is_none")]
167 pub score: Option<f64>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(untagged)]
178pub enum CocoSegmentation {
179 Polygon(Vec<Vec<f64>>),
183 Rle(CocoRle),
185 CompressedRle(CocoCompressedRle),
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct CocoRle {
195 pub counts: Vec<u32>,
197 pub size: [u32; 2],
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct CocoCompressedRle {
206 pub counts: String,
208 pub size: [u32; 2],
210}
211
212#[derive(Debug, Clone)]
216pub struct CocoIndex {
217 pub images: HashMap<u64, CocoImage>,
219 pub categories: HashMap<u32, CocoCategory>,
221 pub label_indices: HashMap<u32, u64>,
223 pub annotations_by_image: HashMap<u64, Vec<CocoAnnotation>>,
225 pub frequencies: HashMap<u32, String>,
227}
228
229impl CocoIndex {
230 pub fn from_dataset(dataset: &CocoDataset) -> Self {
235 let images: HashMap<_, _> = dataset
236 .images
237 .iter()
238 .map(|img| (img.id, img.clone()))
239 .collect();
240
241 let categories: HashMap<_, _> = dataset
242 .categories
243 .iter()
244 .map(|cat| (cat.id, cat.clone()))
245 .collect();
246
247 let label_indices: HashMap<_, _> = dataset
249 .categories
250 .iter()
251 .map(|c| (c.id, c.id as u64))
252 .collect();
253
254 let frequencies: HashMap<_, _> = dataset
255 .categories
256 .iter()
257 .filter_map(|c| c.frequency.as_ref().map(|f| (c.id, f.clone())))
258 .collect();
259
260 let mut annotations_by_image: HashMap<u64, Vec<CocoAnnotation>> = HashMap::new();
261 for ann in &dataset.annotations {
262 annotations_by_image
263 .entry(ann.image_id)
264 .or_default()
265 .push(ann.clone());
266 }
267
268 Self {
269 images,
270 categories,
271 label_indices,
272 annotations_by_image,
273 frequencies,
274 }
275 }
276
277 pub fn label_name(&self, category_id: u32) -> Option<&str> {
279 self.categories.get(&category_id).map(|c| c.name.as_str())
280 }
281
282 pub fn label_index(&self, category_id: u32) -> Option<u64> {
284 self.label_indices.get(&category_id).copied()
285 }
286
287 pub fn annotations_for_image(&self, image_id: u64) -> &[CocoAnnotation] {
289 self.annotations_by_image
290 .get(&image_id)
291 .map(|v| v.as_slice())
292 .unwrap_or(&[])
293 }
294
295 pub fn frequency(&self, category_id: u32) -> Option<&str> {
297 self.frequencies.get(&category_id).map(|s| s.as_str())
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_coco_dataset_default() {
307 let dataset = CocoDataset::default();
308 assert!(dataset.images.is_empty());
309 assert!(dataset.annotations.is_empty());
310 assert!(dataset.categories.is_empty());
311 }
312
313 #[test]
314 fn test_coco_index_from_dataset() {
315 let dataset = CocoDataset {
316 images: vec![
317 CocoImage {
318 id: 1,
319 width: 640,
320 height: 480,
321 file_name: "image1.jpg".to_string(),
322 ..Default::default()
323 },
324 CocoImage {
325 id: 2,
326 width: 800,
327 height: 600,
328 file_name: "image2.jpg".to_string(),
329 ..Default::default()
330 },
331 ],
332 categories: vec![
333 CocoCategory {
334 id: 1,
335 name: "person".to_string(),
336 supercategory: Some("human".to_string()),
337 ..Default::default()
338 },
339 CocoCategory {
340 id: 2,
341 name: "car".to_string(),
342 supercategory: Some("vehicle".to_string()),
343 ..Default::default()
344 },
345 ],
346 annotations: vec![
347 CocoAnnotation {
348 id: 100,
349 image_id: 1,
350 category_id: 1,
351 bbox: [10.0, 20.0, 100.0, 200.0],
352 area: 20000.0,
353 iscrowd: 0,
354 segmentation: None,
355 score: None,
356 },
357 CocoAnnotation {
358 id: 101,
359 image_id: 1,
360 category_id: 2,
361 bbox: [50.0, 60.0, 150.0, 100.0],
362 area: 15000.0,
363 iscrowd: 0,
364 segmentation: None,
365 score: None,
366 },
367 ],
368 ..Default::default()
369 };
370
371 let index = CocoIndex::from_dataset(&dataset);
372
373 assert_eq!(index.images.len(), 2);
375 assert_eq!(index.images.get(&1).unwrap().file_name, "image1.jpg");
376
377 assert_eq!(index.categories.len(), 2);
379 assert_eq!(index.label_name(1), Some("person"));
380 assert_eq!(index.label_name(2), Some("car"));
381
382 assert_eq!(index.label_index(2), Some(2)); assert_eq!(index.label_index(1), Some(1)); let anns = index.annotations_for_image(1);
388 assert_eq!(anns.len(), 2);
389
390 let anns = index.annotations_for_image(2);
391 assert!(anns.is_empty());
392 }
393
394 #[test]
395 fn test_coco_segmentation_polygon_deserialize() {
396 let json = r#"[[100.0, 200.0, 150.0, 250.0, 100.0, 250.0]]"#;
397 let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
398
399 match seg {
400 CocoSegmentation::Polygon(polys) => {
401 assert_eq!(polys.len(), 1);
402 assert_eq!(polys[0].len(), 6);
403 }
404 _ => panic!("Expected polygon segmentation"),
405 }
406 }
407
408 #[test]
409 fn test_coco_segmentation_rle_deserialize() {
410 let json = r#"{"counts": [10, 20, 30, 40], "size": [100, 200]}"#;
411 let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
412
413 match seg {
414 CocoSegmentation::Rle(rle) => {
415 assert_eq!(rle.counts, vec![10, 20, 30, 40]);
416 assert_eq!(rle.size, [100, 200]);
417 }
418 _ => panic!("Expected RLE segmentation"),
419 }
420 }
421
422 #[test]
423 fn test_coco_annotation_roundtrip() {
424 let ann = CocoAnnotation {
425 id: 12345,
426 image_id: 67890,
427 category_id: 1,
428 bbox: [100.5, 200.5, 50.0, 80.0],
429 area: 4000.0,
430 iscrowd: 0,
431 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
432 100.0, 200.0, 150.0, 200.0, 150.0, 280.0, 100.0, 280.0,
433 ]])),
434 score: None,
435 };
436
437 let json = serde_json::to_string(&ann).unwrap();
438 let restored: CocoAnnotation = serde_json::from_str(&json).unwrap();
439
440 assert_eq!(restored.id, ann.id);
441 assert_eq!(restored.image_id, ann.image_id);
442 assert_eq!(restored.category_id, ann.category_id);
443 assert_eq!(restored.bbox, ann.bbox);
444 }
445
446 #[test]
447 fn test_coco_index_preserves_category_id() {
448 let dataset = CocoDataset {
450 images: vec![CocoImage {
451 id: 1,
452 width: 640,
453 height: 480,
454 file_name: "img.jpg".to_string(),
455 ..Default::default()
456 }],
457 categories: vec![
458 CocoCategory {
459 id: 1,
460 name: "person".to_string(),
461 supercategory: None,
462 ..Default::default()
463 },
464 CocoCategory {
465 id: 3,
466 name: "car".to_string(),
467 supercategory: None,
468 ..Default::default()
469 },
470 CocoCategory {
471 id: 90,
472 name: "toothbrush".to_string(),
473 supercategory: None,
474 ..Default::default()
475 },
476 ],
477 annotations: vec![],
478 ..Default::default()
479 };
480
481 let index = CocoIndex::from_dataset(&dataset);
482
483 assert_eq!(index.label_index(1), Some(1)); assert_eq!(index.label_index(3), Some(3)); assert_eq!(index.label_index(90), Some(90)); assert_eq!(index.label_index(2), None);
490 assert_eq!(index.label_index(50), None);
491 }
492
493 #[test]
494 fn test_lvis_image_deserialize() {
495 let json = r#"{
496 "id": 397133,
497 "width": 640,
498 "height": 480,
499 "file_name": "000000397133.jpg",
500 "neg_category_ids": [5, 12, 87],
501 "not_exhaustive_category_ids": [3, 45]
502 }"#;
503 let image: CocoImage = serde_json::from_str(json).unwrap();
504 assert_eq!(image.neg_category_ids, Some(vec![5, 12, 87]));
505 assert_eq!(image.not_exhaustive_category_ids, Some(vec![3, 45]));
506 }
507
508 #[test]
509 fn test_lvis_category_deserialize() {
510 let json = r#"{
511 "id": 1,
512 "name": "aerosol_can",
513 "synset": "aerosol.n.02",
514 "frequency": "c",
515 "synonyms": ["aerosol_can", "spray_can"],
516 "def": "a dispenser that holds a substance under pressure",
517 "image_count": 57,
518 "instance_count": 98
519 }"#;
520 let cat: CocoCategory = serde_json::from_str(json).unwrap();
521 assert_eq!(cat.synset, Some("aerosol.n.02".to_string()));
522 assert_eq!(cat.frequency, Some("c".to_string()));
523 assert_eq!(
524 cat.synonyms,
525 Some(vec!["aerosol_can".to_string(), "spray_can".to_string()])
526 );
527 assert_eq!(
528 cat.def,
529 Some("a dispenser that holds a substance under pressure".to_string())
530 );
531 assert_eq!(cat.image_count, Some(57));
532 assert_eq!(cat.instance_count, Some(98));
533 }
534
535 #[test]
536 fn test_standard_coco_still_parses() {
537 let json = r#"{"id": 1, "name": "person", "supercategory": "human"}"#;
538 let cat: CocoCategory = serde_json::from_str(json).unwrap();
539 assert_eq!(cat.name, "person");
540 assert_eq!(cat.synset, None);
541 assert_eq!(cat.frequency, None);
542 assert_eq!(cat.synonyms, None);
543 assert_eq!(cat.def, None);
544 }
545}