1use super::types::{
9 CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
10};
11use crate::Error;
12use std::{
13 fs::File,
14 io::{BufWriter, Write},
15 path::Path,
16};
17use zip::{CompressionMethod, write::SimpleFileOptions};
18
19#[derive(Debug, Clone)]
21pub struct CocoWriteOptions {
22 pub compress: bool,
24 pub pretty: bool,
26}
27
28impl Default for CocoWriteOptions {
29 fn default() -> Self {
30 Self {
31 compress: true,
32 pretty: false,
33 }
34 }
35}
36
37pub struct CocoWriter {
50 options: CocoWriteOptions,
51}
52
53impl CocoWriter {
54 pub fn new() -> Self {
56 Self {
57 options: CocoWriteOptions::default(),
58 }
59 }
60
61 pub fn with_options(options: CocoWriteOptions) -> Self {
63 Self { options }
64 }
65
66 pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
72 if let Some(parent) = path.as_ref().parent()
74 && !parent.as_os_str().is_empty()
75 {
76 std::fs::create_dir_all(parent)?;
77 }
78
79 let file = File::create(path.as_ref())?;
80 let writer = BufWriter::with_capacity(64 * 1024, file);
81
82 if self.options.pretty {
83 serde_json::to_writer_pretty(writer, dataset)?;
84 } else {
85 serde_json::to_writer(writer, dataset)?;
86 }
87
88 Ok(())
89 }
90
91 pub fn write_zip<P: AsRef<Path>>(
102 &self,
103 dataset: &CocoDataset,
104 images: impl Iterator<Item = (String, Vec<u8>)>,
105 path: P,
106 ) -> Result<(), Error> {
107 if let Some(parent) = path.as_ref().parent()
109 && !parent.as_os_str().is_empty()
110 {
111 std::fs::create_dir_all(parent)?;
112 }
113
114 let file = File::create(path.as_ref())?;
115 let mut zip = zip::ZipWriter::new(file);
116
117 let options = if self.options.compress {
118 SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
119 } else {
120 SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
121 };
122
123 zip.start_file("annotations/instances.json", options)?;
125 let json = if self.options.pretty {
126 serde_json::to_string_pretty(dataset)?
127 } else {
128 serde_json::to_string(dataset)?
129 };
130 zip.write_all(json.as_bytes())?;
131
132 for (filename, data) in images {
134 zip.start_file(&filename, options)?;
135 zip.write_all(&data)?;
136 }
137
138 zip.finish()?;
139 Ok(())
140 }
141
142 pub fn write_zip_from_dir<P: AsRef<Path>>(
149 &self,
150 dataset: &CocoDataset,
151 images_dir: P,
152 path: P,
153 ) -> Result<(), Error> {
154 let images_dir = images_dir.as_ref();
155
156 let images = dataset.images.iter().filter_map(|img| {
158 let img_path = images_dir.join(&img.file_name);
159 std::fs::read(&img_path)
160 .ok()
161 .map(|data| (format!("images/{}", img.file_name), data))
162 });
163
164 self.write_zip(dataset, images, path)
165 }
166
167 pub fn write_split_by_group<P: AsRef<Path>>(
191 &self,
192 dataset: &CocoDataset,
193 group_assignments: &[String],
194 images_source: Option<&Path>,
195 output_dir: P,
196 ) -> Result<std::collections::HashMap<String, usize>, Error> {
197 use std::collections::{HashMap, HashSet};
198
199 let output_dir = output_dir.as_ref();
200
201 if dataset.images.len() != group_assignments.len() {
203 return Err(Error::CocoError(format!(
204 "Image count ({}) does not match group assignment count ({})",
205 dataset.images.len(),
206 group_assignments.len()
207 )));
208 }
209
210 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
212 for (idx, group) in group_assignments.iter().enumerate() {
213 groups.entry(group.clone()).or_default().push(idx);
214 }
215
216 let mut result = HashMap::new();
217
218 for (group_name, image_indices) in &groups {
219 let group_dir = output_dir.join(group_name);
221 let annotations_dir = group_dir.join("annotations");
222 let images_dir = group_dir.join("images");
223
224 std::fs::create_dir_all(&annotations_dir)?;
225 std::fs::create_dir_all(&images_dir)?;
226
227 let image_ids: HashSet<u64> = image_indices
229 .iter()
230 .map(|&idx| dataset.images[idx].id)
231 .collect();
232
233 let subset = CocoDataset {
234 info: dataset.info.clone(),
235 licenses: dataset.licenses.clone(),
236 images: image_indices
237 .iter()
238 .map(|&idx| dataset.images[idx].clone())
239 .collect(),
240 annotations: dataset
241 .annotations
242 .iter()
243 .filter(|ann| image_ids.contains(&ann.image_id))
244 .cloned()
245 .collect(),
246 categories: dataset.categories.clone(),
247 };
248
249 let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
251 self.write_json(&subset, &ann_file)?;
252
253 if let Some(source) = images_source {
255 for &idx in image_indices {
256 let image = &dataset.images[idx];
257 let src_path = source.join(&image.file_name);
258 let dst_path = images_dir.join(&image.file_name);
259
260 if src_path.exists() {
261 std::fs::copy(&src_path, &dst_path)?;
262 }
263 }
264 }
265
266 result.insert(group_name.clone(), image_indices.len());
267 }
268
269 Ok(result)
270 }
271
272 pub fn write_split_by_group_zip<P: AsRef<Path>>(
288 &self,
289 dataset: &CocoDataset,
290 group_assignments: &[String],
291 images_source: Option<&Path>,
292 output_dir: P,
293 ) -> Result<std::collections::HashMap<String, usize>, Error> {
294 use std::collections::{HashMap, HashSet};
295
296 let output_dir = output_dir.as_ref();
297 std::fs::create_dir_all(output_dir)?;
298
299 if dataset.images.len() != group_assignments.len() {
301 return Err(Error::CocoError(format!(
302 "Image count ({}) does not match group assignment count ({})",
303 dataset.images.len(),
304 group_assignments.len()
305 )));
306 }
307
308 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
310 for (idx, group) in group_assignments.iter().enumerate() {
311 groups.entry(group.clone()).or_default().push(idx);
312 }
313
314 let mut result = HashMap::new();
315
316 for (group_name, image_indices) in &groups {
317 let image_ids: HashSet<u64> = image_indices
319 .iter()
320 .map(|&idx| dataset.images[idx].id)
321 .collect();
322
323 let subset = CocoDataset {
324 info: dataset.info.clone(),
325 licenses: dataset.licenses.clone(),
326 images: image_indices
327 .iter()
328 .map(|&idx| dataset.images[idx].clone())
329 .collect(),
330 annotations: dataset
331 .annotations
332 .iter()
333 .filter(|ann| image_ids.contains(&ann.image_id))
334 .cloned()
335 .collect(),
336 categories: dataset.categories.clone(),
337 };
338
339 let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
341 image_indices
342 .iter()
343 .filter_map(|&idx| {
344 let image = &dataset.images[idx];
345 let src_path = source.join(&image.file_name);
346 std::fs::read(&src_path)
347 .ok()
348 .map(|data| (format!("images/{}", image.file_name), data))
349 })
350 .collect()
351 } else {
352 vec![]
353 };
354
355 let zip_path = output_dir.join(format!("{}.zip", group_name));
357 self.write_zip(&subset, images.into_iter(), &zip_path)?;
358
359 result.insert(group_name.clone(), image_indices.len());
360 }
361
362 Ok(result)
363 }
364}
365
366impl Default for CocoWriter {
367 fn default() -> Self {
368 Self::new()
369 }
370}
371
372#[derive(Debug, Default)]
376pub struct CocoDatasetBuilder {
377 dataset: CocoDataset,
378 next_image_id: u64,
379 next_annotation_id: u64,
380 next_category_id: u32,
381}
382
383impl CocoDatasetBuilder {
384 pub fn new() -> Self {
386 Self {
387 dataset: CocoDataset::default(),
388 next_image_id: 1,
389 next_annotation_id: 1,
390 next_category_id: 1,
391 }
392 }
393
394 pub fn info(mut self, info: CocoInfo) -> Self {
396 self.dataset.info = info;
397 self
398 }
399
400 pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
402 for cat in &self.dataset.categories {
404 if cat.name == name {
405 return cat.id;
406 }
407 }
408
409 let id = self.next_category_id;
410 self.next_category_id += 1;
411
412 self.dataset.categories.push(CocoCategory {
413 id,
414 name: name.to_string(),
415 supercategory: supercategory.map(String::from),
416 });
417
418 id
419 }
420
421 pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
423 let id = self.next_image_id;
424 self.next_image_id += 1;
425
426 self.dataset.images.push(CocoImage {
427 id,
428 width,
429 height,
430 file_name: file_name.to_string(),
431 ..Default::default()
432 });
433
434 id
435 }
436
437 pub fn add_annotation(
439 &mut self,
440 image_id: u64,
441 category_id: u32,
442 bbox: [f64; 4],
443 segmentation: Option<CocoSegmentation>,
444 ) -> u64 {
445 let id = self.next_annotation_id;
446 self.next_annotation_id += 1;
447
448 let area = bbox[2] * bbox[3]; self.dataset.annotations.push(CocoAnnotation {
451 id,
452 image_id,
453 category_id,
454 bbox,
455 area,
456 iscrowd: 0,
457 segmentation,
458 });
459
460 id
461 }
462
463 pub fn build(self) -> CocoDataset {
465 self.dataset
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use tempfile::TempDir;
473
474 #[test]
475 fn test_writer_default() {
476 let writer = CocoWriter::new();
477 assert!(writer.options.compress);
478 assert!(!writer.options.pretty);
479 }
480
481 #[test]
482 fn test_write_json() {
483 let temp_dir = TempDir::new().unwrap();
484 let output_path = temp_dir.path().join("test.json");
485
486 let dataset = CocoDataset {
487 images: vec![CocoImage {
488 id: 1,
489 width: 640,
490 height: 480,
491 file_name: "test.jpg".to_string(),
492 ..Default::default()
493 }],
494 categories: vec![CocoCategory {
495 id: 1,
496 name: "person".to_string(),
497 supercategory: None,
498 }],
499 annotations: vec![CocoAnnotation {
500 id: 1,
501 image_id: 1,
502 category_id: 1,
503 bbox: [10.0, 20.0, 100.0, 80.0],
504 area: 8000.0,
505 iscrowd: 0,
506 segmentation: None,
507 }],
508 ..Default::default()
509 };
510
511 let writer = CocoWriter::new();
512 writer.write_json(&dataset, &output_path).unwrap();
513
514 assert!(output_path.exists());
516
517 let contents = std::fs::read_to_string(&output_path).unwrap();
519 let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
520 assert_eq!(restored.images.len(), 1);
521 assert_eq!(restored.annotations.len(), 1);
522 }
523
524 #[test]
525 fn test_write_json_pretty() {
526 let temp_dir = TempDir::new().unwrap();
527 let output_path = temp_dir.path().join("test_pretty.json");
528
529 let dataset = CocoDataset::default();
530
531 let writer = CocoWriter::with_options(CocoWriteOptions {
532 pretty: true,
533 compress: false,
534 });
535 writer.write_json(&dataset, &output_path).unwrap();
536
537 let contents = std::fs::read_to_string(&output_path).unwrap();
538 assert!(contents.contains('\n')); }
540
541 #[test]
542 fn test_dataset_builder() {
543 let mut builder = CocoDatasetBuilder::new();
544
545 let person_id = builder.add_category("person", Some("human"));
547 let car_id = builder.add_category("car", Some("vehicle"));
548
549 assert_eq!(person_id, 1);
550 assert_eq!(car_id, 2);
551
552 let person_id2 = builder.add_category("person", None);
554 assert_eq!(person_id2, 1);
555
556 let img1 = builder.add_image("image1.jpg", 640, 480);
558 let img2 = builder.add_image("image2.jpg", 800, 600);
559
560 assert_eq!(img1, 1);
561 assert_eq!(img2, 2);
562
563 let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
565 let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
566
567 assert_eq!(ann1, 1);
568 assert_eq!(ann2, 2);
569
570 let dataset = builder.build();
572
573 assert_eq!(dataset.categories.len(), 2);
574 assert_eq!(dataset.images.len(), 2);
575 assert_eq!(dataset.annotations.len(), 2);
576 }
577
578 #[test]
579 fn test_write_zip() {
580 let temp_dir = TempDir::new().unwrap();
581 let output_path = temp_dir.path().join("test.zip");
582
583 let dataset = CocoDataset {
584 images: vec![CocoImage {
585 id: 1,
586 width: 100,
587 height: 100,
588 file_name: "test.jpg".to_string(),
589 ..Default::default()
590 }],
591 ..Default::default()
592 };
593
594 let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
596
597 let writer = CocoWriter::new();
598 writer
599 .write_zip(&dataset, images.into_iter(), &output_path)
600 .unwrap();
601
602 assert!(output_path.exists());
604
605 let file = std::fs::File::open(&output_path).unwrap();
607 let mut archive = zip::ZipArchive::new(file).unwrap();
608
609 assert!(archive.by_name("annotations/instances.json").is_ok());
611 assert!(archive.by_name("images/test.jpg").is_ok());
612 }
613
614 #[test]
615 fn test_write_split_by_group() {
616 let temp_dir = TempDir::new().unwrap();
617 let output_dir = temp_dir.path().join("split_output");
618
619 let dataset = CocoDataset {
620 images: vec![
621 CocoImage {
622 id: 1,
623 width: 640,
624 height: 480,
625 file_name: "train1.jpg".to_string(),
626 ..Default::default()
627 },
628 CocoImage {
629 id: 2,
630 width: 640,
631 height: 480,
632 file_name: "train2.jpg".to_string(),
633 ..Default::default()
634 },
635 CocoImage {
636 id: 3,
637 width: 800,
638 height: 600,
639 file_name: "val1.jpg".to_string(),
640 ..Default::default()
641 },
642 ],
643 categories: vec![CocoCategory {
644 id: 1,
645 name: "person".to_string(),
646 supercategory: None,
647 }],
648 annotations: vec![
649 CocoAnnotation {
650 id: 1,
651 image_id: 1,
652 category_id: 1,
653 bbox: [10.0, 20.0, 100.0, 80.0],
654 ..Default::default()
655 },
656 CocoAnnotation {
657 id: 2,
658 image_id: 2,
659 category_id: 1,
660 bbox: [20.0, 30.0, 100.0, 80.0],
661 ..Default::default()
662 },
663 CocoAnnotation {
664 id: 3,
665 image_id: 3,
666 category_id: 1,
667 bbox: [30.0, 40.0, 100.0, 80.0],
668 ..Default::default()
669 },
670 ],
671 ..Default::default()
672 };
673
674 let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
675
676 let writer = CocoWriter::new();
677 let result = writer
678 .write_split_by_group(&dataset, &groups, None, &output_dir)
679 .unwrap();
680
681 assert_eq!(result.get("train"), Some(&2));
683 assert_eq!(result.get("val"), Some(&1));
684
685 assert!(
687 output_dir
688 .join("train/annotations/instances_train.json")
689 .exists()
690 );
691 assert!(
692 output_dir
693 .join("val/annotations/instances_val.json")
694 .exists()
695 );
696
697 let train_json =
699 std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
700 .unwrap();
701 let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
702 assert_eq!(train_data.images.len(), 2);
703 assert_eq!(train_data.annotations.len(), 2);
704
705 let val_json =
707 std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
708 let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
709 assert_eq!(val_data.images.len(), 1);
710 assert_eq!(val_data.annotations.len(), 1);
711 }
712
713 #[test]
714 fn test_write_split_by_group_mismatch() {
715 let dataset = CocoDataset {
716 images: vec![CocoImage {
717 id: 1,
718 ..Default::default()
719 }],
720 ..Default::default()
721 };
722
723 let groups = vec!["train".to_string(), "val".to_string()];
725
726 let writer = CocoWriter::new();
727 let result =
728 writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
729
730 assert!(result.is_err());
731 }
732}