edgefirst_client/coco/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Streaming COCO JSON/ZIP writers.
5//!
6//! Provides efficient writing of COCO annotation files to JSON or ZIP archives.
7
8use super::types::{
9    CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
10};
11use crate::Error;
12use std::{
13    fs::File,
14    io::{BufWriter, Write},
15    path::Path,
16};
17use zip::{CompressionMethod, write::SimpleFileOptions};
18
19/// Options for COCO writing.
20#[derive(Debug, Clone)]
21pub struct CocoWriteOptions {
22    /// Compress output (for ZIP).
23    pub compress: bool,
24    /// Pretty-print JSON with indentation.
25    pub pretty: bool,
26}
27
28impl Default for CocoWriteOptions {
29    fn default() -> Self {
30        Self {
31            compress: true,
32            pretty: false,
33        }
34    }
35}
36
37/// COCO writer for generating JSON and ZIP files.
38///
39/// # Example
40///
41/// ```rust,no_run
42/// use edgefirst_client::coco::{CocoDataset, CocoWriter};
43///
44/// let writer = CocoWriter::new();
45/// let dataset = CocoDataset::default();
46/// writer.write_json(&dataset, "annotations.json")?;
47/// # Ok::<(), edgefirst_client::Error>(())
48/// ```
49pub struct CocoWriter {
50    options: CocoWriteOptions,
51}
52
53impl CocoWriter {
54    /// Create a new COCO writer with default options.
55    pub fn new() -> Self {
56        Self {
57            options: CocoWriteOptions::default(),
58        }
59    }
60
61    /// Create a new COCO writer with custom options.
62    pub fn with_options(options: CocoWriteOptions) -> Self {
63        Self { options }
64    }
65
66    /// Write COCO dataset to a JSON file.
67    ///
68    /// # Arguments
69    /// * `dataset` - The COCO dataset to write
70    /// * `path` - Output file path
71    pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
72        // Ensure parent directory exists
73        if let Some(parent) = path.as_ref().parent()
74            && !parent.as_os_str().is_empty()
75        {
76            std::fs::create_dir_all(parent)?;
77        }
78
79        let file = File::create(path.as_ref())?;
80        let writer = BufWriter::with_capacity(64 * 1024, file);
81
82        if self.options.pretty {
83            serde_json::to_writer_pretty(writer, dataset)?;
84        } else {
85            serde_json::to_writer(writer, dataset)?;
86        }
87
88        Ok(())
89    }
90
91    /// Write COCO dataset to a ZIP file with images.
92    ///
93    /// Creates a ZIP archive with:
94    /// - `annotations/instances.json` - The COCO annotations
95    /// - Images at their original relative paths
96    ///
97    /// # Arguments
98    /// * `dataset` - The COCO dataset to write
99    /// * `images` - Iterator of `(filename, image_data)` pairs
100    /// * `path` - Output ZIP file path
101    pub fn write_zip<P: AsRef<Path>>(
102        &self,
103        dataset: &CocoDataset,
104        images: impl Iterator<Item = (String, Vec<u8>)>,
105        path: P,
106    ) -> Result<(), Error> {
107        // Ensure parent directory exists
108        if let Some(parent) = path.as_ref().parent()
109            && !parent.as_os_str().is_empty()
110        {
111            std::fs::create_dir_all(parent)?;
112        }
113
114        let file = File::create(path.as_ref())?;
115        let mut zip = zip::ZipWriter::new(file);
116
117        let options = if self.options.compress {
118            SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
119        } else {
120            SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
121        };
122
123        // Write annotations
124        zip.start_file("annotations/instances.json", options)?;
125        let json = if self.options.pretty {
126            serde_json::to_string_pretty(dataset)?
127        } else {
128            serde_json::to_string(dataset)?
129        };
130        zip.write_all(json.as_bytes())?;
131
132        // Write images
133        for (filename, data) in images {
134            zip.start_file(&filename, options)?;
135            zip.write_all(&data)?;
136        }
137
138        zip.finish()?;
139        Ok(())
140    }
141
142    /// Write COCO dataset to a ZIP file with images from a source directory.
143    ///
144    /// # Arguments
145    /// * `dataset` - The COCO dataset to write
146    /// * `images_dir` - Directory containing source images
147    /// * `path` - Output ZIP file path
148    pub fn write_zip_from_dir<P: AsRef<Path>>(
149        &self,
150        dataset: &CocoDataset,
151        images_dir: P,
152        path: P,
153    ) -> Result<(), Error> {
154        let images_dir = images_dir.as_ref();
155
156        // Collect image data
157        let images = dataset.images.iter().filter_map(|img| {
158            let img_path = images_dir.join(&img.file_name);
159            std::fs::read(&img_path)
160                .ok()
161                .map(|data| (format!("images/{}", img.file_name), data))
162        });
163
164        self.write_zip(dataset, images, path)
165    }
166
167    /// Split a dataset by group and write each group to its own directory.
168    ///
169    /// Creates a directory structure like:
170    /// ```text
171    /// output_dir/
172    /// ├── train/
173    /// │   ├── annotations/instances_train.json
174    /// │   └── images/
175    /// │       └── *.jpg
176    /// └── val/
177    ///     ├── annotations/instances_val.json
178    ///     └── images/
179    ///         └── *.jpg
180    /// ```
181    ///
182    /// # Arguments
183    /// * `dataset` - The COCO dataset to split
184    /// * `group_assignments` - Parallel array of group names for each image
185    /// * `images_source` - Optional source directory containing images to copy
186    /// * `output_dir` - Output root directory
187    ///
188    /// # Returns
189    /// HashMap of group name → number of images written
190    pub fn write_split_by_group<P: AsRef<Path>>(
191        &self,
192        dataset: &CocoDataset,
193        group_assignments: &[String],
194        images_source: Option<&Path>,
195        output_dir: P,
196    ) -> Result<std::collections::HashMap<String, usize>, Error> {
197        use std::collections::{HashMap, HashSet};
198
199        let output_dir = output_dir.as_ref();
200
201        // Validate input
202        if dataset.images.len() != group_assignments.len() {
203            return Err(Error::CocoError(format!(
204                "Image count ({}) does not match group assignment count ({})",
205                dataset.images.len(),
206                group_assignments.len()
207            )));
208        }
209
210        // Build groups
211        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
212        for (idx, group) in group_assignments.iter().enumerate() {
213            groups.entry(group.clone()).or_default().push(idx);
214        }
215
216        let mut result = HashMap::new();
217
218        for (group_name, image_indices) in &groups {
219            // Create subdirectory structure
220            let group_dir = output_dir.join(group_name);
221            let annotations_dir = group_dir.join("annotations");
222            let images_dir = group_dir.join("images");
223
224            std::fs::create_dir_all(&annotations_dir)?;
225            std::fs::create_dir_all(&images_dir)?;
226
227            // Build subset dataset for this group
228            let image_ids: HashSet<u64> = image_indices
229                .iter()
230                .map(|&idx| dataset.images[idx].id)
231                .collect();
232
233            let subset = CocoDataset {
234                info: dataset.info.clone(),
235                licenses: dataset.licenses.clone(),
236                images: image_indices
237                    .iter()
238                    .map(|&idx| dataset.images[idx].clone())
239                    .collect(),
240                annotations: dataset
241                    .annotations
242                    .iter()
243                    .filter(|ann| image_ids.contains(&ann.image_id))
244                    .cloned()
245                    .collect(),
246                categories: dataset.categories.clone(),
247            };
248
249            // Write annotations JSON
250            let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
251            self.write_json(&subset, &ann_file)?;
252
253            // Copy images if source provided
254            if let Some(source) = images_source {
255                for &idx in image_indices {
256                    let image = &dataset.images[idx];
257                    let src_path = source.join(&image.file_name);
258                    let dst_path = images_dir.join(&image.file_name);
259
260                    if src_path.exists() {
261                        std::fs::copy(&src_path, &dst_path)?;
262                    }
263                }
264            }
265
266            result.insert(group_name.clone(), image_indices.len());
267        }
268
269        Ok(result)
270    }
271
272    /// Split a dataset by group and write each group to its own ZIP archive.
273    ///
274    /// Creates ZIP archives like:
275    /// - `output_dir/train.zip` containing train split
276    /// - `output_dir/val.zip` containing val split
277    ///
278    /// # Arguments
279    /// * `dataset` - The COCO dataset to split
280    /// * `group_assignments` - Parallel array of group names for each image
281    /// * `images_source` - Optional source directory containing images to
282    ///   include
283    /// * `output_dir` - Output directory for ZIP files
284    ///
285    /// # Returns
286    /// HashMap of group name → number of images written
287    pub fn write_split_by_group_zip<P: AsRef<Path>>(
288        &self,
289        dataset: &CocoDataset,
290        group_assignments: &[String],
291        images_source: Option<&Path>,
292        output_dir: P,
293    ) -> Result<std::collections::HashMap<String, usize>, Error> {
294        use std::collections::{HashMap, HashSet};
295
296        let output_dir = output_dir.as_ref();
297        std::fs::create_dir_all(output_dir)?;
298
299        // Validate input
300        if dataset.images.len() != group_assignments.len() {
301            return Err(Error::CocoError(format!(
302                "Image count ({}) does not match group assignment count ({})",
303                dataset.images.len(),
304                group_assignments.len()
305            )));
306        }
307
308        // Build groups
309        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
310        for (idx, group) in group_assignments.iter().enumerate() {
311            groups.entry(group.clone()).or_default().push(idx);
312        }
313
314        let mut result = HashMap::new();
315
316        for (group_name, image_indices) in &groups {
317            // Build subset dataset for this group
318            let image_ids: HashSet<u64> = image_indices
319                .iter()
320                .map(|&idx| dataset.images[idx].id)
321                .collect();
322
323            let subset = CocoDataset {
324                info: dataset.info.clone(),
325                licenses: dataset.licenses.clone(),
326                images: image_indices
327                    .iter()
328                    .map(|&idx| dataset.images[idx].clone())
329                    .collect(),
330                annotations: dataset
331                    .annotations
332                    .iter()
333                    .filter(|ann| image_ids.contains(&ann.image_id))
334                    .cloned()
335                    .collect(),
336                categories: dataset.categories.clone(),
337            };
338
339            // Collect images if source provided
340            let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
341                image_indices
342                    .iter()
343                    .filter_map(|&idx| {
344                        let image = &dataset.images[idx];
345                        let src_path = source.join(&image.file_name);
346                        std::fs::read(&src_path)
347                            .ok()
348                            .map(|data| (format!("images/{}", image.file_name), data))
349                    })
350                    .collect()
351            } else {
352                vec![]
353            };
354
355            // Write ZIP
356            let zip_path = output_dir.join(format!("{}.zip", group_name));
357            self.write_zip(&subset, images.into_iter(), &zip_path)?;
358
359            result.insert(group_name.clone(), image_indices.len());
360        }
361
362        Ok(result)
363    }
364}
365
366impl Default for CocoWriter {
367    fn default() -> Self {
368        Self::new()
369    }
370}
371
372/// Builder for constructing a COCO dataset.
373///
374/// Provides a convenient API for incrementally building a COCO dataset.
375#[derive(Debug, Default)]
376pub struct CocoDatasetBuilder {
377    dataset: CocoDataset,
378    next_image_id: u64,
379    next_annotation_id: u64,
380    next_category_id: u32,
381}
382
383impl CocoDatasetBuilder {
384    /// Create a new dataset builder.
385    pub fn new() -> Self {
386        Self {
387            dataset: CocoDataset::default(),
388            next_image_id: 1,
389            next_annotation_id: 1,
390            next_category_id: 1,
391        }
392    }
393
394    /// Set dataset info.
395    pub fn info(mut self, info: CocoInfo) -> Self {
396        self.dataset.info = info;
397        self
398    }
399
400    /// Add a category, returning its ID.
401    pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
402        // Check if category already exists
403        for cat in &self.dataset.categories {
404            if cat.name == name {
405                return cat.id;
406            }
407        }
408
409        let id = self.next_category_id;
410        self.next_category_id += 1;
411
412        self.dataset.categories.push(CocoCategory {
413            id,
414            name: name.to_string(),
415            supercategory: supercategory.map(String::from),
416        });
417
418        id
419    }
420
421    /// Add an image, returning its ID.
422    pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
423        let id = self.next_image_id;
424        self.next_image_id += 1;
425
426        self.dataset.images.push(CocoImage {
427            id,
428            width,
429            height,
430            file_name: file_name.to_string(),
431            ..Default::default()
432        });
433
434        id
435    }
436
437    /// Add an annotation, returning its ID.
438    pub fn add_annotation(
439        &mut self,
440        image_id: u64,
441        category_id: u32,
442        bbox: [f64; 4],
443        segmentation: Option<CocoSegmentation>,
444    ) -> u64 {
445        let id = self.next_annotation_id;
446        self.next_annotation_id += 1;
447
448        let area = bbox[2] * bbox[3]; // Default area from bbox
449
450        self.dataset.annotations.push(CocoAnnotation {
451            id,
452            image_id,
453            category_id,
454            bbox,
455            area,
456            iscrowd: 0,
457            segmentation,
458        });
459
460        id
461    }
462
463    /// Build the final dataset.
464    pub fn build(self) -> CocoDataset {
465        self.dataset
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use tempfile::TempDir;
473
474    #[test]
475    fn test_writer_default() {
476        let writer = CocoWriter::new();
477        assert!(writer.options.compress);
478        assert!(!writer.options.pretty);
479    }
480
481    #[test]
482    fn test_write_json() {
483        let temp_dir = TempDir::new().unwrap();
484        let output_path = temp_dir.path().join("test.json");
485
486        let dataset = CocoDataset {
487            images: vec![CocoImage {
488                id: 1,
489                width: 640,
490                height: 480,
491                file_name: "test.jpg".to_string(),
492                ..Default::default()
493            }],
494            categories: vec![CocoCategory {
495                id: 1,
496                name: "person".to_string(),
497                supercategory: None,
498            }],
499            annotations: vec![CocoAnnotation {
500                id: 1,
501                image_id: 1,
502                category_id: 1,
503                bbox: [10.0, 20.0, 100.0, 80.0],
504                area: 8000.0,
505                iscrowd: 0,
506                segmentation: None,
507            }],
508            ..Default::default()
509        };
510
511        let writer = CocoWriter::new();
512        writer.write_json(&dataset, &output_path).unwrap();
513
514        // Verify file was created
515        assert!(output_path.exists());
516
517        // Read it back and verify
518        let contents = std::fs::read_to_string(&output_path).unwrap();
519        let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
520        assert_eq!(restored.images.len(), 1);
521        assert_eq!(restored.annotations.len(), 1);
522    }
523
524    #[test]
525    fn test_write_json_pretty() {
526        let temp_dir = TempDir::new().unwrap();
527        let output_path = temp_dir.path().join("test_pretty.json");
528
529        let dataset = CocoDataset::default();
530
531        let writer = CocoWriter::with_options(CocoWriteOptions {
532            pretty: true,
533            compress: false,
534        });
535        writer.write_json(&dataset, &output_path).unwrap();
536
537        let contents = std::fs::read_to_string(&output_path).unwrap();
538        assert!(contents.contains('\n')); // Pretty-printed should have newlines
539    }
540
541    #[test]
542    fn test_dataset_builder() {
543        let mut builder = CocoDatasetBuilder::new();
544
545        // Add categories
546        let person_id = builder.add_category("person", Some("human"));
547        let car_id = builder.add_category("car", Some("vehicle"));
548
549        assert_eq!(person_id, 1);
550        assert_eq!(car_id, 2);
551
552        // Adding same category returns existing ID
553        let person_id2 = builder.add_category("person", None);
554        assert_eq!(person_id2, 1);
555
556        // Add images
557        let img1 = builder.add_image("image1.jpg", 640, 480);
558        let img2 = builder.add_image("image2.jpg", 800, 600);
559
560        assert_eq!(img1, 1);
561        assert_eq!(img2, 2);
562
563        // Add annotations
564        let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
565        let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
566
567        assert_eq!(ann1, 1);
568        assert_eq!(ann2, 2);
569
570        // Build final dataset
571        let dataset = builder.build();
572
573        assert_eq!(dataset.categories.len(), 2);
574        assert_eq!(dataset.images.len(), 2);
575        assert_eq!(dataset.annotations.len(), 2);
576    }
577
578    #[test]
579    fn test_write_zip() {
580        let temp_dir = TempDir::new().unwrap();
581        let output_path = temp_dir.path().join("test.zip");
582
583        let dataset = CocoDataset {
584            images: vec![CocoImage {
585                id: 1,
586                width: 100,
587                height: 100,
588                file_name: "test.jpg".to_string(),
589                ..Default::default()
590            }],
591            ..Default::default()
592        };
593
594        // Create a fake image
595        let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
596
597        let writer = CocoWriter::new();
598        writer
599            .write_zip(&dataset, images.into_iter(), &output_path)
600            .unwrap();
601
602        // Verify ZIP was created
603        assert!(output_path.exists());
604
605        // Verify contents
606        let file = std::fs::File::open(&output_path).unwrap();
607        let mut archive = zip::ZipArchive::new(file).unwrap();
608
609        // Should contain annotations and image
610        assert!(archive.by_name("annotations/instances.json").is_ok());
611        assert!(archive.by_name("images/test.jpg").is_ok());
612    }
613
614    #[test]
615    fn test_write_split_by_group() {
616        let temp_dir = TempDir::new().unwrap();
617        let output_dir = temp_dir.path().join("split_output");
618
619        let dataset = CocoDataset {
620            images: vec![
621                CocoImage {
622                    id: 1,
623                    width: 640,
624                    height: 480,
625                    file_name: "train1.jpg".to_string(),
626                    ..Default::default()
627                },
628                CocoImage {
629                    id: 2,
630                    width: 640,
631                    height: 480,
632                    file_name: "train2.jpg".to_string(),
633                    ..Default::default()
634                },
635                CocoImage {
636                    id: 3,
637                    width: 800,
638                    height: 600,
639                    file_name: "val1.jpg".to_string(),
640                    ..Default::default()
641                },
642            ],
643            categories: vec![CocoCategory {
644                id: 1,
645                name: "person".to_string(),
646                supercategory: None,
647            }],
648            annotations: vec![
649                CocoAnnotation {
650                    id: 1,
651                    image_id: 1,
652                    category_id: 1,
653                    bbox: [10.0, 20.0, 100.0, 80.0],
654                    ..Default::default()
655                },
656                CocoAnnotation {
657                    id: 2,
658                    image_id: 2,
659                    category_id: 1,
660                    bbox: [20.0, 30.0, 100.0, 80.0],
661                    ..Default::default()
662                },
663                CocoAnnotation {
664                    id: 3,
665                    image_id: 3,
666                    category_id: 1,
667                    bbox: [30.0, 40.0, 100.0, 80.0],
668                    ..Default::default()
669                },
670            ],
671            ..Default::default()
672        };
673
674        let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
675
676        let writer = CocoWriter::new();
677        let result = writer
678            .write_split_by_group(&dataset, &groups, None, &output_dir)
679            .unwrap();
680
681        // Verify counts
682        assert_eq!(result.get("train"), Some(&2));
683        assert_eq!(result.get("val"), Some(&1));
684
685        // Verify directory structure
686        assert!(
687            output_dir
688                .join("train/annotations/instances_train.json")
689                .exists()
690        );
691        assert!(
692            output_dir
693                .join("val/annotations/instances_val.json")
694                .exists()
695        );
696
697        // Verify train JSON content
698        let train_json =
699            std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
700                .unwrap();
701        let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
702        assert_eq!(train_data.images.len(), 2);
703        assert_eq!(train_data.annotations.len(), 2);
704
705        // Verify val JSON content
706        let val_json =
707            std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
708        let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
709        assert_eq!(val_data.images.len(), 1);
710        assert_eq!(val_data.annotations.len(), 1);
711    }
712
713    #[test]
714    fn test_write_split_by_group_mismatch() {
715        let dataset = CocoDataset {
716            images: vec![CocoImage {
717                id: 1,
718                ..Default::default()
719            }],
720            ..Default::default()
721        };
722
723        // Wrong number of group assignments
724        let groups = vec!["train".to_string(), "val".to_string()];
725
726        let writer = CocoWriter::new();
727        let result =
728            writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
729
730        assert!(result.is_err());
731    }
732}