Skip to main content

edgefirst_client/coco/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! COCO to EdgeFirst Arrow format conversion.
5//!
6//! Provides high-performance conversion between COCO JSON and EdgeFirst Arrow
7//! format, supporting async operations and progress tracking.
8
9use super::{
10    convert::{
11        box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask_data,
12        coco_segmentation_to_polygon, polygon_to_coco_polygon,
13    },
14    reader::CocoReader,
15    types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
16    writer::{CocoDatasetBuilder, CocoWriter},
17};
18use crate::{Annotation, Box2d, Error, Polygon, Progress, Sample};
19use polars::prelude::*;
20use std::{
21    collections::{BTreeMap, HashMap},
22    path::Path,
23    sync::{
24        Arc,
25        atomic::{AtomicUsize, Ordering},
26    },
27};
28use tokio::sync::{Semaphore, mpsc::Sender};
29
30/// Schema version written into Arrow IPC file metadata.
31pub const SCHEMA_VERSION: &str = "2026.04";
32
33/// Polygon rings for a single row: each ring is a vec of `(x, y)` coordinate pairs.
34type PolygonRings = Vec<Vec<(f32, f32)>>;
35
36/// Options for COCO to Arrow conversion.
37#[derive(Debug, Clone)]
38pub struct CocoToArrowOptions {
39    /// Include segmentation masks in output.
40    pub include_masks: bool,
41    /// Group name for all samples (e.g., "train", "val").
42    pub group: Option<String>,
43    /// Maximum number of parallel workers.
44    pub max_workers: usize,
45}
46
47impl Default for CocoToArrowOptions {
48    fn default() -> Self {
49        Self {
50            include_masks: true,
51            group: None,
52            max_workers: max_workers(),
53        }
54    }
55}
56
57/// Options for Arrow to COCO conversion.
58#[derive(Debug, Clone)]
59pub struct ArrowToCocoOptions {
60    /// Filter by group names (empty = all).
61    pub groups: Vec<String>,
62    /// Include segmentation masks in output.
63    pub include_masks: bool,
64    /// COCO info section.
65    pub info: Option<CocoInfo>,
66}
67
68impl Default for ArrowToCocoOptions {
69    fn default() -> Self {
70        Self {
71            groups: vec![],
72            include_masks: true,
73            info: None,
74        }
75    }
76}
77
78/// Determine maximum number of parallel workers.
79fn max_workers() -> usize {
80    std::env::var("MAX_COCO_WORKERS")
81        .ok()
82        .and_then(|v| v.parse().ok())
83        .unwrap_or_else(|| {
84            let cpus = std::thread::available_parallelism()
85                .map(|n| n.get())
86                .unwrap_or(4);
87            (cpus / 2).clamp(2, 8)
88        })
89}
90
91/// Convert COCO annotations to EdgeFirst Arrow format.
92///
93/// This is a high-performance async conversion that uses parallel workers
94/// for parsing and transforming annotations.
95///
96/// # Arguments
97/// * `coco_path` - Path to COCO annotation JSON file or ZIP archive
98/// * `output_path` - Output Arrow file path
99/// * `options` - Conversion options
100/// * `progress` - Optional progress channel
101///
102/// # Returns
103/// Number of samples converted
104pub async fn coco_to_arrow<P: AsRef<Path>>(
105    coco_path: P,
106    output_path: P,
107    options: &CocoToArrowOptions,
108    progress: Option<Sender<Progress>>,
109) -> Result<usize, Error> {
110    let coco_path = coco_path.as_ref();
111    let output_path = output_path.as_ref();
112
113    // Read COCO dataset
114    let reader = CocoReader::new();
115    let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
116        reader.read_annotations_zip(coco_path)?
117    } else {
118        reader.read_json(coco_path)?
119    };
120
121    // Build index for efficient lookups
122    let index = Arc::new(CocoIndex::from_dataset(&dataset));
123    let total_images = dataset.images.len();
124
125    // Send initial progress
126    if let Some(ref p) = progress {
127        let _ = p
128            .send(Progress {
129                current: 0,
130                total: total_images,
131                status: None,
132            })
133            .await;
134    }
135
136    // Process images in parallel
137    let sem = Arc::new(Semaphore::new(options.max_workers));
138    let current = Arc::new(AtomicUsize::new(0));
139    let include_masks = options.include_masks;
140    let group = options.group.clone();
141
142    let mut tasks = Vec::with_capacity(total_images);
143
144    for image in dataset.images {
145        let sem = sem.clone();
146        let index = index.clone();
147        let current = current.clone();
148        let progress = progress.clone();
149        let total = total_images;
150        let group = group.clone();
151
152        let task = tokio::spawn(async move {
153            let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
154
155            // Convert this image's annotations to EdgeFirst samples
156            let samples =
157                convert_image_annotations(&image, &index, include_masks, group.as_deref());
158
159            // Update progress
160            let c = current.fetch_add(1, Ordering::SeqCst) + 1;
161            if let Some(ref p) = progress {
162                let _ = p
163                    .send(Progress {
164                        current: c,
165                        total,
166                        status: None,
167                    })
168                    .await;
169            }
170
171            Ok::<Vec<Sample>, Error>(samples)
172        });
173
174        tasks.push(task);
175    }
176
177    // Collect all samples
178    let mut all_samples = Vec::with_capacity(total_images);
179    for task in tasks {
180        let samples = task.await??;
181        all_samples.extend(samples);
182    }
183
184    // Convert to DataFrame
185    let df = crate::samples_dataframe(&all_samples)?;
186
187    // Build schema-level metadata
188    let mut metadata: BTreeMap<PlSmallStr, PlSmallStr> = BTreeMap::new();
189    metadata.insert(
190        PlSmallStr::from("schema_version"),
191        PlSmallStr::from(SCHEMA_VERSION),
192    );
193
194    // Build category_metadata JSON from all categories.
195    // Includes id, frequency, and any LVIS fields (synset, synonyms, def).
196    // All categories are stored so that categories without annotations
197    // (e.g., those only referenced in neg_category_ids) can be
198    // reconstructed during Arrow→COCO export.
199    if !dataset.categories.is_empty() {
200        let cat_meta: HashMap<String, serde_json::Value> = dataset
201            .categories
202            .iter()
203            .map(|c| {
204                let mut entry = serde_json::Map::new();
205                entry.insert("id".to_string(), serde_json::json!(c.id));
206                if let Some(ref f) = c.frequency {
207                    entry.insert(
208                        "frequency".to_string(),
209                        serde_json::Value::String(f.clone()),
210                    );
211                }
212                if let Some(ref s) = c.synset {
213                    entry.insert("synset".to_string(), serde_json::Value::String(s.clone()));
214                }
215                if let Some(ref syns) = c.synonyms {
216                    entry.insert("synonyms".to_string(), serde_json::json!(syns));
217                }
218                if let Some(ref d) = c.def {
219                    entry.insert(
220                        "definition".to_string(),
221                        serde_json::Value::String(d.clone()),
222                    );
223                }
224                if let Some(ref sc) = c.supercategory {
225                    entry.insert(
226                        "supercategory".to_string(),
227                        serde_json::Value::String(sc.clone()),
228                    );
229                }
230                // Note: image_count and instance_count are intentionally not
231                // stored — they are recomputable statistics that can be derived
232                // from the annotations at any time.
233                (c.name.clone(), serde_json::Value::Object(entry))
234            })
235            .collect();
236
237        let json = serde_json::to_string(&cat_meta).unwrap_or_default();
238        metadata.insert(
239            PlSmallStr::from("category_metadata"),
240            PlSmallStr::from(json.as_str()),
241        );
242    }
243
244    // Write labels metadata: sorted list of category names by category_id.
245    if !dataset.categories.is_empty() {
246        let mut cats: Vec<_> = dataset.categories.iter().collect();
247        cats.sort_by_key(|c| c.id);
248        let labels: Vec<String> = cats.iter().map(|c| c.name.clone()).collect();
249        let labels_json = serde_json::to_string(&labels).unwrap_or_default();
250        metadata.insert(PlSmallStr::from("labels"), PlSmallStr::from(labels_json));
251    }
252
253    // Write Arrow file
254    if let Some(parent) = output_path.parent()
255        && !parent.as_os_str().is_empty()
256    {
257        std::fs::create_dir_all(parent)?;
258    }
259    let mut file = std::fs::File::create(output_path)?;
260    let mut writer = IpcWriter::new(&mut file);
261    writer.set_custom_schema_metadata(Arc::new(metadata));
262    writer.finish(&mut df.clone())?;
263
264    Ok(all_samples.len())
265}
266
267/// Convert a single image's annotations to EdgeFirst samples.
268fn convert_image_annotations(
269    image: &CocoImage,
270    index: &CocoIndex,
271    include_masks: bool,
272    group: Option<&str>,
273) -> Vec<Sample> {
274    let annotations = index.annotations_for_image(image.id);
275    let sample_name = sample_name_from_filename(&image.file_name);
276
277    // Translate LVIS image-level fields to label_index lists
278    let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
279        ids.iter()
280            .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
281            .collect::<Vec<u32>>()
282    });
283    let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
284        ids.iter()
285            .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
286            .collect::<Vec<u32>>()
287    });
288
289    let mut samples: Vec<Sample> = annotations
290        .iter()
291        .filter_map(|ann| {
292            let label = index.label_name(ann.category_id)?;
293            let label_index = index.label_index(ann.category_id);
294
295            // Convert bbox
296            let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
297
298            // Convert segmentation based on type:
299            // - Polygon → annotation.polygon (normalized coords)
300            // - RLE/CompressedRle → annotation.mask (PNG-encoded MaskData)
301            let (polygon, mask) = if include_masks {
302                if let Some(seg) = &ann.segmentation {
303                    match seg {
304                        CocoSegmentation::Polygon(_) => {
305                            let poly =
306                                coco_segmentation_to_polygon(seg, image.width, image.height).ok();
307                            (poly, None)
308                        }
309                        CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_) => {
310                            let mask_data = coco_segmentation_to_mask_data(seg).ok().flatten();
311                            (None, mask_data)
312                        }
313                    }
314                } else {
315                    (None, None)
316                }
317            } else {
318                (None, None)
319            };
320
321            let mut annotation = Annotation::new();
322            annotation.set_name(Some(sample_name.clone()));
323            annotation.set_label(Some(label.to_string()));
324            annotation.set_label_index(label_index);
325            annotation.set_box2d(Some(box2d));
326            annotation.set_polygon(polygon);
327            annotation.set_mask(mask);
328            annotation.set_group(group.map(String::from));
329            annotation.set_iscrowd(Some(ann.iscrowd != 0));
330            annotation.set_category_frequency(index.frequency(ann.category_id).map(String::from));
331
332            // Map COCO score to appropriate geometry score field
333            if let Some(score) = ann.score {
334                let score_f32 = score as f32;
335                if annotation.mask().is_some() {
336                    annotation.set_mask_score(Some(score_f32));
337                } else if annotation.polygon().is_some() {
338                    annotation.set_polygon_score(Some(score_f32));
339                } else {
340                    annotation.set_box2d_score(Some(score_f32));
341                }
342            }
343
344            let mut sample = Sample {
345                image_name: Some(sample_name.clone()),
346                width: Some(image.width),
347                height: Some(image.height),
348                group: group.map(String::from),
349                annotations: vec![annotation],
350                ..Default::default()
351            };
352            sample.neg_label_indices = neg_label_indices.clone();
353            sample.not_exhaustive_label_indices = not_exhaustive_label_indices.clone();
354
355            Some(sample)
356        })
357        .collect();
358
359    // Emit sentinel for images with no annotations but with neg/exhaustive data.
360    // Without this, neg_category_ids would be silently lost for images that have
361    // verified-negative labels but no positive annotations.
362    if samples.is_empty()
363        && (image.neg_category_ids.is_some() || image.not_exhaustive_category_ids.is_some())
364    {
365        let mut sample = Sample {
366            image_name: Some(sample_name.clone()),
367            width: Some(image.width),
368            height: Some(image.height),
369            group: group.map(String::from),
370            ..Default::default()
371        };
372        sample.neg_label_indices = neg_label_indices;
373        sample.not_exhaustive_label_indices = not_exhaustive_label_indices;
374        samples.push(sample);
375    }
376
377    samples
378}
379
380/// Extract sample name from image filename.
381fn sample_name_from_filename(filename: &str) -> String {
382    Path::new(filename)
383        .file_stem()
384        .and_then(|s| s.to_str())
385        .map(String::from)
386        .unwrap_or_else(|| filename.to_string())
387}
388
389/// Convert EdgeFirst Arrow format to COCO annotations.
390///
391/// Reads an Arrow file and produces COCO JSON output. LVIS extension fields
392/// are preserved when present in the Arrow file: `neg_category_ids`,
393/// `not_exhaustive_category_ids`, category `frequency`, annotation `iscrowd`,
394/// `supercategory`, and category metadata (`synset`, `synonyms`, `def`).
395///
396/// # Arguments
397/// * `arrow_path` - Path to EdgeFirst Arrow file
398/// * `output_path` - Output COCO JSON file path
399/// * `options` - Conversion options
400/// * `progress` - Optional progress channel
401///
402/// # Returns
403/// Number of annotations converted
404pub async fn arrow_to_coco<P: AsRef<Path>>(
405    arrow_path: P,
406    output_path: P,
407    options: &ArrowToCocoOptions,
408    progress: Option<Sender<Progress>>,
409) -> Result<usize, Error> {
410    let arrow_path = arrow_path.as_ref();
411    let output_path = output_path.as_ref();
412
413    // Read file-level metadata (must be done before consuming the reader)
414    let (schema_version, category_metadata_json, labels_metadata_json) = {
415        let mut meta_file = std::fs::File::open(arrow_path)?;
416        let mut reader = IpcReader::new(&mut meta_file);
417        let meta = reader.custom_metadata().ok().flatten();
418        let sv = meta.as_ref().and_then(|m| {
419            m.get(&PlSmallStr::from("schema_version"))
420                .map(|s| s.to_string())
421        });
422        let cm = meta.as_ref().and_then(|m| {
423            m.get(&PlSmallStr::from("category_metadata"))
424                .map(|s| s.to_string())
425        });
426        let lm = meta
427            .as_ref()
428            .and_then(|m| m.get(&PlSmallStr::from("labels")).map(|s| s.to_string()));
429        (sv, cm, lm)
430    };
431
432    // Determine format version: absent → 2025.10, present → use value
433    let is_legacy = schema_version.is_none();
434
435    // Read Arrow file
436    let mut file = std::fs::File::open(arrow_path)?;
437    let df = IpcReader::new(&mut file).finish()?;
438
439    // Get group column for filtering
440    let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
441
442    let total_rows = df.height();
443
444    if let Some(ref p) = progress {
445        let _ = p
446            .send(Progress {
447                current: 0,
448                total: total_rows,
449                status: None,
450            })
451            .await;
452    }
453
454    // Extract columns - all at once for O(n) instead of O(n²) per-row access
455    let names: Vec<String> = df
456        .column("name")?
457        .str()?
458        .into_iter()
459        .map(|s| s.unwrap_or_default().to_string())
460        .collect();
461
462    let labels: Vec<String> = df
463        .column("label")
464        .ok()
465        .and_then(|c| c.cast(&DataType::String).ok())
466        .map(|c| {
467            c.str()
468                .ok()
469                .map(|s| {
470                    s.into_iter()
471                        .map(|v| v.unwrap_or_default().to_string())
472                        .collect()
473                })
474                .unwrap_or_else(|| vec![String::new(); total_rows])
475        })
476        .unwrap_or_else(|| vec![String::new(); total_rows]);
477
478    let label_indices: Vec<Option<u64>> = df
479        .column("label_index")
480        .ok()
481        .map(|c| {
482            c.u64()
483                .ok()
484                .map(|s| s.into_iter().collect())
485                .unwrap_or_else(|| vec![None; total_rows])
486        })
487        .unwrap_or_else(|| vec![None; total_rows]);
488
489    // Get group column for filtering
490    let groups: Vec<String> = df
491        .column("group")
492        .ok()
493        .and_then(|c| c.cast(&DataType::String).ok())
494        .map(|c| {
495            c.str()
496                .ok()
497                .map(|s| {
498                    s.into_iter()
499                        .map(|v| v.unwrap_or_default().to_string())
500                        .collect()
501                })
502                .unwrap_or_default()
503        })
504        .unwrap_or_else(|| vec!["".to_string(); total_rows]);
505
506    // Extract all box2d values upfront (O(n) instead of O(n²))
507    let box2ds = df
508        .column("box2d")
509        .ok()
510        .map(extract_all_box2ds)
511        .transpose()?
512        .unwrap_or_else(|| vec![[0.0; 4]; total_rows]);
513
514    // Extract segmentation data based on schema version
515    //
516    // 2025.10 (legacy): mask column is List(Float32) with NaN-separated polygon coords
517    // 2026.04+:         polygon column is List(List(Float32)), mask column is Binary (PNG)
518    let legacy_masks: Option<Vec<Vec<f32>>> = if is_legacy && options.include_masks {
519        df.column("mask").ok().map(extract_all_masks).transpose()?
520    } else {
521        None
522    };
523
524    let polygons_2026: Option<Vec<Option<PolygonRings>>> = if !is_legacy && options.include_masks {
525        df.column("polygon")
526            .ok()
527            .map(|c| extract_all_polygons(c, total_rows))
528    } else {
529        None
530    };
531
532    let mask_binary_2026: Option<Vec<Option<Vec<u8>>>> = if !is_legacy && options.include_masks {
533        df.column("mask")
534            .ok()
535            .map(|c| extract_all_binary_masks(c, total_rows))
536    } else {
537        None
538    };
539
540    // Extract all sizes upfront if present
541    let sizes = df
542        .column("size")
543        .ok()
544        .and_then(|c| extract_all_sizes(c).ok());
545
546    // Extract iscrowd column (optional, Boolean in 2026.04, UInt32 in older schemas)
547    let iscrowds: Vec<u8> = df
548        .column("iscrowd")
549        .ok()
550        .map(|c| {
551            // Try Boolean first (2026.04 schema), then fall back to UInt32 (older schemas)
552            if let Ok(bool_ca) = c.bool() {
553                bool_ca
554                    .into_iter()
555                    .map(|v| if v.unwrap_or(false) { 1 } else { 0 })
556                    .collect()
557            } else {
558                c.u32()
559                    .ok()
560                    .map(|s| s.into_iter().map(|v| v.unwrap_or(0) as u8).collect())
561                    .unwrap_or_else(|| vec![0; total_rows])
562            }
563        })
564        .unwrap_or_else(|| vec![0; total_rows]);
565
566    // Extract category_frequency column (optional, Categorical/String)
567    let category_frequencies: Vec<Option<String>> = df
568        .column("category_frequency")
569        .ok()
570        .and_then(|c| c.cast(&DataType::String).ok())
571        .map(|c| {
572            c.str()
573                .ok()
574                .map(|s| s.into_iter().map(|v| v.map(String::from)).collect())
575                .unwrap_or_else(|| vec![None; total_rows])
576        })
577        .unwrap_or_else(|| vec![None; total_rows]);
578
579    // Extract neg_label_indices column (optional, List<UInt32>)
580    let neg_label_indices: Vec<Option<Vec<u32>>> = df
581        .column("neg_label_indices")
582        .ok()
583        .map(|c| extract_list_u32_column(c, total_rows))
584        .unwrap_or_else(|| vec![None; total_rows]);
585
586    // Extract not_exhaustive_label_indices column (optional, List<UInt32>)
587    let not_exhaustive_label_indices: Vec<Option<Vec<u32>>> = df
588        .column("not_exhaustive_label_indices")
589        .ok()
590        .map(|c| extract_list_u32_column(c, total_rows))
591        .unwrap_or_else(|| vec![None; total_rows]);
592
593    // Extract score columns (2026.04 schema)
594    let box2d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box2d_score", total_rows);
595    let box3d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box3d_score", total_rows);
596    let polygon_scores: Vec<Option<f32>> = extract_f32_column(&df, "polygon_score", total_rows);
597    let mask_scores: Vec<Option<f32>> = extract_f32_column(&df, "mask_score", total_rows);
598
599    // Build COCO dataset
600    let mut builder = CocoDatasetBuilder::new();
601
602    if let Some(info) = &options.info {
603        builder = builder.info(info.clone());
604    }
605
606    // Group-filter predicate: returns true if this row should be skipped
607    let skip_row = |i: usize| -> bool {
608        !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i])
609    };
610
611    // Track unique images and categories
612    let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
613    let mut image_ids: HashMap<String, u64> = HashMap::new();
614    let mut category_ids: HashMap<String, u32> = HashMap::new();
615
616    // First pass: collect unique images and categories
617    for i in 0..total_rows {
618        if skip_row(i) {
619            continue;
620        }
621
622        let name = &names[i];
623        let label = &labels[i];
624
625        // Get or estimate image dimensions
626        if !image_ids.contains_key(name) {
627            let (width, height) = sizes
628                .as_ref()
629                .and_then(|s| s.get(i).copied())
630                .unwrap_or((0, 0));
631
632            let id = builder.add_image(name, width, height);
633            image_ids.insert(name.clone(), id);
634            image_dimensions.insert(name.clone(), (width, height));
635        }
636
637        if !label.is_empty() && !category_ids.contains_key(label) {
638            let id = if let Some(Some(idx)) = label_indices.get(i) {
639                builder.add_category_with_id(*idx as u32, label, None)
640            } else {
641                builder.add_category(label, None)
642            };
643            category_ids.insert(label.clone(), id);
644        }
645    }
646
647    // Second pass: create annotations
648    let mut last_progress_update = 0;
649    for i in 0..total_rows {
650        if skip_row(i) {
651            continue;
652        }
653
654        let name = &names[i];
655        let label = &labels[i];
656
657        // Skip sentinel rows (empty label = image with neg/exhaustive data but no annotations)
658        if label.is_empty() {
659            continue;
660        }
661
662        let image_id = *image_ids.get(name).unwrap_or(&0);
663        let category_id = *category_ids.get(label).unwrap_or(&0);
664        let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
665
666        // Convert box2d from Arrow center-normalized [cx, cy, w, h] to COCO format
667        // Arrow stores center-point, Box2d expects top-left
668        let bbox = box2ds.get(i).map(|box2d| {
669            let cx = box2d[0];
670            let cy = box2d[1];
671            let w = box2d[2];
672            let h = box2d[3];
673            // Convert from center-point to top-left format
674            let left = cx - w / 2.0;
675            let top = cy - h / 2.0;
676            let ef_box2d = Box2d::new(left, top, w, h);
677            box2d_to_coco_bbox(&ef_box2d, width, height)
678        });
679
680        // Build segmentation based on schema version
681        let segmentation = if options.include_masks {
682            if is_legacy {
683                // 2025.10: mask column contains NaN-separated flat polygon coords
684                legacy_masks.as_ref().and_then(|m| {
685                    m.get(i).and_then(|coords| {
686                        if coords.is_empty() {
687                            None
688                        } else {
689                            let rings = crate::unflatten_polygon_coordinates(coords);
690                            let polygon = Polygon::new(rings);
691                            let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
692                            if coco_poly.is_empty() {
693                                None
694                            } else {
695                                Some(CocoSegmentation::Polygon(coco_poly))
696                            }
697                        }
698                    })
699                })
700            } else {
701                // 2026.04+: try mask (Binary/PNG → RLE) first, then polygon column
702                let mask_seg = mask_binary_2026.as_ref().and_then(|masks| {
703                    masks.get(i).and_then(|opt_bytes| {
704                        opt_bytes
705                            .as_ref()
706                            .and_then(|png_bytes| png_to_rle_segmentation(png_bytes, i))
707                    })
708                });
709
710                if mask_seg.is_some() {
711                    mask_seg
712                } else {
713                    // Fall back to polygon column
714                    polygons_2026.as_ref().and_then(|polys| {
715                        polys.get(i).and_then(|opt_rings| {
716                            opt_rings.as_ref().and_then(|rings| {
717                                if rings.is_empty() {
718                                    return None;
719                                }
720                                let polygon = Polygon::new(rings.clone());
721                                let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
722                                if coco_poly.is_empty() {
723                                    None
724                                } else {
725                                    Some(CocoSegmentation::Polygon(coco_poly))
726                                }
727                            })
728                        })
729                    })
730                }
731            }
732        } else {
733            None
734        };
735
736        // Determine the score: use first non-null from available score columns
737        let score: Option<f64> = mask_scores[i]
738            .or(polygon_scores[i])
739            .or(box3d_scores[i])
740            .or(box2d_scores[i])
741            .map(|s| s as f64);
742
743        if let Some(bbox) = bbox {
744            let iscrowd = iscrowds[i];
745            let ann_id = builder.add_annotation_with_iscrowd(
746                image_id,
747                category_id,
748                bbox,
749                segmentation,
750                iscrowd,
751            );
752
753            // Set score on the annotation if present
754            if let Some(score_val) = score {
755                builder.set_annotation_score(ann_id, score_val);
756            }
757        }
758
759        // Update progress every 1000 rows to reduce overhead
760        if let Some(ref p) = progress
761            && (i - last_progress_update >= 1000 || i == total_rows - 1)
762        {
763            let _ = p
764                .send(Progress {
765                    current: i + 1,
766                    total: total_rows,
767                    status: None,
768                })
769                .await;
770            last_progress_update = i;
771        }
772    }
773
774    // Send final progress event (may not have fired if last rows were filtered)
775    if let Some(ref p) = progress
776        && last_progress_update < total_rows.saturating_sub(1)
777    {
778        let _ = p
779            .send(Progress {
780                current: total_rows,
781                total: total_rows,
782                status: None,
783            })
784            .await;
785    }
786
787    // Third pass: set LVIS image-level fields (neg/not-exhaustive category IDs)
788    // Since label_index == category_id, we can use the values directly.
789    {
790        let mut processed_images: std::collections::HashSet<u64> = std::collections::HashSet::new();
791        for i in 0..total_rows {
792            if skip_row(i) {
793                continue;
794            }
795            let name = &names[i];
796            if let Some(&image_id) = image_ids.get(name) {
797                if !processed_images.insert(image_id) {
798                    continue;
799                }
800                let neg = neg_label_indices[i].clone();
801                let not_exhaustive = not_exhaustive_label_indices[i].clone();
802                if neg.is_some() || not_exhaustive.is_some() {
803                    builder.set_image_neg_categories(image_id, neg, not_exhaustive);
804                }
805            }
806        }
807    }
808
809    // Set category frequency from the category_frequency column.
810    // Build a map of category_name -> frequency from the first occurrence.
811    {
812        let mut freq_map: HashMap<String, String> = HashMap::new();
813        for i in 0..total_rows {
814            if skip_row(i) {
815                continue;
816            }
817            let label = &labels[i];
818            if !label.is_empty()
819                && !freq_map.contains_key(label)
820                && let Some(ref freq) = category_frequencies[i]
821            {
822                freq_map.insert(label.clone(), freq.clone());
823            }
824        }
825        for (name, freq) in &freq_map {
826            builder.set_category_metadata(name, None, Some(freq.clone()), None, None);
827        }
828    }
829
830    // Set category metadata from file-level metadata JSON
831    // (id, frequency, synset, synonyms, def, supercategory).
832    // Also creates categories that exist in metadata but have no annotations
833    // (e.g., categories only referenced in neg_category_ids).
834    // set_category_metadata only updates fields that are Some, so frequency
835    // set from the column above is preserved for categories that had annotations.
836    if let Some(ref json_str) = category_metadata_json
837        && let Ok(meta) = serde_json::from_str::<HashMap<String, serde_json::Value>>(json_str)
838    {
839        for (cat_name, value) in &meta {
840            let supercategory = value.get("supercategory").and_then(|v| v.as_str());
841
842            // If this category doesn't exist yet, create it with the stored id
843            if !category_ids.contains_key(cat_name.as_str()) {
844                let cat_id = value.get("id").and_then(|v| v.as_u64()).map(|id| id as u32);
845                let id = if let Some(cat_id) = cat_id {
846                    builder.add_category_with_id(cat_id, cat_name, supercategory)
847                } else {
848                    builder.add_category(cat_name, supercategory)
849                };
850                category_ids.insert(cat_name.clone(), id);
851            } else {
852                // Category already exists — set supercategory if present in metadata
853                if let Some(sc) = supercategory {
854                    builder.set_category_supercategory(cat_name, sc);
855                }
856            }
857
858            let synset = value
859                .get("synset")
860                .and_then(|v| v.as_str())
861                .map(String::from);
862            let frequency = value
863                .get("frequency")
864                .and_then(|v| v.as_str())
865                .map(String::from);
866            let synonyms = value.get("synonyms").and_then(|v| {
867                v.as_array().map(|arr| {
868                    arr.iter()
869                        .filter_map(|s| s.as_str().map(String::from))
870                        .collect()
871                })
872            });
873            let def = value
874                .get("definition")
875                .and_then(|v| v.as_str())
876                .map(String::from);
877
878            builder.set_category_metadata(cat_name, synset, frequency, synonyms, def);
879        }
880    }
881
882    // Populate category names from labels metadata if categories weren't set
883    // from category_metadata (e.g., older files that only have labels list).
884    if category_metadata_json.is_none()
885        && let Some(ref labels_json) = labels_metadata_json
886        && let Ok(label_names) = serde_json::from_str::<Vec<String>>(labels_json)
887    {
888        for label_name in &label_names {
889            if !category_ids.contains_key(label_name) {
890                let id = builder.add_category(label_name, None);
891                category_ids.insert(label_name.clone(), id);
892            }
893        }
894    }
895
896    let dataset = builder.build();
897    let annotation_count = dataset.annotations.len();
898
899    // Write output
900    let writer = CocoWriter::new();
901    writer.write_json(&dataset, output_path)?;
902
903    Ok(annotation_count)
904}
905
906/// Extract all box2d values from a column at once (O(n) instead of O(n²)).
907fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
908    let arr = col.array()?;
909    let mut result = Vec::with_capacity(arr.len());
910
911    for inner in arr.amortized_iter() {
912        let values = if let Some(inner) = inner {
913            let series = inner.as_ref();
914            let vals: Vec<f32> = series
915                .f32()
916                .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
917                .into_iter()
918                .map(|v| v.unwrap_or(0.0))
919                .collect();
920
921            if vals.len() == 4 {
922                [vals[0], vals[1], vals[2], vals[3]]
923            } else {
924                [0.0, 0.0, 0.0, 0.0]
925            }
926        } else {
927            [0.0, 0.0, 0.0, 0.0]
928        };
929        result.push(values);
930    }
931
932    Ok(result)
933}
934
935/// Extract all mask coordinates from a column at once (O(n) instead of O(n²)).
936fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
937    let list = col.list()?;
938    let mut result = Vec::with_capacity(list.len());
939
940    for i in 0..list.len() {
941        let coords = match list.get_as_series(i) {
942            Some(series) => series
943                .f32()
944                .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
945                .into_iter()
946                .map(|v| v.unwrap_or(f32::NAN))
947                .collect(),
948            None => vec![],
949        };
950        result.push(coords);
951    }
952
953    Ok(result)
954}
955
956/// Extract all image sizes from a column at once.
957fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
958    let arr = col.array()?;
959    let mut result = Vec::with_capacity(arr.len());
960
961    for inner in arr.amortized_iter() {
962        let size = if let Some(inner) = inner {
963            let series = inner.as_ref();
964            let values: Vec<u32> = series
965                .u32()
966                .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
967                .into_iter()
968                .map(|v| v.unwrap_or(0))
969                .collect();
970
971            if values.len() >= 2 {
972                (values[0], values[1])
973            } else {
974                (0, 0)
975            }
976        } else {
977            (0, 0)
978        };
979        result.push(size);
980    }
981
982    Ok(result)
983}
984
985/// Extract a List<UInt32> column into a vector of optional Vec<u32>.
986fn extract_list_u32_column(col: &Column, total_rows: usize) -> Vec<Option<Vec<u32>>> {
987    col.list()
988        .ok()
989        .map(|list| {
990            (0..list.len())
991                .map(|i| {
992                    list.get_as_series(i).and_then(|series| {
993                        series
994                            .u32()
995                            .ok()
996                            .map(|ca| ca.into_iter().flatten().collect::<Vec<u32>>())
997                    })
998                })
999                .collect()
1000        })
1001        .unwrap_or_else(|| vec![None; total_rows])
1002}
1003
1004/// Extract polygon rings from a `List(List(Float32))` column (2026.04 schema).
1005///
1006/// Each row is an optional list of rings; each ring is a list of flat `[x, y, x, y, ...]`
1007/// coordinate pairs.
1008fn extract_all_polygons(col: &Column, total_rows: usize) -> Vec<Option<PolygonRings>> {
1009    let outer_list = match col.list() {
1010        Ok(l) => l,
1011        Err(_) => return vec![None; total_rows],
1012    };
1013
1014    let mut result = Vec::with_capacity(total_rows);
1015    for i in 0..outer_list.len() {
1016        let rings = outer_list.get_as_series(i).and_then(|ring_series| {
1017            let inner_list = ring_series.list().ok()?;
1018            let mut rings = Vec::new();
1019            for j in 0..inner_list.len() {
1020                if let Some(coords_series) = inner_list.get_as_series(j)
1021                    && let Ok(f32_ca) = coords_series.f32()
1022                {
1023                    let coords: Vec<f32> = f32_ca.into_iter().map(|v| v.unwrap_or(0.0)).collect();
1024                    // Convert flat [x, y, x, y, ...] to Vec<(f32, f32)>
1025                    let points: Vec<(f32, f32)> = coords
1026                        .chunks(2)
1027                        .filter(|c| c.len() == 2)
1028                        .map(|c| (c[0], c[1]))
1029                        .collect();
1030                    if !points.is_empty() {
1031                        rings.push(points);
1032                    }
1033                }
1034            }
1035            if rings.is_empty() { None } else { Some(rings) }
1036        });
1037        result.push(rings);
1038    }
1039    result
1040}
1041
1042/// Extract binary mask data from a `Binary` column (2026.04 schema — PNG bytes).
1043fn extract_all_binary_masks(col: &Column, total_rows: usize) -> Vec<Option<Vec<u8>>> {
1044    let binary_ca = match col.binary() {
1045        Ok(b) => b,
1046        Err(_) => return vec![None; total_rows],
1047    };
1048
1049    (0..binary_ca.len())
1050        .map(|i| binary_ca.get(i).map(|bytes| bytes.to_vec()))
1051        .collect()
1052}
1053
1054/// Extract an optional Float32 column by name.
1055fn extract_f32_column(df: &DataFrame, name: &str, total_rows: usize) -> Vec<Option<f32>> {
1056    df.column(name)
1057        .ok()
1058        .and_then(|c| c.f32().ok())
1059        .map(|ca| ca.into_iter().collect())
1060        .unwrap_or_else(|| vec![None; total_rows])
1061}
1062
1063/// Decode a PNG mask (Binary column bytes) into COCO RLE segmentation.
1064///
1065/// Validates the PNG, decodes pixels, binarizes if needed (8-bit or 16-bit),
1066/// and encodes as COCO RLE. Returns `None` for empty or invalid data (with
1067/// a warning log for invalid cases).
1068fn png_to_rle_segmentation(png_bytes: &[u8], row_index: usize) -> Option<CocoSegmentation> {
1069    if png_bytes.is_empty() {
1070        return None;
1071    }
1072
1073    let mask_data = match crate::MaskData::from_png_checked(png_bytes.to_vec()) {
1074        Ok(m) => m,
1075        Err(e) => {
1076            log::warn!("Skipping invalid PNG mask at row {}: {}", row_index, e);
1077            return None;
1078        }
1079    };
1080
1081    let mw = mask_data.width();
1082    let mh = mask_data.height();
1083    let bit_depth = mask_data.bit_depth();
1084
1085    let decoded = match mask_data.decode() {
1086        Ok(d) => d,
1087        Err(e) => {
1088            log::warn!("Failed to decode PNG mask at row {}: {}", row_index, e);
1089            return None;
1090        }
1091    };
1092
1093    let binary_mask = match bit_depth {
1094        1 => decoded,
1095        8 => {
1096            log::warn!(
1097                "Binarizing 8-bit mask for row {} — score data is lost",
1098                row_index
1099            );
1100            decoded
1101                .iter()
1102                .map(|&v| if v >= 128 { 1 } else { 0 })
1103                .collect()
1104        }
1105        16 => {
1106            log::warn!(
1107                "Binarizing 16-bit mask for row {} — score data is lost",
1108                row_index
1109            );
1110            decoded
1111                .chunks(2)
1112                .map(|pair| {
1113                    let val = if pair.len() == 2 {
1114                        u16::from_be_bytes([pair[0], pair[1]])
1115                    } else {
1116                        0
1117                    };
1118                    if val >= 32768 { 1u8 } else { 0u8 }
1119                })
1120                .collect()
1121        }
1122        _ => decoded,
1123    };
1124
1125    match super::convert::encode_rle(&binary_mask, mw, mh) {
1126        Ok(rle) => Some(CocoSegmentation::Rle(rle)),
1127        Err(e) => {
1128            log::warn!("Failed to encode RLE for row {}: {}", row_index, e);
1129            None
1130        }
1131    }
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137    use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
1138    use tempfile::TempDir;
1139
1140    // =========================================================================
1141    // unflatten_polygon_coords tests
1142    // =========================================================================
1143
1144    #[test]
1145    fn test_unflatten_polygon_coords_empty() {
1146        let coords: Vec<f32> = vec![];
1147        let result = crate::unflatten_polygon_coordinates(&coords);
1148        assert!(result.is_empty());
1149    }
1150
1151    #[test]
1152    fn test_unflatten_polygon_coords_single_polygon() {
1153        // Simple rectangle: 4 points
1154        let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
1155        let result = crate::unflatten_polygon_coordinates(&coords);
1156
1157        assert_eq!(result.len(), 1);
1158        assert_eq!(result[0].len(), 4);
1159        assert_eq!(result[0][0], (0.1, 0.2));
1160        assert_eq!(result[0][3], (0.1, 0.4));
1161    }
1162
1163    #[test]
1164    fn test_unflatten_polygon_coords_multiple_polygons() {
1165        // Two triangles separated by NaN
1166        let coords = vec![
1167            0.1,
1168            0.1,
1169            0.2,
1170            0.1,
1171            0.15,
1172            0.2,      // First triangle
1173            f32::NAN, // Separator
1174            0.5,
1175            0.5,
1176            0.6,
1177            0.5,
1178            0.55,
1179            0.6, // Second triangle
1180        ];
1181        let result = crate::unflatten_polygon_coordinates(&coords);
1182
1183        assert_eq!(result.len(), 2);
1184        assert_eq!(result[0].len(), 3);
1185        assert_eq!(result[1].len(), 3);
1186        assert_eq!(result[0][0], (0.1, 0.1));
1187        assert_eq!(result[1][0], (0.5, 0.5));
1188    }
1189
1190    #[test]
1191    fn test_unflatten_polygon_coords_leading_nan() {
1192        // NaN at the start should be handled gracefully
1193        let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
1194        let result = crate::unflatten_polygon_coordinates(&coords);
1195
1196        assert_eq!(result.len(), 1);
1197        assert_eq!(result[0].len(), 3);
1198    }
1199
1200    #[test]
1201    fn test_unflatten_polygon_coords_trailing_nan() {
1202        // NaN at the end
1203        let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
1204        let result = crate::unflatten_polygon_coordinates(&coords);
1205
1206        assert_eq!(result.len(), 1);
1207        assert_eq!(result[0].len(), 2);
1208    }
1209
1210    #[test]
1211    fn test_unflatten_polygon_coords_consecutive_nans() {
1212        // Multiple NaNs in a row
1213        let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
1214        let result = crate::unflatten_polygon_coordinates(&coords);
1215
1216        assert_eq!(result.len(), 2);
1217        assert_eq!(result[0].len(), 1);
1218        assert_eq!(result[1].len(), 1);
1219    }
1220
1221    #[test]
1222    fn test_unflatten_polygon_coords_odd_values() {
1223        // Odd number of coordinates (trailing x without y)
1224        let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1225        let result = crate::unflatten_polygon_coordinates(&coords);
1226
1227        assert_eq!(result.len(), 1);
1228        assert_eq!(result[0].len(), 2); // Only complete pairs
1229    }
1230
1231    // =========================================================================
1232    // convert_image_annotations tests
1233    // =========================================================================
1234
1235    #[test]
1236    fn test_convert_image_annotations_basic() {
1237        let image = CocoImage {
1238            id: 1,
1239            width: 640,
1240            height: 480,
1241            file_name: "test_image.jpg".to_string(),
1242            ..Default::default()
1243        };
1244
1245        let dataset = CocoDataset {
1246            images: vec![image.clone()],
1247            categories: vec![CocoCategory {
1248                id: 1,
1249                name: "cat".to_string(),
1250                supercategory: Some("animal".to_string()),
1251                ..Default::default()
1252            }],
1253            annotations: vec![CocoAnnotation {
1254                id: 1,
1255                image_id: 1,
1256                category_id: 1,
1257                bbox: [100.0, 100.0, 200.0, 200.0],
1258                area: 40000.0,
1259                iscrowd: 0,
1260                segmentation: None,
1261                score: None,
1262            }],
1263            ..Default::default()
1264        };
1265
1266        let index = CocoIndex::from_dataset(&dataset);
1267        let samples = convert_image_annotations(&image, &index, true, Some("train"));
1268
1269        assert_eq!(samples.len(), 1);
1270        assert_eq!(samples[0].image_name, Some("test_image".to_string()));
1271        assert_eq!(samples[0].group, Some("train".to_string()));
1272        assert_eq!(samples[0].annotations.len(), 1);
1273        assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
1274    }
1275
1276    #[test]
1277    fn test_convert_image_annotations_with_mask() {
1278        let image = CocoImage {
1279            id: 1,
1280            width: 100,
1281            height: 100,
1282            file_name: "masked.jpg".to_string(),
1283            ..Default::default()
1284        };
1285
1286        let dataset = CocoDataset {
1287            images: vec![image.clone()],
1288            categories: vec![CocoCategory {
1289                id: 1,
1290                name: "object".to_string(),
1291                supercategory: None,
1292                ..Default::default()
1293            }],
1294            annotations: vec![CocoAnnotation {
1295                id: 1,
1296                image_id: 1,
1297                category_id: 1,
1298                bbox: [10.0, 10.0, 50.0, 50.0],
1299                area: 2500.0,
1300                iscrowd: 0,
1301                segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1302                    10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
1303                ]])),
1304                score: None,
1305            }],
1306            ..Default::default()
1307        };
1308
1309        let index = CocoIndex::from_dataset(&dataset);
1310
1311        // With masks enabled
1312        let samples_with_mask = convert_image_annotations(&image, &index, true, None);
1313        assert!(samples_with_mask[0].annotations[0].polygon().is_some());
1314
1315        // With masks disabled
1316        let samples_no_mask = convert_image_annotations(&image, &index, false, None);
1317        assert!(samples_no_mask[0].annotations[0].polygon().is_none());
1318    }
1319
1320    #[test]
1321    fn test_convert_image_annotations_no_annotations() {
1322        let image = CocoImage {
1323            id: 1,
1324            width: 640,
1325            height: 480,
1326            file_name: "empty.jpg".to_string(),
1327            ..Default::default()
1328        };
1329
1330        let dataset = CocoDataset {
1331            images: vec![image.clone()],
1332            categories: vec![],
1333            annotations: vec![],
1334            ..Default::default()
1335        };
1336
1337        let index = CocoIndex::from_dataset(&dataset);
1338        let samples = convert_image_annotations(&image, &index, true, None);
1339
1340        assert!(samples.is_empty());
1341    }
1342
1343    // =========================================================================
1344    // sample_name_from_filename tests
1345    // =========================================================================
1346
1347    #[test]
1348    fn test_sample_name_from_filename() {
1349        assert_eq!(
1350            sample_name_from_filename("000000397133.jpg"),
1351            "000000397133"
1352        );
1353        assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
1354        assert_eq!(sample_name_from_filename("test"), "test");
1355    }
1356
1357    #[test]
1358    fn test_sample_name_from_filename_nested_path() {
1359        assert_eq!(
1360            sample_name_from_filename("a/b/c/deep_image.png"),
1361            "deep_image"
1362        );
1363    }
1364
1365    #[test]
1366    fn test_sample_name_from_filename_no_extension() {
1367        assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
1368    }
1369
1370    // =========================================================================
1371    // Options tests
1372    // =========================================================================
1373
1374    #[test]
1375    fn test_coco_to_arrow_options_default() {
1376        let options = CocoToArrowOptions::default();
1377        assert!(options.include_masks);
1378        assert!(options.group.is_none());
1379        assert!(options.max_workers >= 2);
1380    }
1381
1382    #[test]
1383    fn test_arrow_to_coco_options_default() {
1384        let options = ArrowToCocoOptions::default();
1385        assert!(options.groups.is_empty());
1386        assert!(options.include_masks);
1387        assert!(options.info.is_none());
1388    }
1389
1390    #[test]
1391    fn test_max_workers() {
1392        let workers = max_workers();
1393        assert!(workers >= 2);
1394        assert!(workers <= 8);
1395    }
1396
1397    #[tokio::test]
1398    async fn test_coco_to_arrow_minimal() {
1399        let temp_dir = TempDir::new().unwrap();
1400
1401        // Create minimal COCO JSON
1402        let coco_json = r#"{
1403            "images": [
1404                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1405            ],
1406            "annotations": [
1407                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1408            ],
1409            "categories": [
1410                {"id": 1, "name": "person", "supercategory": "human"}
1411            ]
1412        }"#;
1413
1414        let coco_path = temp_dir.path().join("test.json");
1415        std::fs::write(&coco_path, coco_json).unwrap();
1416
1417        let arrow_path = temp_dir.path().join("output.arrow");
1418
1419        let options = CocoToArrowOptions::default();
1420        let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1421            .await
1422            .unwrap();
1423
1424        assert_eq!(count, 1);
1425        assert!(arrow_path.exists());
1426
1427        // Verify Arrow contents
1428        let mut file = std::fs::File::open(&arrow_path).unwrap();
1429        let df = IpcReader::new(&mut file).finish().unwrap();
1430        assert_eq!(df.height(), 1);
1431    }
1432
1433    #[tokio::test]
1434    async fn test_arrow_to_coco_roundtrip() {
1435        let temp_dir = TempDir::new().unwrap();
1436
1437        // Create COCO JSON
1438        let original = CocoDataset {
1439            images: vec![CocoImage {
1440                id: 1,
1441                width: 640,
1442                height: 480,
1443                file_name: "test.jpg".to_string(),
1444                ..Default::default()
1445            }],
1446            annotations: vec![CocoAnnotation {
1447                id: 1,
1448                image_id: 1,
1449                category_id: 1,
1450                bbox: [100.0, 50.0, 200.0, 150.0],
1451                area: 30000.0,
1452                iscrowd: 0,
1453                segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1454                    100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
1455                ]])),
1456                score: None,
1457            }],
1458            categories: vec![CocoCategory {
1459                id: 1,
1460                name: "person".to_string(),
1461                supercategory: Some("human".to_string()),
1462                ..Default::default()
1463            }],
1464            ..Default::default()
1465        };
1466
1467        // Write original COCO
1468        let coco_path = temp_dir.path().join("original.json");
1469        let writer = CocoWriter::new();
1470        writer.write_json(&original, &coco_path).unwrap();
1471
1472        // Convert to Arrow
1473        let arrow_path = temp_dir.path().join("converted.arrow");
1474        let options = CocoToArrowOptions::default();
1475        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1476            .await
1477            .unwrap();
1478
1479        // Convert back to COCO
1480        let restored_path = temp_dir.path().join("restored.json");
1481        let options = ArrowToCocoOptions::default();
1482        arrow_to_coco(&arrow_path, &restored_path, &options, None)
1483            .await
1484            .unwrap();
1485
1486        // Verify restored data
1487        let reader = CocoReader::new();
1488        let restored = reader.read_json(&restored_path).unwrap();
1489
1490        assert_eq!(restored.images.len(), 1);
1491        assert_eq!(restored.annotations.len(), 1);
1492        assert_eq!(restored.categories.len(), 1);
1493
1494        // Check category name preserved
1495        assert_eq!(restored.categories[0].name, "person");
1496    }
1497
1498    // =========================================================================
1499    // Arrow IPC file metadata tests
1500    // =========================================================================
1501
1502    #[tokio::test]
1503    async fn test_coco_to_arrow_schema_version_metadata() {
1504        let temp_dir = TempDir::new().unwrap();
1505
1506        // Create minimal COCO JSON (no LVIS fields)
1507        let coco_json = r#"{
1508            "images": [
1509                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1510            ],
1511            "annotations": [
1512                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1513            ],
1514            "categories": [
1515                {"id": 1, "name": "person", "supercategory": "human"}
1516            ]
1517        }"#;
1518
1519        let coco_path = temp_dir.path().join("test.json");
1520        std::fs::write(&coco_path, coco_json).unwrap();
1521
1522        let arrow_path = temp_dir.path().join("output.arrow");
1523        let options = CocoToArrowOptions::default();
1524        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1525            .await
1526            .unwrap();
1527
1528        // Read back and verify schema_version metadata
1529        let mut file = std::fs::File::open(&arrow_path).unwrap();
1530        let mut reader = IpcReader::new(&mut file);
1531        let custom_meta = reader.custom_metadata().unwrap();
1532        assert!(custom_meta.is_some(), "custom metadata should be present");
1533
1534        let meta = custom_meta.unwrap();
1535        assert_eq!(
1536            meta.get(&PlSmallStr::from("schema_version")),
1537            Some(&PlSmallStr::from(SCHEMA_VERSION)),
1538            "schema_version metadata should be '2026.04'"
1539        );
1540
1541        // category_metadata is always present when there are categories
1542        assert!(
1543            meta.contains_key(&PlSmallStr::from("category_metadata")),
1544            "category_metadata should be present even without LVIS fields"
1545        );
1546    }
1547
1548    #[tokio::test]
1549    async fn test_coco_to_arrow_category_metadata_lvis() {
1550        let temp_dir = TempDir::new().unwrap();
1551
1552        // Create COCO JSON with LVIS category fields
1553        let coco_json = r#"{
1554            "images": [
1555                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1556            ],
1557            "annotations": [
1558                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0},
1559                {"id": 2, "image_id": 1, "category_id": 2, "bbox": [50, 60, 80, 40], "area": 3200, "iscrowd": 0}
1560            ],
1561            "categories": [
1562                {
1563                    "id": 1,
1564                    "name": "aerosol_can",
1565                    "synset": "aerosol.n.02",
1566                    "synonyms": ["aerosol_can", "spray_can"],
1567                    "def": "a dispenser that holds a substance under pressure"
1568                },
1569                {
1570                    "id": 2,
1571                    "name": "person",
1572                    "supercategory": "human"
1573                }
1574            ]
1575        }"#;
1576
1577        let coco_path = temp_dir.path().join("lvis.json");
1578        std::fs::write(&coco_path, coco_json).unwrap();
1579
1580        let arrow_path = temp_dir.path().join("lvis_output.arrow");
1581        let options = CocoToArrowOptions::default();
1582        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1583            .await
1584            .unwrap();
1585
1586        // Read back and verify metadata
1587        let mut file = std::fs::File::open(&arrow_path).unwrap();
1588        let mut reader = IpcReader::new(&mut file);
1589        let custom_meta = reader.custom_metadata().unwrap();
1590        assert!(custom_meta.is_some(), "custom metadata should be present");
1591
1592        let meta = custom_meta.unwrap();
1593
1594        // schema_version is always present
1595        assert_eq!(
1596            meta.get(&PlSmallStr::from("schema_version")),
1597            Some(&PlSmallStr::from(SCHEMA_VERSION)),
1598        );
1599
1600        // category_metadata should be present (aerosol_can has LVIS fields)
1601        let cat_meta_str = meta
1602            .get(&PlSmallStr::from("category_metadata"))
1603            .expect("category_metadata should be present for LVIS data");
1604
1605        let cat_meta: HashMap<String, serde_json::Value> =
1606            serde_json::from_str(cat_meta_str.as_str()).unwrap();
1607
1608        // Both categories should be present (all categories are now stored)
1609        assert!(
1610            cat_meta.contains_key("aerosol_can"),
1611            "aerosol_can should be in category_metadata"
1612        );
1613        assert!(
1614            cat_meta.contains_key("person"),
1615            "person should also be in category_metadata"
1616        );
1617
1618        // Verify aerosol_can entry contents
1619        let aerosol = cat_meta.get("aerosol_can").unwrap();
1620        assert_eq!(
1621            aerosol.get("synset").and_then(|v| v.as_str()),
1622            Some("aerosol.n.02")
1623        );
1624        assert_eq!(
1625            aerosol.get("definition").and_then(|v| v.as_str()),
1626            Some("a dispenser that holds a substance under pressure")
1627        );
1628        let synonyms = aerosol.get("synonyms").and_then(|v| v.as_array()).unwrap();
1629        assert_eq!(synonyms.len(), 2);
1630        assert_eq!(synonyms[0].as_str(), Some("aerosol_can"));
1631        assert_eq!(synonyms[1].as_str(), Some("spray_can"));
1632    }
1633
1634    // =========================================================================
1635    // LVIS round-trip tests
1636    // =========================================================================
1637
1638    #[tokio::test]
1639    async fn test_coco_arrow_roundtrip_lvis_supercategory() {
1640        let temp_dir = TempDir::new().unwrap();
1641
1642        // Create COCO JSON with supercategory
1643        let coco_json = r#"{
1644            "images": [
1645                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1646            ],
1647            "annotations": [
1648                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1649            ],
1650            "categories": [
1651                {"id": 1, "name": "person", "supercategory": "human"}
1652            ]
1653        }"#;
1654
1655        let coco_path = temp_dir.path().join("original.json");
1656        std::fs::write(&coco_path, coco_json).unwrap();
1657
1658        // Convert to Arrow
1659        let arrow_path = temp_dir.path().join("converted.arrow");
1660        let options = CocoToArrowOptions::default();
1661        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1662            .await
1663            .unwrap();
1664
1665        // Convert back to COCO
1666        let restored_path = temp_dir.path().join("restored.json");
1667        let options = ArrowToCocoOptions::default();
1668        arrow_to_coco(&arrow_path, &restored_path, &options, None)
1669            .await
1670            .unwrap();
1671
1672        // Verify supercategory is preserved
1673        let reader = CocoReader::new();
1674        let restored = reader.read_json(&restored_path).unwrap();
1675
1676        assert_eq!(restored.categories.len(), 1);
1677        assert_eq!(restored.categories[0].name, "person");
1678        assert_eq!(
1679            restored.categories[0].supercategory,
1680            Some("human".to_string()),
1681            "supercategory should survive COCO→Arrow→COCO round-trip"
1682        );
1683    }
1684
1685    #[tokio::test]
1686    async fn test_coco_arrow_roundtrip_neg_categories_no_annotations() {
1687        let temp_dir = TempDir::new().unwrap();
1688
1689        // Create COCO JSON: image has neg_category_ids but NO annotations
1690        let coco_json = r#"{
1691            "images": [
1692                {
1693                    "id": 1,
1694                    "width": 640,
1695                    "height": 480,
1696                    "file_name": "empty.jpg",
1697                    "neg_category_ids": [1, 2]
1698                }
1699            ],
1700            "annotations": [],
1701            "categories": [
1702                {"id": 1, "name": "cat", "supercategory": "animal"},
1703                {"id": 2, "name": "dog", "supercategory": "animal"}
1704            ]
1705        }"#;
1706
1707        let coco_path = temp_dir.path().join("original.json");
1708        std::fs::write(&coco_path, coco_json).unwrap();
1709
1710        // Convert to Arrow
1711        let arrow_path = temp_dir.path().join("converted.arrow");
1712        let options = CocoToArrowOptions::default();
1713        let sample_count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1714            .await
1715            .unwrap();
1716
1717        // Should have 1 sentinel sample (image with neg data but no annotations)
1718        assert_eq!(
1719            sample_count, 1,
1720            "sentinel row should be emitted for image with neg data"
1721        );
1722
1723        // Convert back to COCO
1724        let restored_path = temp_dir.path().join("restored.json");
1725        let options = ArrowToCocoOptions::default();
1726        arrow_to_coco(&arrow_path, &restored_path, &options, None)
1727            .await
1728            .unwrap();
1729
1730        // Verify neg_category_ids survived the round-trip
1731        let reader = CocoReader::new();
1732        let restored = reader.read_json(&restored_path).unwrap();
1733
1734        assert_eq!(restored.images.len(), 1);
1735        assert_eq!(restored.annotations.len(), 0, "no annotations expected");
1736        assert_eq!(restored.categories.len(), 2, "both categories should exist");
1737
1738        let neg = restored.images[0].neg_category_ids.as_ref();
1739        assert!(
1740            neg.is_some(),
1741            "neg_category_ids should survive round-trip for zero-annotation image"
1742        );
1743        let neg_ids = neg.unwrap();
1744        assert_eq!(neg_ids.len(), 2, "should have 2 neg categories");
1745        assert!(neg_ids.contains(&1), "neg_category_ids should contain 1");
1746        assert!(neg_ids.contains(&2), "neg_category_ids should contain 2");
1747
1748        // Verify supercategory survives for annotation-free categories
1749        for cat in &restored.categories {
1750            assert_eq!(
1751                cat.supercategory,
1752                Some("animal".to_string()),
1753                "supercategory should survive round-trip for annotation-free category '{}'",
1754                cat.name
1755            );
1756        }
1757    }
1758
1759    #[test]
1760    fn test_convert_image_annotations_neg_only_no_annotations() {
1761        let image = CocoImage {
1762            id: 1,
1763            width: 640,
1764            height: 480,
1765            file_name: "neg_only.jpg".to_string(),
1766            neg_category_ids: Some(vec![1, 2]),
1767            ..Default::default()
1768        };
1769
1770        let dataset = CocoDataset {
1771            images: vec![image.clone()],
1772            categories: vec![
1773                CocoCategory {
1774                    id: 1,
1775                    name: "cat".to_string(),
1776                    supercategory: Some("animal".to_string()),
1777                    ..Default::default()
1778                },
1779                CocoCategory {
1780                    id: 2,
1781                    name: "dog".to_string(),
1782                    supercategory: Some("animal".to_string()),
1783                    ..Default::default()
1784                },
1785            ],
1786            annotations: vec![],
1787            ..Default::default()
1788        };
1789
1790        let index = CocoIndex::from_dataset(&dataset);
1791        let samples = convert_image_annotations(&image, &index, true, None);
1792
1793        // Should emit 1 sentinel sample (no annotations but has neg data)
1794        assert_eq!(
1795            samples.len(),
1796            1,
1797            "sentinel row should be emitted for neg-only image"
1798        );
1799        assert_eq!(samples[0].image_name, Some("neg_only".to_string()));
1800        assert!(
1801            samples[0].annotations.is_empty(),
1802            "sentinel should have no annotations"
1803        );
1804        assert!(
1805            samples[0].neg_label_indices.is_some(),
1806            "sentinel should preserve neg_label_indices"
1807        );
1808        assert_eq!(samples[0].neg_label_indices.as_ref().unwrap().len(), 2);
1809    }
1810}