Skip to main content

edgefirst_client/coco/
reader.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Streaming COCO JSON/ZIP readers.
5//!
6//! Provides memory-efficient reading of COCO annotation files from JSON files
7//! or ZIP archives without requiring full extraction.
8
9use super::types::CocoDataset;
10use crate::Error;
11use std::{
12    collections::HashSet,
13    fs::File,
14    io::{BufReader, Read},
15    path::Path,
16};
17
18/// Options for COCO reading.
19#[derive(Debug, Clone, Default)]
20pub struct CocoReadOptions {
21    /// If true, validate all annotations during reading.
22    pub validate: bool,
23    /// Maximum number of images to read (0 = unlimited).
24    pub max_images: usize,
25    /// Filter by category names (empty = all).
26    pub category_filter: Vec<String>,
27}
28
29/// Streaming COCO reader for large datasets.
30///
31/// Supports reading from JSON files and ZIP archives.
32///
33/// # Example
34///
35/// ```rust,no_run
36/// use edgefirst_client::coco::CocoReader;
37///
38/// let reader = CocoReader::new();
39/// let dataset = reader.read_json("annotations/instances_val2017.json")?;
40/// println!("Loaded {} images", dataset.images.len());
41/// # Ok::<(), edgefirst_client::Error>(())
42/// ```
43pub struct CocoReader {
44    options: CocoReadOptions,
45}
46
47impl CocoReader {
48    /// Create a new COCO reader with default options.
49    pub fn new() -> Self {
50        Self {
51            options: CocoReadOptions::default(),
52        }
53    }
54
55    /// Create a new COCO reader with custom options.
56    pub fn with_options(options: CocoReadOptions) -> Self {
57        Self { options }
58    }
59
60    /// Read COCO dataset from a JSON file.
61    ///
62    /// # Arguments
63    /// * `path` - Path to the COCO JSON annotation file
64    ///
65    /// # Returns
66    /// Parsed `CocoDataset` structure
67    pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
68        let file = File::open(path.as_ref())?;
69        let reader = BufReader::with_capacity(64 * 1024, file);
70        let dataset: CocoDataset = serde_json::from_reader(reader)?;
71
72        if self.options.validate {
73            validate_dataset(&dataset)?;
74        }
75
76        Ok(self.apply_filters(dataset))
77    }
78
79    /// Read COCO annotations from a ZIP file.
80    ///
81    /// Looks for annotation JSON files in standard COCO locations:
82    /// - `annotations/instances_*.json`
83    /// - `annotations/*.json`
84    /// - Root level `*.json` files
85    ///
86    /// # Arguments
87    /// * `path` - Path to the ZIP archive containing annotations
88    ///
89    /// # Returns
90    /// Merged `CocoDataset` from all annotation files found
91    pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
92        let file = File::open(path.as_ref())?;
93        let mut archive = zip::ZipArchive::new(file)?;
94
95        let mut merged = CocoDataset::default();
96
97        for i in 0..archive.len() {
98            let mut entry = archive.by_index(i)?;
99            let name = entry.name().to_string();
100
101            // Only process JSON files containing annotations
102            if name.ends_with(".json") && name.contains("instances") {
103                let mut contents = String::new();
104                entry.read_to_string(&mut contents)?;
105
106                let dataset: CocoDataset = serde_json::from_str(&contents)?;
107                merge_datasets(&mut merged, dataset);
108            }
109        }
110
111        if self.options.validate {
112            validate_dataset(&merged)?;
113        }
114
115        Ok(self.apply_filters(merged))
116    }
117
118    /// List image files in a COCO ZIP or folder.
119    ///
120    /// # Arguments
121    /// * `path` - Path to COCO images folder or ZIP archive
122    ///
123    /// # Returns
124    /// Vector of `(relative_path, absolute_path)` for each image
125    pub fn list_images<P: AsRef<Path>>(
126        &self,
127        path: P,
128    ) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
129        let path = path.as_ref();
130        let mut images = Vec::new();
131
132        if path.is_dir() {
133            // Walk directory
134            for entry in walkdir::WalkDir::new(path)
135                .into_iter()
136                .filter_map(|e| e.ok())
137                .filter(|e| e.file_type().is_file())
138            {
139                let filename = entry.file_name().to_string_lossy().to_lowercase();
140                if filename.ends_with(".jpg")
141                    || filename.ends_with(".jpeg")
142                    || filename.ends_with(".png")
143                {
144                    let rel_path = entry
145                        .path()
146                        .strip_prefix(path)
147                        .unwrap_or(entry.path())
148                        .to_string_lossy()
149                        .to_string();
150                    images.push((rel_path, entry.path().to_path_buf()));
151                }
152            }
153        } else if path.extension().is_some_and(|e| e == "zip") {
154            // List from ZIP
155            let file = File::open(path)?;
156            let mut archive = zip::ZipArchive::new(file)?;
157
158            for i in 0..archive.len() {
159                let entry = archive.by_index(i)?;
160                let name = entry.name().to_string();
161                let name_lower = name.to_lowercase();
162
163                if !entry.is_dir()
164                    && (name_lower.ends_with(".jpg")
165                        || name_lower.ends_with(".jpeg")
166                        || name_lower.ends_with(".png"))
167                {
168                    images.push((name.clone(), path.join(&name)));
169                }
170            }
171        }
172
173        Ok(images)
174    }
175
176    /// Read a single image from a ZIP archive.
177    ///
178    /// # Arguments
179    /// * `zip_path` - Path to the ZIP archive
180    /// * `image_name` - Name of the image file within the archive
181    ///
182    /// # Returns
183    /// Raw image bytes
184    pub fn read_image_from_zip<P: AsRef<Path>>(
185        &self,
186        zip_path: P,
187        image_name: &str,
188    ) -> Result<Vec<u8>, Error> {
189        let file = File::open(zip_path.as_ref())?;
190        let mut archive = zip::ZipArchive::new(file)?;
191
192        let mut entry = archive.by_name(image_name)?;
193        let mut buffer = Vec::with_capacity(entry.size() as usize);
194        entry.read_to_end(&mut buffer)?;
195
196        Ok(buffer)
197    }
198
199    /// Apply filters from options to the dataset.
200    fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
201        // Apply max_images filter
202        if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
203            let image_ids: HashSet<_> = dataset
204                .images
205                .iter()
206                .take(self.options.max_images)
207                .map(|i| i.id)
208                .collect();
209
210            dataset.images.truncate(self.options.max_images);
211            dataset
212                .annotations
213                .retain(|a| image_ids.contains(&a.image_id));
214        }
215
216        // Apply category filter
217        if !self.options.category_filter.is_empty() {
218            let category_ids: HashSet<_> = dataset
219                .categories
220                .iter()
221                .filter(|c| self.options.category_filter.contains(&c.name))
222                .map(|c| c.id)
223                .collect();
224
225            dataset
226                .categories
227                .retain(|c| self.options.category_filter.contains(&c.name));
228            dataset
229                .annotations
230                .retain(|a| category_ids.contains(&a.category_id));
231        }
232
233        dataset
234    }
235}
236
237impl Default for CocoReader {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243/// Validate a COCO dataset for consistency.
244fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
245    let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
246    let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
247
248    for ann in &dataset.annotations {
249        if !image_ids.contains(&ann.image_id) {
250            return Err(Error::CocoError(format!(
251                "Annotation {} references non-existent image_id {}",
252                ann.id, ann.image_id
253            )));
254        }
255
256        if !category_ids.contains(&ann.category_id) {
257            return Err(Error::CocoError(format!(
258                "Annotation {} references non-existent category_id {}",
259                ann.id, ann.category_id
260            )));
261        }
262
263        // Validate bbox
264        if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
265            return Err(Error::CocoError(format!(
266                "Annotation {} has invalid bbox dimensions",
267                ann.id
268            )));
269        }
270    }
271
272    Ok(())
273}
274
275/// Infer group name from COCO annotation filename.
276///
277/// Extracts the split name from standard COCO naming conventions:
278/// - `instances_train2017.json` → `"train"`
279/// - `instances_val2017.json` → `"val"`
280/// - `instances_test2017.json` → `"test"`
281/// - `person_keypoints_train2017.json` → `"train"`
282///
283/// # Arguments
284/// * `filename` - The annotation file name
285///
286/// # Returns
287/// The inferred group name if extraction succeeds
288pub fn infer_group_from_filename(filename: &str) -> Option<String> {
289    let stem = Path::new(filename).file_stem()?.to_str()?;
290
291    // Try common COCO patterns
292    // Pattern: instances_<group><year>.json
293    if let Some(rest) = stem.strip_prefix("instances_") {
294        let group = rest.trim_end_matches(char::is_numeric);
295        if !group.is_empty() {
296            return Some(group.to_string());
297        }
298    }
299
300    // Pattern: person_keypoints_<group><year>.json
301    if let Some(rest) = stem.strip_prefix("person_keypoints_") {
302        let group = rest.trim_end_matches(char::is_numeric);
303        if !group.is_empty() {
304            return Some(group.to_string());
305        }
306    }
307
308    // Pattern: captions_<group><year>.json
309    if let Some(rest) = stem.strip_prefix("captions_") {
310        let group = rest.trim_end_matches(char::is_numeric);
311        if !group.is_empty() {
312            return Some(group.to_string());
313        }
314    }
315
316    // Pattern: panoptic_<group><year>.json
317    if let Some(rest) = stem.strip_prefix("panoptic_") {
318        let group = rest.trim_end_matches(char::is_numeric);
319        if !group.is_empty() {
320            return Some(group.to_string());
321        }
322    }
323
324    // Fallback: look for train/val/test anywhere in the filename
325    let lower = filename.to_lowercase();
326    if lower.contains("train") {
327        return Some("train".to_string());
328    }
329    if lower.contains("val") {
330        return Some("val".to_string());
331    }
332    if lower.contains("test") {
333        return Some("test".to_string());
334    }
335
336    None
337}
338
339/// Infer the group name from an image folder path.
340///
341/// Extracts the group from folder names like "train2017", "val2017",
342/// "test2017". Strips trailing year numbers to get the group name.
343///
344/// # Examples
345/// - `train2017/000000001.jpg` → "train"
346/// - `val2017/000000002.jpg` → "val"
347/// - `test2017/000000003.jpg` → "test"
348/// - `custom_split/image.jpg` → "custom_split"
349///
350/// # Arguments
351/// * `image_path` - Relative path to the image (e.g.,
352///   "train2017/000000001.jpg")
353///
354/// # Returns
355/// Inferred group name, or None if no folder component is found.
356pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
357    let path = Path::new(image_path);
358
359    // Get the parent folder name (e.g., "train2017" from "train2017/image.jpg")
360    let folder = path.parent()?.file_name()?.to_str()?;
361
362    if folder.is_empty() {
363        return None;
364    }
365
366    // Strip trailing year numbers (e.g., "train2017" → "train")
367    let group = folder.trim_end_matches(char::is_numeric);
368
369    if group.is_empty() {
370        // Folder was all digits, use original
371        Some(folder.to_string())
372    } else {
373        Some(group.to_string())
374    }
375}
376
377/// Read all COCO annotation files from a directory.
378///
379/// Discovers and reads annotation files from standard COCO directory
380/// structures:
381///
382/// ```text
383/// coco_dir/
384/// ├── annotations/
385/// │   ├── instances_train2017.json
386/// │   └── instances_val2017.json
387/// └── ...
388/// ```
389///
390/// # Arguments
391///
392/// * `path` - Path to the COCO directory
393/// * `options` - Read options
394///
395/// # Returns
396///
397/// Vector of `(CocoDataset, inferred_group)` pairs
398pub fn read_coco_directory<P: AsRef<Path>>(
399    path: P,
400    options: &CocoReadOptions,
401) -> Result<Vec<(CocoDataset, String)>, Error> {
402    let path = path.as_ref();
403    let mut results = Vec::new();
404
405    // Look for annotation files
406    let annotations_dir = path.join("annotations");
407    let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
408        vec![annotations_dir.as_path(), path]
409    } else {
410        vec![path]
411    };
412
413    for search_dir in search_dirs {
414        if !search_dir.is_dir() {
415            continue;
416        }
417
418        for entry in std::fs::read_dir(search_dir)? {
419            let entry = entry?;
420            let file_path = entry.path();
421
422            if !file_path.is_file() {
423                continue;
424            }
425
426            let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
427
428            // Only process instance annotation files
429            if filename.ends_with(".json") && filename.contains("instances") {
430                let group =
431                    infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
432
433                let reader = CocoReader::with_options(options.clone());
434                let dataset = reader.read_json(&file_path)?;
435
436                results.push((dataset, group));
437            }
438        }
439    }
440
441    if results.is_empty() {
442        return Err(Error::MissingAnnotations(format!(
443            "No COCO annotation files found in {}",
444            path.display()
445        )));
446    }
447
448    Ok(results)
449}
450
451/// Merge a source dataset into a target dataset.
452fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
453    // Take info if not set
454    if target.info.description.is_none() {
455        target.info = source.info;
456    }
457
458    // Merge images (deduplicate by id)
459    let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
460    for image in source.images {
461        if !existing_ids.contains(&image.id) {
462            target.images.push(image);
463        }
464    }
465
466    // Merge categories (deduplicate by id)
467    let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
468    for cat in source.categories {
469        if !existing_cats.contains(&cat.id) {
470            target.categories.push(cat);
471        }
472    }
473
474    // Merge annotations (always append - IDs should be globally unique)
475    target.annotations.extend(source.annotations);
476
477    // Merge licenses
478    let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
479    for lic in source.licenses {
480        if !existing_licenses.contains(&lic.id) {
481            target.licenses.push(lic);
482        }
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
490
491    #[test]
492    fn test_reader_default() {
493        let reader = CocoReader::new();
494        assert!(!reader.options.validate);
495        assert_eq!(reader.options.max_images, 0);
496        assert!(reader.options.category_filter.is_empty());
497    }
498
499    #[test]
500    fn test_reader_with_options() {
501        let options = CocoReadOptions {
502            validate: true,
503            max_images: 100,
504            category_filter: vec!["person".to_string()],
505        };
506        let reader = CocoReader::with_options(options.clone());
507        assert!(reader.options.validate);
508        assert_eq!(reader.options.max_images, 100);
509    }
510
511    #[test]
512    fn test_validate_dataset_valid() {
513        let dataset = CocoDataset {
514            images: vec![CocoImage {
515                id: 1,
516                width: 640,
517                height: 480,
518                file_name: "test.jpg".to_string(),
519                ..Default::default()
520            }],
521            categories: vec![CocoCategory {
522                id: 1,
523                name: "person".to_string(),
524                supercategory: None,
525                ..Default::default()
526            }],
527            annotations: vec![CocoAnnotation {
528                id: 1,
529                image_id: 1,
530                category_id: 1,
531                bbox: [10.0, 20.0, 100.0, 80.0],
532                area: 8000.0,
533                iscrowd: 0,
534                segmentation: None,
535                score: None,
536            }],
537            ..Default::default()
538        };
539
540        assert!(validate_dataset(&dataset).is_ok());
541    }
542
543    #[test]
544    fn test_validate_dataset_missing_image() {
545        let dataset = CocoDataset {
546            images: vec![],
547            categories: vec![CocoCategory {
548                id: 1,
549                name: "person".to_string(),
550                supercategory: None,
551                ..Default::default()
552            }],
553            annotations: vec![CocoAnnotation {
554                id: 1,
555                image_id: 999, // Non-existent
556                category_id: 1,
557                bbox: [10.0, 20.0, 100.0, 80.0],
558                ..Default::default()
559            }],
560            ..Default::default()
561        };
562
563        assert!(validate_dataset(&dataset).is_err());
564    }
565
566    #[test]
567    fn test_merge_datasets() {
568        let mut target = CocoDataset {
569            images: vec![CocoImage {
570                id: 1,
571                width: 640,
572                height: 480,
573                file_name: "img1.jpg".to_string(),
574                ..Default::default()
575            }],
576            categories: vec![CocoCategory {
577                id: 1,
578                name: "person".to_string(),
579                supercategory: None,
580                ..Default::default()
581            }],
582            annotations: vec![],
583            ..Default::default()
584        };
585
586        let source = CocoDataset {
587            images: vec![
588                CocoImage {
589                    id: 1, // Duplicate - should not be added
590                    width: 640,
591                    height: 480,
592                    file_name: "img1.jpg".to_string(),
593                    ..Default::default()
594                },
595                CocoImage {
596                    id: 2, // New - should be added
597                    width: 800,
598                    height: 600,
599                    file_name: "img2.jpg".to_string(),
600                    ..Default::default()
601                },
602            ],
603            categories: vec![CocoCategory {
604                id: 2,
605                name: "car".to_string(),
606                supercategory: None,
607                ..Default::default()
608            }],
609            annotations: vec![],
610            ..Default::default()
611        };
612
613        merge_datasets(&mut target, source);
614
615        assert_eq!(target.images.len(), 2);
616        assert_eq!(target.categories.len(), 2);
617    }
618
619    #[test]
620    fn test_apply_max_images_filter() {
621        let reader = CocoReader::with_options(CocoReadOptions {
622            max_images: 2,
623            ..Default::default()
624        });
625
626        let dataset = CocoDataset {
627            images: vec![
628                CocoImage {
629                    id: 1,
630                    ..Default::default()
631                },
632                CocoImage {
633                    id: 2,
634                    ..Default::default()
635                },
636                CocoImage {
637                    id: 3,
638                    ..Default::default()
639                },
640            ],
641            annotations: vec![
642                CocoAnnotation {
643                    id: 1,
644                    image_id: 1,
645                    ..Default::default()
646                },
647                CocoAnnotation {
648                    id: 2,
649                    image_id: 2,
650                    ..Default::default()
651                },
652                CocoAnnotation {
653                    id: 3,
654                    image_id: 3,
655                    ..Default::default()
656                },
657            ],
658            ..Default::default()
659        };
660
661        let filtered = reader.apply_filters(dataset);
662        assert_eq!(filtered.images.len(), 2);
663        assert_eq!(filtered.annotations.len(), 2);
664    }
665
666    #[test]
667    fn test_infer_group_from_filename_instances() {
668        assert_eq!(
669            infer_group_from_filename("instances_train2017.json"),
670            Some("train".to_string())
671        );
672        assert_eq!(
673            infer_group_from_filename("instances_val2017.json"),
674            Some("val".to_string())
675        );
676        assert_eq!(
677            infer_group_from_filename("instances_test2017.json"),
678            Some("test".to_string())
679        );
680    }
681
682    #[test]
683    fn test_infer_group_from_filename_keypoints() {
684        assert_eq!(
685            infer_group_from_filename("person_keypoints_train2017.json"),
686            Some("train".to_string())
687        );
688        assert_eq!(
689            infer_group_from_filename("person_keypoints_val2017.json"),
690            Some("val".to_string())
691        );
692    }
693
694    #[test]
695    fn test_infer_group_from_filename_captions() {
696        assert_eq!(
697            infer_group_from_filename("captions_train2017.json"),
698            Some("train".to_string())
699        );
700        assert_eq!(
701            infer_group_from_filename("captions_val2017.json"),
702            Some("val".to_string())
703        );
704    }
705
706    #[test]
707    fn test_infer_group_from_filename_panoptic() {
708        assert_eq!(
709            infer_group_from_filename("panoptic_train2017.json"),
710            Some("train".to_string())
711        );
712        assert_eq!(
713            infer_group_from_filename("panoptic_val2017.json"),
714            Some("val".to_string())
715        );
716    }
717
718    #[test]
719    fn test_infer_group_from_filename_fallback() {
720        // Falls back to looking for train/val/test in filename
721        assert_eq!(
722            infer_group_from_filename("my_custom_train_annotations.json"),
723            Some("train".to_string())
724        );
725        assert_eq!(
726            infer_group_from_filename("validation_data.json"),
727            Some("val".to_string())
728        );
729    }
730
731    #[test]
732    fn test_infer_group_from_filename_no_match() {
733        // No recognizable pattern
734        assert_eq!(infer_group_from_filename("annotations.json"), None);
735        assert_eq!(infer_group_from_filename("data.json"), None);
736    }
737}