Skip to main content

edgefirst_client/coco/
types.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! COCO JSON data structures for serde serialization/deserialization.
5//!
6//! Supports object detection and instance segmentation annotation types.
7//! Keypoints, captions, and panoptic segmentation are NOT supported in this
8//! version.
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Top-level COCO dataset structure.
14///
15/// This is the root structure for COCO annotation files like
16/// `instances_train2017.json`.
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct CocoDataset {
19    /// Dataset metadata (optional but commonly present).
20    #[serde(default)]
21    pub info: CocoInfo,
22    /// License information for the images.
23    #[serde(default)]
24    pub licenses: Vec<CocoLicense>,
25    /// List of images in the dataset.
26    pub images: Vec<CocoImage>,
27    /// List of annotations (one per object instance).
28    #[serde(default)]
29    pub annotations: Vec<CocoAnnotation>,
30    /// List of object categories/classes.
31    #[serde(default)]
32    pub categories: Vec<CocoCategory>,
33}
34
35/// Dataset metadata.
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CocoInfo {
38    /// Year the dataset was created.
39    #[serde(default)]
40    pub year: Option<u32>,
41    /// Version string.
42    #[serde(default)]
43    pub version: Option<String>,
44    /// Dataset description.
45    #[serde(default)]
46    pub description: Option<String>,
47    /// Dataset contributor.
48    #[serde(default)]
49    pub contributor: Option<String>,
50    /// Dataset URL.
51    #[serde(default)]
52    pub url: Option<String>,
53    /// Date the dataset was created.
54    #[serde(default)]
55    pub date_created: Option<String>,
56}
57
58/// License information.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CocoLicense {
61    /// Unique license ID.
62    pub id: u32,
63    /// License name.
64    pub name: String,
65    /// License URL.
66    #[serde(default)]
67    pub url: Option<String>,
68}
69
70/// Image metadata.
71///
72/// Each image has a unique ID and associated metadata.
73#[derive(Debug, Clone, Serialize, Deserialize, Default)]
74pub struct CocoImage {
75    /// Unique image ID.
76    pub id: u64,
77    /// Image width in pixels.
78    pub width: u32,
79    /// Image height in pixels.
80    pub height: u32,
81    /// Filename (relative path within the images folder).
82    pub file_name: String,
83    /// License ID (references `CocoLicense.id`).
84    #[serde(default)]
85    pub license: Option<u32>,
86    /// Flickr URL (if from Flickr).
87    #[serde(default)]
88    pub flickr_url: Option<String>,
89    /// COCO download URL.
90    #[serde(default)]
91    pub coco_url: Option<String>,
92    /// Date the image was captured.
93    #[serde(default)]
94    pub date_captured: Option<String>,
95    /// LVIS: Categories verified as absent from this image.
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub neg_category_ids: Option<Vec<u32>>,
98    /// LVIS: Categories with possibly incomplete annotation in this image.
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub not_exhaustive_category_ids: Option<Vec<u32>>,
101}
102
103/// Category definition.
104///
105/// Categories define the object classes used in the dataset.
106#[derive(Debug, Clone, Serialize, Deserialize, Default)]
107pub struct CocoCategory {
108    /// Unique category ID.
109    pub id: u32,
110    /// Category name (e.g., "person", "car").
111    pub name: String,
112    /// Parent category name (e.g., "human" for "person").
113    #[serde(default)]
114    pub supercategory: Option<String>,
115    /// LVIS: WordNet synset identifier (e.g., "aerosol.n.02").
116    #[serde(default, skip_serializing_if = "Option::is_none")]
117    pub synset: Option<String>,
118    /// LVIS: Frequency group — "f" (frequent), "c" (common), "r" (rare).
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    pub frequency: Option<String>,
121    /// LVIS: Alternate names for this category.
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub synonyms: Option<Vec<String>>,
124    /// LVIS: Natural language definition.
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    pub def: Option<String>,
127    /// LVIS: Number of images containing this category.
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub image_count: Option<u32>,
130    /// LVIS: Total annotated instances of this category.
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub instance_count: Option<u32>,
133}
134
135/// Annotation for object detection and instance segmentation.
136///
137/// Each annotation represents a single object instance in an image.
138///
139/// Note: Keypoints, captions, and panoptic fields are NOT supported.
140#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct CocoAnnotation {
142    /// Unique annotation ID.
143    pub id: u64,
144    /// ID of the image containing this object.
145    pub image_id: u64,
146    /// Category ID of this object.
147    pub category_id: u32,
148    /// Bounding box: `[x, y, width, height]` in pixels (top-left corner).
149    pub bbox: [f64; 4],
150    /// Area of the segmentation mask in pixels².
151    #[serde(default)]
152    pub area: f64,
153    /// Whether this is a crowd annotation (0 = single instance, 1 = crowd).
154    #[serde(default)]
155    pub iscrowd: u8,
156    /// Segmentation mask (polygon or RLE format).
157    #[serde(default, skip_serializing_if = "Option::is_none")]
158    pub segmentation: Option<CocoSegmentation>,
159    /// Detection confidence score (present in COCO detection results).
160    #[serde(default, skip_serializing_if = "Option::is_none")]
161    pub score: Option<f64>,
162}
163
164/// Segmentation format: polygon array or RLE.
165///
166/// COCO supports two segmentation formats:
167/// - **Polygon**: For single instances (`iscrowd=0`), uses nested coordinate
168///   arrays
169/// - **RLE**: For crowds (`iscrowd=1`), uses run-length encoding
170#[derive(Debug, Clone, Serialize, Deserialize)]
171#[serde(untagged)]
172pub enum CocoSegmentation {
173    /// Polygon format: `[[x1,y1,x2,y2,...], [x3,y3,...]]`
174    ///
175    /// Multiple polygons represent disjoint regions of the same object.
176    Polygon(Vec<Vec<f64>>),
177    /// Uncompressed RLE format with counts array.
178    Rle(CocoRle),
179    /// Compressed RLE format with LEB128-encoded counts string.
180    CompressedRle(CocoCompressedRle),
181}
182
183/// Uncompressed RLE (Run-Length Encoding) segmentation.
184///
185/// The counts array alternates between background and foreground pixel runs,
186/// starting with background. The encoding is **column-major** (Fortran order).
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct CocoRle {
189    /// Run-length counts: `[bg_run, fg_run, bg_run, fg_run, ...]`
190    pub counts: Vec<u32>,
191    /// Image size as `[height, width]` (NOT `[width, height]`!)
192    pub size: [u32; 2],
193}
194
195/// Compressed RLE segmentation (LEB128 encoded).
196///
197/// Used by pycocotools for more compact storage.
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct CocoCompressedRle {
200    /// LEB128-encoded counts string.
201    pub counts: String,
202    /// Image size as `[height, width]`.
203    pub size: [u32; 2],
204}
205
206/// Lookup tables for efficient COCO data access.
207///
208/// Builds indexes from a `CocoDataset` for O(1) lookups.
209#[derive(Debug, Clone)]
210pub struct CocoIndex {
211    /// `image_id` → `CocoImage`
212    pub images: HashMap<u64, CocoImage>,
213    /// `category_id` → `CocoCategory`
214    pub categories: HashMap<u32, CocoCategory>,
215    /// `category_id` → `label_index` (source-faithful, preserves original category_id)
216    pub label_indices: HashMap<u32, u64>,
217    /// `image_id` → `Vec<CocoAnnotation>`
218    pub annotations_by_image: HashMap<u64, Vec<CocoAnnotation>>,
219    /// `category_id` → frequency group ("f", "c", "r")
220    pub frequencies: HashMap<u32, String>,
221}
222
223impl CocoIndex {
224    /// Build lookup index from a `CocoDataset`.
225    ///
226    /// Creates efficient lookup tables for accessing images, categories,
227    /// and annotations by their IDs.
228    pub fn from_dataset(dataset: &CocoDataset) -> Self {
229        let images: HashMap<_, _> = dataset
230            .images
231            .iter()
232            .map(|img| (img.id, img.clone()))
233            .collect();
234
235        let categories: HashMap<_, _> = dataset
236            .categories
237            .iter()
238            .map(|cat| (cat.id, cat.clone()))
239            .collect();
240
241        // Preserve source category_id as label_index (source-faithful)
242        let label_indices: HashMap<_, _> = dataset
243            .categories
244            .iter()
245            .map(|c| (c.id, c.id as u64))
246            .collect();
247
248        let frequencies: HashMap<_, _> = dataset
249            .categories
250            .iter()
251            .filter_map(|c| c.frequency.as_ref().map(|f| (c.id, f.clone())))
252            .collect();
253
254        let mut annotations_by_image: HashMap<u64, Vec<CocoAnnotation>> = HashMap::new();
255        for ann in &dataset.annotations {
256            annotations_by_image
257                .entry(ann.image_id)
258                .or_default()
259                .push(ann.clone());
260        }
261
262        Self {
263            images,
264            categories,
265            label_indices,
266            annotations_by_image,
267            frequencies,
268        }
269    }
270
271    /// Get the label name for a category ID.
272    pub fn label_name(&self, category_id: u32) -> Option<&str> {
273        self.categories.get(&category_id).map(|c| c.name.as_str())
274    }
275
276    /// Get the label index for a category ID.
277    pub fn label_index(&self, category_id: u32) -> Option<u64> {
278        self.label_indices.get(&category_id).copied()
279    }
280
281    /// Get annotations for an image.
282    pub fn annotations_for_image(&self, image_id: u64) -> &[CocoAnnotation] {
283        self.annotations_by_image
284            .get(&image_id)
285            .map(|v| v.as_slice())
286            .unwrap_or(&[])
287    }
288
289    /// Get the frequency group for a category ID.
290    pub fn frequency(&self, category_id: u32) -> Option<&str> {
291        self.frequencies.get(&category_id).map(|s| s.as_str())
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_coco_dataset_default() {
301        let dataset = CocoDataset::default();
302        assert!(dataset.images.is_empty());
303        assert!(dataset.annotations.is_empty());
304        assert!(dataset.categories.is_empty());
305    }
306
307    #[test]
308    fn test_coco_index_from_dataset() {
309        let dataset = CocoDataset {
310            images: vec![
311                CocoImage {
312                    id: 1,
313                    width: 640,
314                    height: 480,
315                    file_name: "image1.jpg".to_string(),
316                    ..Default::default()
317                },
318                CocoImage {
319                    id: 2,
320                    width: 800,
321                    height: 600,
322                    file_name: "image2.jpg".to_string(),
323                    ..Default::default()
324                },
325            ],
326            categories: vec![
327                CocoCategory {
328                    id: 1,
329                    name: "person".to_string(),
330                    supercategory: Some("human".to_string()),
331                    ..Default::default()
332                },
333                CocoCategory {
334                    id: 2,
335                    name: "car".to_string(),
336                    supercategory: Some("vehicle".to_string()),
337                    ..Default::default()
338                },
339            ],
340            annotations: vec![
341                CocoAnnotation {
342                    id: 100,
343                    image_id: 1,
344                    category_id: 1,
345                    bbox: [10.0, 20.0, 100.0, 200.0],
346                    area: 20000.0,
347                    iscrowd: 0,
348                    segmentation: None,
349                    score: None,
350                },
351                CocoAnnotation {
352                    id: 101,
353                    image_id: 1,
354                    category_id: 2,
355                    bbox: [50.0, 60.0, 150.0, 100.0],
356                    area: 15000.0,
357                    iscrowd: 0,
358                    segmentation: None,
359                    score: None,
360                },
361            ],
362            ..Default::default()
363        };
364
365        let index = CocoIndex::from_dataset(&dataset);
366
367        // Check images lookup
368        assert_eq!(index.images.len(), 2);
369        assert_eq!(index.images.get(&1).unwrap().file_name, "image1.jpg");
370
371        // Check categories lookup
372        assert_eq!(index.categories.len(), 2);
373        assert_eq!(index.label_name(1), Some("person"));
374        assert_eq!(index.label_name(2), Some("car"));
375
376        // Check source-faithful label indices (category_id preserved)
377        assert_eq!(index.label_index(2), Some(2)); // car
378        assert_eq!(index.label_index(1), Some(1)); // person
379
380        // Check annotations by image
381        let anns = index.annotations_for_image(1);
382        assert_eq!(anns.len(), 2);
383
384        let anns = index.annotations_for_image(2);
385        assert!(anns.is_empty());
386    }
387
388    #[test]
389    fn test_coco_segmentation_polygon_deserialize() {
390        let json = r#"[[100.0, 200.0, 150.0, 250.0, 100.0, 250.0]]"#;
391        let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
392
393        match seg {
394            CocoSegmentation::Polygon(polys) => {
395                assert_eq!(polys.len(), 1);
396                assert_eq!(polys[0].len(), 6);
397            }
398            _ => panic!("Expected polygon segmentation"),
399        }
400    }
401
402    #[test]
403    fn test_coco_segmentation_rle_deserialize() {
404        let json = r#"{"counts": [10, 20, 30, 40], "size": [100, 200]}"#;
405        let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
406
407        match seg {
408            CocoSegmentation::Rle(rle) => {
409                assert_eq!(rle.counts, vec![10, 20, 30, 40]);
410                assert_eq!(rle.size, [100, 200]);
411            }
412            _ => panic!("Expected RLE segmentation"),
413        }
414    }
415
416    #[test]
417    fn test_coco_annotation_roundtrip() {
418        let ann = CocoAnnotation {
419            id: 12345,
420            image_id: 67890,
421            category_id: 1,
422            bbox: [100.5, 200.5, 50.0, 80.0],
423            area: 4000.0,
424            iscrowd: 0,
425            segmentation: Some(CocoSegmentation::Polygon(vec![vec![
426                100.0, 200.0, 150.0, 200.0, 150.0, 280.0, 100.0, 280.0,
427            ]])),
428            score: None,
429        };
430
431        let json = serde_json::to_string(&ann).unwrap();
432        let restored: CocoAnnotation = serde_json::from_str(&json).unwrap();
433
434        assert_eq!(restored.id, ann.id);
435        assert_eq!(restored.image_id, ann.image_id);
436        assert_eq!(restored.category_id, ann.category_id);
437        assert_eq!(restored.bbox, ann.bbox);
438    }
439
440    #[test]
441    fn test_coco_index_preserves_category_id() {
442        // Non-contiguous category IDs (typical of COCO/LVIS datasets)
443        let dataset = CocoDataset {
444            images: vec![CocoImage {
445                id: 1,
446                width: 640,
447                height: 480,
448                file_name: "img.jpg".to_string(),
449                ..Default::default()
450            }],
451            categories: vec![
452                CocoCategory {
453                    id: 1,
454                    name: "person".to_string(),
455                    supercategory: None,
456                    ..Default::default()
457                },
458                CocoCategory {
459                    id: 3,
460                    name: "car".to_string(),
461                    supercategory: None,
462                    ..Default::default()
463                },
464                CocoCategory {
465                    id: 90,
466                    name: "toothbrush".to_string(),
467                    supercategory: None,
468                    ..Default::default()
469                },
470            ],
471            annotations: vec![],
472            ..Default::default()
473        };
474
475        let index = CocoIndex::from_dataset(&dataset);
476
477        // label_index must equal original category_id, NOT alphabetical order
478        assert_eq!(index.label_index(1), Some(1)); // person
479        assert_eq!(index.label_index(3), Some(3)); // car
480        assert_eq!(index.label_index(90), Some(90)); // toothbrush
481
482        // Unknown category returns None
483        assert_eq!(index.label_index(2), None);
484        assert_eq!(index.label_index(50), None);
485    }
486
487    #[test]
488    fn test_lvis_image_deserialize() {
489        let json = r#"{
490            "id": 397133,
491            "width": 640,
492            "height": 480,
493            "file_name": "000000397133.jpg",
494            "neg_category_ids": [5, 12, 87],
495            "not_exhaustive_category_ids": [3, 45]
496        }"#;
497        let image: CocoImage = serde_json::from_str(json).unwrap();
498        assert_eq!(image.neg_category_ids, Some(vec![5, 12, 87]));
499        assert_eq!(image.not_exhaustive_category_ids, Some(vec![3, 45]));
500    }
501
502    #[test]
503    fn test_lvis_category_deserialize() {
504        let json = r#"{
505            "id": 1,
506            "name": "aerosol_can",
507            "synset": "aerosol.n.02",
508            "frequency": "c",
509            "synonyms": ["aerosol_can", "spray_can"],
510            "def": "a dispenser that holds a substance under pressure",
511            "image_count": 57,
512            "instance_count": 98
513        }"#;
514        let cat: CocoCategory = serde_json::from_str(json).unwrap();
515        assert_eq!(cat.synset, Some("aerosol.n.02".to_string()));
516        assert_eq!(cat.frequency, Some("c".to_string()));
517        assert_eq!(
518            cat.synonyms,
519            Some(vec!["aerosol_can".to_string(), "spray_can".to_string()])
520        );
521        assert_eq!(
522            cat.def,
523            Some("a dispenser that holds a substance under pressure".to_string())
524        );
525        assert_eq!(cat.image_count, Some(57));
526        assert_eq!(cat.instance_count, Some(98));
527    }
528
529    #[test]
530    fn test_standard_coco_still_parses() {
531        let json = r#"{"id": 1, "name": "person", "supercategory": "human"}"#;
532        let cat: CocoCategory = serde_json::from_str(json).unwrap();
533        assert_eq!(cat.name, "person");
534        assert_eq!(cat.synset, None);
535        assert_eq!(cat.frequency, None);
536        assert_eq!(cat.synonyms, None);
537        assert_eq!(cat.def, None);
538    }
539}