1use super::types::CocoDataset;
10use crate::Error;
11use std::{
12 collections::HashSet,
13 fs::File,
14 io::{BufReader, Read},
15 path::Path,
16};
17
18#[derive(Debug, Clone, Default)]
20pub struct CocoReadOptions {
21 pub validate: bool,
23 pub max_images: usize,
25 pub category_filter: Vec<String>,
27}
28
29pub struct CocoReader {
44 options: CocoReadOptions,
45}
46
47impl CocoReader {
48 pub fn new() -> Self {
50 Self {
51 options: CocoReadOptions::default(),
52 }
53 }
54
55 pub fn with_options(options: CocoReadOptions) -> Self {
57 Self { options }
58 }
59
60 pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
68 let file = File::open(path.as_ref())?;
69 let reader = BufReader::with_capacity(64 * 1024, file);
70 let mut dataset: CocoDataset = serde_json::from_reader(reader)?;
71 fill_missing_file_names(&mut dataset);
72
73 if self.options.validate {
74 validate_dataset(&dataset)?;
75 }
76
77 Ok(self.apply_filters(dataset))
78 }
79
80 pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
93 let file = File::open(path.as_ref())?;
94 let mut archive = zip::ZipArchive::new(file)?;
95
96 let mut merged = CocoDataset::default();
97
98 for i in 0..archive.len() {
99 let mut entry = archive.by_index(i)?;
100 let name = entry.name().to_string();
101
102 if name.ends_with(".json") && name.contains("instances") {
104 let mut contents = String::new();
105 entry.read_to_string(&mut contents)?;
106
107 let mut dataset: CocoDataset = serde_json::from_str(&contents)?;
108 fill_missing_file_names(&mut dataset);
109 merge_datasets(&mut merged, dataset);
110 }
111 }
112
113 if self.options.validate {
114 validate_dataset(&merged)?;
115 }
116
117 Ok(self.apply_filters(merged))
118 }
119
120 pub fn list_images<P: AsRef<Path>>(
128 &self,
129 path: P,
130 ) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
131 let path = path.as_ref();
132 let mut images = Vec::new();
133
134 if path.is_dir() {
135 for entry in walkdir::WalkDir::new(path)
137 .into_iter()
138 .filter_map(|e| e.ok())
139 .filter(|e| e.file_type().is_file())
140 {
141 let filename = entry.file_name().to_string_lossy().to_lowercase();
142 if filename.ends_with(".jpg")
143 || filename.ends_with(".jpeg")
144 || filename.ends_with(".png")
145 {
146 let rel_path = entry
147 .path()
148 .strip_prefix(path)
149 .unwrap_or(entry.path())
150 .to_string_lossy()
151 .to_string();
152 images.push((rel_path, entry.path().to_path_buf()));
153 }
154 }
155 } else if path.extension().is_some_and(|e| e == "zip") {
156 let file = File::open(path)?;
158 let mut archive = zip::ZipArchive::new(file)?;
159
160 for i in 0..archive.len() {
161 let entry = archive.by_index(i)?;
162 let name = entry.name().to_string();
163 let name_lower = name.to_lowercase();
164
165 if !entry.is_dir()
166 && (name_lower.ends_with(".jpg")
167 || name_lower.ends_with(".jpeg")
168 || name_lower.ends_with(".png"))
169 {
170 images.push((name.clone(), path.join(&name)));
171 }
172 }
173 }
174
175 Ok(images)
176 }
177
178 pub fn read_image_from_zip<P: AsRef<Path>>(
187 &self,
188 zip_path: P,
189 image_name: &str,
190 ) -> Result<Vec<u8>, Error> {
191 let file = File::open(zip_path.as_ref())?;
192 let mut archive = zip::ZipArchive::new(file)?;
193
194 let mut entry = archive.by_name(image_name)?;
195 let mut buffer = Vec::with_capacity(entry.size() as usize);
196 entry.read_to_end(&mut buffer)?;
197
198 Ok(buffer)
199 }
200
201 fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
203 if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
205 let image_ids: HashSet<_> = dataset
206 .images
207 .iter()
208 .take(self.options.max_images)
209 .map(|i| i.id)
210 .collect();
211
212 dataset.images.truncate(self.options.max_images);
213 dataset
214 .annotations
215 .retain(|a| image_ids.contains(&a.image_id));
216 }
217
218 if !self.options.category_filter.is_empty() {
220 let category_ids: HashSet<_> = dataset
221 .categories
222 .iter()
223 .filter(|c| self.options.category_filter.contains(&c.name))
224 .map(|c| c.id)
225 .collect();
226
227 dataset
228 .categories
229 .retain(|c| self.options.category_filter.contains(&c.name));
230 dataset
231 .annotations
232 .retain(|a| category_ids.contains(&a.category_id));
233 }
234
235 dataset
236 }
237}
238
239impl Default for CocoReader {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
247 let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
248 let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
249
250 for ann in &dataset.annotations {
251 if !image_ids.contains(&ann.image_id) {
252 return Err(Error::CocoError(format!(
253 "Annotation {} references non-existent image_id {}",
254 ann.id, ann.image_id
255 )));
256 }
257
258 if !category_ids.contains(&ann.category_id) {
259 return Err(Error::CocoError(format!(
260 "Annotation {} references non-existent category_id {}",
261 ann.id, ann.category_id
262 )));
263 }
264
265 if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
267 return Err(Error::CocoError(format!(
268 "Annotation {} has invalid bbox dimensions",
269 ann.id
270 )));
271 }
272 }
273
274 Ok(())
275}
276
277fn derive_file_name_from_coco_url(url: &str) -> Option<String> {
298 let after_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
299 let (_host, path) = after_scheme.split_once('/')?;
300 if path.is_empty() {
301 return None;
302 }
303
304 if path.starts_with('/') {
307 return None;
308 }
309
310 if path.contains('\\') || path.contains(':') {
322 return None;
323 }
324 let as_path = Path::new(path);
325 for component in as_path.components() {
326 use std::path::Component;
327 match component {
328 Component::Normal(_) => continue,
329 Component::RootDir
331 | Component::Prefix(_)
332 | Component::ParentDir
333 | Component::CurDir => return None,
334 }
335 }
336
337 Some(path.to_string())
338}
339
340fn fill_missing_file_names(dataset: &mut CocoDataset) {
347 for image in &mut dataset.images {
348 if !image.file_name.is_empty() {
349 continue;
350 }
351 if let Some(derived) = image
352 .coco_url
353 .as_deref()
354 .and_then(derive_file_name_from_coco_url)
355 {
356 image.file_name = derived;
357 }
358 }
359}
360
361pub fn infer_group_from_filename(filename: &str) -> Option<String> {
375 let stem = Path::new(filename).file_stem()?.to_str()?;
376
377 if let Some(rest) = stem.strip_prefix("instances_") {
380 let group = rest.trim_end_matches(char::is_numeric);
381 if !group.is_empty() {
382 return Some(group.to_string());
383 }
384 }
385
386 if let Some(rest) = stem.strip_prefix("person_keypoints_") {
388 let group = rest.trim_end_matches(char::is_numeric);
389 if !group.is_empty() {
390 return Some(group.to_string());
391 }
392 }
393
394 if let Some(rest) = stem.strip_prefix("captions_") {
396 let group = rest.trim_end_matches(char::is_numeric);
397 if !group.is_empty() {
398 return Some(group.to_string());
399 }
400 }
401
402 if let Some(rest) = stem.strip_prefix("panoptic_") {
404 let group = rest.trim_end_matches(char::is_numeric);
405 if !group.is_empty() {
406 return Some(group.to_string());
407 }
408 }
409
410 let lower = filename.to_lowercase();
412 if lower.contains("train") {
413 return Some("train".to_string());
414 }
415 if lower.contains("val") {
416 return Some("val".to_string());
417 }
418 if lower.contains("test") {
419 return Some("test".to_string());
420 }
421
422 None
423}
424
425pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
443 let path = Path::new(image_path);
444
445 let folder = path.parent()?.file_name()?.to_str()?;
447
448 if folder.is_empty() {
449 return None;
450 }
451
452 let group = folder.trim_end_matches(char::is_numeric);
454
455 if group.is_empty() {
456 Some(folder.to_string())
458 } else {
459 Some(group.to_string())
460 }
461}
462
463pub fn read_coco_directory<P: AsRef<Path>>(
485 path: P,
486 options: &CocoReadOptions,
487) -> Result<Vec<(CocoDataset, String)>, Error> {
488 let path = path.as_ref();
489 let mut results = Vec::new();
490
491 let annotations_dir = path.join("annotations");
493 let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
494 vec![annotations_dir.as_path(), path]
495 } else {
496 vec![path]
497 };
498
499 for search_dir in search_dirs {
500 if !search_dir.is_dir() {
501 continue;
502 }
503
504 for entry in std::fs::read_dir(search_dir)? {
505 let entry = entry?;
506 let file_path = entry.path();
507
508 if !file_path.is_file() {
509 continue;
510 }
511
512 let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
513
514 if filename.ends_with(".json") && filename.contains("instances") {
516 let group =
517 infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
518
519 let reader = CocoReader::with_options(options.clone());
520 let dataset = reader.read_json(&file_path)?;
521
522 results.push((dataset, group));
523 }
524 }
525 }
526
527 if results.is_empty() {
528 return Err(Error::MissingAnnotations(format!(
529 "No COCO annotation files found in {}",
530 path.display()
531 )));
532 }
533
534 Ok(results)
535}
536
537fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
539 if target.info.description.is_none() {
541 target.info = source.info;
542 }
543
544 let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
546 for image in source.images {
547 if !existing_ids.contains(&image.id) {
548 target.images.push(image);
549 }
550 }
551
552 let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
554 for cat in source.categories {
555 if !existing_cats.contains(&cat.id) {
556 target.categories.push(cat);
557 }
558 }
559
560 target.annotations.extend(source.annotations);
562
563 let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
565 for lic in source.licenses {
566 if !existing_licenses.contains(&lic.id) {
567 target.licenses.push(lic);
568 }
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
576
577 #[test]
578 fn test_reader_default() {
579 let reader = CocoReader::new();
580 assert!(!reader.options.validate);
581 assert_eq!(reader.options.max_images, 0);
582 assert!(reader.options.category_filter.is_empty());
583 }
584
585 #[test]
586 fn test_reader_with_options() {
587 let options = CocoReadOptions {
588 validate: true,
589 max_images: 100,
590 category_filter: vec!["person".to_string()],
591 };
592 let reader = CocoReader::with_options(options.clone());
593 assert!(reader.options.validate);
594 assert_eq!(reader.options.max_images, 100);
595 }
596
597 #[test]
598 fn test_derive_file_name_from_coco_url() {
599 assert_eq!(
600 derive_file_name_from_coco_url(
601 "http://images.cocodataset.org/val2017/000000397133.jpg"
602 ),
603 Some("val2017/000000397133.jpg".to_string())
604 );
605 assert_eq!(
606 derive_file_name_from_coco_url(
607 "https://images.cocodataset.org/train2017/000000000009.jpg"
608 ),
609 Some("train2017/000000000009.jpg".to_string())
610 );
611 assert_eq!(derive_file_name_from_coco_url("host-only"), None);
612 assert_eq!(derive_file_name_from_coco_url("http://host/"), None);
613 }
614
615 #[test]
616 fn test_derive_file_name_from_coco_url_rejects_traversal() {
617 assert_eq!(
620 derive_file_name_from_coco_url("http://host/../etc/passwd"),
621 None
622 );
623 assert_eq!(
624 derive_file_name_from_coco_url("http://host/val2017/../../etc/passwd"),
625 None
626 );
627 assert_eq!(
630 derive_file_name_from_coco_url("http://host/./foo.jpg"),
631 None
632 );
633 }
634
635 #[test]
636 fn test_derive_file_name_from_coco_url_rejects_absolute_and_windows() {
637 assert_eq!(
641 derive_file_name_from_coco_url("http://host//etc/passwd"),
642 None
643 );
644 assert_eq!(
647 derive_file_name_from_coco_url("http://host/val2017\\..\\..\\etc"),
648 None
649 );
650 assert_eq!(
653 derive_file_name_from_coco_url("http://host/C:/Windows/System32"),
654 None
655 );
656 }
657
658 #[test]
659 fn test_fill_missing_file_names_from_lvis_json() {
660 let json = r#"{
662 "images": [
663 {
664 "id": 397133,
665 "width": 640,
666 "height": 427,
667 "coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg",
668 "neg_category_ids": [279, 899],
669 "not_exhaustive_category_ids": [914]
670 }
671 ],
672 "annotations": [],
673 "categories": []
674 }"#;
675 let mut dataset: CocoDataset = serde_json::from_str(json).unwrap();
676 assert_eq!(dataset.images[0].file_name, "");
677 fill_missing_file_names(&mut dataset);
678 assert_eq!(dataset.images[0].file_name, "val2017/000000397133.jpg");
679 assert_eq!(
681 dataset.images[0].neg_category_ids.as_deref(),
682 Some(&[279u32, 899][..])
683 );
684 }
685
686 #[test]
687 fn test_fill_missing_file_names_preserves_existing() {
688 let mut dataset = CocoDataset {
690 images: vec![CocoImage {
691 id: 1,
692 width: 640,
693 height: 480,
694 file_name: "custom/path.jpg".to_string(),
695 coco_url: Some("http://images.cocodataset.org/val2017/foo.jpg".to_string()),
696 ..Default::default()
697 }],
698 ..Default::default()
699 };
700 fill_missing_file_names(&mut dataset);
701 assert_eq!(dataset.images[0].file_name, "custom/path.jpg");
702 }
703
704 #[test]
705 fn test_validate_dataset_valid() {
706 let dataset = CocoDataset {
707 images: vec![CocoImage {
708 id: 1,
709 width: 640,
710 height: 480,
711 file_name: "test.jpg".to_string(),
712 ..Default::default()
713 }],
714 categories: vec![CocoCategory {
715 id: 1,
716 name: "person".to_string(),
717 supercategory: None,
718 ..Default::default()
719 }],
720 annotations: vec![CocoAnnotation {
721 id: 1,
722 image_id: 1,
723 category_id: 1,
724 bbox: [10.0, 20.0, 100.0, 80.0],
725 area: 8000.0,
726 iscrowd: 0,
727 segmentation: None,
728 score: None,
729 }],
730 ..Default::default()
731 };
732
733 assert!(validate_dataset(&dataset).is_ok());
734 }
735
736 #[test]
737 fn test_validate_dataset_missing_image() {
738 let dataset = CocoDataset {
739 images: vec![],
740 categories: vec![CocoCategory {
741 id: 1,
742 name: "person".to_string(),
743 supercategory: None,
744 ..Default::default()
745 }],
746 annotations: vec![CocoAnnotation {
747 id: 1,
748 image_id: 999, category_id: 1,
750 bbox: [10.0, 20.0, 100.0, 80.0],
751 ..Default::default()
752 }],
753 ..Default::default()
754 };
755
756 assert!(validate_dataset(&dataset).is_err());
757 }
758
759 #[test]
760 fn test_merge_datasets() {
761 let mut target = CocoDataset {
762 images: vec![CocoImage {
763 id: 1,
764 width: 640,
765 height: 480,
766 file_name: "img1.jpg".to_string(),
767 ..Default::default()
768 }],
769 categories: vec![CocoCategory {
770 id: 1,
771 name: "person".to_string(),
772 supercategory: None,
773 ..Default::default()
774 }],
775 annotations: vec![],
776 ..Default::default()
777 };
778
779 let source = CocoDataset {
780 images: vec![
781 CocoImage {
782 id: 1, width: 640,
784 height: 480,
785 file_name: "img1.jpg".to_string(),
786 ..Default::default()
787 },
788 CocoImage {
789 id: 2, width: 800,
791 height: 600,
792 file_name: "img2.jpg".to_string(),
793 ..Default::default()
794 },
795 ],
796 categories: vec![CocoCategory {
797 id: 2,
798 name: "car".to_string(),
799 supercategory: None,
800 ..Default::default()
801 }],
802 annotations: vec![],
803 ..Default::default()
804 };
805
806 merge_datasets(&mut target, source);
807
808 assert_eq!(target.images.len(), 2);
809 assert_eq!(target.categories.len(), 2);
810 }
811
812 #[test]
813 fn test_apply_max_images_filter() {
814 let reader = CocoReader::with_options(CocoReadOptions {
815 max_images: 2,
816 ..Default::default()
817 });
818
819 let dataset = CocoDataset {
820 images: vec![
821 CocoImage {
822 id: 1,
823 ..Default::default()
824 },
825 CocoImage {
826 id: 2,
827 ..Default::default()
828 },
829 CocoImage {
830 id: 3,
831 ..Default::default()
832 },
833 ],
834 annotations: vec![
835 CocoAnnotation {
836 id: 1,
837 image_id: 1,
838 ..Default::default()
839 },
840 CocoAnnotation {
841 id: 2,
842 image_id: 2,
843 ..Default::default()
844 },
845 CocoAnnotation {
846 id: 3,
847 image_id: 3,
848 ..Default::default()
849 },
850 ],
851 ..Default::default()
852 };
853
854 let filtered = reader.apply_filters(dataset);
855 assert_eq!(filtered.images.len(), 2);
856 assert_eq!(filtered.annotations.len(), 2);
857 }
858
859 #[test]
860 fn test_infer_group_from_filename_instances() {
861 assert_eq!(
862 infer_group_from_filename("instances_train2017.json"),
863 Some("train".to_string())
864 );
865 assert_eq!(
866 infer_group_from_filename("instances_val2017.json"),
867 Some("val".to_string())
868 );
869 assert_eq!(
870 infer_group_from_filename("instances_test2017.json"),
871 Some("test".to_string())
872 );
873 }
874
875 #[test]
876 fn test_infer_group_from_filename_keypoints() {
877 assert_eq!(
878 infer_group_from_filename("person_keypoints_train2017.json"),
879 Some("train".to_string())
880 );
881 assert_eq!(
882 infer_group_from_filename("person_keypoints_val2017.json"),
883 Some("val".to_string())
884 );
885 }
886
887 #[test]
888 fn test_infer_group_from_filename_captions() {
889 assert_eq!(
890 infer_group_from_filename("captions_train2017.json"),
891 Some("train".to_string())
892 );
893 assert_eq!(
894 infer_group_from_filename("captions_val2017.json"),
895 Some("val".to_string())
896 );
897 }
898
899 #[test]
900 fn test_infer_group_from_filename_panoptic() {
901 assert_eq!(
902 infer_group_from_filename("panoptic_train2017.json"),
903 Some("train".to_string())
904 );
905 assert_eq!(
906 infer_group_from_filename("panoptic_val2017.json"),
907 Some("val".to_string())
908 );
909 }
910
911 #[test]
912 fn test_infer_group_from_filename_fallback() {
913 assert_eq!(
915 infer_group_from_filename("my_custom_train_annotations.json"),
916 Some("train".to_string())
917 );
918 assert_eq!(
919 infer_group_from_filename("validation_data.json"),
920 Some("val".to_string())
921 );
922 }
923
924 #[test]
925 fn test_infer_group_from_filename_no_match() {
926 assert_eq!(infer_group_from_filename("annotations.json"), None);
928 assert_eq!(infer_group_from_filename("data.json"), None);
929 }
930}