Skip to main content

edgefirst_client/coco/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Streaming COCO JSON/ZIP writers.
5//!
6//! Provides efficient writing of COCO annotation files to JSON or ZIP archives.
7
8use super::types::{
9    CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
10};
11use crate::Error;
12use std::{
13    fs::File,
14    io::{BufWriter, Write},
15    path::Path,
16};
17use zip::{CompressionMethod, write::SimpleFileOptions};
18
19/// Options for COCO writing.
20#[derive(Debug, Clone)]
21pub struct CocoWriteOptions {
22    /// Compress output (for ZIP).
23    pub compress: bool,
24    /// Pretty-print JSON with indentation.
25    pub pretty: bool,
26}
27
28impl Default for CocoWriteOptions {
29    fn default() -> Self {
30        Self {
31            compress: true,
32            pretty: false,
33        }
34    }
35}
36
37/// COCO writer for generating JSON and ZIP files.
38///
39/// # Example
40///
41/// ```rust,no_run
42/// use edgefirst_client::coco::{CocoDataset, CocoWriter};
43///
44/// let writer = CocoWriter::new();
45/// let dataset = CocoDataset::default();
46/// writer.write_json(&dataset, "annotations.json")?;
47/// # Ok::<(), edgefirst_client::Error>(())
48/// ```
49pub struct CocoWriter {
50    options: CocoWriteOptions,
51}
52
53impl CocoWriter {
54    /// Create a new COCO writer with default options.
55    pub fn new() -> Self {
56        Self {
57            options: CocoWriteOptions::default(),
58        }
59    }
60
61    /// Create a new COCO writer with custom options.
62    pub fn with_options(options: CocoWriteOptions) -> Self {
63        Self { options }
64    }
65
66    /// Write COCO dataset to a JSON file.
67    ///
68    /// # Arguments
69    /// * `dataset` - The COCO dataset to write
70    /// * `path` - Output file path
71    pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
72        // Ensure parent directory exists
73        if let Some(parent) = path.as_ref().parent()
74            && !parent.as_os_str().is_empty()
75        {
76            std::fs::create_dir_all(parent)?;
77        }
78
79        let file = File::create(path.as_ref())?;
80        let writer = BufWriter::with_capacity(64 * 1024, file);
81
82        if self.options.pretty {
83            serde_json::to_writer_pretty(writer, dataset)?;
84        } else {
85            serde_json::to_writer(writer, dataset)?;
86        }
87
88        Ok(())
89    }
90
91    /// Write COCO dataset to a ZIP file with images.
92    ///
93    /// Creates a ZIP archive with:
94    /// - `annotations/instances.json` - The COCO annotations
95    /// - Images at their original relative paths
96    ///
97    /// # Arguments
98    /// * `dataset` - The COCO dataset to write
99    /// * `images` - Iterator of `(filename, image_data)` pairs
100    /// * `path` - Output ZIP file path
101    pub fn write_zip<P: AsRef<Path>>(
102        &self,
103        dataset: &CocoDataset,
104        images: impl Iterator<Item = (String, Vec<u8>)>,
105        path: P,
106    ) -> Result<(), Error> {
107        // Ensure parent directory exists
108        if let Some(parent) = path.as_ref().parent()
109            && !parent.as_os_str().is_empty()
110        {
111            std::fs::create_dir_all(parent)?;
112        }
113
114        let file = File::create(path.as_ref())?;
115        let mut zip = zip::ZipWriter::new(file);
116
117        let options = if self.options.compress {
118            SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
119        } else {
120            SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
121        };
122
123        // Write annotations
124        zip.start_file("annotations/instances.json", options)?;
125        let json = if self.options.pretty {
126            serde_json::to_string_pretty(dataset)?
127        } else {
128            serde_json::to_string(dataset)?
129        };
130        zip.write_all(json.as_bytes())?;
131
132        // Write images
133        for (filename, data) in images {
134            zip.start_file(&filename, options)?;
135            zip.write_all(&data)?;
136        }
137
138        zip.finish()?;
139        Ok(())
140    }
141
142    /// Write COCO dataset to a ZIP file with images from a source directory.
143    ///
144    /// # Arguments
145    /// * `dataset` - The COCO dataset to write
146    /// * `images_dir` - Directory containing source images
147    /// * `path` - Output ZIP file path
148    pub fn write_zip_from_dir<P: AsRef<Path>>(
149        &self,
150        dataset: &CocoDataset,
151        images_dir: P,
152        path: P,
153    ) -> Result<(), Error> {
154        let images_dir = images_dir.as_ref();
155
156        // Collect image data
157        let images = dataset.images.iter().filter_map(|img| {
158            let img_path = images_dir.join(&img.file_name);
159            std::fs::read(&img_path)
160                .ok()
161                .map(|data| (format!("images/{}", img.file_name), data))
162        });
163
164        self.write_zip(dataset, images, path)
165    }
166
167    /// Split a dataset by group and write each group to its own directory.
168    ///
169    /// Creates a directory structure like:
170    /// ```text
171    /// output_dir/
172    /// ├── train/
173    /// │   ├── annotations/instances_train.json
174    /// │   └── images/
175    /// │       └── *.jpg
176    /// └── val/
177    ///     ├── annotations/instances_val.json
178    ///     └── images/
179    ///         └── *.jpg
180    /// ```
181    ///
182    /// # Arguments
183    /// * `dataset` - The COCO dataset to split
184    /// * `group_assignments` - Parallel array of group names for each image
185    /// * `images_source` - Optional source directory containing images to copy
186    /// * `output_dir` - Output root directory
187    ///
188    /// # Returns
189    /// HashMap of group name → number of images written
190    pub fn write_split_by_group<P: AsRef<Path>>(
191        &self,
192        dataset: &CocoDataset,
193        group_assignments: &[String],
194        images_source: Option<&Path>,
195        output_dir: P,
196    ) -> Result<std::collections::HashMap<String, usize>, Error> {
197        use std::collections::{HashMap, HashSet};
198
199        let output_dir = output_dir.as_ref();
200
201        // Validate input
202        if dataset.images.len() != group_assignments.len() {
203            return Err(Error::CocoError(format!(
204                "Image count ({}) does not match group assignment count ({})",
205                dataset.images.len(),
206                group_assignments.len()
207            )));
208        }
209
210        // Build groups
211        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
212        for (idx, group) in group_assignments.iter().enumerate() {
213            groups.entry(group.clone()).or_default().push(idx);
214        }
215
216        let mut result = HashMap::new();
217
218        for (group_name, image_indices) in &groups {
219            // Create subdirectory structure
220            let group_dir = output_dir.join(group_name);
221            let annotations_dir = group_dir.join("annotations");
222            let images_dir = group_dir.join("images");
223
224            std::fs::create_dir_all(&annotations_dir)?;
225            std::fs::create_dir_all(&images_dir)?;
226
227            // Build subset dataset for this group
228            let image_ids: HashSet<u64> = image_indices
229                .iter()
230                .map(|&idx| dataset.images[idx].id)
231                .collect();
232
233            let subset = CocoDataset {
234                info: dataset.info.clone(),
235                licenses: dataset.licenses.clone(),
236                images: image_indices
237                    .iter()
238                    .map(|&idx| dataset.images[idx].clone())
239                    .collect(),
240                annotations: dataset
241                    .annotations
242                    .iter()
243                    .filter(|ann| image_ids.contains(&ann.image_id))
244                    .cloned()
245                    .collect(),
246                categories: dataset.categories.clone(),
247            };
248
249            // Write annotations JSON
250            let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
251            self.write_json(&subset, &ann_file)?;
252
253            // Copy images if source provided
254            if let Some(source) = images_source {
255                for &idx in image_indices {
256                    let image = &dataset.images[idx];
257                    let src_path = source.join(&image.file_name);
258                    let dst_path = images_dir.join(&image.file_name);
259
260                    if src_path.exists() {
261                        std::fs::copy(&src_path, &dst_path)?;
262                    }
263                }
264            }
265
266            result.insert(group_name.clone(), image_indices.len());
267        }
268
269        Ok(result)
270    }
271
272    /// Split a dataset by group and write each group to its own ZIP archive.
273    ///
274    /// Creates ZIP archives like:
275    /// - `output_dir/train.zip` containing train split
276    /// - `output_dir/val.zip` containing val split
277    ///
278    /// # Arguments
279    /// * `dataset` - The COCO dataset to split
280    /// * `group_assignments` - Parallel array of group names for each image
281    /// * `images_source` - Optional source directory containing images to
282    ///   include
283    /// * `output_dir` - Output directory for ZIP files
284    ///
285    /// # Returns
286    /// HashMap of group name → number of images written
287    pub fn write_split_by_group_zip<P: AsRef<Path>>(
288        &self,
289        dataset: &CocoDataset,
290        group_assignments: &[String],
291        images_source: Option<&Path>,
292        output_dir: P,
293    ) -> Result<std::collections::HashMap<String, usize>, Error> {
294        use std::collections::{HashMap, HashSet};
295
296        let output_dir = output_dir.as_ref();
297        std::fs::create_dir_all(output_dir)?;
298
299        // Validate input
300        if dataset.images.len() != group_assignments.len() {
301            return Err(Error::CocoError(format!(
302                "Image count ({}) does not match group assignment count ({})",
303                dataset.images.len(),
304                group_assignments.len()
305            )));
306        }
307
308        // Build groups
309        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
310        for (idx, group) in group_assignments.iter().enumerate() {
311            groups.entry(group.clone()).or_default().push(idx);
312        }
313
314        let mut result = HashMap::new();
315
316        for (group_name, image_indices) in &groups {
317            // Build subset dataset for this group
318            let image_ids: HashSet<u64> = image_indices
319                .iter()
320                .map(|&idx| dataset.images[idx].id)
321                .collect();
322
323            let subset = CocoDataset {
324                info: dataset.info.clone(),
325                licenses: dataset.licenses.clone(),
326                images: image_indices
327                    .iter()
328                    .map(|&idx| dataset.images[idx].clone())
329                    .collect(),
330                annotations: dataset
331                    .annotations
332                    .iter()
333                    .filter(|ann| image_ids.contains(&ann.image_id))
334                    .cloned()
335                    .collect(),
336                categories: dataset.categories.clone(),
337            };
338
339            // Collect images if source provided
340            let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
341                image_indices
342                    .iter()
343                    .filter_map(|&idx| {
344                        let image = &dataset.images[idx];
345                        let src_path = source.join(&image.file_name);
346                        std::fs::read(&src_path)
347                            .ok()
348                            .map(|data| (format!("images/{}", image.file_name), data))
349                    })
350                    .collect()
351            } else {
352                vec![]
353            };
354
355            // Write ZIP
356            let zip_path = output_dir.join(format!("{}.zip", group_name));
357            self.write_zip(&subset, images.into_iter(), &zip_path)?;
358
359            result.insert(group_name.clone(), image_indices.len());
360        }
361
362        Ok(result)
363    }
364}
365
366impl Default for CocoWriter {
367    fn default() -> Self {
368        Self::new()
369    }
370}
371
372/// Builder for constructing a COCO dataset.
373///
374/// Provides a convenient API for incrementally building a COCO dataset.
375#[derive(Debug, Default)]
376pub struct CocoDatasetBuilder {
377    dataset: CocoDataset,
378    next_image_id: u64,
379    next_annotation_id: u64,
380    next_category_id: u32,
381}
382
383impl CocoDatasetBuilder {
384    /// Create a new dataset builder.
385    pub fn new() -> Self {
386        Self {
387            dataset: CocoDataset::default(),
388            next_image_id: 1,
389            next_annotation_id: 1,
390            next_category_id: 1,
391        }
392    }
393
394    /// Set dataset info.
395    pub fn info(mut self, info: CocoInfo) -> Self {
396        self.dataset.info = info;
397        self
398    }
399
400    /// Add a category, returning its ID.
401    pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
402        // Check if category already exists
403        for cat in &self.dataset.categories {
404            if cat.name == name {
405                return cat.id;
406            }
407        }
408
409        let id = self.next_category_id;
410        self.next_category_id += 1;
411
412        self.dataset.categories.push(CocoCategory {
413            id,
414            name: name.to_string(),
415            supercategory: supercategory.map(String::from),
416            ..Default::default()
417        });
418
419        id
420    }
421
422    /// Add a category with a specific ID, returning its ID.
423    ///
424    /// Used when round-tripping datasets that have explicit category IDs
425    /// (e.g., COCO, LVIS) to preserve the original numbering.
426    pub fn add_category_with_id(
427        &mut self,
428        id: u32,
429        name: &str,
430        supercategory: Option<&str>,
431    ) -> u32 {
432        // Check if category already exists (by id or name)
433        for cat in &self.dataset.categories {
434            if cat.id == id || cat.name == name {
435                return cat.id;
436            }
437        }
438
439        self.dataset.categories.push(CocoCategory {
440            id,
441            name: name.to_string(),
442            supercategory: supercategory.map(String::from),
443            ..Default::default()
444        });
445
446        // Keep next_category_id above all known IDs
447        if id >= self.next_category_id {
448            self.next_category_id = id + 1;
449        }
450
451        id
452    }
453
454    /// Add an image, returning its ID.
455    pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
456        let id = self.next_image_id;
457        self.next_image_id += 1;
458
459        self.dataset.images.push(CocoImage {
460            id,
461            width,
462            height,
463            file_name: file_name.to_string(),
464            ..Default::default()
465        });
466
467        id
468    }
469
470    /// Add an annotation, returning its ID.
471    pub fn add_annotation(
472        &mut self,
473        image_id: u64,
474        category_id: u32,
475        bbox: [f64; 4],
476        segmentation: Option<CocoSegmentation>,
477    ) -> u64 {
478        self.add_annotation_with_iscrowd(image_id, category_id, bbox, segmentation, 0)
479    }
480
481    /// Add an annotation with an explicit iscrowd flag, returning its ID.
482    pub fn add_annotation_with_iscrowd(
483        &mut self,
484        image_id: u64,
485        category_id: u32,
486        bbox: [f64; 4],
487        segmentation: Option<CocoSegmentation>,
488        iscrowd: u8,
489    ) -> u64 {
490        self.add_annotation_with_id(None, image_id, category_id, bbox, segmentation, iscrowd)
491    }
492
493    /// Add an annotation with an optional explicit ID and iscrowd flag,
494    /// returning its ID.
495    ///
496    /// When `id` is `Some(explicit)`, the caller-supplied ID is used directly
497    /// and `next_annotation_id` is bumped past it so that subsequent
498    /// auto-generated IDs do not collide. When `id` is `None`, an
499    /// auto-incremented ID is assigned (the standard path).
500    ///
501    /// Used to round-trip COCO/LVIS datasets through Arrow while preserving
502    /// the original annotation `id` for downstream tools that rely on a
503    /// stable per-instance identifier (most notably prompted-segmentation
504    /// workflows where the ID links a predicted mask back to the
505    /// ground-truth instance that prompted it).
506    ///
507    /// Caller is responsible for ensuring uniqueness across explicit IDs
508    /// (mirrors [`add_category_with_id`](Self::add_category_with_id)
509    /// semantics).
510    pub fn add_annotation_with_id(
511        &mut self,
512        id: Option<u64>,
513        image_id: u64,
514        category_id: u32,
515        bbox: [f64; 4],
516        segmentation: Option<CocoSegmentation>,
517        iscrowd: u8,
518    ) -> u64 {
519        // Use saturating_add when advancing the counter so an explicit ID
520        // at u64::MAX (or the auto path approaching it) cannot wrap to 0
521        // and silently collide with the small IDs that the auto path
522        // started at. Saturating produces a duplicate ID at exhaustion,
523        // which is bad but strictly better than the alternative.
524        let id = match id {
525            Some(explicit) => {
526                if explicit >= self.next_annotation_id {
527                    self.next_annotation_id = explicit.saturating_add(1);
528                }
529                explicit
530            }
531            None => {
532                let auto = self.next_annotation_id;
533                self.next_annotation_id = self.next_annotation_id.saturating_add(1);
534                auto
535            }
536        };
537
538        let area = bbox[2] * bbox[3]; // Default area from bbox
539
540        self.dataset.annotations.push(CocoAnnotation {
541            id,
542            image_id,
543            category_id,
544            bbox,
545            area,
546            iscrowd,
547            segmentation,
548            score: None,
549        });
550
551        id
552    }
553
554    /// Set the score on an annotation by ID.
555    pub fn set_annotation_score(&mut self, annotation_id: u64, score: f64) {
556        if let Some(ann) = self
557            .dataset
558            .annotations
559            .iter_mut()
560            .find(|a| a.id == annotation_id)
561        {
562            ann.score = Some(score);
563        }
564    }
565
566    /// Set LVIS annotation metadata on an image.
567    pub fn set_image_neg_categories(
568        &mut self,
569        image_id: u64,
570        neg_category_ids: Option<Vec<u32>>,
571        not_exhaustive_category_ids: Option<Vec<u32>>,
572    ) {
573        if let Some(img) = self.dataset.images.iter_mut().find(|i| i.id == image_id) {
574            img.neg_category_ids = neg_category_ids;
575            img.not_exhaustive_category_ids = not_exhaustive_category_ids;
576        }
577    }
578
579    /// Set LVIS metadata on a category by name.
580    ///
581    /// Only updates fields that are `Some`; leaves existing values intact
582    /// for fields passed as `None`.
583    pub fn set_category_metadata(
584        &mut self,
585        name: &str,
586        synset: Option<String>,
587        frequency: Option<String>,
588        synonyms: Option<Vec<String>>,
589        def: Option<String>,
590    ) {
591        if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
592            if synset.is_some() {
593                cat.synset = synset;
594            }
595            if frequency.is_some() {
596                cat.frequency = frequency;
597            }
598            if synonyms.is_some() {
599                cat.synonyms = synonyms;
600            }
601            if def.is_some() {
602                cat.def = def;
603            }
604        }
605    }
606
607    /// Set the supercategory on a category by name.
608    pub fn set_category_supercategory(&mut self, name: &str, supercategory: &str) {
609        if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
610            cat.supercategory = Some(supercategory.to_string());
611        }
612    }
613
614    /// Build the final dataset.
615    pub fn build(self) -> CocoDataset {
616        self.dataset
617    }
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use tempfile::TempDir;
624
625    #[test]
626    fn test_writer_default() {
627        let writer = CocoWriter::new();
628        assert!(writer.options.compress);
629        assert!(!writer.options.pretty);
630    }
631
632    #[test]
633    fn test_write_json() {
634        let temp_dir = TempDir::new().unwrap();
635        let output_path = temp_dir.path().join("test.json");
636
637        let dataset = CocoDataset {
638            images: vec![CocoImage {
639                id: 1,
640                width: 640,
641                height: 480,
642                file_name: "test.jpg".to_string(),
643                ..Default::default()
644            }],
645            categories: vec![CocoCategory {
646                id: 1,
647                name: "person".to_string(),
648                supercategory: None,
649                ..Default::default()
650            }],
651            annotations: vec![CocoAnnotation {
652                id: 1,
653                image_id: 1,
654                category_id: 1,
655                bbox: [10.0, 20.0, 100.0, 80.0],
656                area: 8000.0,
657                iscrowd: 0,
658                segmentation: None,
659                score: None,
660            }],
661            ..Default::default()
662        };
663
664        let writer = CocoWriter::new();
665        writer.write_json(&dataset, &output_path).unwrap();
666
667        // Verify file was created
668        assert!(output_path.exists());
669
670        // Read it back and verify
671        let contents = std::fs::read_to_string(&output_path).unwrap();
672        let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
673        assert_eq!(restored.images.len(), 1);
674        assert_eq!(restored.annotations.len(), 1);
675    }
676
677    #[test]
678    fn test_write_json_pretty() {
679        let temp_dir = TempDir::new().unwrap();
680        let output_path = temp_dir.path().join("test_pretty.json");
681
682        let dataset = CocoDataset::default();
683
684        let writer = CocoWriter::with_options(CocoWriteOptions {
685            pretty: true,
686            compress: false,
687        });
688        writer.write_json(&dataset, &output_path).unwrap();
689
690        let contents = std::fs::read_to_string(&output_path).unwrap();
691        assert!(contents.contains('\n')); // Pretty-printed should have newlines
692    }
693
694    #[test]
695    fn test_dataset_builder() {
696        let mut builder = CocoDatasetBuilder::new();
697
698        // Add categories
699        let person_id = builder.add_category("person", Some("human"));
700        let car_id = builder.add_category("car", Some("vehicle"));
701
702        assert_eq!(person_id, 1);
703        assert_eq!(car_id, 2);
704
705        // Adding same category returns existing ID
706        let person_id2 = builder.add_category("person", None);
707        assert_eq!(person_id2, 1);
708
709        // Add images
710        let img1 = builder.add_image("image1.jpg", 640, 480);
711        let img2 = builder.add_image("image2.jpg", 800, 600);
712
713        assert_eq!(img1, 1);
714        assert_eq!(img2, 2);
715
716        // Add annotations
717        let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
718        let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
719
720        assert_eq!(ann1, 1);
721        assert_eq!(ann2, 2);
722
723        // Build final dataset
724        let dataset = builder.build();
725
726        assert_eq!(dataset.categories.len(), 2);
727        assert_eq!(dataset.images.len(), 2);
728        assert_eq!(dataset.annotations.len(), 2);
729    }
730
731    #[test]
732    fn test_add_annotation_with_explicit_id_preserves_id_and_advances_counter() {
733        let mut builder = CocoDatasetBuilder::new();
734        let cat = builder.add_category("dog", None);
735        let img = builder.add_image("image.jpg", 640, 480);
736
737        // Explicit ID is preserved verbatim.
738        let ann = builder.add_annotation_with_id(
739            Some(9_876_543_210),
740            img,
741            cat,
742            [10.0, 20.0, 100.0, 80.0],
743            None,
744            0,
745        );
746        assert_eq!(ann, 9_876_543_210);
747
748        // Subsequent auto-generated IDs do not collide with the explicit one.
749        let ann_auto = builder.add_annotation(img, cat, [0.0, 0.0, 1.0, 1.0], None);
750        assert!(
751            ann_auto > 9_876_543_210,
752            "auto-generated annotation id ({ann_auto}) must be greater than the largest explicit id"
753        );
754
755        // None preserves the existing auto-increment behaviour even when the
756        // counter has been bumped past a large explicit ID.
757        let ann_none =
758            builder.add_annotation_with_id(None, img, cat, [0.0, 0.0, 1.0, 1.0], None, 0);
759        assert_eq!(ann_none, ann_auto + 1);
760
761        let dataset = builder.build();
762        assert_eq!(dataset.annotations.len(), 3);
763        assert_eq!(dataset.annotations[0].id, 9_876_543_210);
764    }
765
766    #[test]
767    fn test_add_annotation_with_explicit_id_at_u64_max_does_not_panic_or_wrap() {
768        // Defends against the overflow path flagged in PR review: an
769        // explicit ID at u64::MAX must not panic (debug builds) or wrap
770        // to 0 (release builds) when the counter is bumped past it. The
771        // current saturating implementation produces a duplicate ID at
772        // exhaustion, which is undesirable but strictly better than the
773        // wrap-to-zero alternative that would silently collide with the
774        // small IDs the auto path initially issues.
775        let mut builder = CocoDatasetBuilder::new();
776        let cat = builder.add_category("dog", None);
777        let img = builder.add_image("image.jpg", 640, 480);
778
779        let max =
780            builder.add_annotation_with_id(Some(u64::MAX), img, cat, [0.0, 0.0, 1.0, 1.0], None, 0);
781        assert_eq!(max, u64::MAX);
782
783        // The next auto-assigned ID saturates rather than wrapping;
784        // we explicitly tolerate the duplicate since the alternative
785        // (wrap-to-0) is worse.
786        let saturated = builder.add_annotation(img, cat, [0.0, 0.0, 1.0, 1.0], None);
787        assert_eq!(saturated, u64::MAX);
788    }
789
790    #[test]
791    fn test_write_zip() {
792        let temp_dir = TempDir::new().unwrap();
793        let output_path = temp_dir.path().join("test.zip");
794
795        let dataset = CocoDataset {
796            images: vec![CocoImage {
797                id: 1,
798                width: 100,
799                height: 100,
800                file_name: "test.jpg".to_string(),
801                ..Default::default()
802            }],
803            ..Default::default()
804        };
805
806        // Create a fake image
807        let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
808
809        let writer = CocoWriter::new();
810        writer
811            .write_zip(&dataset, images.into_iter(), &output_path)
812            .unwrap();
813
814        // Verify ZIP was created
815        assert!(output_path.exists());
816
817        // Verify contents
818        let file = std::fs::File::open(&output_path).unwrap();
819        let mut archive = zip::ZipArchive::new(file).unwrap();
820
821        // Should contain annotations and image
822        assert!(archive.by_name("annotations/instances.json").is_ok());
823        assert!(archive.by_name("images/test.jpg").is_ok());
824    }
825
826    #[test]
827    fn test_write_split_by_group() {
828        let temp_dir = TempDir::new().unwrap();
829        let output_dir = temp_dir.path().join("split_output");
830
831        let dataset = CocoDataset {
832            images: vec![
833                CocoImage {
834                    id: 1,
835                    width: 640,
836                    height: 480,
837                    file_name: "train1.jpg".to_string(),
838                    ..Default::default()
839                },
840                CocoImage {
841                    id: 2,
842                    width: 640,
843                    height: 480,
844                    file_name: "train2.jpg".to_string(),
845                    ..Default::default()
846                },
847                CocoImage {
848                    id: 3,
849                    width: 800,
850                    height: 600,
851                    file_name: "val1.jpg".to_string(),
852                    ..Default::default()
853                },
854            ],
855            categories: vec![CocoCategory {
856                id: 1,
857                name: "person".to_string(),
858                supercategory: None,
859                ..Default::default()
860            }],
861            annotations: vec![
862                CocoAnnotation {
863                    id: 1,
864                    image_id: 1,
865                    category_id: 1,
866                    bbox: [10.0, 20.0, 100.0, 80.0],
867                    ..Default::default()
868                },
869                CocoAnnotation {
870                    id: 2,
871                    image_id: 2,
872                    category_id: 1,
873                    bbox: [20.0, 30.0, 100.0, 80.0],
874                    ..Default::default()
875                },
876                CocoAnnotation {
877                    id: 3,
878                    image_id: 3,
879                    category_id: 1,
880                    bbox: [30.0, 40.0, 100.0, 80.0],
881                    ..Default::default()
882                },
883            ],
884            ..Default::default()
885        };
886
887        let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
888
889        let writer = CocoWriter::new();
890        let result = writer
891            .write_split_by_group(&dataset, &groups, None, &output_dir)
892            .unwrap();
893
894        // Verify counts
895        assert_eq!(result.get("train"), Some(&2));
896        assert_eq!(result.get("val"), Some(&1));
897
898        // Verify directory structure
899        assert!(
900            output_dir
901                .join("train/annotations/instances_train.json")
902                .exists()
903        );
904        assert!(
905            output_dir
906                .join("val/annotations/instances_val.json")
907                .exists()
908        );
909
910        // Verify train JSON content
911        let train_json =
912            std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
913                .unwrap();
914        let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
915        assert_eq!(train_data.images.len(), 2);
916        assert_eq!(train_data.annotations.len(), 2);
917
918        // Verify val JSON content
919        let val_json =
920            std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
921        let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
922        assert_eq!(val_data.images.len(), 1);
923        assert_eq!(val_data.annotations.len(), 1);
924    }
925
926    #[test]
927    fn test_write_split_by_group_mismatch() {
928        let dataset = CocoDataset {
929            images: vec![CocoImage {
930                id: 1,
931                ..Default::default()
932            }],
933            ..Default::default()
934        };
935
936        // Wrong number of group assignments
937        let groups = vec!["train".to_string(), "val".to_string()];
938
939        let writer = CocoWriter::new();
940        let result =
941            writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
942
943        assert!(result.is_err());
944    }
945}