edgefirst_client/coco/
reader.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Streaming COCO JSON/ZIP readers.
5//!
6//! Provides memory-efficient reading of COCO annotation files from JSON files
7//! or ZIP archives without requiring full extraction.
8
9use super::types::CocoDataset;
10use crate::Error;
11use std::{
12    collections::HashSet,
13    fs::File,
14    io::{BufReader, Read},
15    path::Path,
16};
17
18/// Options for COCO reading.
19#[derive(Debug, Clone, Default)]
20pub struct CocoReadOptions {
21    /// If true, validate all annotations during reading.
22    pub validate: bool,
23    /// Maximum number of images to read (0 = unlimited).
24    pub max_images: usize,
25    /// Filter by category names (empty = all).
26    pub category_filter: Vec<String>,
27}
28
29/// Streaming COCO reader for large datasets.
30///
31/// Supports reading from JSON files and ZIP archives.
32///
33/// # Example
34///
35/// ```rust,no_run
36/// use edgefirst_client::coco::CocoReader;
37///
38/// let reader = CocoReader::new();
39/// let dataset = reader.read_json("annotations/instances_val2017.json")?;
40/// println!("Loaded {} images", dataset.images.len());
41/// # Ok::<(), edgefirst_client::Error>(())
42/// ```
43pub struct CocoReader {
44    options: CocoReadOptions,
45}
46
47impl CocoReader {
48    /// Create a new COCO reader with default options.
49    pub fn new() -> Self {
50        Self {
51            options: CocoReadOptions::default(),
52        }
53    }
54
55    /// Create a new COCO reader with custom options.
56    pub fn with_options(options: CocoReadOptions) -> Self {
57        Self { options }
58    }
59
60    /// Read COCO dataset from a JSON file.
61    ///
62    /// # Arguments
63    /// * `path` - Path to the COCO JSON annotation file
64    ///
65    /// # Returns
66    /// Parsed `CocoDataset` structure
67    pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
68        let file = File::open(path.as_ref())?;
69        let reader = BufReader::with_capacity(64 * 1024, file);
70        let dataset: CocoDataset = serde_json::from_reader(reader)?;
71
72        if self.options.validate {
73            validate_dataset(&dataset)?;
74        }
75
76        Ok(self.apply_filters(dataset))
77    }
78
79    /// Read COCO annotations from a ZIP file.
80    ///
81    /// Looks for annotation JSON files in standard COCO locations:
82    /// - `annotations/instances_*.json`
83    /// - `annotations/*.json`
84    /// - Root level `*.json` files
85    ///
86    /// # Arguments
87    /// * `path` - Path to the ZIP archive containing annotations
88    ///
89    /// # Returns
90    /// Merged `CocoDataset` from all annotation files found
91    pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
92        let file = File::open(path.as_ref())?;
93        let mut archive = zip::ZipArchive::new(file)?;
94
95        let mut merged = CocoDataset::default();
96
97        for i in 0..archive.len() {
98            let mut entry = archive.by_index(i)?;
99            let name = entry.name().to_string();
100
101            // Only process JSON files containing annotations
102            if name.ends_with(".json") && name.contains("instances") {
103                let mut contents = String::new();
104                entry.read_to_string(&mut contents)?;
105
106                let dataset: CocoDataset = serde_json::from_str(&contents)?;
107                merge_datasets(&mut merged, dataset);
108            }
109        }
110
111        if self.options.validate {
112            validate_dataset(&merged)?;
113        }
114
115        Ok(self.apply_filters(merged))
116    }
117
118    /// List image files in a COCO ZIP or folder.
119    ///
120    /// # Arguments
121    /// * `path` - Path to COCO images folder or ZIP archive
122    ///
123    /// # Returns
124    /// Vector of `(relative_path, absolute_path)` for each image
125    pub fn list_images<P: AsRef<Path>>(
126        &self,
127        path: P,
128    ) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
129        let path = path.as_ref();
130        let mut images = Vec::new();
131
132        if path.is_dir() {
133            // Walk directory
134            for entry in walkdir::WalkDir::new(path)
135                .into_iter()
136                .filter_map(|e| e.ok())
137                .filter(|e| e.file_type().is_file())
138            {
139                let filename = entry.file_name().to_string_lossy().to_lowercase();
140                if filename.ends_with(".jpg")
141                    || filename.ends_with(".jpeg")
142                    || filename.ends_with(".png")
143                {
144                    let rel_path = entry
145                        .path()
146                        .strip_prefix(path)
147                        .unwrap_or(entry.path())
148                        .to_string_lossy()
149                        .to_string();
150                    images.push((rel_path, entry.path().to_path_buf()));
151                }
152            }
153        } else if path.extension().is_some_and(|e| e == "zip") {
154            // List from ZIP
155            let file = File::open(path)?;
156            let mut archive = zip::ZipArchive::new(file)?;
157
158            for i in 0..archive.len() {
159                let entry = archive.by_index(i)?;
160                let name = entry.name().to_string();
161                let name_lower = name.to_lowercase();
162
163                if !entry.is_dir()
164                    && (name_lower.ends_with(".jpg")
165                        || name_lower.ends_with(".jpeg")
166                        || name_lower.ends_with(".png"))
167                {
168                    images.push((name.clone(), path.join(&name)));
169                }
170            }
171        }
172
173        Ok(images)
174    }
175
176    /// Read a single image from a ZIP archive.
177    ///
178    /// # Arguments
179    /// * `zip_path` - Path to the ZIP archive
180    /// * `image_name` - Name of the image file within the archive
181    ///
182    /// # Returns
183    /// Raw image bytes
184    pub fn read_image_from_zip<P: AsRef<Path>>(
185        &self,
186        zip_path: P,
187        image_name: &str,
188    ) -> Result<Vec<u8>, Error> {
189        let file = File::open(zip_path.as_ref())?;
190        let mut archive = zip::ZipArchive::new(file)?;
191
192        let mut entry = archive.by_name(image_name)?;
193        let mut buffer = Vec::with_capacity(entry.size() as usize);
194        entry.read_to_end(&mut buffer)?;
195
196        Ok(buffer)
197    }
198
199    /// Apply filters from options to the dataset.
200    fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
201        // Apply max_images filter
202        if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
203            let image_ids: HashSet<_> = dataset
204                .images
205                .iter()
206                .take(self.options.max_images)
207                .map(|i| i.id)
208                .collect();
209
210            dataset.images.truncate(self.options.max_images);
211            dataset
212                .annotations
213                .retain(|a| image_ids.contains(&a.image_id));
214        }
215
216        // Apply category filter
217        if !self.options.category_filter.is_empty() {
218            let category_ids: HashSet<_> = dataset
219                .categories
220                .iter()
221                .filter(|c| self.options.category_filter.contains(&c.name))
222                .map(|c| c.id)
223                .collect();
224
225            dataset
226                .categories
227                .retain(|c| self.options.category_filter.contains(&c.name));
228            dataset
229                .annotations
230                .retain(|a| category_ids.contains(&a.category_id));
231        }
232
233        dataset
234    }
235}
236
237impl Default for CocoReader {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243/// Validate a COCO dataset for consistency.
244fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
245    let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
246    let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
247
248    for ann in &dataset.annotations {
249        if !image_ids.contains(&ann.image_id) {
250            return Err(Error::CocoError(format!(
251                "Annotation {} references non-existent image_id {}",
252                ann.id, ann.image_id
253            )));
254        }
255
256        if !category_ids.contains(&ann.category_id) {
257            return Err(Error::CocoError(format!(
258                "Annotation {} references non-existent category_id {}",
259                ann.id, ann.category_id
260            )));
261        }
262
263        // Validate bbox
264        if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
265            return Err(Error::CocoError(format!(
266                "Annotation {} has invalid bbox dimensions",
267                ann.id
268            )));
269        }
270    }
271
272    Ok(())
273}
274
275/// Infer group name from COCO annotation filename.
276///
277/// Extracts the split name from standard COCO naming conventions:
278/// - `instances_train2017.json` → `"train"`
279/// - `instances_val2017.json` → `"val"`
280/// - `instances_test2017.json` → `"test"`
281/// - `person_keypoints_train2017.json` → `"train"`
282///
283/// # Arguments
284/// * `filename` - The annotation file name
285///
286/// # Returns
287/// The inferred group name if extraction succeeds
288pub fn infer_group_from_filename(filename: &str) -> Option<String> {
289    let stem = Path::new(filename).file_stem()?.to_str()?;
290
291    // Try common COCO patterns
292    // Pattern: instances_<group><year>.json
293    if let Some(rest) = stem.strip_prefix("instances_") {
294        let group = rest.trim_end_matches(char::is_numeric);
295        if !group.is_empty() {
296            return Some(group.to_string());
297        }
298    }
299
300    // Pattern: person_keypoints_<group><year>.json
301    if let Some(rest) = stem.strip_prefix("person_keypoints_") {
302        let group = rest.trim_end_matches(char::is_numeric);
303        if !group.is_empty() {
304            return Some(group.to_string());
305        }
306    }
307
308    // Pattern: captions_<group><year>.json
309    if let Some(rest) = stem.strip_prefix("captions_") {
310        let group = rest.trim_end_matches(char::is_numeric);
311        if !group.is_empty() {
312            return Some(group.to_string());
313        }
314    }
315
316    // Pattern: panoptic_<group><year>.json
317    if let Some(rest) = stem.strip_prefix("panoptic_") {
318        let group = rest.trim_end_matches(char::is_numeric);
319        if !group.is_empty() {
320            return Some(group.to_string());
321        }
322    }
323
324    // Fallback: look for train/val/test anywhere in the filename
325    let lower = filename.to_lowercase();
326    if lower.contains("train") {
327        return Some("train".to_string());
328    }
329    if lower.contains("val") {
330        return Some("val".to_string());
331    }
332    if lower.contains("test") {
333        return Some("test".to_string());
334    }
335
336    None
337}
338
339/// Infer the group name from an image folder path.
340///
341/// Extracts the group from folder names like "train2017", "val2017",
342/// "test2017". Strips trailing year numbers to get the group name.
343///
344/// # Examples
345/// - `train2017/000000001.jpg` → "train"
346/// - `val2017/000000002.jpg` → "val"
347/// - `test2017/000000003.jpg` → "test"
348/// - `custom_split/image.jpg` → "custom_split"
349///
350/// # Arguments
351/// * `image_path` - Relative path to the image (e.g.,
352///   "train2017/000000001.jpg")
353///
354/// # Returns
355/// Inferred group name, or None if no folder component is found.
356pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
357    let path = Path::new(image_path);
358
359    // Get the parent folder name (e.g., "train2017" from "train2017/image.jpg")
360    let folder = path.parent()?.file_name()?.to_str()?;
361
362    if folder.is_empty() {
363        return None;
364    }
365
366    // Strip trailing year numbers (e.g., "train2017" → "train")
367    let group = folder.trim_end_matches(char::is_numeric);
368
369    if group.is_empty() {
370        // Folder was all digits, use original
371        Some(folder.to_string())
372    } else {
373        Some(group.to_string())
374    }
375}
376
377/// Read all COCO annotation files from a directory.
378///
379/// Discovers and reads annotation files from standard COCO directory
380/// structures:
381///
382/// ```text
383/// coco_dir/
384/// ├── annotations/
385/// │   ├── instances_train2017.json
386/// │   └── instances_val2017.json
387/// └── ...
388/// ```
389///
390/// # Arguments
391///
392/// * `path` - Path to the COCO directory
393/// * `options` - Read options
394///
395/// # Returns
396///
397/// Vector of `(CocoDataset, inferred_group)` pairs
398pub fn read_coco_directory<P: AsRef<Path>>(
399    path: P,
400    options: &CocoReadOptions,
401) -> Result<Vec<(CocoDataset, String)>, Error> {
402    let path = path.as_ref();
403    let mut results = Vec::new();
404
405    // Look for annotation files
406    let annotations_dir = path.join("annotations");
407    let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
408        vec![annotations_dir.as_path(), path]
409    } else {
410        vec![path]
411    };
412
413    for search_dir in search_dirs {
414        if !search_dir.is_dir() {
415            continue;
416        }
417
418        for entry in std::fs::read_dir(search_dir)? {
419            let entry = entry?;
420            let file_path = entry.path();
421
422            if !file_path.is_file() {
423                continue;
424            }
425
426            let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
427
428            // Only process instance annotation files
429            if filename.ends_with(".json") && filename.contains("instances") {
430                let group =
431                    infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
432
433                let reader = CocoReader::with_options(options.clone());
434                let dataset = reader.read_json(&file_path)?;
435
436                results.push((dataset, group));
437            }
438        }
439    }
440
441    if results.is_empty() {
442        return Err(Error::MissingAnnotations(format!(
443            "No COCO annotation files found in {}",
444            path.display()
445        )));
446    }
447
448    Ok(results)
449}
450
451/// Merge a source dataset into a target dataset.
452fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
453    // Take info if not set
454    if target.info.description.is_none() {
455        target.info = source.info;
456    }
457
458    // Merge images (deduplicate by id)
459    let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
460    for image in source.images {
461        if !existing_ids.contains(&image.id) {
462            target.images.push(image);
463        }
464    }
465
466    // Merge categories (deduplicate by id)
467    let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
468    for cat in source.categories {
469        if !existing_cats.contains(&cat.id) {
470            target.categories.push(cat);
471        }
472    }
473
474    // Merge annotations (always append - IDs should be globally unique)
475    target.annotations.extend(source.annotations);
476
477    // Merge licenses
478    let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
479    for lic in source.licenses {
480        if !existing_licenses.contains(&lic.id) {
481            target.licenses.push(lic);
482        }
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
490
491    #[test]
492    fn test_reader_default() {
493        let reader = CocoReader::new();
494        assert!(!reader.options.validate);
495        assert_eq!(reader.options.max_images, 0);
496        assert!(reader.options.category_filter.is_empty());
497    }
498
499    #[test]
500    fn test_reader_with_options() {
501        let options = CocoReadOptions {
502            validate: true,
503            max_images: 100,
504            category_filter: vec!["person".to_string()],
505        };
506        let reader = CocoReader::with_options(options.clone());
507        assert!(reader.options.validate);
508        assert_eq!(reader.options.max_images, 100);
509    }
510
511    #[test]
512    fn test_validate_dataset_valid() {
513        let dataset = CocoDataset {
514            images: vec![CocoImage {
515                id: 1,
516                width: 640,
517                height: 480,
518                file_name: "test.jpg".to_string(),
519                ..Default::default()
520            }],
521            categories: vec![CocoCategory {
522                id: 1,
523                name: "person".to_string(),
524                supercategory: None,
525            }],
526            annotations: vec![CocoAnnotation {
527                id: 1,
528                image_id: 1,
529                category_id: 1,
530                bbox: [10.0, 20.0, 100.0, 80.0],
531                area: 8000.0,
532                iscrowd: 0,
533                segmentation: None,
534            }],
535            ..Default::default()
536        };
537
538        assert!(validate_dataset(&dataset).is_ok());
539    }
540
541    #[test]
542    fn test_validate_dataset_missing_image() {
543        let dataset = CocoDataset {
544            images: vec![],
545            categories: vec![CocoCategory {
546                id: 1,
547                name: "person".to_string(),
548                supercategory: None,
549            }],
550            annotations: vec![CocoAnnotation {
551                id: 1,
552                image_id: 999, // Non-existent
553                category_id: 1,
554                bbox: [10.0, 20.0, 100.0, 80.0],
555                ..Default::default()
556            }],
557            ..Default::default()
558        };
559
560        assert!(validate_dataset(&dataset).is_err());
561    }
562
563    #[test]
564    fn test_merge_datasets() {
565        let mut target = CocoDataset {
566            images: vec![CocoImage {
567                id: 1,
568                width: 640,
569                height: 480,
570                file_name: "img1.jpg".to_string(),
571                ..Default::default()
572            }],
573            categories: vec![CocoCategory {
574                id: 1,
575                name: "person".to_string(),
576                supercategory: None,
577            }],
578            annotations: vec![],
579            ..Default::default()
580        };
581
582        let source = CocoDataset {
583            images: vec![
584                CocoImage {
585                    id: 1, // Duplicate - should not be added
586                    width: 640,
587                    height: 480,
588                    file_name: "img1.jpg".to_string(),
589                    ..Default::default()
590                },
591                CocoImage {
592                    id: 2, // New - should be added
593                    width: 800,
594                    height: 600,
595                    file_name: "img2.jpg".to_string(),
596                    ..Default::default()
597                },
598            ],
599            categories: vec![CocoCategory {
600                id: 2,
601                name: "car".to_string(),
602                supercategory: None,
603            }],
604            annotations: vec![],
605            ..Default::default()
606        };
607
608        merge_datasets(&mut target, source);
609
610        assert_eq!(target.images.len(), 2);
611        assert_eq!(target.categories.len(), 2);
612    }
613
614    #[test]
615    fn test_apply_max_images_filter() {
616        let reader = CocoReader::with_options(CocoReadOptions {
617            max_images: 2,
618            ..Default::default()
619        });
620
621        let dataset = CocoDataset {
622            images: vec![
623                CocoImage {
624                    id: 1,
625                    ..Default::default()
626                },
627                CocoImage {
628                    id: 2,
629                    ..Default::default()
630                },
631                CocoImage {
632                    id: 3,
633                    ..Default::default()
634                },
635            ],
636            annotations: vec![
637                CocoAnnotation {
638                    id: 1,
639                    image_id: 1,
640                    ..Default::default()
641                },
642                CocoAnnotation {
643                    id: 2,
644                    image_id: 2,
645                    ..Default::default()
646                },
647                CocoAnnotation {
648                    id: 3,
649                    image_id: 3,
650                    ..Default::default()
651                },
652            ],
653            ..Default::default()
654        };
655
656        let filtered = reader.apply_filters(dataset);
657        assert_eq!(filtered.images.len(), 2);
658        assert_eq!(filtered.annotations.len(), 2);
659    }
660
661    #[test]
662    fn test_infer_group_from_filename_instances() {
663        assert_eq!(
664            infer_group_from_filename("instances_train2017.json"),
665            Some("train".to_string())
666        );
667        assert_eq!(
668            infer_group_from_filename("instances_val2017.json"),
669            Some("val".to_string())
670        );
671        assert_eq!(
672            infer_group_from_filename("instances_test2017.json"),
673            Some("test".to_string())
674        );
675    }
676
677    #[test]
678    fn test_infer_group_from_filename_keypoints() {
679        assert_eq!(
680            infer_group_from_filename("person_keypoints_train2017.json"),
681            Some("train".to_string())
682        );
683        assert_eq!(
684            infer_group_from_filename("person_keypoints_val2017.json"),
685            Some("val".to_string())
686        );
687    }
688
689    #[test]
690    fn test_infer_group_from_filename_captions() {
691        assert_eq!(
692            infer_group_from_filename("captions_train2017.json"),
693            Some("train".to_string())
694        );
695        assert_eq!(
696            infer_group_from_filename("captions_val2017.json"),
697            Some("val".to_string())
698        );
699    }
700
701    #[test]
702    fn test_infer_group_from_filename_panoptic() {
703        assert_eq!(
704            infer_group_from_filename("panoptic_train2017.json"),
705            Some("train".to_string())
706        );
707        assert_eq!(
708            infer_group_from_filename("panoptic_val2017.json"),
709            Some("val".to_string())
710        );
711    }
712
713    #[test]
714    fn test_infer_group_from_filename_fallback() {
715        // Falls back to looking for train/val/test in filename
716        assert_eq!(
717            infer_group_from_filename("my_custom_train_annotations.json"),
718            Some("train".to_string())
719        );
720        assert_eq!(
721            infer_group_from_filename("validation_data.json"),
722            Some("val".to_string())
723        );
724    }
725
726    #[test]
727    fn test_infer_group_from_filename_no_match() {
728        // No recognizable pattern
729        assert_eq!(infer_group_from_filename("annotations.json"), None);
730        assert_eq!(infer_group_from_filename("data.json"), None);
731    }
732}