Skip to main content

edgefirst_client/coco/
reader.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Streaming COCO JSON/ZIP readers.
5//!
6//! Provides memory-efficient reading of COCO annotation files from JSON files
7//! or ZIP archives without requiring full extraction.
8
9use super::types::CocoDataset;
10use crate::Error;
11use std::{
12    collections::HashSet,
13    fs::File,
14    io::{BufReader, Read},
15    path::Path,
16};
17
18/// Options for COCO reading.
19#[derive(Debug, Clone, Default)]
20pub struct CocoReadOptions {
21    /// If true, validate all annotations during reading.
22    pub validate: bool,
23    /// Maximum number of images to read (0 = unlimited).
24    pub max_images: usize,
25    /// Filter by category names (empty = all).
26    pub category_filter: Vec<String>,
27}
28
29/// Streaming COCO reader for large datasets.
30///
31/// Supports reading from JSON files and ZIP archives.
32///
33/// # Example
34///
35/// ```rust,no_run
36/// use edgefirst_client::coco::CocoReader;
37///
38/// let reader = CocoReader::new();
39/// let dataset = reader.read_json("annotations/instances_val2017.json")?;
40/// println!("Loaded {} images", dataset.images.len());
41/// # Ok::<(), edgefirst_client::Error>(())
42/// ```
43pub struct CocoReader {
44    options: CocoReadOptions,
45}
46
47impl CocoReader {
48    /// Create a new COCO reader with default options.
49    pub fn new() -> Self {
50        Self {
51            options: CocoReadOptions::default(),
52        }
53    }
54
55    /// Create a new COCO reader with custom options.
56    pub fn with_options(options: CocoReadOptions) -> Self {
57        Self { options }
58    }
59
60    /// Read COCO dataset from a JSON file.
61    ///
62    /// # Arguments
63    /// * `path` - Path to the COCO JSON annotation file
64    ///
65    /// # Returns
66    /// Parsed `CocoDataset` structure
67    pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
68        let file = File::open(path.as_ref())?;
69        let reader = BufReader::with_capacity(64 * 1024, file);
70        let mut dataset: CocoDataset = serde_json::from_reader(reader)?;
71        fill_missing_file_names(&mut dataset);
72
73        if self.options.validate {
74            validate_dataset(&dataset)?;
75        }
76
77        Ok(self.apply_filters(dataset))
78    }
79
80    /// Read COCO annotations from a ZIP file.
81    ///
82    /// Looks for annotation JSON files in standard COCO locations:
83    /// - `annotations/instances_*.json`
84    /// - `annotations/*.json`
85    /// - Root level `*.json` files
86    ///
87    /// # Arguments
88    /// * `path` - Path to the ZIP archive containing annotations
89    ///
90    /// # Returns
91    /// Merged `CocoDataset` from all annotation files found
92    pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
93        let file = File::open(path.as_ref())?;
94        let mut archive = zip::ZipArchive::new(file)?;
95
96        let mut merged = CocoDataset::default();
97
98        for i in 0..archive.len() {
99            let mut entry = archive.by_index(i)?;
100            let name = entry.name().to_string();
101
102            // Only process JSON files containing annotations
103            if name.ends_with(".json") && name.contains("instances") {
104                let mut contents = String::new();
105                entry.read_to_string(&mut contents)?;
106
107                let mut dataset: CocoDataset = serde_json::from_str(&contents)?;
108                fill_missing_file_names(&mut dataset);
109                merge_datasets(&mut merged, dataset);
110            }
111        }
112
113        if self.options.validate {
114            validate_dataset(&merged)?;
115        }
116
117        Ok(self.apply_filters(merged))
118    }
119
120    /// List image files in a COCO ZIP or folder.
121    ///
122    /// # Arguments
123    /// * `path` - Path to COCO images folder or ZIP archive
124    ///
125    /// # Returns
126    /// Vector of `(relative_path, absolute_path)` for each image
127    pub fn list_images<P: AsRef<Path>>(
128        &self,
129        path: P,
130    ) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
131        let path = path.as_ref();
132        let mut images = Vec::new();
133
134        if path.is_dir() {
135            // Walk directory
136            for entry in walkdir::WalkDir::new(path)
137                .into_iter()
138                .filter_map(|e| e.ok())
139                .filter(|e| e.file_type().is_file())
140            {
141                let filename = entry.file_name().to_string_lossy().to_lowercase();
142                if filename.ends_with(".jpg")
143                    || filename.ends_with(".jpeg")
144                    || filename.ends_with(".png")
145                {
146                    let rel_path = entry
147                        .path()
148                        .strip_prefix(path)
149                        .unwrap_or(entry.path())
150                        .to_string_lossy()
151                        .to_string();
152                    images.push((rel_path, entry.path().to_path_buf()));
153                }
154            }
155        } else if path.extension().is_some_and(|e| e == "zip") {
156            // List from ZIP
157            let file = File::open(path)?;
158            let mut archive = zip::ZipArchive::new(file)?;
159
160            for i in 0..archive.len() {
161                let entry = archive.by_index(i)?;
162                let name = entry.name().to_string();
163                let name_lower = name.to_lowercase();
164
165                if !entry.is_dir()
166                    && (name_lower.ends_with(".jpg")
167                        || name_lower.ends_with(".jpeg")
168                        || name_lower.ends_with(".png"))
169                {
170                    images.push((name.clone(), path.join(&name)));
171                }
172            }
173        }
174
175        Ok(images)
176    }
177
178    /// Read a single image from a ZIP archive.
179    ///
180    /// # Arguments
181    /// * `zip_path` - Path to the ZIP archive
182    /// * `image_name` - Name of the image file within the archive
183    ///
184    /// # Returns
185    /// Raw image bytes
186    pub fn read_image_from_zip<P: AsRef<Path>>(
187        &self,
188        zip_path: P,
189        image_name: &str,
190    ) -> Result<Vec<u8>, Error> {
191        let file = File::open(zip_path.as_ref())?;
192        let mut archive = zip::ZipArchive::new(file)?;
193
194        let mut entry = archive.by_name(image_name)?;
195        let mut buffer = Vec::with_capacity(entry.size() as usize);
196        entry.read_to_end(&mut buffer)?;
197
198        Ok(buffer)
199    }
200
201    /// Apply filters from options to the dataset.
202    fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
203        // Apply max_images filter
204        if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
205            let image_ids: HashSet<_> = dataset
206                .images
207                .iter()
208                .take(self.options.max_images)
209                .map(|i| i.id)
210                .collect();
211
212            dataset.images.truncate(self.options.max_images);
213            dataset
214                .annotations
215                .retain(|a| image_ids.contains(&a.image_id));
216        }
217
218        // Apply category filter
219        if !self.options.category_filter.is_empty() {
220            let category_ids: HashSet<_> = dataset
221                .categories
222                .iter()
223                .filter(|c| self.options.category_filter.contains(&c.name))
224                .map(|c| c.id)
225                .collect();
226
227            dataset
228                .categories
229                .retain(|c| self.options.category_filter.contains(&c.name));
230            dataset
231                .annotations
232                .retain(|a| category_ids.contains(&a.category_id));
233        }
234
235        dataset
236    }
237}
238
239impl Default for CocoReader {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245/// Validate a COCO dataset for consistency.
246fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
247    let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
248    let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
249
250    for ann in &dataset.annotations {
251        if !image_ids.contains(&ann.image_id) {
252            return Err(Error::CocoError(format!(
253                "Annotation {} references non-existent image_id {}",
254                ann.id, ann.image_id
255            )));
256        }
257
258        if !category_ids.contains(&ann.category_id) {
259            return Err(Error::CocoError(format!(
260                "Annotation {} references non-existent category_id {}",
261                ann.id, ann.category_id
262            )));
263        }
264
265        // Validate bbox
266        if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
267            return Err(Error::CocoError(format!(
268                "Annotation {} has invalid bbox dimensions",
269                ann.id
270            )));
271        }
272    }
273
274    Ok(())
275}
276
277/// Derive a COCO-style relative path from an LVIS `coco_url`.
278///
279/// LVIS reuses COCO 2017 images and only records the URL (for example
280/// `http://images.cocodataset.org/val2017/000000397133.jpg`). Everything
281/// after the host is the same relative path COCO stores in `file_name`,
282/// so we strip the scheme and host and return the remainder.
283///
284/// Downstream code joins the returned value onto an images base directory
285/// (e.g. `images_dir.join(file_name)`), so the result is sanitised to
286/// guarantee it stays inside that base:
287///
288/// - Leading `/` (absolute path) is rejected — `Path::join` would replace
289///   the base entirely on Unix.
290/// - Any `..` segment is rejected — `Path::join` preserves parent-dir
291///   traversals and would escape the base.
292/// - Windows drive prefixes (`C:\...`) are rejected — `Path::join` would
293///   replace the base on Windows.
294///
295/// Sanitisation failures return `None`; callers leave `file_name` empty,
296/// and later pipeline stages surface a clearer error.
297fn derive_file_name_from_coco_url(url: &str) -> Option<String> {
298    let after_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
299    let (_host, path) = after_scheme.split_once('/')?;
300    if path.is_empty() {
301        return None;
302    }
303
304    // Reject absolute paths. A leading `/` would make `Path::join` discard
305    // the images base directory and resolve relative to the filesystem root.
306    if path.starts_with('/') {
307        return None;
308    }
309
310    // Reject any `..` segment. `Path::join` does not normalise parent
311    // components, so `images_dir.join("../etc/passwd")` escapes the base.
312    // Also reject literal backslashes — `Path::components` treats `/` as
313    // the only separator on Unix, so `..\\foo` would slip past a naive
314    // component check on a Linux build.
315    //
316    // Reject any `:` — `Component::Prefix` only fires on Windows builds,
317    // so `C:\\Windows` reads as a single Normal component on Linux and
318    // would be joined onto the base, only to blow up once the resulting
319    // path hits a Windows host. COCO filenames never contain colons, so
320    // this is a safe, portable way to reject drive-letter prefixes.
321    if path.contains('\\') || path.contains(':') {
322        return None;
323    }
324    let as_path = Path::new(path);
325    for component in as_path.components() {
326        use std::path::Component;
327        match component {
328            Component::Normal(_) => continue,
329            // Any of these indicate a non-relative path or a traversal.
330            Component::RootDir
331            | Component::Prefix(_)
332            | Component::ParentDir
333            | Component::CurDir => return None,
334        }
335    }
336
337    Some(path.to_string())
338}
339
340/// Populate empty `file_name` fields from `coco_url` (LVIS compatibility).
341///
342/// LVIS annotation JSONs omit `file_name`, relying on consumers to parse
343/// `coco_url`. This runs once immediately after deserialisation so every
344/// downstream stage (index building, Arrow conversion, studio upload) can
345/// continue to assume `file_name` is populated.
346fn fill_missing_file_names(dataset: &mut CocoDataset) {
347    for image in &mut dataset.images {
348        if !image.file_name.is_empty() {
349            continue;
350        }
351        if let Some(derived) = image
352            .coco_url
353            .as_deref()
354            .and_then(derive_file_name_from_coco_url)
355        {
356            image.file_name = derived;
357        }
358    }
359}
360
361/// Infer group name from COCO annotation filename.
362///
363/// Extracts the split name from standard COCO naming conventions:
364/// - `instances_train2017.json` → `"train"`
365/// - `instances_val2017.json` → `"val"`
366/// - `instances_test2017.json` → `"test"`
367/// - `person_keypoints_train2017.json` → `"train"`
368///
369/// # Arguments
370/// * `filename` - The annotation file name
371///
372/// # Returns
373/// The inferred group name if extraction succeeds
374pub fn infer_group_from_filename(filename: &str) -> Option<String> {
375    let stem = Path::new(filename).file_stem()?.to_str()?;
376
377    // Try common COCO patterns
378    // Pattern: instances_<group><year>.json
379    if let Some(rest) = stem.strip_prefix("instances_") {
380        let group = rest.trim_end_matches(char::is_numeric);
381        if !group.is_empty() {
382            return Some(group.to_string());
383        }
384    }
385
386    // Pattern: person_keypoints_<group><year>.json
387    if let Some(rest) = stem.strip_prefix("person_keypoints_") {
388        let group = rest.trim_end_matches(char::is_numeric);
389        if !group.is_empty() {
390            return Some(group.to_string());
391        }
392    }
393
394    // Pattern: captions_<group><year>.json
395    if let Some(rest) = stem.strip_prefix("captions_") {
396        let group = rest.trim_end_matches(char::is_numeric);
397        if !group.is_empty() {
398            return Some(group.to_string());
399        }
400    }
401
402    // Pattern: panoptic_<group><year>.json
403    if let Some(rest) = stem.strip_prefix("panoptic_") {
404        let group = rest.trim_end_matches(char::is_numeric);
405        if !group.is_empty() {
406            return Some(group.to_string());
407        }
408    }
409
410    // Fallback: look for train/val/test anywhere in the filename
411    let lower = filename.to_lowercase();
412    if lower.contains("train") {
413        return Some("train".to_string());
414    }
415    if lower.contains("val") {
416        return Some("val".to_string());
417    }
418    if lower.contains("test") {
419        return Some("test".to_string());
420    }
421
422    None
423}
424
425/// Infer the group name from an image folder path.
426///
427/// Extracts the group from folder names like "train2017", "val2017",
428/// "test2017". Strips trailing year numbers to get the group name.
429///
430/// # Examples
431/// - `train2017/000000001.jpg` → "train"
432/// - `val2017/000000002.jpg` → "val"
433/// - `test2017/000000003.jpg` → "test"
434/// - `custom_split/image.jpg` → "custom_split"
435///
436/// # Arguments
437/// * `image_path` - Relative path to the image (e.g.,
438///   "train2017/000000001.jpg")
439///
440/// # Returns
441/// Inferred group name, or None if no folder component is found.
442pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
443    let path = Path::new(image_path);
444
445    // Get the parent folder name (e.g., "train2017" from "train2017/image.jpg")
446    let folder = path.parent()?.file_name()?.to_str()?;
447
448    if folder.is_empty() {
449        return None;
450    }
451
452    // Strip trailing year numbers (e.g., "train2017" → "train")
453    let group = folder.trim_end_matches(char::is_numeric);
454
455    if group.is_empty() {
456        // Folder was all digits, use original
457        Some(folder.to_string())
458    } else {
459        Some(group.to_string())
460    }
461}
462
463/// Read all COCO annotation files from a directory.
464///
465/// Discovers and reads annotation files from standard COCO directory
466/// structures:
467///
468/// ```text
469/// coco_dir/
470/// ├── annotations/
471/// │   ├── instances_train2017.json
472/// │   └── instances_val2017.json
473/// └── ...
474/// ```
475///
476/// # Arguments
477///
478/// * `path` - Path to the COCO directory
479/// * `options` - Read options
480///
481/// # Returns
482///
483/// Vector of `(CocoDataset, inferred_group)` pairs
484pub fn read_coco_directory<P: AsRef<Path>>(
485    path: P,
486    options: &CocoReadOptions,
487) -> Result<Vec<(CocoDataset, String)>, Error> {
488    let path = path.as_ref();
489    let mut results = Vec::new();
490
491    // Look for annotation files
492    let annotations_dir = path.join("annotations");
493    let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
494        vec![annotations_dir.as_path(), path]
495    } else {
496        vec![path]
497    };
498
499    for search_dir in search_dirs {
500        if !search_dir.is_dir() {
501            continue;
502        }
503
504        for entry in std::fs::read_dir(search_dir)? {
505            let entry = entry?;
506            let file_path = entry.path();
507
508            if !file_path.is_file() {
509                continue;
510            }
511
512            let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
513
514            // Only process instance annotation files
515            if filename.ends_with(".json") && filename.contains("instances") {
516                let group =
517                    infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
518
519                let reader = CocoReader::with_options(options.clone());
520                let dataset = reader.read_json(&file_path)?;
521
522                results.push((dataset, group));
523            }
524        }
525    }
526
527    if results.is_empty() {
528        return Err(Error::MissingAnnotations(format!(
529            "No COCO annotation files found in {}",
530            path.display()
531        )));
532    }
533
534    Ok(results)
535}
536
537/// Merge a source dataset into a target dataset.
538fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
539    // Take info if not set
540    if target.info.description.is_none() {
541        target.info = source.info;
542    }
543
544    // Merge images (deduplicate by id)
545    let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
546    for image in source.images {
547        if !existing_ids.contains(&image.id) {
548            target.images.push(image);
549        }
550    }
551
552    // Merge categories (deduplicate by id)
553    let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
554    for cat in source.categories {
555        if !existing_cats.contains(&cat.id) {
556            target.categories.push(cat);
557        }
558    }
559
560    // Merge annotations (always append - IDs should be globally unique)
561    target.annotations.extend(source.annotations);
562
563    // Merge licenses
564    let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
565    for lic in source.licenses {
566        if !existing_licenses.contains(&lic.id) {
567            target.licenses.push(lic);
568        }
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
576
577    #[test]
578    fn test_reader_default() {
579        let reader = CocoReader::new();
580        assert!(!reader.options.validate);
581        assert_eq!(reader.options.max_images, 0);
582        assert!(reader.options.category_filter.is_empty());
583    }
584
585    #[test]
586    fn test_reader_with_options() {
587        let options = CocoReadOptions {
588            validate: true,
589            max_images: 100,
590            category_filter: vec!["person".to_string()],
591        };
592        let reader = CocoReader::with_options(options.clone());
593        assert!(reader.options.validate);
594        assert_eq!(reader.options.max_images, 100);
595    }
596
597    #[test]
598    fn test_derive_file_name_from_coco_url() {
599        assert_eq!(
600            derive_file_name_from_coco_url(
601                "http://images.cocodataset.org/val2017/000000397133.jpg"
602            ),
603            Some("val2017/000000397133.jpg".to_string())
604        );
605        assert_eq!(
606            derive_file_name_from_coco_url(
607                "https://images.cocodataset.org/train2017/000000000009.jpg"
608            ),
609            Some("train2017/000000000009.jpg".to_string())
610        );
611        assert_eq!(derive_file_name_from_coco_url("host-only"), None);
612        assert_eq!(derive_file_name_from_coco_url("http://host/"), None);
613    }
614
615    #[test]
616    fn test_derive_file_name_from_coco_url_rejects_traversal() {
617        // Parent-dir segments must not survive — `Path::join` would otherwise
618        // let `images_dir.join("../etc/passwd")` escape the images base.
619        assert_eq!(
620            derive_file_name_from_coco_url("http://host/../etc/passwd"),
621            None
622        );
623        assert_eq!(
624            derive_file_name_from_coco_url("http://host/val2017/../../etc/passwd"),
625            None
626        );
627        // A lone `.` segment is also rejected (no legitimate reason to keep
628        // it, and refusing keeps the sanitiser easy to reason about).
629        assert_eq!(derive_file_name_from_coco_url("http://host/./foo.jpg"), None);
630    }
631
632    #[test]
633    fn test_derive_file_name_from_coco_url_rejects_absolute_and_windows() {
634        // Leading `/` after the host means the URL encoded an absolute
635        // filesystem path (`http://host//etc/passwd`). `Path::join` drops
636        // the base on Unix when joining an absolute path.
637        assert_eq!(
638            derive_file_name_from_coco_url("http://host//etc/passwd"),
639            None
640        );
641        // Backslashes would slip past a naive Component check on Linux
642        // because `/` is the only separator there.
643        assert_eq!(
644            derive_file_name_from_coco_url("http://host/val2017\\..\\..\\etc"),
645            None
646        );
647        // Windows drive prefix — `Path::join` drops the base on Windows
648        // when joining a path with a drive.
649        assert_eq!(
650            derive_file_name_from_coco_url("http://host/C:/Windows/System32"),
651            None
652        );
653    }
654
655    #[test]
656    fn test_fill_missing_file_names_from_lvis_json() {
657        // LVIS-style image record: no `file_name`, only `coco_url`.
658        let json = r#"{
659            "images": [
660                {
661                    "id": 397133,
662                    "width": 640,
663                    "height": 427,
664                    "coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg",
665                    "neg_category_ids": [279, 899],
666                    "not_exhaustive_category_ids": [914]
667                }
668            ],
669            "annotations": [],
670            "categories": []
671        }"#;
672        let mut dataset: CocoDataset = serde_json::from_str(json).unwrap();
673        assert_eq!(dataset.images[0].file_name, "");
674        fill_missing_file_names(&mut dataset);
675        assert_eq!(dataset.images[0].file_name, "val2017/000000397133.jpg");
676        // LVIS extension fields survive deserialisation.
677        assert_eq!(
678            dataset.images[0].neg_category_ids.as_deref(),
679            Some(&[279u32, 899][..])
680        );
681    }
682
683    #[test]
684    fn test_fill_missing_file_names_preserves_existing() {
685        // COCO-style records already carry `file_name`; we must not clobber them.
686        let mut dataset = CocoDataset {
687            images: vec![CocoImage {
688                id: 1,
689                width: 640,
690                height: 480,
691                file_name: "custom/path.jpg".to_string(),
692                coco_url: Some("http://images.cocodataset.org/val2017/foo.jpg".to_string()),
693                ..Default::default()
694            }],
695            ..Default::default()
696        };
697        fill_missing_file_names(&mut dataset);
698        assert_eq!(dataset.images[0].file_name, "custom/path.jpg");
699    }
700
701    #[test]
702    fn test_validate_dataset_valid() {
703        let dataset = CocoDataset {
704            images: vec![CocoImage {
705                id: 1,
706                width: 640,
707                height: 480,
708                file_name: "test.jpg".to_string(),
709                ..Default::default()
710            }],
711            categories: vec![CocoCategory {
712                id: 1,
713                name: "person".to_string(),
714                supercategory: None,
715                ..Default::default()
716            }],
717            annotations: vec![CocoAnnotation {
718                id: 1,
719                image_id: 1,
720                category_id: 1,
721                bbox: [10.0, 20.0, 100.0, 80.0],
722                area: 8000.0,
723                iscrowd: 0,
724                segmentation: None,
725                score: None,
726            }],
727            ..Default::default()
728        };
729
730        assert!(validate_dataset(&dataset).is_ok());
731    }
732
733    #[test]
734    fn test_validate_dataset_missing_image() {
735        let dataset = CocoDataset {
736            images: vec![],
737            categories: vec![CocoCategory {
738                id: 1,
739                name: "person".to_string(),
740                supercategory: None,
741                ..Default::default()
742            }],
743            annotations: vec![CocoAnnotation {
744                id: 1,
745                image_id: 999, // Non-existent
746                category_id: 1,
747                bbox: [10.0, 20.0, 100.0, 80.0],
748                ..Default::default()
749            }],
750            ..Default::default()
751        };
752
753        assert!(validate_dataset(&dataset).is_err());
754    }
755
756    #[test]
757    fn test_merge_datasets() {
758        let mut target = CocoDataset {
759            images: vec![CocoImage {
760                id: 1,
761                width: 640,
762                height: 480,
763                file_name: "img1.jpg".to_string(),
764                ..Default::default()
765            }],
766            categories: vec![CocoCategory {
767                id: 1,
768                name: "person".to_string(),
769                supercategory: None,
770                ..Default::default()
771            }],
772            annotations: vec![],
773            ..Default::default()
774        };
775
776        let source = CocoDataset {
777            images: vec![
778                CocoImage {
779                    id: 1, // Duplicate - should not be added
780                    width: 640,
781                    height: 480,
782                    file_name: "img1.jpg".to_string(),
783                    ..Default::default()
784                },
785                CocoImage {
786                    id: 2, // New - should be added
787                    width: 800,
788                    height: 600,
789                    file_name: "img2.jpg".to_string(),
790                    ..Default::default()
791                },
792            ],
793            categories: vec![CocoCategory {
794                id: 2,
795                name: "car".to_string(),
796                supercategory: None,
797                ..Default::default()
798            }],
799            annotations: vec![],
800            ..Default::default()
801        };
802
803        merge_datasets(&mut target, source);
804
805        assert_eq!(target.images.len(), 2);
806        assert_eq!(target.categories.len(), 2);
807    }
808
809    #[test]
810    fn test_apply_max_images_filter() {
811        let reader = CocoReader::with_options(CocoReadOptions {
812            max_images: 2,
813            ..Default::default()
814        });
815
816        let dataset = CocoDataset {
817            images: vec![
818                CocoImage {
819                    id: 1,
820                    ..Default::default()
821                },
822                CocoImage {
823                    id: 2,
824                    ..Default::default()
825                },
826                CocoImage {
827                    id: 3,
828                    ..Default::default()
829                },
830            ],
831            annotations: vec![
832                CocoAnnotation {
833                    id: 1,
834                    image_id: 1,
835                    ..Default::default()
836                },
837                CocoAnnotation {
838                    id: 2,
839                    image_id: 2,
840                    ..Default::default()
841                },
842                CocoAnnotation {
843                    id: 3,
844                    image_id: 3,
845                    ..Default::default()
846                },
847            ],
848            ..Default::default()
849        };
850
851        let filtered = reader.apply_filters(dataset);
852        assert_eq!(filtered.images.len(), 2);
853        assert_eq!(filtered.annotations.len(), 2);
854    }
855
856    #[test]
857    fn test_infer_group_from_filename_instances() {
858        assert_eq!(
859            infer_group_from_filename("instances_train2017.json"),
860            Some("train".to_string())
861        );
862        assert_eq!(
863            infer_group_from_filename("instances_val2017.json"),
864            Some("val".to_string())
865        );
866        assert_eq!(
867            infer_group_from_filename("instances_test2017.json"),
868            Some("test".to_string())
869        );
870    }
871
872    #[test]
873    fn test_infer_group_from_filename_keypoints() {
874        assert_eq!(
875            infer_group_from_filename("person_keypoints_train2017.json"),
876            Some("train".to_string())
877        );
878        assert_eq!(
879            infer_group_from_filename("person_keypoints_val2017.json"),
880            Some("val".to_string())
881        );
882    }
883
884    #[test]
885    fn test_infer_group_from_filename_captions() {
886        assert_eq!(
887            infer_group_from_filename("captions_train2017.json"),
888            Some("train".to_string())
889        );
890        assert_eq!(
891            infer_group_from_filename("captions_val2017.json"),
892            Some("val".to_string())
893        );
894    }
895
896    #[test]
897    fn test_infer_group_from_filename_panoptic() {
898        assert_eq!(
899            infer_group_from_filename("panoptic_train2017.json"),
900            Some("train".to_string())
901        );
902        assert_eq!(
903            infer_group_from_filename("panoptic_val2017.json"),
904            Some("val".to_string())
905        );
906    }
907
908    #[test]
909    fn test_infer_group_from_filename_fallback() {
910        // Falls back to looking for train/val/test in filename
911        assert_eq!(
912            infer_group_from_filename("my_custom_train_annotations.json"),
913            Some("train".to_string())
914        );
915        assert_eq!(
916            infer_group_from_filename("validation_data.json"),
917            Some("val".to_string())
918        );
919    }
920
921    #[test]
922    fn test_infer_group_from_filename_no_match() {
923        // No recognizable pattern
924        assert_eq!(infer_group_from_filename("annotations.json"), None);
925        assert_eq!(infer_group_from_filename("data.json"), None);
926    }
927}