1use super::types::CocoDataset;
10use crate::Error;
11use std::{
12 collections::HashSet,
13 fs::File,
14 io::{BufReader, Read},
15 path::Path,
16};
17
18#[derive(Debug, Clone, Default)]
20pub struct CocoReadOptions {
21 pub validate: bool,
23 pub max_images: usize,
25 pub category_filter: Vec<String>,
27}
28
29pub struct CocoReader {
44 options: CocoReadOptions,
45}
46
47impl CocoReader {
48 pub fn new() -> Self {
50 Self {
51 options: CocoReadOptions::default(),
52 }
53 }
54
55 pub fn with_options(options: CocoReadOptions) -> Self {
57 Self { options }
58 }
59
60 pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
68 let file = File::open(path.as_ref())?;
69 let reader = BufReader::with_capacity(64 * 1024, file);
70 let dataset: CocoDataset = serde_json::from_reader(reader)?;
71
72 if self.options.validate {
73 validate_dataset(&dataset)?;
74 }
75
76 Ok(self.apply_filters(dataset))
77 }
78
79 pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
92 let file = File::open(path.as_ref())?;
93 let mut archive = zip::ZipArchive::new(file)?;
94
95 let mut merged = CocoDataset::default();
96
97 for i in 0..archive.len() {
98 let mut entry = archive.by_index(i)?;
99 let name = entry.name().to_string();
100
101 if name.ends_with(".json") && name.contains("instances") {
103 let mut contents = String::new();
104 entry.read_to_string(&mut contents)?;
105
106 let dataset: CocoDataset = serde_json::from_str(&contents)?;
107 merge_datasets(&mut merged, dataset);
108 }
109 }
110
111 if self.options.validate {
112 validate_dataset(&merged)?;
113 }
114
115 Ok(self.apply_filters(merged))
116 }
117
118 pub fn list_images<P: AsRef<Path>>(
126 &self,
127 path: P,
128 ) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
129 let path = path.as_ref();
130 let mut images = Vec::new();
131
132 if path.is_dir() {
133 for entry in walkdir::WalkDir::new(path)
135 .into_iter()
136 .filter_map(|e| e.ok())
137 .filter(|e| e.file_type().is_file())
138 {
139 let filename = entry.file_name().to_string_lossy().to_lowercase();
140 if filename.ends_with(".jpg")
141 || filename.ends_with(".jpeg")
142 || filename.ends_with(".png")
143 {
144 let rel_path = entry
145 .path()
146 .strip_prefix(path)
147 .unwrap_or(entry.path())
148 .to_string_lossy()
149 .to_string();
150 images.push((rel_path, entry.path().to_path_buf()));
151 }
152 }
153 } else if path.extension().is_some_and(|e| e == "zip") {
154 let file = File::open(path)?;
156 let mut archive = zip::ZipArchive::new(file)?;
157
158 for i in 0..archive.len() {
159 let entry = archive.by_index(i)?;
160 let name = entry.name().to_string();
161 let name_lower = name.to_lowercase();
162
163 if !entry.is_dir()
164 && (name_lower.ends_with(".jpg")
165 || name_lower.ends_with(".jpeg")
166 || name_lower.ends_with(".png"))
167 {
168 images.push((name.clone(), path.join(&name)));
169 }
170 }
171 }
172
173 Ok(images)
174 }
175
176 pub fn read_image_from_zip<P: AsRef<Path>>(
185 &self,
186 zip_path: P,
187 image_name: &str,
188 ) -> Result<Vec<u8>, Error> {
189 let file = File::open(zip_path.as_ref())?;
190 let mut archive = zip::ZipArchive::new(file)?;
191
192 let mut entry = archive.by_name(image_name)?;
193 let mut buffer = Vec::with_capacity(entry.size() as usize);
194 entry.read_to_end(&mut buffer)?;
195
196 Ok(buffer)
197 }
198
199 fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
201 if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
203 let image_ids: HashSet<_> = dataset
204 .images
205 .iter()
206 .take(self.options.max_images)
207 .map(|i| i.id)
208 .collect();
209
210 dataset.images.truncate(self.options.max_images);
211 dataset
212 .annotations
213 .retain(|a| image_ids.contains(&a.image_id));
214 }
215
216 if !self.options.category_filter.is_empty() {
218 let category_ids: HashSet<_> = dataset
219 .categories
220 .iter()
221 .filter(|c| self.options.category_filter.contains(&c.name))
222 .map(|c| c.id)
223 .collect();
224
225 dataset
226 .categories
227 .retain(|c| self.options.category_filter.contains(&c.name));
228 dataset
229 .annotations
230 .retain(|a| category_ids.contains(&a.category_id));
231 }
232
233 dataset
234 }
235}
236
237impl Default for CocoReader {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
245 let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
246 let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
247
248 for ann in &dataset.annotations {
249 if !image_ids.contains(&ann.image_id) {
250 return Err(Error::CocoError(format!(
251 "Annotation {} references non-existent image_id {}",
252 ann.id, ann.image_id
253 )));
254 }
255
256 if !category_ids.contains(&ann.category_id) {
257 return Err(Error::CocoError(format!(
258 "Annotation {} references non-existent category_id {}",
259 ann.id, ann.category_id
260 )));
261 }
262
263 if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
265 return Err(Error::CocoError(format!(
266 "Annotation {} has invalid bbox dimensions",
267 ann.id
268 )));
269 }
270 }
271
272 Ok(())
273}
274
275pub fn infer_group_from_filename(filename: &str) -> Option<String> {
289 let stem = Path::new(filename).file_stem()?.to_str()?;
290
291 if let Some(rest) = stem.strip_prefix("instances_") {
294 let group = rest.trim_end_matches(char::is_numeric);
295 if !group.is_empty() {
296 return Some(group.to_string());
297 }
298 }
299
300 if let Some(rest) = stem.strip_prefix("person_keypoints_") {
302 let group = rest.trim_end_matches(char::is_numeric);
303 if !group.is_empty() {
304 return Some(group.to_string());
305 }
306 }
307
308 if let Some(rest) = stem.strip_prefix("captions_") {
310 let group = rest.trim_end_matches(char::is_numeric);
311 if !group.is_empty() {
312 return Some(group.to_string());
313 }
314 }
315
316 if let Some(rest) = stem.strip_prefix("panoptic_") {
318 let group = rest.trim_end_matches(char::is_numeric);
319 if !group.is_empty() {
320 return Some(group.to_string());
321 }
322 }
323
324 let lower = filename.to_lowercase();
326 if lower.contains("train") {
327 return Some("train".to_string());
328 }
329 if lower.contains("val") {
330 return Some("val".to_string());
331 }
332 if lower.contains("test") {
333 return Some("test".to_string());
334 }
335
336 None
337}
338
339pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
357 let path = Path::new(image_path);
358
359 let folder = path.parent()?.file_name()?.to_str()?;
361
362 if folder.is_empty() {
363 return None;
364 }
365
366 let group = folder.trim_end_matches(char::is_numeric);
368
369 if group.is_empty() {
370 Some(folder.to_string())
372 } else {
373 Some(group.to_string())
374 }
375}
376
377pub fn read_coco_directory<P: AsRef<Path>>(
399 path: P,
400 options: &CocoReadOptions,
401) -> Result<Vec<(CocoDataset, String)>, Error> {
402 let path = path.as_ref();
403 let mut results = Vec::new();
404
405 let annotations_dir = path.join("annotations");
407 let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
408 vec![annotations_dir.as_path(), path]
409 } else {
410 vec![path]
411 };
412
413 for search_dir in search_dirs {
414 if !search_dir.is_dir() {
415 continue;
416 }
417
418 for entry in std::fs::read_dir(search_dir)? {
419 let entry = entry?;
420 let file_path = entry.path();
421
422 if !file_path.is_file() {
423 continue;
424 }
425
426 let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
427
428 if filename.ends_with(".json") && filename.contains("instances") {
430 let group =
431 infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
432
433 let reader = CocoReader::with_options(options.clone());
434 let dataset = reader.read_json(&file_path)?;
435
436 results.push((dataset, group));
437 }
438 }
439 }
440
441 if results.is_empty() {
442 return Err(Error::MissingAnnotations(format!(
443 "No COCO annotation files found in {}",
444 path.display()
445 )));
446 }
447
448 Ok(results)
449}
450
451fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
453 if target.info.description.is_none() {
455 target.info = source.info;
456 }
457
458 let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
460 for image in source.images {
461 if !existing_ids.contains(&image.id) {
462 target.images.push(image);
463 }
464 }
465
466 let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
468 for cat in source.categories {
469 if !existing_cats.contains(&cat.id) {
470 target.categories.push(cat);
471 }
472 }
473
474 target.annotations.extend(source.annotations);
476
477 let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
479 for lic in source.licenses {
480 if !existing_licenses.contains(&lic.id) {
481 target.licenses.push(lic);
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
490
491 #[test]
492 fn test_reader_default() {
493 let reader = CocoReader::new();
494 assert!(!reader.options.validate);
495 assert_eq!(reader.options.max_images, 0);
496 assert!(reader.options.category_filter.is_empty());
497 }
498
499 #[test]
500 fn test_reader_with_options() {
501 let options = CocoReadOptions {
502 validate: true,
503 max_images: 100,
504 category_filter: vec!["person".to_string()],
505 };
506 let reader = CocoReader::with_options(options.clone());
507 assert!(reader.options.validate);
508 assert_eq!(reader.options.max_images, 100);
509 }
510
511 #[test]
512 fn test_validate_dataset_valid() {
513 let dataset = CocoDataset {
514 images: vec![CocoImage {
515 id: 1,
516 width: 640,
517 height: 480,
518 file_name: "test.jpg".to_string(),
519 ..Default::default()
520 }],
521 categories: vec![CocoCategory {
522 id: 1,
523 name: "person".to_string(),
524 supercategory: None,
525 }],
526 annotations: vec![CocoAnnotation {
527 id: 1,
528 image_id: 1,
529 category_id: 1,
530 bbox: [10.0, 20.0, 100.0, 80.0],
531 area: 8000.0,
532 iscrowd: 0,
533 segmentation: None,
534 }],
535 ..Default::default()
536 };
537
538 assert!(validate_dataset(&dataset).is_ok());
539 }
540
541 #[test]
542 fn test_validate_dataset_missing_image() {
543 let dataset = CocoDataset {
544 images: vec![],
545 categories: vec![CocoCategory {
546 id: 1,
547 name: "person".to_string(),
548 supercategory: None,
549 }],
550 annotations: vec![CocoAnnotation {
551 id: 1,
552 image_id: 999, category_id: 1,
554 bbox: [10.0, 20.0, 100.0, 80.0],
555 ..Default::default()
556 }],
557 ..Default::default()
558 };
559
560 assert!(validate_dataset(&dataset).is_err());
561 }
562
563 #[test]
564 fn test_merge_datasets() {
565 let mut target = CocoDataset {
566 images: vec![CocoImage {
567 id: 1,
568 width: 640,
569 height: 480,
570 file_name: "img1.jpg".to_string(),
571 ..Default::default()
572 }],
573 categories: vec![CocoCategory {
574 id: 1,
575 name: "person".to_string(),
576 supercategory: None,
577 }],
578 annotations: vec![],
579 ..Default::default()
580 };
581
582 let source = CocoDataset {
583 images: vec![
584 CocoImage {
585 id: 1, width: 640,
587 height: 480,
588 file_name: "img1.jpg".to_string(),
589 ..Default::default()
590 },
591 CocoImage {
592 id: 2, width: 800,
594 height: 600,
595 file_name: "img2.jpg".to_string(),
596 ..Default::default()
597 },
598 ],
599 categories: vec![CocoCategory {
600 id: 2,
601 name: "car".to_string(),
602 supercategory: None,
603 }],
604 annotations: vec![],
605 ..Default::default()
606 };
607
608 merge_datasets(&mut target, source);
609
610 assert_eq!(target.images.len(), 2);
611 assert_eq!(target.categories.len(), 2);
612 }
613
614 #[test]
615 fn test_apply_max_images_filter() {
616 let reader = CocoReader::with_options(CocoReadOptions {
617 max_images: 2,
618 ..Default::default()
619 });
620
621 let dataset = CocoDataset {
622 images: vec![
623 CocoImage {
624 id: 1,
625 ..Default::default()
626 },
627 CocoImage {
628 id: 2,
629 ..Default::default()
630 },
631 CocoImage {
632 id: 3,
633 ..Default::default()
634 },
635 ],
636 annotations: vec![
637 CocoAnnotation {
638 id: 1,
639 image_id: 1,
640 ..Default::default()
641 },
642 CocoAnnotation {
643 id: 2,
644 image_id: 2,
645 ..Default::default()
646 },
647 CocoAnnotation {
648 id: 3,
649 image_id: 3,
650 ..Default::default()
651 },
652 ],
653 ..Default::default()
654 };
655
656 let filtered = reader.apply_filters(dataset);
657 assert_eq!(filtered.images.len(), 2);
658 assert_eq!(filtered.annotations.len(), 2);
659 }
660
661 #[test]
662 fn test_infer_group_from_filename_instances() {
663 assert_eq!(
664 infer_group_from_filename("instances_train2017.json"),
665 Some("train".to_string())
666 );
667 assert_eq!(
668 infer_group_from_filename("instances_val2017.json"),
669 Some("val".to_string())
670 );
671 assert_eq!(
672 infer_group_from_filename("instances_test2017.json"),
673 Some("test".to_string())
674 );
675 }
676
677 #[test]
678 fn test_infer_group_from_filename_keypoints() {
679 assert_eq!(
680 infer_group_from_filename("person_keypoints_train2017.json"),
681 Some("train".to_string())
682 );
683 assert_eq!(
684 infer_group_from_filename("person_keypoints_val2017.json"),
685 Some("val".to_string())
686 );
687 }
688
689 #[test]
690 fn test_infer_group_from_filename_captions() {
691 assert_eq!(
692 infer_group_from_filename("captions_train2017.json"),
693 Some("train".to_string())
694 );
695 assert_eq!(
696 infer_group_from_filename("captions_val2017.json"),
697 Some("val".to_string())
698 );
699 }
700
701 #[test]
702 fn test_infer_group_from_filename_panoptic() {
703 assert_eq!(
704 infer_group_from_filename("panoptic_train2017.json"),
705 Some("train".to_string())
706 );
707 assert_eq!(
708 infer_group_from_filename("panoptic_val2017.json"),
709 Some("val".to_string())
710 );
711 }
712
713 #[test]
714 fn test_infer_group_from_filename_fallback() {
715 assert_eq!(
717 infer_group_from_filename("my_custom_train_annotations.json"),
718 Some("train".to_string())
719 );
720 assert_eq!(
721 infer_group_from_filename("validation_data.json"),
722 Some("val".to_string())
723 );
724 }
725
726 #[test]
727 fn test_infer_group_from_filename_no_match() {
728 assert_eq!(infer_group_from_filename("annotations.json"), None);
730 assert_eq!(infer_group_from_filename("data.json"), None);
731 }
732}