Skip to main content

edgefirst_client/coco/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! COCO to EdgeFirst Arrow format conversion.
5//!
6//! Provides high-performance conversion between COCO JSON and EdgeFirst Arrow
7//! format, supporting async operations and progress tracking.
8
9use super::{
10    convert::{
11        box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask_data,
12        coco_segmentation_to_polygon, polygon_to_coco_polygon,
13    },
14    reader::CocoReader,
15    types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
16    writer::{CocoDatasetBuilder, CocoWriter},
17};
18use crate::{Annotation, Box2d, Error, Polygon, Progress, Sample};
19use polars::prelude::*;
20use std::{
21    collections::{BTreeMap, HashMap},
22    path::Path,
23    sync::{
24        Arc,
25        atomic::{AtomicUsize, Ordering},
26    },
27};
28use tokio::sync::{Semaphore, mpsc::Sender};
29
30/// Schema version written into Arrow IPC file metadata.
31pub const SCHEMA_VERSION: &str = "2026.04";
32
33/// Polygon rings for a single row: each ring is a vec of `(x, y)` coordinate pairs.
34type PolygonRings = Vec<Vec<(f32, f32)>>;
35
36/// Options for COCO to Arrow conversion.
37#[derive(Debug, Clone)]
38pub struct CocoToArrowOptions {
39    /// Include segmentation masks in output.
40    pub include_masks: bool,
41    /// Group name for all samples (e.g., "train", "val").
42    pub group: Option<String>,
43    /// Maximum number of parallel workers.
44    pub max_workers: usize,
45}
46
47impl Default for CocoToArrowOptions {
48    fn default() -> Self {
49        Self {
50            include_masks: true,
51            group: None,
52            max_workers: max_workers(),
53        }
54    }
55}
56
57/// Options for Arrow to COCO conversion.
58#[derive(Debug, Clone)]
59pub struct ArrowToCocoOptions {
60    /// Filter by group names (empty = all).
61    pub groups: Vec<String>,
62    /// Include segmentation masks in output.
63    pub include_masks: bool,
64    /// COCO info section.
65    pub info: Option<CocoInfo>,
66}
67
68impl Default for ArrowToCocoOptions {
69    fn default() -> Self {
70        Self {
71            groups: vec![],
72            include_masks: true,
73            info: None,
74        }
75    }
76}
77
78/// Determine maximum number of parallel workers.
79fn max_workers() -> usize {
80    std::env::var("MAX_COCO_WORKERS")
81        .ok()
82        .and_then(|v| v.parse().ok())
83        .unwrap_or_else(|| {
84            let cpus = std::thread::available_parallelism()
85                .map(|n| n.get())
86                .unwrap_or(4);
87            (cpus / 2).clamp(2, 8)
88        })
89}
90
91/// Convert COCO annotations to EdgeFirst Arrow format.
92///
93/// This is a high-performance async conversion that uses parallel workers
94/// for parsing and transforming annotations.
95///
96/// # Arguments
97/// * `coco_path` - Path to COCO annotation JSON file or ZIP archive
98/// * `output_path` - Output Arrow file path
99/// * `options` - Conversion options
100/// * `progress` - Optional progress channel
101///
102/// # Returns
103/// Number of samples converted
104pub async fn coco_to_arrow<P: AsRef<Path>>(
105    coco_path: P,
106    output_path: P,
107    options: &CocoToArrowOptions,
108    progress: Option<Sender<Progress>>,
109) -> Result<usize, Error> {
110    let coco_path = coco_path.as_ref();
111    let output_path = output_path.as_ref();
112
113    // Read COCO dataset
114    let reader = CocoReader::new();
115    let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
116        reader.read_annotations_zip(coco_path)?
117    } else {
118        reader.read_json(coco_path)?
119    };
120
121    // Build index for efficient lookups
122    let index = Arc::new(CocoIndex::from_dataset(&dataset));
123    let total_images = dataset.images.len();
124
125    // Send initial progress
126    if let Some(ref p) = progress {
127        let _ = p
128            .send(Progress {
129                current: 0,
130                total: total_images,
131                status: None,
132            })
133            .await;
134    }
135
136    // Process images in parallel
137    let sem = Arc::new(Semaphore::new(options.max_workers));
138    let current = Arc::new(AtomicUsize::new(0));
139    let include_masks = options.include_masks;
140    let group = options.group.clone();
141
142    let mut tasks = Vec::with_capacity(total_images);
143
144    for image in dataset.images {
145        let sem = sem.clone();
146        let index = index.clone();
147        let current = current.clone();
148        let progress = progress.clone();
149        let total = total_images;
150        let group = group.clone();
151
152        let task = tokio::spawn(async move {
153            let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
154
155            // Convert this image's annotations to EdgeFirst samples
156            let samples =
157                convert_image_annotations(&image, &index, include_masks, group.as_deref());
158
159            // Update progress
160            let c = current.fetch_add(1, Ordering::SeqCst) + 1;
161            if let Some(ref p) = progress {
162                let _ = p
163                    .send(Progress {
164                        current: c,
165                        total,
166                        status: None,
167                    })
168                    .await;
169            }
170
171            Ok::<Vec<Sample>, Error>(samples)
172        });
173
174        tasks.push(task);
175    }
176
177    // Collect all samples
178    let mut all_samples = Vec::with_capacity(total_images);
179    for task in tasks {
180        let samples = task.await??;
181        all_samples.extend(samples);
182    }
183
184    // Convert to DataFrame
185    let df = crate::samples_dataframe(&all_samples)?;
186
187    // Build schema-level metadata
188    let mut metadata: BTreeMap<PlSmallStr, PlSmallStr> = BTreeMap::new();
189    metadata.insert(
190        PlSmallStr::from("schema_version"),
191        PlSmallStr::from(SCHEMA_VERSION),
192    );
193
194    // Build category_metadata JSON from all categories.
195    // Includes id, frequency, and any LVIS fields (synset, synonyms, def).
196    // All categories are stored so that categories without annotations
197    // (e.g., those only referenced in neg_category_ids) can be
198    // reconstructed during Arrow→COCO export.
199    if !dataset.categories.is_empty() {
200        let cat_meta: HashMap<String, serde_json::Value> = dataset
201            .categories
202            .iter()
203            .map(|c| {
204                let mut entry = serde_json::Map::new();
205                entry.insert("id".to_string(), serde_json::json!(c.id));
206                if let Some(ref f) = c.frequency {
207                    entry.insert(
208                        "frequency".to_string(),
209                        serde_json::Value::String(f.clone()),
210                    );
211                }
212                if let Some(ref s) = c.synset {
213                    entry.insert("synset".to_string(), serde_json::Value::String(s.clone()));
214                }
215                if let Some(ref syns) = c.synonyms {
216                    entry.insert("synonyms".to_string(), serde_json::json!(syns));
217                }
218                if let Some(ref d) = c.def {
219                    entry.insert(
220                        "definition".to_string(),
221                        serde_json::Value::String(d.clone()),
222                    );
223                }
224                if let Some(ref sc) = c.supercategory {
225                    entry.insert(
226                        "supercategory".to_string(),
227                        serde_json::Value::String(sc.clone()),
228                    );
229                }
230                // Note: image_count and instance_count are intentionally not
231                // stored — they are recomputable statistics that can be derived
232                // from the annotations at any time.
233                (c.name.clone(), serde_json::Value::Object(entry))
234            })
235            .collect();
236
237        let json = serde_json::to_string(&cat_meta).unwrap_or_default();
238        metadata.insert(
239            PlSmallStr::from("category_metadata"),
240            PlSmallStr::from(json.as_str()),
241        );
242    }
243
244    // Write labels metadata: sorted list of category names by category_id.
245    if !dataset.categories.is_empty() {
246        let mut cats: Vec<_> = dataset.categories.iter().collect();
247        cats.sort_by_key(|c| c.id);
248        let labels: Vec<String> = cats.iter().map(|c| c.name.clone()).collect();
249        let labels_json = serde_json::to_string(&labels).unwrap_or_default();
250        metadata.insert(PlSmallStr::from("labels"), PlSmallStr::from(labels_json));
251    }
252
253    // Write Arrow file
254    if let Some(parent) = output_path.parent()
255        && !parent.as_os_str().is_empty()
256    {
257        std::fs::create_dir_all(parent)?;
258    }
259    let mut file = std::fs::File::create(output_path)?;
260    let mut writer = IpcWriter::new(&mut file);
261    writer.set_custom_schema_metadata(Arc::new(metadata));
262    writer.finish(&mut df.clone())?;
263
264    Ok(all_samples.len())
265}
266
267/// Convert a single image's annotations to EdgeFirst samples.
268fn convert_image_annotations(
269    image: &CocoImage,
270    index: &CocoIndex,
271    include_masks: bool,
272    group: Option<&str>,
273) -> Vec<Sample> {
274    let annotations = index.annotations_for_image(image.id);
275    let sample_name = sample_name_from_filename(&image.file_name);
276
277    // Translate LVIS image-level fields to label_index lists
278    let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
279        ids.iter()
280            .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
281            .collect::<Vec<u32>>()
282    });
283    let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
284        ids.iter()
285            .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
286            .collect::<Vec<u32>>()
287    });
288
289    let mut samples: Vec<Sample> = annotations
290        .iter()
291        .filter_map(|ann| {
292            let label = index.label_name(ann.category_id)?;
293            let label_index = index.label_index(ann.category_id);
294
295            // Convert bbox
296            let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
297
298            // Convert segmentation based on type:
299            // - Polygon → annotation.polygon (normalized coords)
300            // - RLE/CompressedRle → annotation.mask (PNG-encoded MaskData)
301            let (polygon, mask) = if include_masks {
302                if let Some(seg) = &ann.segmentation {
303                    match seg {
304                        CocoSegmentation::Polygon(_) => {
305                            let poly =
306                                coco_segmentation_to_polygon(seg, image.width, image.height).ok();
307                            (poly, None)
308                        }
309                        CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_) => {
310                            let mask_data = coco_segmentation_to_mask_data(seg).ok().flatten();
311                            (None, mask_data)
312                        }
313                    }
314                } else {
315                    (None, None)
316                }
317            } else {
318                (None, None)
319            };
320
321            let mut annotation = Annotation::new();
322            annotation.set_name(Some(sample_name.clone()));
323            annotation.set_object_id(Some(ann.id.to_string()));
324            annotation.set_label(Some(label.to_string()));
325            annotation.set_label_index(label_index);
326            annotation.set_box2d(Some(box2d));
327            annotation.set_polygon(polygon);
328            annotation.set_mask(mask);
329            annotation.set_group(group.map(String::from));
330            annotation.set_iscrowd(Some(ann.iscrowd != 0));
331            annotation.set_category_frequency(index.frequency(ann.category_id).map(String::from));
332
333            // Map COCO score to appropriate geometry score field
334            if let Some(score) = ann.score {
335                let score_f32 = score as f32;
336                if annotation.mask().is_some() {
337                    annotation.set_mask_score(Some(score_f32));
338                } else if annotation.polygon().is_some() {
339                    annotation.set_polygon_score(Some(score_f32));
340                } else {
341                    annotation.set_box2d_score(Some(score_f32));
342                }
343            }
344
345            let mut sample = Sample {
346                image_name: Some(sample_name.clone()),
347                width: Some(image.width),
348                height: Some(image.height),
349                group: group.map(String::from),
350                annotations: vec![annotation],
351                ..Default::default()
352            };
353            sample.neg_label_indices = neg_label_indices.clone();
354            sample.not_exhaustive_label_indices = not_exhaustive_label_indices.clone();
355
356            Some(sample)
357        })
358        .collect();
359
360    // Emit a placeholder row for any image with no annotations so the image is
361    // never dropped from the dataset. This preserves the image's group (dataset
362    // split) for every image, and carries any LVIS neg/exhaustive category data
363    // for images that have verified-negative labels but no positive annotations.
364    if samples.is_empty() {
365        let mut sample = Sample {
366            image_name: Some(sample_name.clone()),
367            width: Some(image.width),
368            height: Some(image.height),
369            group: group.map(String::from),
370            ..Default::default()
371        };
372        sample.neg_label_indices = neg_label_indices;
373        sample.not_exhaustive_label_indices = not_exhaustive_label_indices;
374        samples.push(sample);
375    }
376
377    samples
378}
379
380/// Extract sample name from image filename.
381fn sample_name_from_filename(filename: &str) -> String {
382    Path::new(filename)
383        .file_stem()
384        .and_then(|s| s.to_str())
385        .map(String::from)
386        .unwrap_or_else(|| filename.to_string())
387}
388
389/// Convert EdgeFirst Arrow format to COCO annotations.
390///
391/// Reads an Arrow file and produces COCO JSON output. LVIS extension fields
392/// are preserved when present in the Arrow file: `neg_category_ids`,
393/// `not_exhaustive_category_ids`, category `frequency`, annotation `iscrowd`,
394/// `supercategory`, and category metadata (`synset`, `synonyms`, `def`).
395///
396/// # Arguments
397/// * `arrow_path` - Path to EdgeFirst Arrow file
398/// * `output_path` - Output COCO JSON file path
399/// * `options` - Conversion options
400/// * `progress` - Optional progress channel
401///
402/// # Returns
403/// Number of annotations converted
404pub async fn arrow_to_coco<P: AsRef<Path>>(
405    arrow_path: P,
406    output_path: P,
407    options: &ArrowToCocoOptions,
408    progress: Option<Sender<Progress>>,
409) -> Result<usize, Error> {
410    let arrow_path = arrow_path.as_ref();
411    let output_path = output_path.as_ref();
412
413    // Read file-level metadata (must be done before consuming the reader)
414    let (schema_version, category_metadata_json, labels_metadata_json) = {
415        let mut meta_file = std::fs::File::open(arrow_path)?;
416        let mut reader = IpcReader::new(&mut meta_file);
417        let meta = reader.custom_metadata().ok().flatten();
418        let sv = meta.as_ref().and_then(|m| {
419            m.get(&PlSmallStr::from("schema_version"))
420                .map(|s| s.to_string())
421        });
422        let cm = meta.as_ref().and_then(|m| {
423            m.get(&PlSmallStr::from("category_metadata"))
424                .map(|s| s.to_string())
425        });
426        let lm = meta
427            .as_ref()
428            .and_then(|m| m.get(&PlSmallStr::from("labels")).map(|s| s.to_string()));
429        (sv, cm, lm)
430    };
431
432    // Determine format version: absent → 2025.10, present → use value
433    let is_legacy = schema_version.is_none();
434
435    // Read Arrow file
436    let mut file = std::fs::File::open(arrow_path)?;
437    let df = IpcReader::new(&mut file).finish()?;
438
439    // Get group column for filtering
440    let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
441
442    let total_rows = df.height();
443
444    if let Some(ref p) = progress {
445        let _ = p
446            .send(Progress {
447                current: 0,
448                total: total_rows,
449                status: None,
450            })
451            .await;
452    }
453
454    // Extract columns - all at once for O(n) instead of O(n²) per-row access
455    let names: Vec<String> = df
456        .column("name")?
457        .str()?
458        .into_iter()
459        .map(|s| s.unwrap_or_default().to_string())
460        .collect();
461
462    let labels: Vec<String> = df
463        .column("label")
464        .ok()
465        .and_then(|c| c.cast(&DataType::String).ok())
466        .map(|c| {
467            c.str()
468                .ok()
469                .map(|s| {
470                    s.into_iter()
471                        .map(|v| v.unwrap_or_default().to_string())
472                        .collect()
473                })
474                .unwrap_or_else(|| vec![String::new(); total_rows])
475        })
476        .unwrap_or_else(|| vec![String::new(); total_rows]);
477
478    let label_indices: Vec<Option<u64>> = df
479        .column("label_index")
480        .ok()
481        .map(|c| {
482            c.u64()
483                .ok()
484                .map(|s| s.into_iter().collect())
485                .unwrap_or_else(|| vec![None; total_rows])
486        })
487        .unwrap_or_else(|| vec![None; total_rows]);
488
489    // Get group column for filtering
490    let groups: Vec<String> = df
491        .column("group")
492        .ok()
493        .and_then(|c| c.cast(&DataType::String).ok())
494        .map(|c| {
495            c.str()
496                .ok()
497                .map(|s| {
498                    s.into_iter()
499                        .map(|v| v.unwrap_or_default().to_string())
500                        .collect()
501                })
502                .unwrap_or_default()
503        })
504        .unwrap_or_else(|| vec!["".to_string(); total_rows]);
505
506    // Extract all box2d values upfront (O(n) instead of O(n²))
507    let box2ds = df
508        .column("box2d")
509        .ok()
510        .map(extract_all_box2ds)
511        .transpose()?
512        .unwrap_or_else(|| vec![[0.0; 4]; total_rows]);
513
514    // Extract segmentation data based on schema version
515    //
516    // 2025.10 (legacy): mask column is List(Float32) with NaN-separated polygon coords
517    // 2026.04+:         polygon column is List(List(Float32)), mask column is Binary (PNG)
518    let legacy_masks: Option<Vec<Vec<f32>>> = if is_legacy && options.include_masks {
519        df.column("mask").ok().map(extract_all_masks).transpose()?
520    } else {
521        None
522    };
523
524    let polygons_2026: Option<Vec<Option<PolygonRings>>> = if !is_legacy && options.include_masks {
525        df.column("polygon")
526            .ok()
527            .map(|c| extract_all_polygons(c, total_rows))
528    } else {
529        None
530    };
531
532    let mask_binary_2026: Option<Vec<Option<Vec<u8>>>> = if !is_legacy && options.include_masks {
533        df.column("mask")
534            .ok()
535            .map(|c| extract_all_binary_masks(c, total_rows))
536    } else {
537        None
538    };
539
540    // Extract all sizes upfront if present
541    let sizes = df
542        .column("size")
543        .ok()
544        .and_then(|c| extract_all_sizes(c).ok());
545
546    // Extract iscrowd column (optional, Boolean in 2026.04, UInt32 in older schemas)
547    let iscrowds: Vec<u8> = df
548        .column("iscrowd")
549        .ok()
550        .map(|c| {
551            // Try Boolean first (2026.04 schema), then fall back to UInt32 (older schemas)
552            if let Ok(bool_ca) = c.bool() {
553                bool_ca
554                    .into_iter()
555                    .map(|v| if v.unwrap_or(false) { 1 } else { 0 })
556                    .collect()
557            } else {
558                c.u32()
559                    .ok()
560                    .map(|s| s.into_iter().map(|v| v.unwrap_or(0) as u8).collect())
561                    .unwrap_or_else(|| vec![0; total_rows])
562            }
563        })
564        .unwrap_or_else(|| vec![0; total_rows]);
565
566    // Extract category_frequency column (optional, Categorical/String)
567    let category_frequencies: Vec<Option<String>> = df
568        .column("category_frequency")
569        .ok()
570        .and_then(|c| c.cast(&DataType::String).ok())
571        .map(|c| {
572            c.str()
573                .ok()
574                .map(|s| s.into_iter().map(|v| v.map(String::from)).collect())
575                .unwrap_or_else(|| vec![None; total_rows])
576        })
577        .unwrap_or_else(|| vec![None; total_rows]);
578
579    // Extract neg_label_indices column (optional, List<UInt32>)
580    let neg_label_indices: Vec<Option<Vec<u32>>> = df
581        .column("neg_label_indices")
582        .ok()
583        .map(|c| extract_list_u32_column(c, total_rows))
584        .unwrap_or_else(|| vec![None; total_rows]);
585
586    // Extract not_exhaustive_label_indices column (optional, List<UInt32>)
587    let not_exhaustive_label_indices: Vec<Option<Vec<u32>>> = df
588        .column("not_exhaustive_label_indices")
589        .ok()
590        .map(|c| extract_list_u32_column(c, total_rows))
591        .unwrap_or_else(|| vec![None; total_rows]);
592
593    // Extract score columns (2026.04 schema)
594    let box2d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box2d_score", total_rows);
595    let box3d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box3d_score", total_rows);
596    let polygon_scores: Vec<Option<f32>> = extract_f32_column(&df, "polygon_score", total_rows);
597    let mask_scores: Vec<Option<f32>> = extract_f32_column(&df, "mask_score", total_rows);
598
599    // Extract object_id column (optional, String) and parse to u64 where
600    // possible. This preserves the source COCO/LVIS annotation `id` across
601    // the Arrow→COCO round-trip so downstream tools (e.g., prompted-
602    // segmentation workflows that key on ann.id) see the original IDs.
603    //
604    // Non-numeric object_ids — produced by datasets whose instances carry
605    // string UUIDs rather than COCO numeric IDs — parse to None and fall
606    // through to auto-generated IDs in the builder. This is intentional:
607    // COCO requires numeric annotation IDs, and a UUID has no meaningful
608    // numeric projection.
609    let object_id_u64s: Vec<Option<u64>> = df
610        .column("object_id")
611        .ok()
612        .and_then(|c| c.cast(&DataType::String).ok())
613        .map(|c| {
614            c.str()
615                .ok()
616                .map(|s| {
617                    s.into_iter()
618                        .map(|v| v.and_then(|s| s.parse::<u64>().ok()))
619                        .collect()
620                })
621                .unwrap_or_else(|| vec![None; total_rows])
622        })
623        .unwrap_or_else(|| vec![None; total_rows]);
624
625    // Build COCO dataset
626    let mut builder = CocoDatasetBuilder::new();
627
628    if let Some(info) = &options.info {
629        builder = builder.info(info.clone());
630    }
631
632    // Group-filter predicate: returns true if this row should be skipped
633    let skip_row = |i: usize| -> bool {
634        !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i])
635    };
636
637    // Track unique images and categories
638    let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
639    let mut image_ids: HashMap<String, u64> = HashMap::new();
640    let mut category_ids: HashMap<String, u32> = HashMap::new();
641
642    // First pass: collect unique images and categories
643    for i in 0..total_rows {
644        if skip_row(i) {
645            continue;
646        }
647
648        let name = &names[i];
649        let label = &labels[i];
650
651        // Get or estimate image dimensions
652        if !image_ids.contains_key(name) {
653            let (width, height) = sizes
654                .as_ref()
655                .and_then(|s| s.get(i).copied())
656                .unwrap_or((0, 0));
657
658            let id = builder.add_image(name, width, height);
659            image_ids.insert(name.clone(), id);
660            image_dimensions.insert(name.clone(), (width, height));
661        }
662
663        if !label.is_empty() && !category_ids.contains_key(label) {
664            let id = if let Some(Some(idx)) = label_indices.get(i) {
665                builder.add_category_with_id(*idx as u32, label, None)
666            } else {
667                builder.add_category(label, None)
668            };
669            category_ids.insert(label.clone(), id);
670        }
671    }
672
673    // Second pass: create annotations
674    let mut last_progress_update = 0;
675    for i in 0..total_rows {
676        if skip_row(i) {
677            continue;
678        }
679
680        let name = &names[i];
681        let label = &labels[i];
682
683        // Skip sentinel rows (empty label = image with neg/exhaustive data but no annotations)
684        if label.is_empty() {
685            continue;
686        }
687
688        let image_id = *image_ids.get(name).unwrap_or(&0);
689        let category_id = *category_ids.get(label).unwrap_or(&0);
690        let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
691
692        // Convert box2d from Arrow center-normalized [cx, cy, w, h] to COCO format
693        // Arrow stores center-point, Box2d expects top-left
694        let bbox = box2ds.get(i).map(|box2d| {
695            let cx = box2d[0];
696            let cy = box2d[1];
697            let w = box2d[2];
698            let h = box2d[3];
699            // Convert from center-point to top-left format
700            let left = cx - w / 2.0;
701            let top = cy - h / 2.0;
702            let ef_box2d = Box2d::new(left, top, w, h);
703            box2d_to_coco_bbox(&ef_box2d, width, height)
704        });
705
706        // Build segmentation based on schema version
707        let segmentation = if options.include_masks {
708            if is_legacy {
709                // 2025.10: mask column contains NaN-separated flat polygon coords
710                legacy_masks.as_ref().and_then(|m| {
711                    m.get(i).and_then(|coords| {
712                        if coords.is_empty() {
713                            None
714                        } else {
715                            let rings = crate::unflatten_polygon_coordinates(coords);
716                            let polygon = Polygon::new(rings);
717                            let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
718                            if coco_poly.is_empty() {
719                                None
720                            } else {
721                                Some(CocoSegmentation::Polygon(coco_poly))
722                            }
723                        }
724                    })
725                })
726            } else {
727                // 2026.04+: try mask (Binary/PNG → RLE) first, then polygon column
728                let mask_seg = mask_binary_2026.as_ref().and_then(|masks| {
729                    masks.get(i).and_then(|opt_bytes| {
730                        opt_bytes
731                            .as_ref()
732                            .and_then(|png_bytes| png_to_rle_segmentation(png_bytes, i))
733                    })
734                });
735
736                if mask_seg.is_some() {
737                    mask_seg
738                } else {
739                    // Fall back to polygon column
740                    polygons_2026.as_ref().and_then(|polys| {
741                        polys.get(i).and_then(|opt_rings| {
742                            opt_rings.as_ref().and_then(|rings| {
743                                if rings.is_empty() {
744                                    return None;
745                                }
746                                let polygon = Polygon::new(rings.clone());
747                                let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
748                                if coco_poly.is_empty() {
749                                    None
750                                } else {
751                                    Some(CocoSegmentation::Polygon(coco_poly))
752                                }
753                            })
754                        })
755                    })
756                }
757            }
758        } else {
759            None
760        };
761
762        // Determine the score: use first non-null from available score columns
763        let score: Option<f64> = mask_scores[i]
764            .or(polygon_scores[i])
765            .or(box3d_scores[i])
766            .or(box2d_scores[i])
767            .map(|s| s as f64);
768
769        if let Some(bbox) = bbox {
770            let iscrowd = iscrowds[i];
771            let ann_id = builder.add_annotation_with_id(
772                object_id_u64s[i],
773                image_id,
774                category_id,
775                bbox,
776                segmentation,
777                iscrowd,
778            );
779
780            // Set score on the annotation if present
781            if let Some(score_val) = score {
782                builder.set_annotation_score(ann_id, score_val);
783            }
784        }
785
786        // Update progress every 1000 rows to reduce overhead
787        if let Some(ref p) = progress
788            && (i - last_progress_update >= 1000 || i == total_rows - 1)
789        {
790            let _ = p
791                .send(Progress {
792                    current: i + 1,
793                    total: total_rows,
794                    status: None,
795                })
796                .await;
797            last_progress_update = i;
798        }
799    }
800
801    // Send final progress event (may not have fired if last rows were filtered)
802    if let Some(ref p) = progress
803        && last_progress_update < total_rows.saturating_sub(1)
804    {
805        let _ = p
806            .send(Progress {
807                current: total_rows,
808                total: total_rows,
809                status: None,
810            })
811            .await;
812    }
813
814    // Third pass: set LVIS image-level fields (neg/not-exhaustive category IDs)
815    // Since label_index == category_id, we can use the values directly.
816    {
817        let mut processed_images: std::collections::HashSet<u64> = std::collections::HashSet::new();
818        for i in 0..total_rows {
819            if skip_row(i) {
820                continue;
821            }
822            let name = &names[i];
823            if let Some(&image_id) = image_ids.get(name) {
824                if !processed_images.insert(image_id) {
825                    continue;
826                }
827                let neg = neg_label_indices[i].clone();
828                let not_exhaustive = not_exhaustive_label_indices[i].clone();
829                if neg.is_some() || not_exhaustive.is_some() {
830                    builder.set_image_neg_categories(image_id, neg, not_exhaustive);
831                }
832            }
833        }
834    }
835
836    // Set category frequency from the category_frequency column.
837    // Build a map of category_name -> frequency from the first occurrence.
838    {
839        let mut freq_map: HashMap<String, String> = HashMap::new();
840        for i in 0..total_rows {
841            if skip_row(i) {
842                continue;
843            }
844            let label = &labels[i];
845            if !label.is_empty()
846                && !freq_map.contains_key(label)
847                && let Some(ref freq) = category_frequencies[i]
848            {
849                freq_map.insert(label.clone(), freq.clone());
850            }
851        }
852        for (name, freq) in &freq_map {
853            builder.set_category_metadata(name, None, Some(freq.clone()), None, None);
854        }
855    }
856
857    // Set category metadata from file-level metadata JSON
858    // (id, frequency, synset, synonyms, def, supercategory).
859    // Also creates categories that exist in metadata but have no annotations
860    // (e.g., categories only referenced in neg_category_ids).
861    // set_category_metadata only updates fields that are Some, so frequency
862    // set from the column above is preserved for categories that had annotations.
863    if let Some(ref json_str) = category_metadata_json
864        && let Ok(meta) = serde_json::from_str::<HashMap<String, serde_json::Value>>(json_str)
865    {
866        for (cat_name, value) in &meta {
867            let supercategory = value.get("supercategory").and_then(|v| v.as_str());
868
869            // If this category doesn't exist yet, create it with the stored id
870            if !category_ids.contains_key(cat_name.as_str()) {
871                let cat_id = value.get("id").and_then(|v| v.as_u64()).map(|id| id as u32);
872                let id = if let Some(cat_id) = cat_id {
873                    builder.add_category_with_id(cat_id, cat_name, supercategory)
874                } else {
875                    builder.add_category(cat_name, supercategory)
876                };
877                category_ids.insert(cat_name.clone(), id);
878            } else {
879                // Category already exists — set supercategory if present in metadata
880                if let Some(sc) = supercategory {
881                    builder.set_category_supercategory(cat_name, sc);
882                }
883            }
884
885            let synset = value
886                .get("synset")
887                .and_then(|v| v.as_str())
888                .map(String::from);
889            let frequency = value
890                .get("frequency")
891                .and_then(|v| v.as_str())
892                .map(String::from);
893            let synonyms = value.get("synonyms").and_then(|v| {
894                v.as_array().map(|arr| {
895                    arr.iter()
896                        .filter_map(|s| s.as_str().map(String::from))
897                        .collect()
898                })
899            });
900            let def = value
901                .get("definition")
902                .and_then(|v| v.as_str())
903                .map(String::from);
904
905            builder.set_category_metadata(cat_name, synset, frequency, synonyms, def);
906        }
907    }
908
909    // Populate category names from labels metadata if categories weren't set
910    // from category_metadata (e.g., older files that only have labels list).
911    if category_metadata_json.is_none()
912        && let Some(ref labels_json) = labels_metadata_json
913        && let Ok(label_names) = serde_json::from_str::<Vec<String>>(labels_json)
914    {
915        for label_name in &label_names {
916            if !category_ids.contains_key(label_name) {
917                let id = builder.add_category(label_name, None);
918                category_ids.insert(label_name.clone(), id);
919            }
920        }
921    }
922
923    let dataset = builder.build();
924    let annotation_count = dataset.annotations.len();
925
926    // Write output
927    let writer = CocoWriter::new();
928    writer.write_json(&dataset, output_path)?;
929
930    Ok(annotation_count)
931}
932
933/// Extract all box2d values from a column at once (O(n) instead of O(n²)).
934fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
935    let arr = col.array()?;
936    let mut result = Vec::with_capacity(arr.len());
937
938    for inner in arr.amortized_iter() {
939        let values = if let Some(inner) = inner {
940            let series = inner.as_ref();
941            let vals: Vec<f32> = series
942                .f32()
943                .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
944                .into_iter()
945                .map(|v| v.unwrap_or(0.0))
946                .collect();
947
948            if vals.len() == 4 {
949                [vals[0], vals[1], vals[2], vals[3]]
950            } else {
951                [0.0, 0.0, 0.0, 0.0]
952            }
953        } else {
954            [0.0, 0.0, 0.0, 0.0]
955        };
956        result.push(values);
957    }
958
959    Ok(result)
960}
961
962/// Extract all mask coordinates from a column at once (O(n) instead of O(n²)).
963fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
964    let list = col.list()?;
965    let mut result = Vec::with_capacity(list.len());
966
967    for i in 0..list.len() {
968        let coords = match list.get_as_series(i) {
969            Some(series) => series
970                .f32()
971                .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
972                .into_iter()
973                .map(|v| v.unwrap_or(f32::NAN))
974                .collect(),
975            None => vec![],
976        };
977        result.push(coords);
978    }
979
980    Ok(result)
981}
982
983/// Extract all image sizes from a column at once.
984fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
985    let arr = col.array()?;
986    let mut result = Vec::with_capacity(arr.len());
987
988    for inner in arr.amortized_iter() {
989        let size = if let Some(inner) = inner {
990            let series = inner.as_ref();
991            let values: Vec<u32> = series
992                .u32()
993                .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
994                .into_iter()
995                .map(|v| v.unwrap_or(0))
996                .collect();
997
998            if values.len() >= 2 {
999                (values[0], values[1])
1000            } else {
1001                (0, 0)
1002            }
1003        } else {
1004            (0, 0)
1005        };
1006        result.push(size);
1007    }
1008
1009    Ok(result)
1010}
1011
1012/// Extract a List<UInt32> column into a vector of optional Vec<u32>.
1013fn extract_list_u32_column(col: &Column, total_rows: usize) -> Vec<Option<Vec<u32>>> {
1014    col.list()
1015        .ok()
1016        .map(|list| {
1017            (0..list.len())
1018                .map(|i| {
1019                    list.get_as_series(i).and_then(|series| {
1020                        series
1021                            .u32()
1022                            .ok()
1023                            .map(|ca| ca.into_iter().flatten().collect::<Vec<u32>>())
1024                    })
1025                })
1026                .collect()
1027        })
1028        .unwrap_or_else(|| vec![None; total_rows])
1029}
1030
1031/// Extract polygon rings from a `List(List(Float32))` column (2026.04 schema).
1032///
1033/// Each row is an optional list of rings; each ring is a list of flat `[x, y, x, y, ...]`
1034/// coordinate pairs.
1035fn extract_all_polygons(col: &Column, total_rows: usize) -> Vec<Option<PolygonRings>> {
1036    let outer_list = match col.list() {
1037        Ok(l) => l,
1038        Err(_) => return vec![None; total_rows],
1039    };
1040
1041    let mut result = Vec::with_capacity(total_rows);
1042    for i in 0..outer_list.len() {
1043        let rings = outer_list.get_as_series(i).and_then(|ring_series| {
1044            let inner_list = ring_series.list().ok()?;
1045            let mut rings = Vec::new();
1046            for j in 0..inner_list.len() {
1047                if let Some(coords_series) = inner_list.get_as_series(j)
1048                    && let Ok(f32_ca) = coords_series.f32()
1049                {
1050                    let coords: Vec<f32> = f32_ca.into_iter().map(|v| v.unwrap_or(0.0)).collect();
1051                    // Convert flat [x, y, x, y, ...] to Vec<(f32, f32)>
1052                    let points: Vec<(f32, f32)> = coords
1053                        .chunks(2)
1054                        .filter(|c| c.len() == 2)
1055                        .map(|c| (c[0], c[1]))
1056                        .collect();
1057                    if !points.is_empty() {
1058                        rings.push(points);
1059                    }
1060                }
1061            }
1062            if rings.is_empty() { None } else { Some(rings) }
1063        });
1064        result.push(rings);
1065    }
1066    result
1067}
1068
1069/// Extract binary mask data from a `Binary` column (2026.04 schema — PNG bytes).
1070fn extract_all_binary_masks(col: &Column, total_rows: usize) -> Vec<Option<Vec<u8>>> {
1071    let binary_ca = match col.binary() {
1072        Ok(b) => b,
1073        Err(_) => return vec![None; total_rows],
1074    };
1075
1076    (0..binary_ca.len())
1077        .map(|i| binary_ca.get(i).map(|bytes| bytes.to_vec()))
1078        .collect()
1079}
1080
1081/// Extract an optional Float32 column by name.
1082fn extract_f32_column(df: &DataFrame, name: &str, total_rows: usize) -> Vec<Option<f32>> {
1083    df.column(name)
1084        .ok()
1085        .and_then(|c| c.f32().ok())
1086        .map(|ca| ca.into_iter().collect())
1087        .unwrap_or_else(|| vec![None; total_rows])
1088}
1089
1090/// Decode a PNG mask (Binary column bytes) into COCO RLE segmentation.
1091///
1092/// Validates the PNG, decodes pixels, binarizes if needed (8-bit or 16-bit),
1093/// and encodes as COCO RLE. Returns `None` for empty or invalid data (with
1094/// a warning log for invalid cases).
1095fn png_to_rle_segmentation(png_bytes: &[u8], row_index: usize) -> Option<CocoSegmentation> {
1096    if png_bytes.is_empty() {
1097        return None;
1098    }
1099
1100    let mask_data = match crate::MaskData::from_png_checked(png_bytes.to_vec()) {
1101        Ok(m) => m,
1102        Err(e) => {
1103            log::warn!("Skipping invalid PNG mask at row {}: {}", row_index, e);
1104            return None;
1105        }
1106    };
1107
1108    let mw = mask_data.width();
1109    let mh = mask_data.height();
1110    let bit_depth = mask_data.bit_depth();
1111
1112    let decoded = match mask_data.decode() {
1113        Ok(d) => d,
1114        Err(e) => {
1115            log::warn!("Failed to decode PNG mask at row {}: {}", row_index, e);
1116            return None;
1117        }
1118    };
1119
1120    let binary_mask = match bit_depth {
1121        1 => decoded,
1122        8 => {
1123            log::warn!(
1124                "Binarizing 8-bit mask for row {} — score data is lost",
1125                row_index
1126            );
1127            decoded
1128                .iter()
1129                .map(|&v| if v >= 128 { 1 } else { 0 })
1130                .collect()
1131        }
1132        16 => {
1133            log::warn!(
1134                "Binarizing 16-bit mask for row {} — score data is lost",
1135                row_index
1136            );
1137            decoded
1138                .chunks(2)
1139                .map(|pair| {
1140                    let val = if pair.len() == 2 {
1141                        u16::from_be_bytes([pair[0], pair[1]])
1142                    } else {
1143                        0
1144                    };
1145                    if val >= 32768 { 1u8 } else { 0u8 }
1146                })
1147                .collect()
1148        }
1149        _ => decoded,
1150    };
1151
1152    match super::convert::encode_rle(&binary_mask, mw, mh) {
1153        Ok(rle) => Some(CocoSegmentation::Rle(rle)),
1154        Err(e) => {
1155            log::warn!("Failed to encode RLE for row {}: {}", row_index, e);
1156            None
1157        }
1158    }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164    use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
1165    use tempfile::TempDir;
1166
1167    // =========================================================================
1168    // unflatten_polygon_coords tests
1169    // =========================================================================
1170
1171    #[test]
1172    fn test_unflatten_polygon_coords_empty() {
1173        let coords: Vec<f32> = vec![];
1174        let result = crate::unflatten_polygon_coordinates(&coords);
1175        assert!(result.is_empty());
1176    }
1177
1178    #[test]
1179    fn test_unflatten_polygon_coords_single_polygon() {
1180        // Simple rectangle: 4 points
1181        let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
1182        let result = crate::unflatten_polygon_coordinates(&coords);
1183
1184        assert_eq!(result.len(), 1);
1185        assert_eq!(result[0].len(), 4);
1186        assert_eq!(result[0][0], (0.1, 0.2));
1187        assert_eq!(result[0][3], (0.1, 0.4));
1188    }
1189
1190    #[test]
1191    fn test_unflatten_polygon_coords_multiple_polygons() {
1192        // Two triangles separated by NaN
1193        let coords = vec![
1194            0.1,
1195            0.1,
1196            0.2,
1197            0.1,
1198            0.15,
1199            0.2,      // First triangle
1200            f32::NAN, // Separator
1201            0.5,
1202            0.5,
1203            0.6,
1204            0.5,
1205            0.55,
1206            0.6, // Second triangle
1207        ];
1208        let result = crate::unflatten_polygon_coordinates(&coords);
1209
1210        assert_eq!(result.len(), 2);
1211        assert_eq!(result[0].len(), 3);
1212        assert_eq!(result[1].len(), 3);
1213        assert_eq!(result[0][0], (0.1, 0.1));
1214        assert_eq!(result[1][0], (0.5, 0.5));
1215    }
1216
1217    #[test]
1218    fn test_unflatten_polygon_coords_leading_nan() {
1219        // NaN at the start should be handled gracefully
1220        let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
1221        let result = crate::unflatten_polygon_coordinates(&coords);
1222
1223        assert_eq!(result.len(), 1);
1224        assert_eq!(result[0].len(), 3);
1225    }
1226
1227    #[test]
1228    fn test_unflatten_polygon_coords_trailing_nan() {
1229        // NaN at the end
1230        let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
1231        let result = crate::unflatten_polygon_coordinates(&coords);
1232
1233        assert_eq!(result.len(), 1);
1234        assert_eq!(result[0].len(), 2);
1235    }
1236
1237    #[test]
1238    fn test_unflatten_polygon_coords_consecutive_nans() {
1239        // Multiple NaNs in a row
1240        let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
1241        let result = crate::unflatten_polygon_coordinates(&coords);
1242
1243        assert_eq!(result.len(), 2);
1244        assert_eq!(result[0].len(), 1);
1245        assert_eq!(result[1].len(), 1);
1246    }
1247
1248    #[test]
1249    fn test_unflatten_polygon_coords_odd_values() {
1250        // Odd number of coordinates (trailing x without y)
1251        let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1252        let result = crate::unflatten_polygon_coordinates(&coords);
1253
1254        assert_eq!(result.len(), 1);
1255        assert_eq!(result[0].len(), 2); // Only complete pairs
1256    }
1257
1258    // =========================================================================
1259    // convert_image_annotations tests
1260    // =========================================================================
1261
1262    #[test]
1263    fn test_convert_image_annotations_basic() {
1264        let image = CocoImage {
1265            id: 1,
1266            width: 640,
1267            height: 480,
1268            file_name: "test_image.jpg".to_string(),
1269            ..Default::default()
1270        };
1271
1272        let dataset = CocoDataset {
1273            images: vec![image.clone()],
1274            categories: vec![CocoCategory {
1275                id: 1,
1276                name: "cat".to_string(),
1277                supercategory: Some("animal".to_string()),
1278                ..Default::default()
1279            }],
1280            annotations: vec![CocoAnnotation {
1281                id: 42,
1282                image_id: 1,
1283                category_id: 1,
1284                bbox: [100.0, 100.0, 200.0, 200.0],
1285                area: 40000.0,
1286                iscrowd: 0,
1287                segmentation: None,
1288                score: None,
1289            }],
1290            ..Default::default()
1291        };
1292
1293        let index = CocoIndex::from_dataset(&dataset);
1294        let samples = convert_image_annotations(&image, &index, true, Some("train"));
1295
1296        assert_eq!(samples.len(), 1);
1297        assert_eq!(samples[0].image_name, Some("test_image".to_string()));
1298        assert_eq!(samples[0].group, Some("train".to_string()));
1299        assert_eq!(samples[0].annotations.len(), 1);
1300        assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
1301        assert_eq!(
1302            samples[0].annotations[0].object_id(),
1303            Some(&"42".to_string()),
1304            "object_id must be populated from COCO annotation id to enable \
1305             prediction-to-prompt linking in prompted-segmentation workflows",
1306        );
1307    }
1308
1309    #[test]
1310    fn test_convert_image_annotations_with_mask() {
1311        let image = CocoImage {
1312            id: 1,
1313            width: 100,
1314            height: 100,
1315            file_name: "masked.jpg".to_string(),
1316            ..Default::default()
1317        };
1318
1319        let dataset = CocoDataset {
1320            images: vec![image.clone()],
1321            categories: vec![CocoCategory {
1322                id: 1,
1323                name: "object".to_string(),
1324                supercategory: None,
1325                ..Default::default()
1326            }],
1327            annotations: vec![CocoAnnotation {
1328                id: 1,
1329                image_id: 1,
1330                category_id: 1,
1331                bbox: [10.0, 10.0, 50.0, 50.0],
1332                area: 2500.0,
1333                iscrowd: 0,
1334                segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1335                    10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
1336                ]])),
1337                score: None,
1338            }],
1339            ..Default::default()
1340        };
1341
1342        let index = CocoIndex::from_dataset(&dataset);
1343
1344        // With masks enabled
1345        let samples_with_mask = convert_image_annotations(&image, &index, true, None);
1346        assert!(samples_with_mask[0].annotations[0].polygon().is_some());
1347
1348        // With masks disabled
1349        let samples_no_mask = convert_image_annotations(&image, &index, false, None);
1350        assert!(samples_no_mask[0].annotations[0].polygon().is_none());
1351    }
1352
1353    #[test]
1354    fn test_convert_image_annotations_object_id_from_lvis_large_id() {
1355        // LVIS v1.0 annotation IDs are u64 and routinely exceed 32-bit range
1356        // (the public release goes well past 2 billion). This test guards
1357        // against any future change that silently truncates on the path from
1358        // CocoAnnotation.id (u64) to Annotation.object_id (String).
1359        let image = CocoImage {
1360            id: 397133,
1361            width: 640,
1362            height: 480,
1363            file_name: "000000397133.jpg".to_string(),
1364            ..Default::default()
1365        };
1366
1367        let large_id: u64 = 9_876_543_210;
1368        let dataset = CocoDataset {
1369            images: vec![image.clone()],
1370            categories: vec![CocoCategory {
1371                id: 16,
1372                name: "dog".to_string(),
1373                synset: Some("dog.n.01".to_string()),
1374                frequency: Some("f".to_string()),
1375                ..Default::default()
1376            }],
1377            annotations: vec![CocoAnnotation {
1378                id: large_id,
1379                image_id: 397133,
1380                category_id: 16,
1381                bbox: [192.81, 224.8, 74.73, 33.43],
1382                area: 1035.7,
1383                iscrowd: 0,
1384                segmentation: None,
1385                score: None,
1386            }],
1387            ..Default::default()
1388        };
1389
1390        let index = CocoIndex::from_dataset(&dataset);
1391        let samples = convert_image_annotations(&image, &index, true, None);
1392
1393        assert_eq!(samples.len(), 1);
1394        assert_eq!(samples[0].annotations.len(), 1);
1395        assert_eq!(
1396            samples[0].annotations[0].object_id(),
1397            Some(&large_id.to_string()),
1398        );
1399    }
1400
1401    #[test]
1402    fn test_convert_image_annotations_no_annotations() {
1403        let image = CocoImage {
1404            id: 1,
1405            width: 640,
1406            height: 480,
1407            file_name: "empty.jpg".to_string(),
1408            ..Default::default()
1409        };
1410
1411        let dataset = CocoDataset {
1412            images: vec![image.clone()],
1413            categories: vec![],
1414            annotations: vec![],
1415            ..Default::default()
1416        };
1417
1418        let index = CocoIndex::from_dataset(&dataset);
1419        let samples = convert_image_annotations(&image, &index, true, None);
1420
1421        // An image with no annotations must still emit one placeholder row so
1422        // the image is never dropped from the dataset.
1423        assert_eq!(samples.len(), 1);
1424        assert_eq!(samples[0].image_name, Some("empty".to_string()));
1425        assert!(samples[0].annotations.is_empty());
1426        assert_eq!(samples[0].group, None);
1427    }
1428
1429    // =========================================================================
1430    // sample_name_from_filename tests
1431    // =========================================================================
1432
1433    #[test]
1434    fn test_sample_name_from_filename() {
1435        assert_eq!(
1436            sample_name_from_filename("000000397133.jpg"),
1437            "000000397133"
1438        );
1439        assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
1440        assert_eq!(sample_name_from_filename("test"), "test");
1441    }
1442
1443    #[test]
1444    fn test_sample_name_from_filename_nested_path() {
1445        assert_eq!(
1446            sample_name_from_filename("a/b/c/deep_image.png"),
1447            "deep_image"
1448        );
1449    }
1450
1451    #[test]
1452    fn test_sample_name_from_filename_no_extension() {
1453        assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
1454    }
1455
1456    // =========================================================================
1457    // Options tests
1458    // =========================================================================
1459
1460    #[test]
1461    fn test_coco_to_arrow_options_default() {
1462        let options = CocoToArrowOptions::default();
1463        assert!(options.include_masks);
1464        assert!(options.group.is_none());
1465        assert!(options.max_workers >= 2);
1466    }
1467
1468    #[test]
1469    fn test_arrow_to_coco_options_default() {
1470        let options = ArrowToCocoOptions::default();
1471        assert!(options.groups.is_empty());
1472        assert!(options.include_masks);
1473        assert!(options.info.is_none());
1474    }
1475
1476    #[test]
1477    fn test_max_workers() {
1478        let workers = max_workers();
1479        assert!(workers >= 2);
1480        assert!(workers <= 8);
1481    }
1482
1483    #[tokio::test]
1484    async fn test_coco_to_arrow_minimal() {
1485        let temp_dir = TempDir::new().unwrap();
1486
1487        // Create minimal COCO JSON
1488        let coco_json = r#"{
1489            "images": [
1490                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1491            ],
1492            "annotations": [
1493                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1494            ],
1495            "categories": [
1496                {"id": 1, "name": "person", "supercategory": "human"}
1497            ]
1498        }"#;
1499
1500        let coco_path = temp_dir.path().join("test.json");
1501        std::fs::write(&coco_path, coco_json).unwrap();
1502
1503        let arrow_path = temp_dir.path().join("output.arrow");
1504
1505        let options = CocoToArrowOptions::default();
1506        let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1507            .await
1508            .unwrap();
1509
1510        assert_eq!(count, 1);
1511        assert!(arrow_path.exists());
1512
1513        // Verify Arrow contents
1514        let mut file = std::fs::File::open(&arrow_path).unwrap();
1515        let df = IpcReader::new(&mut file).finish().unwrap();
1516        assert_eq!(df.height(), 1);
1517    }
1518
1519    #[tokio::test]
1520    async fn test_arrow_to_coco_roundtrip() {
1521        let temp_dir = TempDir::new().unwrap();
1522
1523        // Create COCO JSON
1524        let original = CocoDataset {
1525            images: vec![CocoImage {
1526                id: 1,
1527                width: 640,
1528                height: 480,
1529                file_name: "test.jpg".to_string(),
1530                ..Default::default()
1531            }],
1532            annotations: vec![CocoAnnotation {
1533                id: 1,
1534                image_id: 1,
1535                category_id: 1,
1536                bbox: [100.0, 50.0, 200.0, 150.0],
1537                area: 30000.0,
1538                iscrowd: 0,
1539                segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1540                    100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
1541                ]])),
1542                score: None,
1543            }],
1544            categories: vec![CocoCategory {
1545                id: 1,
1546                name: "person".to_string(),
1547                supercategory: Some("human".to_string()),
1548                ..Default::default()
1549            }],
1550            ..Default::default()
1551        };
1552
1553        // Write original COCO
1554        let coco_path = temp_dir.path().join("original.json");
1555        let writer = CocoWriter::new();
1556        writer.write_json(&original, &coco_path).unwrap();
1557
1558        // Convert to Arrow
1559        let arrow_path = temp_dir.path().join("converted.arrow");
1560        let options = CocoToArrowOptions::default();
1561        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1562            .await
1563            .unwrap();
1564
1565        // Convert back to COCO
1566        let restored_path = temp_dir.path().join("restored.json");
1567        let options = ArrowToCocoOptions::default();
1568        arrow_to_coco(&arrow_path, &restored_path, &options, None)
1569            .await
1570            .unwrap();
1571
1572        // Verify restored data
1573        let reader = CocoReader::new();
1574        let restored = reader.read_json(&restored_path).unwrap();
1575
1576        assert_eq!(restored.images.len(), 1);
1577        assert_eq!(restored.annotations.len(), 1);
1578        assert_eq!(restored.categories.len(), 1);
1579
1580        // Check category name preserved
1581        assert_eq!(restored.categories[0].name, "person");
1582    }
1583
1584    #[tokio::test]
1585    async fn test_arrow_to_coco_roundtrip_preserves_annotation_id() {
1586        // Asserts the COCO/LVIS annotation `id` survives the full
1587        // JSON → Arrow → JSON round-trip. The IDs deliberately mix a
1588        // small value (1) with a 33-bit value (9_876_543_210) to catch
1589        // any future regression that silently truncates to u32 along
1590        // the path.
1591        let temp_dir = TempDir::new().unwrap();
1592
1593        let large_id: u64 = 9_876_543_210;
1594        let original = CocoDataset {
1595            images: vec![CocoImage {
1596                id: 1,
1597                width: 640,
1598                height: 480,
1599                file_name: "test.jpg".to_string(),
1600                ..Default::default()
1601            }],
1602            annotations: vec![
1603                CocoAnnotation {
1604                    id: 1,
1605                    image_id: 1,
1606                    category_id: 1,
1607                    bbox: [10.0, 20.0, 100.0, 80.0],
1608                    area: 8000.0,
1609                    iscrowd: 0,
1610                    segmentation: None,
1611                    score: None,
1612                },
1613                CocoAnnotation {
1614                    id: large_id,
1615                    image_id: 1,
1616                    category_id: 1,
1617                    bbox: [200.0, 200.0, 100.0, 100.0],
1618                    area: 10000.0,
1619                    iscrowd: 0,
1620                    segmentation: None,
1621                    score: None,
1622                },
1623            ],
1624            categories: vec![CocoCategory {
1625                id: 1,
1626                name: "person".to_string(),
1627                supercategory: Some("human".to_string()),
1628                ..Default::default()
1629            }],
1630            ..Default::default()
1631        };
1632
1633        let coco_path = temp_dir.path().join("original.json");
1634        let writer = CocoWriter::new();
1635        writer.write_json(&original, &coco_path).unwrap();
1636
1637        let arrow_path = temp_dir.path().join("converted.arrow");
1638        coco_to_arrow(
1639            &coco_path,
1640            &arrow_path,
1641            &CocoToArrowOptions::default(),
1642            None,
1643        )
1644        .await
1645        .unwrap();
1646
1647        let restored_path = temp_dir.path().join("restored.json");
1648        arrow_to_coco(
1649            &arrow_path,
1650            &restored_path,
1651            &ArrowToCocoOptions::default(),
1652            None,
1653        )
1654        .await
1655        .unwrap();
1656
1657        let restored = CocoReader::new().read_json(&restored_path).unwrap();
1658        assert_eq!(restored.annotations.len(), 2);
1659
1660        let restored_ids: std::collections::HashSet<u64> =
1661            restored.annotations.iter().map(|a| a.id).collect();
1662        assert!(
1663            restored_ids.contains(&1),
1664            "small annotation id (1) must round-trip; got {restored_ids:?}"
1665        );
1666        assert!(
1667            restored_ids.contains(&large_id),
1668            "33-bit LVIS-scale annotation id ({large_id}) must round-trip; got {restored_ids:?}"
1669        );
1670    }
1671
1672    // =========================================================================
1673    // Arrow IPC file metadata tests
1674    // =========================================================================
1675
1676    #[tokio::test]
1677    async fn test_coco_to_arrow_schema_version_metadata() {
1678        let temp_dir = TempDir::new().unwrap();
1679
1680        // Create minimal COCO JSON (no LVIS fields)
1681        let coco_json = r#"{
1682            "images": [
1683                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1684            ],
1685            "annotations": [
1686                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1687            ],
1688            "categories": [
1689                {"id": 1, "name": "person", "supercategory": "human"}
1690            ]
1691        }"#;
1692
1693        let coco_path = temp_dir.path().join("test.json");
1694        std::fs::write(&coco_path, coco_json).unwrap();
1695
1696        let arrow_path = temp_dir.path().join("output.arrow");
1697        let options = CocoToArrowOptions::default();
1698        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1699            .await
1700            .unwrap();
1701
1702        // Read back and verify schema_version metadata
1703        let mut file = std::fs::File::open(&arrow_path).unwrap();
1704        let mut reader = IpcReader::new(&mut file);
1705        let custom_meta = reader.custom_metadata().unwrap();
1706        assert!(custom_meta.is_some(), "custom metadata should be present");
1707
1708        let meta = custom_meta.unwrap();
1709        assert_eq!(
1710            meta.get(&PlSmallStr::from("schema_version")),
1711            Some(&PlSmallStr::from(SCHEMA_VERSION)),
1712            "schema_version metadata should be '2026.04'"
1713        );
1714
1715        // category_metadata is always present when there are categories
1716        assert!(
1717            meta.contains_key(&PlSmallStr::from("category_metadata")),
1718            "category_metadata should be present even without LVIS fields"
1719        );
1720    }
1721
1722    #[tokio::test]
1723    async fn test_coco_to_arrow_category_metadata_lvis() {
1724        let temp_dir = TempDir::new().unwrap();
1725
1726        // Create COCO JSON with LVIS category fields
1727        let coco_json = r#"{
1728            "images": [
1729                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1730            ],
1731            "annotations": [
1732                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0},
1733                {"id": 2, "image_id": 1, "category_id": 2, "bbox": [50, 60, 80, 40], "area": 3200, "iscrowd": 0}
1734            ],
1735            "categories": [
1736                {
1737                    "id": 1,
1738                    "name": "aerosol_can",
1739                    "synset": "aerosol.n.02",
1740                    "synonyms": ["aerosol_can", "spray_can"],
1741                    "def": "a dispenser that holds a substance under pressure"
1742                },
1743                {
1744                    "id": 2,
1745                    "name": "person",
1746                    "supercategory": "human"
1747                }
1748            ]
1749        }"#;
1750
1751        let coco_path = temp_dir.path().join("lvis.json");
1752        std::fs::write(&coco_path, coco_json).unwrap();
1753
1754        let arrow_path = temp_dir.path().join("lvis_output.arrow");
1755        let options = CocoToArrowOptions::default();
1756        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1757            .await
1758            .unwrap();
1759
1760        // Read back and verify metadata
1761        let mut file = std::fs::File::open(&arrow_path).unwrap();
1762        let mut reader = IpcReader::new(&mut file);
1763        let custom_meta = reader.custom_metadata().unwrap();
1764        assert!(custom_meta.is_some(), "custom metadata should be present");
1765
1766        let meta = custom_meta.unwrap();
1767
1768        // schema_version is always present
1769        assert_eq!(
1770            meta.get(&PlSmallStr::from("schema_version")),
1771            Some(&PlSmallStr::from(SCHEMA_VERSION)),
1772        );
1773
1774        // category_metadata should be present (aerosol_can has LVIS fields)
1775        let cat_meta_str = meta
1776            .get(&PlSmallStr::from("category_metadata"))
1777            .expect("category_metadata should be present for LVIS data");
1778
1779        let cat_meta: HashMap<String, serde_json::Value> =
1780            serde_json::from_str(cat_meta_str.as_str()).unwrap();
1781
1782        // Both categories should be present (all categories are now stored)
1783        assert!(
1784            cat_meta.contains_key("aerosol_can"),
1785            "aerosol_can should be in category_metadata"
1786        );
1787        assert!(
1788            cat_meta.contains_key("person"),
1789            "person should also be in category_metadata"
1790        );
1791
1792        // Verify aerosol_can entry contents
1793        let aerosol = cat_meta.get("aerosol_can").unwrap();
1794        assert_eq!(
1795            aerosol.get("synset").and_then(|v| v.as_str()),
1796            Some("aerosol.n.02")
1797        );
1798        assert_eq!(
1799            aerosol.get("definition").and_then(|v| v.as_str()),
1800            Some("a dispenser that holds a substance under pressure")
1801        );
1802        let synonyms = aerosol.get("synonyms").and_then(|v| v.as_array()).unwrap();
1803        assert_eq!(synonyms.len(), 2);
1804        assert_eq!(synonyms[0].as_str(), Some("aerosol_can"));
1805        assert_eq!(synonyms[1].as_str(), Some("spray_can"));
1806    }
1807
1808    // =========================================================================
1809    // LVIS round-trip tests
1810    // =========================================================================
1811
1812    #[tokio::test]
1813    async fn test_coco_arrow_roundtrip_lvis_supercategory() {
1814        let temp_dir = TempDir::new().unwrap();
1815
1816        // Create COCO JSON with supercategory
1817        let coco_json = r#"{
1818            "images": [
1819                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1820            ],
1821            "annotations": [
1822                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1823            ],
1824            "categories": [
1825                {"id": 1, "name": "person", "supercategory": "human"}
1826            ]
1827        }"#;
1828
1829        let coco_path = temp_dir.path().join("original.json");
1830        std::fs::write(&coco_path, coco_json).unwrap();
1831
1832        // Convert to Arrow
1833        let arrow_path = temp_dir.path().join("converted.arrow");
1834        let options = CocoToArrowOptions::default();
1835        coco_to_arrow(&coco_path, &arrow_path, &options, None)
1836            .await
1837            .unwrap();
1838
1839        // Convert back to COCO
1840        let restored_path = temp_dir.path().join("restored.json");
1841        let options = ArrowToCocoOptions::default();
1842        arrow_to_coco(&arrow_path, &restored_path, &options, None)
1843            .await
1844            .unwrap();
1845
1846        // Verify supercategory is preserved
1847        let reader = CocoReader::new();
1848        let restored = reader.read_json(&restored_path).unwrap();
1849
1850        assert_eq!(restored.categories.len(), 1);
1851        assert_eq!(restored.categories[0].name, "person");
1852        assert_eq!(
1853            restored.categories[0].supercategory,
1854            Some("human".to_string()),
1855            "supercategory should survive COCO→Arrow→COCO round-trip"
1856        );
1857    }
1858
1859    #[tokio::test]
1860    async fn test_coco_arrow_roundtrip_neg_categories_no_annotations() {
1861        let temp_dir = TempDir::new().unwrap();
1862
1863        // Create COCO JSON: image has neg_category_ids but NO annotations
1864        let coco_json = r#"{
1865            "images": [
1866                {
1867                    "id": 1,
1868                    "width": 640,
1869                    "height": 480,
1870                    "file_name": "empty.jpg",
1871                    "neg_category_ids": [1, 2]
1872                }
1873            ],
1874            "annotations": [],
1875            "categories": [
1876                {"id": 1, "name": "cat", "supercategory": "animal"},
1877                {"id": 2, "name": "dog", "supercategory": "animal"}
1878            ]
1879        }"#;
1880
1881        let coco_path = temp_dir.path().join("original.json");
1882        std::fs::write(&coco_path, coco_json).unwrap();
1883
1884        // Convert to Arrow
1885        let arrow_path = temp_dir.path().join("converted.arrow");
1886        let options = CocoToArrowOptions::default();
1887        let sample_count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1888            .await
1889            .unwrap();
1890
1891        // Should have 1 sentinel sample (image with neg data but no annotations)
1892        assert_eq!(
1893            sample_count, 1,
1894            "sentinel row should be emitted for image with neg data"
1895        );
1896
1897        // Convert back to COCO
1898        let restored_path = temp_dir.path().join("restored.json");
1899        let options = ArrowToCocoOptions::default();
1900        arrow_to_coco(&arrow_path, &restored_path, &options, None)
1901            .await
1902            .unwrap();
1903
1904        // Verify neg_category_ids survived the round-trip
1905        let reader = CocoReader::new();
1906        let restored = reader.read_json(&restored_path).unwrap();
1907
1908        assert_eq!(restored.images.len(), 1);
1909        assert_eq!(restored.annotations.len(), 0, "no annotations expected");
1910        assert_eq!(restored.categories.len(), 2, "both categories should exist");
1911
1912        let neg = restored.images[0].neg_category_ids.as_ref();
1913        assert!(
1914            neg.is_some(),
1915            "neg_category_ids should survive round-trip for zero-annotation image"
1916        );
1917        let neg_ids = neg.unwrap();
1918        assert_eq!(neg_ids.len(), 2, "should have 2 neg categories");
1919        assert!(neg_ids.contains(&1), "neg_category_ids should contain 1");
1920        assert!(neg_ids.contains(&2), "neg_category_ids should contain 2");
1921
1922        // Verify supercategory survives for annotation-free categories
1923        for cat in &restored.categories {
1924            assert_eq!(
1925                cat.supercategory,
1926                Some("animal".to_string()),
1927                "supercategory should survive round-trip for annotation-free category '{}'",
1928                cat.name
1929            );
1930        }
1931    }
1932
1933    #[test]
1934    fn test_convert_image_annotations_neg_only_no_annotations() {
1935        let image = CocoImage {
1936            id: 1,
1937            width: 640,
1938            height: 480,
1939            file_name: "neg_only.jpg".to_string(),
1940            neg_category_ids: Some(vec![1, 2]),
1941            ..Default::default()
1942        };
1943
1944        let dataset = CocoDataset {
1945            images: vec![image.clone()],
1946            categories: vec![
1947                CocoCategory {
1948                    id: 1,
1949                    name: "cat".to_string(),
1950                    supercategory: Some("animal".to_string()),
1951                    ..Default::default()
1952                },
1953                CocoCategory {
1954                    id: 2,
1955                    name: "dog".to_string(),
1956                    supercategory: Some("animal".to_string()),
1957                    ..Default::default()
1958                },
1959            ],
1960            annotations: vec![],
1961            ..Default::default()
1962        };
1963
1964        let index = CocoIndex::from_dataset(&dataset);
1965        let samples = convert_image_annotations(&image, &index, true, None);
1966
1967        // Should emit 1 sentinel sample (no annotations but has neg data)
1968        assert_eq!(
1969            samples.len(),
1970            1,
1971            "sentinel row should be emitted for neg-only image"
1972        );
1973        assert_eq!(samples[0].image_name, Some("neg_only".to_string()));
1974        assert!(
1975            samples[0].annotations.is_empty(),
1976            "sentinel should have no annotations"
1977        );
1978        assert!(
1979            samples[0].neg_label_indices.is_some(),
1980            "sentinel should preserve neg_label_indices"
1981        );
1982        assert_eq!(samples[0].neg_label_indices.as_ref().unwrap().len(), 2);
1983    }
1984
1985    #[test]
1986    fn test_convert_image_annotations_no_annotations_emits_placeholder() {
1987        // A plain image with NO annotations and NO LVIS neg/exhaustive fields
1988        // must still emit one placeholder sample so the image is never dropped
1989        // and its dataset split (group) is preserved.
1990        let image = CocoImage {
1991            id: 1,
1992            width: 640,
1993            height: 480,
1994            file_name: "empty.jpg".to_string(),
1995            ..Default::default()
1996        };
1997
1998        let dataset = CocoDataset {
1999            images: vec![image.clone()],
2000            categories: vec![CocoCategory {
2001                id: 1,
2002                name: "person".to_string(),
2003                ..Default::default()
2004            }],
2005            annotations: vec![],
2006            ..Default::default()
2007        };
2008
2009        let index = CocoIndex::from_dataset(&dataset);
2010        let samples = convert_image_annotations(&image, &index, true, Some("train"));
2011
2012        assert_eq!(
2013            samples.len(),
2014            1,
2015            "placeholder row must be emitted for an unannotated image"
2016        );
2017        assert_eq!(samples[0].image_name, Some("empty".to_string()));
2018        assert!(
2019            samples[0].annotations.is_empty(),
2020            "placeholder should have no annotations"
2021        );
2022        assert_eq!(
2023            samples[0].group,
2024            Some("train".to_string()),
2025            "group must be preserved on the placeholder row"
2026        );
2027    }
2028
2029    #[tokio::test]
2030    async fn test_coco_to_arrow_includes_unannotated_images() {
2031        // coco-to-arrow must emit one row per image even when some images have
2032        // no annotations, so dataset splits (group) cover every image.
2033        let temp_dir = TempDir::new().unwrap();
2034
2035        let coco_json = r#"{
2036            "images": [
2037                {"id": 1, "width": 640, "height": 480, "file_name": "annotated.jpg"},
2038                {"id": 2, "width": 640, "height": 480, "file_name": "empty.jpg"}
2039            ],
2040            "annotations": [
2041                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
2042            ],
2043            "categories": [
2044                {"id": 1, "name": "person", "supercategory": "human"}
2045            ]
2046        }"#;
2047
2048        let coco_path = temp_dir.path().join("test.json");
2049        std::fs::write(&coco_path, coco_json).unwrap();
2050        let arrow_path = temp_dir.path().join("out.arrow");
2051
2052        let options = CocoToArrowOptions {
2053            group: Some("train".to_string()),
2054            ..Default::default()
2055        };
2056        let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
2057            .await
2058            .unwrap();
2059
2060        // 1 annotation row + 1 placeholder row for the unannotated image.
2061        assert_eq!(count, 2, "every image must produce at least one row");
2062
2063        let mut file = std::fs::File::open(&arrow_path).unwrap();
2064        let df = IpcReader::new(&mut file).finish().unwrap();
2065        assert_eq!(df.height(), 2);
2066
2067        // The unannotated image must appear with its group set and a null label.
2068        let names = df.column("name").unwrap().str().unwrap();
2069        let empty_row = (0..df.height()).find(|&i| names.get(i) == Some("empty"));
2070        assert!(
2071            empty_row.is_some(),
2072            "unannotated image 'empty' must appear in the Arrow output"
2073        );
2074        let i = empty_row.unwrap();
2075        let group_col = df.column("group").unwrap().cast(&DataType::String).unwrap();
2076        assert_eq!(
2077            group_col.str().unwrap().get(i),
2078            Some("train"),
2079            "group must be set on the unannotated image's row"
2080        );
2081        let label_col = df.column("label").unwrap().cast(&DataType::String).unwrap();
2082        assert_eq!(
2083            label_col.str().unwrap().get(i),
2084            None,
2085            "unannotated image row must have a null label"
2086        );
2087    }
2088}