1use super::types::{
9 CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
10};
11use crate::Error;
12use std::{
13 fs::File,
14 io::{BufWriter, Write},
15 path::Path,
16};
17use zip::{CompressionMethod, write::SimpleFileOptions};
18
19#[derive(Debug, Clone)]
21pub struct CocoWriteOptions {
22 pub compress: bool,
24 pub pretty: bool,
26}
27
28impl Default for CocoWriteOptions {
29 fn default() -> Self {
30 Self {
31 compress: true,
32 pretty: false,
33 }
34 }
35}
36
37pub struct CocoWriter {
50 options: CocoWriteOptions,
51}
52
53impl CocoWriter {
54 pub fn new() -> Self {
56 Self {
57 options: CocoWriteOptions::default(),
58 }
59 }
60
61 pub fn with_options(options: CocoWriteOptions) -> Self {
63 Self { options }
64 }
65
66 pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
72 if let Some(parent) = path.as_ref().parent()
74 && !parent.as_os_str().is_empty()
75 {
76 std::fs::create_dir_all(parent)?;
77 }
78
79 let file = File::create(path.as_ref())?;
80 let writer = BufWriter::with_capacity(64 * 1024, file);
81
82 if self.options.pretty {
83 serde_json::to_writer_pretty(writer, dataset)?;
84 } else {
85 serde_json::to_writer(writer, dataset)?;
86 }
87
88 Ok(())
89 }
90
91 pub fn write_zip<P: AsRef<Path>>(
102 &self,
103 dataset: &CocoDataset,
104 images: impl Iterator<Item = (String, Vec<u8>)>,
105 path: P,
106 ) -> Result<(), Error> {
107 if let Some(parent) = path.as_ref().parent()
109 && !parent.as_os_str().is_empty()
110 {
111 std::fs::create_dir_all(parent)?;
112 }
113
114 let file = File::create(path.as_ref())?;
115 let mut zip = zip::ZipWriter::new(file);
116
117 let options = if self.options.compress {
118 SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
119 } else {
120 SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
121 };
122
123 zip.start_file("annotations/instances.json", options)?;
125 let json = if self.options.pretty {
126 serde_json::to_string_pretty(dataset)?
127 } else {
128 serde_json::to_string(dataset)?
129 };
130 zip.write_all(json.as_bytes())?;
131
132 for (filename, data) in images {
134 zip.start_file(&filename, options)?;
135 zip.write_all(&data)?;
136 }
137
138 zip.finish()?;
139 Ok(())
140 }
141
142 pub fn write_zip_from_dir<P: AsRef<Path>>(
149 &self,
150 dataset: &CocoDataset,
151 images_dir: P,
152 path: P,
153 ) -> Result<(), Error> {
154 let images_dir = images_dir.as_ref();
155
156 let images = dataset.images.iter().filter_map(|img| {
158 let img_path = images_dir.join(&img.file_name);
159 std::fs::read(&img_path)
160 .ok()
161 .map(|data| (format!("images/{}", img.file_name), data))
162 });
163
164 self.write_zip(dataset, images, path)
165 }
166
167 pub fn write_split_by_group<P: AsRef<Path>>(
191 &self,
192 dataset: &CocoDataset,
193 group_assignments: &[String],
194 images_source: Option<&Path>,
195 output_dir: P,
196 ) -> Result<std::collections::HashMap<String, usize>, Error> {
197 use std::collections::{HashMap, HashSet};
198
199 let output_dir = output_dir.as_ref();
200
201 if dataset.images.len() != group_assignments.len() {
203 return Err(Error::CocoError(format!(
204 "Image count ({}) does not match group assignment count ({})",
205 dataset.images.len(),
206 group_assignments.len()
207 )));
208 }
209
210 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
212 for (idx, group) in group_assignments.iter().enumerate() {
213 groups.entry(group.clone()).or_default().push(idx);
214 }
215
216 let mut result = HashMap::new();
217
218 for (group_name, image_indices) in &groups {
219 let group_dir = output_dir.join(group_name);
221 let annotations_dir = group_dir.join("annotations");
222 let images_dir = group_dir.join("images");
223
224 std::fs::create_dir_all(&annotations_dir)?;
225 std::fs::create_dir_all(&images_dir)?;
226
227 let image_ids: HashSet<u64> = image_indices
229 .iter()
230 .map(|&idx| dataset.images[idx].id)
231 .collect();
232
233 let subset = CocoDataset {
234 info: dataset.info.clone(),
235 licenses: dataset.licenses.clone(),
236 images: image_indices
237 .iter()
238 .map(|&idx| dataset.images[idx].clone())
239 .collect(),
240 annotations: dataset
241 .annotations
242 .iter()
243 .filter(|ann| image_ids.contains(&ann.image_id))
244 .cloned()
245 .collect(),
246 categories: dataset.categories.clone(),
247 };
248
249 let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
251 self.write_json(&subset, &ann_file)?;
252
253 if let Some(source) = images_source {
255 for &idx in image_indices {
256 let image = &dataset.images[idx];
257 let src_path = source.join(&image.file_name);
258 let dst_path = images_dir.join(&image.file_name);
259
260 if src_path.exists() {
261 std::fs::copy(&src_path, &dst_path)?;
262 }
263 }
264 }
265
266 result.insert(group_name.clone(), image_indices.len());
267 }
268
269 Ok(result)
270 }
271
272 pub fn write_split_by_group_zip<P: AsRef<Path>>(
288 &self,
289 dataset: &CocoDataset,
290 group_assignments: &[String],
291 images_source: Option<&Path>,
292 output_dir: P,
293 ) -> Result<std::collections::HashMap<String, usize>, Error> {
294 use std::collections::{HashMap, HashSet};
295
296 let output_dir = output_dir.as_ref();
297 std::fs::create_dir_all(output_dir)?;
298
299 if dataset.images.len() != group_assignments.len() {
301 return Err(Error::CocoError(format!(
302 "Image count ({}) does not match group assignment count ({})",
303 dataset.images.len(),
304 group_assignments.len()
305 )));
306 }
307
308 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
310 for (idx, group) in group_assignments.iter().enumerate() {
311 groups.entry(group.clone()).or_default().push(idx);
312 }
313
314 let mut result = HashMap::new();
315
316 for (group_name, image_indices) in &groups {
317 let image_ids: HashSet<u64> = image_indices
319 .iter()
320 .map(|&idx| dataset.images[idx].id)
321 .collect();
322
323 let subset = CocoDataset {
324 info: dataset.info.clone(),
325 licenses: dataset.licenses.clone(),
326 images: image_indices
327 .iter()
328 .map(|&idx| dataset.images[idx].clone())
329 .collect(),
330 annotations: dataset
331 .annotations
332 .iter()
333 .filter(|ann| image_ids.contains(&ann.image_id))
334 .cloned()
335 .collect(),
336 categories: dataset.categories.clone(),
337 };
338
339 let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
341 image_indices
342 .iter()
343 .filter_map(|&idx| {
344 let image = &dataset.images[idx];
345 let src_path = source.join(&image.file_name);
346 std::fs::read(&src_path)
347 .ok()
348 .map(|data| (format!("images/{}", image.file_name), data))
349 })
350 .collect()
351 } else {
352 vec![]
353 };
354
355 let zip_path = output_dir.join(format!("{}.zip", group_name));
357 self.write_zip(&subset, images.into_iter(), &zip_path)?;
358
359 result.insert(group_name.clone(), image_indices.len());
360 }
361
362 Ok(result)
363 }
364}
365
366impl Default for CocoWriter {
367 fn default() -> Self {
368 Self::new()
369 }
370}
371
372#[derive(Debug, Default)]
376pub struct CocoDatasetBuilder {
377 dataset: CocoDataset,
378 next_image_id: u64,
379 next_annotation_id: u64,
380 next_category_id: u32,
381}
382
383impl CocoDatasetBuilder {
384 pub fn new() -> Self {
386 Self {
387 dataset: CocoDataset::default(),
388 next_image_id: 1,
389 next_annotation_id: 1,
390 next_category_id: 1,
391 }
392 }
393
394 pub fn info(mut self, info: CocoInfo) -> Self {
396 self.dataset.info = info;
397 self
398 }
399
400 pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
402 for cat in &self.dataset.categories {
404 if cat.name == name {
405 return cat.id;
406 }
407 }
408
409 let id = self.next_category_id;
410 self.next_category_id += 1;
411
412 self.dataset.categories.push(CocoCategory {
413 id,
414 name: name.to_string(),
415 supercategory: supercategory.map(String::from),
416 ..Default::default()
417 });
418
419 id
420 }
421
422 pub fn add_category_with_id(
427 &mut self,
428 id: u32,
429 name: &str,
430 supercategory: Option<&str>,
431 ) -> u32 {
432 for cat in &self.dataset.categories {
434 if cat.id == id || cat.name == name {
435 return cat.id;
436 }
437 }
438
439 self.dataset.categories.push(CocoCategory {
440 id,
441 name: name.to_string(),
442 supercategory: supercategory.map(String::from),
443 ..Default::default()
444 });
445
446 if id >= self.next_category_id {
448 self.next_category_id = id + 1;
449 }
450
451 id
452 }
453
454 pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
456 let id = self.next_image_id;
457 self.next_image_id += 1;
458
459 self.dataset.images.push(CocoImage {
460 id,
461 width,
462 height,
463 file_name: file_name.to_string(),
464 ..Default::default()
465 });
466
467 id
468 }
469
470 pub fn add_annotation(
472 &mut self,
473 image_id: u64,
474 category_id: u32,
475 bbox: [f64; 4],
476 segmentation: Option<CocoSegmentation>,
477 ) -> u64 {
478 self.add_annotation_with_iscrowd(image_id, category_id, bbox, segmentation, 0)
479 }
480
481 pub fn add_annotation_with_iscrowd(
483 &mut self,
484 image_id: u64,
485 category_id: u32,
486 bbox: [f64; 4],
487 segmentation: Option<CocoSegmentation>,
488 iscrowd: u8,
489 ) -> u64 {
490 let id = self.next_annotation_id;
491 self.next_annotation_id += 1;
492
493 let area = bbox[2] * bbox[3]; self.dataset.annotations.push(CocoAnnotation {
496 id,
497 image_id,
498 category_id,
499 bbox,
500 area,
501 iscrowd,
502 segmentation,
503 score: None,
504 });
505
506 id
507 }
508
509 pub fn set_annotation_score(&mut self, annotation_id: u64, score: f64) {
511 if let Some(ann) = self
512 .dataset
513 .annotations
514 .iter_mut()
515 .find(|a| a.id == annotation_id)
516 {
517 ann.score = Some(score);
518 }
519 }
520
521 pub fn set_image_neg_categories(
523 &mut self,
524 image_id: u64,
525 neg_category_ids: Option<Vec<u32>>,
526 not_exhaustive_category_ids: Option<Vec<u32>>,
527 ) {
528 if let Some(img) = self.dataset.images.iter_mut().find(|i| i.id == image_id) {
529 img.neg_category_ids = neg_category_ids;
530 img.not_exhaustive_category_ids = not_exhaustive_category_ids;
531 }
532 }
533
534 pub fn set_category_metadata(
539 &mut self,
540 name: &str,
541 synset: Option<String>,
542 frequency: Option<String>,
543 synonyms: Option<Vec<String>>,
544 def: Option<String>,
545 ) {
546 if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
547 if synset.is_some() {
548 cat.synset = synset;
549 }
550 if frequency.is_some() {
551 cat.frequency = frequency;
552 }
553 if synonyms.is_some() {
554 cat.synonyms = synonyms;
555 }
556 if def.is_some() {
557 cat.def = def;
558 }
559 }
560 }
561
562 pub fn set_category_supercategory(&mut self, name: &str, supercategory: &str) {
564 if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
565 cat.supercategory = Some(supercategory.to_string());
566 }
567 }
568
569 pub fn build(self) -> CocoDataset {
571 self.dataset
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use tempfile::TempDir;
579
580 #[test]
581 fn test_writer_default() {
582 let writer = CocoWriter::new();
583 assert!(writer.options.compress);
584 assert!(!writer.options.pretty);
585 }
586
587 #[test]
588 fn test_write_json() {
589 let temp_dir = TempDir::new().unwrap();
590 let output_path = temp_dir.path().join("test.json");
591
592 let dataset = CocoDataset {
593 images: vec![CocoImage {
594 id: 1,
595 width: 640,
596 height: 480,
597 file_name: "test.jpg".to_string(),
598 ..Default::default()
599 }],
600 categories: vec![CocoCategory {
601 id: 1,
602 name: "person".to_string(),
603 supercategory: None,
604 ..Default::default()
605 }],
606 annotations: vec![CocoAnnotation {
607 id: 1,
608 image_id: 1,
609 category_id: 1,
610 bbox: [10.0, 20.0, 100.0, 80.0],
611 area: 8000.0,
612 iscrowd: 0,
613 segmentation: None,
614 score: None,
615 }],
616 ..Default::default()
617 };
618
619 let writer = CocoWriter::new();
620 writer.write_json(&dataset, &output_path).unwrap();
621
622 assert!(output_path.exists());
624
625 let contents = std::fs::read_to_string(&output_path).unwrap();
627 let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
628 assert_eq!(restored.images.len(), 1);
629 assert_eq!(restored.annotations.len(), 1);
630 }
631
632 #[test]
633 fn test_write_json_pretty() {
634 let temp_dir = TempDir::new().unwrap();
635 let output_path = temp_dir.path().join("test_pretty.json");
636
637 let dataset = CocoDataset::default();
638
639 let writer = CocoWriter::with_options(CocoWriteOptions {
640 pretty: true,
641 compress: false,
642 });
643 writer.write_json(&dataset, &output_path).unwrap();
644
645 let contents = std::fs::read_to_string(&output_path).unwrap();
646 assert!(contents.contains('\n')); }
648
649 #[test]
650 fn test_dataset_builder() {
651 let mut builder = CocoDatasetBuilder::new();
652
653 let person_id = builder.add_category("person", Some("human"));
655 let car_id = builder.add_category("car", Some("vehicle"));
656
657 assert_eq!(person_id, 1);
658 assert_eq!(car_id, 2);
659
660 let person_id2 = builder.add_category("person", None);
662 assert_eq!(person_id2, 1);
663
664 let img1 = builder.add_image("image1.jpg", 640, 480);
666 let img2 = builder.add_image("image2.jpg", 800, 600);
667
668 assert_eq!(img1, 1);
669 assert_eq!(img2, 2);
670
671 let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
673 let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
674
675 assert_eq!(ann1, 1);
676 assert_eq!(ann2, 2);
677
678 let dataset = builder.build();
680
681 assert_eq!(dataset.categories.len(), 2);
682 assert_eq!(dataset.images.len(), 2);
683 assert_eq!(dataset.annotations.len(), 2);
684 }
685
686 #[test]
687 fn test_write_zip() {
688 let temp_dir = TempDir::new().unwrap();
689 let output_path = temp_dir.path().join("test.zip");
690
691 let dataset = CocoDataset {
692 images: vec![CocoImage {
693 id: 1,
694 width: 100,
695 height: 100,
696 file_name: "test.jpg".to_string(),
697 ..Default::default()
698 }],
699 ..Default::default()
700 };
701
702 let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
704
705 let writer = CocoWriter::new();
706 writer
707 .write_zip(&dataset, images.into_iter(), &output_path)
708 .unwrap();
709
710 assert!(output_path.exists());
712
713 let file = std::fs::File::open(&output_path).unwrap();
715 let mut archive = zip::ZipArchive::new(file).unwrap();
716
717 assert!(archive.by_name("annotations/instances.json").is_ok());
719 assert!(archive.by_name("images/test.jpg").is_ok());
720 }
721
722 #[test]
723 fn test_write_split_by_group() {
724 let temp_dir = TempDir::new().unwrap();
725 let output_dir = temp_dir.path().join("split_output");
726
727 let dataset = CocoDataset {
728 images: vec![
729 CocoImage {
730 id: 1,
731 width: 640,
732 height: 480,
733 file_name: "train1.jpg".to_string(),
734 ..Default::default()
735 },
736 CocoImage {
737 id: 2,
738 width: 640,
739 height: 480,
740 file_name: "train2.jpg".to_string(),
741 ..Default::default()
742 },
743 CocoImage {
744 id: 3,
745 width: 800,
746 height: 600,
747 file_name: "val1.jpg".to_string(),
748 ..Default::default()
749 },
750 ],
751 categories: vec![CocoCategory {
752 id: 1,
753 name: "person".to_string(),
754 supercategory: None,
755 ..Default::default()
756 }],
757 annotations: vec![
758 CocoAnnotation {
759 id: 1,
760 image_id: 1,
761 category_id: 1,
762 bbox: [10.0, 20.0, 100.0, 80.0],
763 ..Default::default()
764 },
765 CocoAnnotation {
766 id: 2,
767 image_id: 2,
768 category_id: 1,
769 bbox: [20.0, 30.0, 100.0, 80.0],
770 ..Default::default()
771 },
772 CocoAnnotation {
773 id: 3,
774 image_id: 3,
775 category_id: 1,
776 bbox: [30.0, 40.0, 100.0, 80.0],
777 ..Default::default()
778 },
779 ],
780 ..Default::default()
781 };
782
783 let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
784
785 let writer = CocoWriter::new();
786 let result = writer
787 .write_split_by_group(&dataset, &groups, None, &output_dir)
788 .unwrap();
789
790 assert_eq!(result.get("train"), Some(&2));
792 assert_eq!(result.get("val"), Some(&1));
793
794 assert!(
796 output_dir
797 .join("train/annotations/instances_train.json")
798 .exists()
799 );
800 assert!(
801 output_dir
802 .join("val/annotations/instances_val.json")
803 .exists()
804 );
805
806 let train_json =
808 std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
809 .unwrap();
810 let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
811 assert_eq!(train_data.images.len(), 2);
812 assert_eq!(train_data.annotations.len(), 2);
813
814 let val_json =
816 std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
817 let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
818 assert_eq!(val_data.images.len(), 1);
819 assert_eq!(val_data.annotations.len(), 1);
820 }
821
822 #[test]
823 fn test_write_split_by_group_mismatch() {
824 let dataset = CocoDataset {
825 images: vec![CocoImage {
826 id: 1,
827 ..Default::default()
828 }],
829 ..Default::default()
830 };
831
832 let groups = vec!["train".to_string(), "val".to_string()];
834
835 let writer = CocoWriter::new();
836 let result =
837 writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
838
839 assert!(result.is_err());
840 }
841}