Skip to main content

edgefirst_client/coco/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Streaming COCO JSON/ZIP writers.
5//!
6//! Provides efficient writing of COCO annotation files to JSON or ZIP archives.
7
8use super::types::{
9    CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
10};
11use crate::Error;
12use std::{
13    fs::File,
14    io::{BufWriter, Write},
15    path::Path,
16};
17use zip::{CompressionMethod, write::SimpleFileOptions};
18
19/// Options for COCO writing.
20#[derive(Debug, Clone)]
21pub struct CocoWriteOptions {
22    /// Compress output (for ZIP).
23    pub compress: bool,
24    /// Pretty-print JSON with indentation.
25    pub pretty: bool,
26}
27
28impl Default for CocoWriteOptions {
29    fn default() -> Self {
30        Self {
31            compress: true,
32            pretty: false,
33        }
34    }
35}
36
37/// COCO writer for generating JSON and ZIP files.
38///
39/// # Example
40///
41/// ```rust,no_run
42/// use edgefirst_client::coco::{CocoDataset, CocoWriter};
43///
44/// let writer = CocoWriter::new();
45/// let dataset = CocoDataset::default();
46/// writer.write_json(&dataset, "annotations.json")?;
47/// # Ok::<(), edgefirst_client::Error>(())
48/// ```
49pub struct CocoWriter {
50    options: CocoWriteOptions,
51}
52
53impl CocoWriter {
54    /// Create a new COCO writer with default options.
55    pub fn new() -> Self {
56        Self {
57            options: CocoWriteOptions::default(),
58        }
59    }
60
61    /// Create a new COCO writer with custom options.
62    pub fn with_options(options: CocoWriteOptions) -> Self {
63        Self { options }
64    }
65
66    /// Write COCO dataset to a JSON file.
67    ///
68    /// # Arguments
69    /// * `dataset` - The COCO dataset to write
70    /// * `path` - Output file path
71    pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
72        // Ensure parent directory exists
73        if let Some(parent) = path.as_ref().parent()
74            && !parent.as_os_str().is_empty()
75        {
76            std::fs::create_dir_all(parent)?;
77        }
78
79        let file = File::create(path.as_ref())?;
80        let writer = BufWriter::with_capacity(64 * 1024, file);
81
82        if self.options.pretty {
83            serde_json::to_writer_pretty(writer, dataset)?;
84        } else {
85            serde_json::to_writer(writer, dataset)?;
86        }
87
88        Ok(())
89    }
90
91    /// Write COCO dataset to a ZIP file with images.
92    ///
93    /// Creates a ZIP archive with:
94    /// - `annotations/instances.json` - The COCO annotations
95    /// - Images at their original relative paths
96    ///
97    /// # Arguments
98    /// * `dataset` - The COCO dataset to write
99    /// * `images` - Iterator of `(filename, image_data)` pairs
100    /// * `path` - Output ZIP file path
101    pub fn write_zip<P: AsRef<Path>>(
102        &self,
103        dataset: &CocoDataset,
104        images: impl Iterator<Item = (String, Vec<u8>)>,
105        path: P,
106    ) -> Result<(), Error> {
107        // Ensure parent directory exists
108        if let Some(parent) = path.as_ref().parent()
109            && !parent.as_os_str().is_empty()
110        {
111            std::fs::create_dir_all(parent)?;
112        }
113
114        let file = File::create(path.as_ref())?;
115        let mut zip = zip::ZipWriter::new(file);
116
117        let options = if self.options.compress {
118            SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
119        } else {
120            SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
121        };
122
123        // Write annotations
124        zip.start_file("annotations/instances.json", options)?;
125        let json = if self.options.pretty {
126            serde_json::to_string_pretty(dataset)?
127        } else {
128            serde_json::to_string(dataset)?
129        };
130        zip.write_all(json.as_bytes())?;
131
132        // Write images
133        for (filename, data) in images {
134            zip.start_file(&filename, options)?;
135            zip.write_all(&data)?;
136        }
137
138        zip.finish()?;
139        Ok(())
140    }
141
142    /// Write COCO dataset to a ZIP file with images from a source directory.
143    ///
144    /// # Arguments
145    /// * `dataset` - The COCO dataset to write
146    /// * `images_dir` - Directory containing source images
147    /// * `path` - Output ZIP file path
148    pub fn write_zip_from_dir<P: AsRef<Path>>(
149        &self,
150        dataset: &CocoDataset,
151        images_dir: P,
152        path: P,
153    ) -> Result<(), Error> {
154        let images_dir = images_dir.as_ref();
155
156        // Collect image data
157        let images = dataset.images.iter().filter_map(|img| {
158            let img_path = images_dir.join(&img.file_name);
159            std::fs::read(&img_path)
160                .ok()
161                .map(|data| (format!("images/{}", img.file_name), data))
162        });
163
164        self.write_zip(dataset, images, path)
165    }
166
167    /// Split a dataset by group and write each group to its own directory.
168    ///
169    /// Creates a directory structure like:
170    /// ```text
171    /// output_dir/
172    /// ├── train/
173    /// │   ├── annotations/instances_train.json
174    /// │   └── images/
175    /// │       └── *.jpg
176    /// └── val/
177    ///     ├── annotations/instances_val.json
178    ///     └── images/
179    ///         └── *.jpg
180    /// ```
181    ///
182    /// # Arguments
183    /// * `dataset` - The COCO dataset to split
184    /// * `group_assignments` - Parallel array of group names for each image
185    /// * `images_source` - Optional source directory containing images to copy
186    /// * `output_dir` - Output root directory
187    ///
188    /// # Returns
189    /// HashMap of group name → number of images written
190    pub fn write_split_by_group<P: AsRef<Path>>(
191        &self,
192        dataset: &CocoDataset,
193        group_assignments: &[String],
194        images_source: Option<&Path>,
195        output_dir: P,
196    ) -> Result<std::collections::HashMap<String, usize>, Error> {
197        use std::collections::{HashMap, HashSet};
198
199        let output_dir = output_dir.as_ref();
200
201        // Validate input
202        if dataset.images.len() != group_assignments.len() {
203            return Err(Error::CocoError(format!(
204                "Image count ({}) does not match group assignment count ({})",
205                dataset.images.len(),
206                group_assignments.len()
207            )));
208        }
209
210        // Build groups
211        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
212        for (idx, group) in group_assignments.iter().enumerate() {
213            groups.entry(group.clone()).or_default().push(idx);
214        }
215
216        let mut result = HashMap::new();
217
218        for (group_name, image_indices) in &groups {
219            // Create subdirectory structure
220            let group_dir = output_dir.join(group_name);
221            let annotations_dir = group_dir.join("annotations");
222            let images_dir = group_dir.join("images");
223
224            std::fs::create_dir_all(&annotations_dir)?;
225            std::fs::create_dir_all(&images_dir)?;
226
227            // Build subset dataset for this group
228            let image_ids: HashSet<u64> = image_indices
229                .iter()
230                .map(|&idx| dataset.images[idx].id)
231                .collect();
232
233            let subset = CocoDataset {
234                info: dataset.info.clone(),
235                licenses: dataset.licenses.clone(),
236                images: image_indices
237                    .iter()
238                    .map(|&idx| dataset.images[idx].clone())
239                    .collect(),
240                annotations: dataset
241                    .annotations
242                    .iter()
243                    .filter(|ann| image_ids.contains(&ann.image_id))
244                    .cloned()
245                    .collect(),
246                categories: dataset.categories.clone(),
247            };
248
249            // Write annotations JSON
250            let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
251            self.write_json(&subset, &ann_file)?;
252
253            // Copy images if source provided
254            if let Some(source) = images_source {
255                for &idx in image_indices {
256                    let image = &dataset.images[idx];
257                    let src_path = source.join(&image.file_name);
258                    let dst_path = images_dir.join(&image.file_name);
259
260                    if src_path.exists() {
261                        std::fs::copy(&src_path, &dst_path)?;
262                    }
263                }
264            }
265
266            result.insert(group_name.clone(), image_indices.len());
267        }
268
269        Ok(result)
270    }
271
272    /// Split a dataset by group and write each group to its own ZIP archive.
273    ///
274    /// Creates ZIP archives like:
275    /// - `output_dir/train.zip` containing train split
276    /// - `output_dir/val.zip` containing val split
277    ///
278    /// # Arguments
279    /// * `dataset` - The COCO dataset to split
280    /// * `group_assignments` - Parallel array of group names for each image
281    /// * `images_source` - Optional source directory containing images to
282    ///   include
283    /// * `output_dir` - Output directory for ZIP files
284    ///
285    /// # Returns
286    /// HashMap of group name → number of images written
287    pub fn write_split_by_group_zip<P: AsRef<Path>>(
288        &self,
289        dataset: &CocoDataset,
290        group_assignments: &[String],
291        images_source: Option<&Path>,
292        output_dir: P,
293    ) -> Result<std::collections::HashMap<String, usize>, Error> {
294        use std::collections::{HashMap, HashSet};
295
296        let output_dir = output_dir.as_ref();
297        std::fs::create_dir_all(output_dir)?;
298
299        // Validate input
300        if dataset.images.len() != group_assignments.len() {
301            return Err(Error::CocoError(format!(
302                "Image count ({}) does not match group assignment count ({})",
303                dataset.images.len(),
304                group_assignments.len()
305            )));
306        }
307
308        // Build groups
309        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
310        for (idx, group) in group_assignments.iter().enumerate() {
311            groups.entry(group.clone()).or_default().push(idx);
312        }
313
314        let mut result = HashMap::new();
315
316        for (group_name, image_indices) in &groups {
317            // Build subset dataset for this group
318            let image_ids: HashSet<u64> = image_indices
319                .iter()
320                .map(|&idx| dataset.images[idx].id)
321                .collect();
322
323            let subset = CocoDataset {
324                info: dataset.info.clone(),
325                licenses: dataset.licenses.clone(),
326                images: image_indices
327                    .iter()
328                    .map(|&idx| dataset.images[idx].clone())
329                    .collect(),
330                annotations: dataset
331                    .annotations
332                    .iter()
333                    .filter(|ann| image_ids.contains(&ann.image_id))
334                    .cloned()
335                    .collect(),
336                categories: dataset.categories.clone(),
337            };
338
339            // Collect images if source provided
340            let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
341                image_indices
342                    .iter()
343                    .filter_map(|&idx| {
344                        let image = &dataset.images[idx];
345                        let src_path = source.join(&image.file_name);
346                        std::fs::read(&src_path)
347                            .ok()
348                            .map(|data| (format!("images/{}", image.file_name), data))
349                    })
350                    .collect()
351            } else {
352                vec![]
353            };
354
355            // Write ZIP
356            let zip_path = output_dir.join(format!("{}.zip", group_name));
357            self.write_zip(&subset, images.into_iter(), &zip_path)?;
358
359            result.insert(group_name.clone(), image_indices.len());
360        }
361
362        Ok(result)
363    }
364}
365
366impl Default for CocoWriter {
367    fn default() -> Self {
368        Self::new()
369    }
370}
371
372/// Builder for constructing a COCO dataset.
373///
374/// Provides a convenient API for incrementally building a COCO dataset.
375#[derive(Debug, Default)]
376pub struct CocoDatasetBuilder {
377    dataset: CocoDataset,
378    next_image_id: u64,
379    next_annotation_id: u64,
380    next_category_id: u32,
381}
382
383impl CocoDatasetBuilder {
384    /// Create a new dataset builder.
385    pub fn new() -> Self {
386        Self {
387            dataset: CocoDataset::default(),
388            next_image_id: 1,
389            next_annotation_id: 1,
390            next_category_id: 1,
391        }
392    }
393
394    /// Set dataset info.
395    pub fn info(mut self, info: CocoInfo) -> Self {
396        self.dataset.info = info;
397        self
398    }
399
400    /// Add a category, returning its ID.
401    pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
402        // Check if category already exists
403        for cat in &self.dataset.categories {
404            if cat.name == name {
405                return cat.id;
406            }
407        }
408
409        let id = self.next_category_id;
410        self.next_category_id += 1;
411
412        self.dataset.categories.push(CocoCategory {
413            id,
414            name: name.to_string(),
415            supercategory: supercategory.map(String::from),
416            ..Default::default()
417        });
418
419        id
420    }
421
422    /// Add a category with a specific ID, returning its ID.
423    ///
424    /// Used when round-tripping datasets that have explicit category IDs
425    /// (e.g., COCO, LVIS) to preserve the original numbering.
426    pub fn add_category_with_id(
427        &mut self,
428        id: u32,
429        name: &str,
430        supercategory: Option<&str>,
431    ) -> u32 {
432        // Check if category already exists (by id or name)
433        for cat in &self.dataset.categories {
434            if cat.id == id || cat.name == name {
435                return cat.id;
436            }
437        }
438
439        self.dataset.categories.push(CocoCategory {
440            id,
441            name: name.to_string(),
442            supercategory: supercategory.map(String::from),
443            ..Default::default()
444        });
445
446        // Keep next_category_id above all known IDs
447        if id >= self.next_category_id {
448            self.next_category_id = id + 1;
449        }
450
451        id
452    }
453
454    /// Add an image, returning its ID.
455    pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
456        let id = self.next_image_id;
457        self.next_image_id += 1;
458
459        self.dataset.images.push(CocoImage {
460            id,
461            width,
462            height,
463            file_name: file_name.to_string(),
464            ..Default::default()
465        });
466
467        id
468    }
469
470    /// Add an annotation, returning its ID.
471    pub fn add_annotation(
472        &mut self,
473        image_id: u64,
474        category_id: u32,
475        bbox: [f64; 4],
476        segmentation: Option<CocoSegmentation>,
477    ) -> u64 {
478        self.add_annotation_with_iscrowd(image_id, category_id, bbox, segmentation, 0)
479    }
480
481    /// Add an annotation with an explicit iscrowd flag, returning its ID.
482    pub fn add_annotation_with_iscrowd(
483        &mut self,
484        image_id: u64,
485        category_id: u32,
486        bbox: [f64; 4],
487        segmentation: Option<CocoSegmentation>,
488        iscrowd: u8,
489    ) -> u64 {
490        let id = self.next_annotation_id;
491        self.next_annotation_id += 1;
492
493        let area = bbox[2] * bbox[3]; // Default area from bbox
494
495        self.dataset.annotations.push(CocoAnnotation {
496            id,
497            image_id,
498            category_id,
499            bbox,
500            area,
501            iscrowd,
502            segmentation,
503            score: None,
504        });
505
506        id
507    }
508
509    /// Set the score on an annotation by ID.
510    pub fn set_annotation_score(&mut self, annotation_id: u64, score: f64) {
511        if let Some(ann) = self
512            .dataset
513            .annotations
514            .iter_mut()
515            .find(|a| a.id == annotation_id)
516        {
517            ann.score = Some(score);
518        }
519    }
520
521    /// Set LVIS annotation metadata on an image.
522    pub fn set_image_neg_categories(
523        &mut self,
524        image_id: u64,
525        neg_category_ids: Option<Vec<u32>>,
526        not_exhaustive_category_ids: Option<Vec<u32>>,
527    ) {
528        if let Some(img) = self.dataset.images.iter_mut().find(|i| i.id == image_id) {
529            img.neg_category_ids = neg_category_ids;
530            img.not_exhaustive_category_ids = not_exhaustive_category_ids;
531        }
532    }
533
534    /// Set LVIS metadata on a category by name.
535    ///
536    /// Only updates fields that are `Some`; leaves existing values intact
537    /// for fields passed as `None`.
538    pub fn set_category_metadata(
539        &mut self,
540        name: &str,
541        synset: Option<String>,
542        frequency: Option<String>,
543        synonyms: Option<Vec<String>>,
544        def: Option<String>,
545    ) {
546        if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
547            if synset.is_some() {
548                cat.synset = synset;
549            }
550            if frequency.is_some() {
551                cat.frequency = frequency;
552            }
553            if synonyms.is_some() {
554                cat.synonyms = synonyms;
555            }
556            if def.is_some() {
557                cat.def = def;
558            }
559        }
560    }
561
562    /// Set the supercategory on a category by name.
563    pub fn set_category_supercategory(&mut self, name: &str, supercategory: &str) {
564        if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
565            cat.supercategory = Some(supercategory.to_string());
566        }
567    }
568
569    /// Build the final dataset.
570    pub fn build(self) -> CocoDataset {
571        self.dataset
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use tempfile::TempDir;
579
580    #[test]
581    fn test_writer_default() {
582        let writer = CocoWriter::new();
583        assert!(writer.options.compress);
584        assert!(!writer.options.pretty);
585    }
586
587    #[test]
588    fn test_write_json() {
589        let temp_dir = TempDir::new().unwrap();
590        let output_path = temp_dir.path().join("test.json");
591
592        let dataset = CocoDataset {
593            images: vec![CocoImage {
594                id: 1,
595                width: 640,
596                height: 480,
597                file_name: "test.jpg".to_string(),
598                ..Default::default()
599            }],
600            categories: vec![CocoCategory {
601                id: 1,
602                name: "person".to_string(),
603                supercategory: None,
604                ..Default::default()
605            }],
606            annotations: vec![CocoAnnotation {
607                id: 1,
608                image_id: 1,
609                category_id: 1,
610                bbox: [10.0, 20.0, 100.0, 80.0],
611                area: 8000.0,
612                iscrowd: 0,
613                segmentation: None,
614                score: None,
615            }],
616            ..Default::default()
617        };
618
619        let writer = CocoWriter::new();
620        writer.write_json(&dataset, &output_path).unwrap();
621
622        // Verify file was created
623        assert!(output_path.exists());
624
625        // Read it back and verify
626        let contents = std::fs::read_to_string(&output_path).unwrap();
627        let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
628        assert_eq!(restored.images.len(), 1);
629        assert_eq!(restored.annotations.len(), 1);
630    }
631
632    #[test]
633    fn test_write_json_pretty() {
634        let temp_dir = TempDir::new().unwrap();
635        let output_path = temp_dir.path().join("test_pretty.json");
636
637        let dataset = CocoDataset::default();
638
639        let writer = CocoWriter::with_options(CocoWriteOptions {
640            pretty: true,
641            compress: false,
642        });
643        writer.write_json(&dataset, &output_path).unwrap();
644
645        let contents = std::fs::read_to_string(&output_path).unwrap();
646        assert!(contents.contains('\n')); // Pretty-printed should have newlines
647    }
648
649    #[test]
650    fn test_dataset_builder() {
651        let mut builder = CocoDatasetBuilder::new();
652
653        // Add categories
654        let person_id = builder.add_category("person", Some("human"));
655        let car_id = builder.add_category("car", Some("vehicle"));
656
657        assert_eq!(person_id, 1);
658        assert_eq!(car_id, 2);
659
660        // Adding same category returns existing ID
661        let person_id2 = builder.add_category("person", None);
662        assert_eq!(person_id2, 1);
663
664        // Add images
665        let img1 = builder.add_image("image1.jpg", 640, 480);
666        let img2 = builder.add_image("image2.jpg", 800, 600);
667
668        assert_eq!(img1, 1);
669        assert_eq!(img2, 2);
670
671        // Add annotations
672        let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
673        let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
674
675        assert_eq!(ann1, 1);
676        assert_eq!(ann2, 2);
677
678        // Build final dataset
679        let dataset = builder.build();
680
681        assert_eq!(dataset.categories.len(), 2);
682        assert_eq!(dataset.images.len(), 2);
683        assert_eq!(dataset.annotations.len(), 2);
684    }
685
686    #[test]
687    fn test_write_zip() {
688        let temp_dir = TempDir::new().unwrap();
689        let output_path = temp_dir.path().join("test.zip");
690
691        let dataset = CocoDataset {
692            images: vec![CocoImage {
693                id: 1,
694                width: 100,
695                height: 100,
696                file_name: "test.jpg".to_string(),
697                ..Default::default()
698            }],
699            ..Default::default()
700        };
701
702        // Create a fake image
703        let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
704
705        let writer = CocoWriter::new();
706        writer
707            .write_zip(&dataset, images.into_iter(), &output_path)
708            .unwrap();
709
710        // Verify ZIP was created
711        assert!(output_path.exists());
712
713        // Verify contents
714        let file = std::fs::File::open(&output_path).unwrap();
715        let mut archive = zip::ZipArchive::new(file).unwrap();
716
717        // Should contain annotations and image
718        assert!(archive.by_name("annotations/instances.json").is_ok());
719        assert!(archive.by_name("images/test.jpg").is_ok());
720    }
721
722    #[test]
723    fn test_write_split_by_group() {
724        let temp_dir = TempDir::new().unwrap();
725        let output_dir = temp_dir.path().join("split_output");
726
727        let dataset = CocoDataset {
728            images: vec![
729                CocoImage {
730                    id: 1,
731                    width: 640,
732                    height: 480,
733                    file_name: "train1.jpg".to_string(),
734                    ..Default::default()
735                },
736                CocoImage {
737                    id: 2,
738                    width: 640,
739                    height: 480,
740                    file_name: "train2.jpg".to_string(),
741                    ..Default::default()
742                },
743                CocoImage {
744                    id: 3,
745                    width: 800,
746                    height: 600,
747                    file_name: "val1.jpg".to_string(),
748                    ..Default::default()
749                },
750            ],
751            categories: vec![CocoCategory {
752                id: 1,
753                name: "person".to_string(),
754                supercategory: None,
755                ..Default::default()
756            }],
757            annotations: vec![
758                CocoAnnotation {
759                    id: 1,
760                    image_id: 1,
761                    category_id: 1,
762                    bbox: [10.0, 20.0, 100.0, 80.0],
763                    ..Default::default()
764                },
765                CocoAnnotation {
766                    id: 2,
767                    image_id: 2,
768                    category_id: 1,
769                    bbox: [20.0, 30.0, 100.0, 80.0],
770                    ..Default::default()
771                },
772                CocoAnnotation {
773                    id: 3,
774                    image_id: 3,
775                    category_id: 1,
776                    bbox: [30.0, 40.0, 100.0, 80.0],
777                    ..Default::default()
778                },
779            ],
780            ..Default::default()
781        };
782
783        let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
784
785        let writer = CocoWriter::new();
786        let result = writer
787            .write_split_by_group(&dataset, &groups, None, &output_dir)
788            .unwrap();
789
790        // Verify counts
791        assert_eq!(result.get("train"), Some(&2));
792        assert_eq!(result.get("val"), Some(&1));
793
794        // Verify directory structure
795        assert!(
796            output_dir
797                .join("train/annotations/instances_train.json")
798                .exists()
799        );
800        assert!(
801            output_dir
802                .join("val/annotations/instances_val.json")
803                .exists()
804        );
805
806        // Verify train JSON content
807        let train_json =
808            std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
809                .unwrap();
810        let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
811        assert_eq!(train_data.images.len(), 2);
812        assert_eq!(train_data.annotations.len(), 2);
813
814        // Verify val JSON content
815        let val_json =
816            std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
817        let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
818        assert_eq!(val_data.images.len(), 1);
819        assert_eq!(val_data.annotations.len(), 1);
820    }
821
822    #[test]
823    fn test_write_split_by_group_mismatch() {
824        let dataset = CocoDataset {
825            images: vec![CocoImage {
826                id: 1,
827                ..Default::default()
828            }],
829            ..Default::default()
830        };
831
832        // Wrong number of group assignments
833        let groups = vec!["train".to_string(), "val".to_string()];
834
835        let writer = CocoWriter::new();
836        let result =
837            writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
838
839        assert!(result.is_err());
840    }
841}