1use super::types::{
9 CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
10};
11use crate::Error;
12use std::{
13 fs::File,
14 io::{BufWriter, Write},
15 path::Path,
16};
17use zip::{CompressionMethod, write::SimpleFileOptions};
18
19#[derive(Debug, Clone)]
21pub struct CocoWriteOptions {
22 pub compress: bool,
24 pub pretty: bool,
26}
27
28impl Default for CocoWriteOptions {
29 fn default() -> Self {
30 Self {
31 compress: true,
32 pretty: false,
33 }
34 }
35}
36
37pub struct CocoWriter {
50 options: CocoWriteOptions,
51}
52
53impl CocoWriter {
54 pub fn new() -> Self {
56 Self {
57 options: CocoWriteOptions::default(),
58 }
59 }
60
61 pub fn with_options(options: CocoWriteOptions) -> Self {
63 Self { options }
64 }
65
66 pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
72 if let Some(parent) = path.as_ref().parent()
74 && !parent.as_os_str().is_empty()
75 {
76 std::fs::create_dir_all(parent)?;
77 }
78
79 let file = File::create(path.as_ref())?;
80 let writer = BufWriter::with_capacity(64 * 1024, file);
81
82 if self.options.pretty {
83 serde_json::to_writer_pretty(writer, dataset)?;
84 } else {
85 serde_json::to_writer(writer, dataset)?;
86 }
87
88 Ok(())
89 }
90
91 pub fn write_zip<P: AsRef<Path>>(
102 &self,
103 dataset: &CocoDataset,
104 images: impl Iterator<Item = (String, Vec<u8>)>,
105 path: P,
106 ) -> Result<(), Error> {
107 if let Some(parent) = path.as_ref().parent()
109 && !parent.as_os_str().is_empty()
110 {
111 std::fs::create_dir_all(parent)?;
112 }
113
114 let file = File::create(path.as_ref())?;
115 let mut zip = zip::ZipWriter::new(file);
116
117 let options = if self.options.compress {
118 SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
119 } else {
120 SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
121 };
122
123 zip.start_file("annotations/instances.json", options)?;
125 let json = if self.options.pretty {
126 serde_json::to_string_pretty(dataset)?
127 } else {
128 serde_json::to_string(dataset)?
129 };
130 zip.write_all(json.as_bytes())?;
131
132 for (filename, data) in images {
134 zip.start_file(&filename, options)?;
135 zip.write_all(&data)?;
136 }
137
138 zip.finish()?;
139 Ok(())
140 }
141
142 pub fn write_zip_from_dir<P: AsRef<Path>>(
149 &self,
150 dataset: &CocoDataset,
151 images_dir: P,
152 path: P,
153 ) -> Result<(), Error> {
154 let images_dir = images_dir.as_ref();
155
156 let images = dataset.images.iter().filter_map(|img| {
158 let img_path = images_dir.join(&img.file_name);
159 std::fs::read(&img_path)
160 .ok()
161 .map(|data| (format!("images/{}", img.file_name), data))
162 });
163
164 self.write_zip(dataset, images, path)
165 }
166
167 pub fn write_split_by_group<P: AsRef<Path>>(
191 &self,
192 dataset: &CocoDataset,
193 group_assignments: &[String],
194 images_source: Option<&Path>,
195 output_dir: P,
196 ) -> Result<std::collections::HashMap<String, usize>, Error> {
197 use std::collections::{HashMap, HashSet};
198
199 let output_dir = output_dir.as_ref();
200
201 if dataset.images.len() != group_assignments.len() {
203 return Err(Error::CocoError(format!(
204 "Image count ({}) does not match group assignment count ({})",
205 dataset.images.len(),
206 group_assignments.len()
207 )));
208 }
209
210 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
212 for (idx, group) in group_assignments.iter().enumerate() {
213 groups.entry(group.clone()).or_default().push(idx);
214 }
215
216 let mut result = HashMap::new();
217
218 for (group_name, image_indices) in &groups {
219 let group_dir = output_dir.join(group_name);
221 let annotations_dir = group_dir.join("annotations");
222 let images_dir = group_dir.join("images");
223
224 std::fs::create_dir_all(&annotations_dir)?;
225 std::fs::create_dir_all(&images_dir)?;
226
227 let image_ids: HashSet<u64> = image_indices
229 .iter()
230 .map(|&idx| dataset.images[idx].id)
231 .collect();
232
233 let subset = CocoDataset {
234 info: dataset.info.clone(),
235 licenses: dataset.licenses.clone(),
236 images: image_indices
237 .iter()
238 .map(|&idx| dataset.images[idx].clone())
239 .collect(),
240 annotations: dataset
241 .annotations
242 .iter()
243 .filter(|ann| image_ids.contains(&ann.image_id))
244 .cloned()
245 .collect(),
246 categories: dataset.categories.clone(),
247 };
248
249 let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
251 self.write_json(&subset, &ann_file)?;
252
253 if let Some(source) = images_source {
255 for &idx in image_indices {
256 let image = &dataset.images[idx];
257 let src_path = source.join(&image.file_name);
258 let dst_path = images_dir.join(&image.file_name);
259
260 if src_path.exists() {
261 std::fs::copy(&src_path, &dst_path)?;
262 }
263 }
264 }
265
266 result.insert(group_name.clone(), image_indices.len());
267 }
268
269 Ok(result)
270 }
271
272 pub fn write_split_by_group_zip<P: AsRef<Path>>(
288 &self,
289 dataset: &CocoDataset,
290 group_assignments: &[String],
291 images_source: Option<&Path>,
292 output_dir: P,
293 ) -> Result<std::collections::HashMap<String, usize>, Error> {
294 use std::collections::{HashMap, HashSet};
295
296 let output_dir = output_dir.as_ref();
297 std::fs::create_dir_all(output_dir)?;
298
299 if dataset.images.len() != group_assignments.len() {
301 return Err(Error::CocoError(format!(
302 "Image count ({}) does not match group assignment count ({})",
303 dataset.images.len(),
304 group_assignments.len()
305 )));
306 }
307
308 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
310 for (idx, group) in group_assignments.iter().enumerate() {
311 groups.entry(group.clone()).or_default().push(idx);
312 }
313
314 let mut result = HashMap::new();
315
316 for (group_name, image_indices) in &groups {
317 let image_ids: HashSet<u64> = image_indices
319 .iter()
320 .map(|&idx| dataset.images[idx].id)
321 .collect();
322
323 let subset = CocoDataset {
324 info: dataset.info.clone(),
325 licenses: dataset.licenses.clone(),
326 images: image_indices
327 .iter()
328 .map(|&idx| dataset.images[idx].clone())
329 .collect(),
330 annotations: dataset
331 .annotations
332 .iter()
333 .filter(|ann| image_ids.contains(&ann.image_id))
334 .cloned()
335 .collect(),
336 categories: dataset.categories.clone(),
337 };
338
339 let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
341 image_indices
342 .iter()
343 .filter_map(|&idx| {
344 let image = &dataset.images[idx];
345 let src_path = source.join(&image.file_name);
346 std::fs::read(&src_path)
347 .ok()
348 .map(|data| (format!("images/{}", image.file_name), data))
349 })
350 .collect()
351 } else {
352 vec![]
353 };
354
355 let zip_path = output_dir.join(format!("{}.zip", group_name));
357 self.write_zip(&subset, images.into_iter(), &zip_path)?;
358
359 result.insert(group_name.clone(), image_indices.len());
360 }
361
362 Ok(result)
363 }
364}
365
366impl Default for CocoWriter {
367 fn default() -> Self {
368 Self::new()
369 }
370}
371
372#[derive(Debug, Default)]
376pub struct CocoDatasetBuilder {
377 dataset: CocoDataset,
378 next_image_id: u64,
379 next_annotation_id: u64,
380 next_category_id: u32,
381}
382
383impl CocoDatasetBuilder {
384 pub fn new() -> Self {
386 Self {
387 dataset: CocoDataset::default(),
388 next_image_id: 1,
389 next_annotation_id: 1,
390 next_category_id: 1,
391 }
392 }
393
394 pub fn info(mut self, info: CocoInfo) -> Self {
396 self.dataset.info = info;
397 self
398 }
399
400 pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
402 for cat in &self.dataset.categories {
404 if cat.name == name {
405 return cat.id;
406 }
407 }
408
409 let id = self.next_category_id;
410 self.next_category_id += 1;
411
412 self.dataset.categories.push(CocoCategory {
413 id,
414 name: name.to_string(),
415 supercategory: supercategory.map(String::from),
416 ..Default::default()
417 });
418
419 id
420 }
421
422 pub fn add_category_with_id(
427 &mut self,
428 id: u32,
429 name: &str,
430 supercategory: Option<&str>,
431 ) -> u32 {
432 for cat in &self.dataset.categories {
434 if cat.id == id || cat.name == name {
435 return cat.id;
436 }
437 }
438
439 self.dataset.categories.push(CocoCategory {
440 id,
441 name: name.to_string(),
442 supercategory: supercategory.map(String::from),
443 ..Default::default()
444 });
445
446 if id >= self.next_category_id {
448 self.next_category_id = id + 1;
449 }
450
451 id
452 }
453
454 pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
456 let id = self.next_image_id;
457 self.next_image_id += 1;
458
459 self.dataset.images.push(CocoImage {
460 id,
461 width,
462 height,
463 file_name: file_name.to_string(),
464 ..Default::default()
465 });
466
467 id
468 }
469
470 pub fn add_annotation(
472 &mut self,
473 image_id: u64,
474 category_id: u32,
475 bbox: [f64; 4],
476 segmentation: Option<CocoSegmentation>,
477 ) -> u64 {
478 self.add_annotation_with_iscrowd(image_id, category_id, bbox, segmentation, 0)
479 }
480
481 pub fn add_annotation_with_iscrowd(
483 &mut self,
484 image_id: u64,
485 category_id: u32,
486 bbox: [f64; 4],
487 segmentation: Option<CocoSegmentation>,
488 iscrowd: u8,
489 ) -> u64 {
490 self.add_annotation_with_id(None, image_id, category_id, bbox, segmentation, iscrowd)
491 }
492
493 pub fn add_annotation_with_id(
511 &mut self,
512 id: Option<u64>,
513 image_id: u64,
514 category_id: u32,
515 bbox: [f64; 4],
516 segmentation: Option<CocoSegmentation>,
517 iscrowd: u8,
518 ) -> u64 {
519 let id = match id {
525 Some(explicit) => {
526 if explicit >= self.next_annotation_id {
527 self.next_annotation_id = explicit.saturating_add(1);
528 }
529 explicit
530 }
531 None => {
532 let auto = self.next_annotation_id;
533 self.next_annotation_id = self.next_annotation_id.saturating_add(1);
534 auto
535 }
536 };
537
538 let area = bbox[2] * bbox[3]; self.dataset.annotations.push(CocoAnnotation {
541 id,
542 image_id,
543 category_id,
544 bbox,
545 area,
546 iscrowd,
547 segmentation,
548 score: None,
549 });
550
551 id
552 }
553
554 pub fn set_annotation_score(&mut self, annotation_id: u64, score: f64) {
556 if let Some(ann) = self
557 .dataset
558 .annotations
559 .iter_mut()
560 .find(|a| a.id == annotation_id)
561 {
562 ann.score = Some(score);
563 }
564 }
565
566 pub fn set_image_neg_categories(
568 &mut self,
569 image_id: u64,
570 neg_category_ids: Option<Vec<u32>>,
571 not_exhaustive_category_ids: Option<Vec<u32>>,
572 ) {
573 if let Some(img) = self.dataset.images.iter_mut().find(|i| i.id == image_id) {
574 img.neg_category_ids = neg_category_ids;
575 img.not_exhaustive_category_ids = not_exhaustive_category_ids;
576 }
577 }
578
579 pub fn set_category_metadata(
584 &mut self,
585 name: &str,
586 synset: Option<String>,
587 frequency: Option<String>,
588 synonyms: Option<Vec<String>>,
589 def: Option<String>,
590 ) {
591 if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
592 if synset.is_some() {
593 cat.synset = synset;
594 }
595 if frequency.is_some() {
596 cat.frequency = frequency;
597 }
598 if synonyms.is_some() {
599 cat.synonyms = synonyms;
600 }
601 if def.is_some() {
602 cat.def = def;
603 }
604 }
605 }
606
607 pub fn set_category_supercategory(&mut self, name: &str, supercategory: &str) {
609 if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
610 cat.supercategory = Some(supercategory.to_string());
611 }
612 }
613
614 pub fn build(self) -> CocoDataset {
616 self.dataset
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623 use tempfile::TempDir;
624
625 #[test]
626 fn test_writer_default() {
627 let writer = CocoWriter::new();
628 assert!(writer.options.compress);
629 assert!(!writer.options.pretty);
630 }
631
632 #[test]
633 fn test_write_json() {
634 let temp_dir = TempDir::new().unwrap();
635 let output_path = temp_dir.path().join("test.json");
636
637 let dataset = CocoDataset {
638 images: vec![CocoImage {
639 id: 1,
640 width: 640,
641 height: 480,
642 file_name: "test.jpg".to_string(),
643 ..Default::default()
644 }],
645 categories: vec![CocoCategory {
646 id: 1,
647 name: "person".to_string(),
648 supercategory: None,
649 ..Default::default()
650 }],
651 annotations: vec![CocoAnnotation {
652 id: 1,
653 image_id: 1,
654 category_id: 1,
655 bbox: [10.0, 20.0, 100.0, 80.0],
656 area: 8000.0,
657 iscrowd: 0,
658 segmentation: None,
659 score: None,
660 }],
661 ..Default::default()
662 };
663
664 let writer = CocoWriter::new();
665 writer.write_json(&dataset, &output_path).unwrap();
666
667 assert!(output_path.exists());
669
670 let contents = std::fs::read_to_string(&output_path).unwrap();
672 let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
673 assert_eq!(restored.images.len(), 1);
674 assert_eq!(restored.annotations.len(), 1);
675 }
676
677 #[test]
678 fn test_write_json_pretty() {
679 let temp_dir = TempDir::new().unwrap();
680 let output_path = temp_dir.path().join("test_pretty.json");
681
682 let dataset = CocoDataset::default();
683
684 let writer = CocoWriter::with_options(CocoWriteOptions {
685 pretty: true,
686 compress: false,
687 });
688 writer.write_json(&dataset, &output_path).unwrap();
689
690 let contents = std::fs::read_to_string(&output_path).unwrap();
691 assert!(contents.contains('\n')); }
693
694 #[test]
695 fn test_dataset_builder() {
696 let mut builder = CocoDatasetBuilder::new();
697
698 let person_id = builder.add_category("person", Some("human"));
700 let car_id = builder.add_category("car", Some("vehicle"));
701
702 assert_eq!(person_id, 1);
703 assert_eq!(car_id, 2);
704
705 let person_id2 = builder.add_category("person", None);
707 assert_eq!(person_id2, 1);
708
709 let img1 = builder.add_image("image1.jpg", 640, 480);
711 let img2 = builder.add_image("image2.jpg", 800, 600);
712
713 assert_eq!(img1, 1);
714 assert_eq!(img2, 2);
715
716 let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
718 let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
719
720 assert_eq!(ann1, 1);
721 assert_eq!(ann2, 2);
722
723 let dataset = builder.build();
725
726 assert_eq!(dataset.categories.len(), 2);
727 assert_eq!(dataset.images.len(), 2);
728 assert_eq!(dataset.annotations.len(), 2);
729 }
730
731 #[test]
732 fn test_add_annotation_with_explicit_id_preserves_id_and_advances_counter() {
733 let mut builder = CocoDatasetBuilder::new();
734 let cat = builder.add_category("dog", None);
735 let img = builder.add_image("image.jpg", 640, 480);
736
737 let ann = builder.add_annotation_with_id(
739 Some(9_876_543_210),
740 img,
741 cat,
742 [10.0, 20.0, 100.0, 80.0],
743 None,
744 0,
745 );
746 assert_eq!(ann, 9_876_543_210);
747
748 let ann_auto = builder.add_annotation(img, cat, [0.0, 0.0, 1.0, 1.0], None);
750 assert!(
751 ann_auto > 9_876_543_210,
752 "auto-generated annotation id ({ann_auto}) must be greater than the largest explicit id"
753 );
754
755 let ann_none =
758 builder.add_annotation_with_id(None, img, cat, [0.0, 0.0, 1.0, 1.0], None, 0);
759 assert_eq!(ann_none, ann_auto + 1);
760
761 let dataset = builder.build();
762 assert_eq!(dataset.annotations.len(), 3);
763 assert_eq!(dataset.annotations[0].id, 9_876_543_210);
764 }
765
766 #[test]
767 fn test_add_annotation_with_explicit_id_at_u64_max_does_not_panic_or_wrap() {
768 let mut builder = CocoDatasetBuilder::new();
776 let cat = builder.add_category("dog", None);
777 let img = builder.add_image("image.jpg", 640, 480);
778
779 let max =
780 builder.add_annotation_with_id(Some(u64::MAX), img, cat, [0.0, 0.0, 1.0, 1.0], None, 0);
781 assert_eq!(max, u64::MAX);
782
783 let saturated = builder.add_annotation(img, cat, [0.0, 0.0, 1.0, 1.0], None);
787 assert_eq!(saturated, u64::MAX);
788 }
789
790 #[test]
791 fn test_write_zip() {
792 let temp_dir = TempDir::new().unwrap();
793 let output_path = temp_dir.path().join("test.zip");
794
795 let dataset = CocoDataset {
796 images: vec![CocoImage {
797 id: 1,
798 width: 100,
799 height: 100,
800 file_name: "test.jpg".to_string(),
801 ..Default::default()
802 }],
803 ..Default::default()
804 };
805
806 let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
808
809 let writer = CocoWriter::new();
810 writer
811 .write_zip(&dataset, images.into_iter(), &output_path)
812 .unwrap();
813
814 assert!(output_path.exists());
816
817 let file = std::fs::File::open(&output_path).unwrap();
819 let mut archive = zip::ZipArchive::new(file).unwrap();
820
821 assert!(archive.by_name("annotations/instances.json").is_ok());
823 assert!(archive.by_name("images/test.jpg").is_ok());
824 }
825
826 #[test]
827 fn test_write_split_by_group() {
828 let temp_dir = TempDir::new().unwrap();
829 let output_dir = temp_dir.path().join("split_output");
830
831 let dataset = CocoDataset {
832 images: vec![
833 CocoImage {
834 id: 1,
835 width: 640,
836 height: 480,
837 file_name: "train1.jpg".to_string(),
838 ..Default::default()
839 },
840 CocoImage {
841 id: 2,
842 width: 640,
843 height: 480,
844 file_name: "train2.jpg".to_string(),
845 ..Default::default()
846 },
847 CocoImage {
848 id: 3,
849 width: 800,
850 height: 600,
851 file_name: "val1.jpg".to_string(),
852 ..Default::default()
853 },
854 ],
855 categories: vec![CocoCategory {
856 id: 1,
857 name: "person".to_string(),
858 supercategory: None,
859 ..Default::default()
860 }],
861 annotations: vec![
862 CocoAnnotation {
863 id: 1,
864 image_id: 1,
865 category_id: 1,
866 bbox: [10.0, 20.0, 100.0, 80.0],
867 ..Default::default()
868 },
869 CocoAnnotation {
870 id: 2,
871 image_id: 2,
872 category_id: 1,
873 bbox: [20.0, 30.0, 100.0, 80.0],
874 ..Default::default()
875 },
876 CocoAnnotation {
877 id: 3,
878 image_id: 3,
879 category_id: 1,
880 bbox: [30.0, 40.0, 100.0, 80.0],
881 ..Default::default()
882 },
883 ],
884 ..Default::default()
885 };
886
887 let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
888
889 let writer = CocoWriter::new();
890 let result = writer
891 .write_split_by_group(&dataset, &groups, None, &output_dir)
892 .unwrap();
893
894 assert_eq!(result.get("train"), Some(&2));
896 assert_eq!(result.get("val"), Some(&1));
897
898 assert!(
900 output_dir
901 .join("train/annotations/instances_train.json")
902 .exists()
903 );
904 assert!(
905 output_dir
906 .join("val/annotations/instances_val.json")
907 .exists()
908 );
909
910 let train_json =
912 std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
913 .unwrap();
914 let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
915 assert_eq!(train_data.images.len(), 2);
916 assert_eq!(train_data.annotations.len(), 2);
917
918 let val_json =
920 std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
921 let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
922 assert_eq!(val_data.images.len(), 1);
923 assert_eq!(val_data.annotations.len(), 1);
924 }
925
926 #[test]
927 fn test_write_split_by_group_mismatch() {
928 let dataset = CocoDataset {
929 images: vec![CocoImage {
930 id: 1,
931 ..Default::default()
932 }],
933 ..Default::default()
934 };
935
936 let groups = vec!["train".to_string(), "val".to_string()];
938
939 let writer = CocoWriter::new();
940 let result =
941 writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
942
943 assert!(result.is_err());
944 }
945}