1use super::types::CocoDataset;
10use crate::Error;
11use std::{
12 collections::HashSet,
13 fs::File,
14 io::{BufReader, Read},
15 path::Path,
16};
17
18#[derive(Debug, Clone, Default)]
20pub struct CocoReadOptions {
21 pub validate: bool,
23 pub max_images: usize,
25 pub category_filter: Vec<String>,
27}
28
29pub struct CocoReader {
44 options: CocoReadOptions,
45}
46
47impl CocoReader {
48 pub fn new() -> Self {
50 Self {
51 options: CocoReadOptions::default(),
52 }
53 }
54
55 pub fn with_options(options: CocoReadOptions) -> Self {
57 Self { options }
58 }
59
60 pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
68 let file = File::open(path.as_ref())?;
69 let reader = BufReader::with_capacity(64 * 1024, file);
70 let mut dataset: CocoDataset = serde_json::from_reader(reader)?;
71 fill_missing_file_names(&mut dataset);
72
73 if self.options.validate {
74 validate_dataset(&dataset)?;
75 }
76
77 Ok(self.apply_filters(dataset))
78 }
79
80 pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
93 let file = File::open(path.as_ref())?;
94 let mut archive = zip::ZipArchive::new(file)?;
95
96 let mut merged = CocoDataset::default();
97
98 for i in 0..archive.len() {
99 let mut entry = archive.by_index(i)?;
100 let name = entry.name().to_string();
101
102 if name.ends_with(".json") && name.contains("instances") {
104 let mut contents = String::new();
105 entry.read_to_string(&mut contents)?;
106
107 let mut dataset: CocoDataset = serde_json::from_str(&contents)?;
108 fill_missing_file_names(&mut dataset);
109 merge_datasets(&mut merged, dataset);
110 }
111 }
112
113 if self.options.validate {
114 validate_dataset(&merged)?;
115 }
116
117 Ok(self.apply_filters(merged))
118 }
119
120 pub fn list_images<P: AsRef<Path>>(
128 &self,
129 path: P,
130 ) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
131 let path = path.as_ref();
132 let mut images = Vec::new();
133
134 if path.is_dir() {
135 for entry in walkdir::WalkDir::new(path)
137 .into_iter()
138 .filter_map(|e| e.ok())
139 .filter(|e| e.file_type().is_file())
140 {
141 let filename = entry.file_name().to_string_lossy().to_lowercase();
142 if filename.ends_with(".jpg")
143 || filename.ends_with(".jpeg")
144 || filename.ends_with(".png")
145 {
146 let rel_path = entry
147 .path()
148 .strip_prefix(path)
149 .unwrap_or(entry.path())
150 .to_string_lossy()
151 .to_string();
152 images.push((rel_path, entry.path().to_path_buf()));
153 }
154 }
155 } else if path.extension().is_some_and(|e| e == "zip") {
156 let file = File::open(path)?;
158 let mut archive = zip::ZipArchive::new(file)?;
159
160 for i in 0..archive.len() {
161 let entry = archive.by_index(i)?;
162 let name = entry.name().to_string();
163 let name_lower = name.to_lowercase();
164
165 if !entry.is_dir()
166 && (name_lower.ends_with(".jpg")
167 || name_lower.ends_with(".jpeg")
168 || name_lower.ends_with(".png"))
169 {
170 images.push((name.clone(), path.join(&name)));
171 }
172 }
173 }
174
175 Ok(images)
176 }
177
178 pub fn read_image_from_zip<P: AsRef<Path>>(
187 &self,
188 zip_path: P,
189 image_name: &str,
190 ) -> Result<Vec<u8>, Error> {
191 let file = File::open(zip_path.as_ref())?;
192 let mut archive = zip::ZipArchive::new(file)?;
193
194 let mut entry = archive.by_name(image_name)?;
195 let mut buffer = Vec::with_capacity(entry.size() as usize);
196 entry.read_to_end(&mut buffer)?;
197
198 Ok(buffer)
199 }
200
201 fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
203 if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
205 let image_ids: HashSet<_> = dataset
206 .images
207 .iter()
208 .take(self.options.max_images)
209 .map(|i| i.id)
210 .collect();
211
212 dataset.images.truncate(self.options.max_images);
213 dataset
214 .annotations
215 .retain(|a| image_ids.contains(&a.image_id));
216 }
217
218 if !self.options.category_filter.is_empty() {
220 let category_ids: HashSet<_> = dataset
221 .categories
222 .iter()
223 .filter(|c| self.options.category_filter.contains(&c.name))
224 .map(|c| c.id)
225 .collect();
226
227 dataset
228 .categories
229 .retain(|c| self.options.category_filter.contains(&c.name));
230 dataset
231 .annotations
232 .retain(|a| category_ids.contains(&a.category_id));
233 }
234
235 dataset
236 }
237}
238
239impl Default for CocoReader {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
247 let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
248 let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
249
250 for ann in &dataset.annotations {
251 if !image_ids.contains(&ann.image_id) {
252 return Err(Error::CocoError(format!(
253 "Annotation {} references non-existent image_id {}",
254 ann.id, ann.image_id
255 )));
256 }
257
258 if !category_ids.contains(&ann.category_id) {
259 return Err(Error::CocoError(format!(
260 "Annotation {} references non-existent category_id {}",
261 ann.id, ann.category_id
262 )));
263 }
264
265 if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
267 return Err(Error::CocoError(format!(
268 "Annotation {} has invalid bbox dimensions",
269 ann.id
270 )));
271 }
272 }
273
274 Ok(())
275}
276
277fn derive_file_name_from_coco_url(url: &str) -> Option<String> {
298 let after_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
299 let (_host, path) = after_scheme.split_once('/')?;
300 if path.is_empty() {
301 return None;
302 }
303
304 if path.starts_with('/') {
307 return None;
308 }
309
310 if path.contains('\\') || path.contains(':') {
322 return None;
323 }
324 let as_path = Path::new(path);
325 for component in as_path.components() {
326 use std::path::Component;
327 match component {
328 Component::Normal(_) => continue,
329 Component::RootDir
331 | Component::Prefix(_)
332 | Component::ParentDir
333 | Component::CurDir => return None,
334 }
335 }
336
337 Some(path.to_string())
338}
339
340fn fill_missing_file_names(dataset: &mut CocoDataset) {
347 for image in &mut dataset.images {
348 if !image.file_name.is_empty() {
349 continue;
350 }
351 if let Some(derived) = image
352 .coco_url
353 .as_deref()
354 .and_then(derive_file_name_from_coco_url)
355 {
356 image.file_name = derived;
357 }
358 }
359}
360
361pub fn infer_group_from_filename(filename: &str) -> Option<String> {
375 let stem = Path::new(filename).file_stem()?.to_str()?;
376
377 if let Some(rest) = stem.strip_prefix("instances_") {
380 let group = rest.trim_end_matches(char::is_numeric);
381 if !group.is_empty() {
382 return Some(group.to_string());
383 }
384 }
385
386 if let Some(rest) = stem.strip_prefix("person_keypoints_") {
388 let group = rest.trim_end_matches(char::is_numeric);
389 if !group.is_empty() {
390 return Some(group.to_string());
391 }
392 }
393
394 if let Some(rest) = stem.strip_prefix("captions_") {
396 let group = rest.trim_end_matches(char::is_numeric);
397 if !group.is_empty() {
398 return Some(group.to_string());
399 }
400 }
401
402 if let Some(rest) = stem.strip_prefix("panoptic_") {
404 let group = rest.trim_end_matches(char::is_numeric);
405 if !group.is_empty() {
406 return Some(group.to_string());
407 }
408 }
409
410 let lower = filename.to_lowercase();
412 if lower.contains("train") {
413 return Some("train".to_string());
414 }
415 if lower.contains("val") {
416 return Some("val".to_string());
417 }
418 if lower.contains("test") {
419 return Some("test".to_string());
420 }
421
422 None
423}
424
425pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
443 let path = Path::new(image_path);
444
445 let folder = path.parent()?.file_name()?.to_str()?;
447
448 if folder.is_empty() {
449 return None;
450 }
451
452 let group = folder.trim_end_matches(char::is_numeric);
454
455 if group.is_empty() {
456 Some(folder.to_string())
458 } else {
459 Some(group.to_string())
460 }
461}
462
463pub fn read_coco_directory<P: AsRef<Path>>(
485 path: P,
486 options: &CocoReadOptions,
487) -> Result<Vec<(CocoDataset, String)>, Error> {
488 let path = path.as_ref();
489 let mut results = Vec::new();
490
491 let annotations_dir = path.join("annotations");
493 let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
494 vec![annotations_dir.as_path(), path]
495 } else {
496 vec![path]
497 };
498
499 for search_dir in search_dirs {
500 if !search_dir.is_dir() {
501 continue;
502 }
503
504 for entry in std::fs::read_dir(search_dir)? {
505 let entry = entry?;
506 let file_path = entry.path();
507
508 if !file_path.is_file() {
509 continue;
510 }
511
512 let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
513
514 if filename.ends_with(".json") && filename.contains("instances") {
516 let group =
517 infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
518
519 let reader = CocoReader::with_options(options.clone());
520 let dataset = reader.read_json(&file_path)?;
521
522 results.push((dataset, group));
523 }
524 }
525 }
526
527 if results.is_empty() {
528 return Err(Error::MissingAnnotations(format!(
529 "No COCO annotation files found in {}",
530 path.display()
531 )));
532 }
533
534 Ok(results)
535}
536
537fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
539 if target.info.description.is_none() {
541 target.info = source.info;
542 }
543
544 let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
546 for image in source.images {
547 if !existing_ids.contains(&image.id) {
548 target.images.push(image);
549 }
550 }
551
552 let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
554 for cat in source.categories {
555 if !existing_cats.contains(&cat.id) {
556 target.categories.push(cat);
557 }
558 }
559
560 target.annotations.extend(source.annotations);
562
563 let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
565 for lic in source.licenses {
566 if !existing_licenses.contains(&lic.id) {
567 target.licenses.push(lic);
568 }
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
576
577 #[test]
578 fn test_reader_default() {
579 let reader = CocoReader::new();
580 assert!(!reader.options.validate);
581 assert_eq!(reader.options.max_images, 0);
582 assert!(reader.options.category_filter.is_empty());
583 }
584
585 #[test]
586 fn test_reader_with_options() {
587 let options = CocoReadOptions {
588 validate: true,
589 max_images: 100,
590 category_filter: vec!["person".to_string()],
591 };
592 let reader = CocoReader::with_options(options.clone());
593 assert!(reader.options.validate);
594 assert_eq!(reader.options.max_images, 100);
595 }
596
597 #[test]
598 fn test_derive_file_name_from_coco_url() {
599 assert_eq!(
600 derive_file_name_from_coco_url(
601 "http://images.cocodataset.org/val2017/000000397133.jpg"
602 ),
603 Some("val2017/000000397133.jpg".to_string())
604 );
605 assert_eq!(
606 derive_file_name_from_coco_url(
607 "https://images.cocodataset.org/train2017/000000000009.jpg"
608 ),
609 Some("train2017/000000000009.jpg".to_string())
610 );
611 assert_eq!(derive_file_name_from_coco_url("host-only"), None);
612 assert_eq!(derive_file_name_from_coco_url("http://host/"), None);
613 }
614
615 #[test]
616 fn test_derive_file_name_from_coco_url_rejects_traversal() {
617 assert_eq!(
620 derive_file_name_from_coco_url("http://host/../etc/passwd"),
621 None
622 );
623 assert_eq!(
624 derive_file_name_from_coco_url("http://host/val2017/../../etc/passwd"),
625 None
626 );
627 assert_eq!(derive_file_name_from_coco_url("http://host/./foo.jpg"), None);
630 }
631
632 #[test]
633 fn test_derive_file_name_from_coco_url_rejects_absolute_and_windows() {
634 assert_eq!(
638 derive_file_name_from_coco_url("http://host//etc/passwd"),
639 None
640 );
641 assert_eq!(
644 derive_file_name_from_coco_url("http://host/val2017\\..\\..\\etc"),
645 None
646 );
647 assert_eq!(
650 derive_file_name_from_coco_url("http://host/C:/Windows/System32"),
651 None
652 );
653 }
654
655 #[test]
656 fn test_fill_missing_file_names_from_lvis_json() {
657 let json = r#"{
659 "images": [
660 {
661 "id": 397133,
662 "width": 640,
663 "height": 427,
664 "coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg",
665 "neg_category_ids": [279, 899],
666 "not_exhaustive_category_ids": [914]
667 }
668 ],
669 "annotations": [],
670 "categories": []
671 }"#;
672 let mut dataset: CocoDataset = serde_json::from_str(json).unwrap();
673 assert_eq!(dataset.images[0].file_name, "");
674 fill_missing_file_names(&mut dataset);
675 assert_eq!(dataset.images[0].file_name, "val2017/000000397133.jpg");
676 assert_eq!(
678 dataset.images[0].neg_category_ids.as_deref(),
679 Some(&[279u32, 899][..])
680 );
681 }
682
683 #[test]
684 fn test_fill_missing_file_names_preserves_existing() {
685 let mut dataset = CocoDataset {
687 images: vec![CocoImage {
688 id: 1,
689 width: 640,
690 height: 480,
691 file_name: "custom/path.jpg".to_string(),
692 coco_url: Some("http://images.cocodataset.org/val2017/foo.jpg".to_string()),
693 ..Default::default()
694 }],
695 ..Default::default()
696 };
697 fill_missing_file_names(&mut dataset);
698 assert_eq!(dataset.images[0].file_name, "custom/path.jpg");
699 }
700
701 #[test]
702 fn test_validate_dataset_valid() {
703 let dataset = CocoDataset {
704 images: vec![CocoImage {
705 id: 1,
706 width: 640,
707 height: 480,
708 file_name: "test.jpg".to_string(),
709 ..Default::default()
710 }],
711 categories: vec![CocoCategory {
712 id: 1,
713 name: "person".to_string(),
714 supercategory: None,
715 ..Default::default()
716 }],
717 annotations: vec![CocoAnnotation {
718 id: 1,
719 image_id: 1,
720 category_id: 1,
721 bbox: [10.0, 20.0, 100.0, 80.0],
722 area: 8000.0,
723 iscrowd: 0,
724 segmentation: None,
725 score: None,
726 }],
727 ..Default::default()
728 };
729
730 assert!(validate_dataset(&dataset).is_ok());
731 }
732
733 #[test]
734 fn test_validate_dataset_missing_image() {
735 let dataset = CocoDataset {
736 images: vec![],
737 categories: vec![CocoCategory {
738 id: 1,
739 name: "person".to_string(),
740 supercategory: None,
741 ..Default::default()
742 }],
743 annotations: vec![CocoAnnotation {
744 id: 1,
745 image_id: 999, category_id: 1,
747 bbox: [10.0, 20.0, 100.0, 80.0],
748 ..Default::default()
749 }],
750 ..Default::default()
751 };
752
753 assert!(validate_dataset(&dataset).is_err());
754 }
755
756 #[test]
757 fn test_merge_datasets() {
758 let mut target = CocoDataset {
759 images: vec![CocoImage {
760 id: 1,
761 width: 640,
762 height: 480,
763 file_name: "img1.jpg".to_string(),
764 ..Default::default()
765 }],
766 categories: vec![CocoCategory {
767 id: 1,
768 name: "person".to_string(),
769 supercategory: None,
770 ..Default::default()
771 }],
772 annotations: vec![],
773 ..Default::default()
774 };
775
776 let source = CocoDataset {
777 images: vec![
778 CocoImage {
779 id: 1, width: 640,
781 height: 480,
782 file_name: "img1.jpg".to_string(),
783 ..Default::default()
784 },
785 CocoImage {
786 id: 2, width: 800,
788 height: 600,
789 file_name: "img2.jpg".to_string(),
790 ..Default::default()
791 },
792 ],
793 categories: vec![CocoCategory {
794 id: 2,
795 name: "car".to_string(),
796 supercategory: None,
797 ..Default::default()
798 }],
799 annotations: vec![],
800 ..Default::default()
801 };
802
803 merge_datasets(&mut target, source);
804
805 assert_eq!(target.images.len(), 2);
806 assert_eq!(target.categories.len(), 2);
807 }
808
809 #[test]
810 fn test_apply_max_images_filter() {
811 let reader = CocoReader::with_options(CocoReadOptions {
812 max_images: 2,
813 ..Default::default()
814 });
815
816 let dataset = CocoDataset {
817 images: vec![
818 CocoImage {
819 id: 1,
820 ..Default::default()
821 },
822 CocoImage {
823 id: 2,
824 ..Default::default()
825 },
826 CocoImage {
827 id: 3,
828 ..Default::default()
829 },
830 ],
831 annotations: vec![
832 CocoAnnotation {
833 id: 1,
834 image_id: 1,
835 ..Default::default()
836 },
837 CocoAnnotation {
838 id: 2,
839 image_id: 2,
840 ..Default::default()
841 },
842 CocoAnnotation {
843 id: 3,
844 image_id: 3,
845 ..Default::default()
846 },
847 ],
848 ..Default::default()
849 };
850
851 let filtered = reader.apply_filters(dataset);
852 assert_eq!(filtered.images.len(), 2);
853 assert_eq!(filtered.annotations.len(), 2);
854 }
855
856 #[test]
857 fn test_infer_group_from_filename_instances() {
858 assert_eq!(
859 infer_group_from_filename("instances_train2017.json"),
860 Some("train".to_string())
861 );
862 assert_eq!(
863 infer_group_from_filename("instances_val2017.json"),
864 Some("val".to_string())
865 );
866 assert_eq!(
867 infer_group_from_filename("instances_test2017.json"),
868 Some("test".to_string())
869 );
870 }
871
872 #[test]
873 fn test_infer_group_from_filename_keypoints() {
874 assert_eq!(
875 infer_group_from_filename("person_keypoints_train2017.json"),
876 Some("train".to_string())
877 );
878 assert_eq!(
879 infer_group_from_filename("person_keypoints_val2017.json"),
880 Some("val".to_string())
881 );
882 }
883
884 #[test]
885 fn test_infer_group_from_filename_captions() {
886 assert_eq!(
887 infer_group_from_filename("captions_train2017.json"),
888 Some("train".to_string())
889 );
890 assert_eq!(
891 infer_group_from_filename("captions_val2017.json"),
892 Some("val".to_string())
893 );
894 }
895
896 #[test]
897 fn test_infer_group_from_filename_panoptic() {
898 assert_eq!(
899 infer_group_from_filename("panoptic_train2017.json"),
900 Some("train".to_string())
901 );
902 assert_eq!(
903 infer_group_from_filename("panoptic_val2017.json"),
904 Some("val".to_string())
905 );
906 }
907
908 #[test]
909 fn test_infer_group_from_filename_fallback() {
910 assert_eq!(
912 infer_group_from_filename("my_custom_train_annotations.json"),
913 Some("train".to_string())
914 );
915 assert_eq!(
916 infer_group_from_filename("validation_data.json"),
917 Some("val".to_string())
918 );
919 }
920
921 #[test]
922 fn test_infer_group_from_filename_no_match() {
923 assert_eq!(infer_group_from_filename("annotations.json"), None);
925 assert_eq!(infer_group_from_filename("data.json"), None);
926 }
927}