1use super::{
10 convert::{
11 box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask_data,
12 coco_segmentation_to_polygon, polygon_to_coco_polygon,
13 },
14 reader::CocoReader,
15 types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
16 writer::{CocoDatasetBuilder, CocoWriter},
17};
18use crate::{Annotation, Box2d, Error, Polygon, Progress, Sample};
19use polars::prelude::*;
20use std::{
21 collections::{BTreeMap, HashMap},
22 path::Path,
23 sync::{
24 Arc,
25 atomic::{AtomicUsize, Ordering},
26 },
27};
28use tokio::sync::{Semaphore, mpsc::Sender};
29
30pub const SCHEMA_VERSION: &str = "2026.04";
32
33type PolygonRings = Vec<Vec<(f32, f32)>>;
35
36#[derive(Debug, Clone)]
38pub struct CocoToArrowOptions {
39 pub include_masks: bool,
41 pub group: Option<String>,
43 pub max_workers: usize,
45}
46
47impl Default for CocoToArrowOptions {
48 fn default() -> Self {
49 Self {
50 include_masks: true,
51 group: None,
52 max_workers: max_workers(),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ArrowToCocoOptions {
60 pub groups: Vec<String>,
62 pub include_masks: bool,
64 pub info: Option<CocoInfo>,
66}
67
68impl Default for ArrowToCocoOptions {
69 fn default() -> Self {
70 Self {
71 groups: vec![],
72 include_masks: true,
73 info: None,
74 }
75 }
76}
77
78fn max_workers() -> usize {
80 std::env::var("MAX_COCO_WORKERS")
81 .ok()
82 .and_then(|v| v.parse().ok())
83 .unwrap_or_else(|| {
84 let cpus = std::thread::available_parallelism()
85 .map(|n| n.get())
86 .unwrap_or(4);
87 (cpus / 2).clamp(2, 8)
88 })
89}
90
91pub async fn coco_to_arrow<P: AsRef<Path>>(
105 coco_path: P,
106 output_path: P,
107 options: &CocoToArrowOptions,
108 progress: Option<Sender<Progress>>,
109) -> Result<usize, Error> {
110 let coco_path = coco_path.as_ref();
111 let output_path = output_path.as_ref();
112
113 let reader = CocoReader::new();
115 let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
116 reader.read_annotations_zip(coco_path)?
117 } else {
118 reader.read_json(coco_path)?
119 };
120
121 let index = Arc::new(CocoIndex::from_dataset(&dataset));
123 let total_images = dataset.images.len();
124
125 if let Some(ref p) = progress {
127 let _ = p
128 .send(Progress {
129 current: 0,
130 total: total_images,
131 status: None,
132 })
133 .await;
134 }
135
136 let sem = Arc::new(Semaphore::new(options.max_workers));
138 let current = Arc::new(AtomicUsize::new(0));
139 let include_masks = options.include_masks;
140 let group = options.group.clone();
141
142 let mut tasks = Vec::with_capacity(total_images);
143
144 for image in dataset.images {
145 let sem = sem.clone();
146 let index = index.clone();
147 let current = current.clone();
148 let progress = progress.clone();
149 let total = total_images;
150 let group = group.clone();
151
152 let task = tokio::spawn(async move {
153 let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
154
155 let samples =
157 convert_image_annotations(&image, &index, include_masks, group.as_deref());
158
159 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
161 if let Some(ref p) = progress {
162 let _ = p
163 .send(Progress {
164 current: c,
165 total,
166 status: None,
167 })
168 .await;
169 }
170
171 Ok::<Vec<Sample>, Error>(samples)
172 });
173
174 tasks.push(task);
175 }
176
177 let mut all_samples = Vec::with_capacity(total_images);
179 for task in tasks {
180 let samples = task.await??;
181 all_samples.extend(samples);
182 }
183
184 let df = crate::samples_dataframe(&all_samples)?;
186
187 let mut metadata: BTreeMap<PlSmallStr, PlSmallStr> = BTreeMap::new();
189 metadata.insert(
190 PlSmallStr::from("schema_version"),
191 PlSmallStr::from(SCHEMA_VERSION),
192 );
193
194 if !dataset.categories.is_empty() {
200 let cat_meta: HashMap<String, serde_json::Value> = dataset
201 .categories
202 .iter()
203 .map(|c| {
204 let mut entry = serde_json::Map::new();
205 entry.insert("id".to_string(), serde_json::json!(c.id));
206 if let Some(ref f) = c.frequency {
207 entry.insert(
208 "frequency".to_string(),
209 serde_json::Value::String(f.clone()),
210 );
211 }
212 if let Some(ref s) = c.synset {
213 entry.insert("synset".to_string(), serde_json::Value::String(s.clone()));
214 }
215 if let Some(ref syns) = c.synonyms {
216 entry.insert("synonyms".to_string(), serde_json::json!(syns));
217 }
218 if let Some(ref d) = c.def {
219 entry.insert(
220 "definition".to_string(),
221 serde_json::Value::String(d.clone()),
222 );
223 }
224 if let Some(ref sc) = c.supercategory {
225 entry.insert(
226 "supercategory".to_string(),
227 serde_json::Value::String(sc.clone()),
228 );
229 }
230 (c.name.clone(), serde_json::Value::Object(entry))
234 })
235 .collect();
236
237 let json = serde_json::to_string(&cat_meta).unwrap_or_default();
238 metadata.insert(
239 PlSmallStr::from("category_metadata"),
240 PlSmallStr::from(json.as_str()),
241 );
242 }
243
244 if !dataset.categories.is_empty() {
246 let mut cats: Vec<_> = dataset.categories.iter().collect();
247 cats.sort_by_key(|c| c.id);
248 let labels: Vec<String> = cats.iter().map(|c| c.name.clone()).collect();
249 let labels_json = serde_json::to_string(&labels).unwrap_or_default();
250 metadata.insert(PlSmallStr::from("labels"), PlSmallStr::from(labels_json));
251 }
252
253 if let Some(parent) = output_path.parent()
255 && !parent.as_os_str().is_empty()
256 {
257 std::fs::create_dir_all(parent)?;
258 }
259 let mut file = std::fs::File::create(output_path)?;
260 let mut writer = IpcWriter::new(&mut file);
261 writer.set_custom_schema_metadata(Arc::new(metadata));
262 writer.finish(&mut df.clone())?;
263
264 Ok(all_samples.len())
265}
266
267fn convert_image_annotations(
269 image: &CocoImage,
270 index: &CocoIndex,
271 include_masks: bool,
272 group: Option<&str>,
273) -> Vec<Sample> {
274 let annotations = index.annotations_for_image(image.id);
275 let sample_name = sample_name_from_filename(&image.file_name);
276
277 let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
279 ids.iter()
280 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
281 .collect::<Vec<u32>>()
282 });
283 let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
284 ids.iter()
285 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
286 .collect::<Vec<u32>>()
287 });
288
289 let mut samples: Vec<Sample> = annotations
290 .iter()
291 .filter_map(|ann| {
292 let label = index.label_name(ann.category_id)?;
293 let label_index = index.label_index(ann.category_id);
294
295 let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
297
298 let (polygon, mask) = if include_masks {
302 if let Some(seg) = &ann.segmentation {
303 match seg {
304 CocoSegmentation::Polygon(_) => {
305 let poly =
306 coco_segmentation_to_polygon(seg, image.width, image.height).ok();
307 (poly, None)
308 }
309 CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_) => {
310 let mask_data = coco_segmentation_to_mask_data(seg).ok().flatten();
311 (None, mask_data)
312 }
313 }
314 } else {
315 (None, None)
316 }
317 } else {
318 (None, None)
319 };
320
321 let mut annotation = Annotation::new();
322 annotation.set_name(Some(sample_name.clone()));
323 annotation.set_object_id(Some(ann.id.to_string()));
324 annotation.set_label(Some(label.to_string()));
325 annotation.set_label_index(label_index);
326 annotation.set_box2d(Some(box2d));
327 annotation.set_polygon(polygon);
328 annotation.set_mask(mask);
329 annotation.set_group(group.map(String::from));
330 annotation.set_iscrowd(Some(ann.iscrowd != 0));
331 annotation.set_category_frequency(index.frequency(ann.category_id).map(String::from));
332
333 if let Some(score) = ann.score {
335 let score_f32 = score as f32;
336 if annotation.mask().is_some() {
337 annotation.set_mask_score(Some(score_f32));
338 } else if annotation.polygon().is_some() {
339 annotation.set_polygon_score(Some(score_f32));
340 } else {
341 annotation.set_box2d_score(Some(score_f32));
342 }
343 }
344
345 let mut sample = Sample {
346 image_name: Some(sample_name.clone()),
347 width: Some(image.width),
348 height: Some(image.height),
349 group: group.map(String::from),
350 annotations: vec![annotation],
351 ..Default::default()
352 };
353 sample.neg_label_indices = neg_label_indices.clone();
354 sample.not_exhaustive_label_indices = not_exhaustive_label_indices.clone();
355
356 Some(sample)
357 })
358 .collect();
359
360 if samples.is_empty() {
365 let mut sample = Sample {
366 image_name: Some(sample_name.clone()),
367 width: Some(image.width),
368 height: Some(image.height),
369 group: group.map(String::from),
370 ..Default::default()
371 };
372 sample.neg_label_indices = neg_label_indices;
373 sample.not_exhaustive_label_indices = not_exhaustive_label_indices;
374 samples.push(sample);
375 }
376
377 samples
378}
379
380fn sample_name_from_filename(filename: &str) -> String {
382 Path::new(filename)
383 .file_stem()
384 .and_then(|s| s.to_str())
385 .map(String::from)
386 .unwrap_or_else(|| filename.to_string())
387}
388
389pub async fn arrow_to_coco<P: AsRef<Path>>(
405 arrow_path: P,
406 output_path: P,
407 options: &ArrowToCocoOptions,
408 progress: Option<Sender<Progress>>,
409) -> Result<usize, Error> {
410 let arrow_path = arrow_path.as_ref();
411 let output_path = output_path.as_ref();
412
413 let (schema_version, category_metadata_json, labels_metadata_json) = {
415 let mut meta_file = std::fs::File::open(arrow_path)?;
416 let mut reader = IpcReader::new(&mut meta_file);
417 let meta = reader.custom_metadata().ok().flatten();
418 let sv = meta.as_ref().and_then(|m| {
419 m.get(&PlSmallStr::from("schema_version"))
420 .map(|s| s.to_string())
421 });
422 let cm = meta.as_ref().and_then(|m| {
423 m.get(&PlSmallStr::from("category_metadata"))
424 .map(|s| s.to_string())
425 });
426 let lm = meta
427 .as_ref()
428 .and_then(|m| m.get(&PlSmallStr::from("labels")).map(|s| s.to_string()));
429 (sv, cm, lm)
430 };
431
432 let is_legacy = schema_version.is_none();
434
435 let mut file = std::fs::File::open(arrow_path)?;
437 let df = IpcReader::new(&mut file).finish()?;
438
439 let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
441
442 let total_rows = df.height();
443
444 if let Some(ref p) = progress {
445 let _ = p
446 .send(Progress {
447 current: 0,
448 total: total_rows,
449 status: None,
450 })
451 .await;
452 }
453
454 let names: Vec<String> = df
456 .column("name")?
457 .str()?
458 .iter()
459 .map(|s| s.unwrap_or_default().to_string())
460 .collect();
461
462 let labels: Vec<String> = df
463 .column("label")
464 .ok()
465 .and_then(|c| c.cast(&DataType::String).ok())
466 .map(|c| {
467 c.str()
468 .ok()
469 .map(|s| {
470 s.iter()
471 .map(|v| v.unwrap_or_default().to_string())
472 .collect()
473 })
474 .unwrap_or_else(|| vec![String::new(); total_rows])
475 })
476 .unwrap_or_else(|| vec![String::new(); total_rows]);
477
478 let label_indices: Vec<Option<u64>> = df
479 .column("label_index")
480 .ok()
481 .map(|c| {
482 c.u64()
483 .ok()
484 .map(|s| s.iter().collect())
485 .unwrap_or_else(|| vec![None; total_rows])
486 })
487 .unwrap_or_else(|| vec![None; total_rows]);
488
489 let groups: Vec<String> = df
491 .column("group")
492 .ok()
493 .and_then(|c| c.cast(&DataType::String).ok())
494 .map(|c| {
495 c.str()
496 .ok()
497 .map(|s| {
498 s.iter()
499 .map(|v| v.unwrap_or_default().to_string())
500 .collect()
501 })
502 .unwrap_or_default()
503 })
504 .unwrap_or_else(|| vec!["".to_string(); total_rows]);
505
506 let box2ds = df
508 .column("box2d")
509 .ok()
510 .map(extract_all_box2ds)
511 .transpose()?
512 .unwrap_or_else(|| vec![[0.0; 4]; total_rows]);
513
514 let legacy_masks: Option<Vec<Vec<f32>>> = if is_legacy && options.include_masks {
519 df.column("mask").ok().map(extract_all_masks).transpose()?
520 } else {
521 None
522 };
523
524 let polygons_2026: Option<Vec<Option<PolygonRings>>> = if !is_legacy && options.include_masks {
525 df.column("polygon")
526 .ok()
527 .map(|c| extract_all_polygons(c, total_rows))
528 } else {
529 None
530 };
531
532 let mask_binary_2026: Option<Vec<Option<Vec<u8>>>> = if !is_legacy && options.include_masks {
533 df.column("mask")
534 .ok()
535 .map(|c| extract_all_binary_masks(c, total_rows))
536 } else {
537 None
538 };
539
540 let sizes = df
542 .column("size")
543 .ok()
544 .and_then(|c| extract_all_sizes(c).ok());
545
546 let iscrowds: Vec<u8> = df
548 .column("iscrowd")
549 .ok()
550 .map(|c| {
551 if let Ok(bool_ca) = c.bool() {
553 bool_ca
554 .iter()
555 .map(|v| if v.unwrap_or(false) { 1 } else { 0 })
556 .collect()
557 } else {
558 c.u32()
559 .ok()
560 .map(|s| s.iter().map(|v| v.unwrap_or(0) as u8).collect())
561 .unwrap_or_else(|| vec![0; total_rows])
562 }
563 })
564 .unwrap_or_else(|| vec![0; total_rows]);
565
566 let category_frequencies: Vec<Option<String>> = df
568 .column("category_frequency")
569 .ok()
570 .and_then(|c| c.cast(&DataType::String).ok())
571 .map(|c| {
572 c.str()
573 .ok()
574 .map(|s| s.iter().map(|v| v.map(String::from)).collect())
575 .unwrap_or_else(|| vec![None; total_rows])
576 })
577 .unwrap_or_else(|| vec![None; total_rows]);
578
579 let neg_label_indices: Vec<Option<Vec<u32>>> = df
581 .column("neg_label_indices")
582 .ok()
583 .map(|c| extract_list_u32_column(c, total_rows))
584 .unwrap_or_else(|| vec![None; total_rows]);
585
586 let not_exhaustive_label_indices: Vec<Option<Vec<u32>>> = df
588 .column("not_exhaustive_label_indices")
589 .ok()
590 .map(|c| extract_list_u32_column(c, total_rows))
591 .unwrap_or_else(|| vec![None; total_rows]);
592
593 let box2d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box2d_score", total_rows);
595 let box3d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box3d_score", total_rows);
596 let polygon_scores: Vec<Option<f32>> = extract_f32_column(&df, "polygon_score", total_rows);
597 let mask_scores: Vec<Option<f32>> = extract_f32_column(&df, "mask_score", total_rows);
598
599 let object_id_u64s: Vec<Option<u64>> = df
610 .column("object_id")
611 .ok()
612 .and_then(|c| c.cast(&DataType::String).ok())
613 .map(|c| {
614 c.str()
615 .ok()
616 .map(|s| {
617 s.iter()
618 .map(|v| v.and_then(|s| s.parse::<u64>().ok()))
619 .collect()
620 })
621 .unwrap_or_else(|| vec![None; total_rows])
622 })
623 .unwrap_or_else(|| vec![None; total_rows]);
624
625 let mut builder = CocoDatasetBuilder::new();
627
628 if let Some(info) = &options.info {
629 builder = builder.info(info.clone());
630 }
631
632 let skip_row = |i: usize| -> bool {
634 !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i])
635 };
636
637 let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
639 let mut image_ids: HashMap<String, u64> = HashMap::new();
640 let mut category_ids: HashMap<String, u32> = HashMap::new();
641
642 for i in 0..total_rows {
644 if skip_row(i) {
645 continue;
646 }
647
648 let name = &names[i];
649 let label = &labels[i];
650
651 if !image_ids.contains_key(name) {
653 let (width, height) = sizes
654 .as_ref()
655 .and_then(|s| s.get(i).copied())
656 .unwrap_or((0, 0));
657
658 let id = builder.add_image(name, width, height);
659 image_ids.insert(name.clone(), id);
660 image_dimensions.insert(name.clone(), (width, height));
661 }
662
663 if !label.is_empty() && !category_ids.contains_key(label) {
664 let id = if let Some(Some(idx)) = label_indices.get(i) {
665 builder.add_category_with_id(*idx as u32, label, None)
666 } else {
667 builder.add_category(label, None)
668 };
669 category_ids.insert(label.clone(), id);
670 }
671 }
672
673 let mut last_progress_update = 0;
675 for i in 0..total_rows {
676 if skip_row(i) {
677 continue;
678 }
679
680 let name = &names[i];
681 let label = &labels[i];
682
683 if label.is_empty() {
685 continue;
686 }
687
688 let image_id = *image_ids.get(name).unwrap_or(&0);
689 let category_id = *category_ids.get(label).unwrap_or(&0);
690 let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
691
692 let bbox = box2ds.get(i).map(|box2d| {
695 let cx = box2d[0];
696 let cy = box2d[1];
697 let w = box2d[2];
698 let h = box2d[3];
699 let left = cx - w / 2.0;
701 let top = cy - h / 2.0;
702 let ef_box2d = Box2d::new(left, top, w, h);
703 box2d_to_coco_bbox(&ef_box2d, width, height)
704 });
705
706 let segmentation = if options.include_masks {
708 if is_legacy {
709 legacy_masks.as_ref().and_then(|m| {
711 m.get(i).and_then(|coords| {
712 if coords.is_empty() {
713 None
714 } else {
715 let rings = crate::unflatten_polygon_coordinates(coords);
716 let polygon = Polygon::new(rings);
717 let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
718 if coco_poly.is_empty() {
719 None
720 } else {
721 Some(CocoSegmentation::Polygon(coco_poly))
722 }
723 }
724 })
725 })
726 } else {
727 let mask_seg = mask_binary_2026.as_ref().and_then(|masks| {
729 masks.get(i).and_then(|opt_bytes| {
730 opt_bytes
731 .as_ref()
732 .and_then(|png_bytes| png_to_rle_segmentation(png_bytes, i))
733 })
734 });
735
736 if mask_seg.is_some() {
737 mask_seg
738 } else {
739 polygons_2026.as_ref().and_then(|polys| {
741 polys.get(i).and_then(|opt_rings| {
742 opt_rings.as_ref().and_then(|rings| {
743 if rings.is_empty() {
744 return None;
745 }
746 let polygon = Polygon::new(rings.clone());
747 let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
748 if coco_poly.is_empty() {
749 None
750 } else {
751 Some(CocoSegmentation::Polygon(coco_poly))
752 }
753 })
754 })
755 })
756 }
757 }
758 } else {
759 None
760 };
761
762 let score: Option<f64> = mask_scores[i]
764 .or(polygon_scores[i])
765 .or(box3d_scores[i])
766 .or(box2d_scores[i])
767 .map(|s| s as f64);
768
769 if let Some(bbox) = bbox {
770 let iscrowd = iscrowds[i];
771 let ann_id = builder.add_annotation_with_id(
772 object_id_u64s[i],
773 image_id,
774 category_id,
775 bbox,
776 segmentation,
777 iscrowd,
778 );
779
780 if let Some(score_val) = score {
782 builder.set_annotation_score(ann_id, score_val);
783 }
784 }
785
786 if let Some(ref p) = progress
788 && (i - last_progress_update >= 1000 || i == total_rows - 1)
789 {
790 let _ = p
791 .send(Progress {
792 current: i + 1,
793 total: total_rows,
794 status: None,
795 })
796 .await;
797 last_progress_update = i;
798 }
799 }
800
801 if let Some(ref p) = progress
803 && last_progress_update < total_rows.saturating_sub(1)
804 {
805 let _ = p
806 .send(Progress {
807 current: total_rows,
808 total: total_rows,
809 status: None,
810 })
811 .await;
812 }
813
814 {
817 let mut processed_images: std::collections::HashSet<u64> = std::collections::HashSet::new();
818 for i in 0..total_rows {
819 if skip_row(i) {
820 continue;
821 }
822 let name = &names[i];
823 if let Some(&image_id) = image_ids.get(name) {
824 if !processed_images.insert(image_id) {
825 continue;
826 }
827 let neg = neg_label_indices[i].clone();
828 let not_exhaustive = not_exhaustive_label_indices[i].clone();
829 if neg.is_some() || not_exhaustive.is_some() {
830 builder.set_image_neg_categories(image_id, neg, not_exhaustive);
831 }
832 }
833 }
834 }
835
836 {
839 let mut freq_map: HashMap<String, String> = HashMap::new();
840 for i in 0..total_rows {
841 if skip_row(i) {
842 continue;
843 }
844 let label = &labels[i];
845 if !label.is_empty()
846 && !freq_map.contains_key(label)
847 && let Some(ref freq) = category_frequencies[i]
848 {
849 freq_map.insert(label.clone(), freq.clone());
850 }
851 }
852 for (name, freq) in &freq_map {
853 builder.set_category_metadata(name, None, Some(freq.clone()), None, None);
854 }
855 }
856
857 if let Some(ref json_str) = category_metadata_json
864 && let Ok(meta) = serde_json::from_str::<HashMap<String, serde_json::Value>>(json_str)
865 {
866 for (cat_name, value) in &meta {
867 let supercategory = value.get("supercategory").and_then(|v| v.as_str());
868
869 if !category_ids.contains_key(cat_name.as_str()) {
871 let cat_id = value.get("id").and_then(|v| v.as_u64()).map(|id| id as u32);
872 let id = if let Some(cat_id) = cat_id {
873 builder.add_category_with_id(cat_id, cat_name, supercategory)
874 } else {
875 builder.add_category(cat_name, supercategory)
876 };
877 category_ids.insert(cat_name.clone(), id);
878 } else {
879 if let Some(sc) = supercategory {
881 builder.set_category_supercategory(cat_name, sc);
882 }
883 }
884
885 let synset = value
886 .get("synset")
887 .and_then(|v| v.as_str())
888 .map(String::from);
889 let frequency = value
890 .get("frequency")
891 .and_then(|v| v.as_str())
892 .map(String::from);
893 let synonyms = value.get("synonyms").and_then(|v| {
894 v.as_array().map(|arr| {
895 arr.iter()
896 .filter_map(|s| s.as_str().map(String::from))
897 .collect()
898 })
899 });
900 let def = value
901 .get("definition")
902 .and_then(|v| v.as_str())
903 .map(String::from);
904
905 builder.set_category_metadata(cat_name, synset, frequency, synonyms, def);
906 }
907 }
908
909 if category_metadata_json.is_none()
912 && let Some(ref labels_json) = labels_metadata_json
913 && let Ok(label_names) = serde_json::from_str::<Vec<String>>(labels_json)
914 {
915 for label_name in &label_names {
916 if !category_ids.contains_key(label_name) {
917 let id = builder.add_category(label_name, None);
918 category_ids.insert(label_name.clone(), id);
919 }
920 }
921 }
922
923 let dataset = builder.build();
924 let annotation_count = dataset.annotations.len();
925
926 let writer = CocoWriter::new();
928 writer.write_json(&dataset, output_path)?;
929
930 Ok(annotation_count)
931}
932
933fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
935 let arr = col.array()?;
936 let mut result = Vec::with_capacity(arr.len());
937
938 for inner in arr.amortized_iter() {
939 let values = if let Some(inner) = inner {
940 let series = inner.as_ref();
941 let vals: Vec<f32> = series
942 .f32()
943 .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
944 .iter()
945 .map(|v| v.unwrap_or(0.0))
946 .collect();
947
948 if vals.len() == 4 {
949 [vals[0], vals[1], vals[2], vals[3]]
950 } else {
951 [0.0, 0.0, 0.0, 0.0]
952 }
953 } else {
954 [0.0, 0.0, 0.0, 0.0]
955 };
956 result.push(values);
957 }
958
959 Ok(result)
960}
961
962fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
964 let list = col.list()?;
965 let mut result = Vec::with_capacity(list.len());
966
967 for i in 0..list.len() {
968 let coords = match list.get_as_series(i) {
969 Some(series) => series
970 .f32()
971 .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
972 .iter()
973 .map(|v| v.unwrap_or(f32::NAN))
974 .collect(),
975 None => vec![],
976 };
977 result.push(coords);
978 }
979
980 Ok(result)
981}
982
983fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
985 let arr = col.array()?;
986 let mut result = Vec::with_capacity(arr.len());
987
988 for inner in arr.amortized_iter() {
989 let size = if let Some(inner) = inner {
990 let series = inner.as_ref();
991 let values: Vec<u32> = series
992 .u32()
993 .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
994 .iter()
995 .map(|v| v.unwrap_or(0))
996 .collect();
997
998 if values.len() >= 2 {
999 (values[0], values[1])
1000 } else {
1001 (0, 0)
1002 }
1003 } else {
1004 (0, 0)
1005 };
1006 result.push(size);
1007 }
1008
1009 Ok(result)
1010}
1011
1012fn extract_list_u32_column(col: &Column, total_rows: usize) -> Vec<Option<Vec<u32>>> {
1014 col.list()
1015 .ok()
1016 .map(|list| {
1017 (0..list.len())
1018 .map(|i| {
1019 list.get_as_series(i).and_then(|series| {
1020 series
1021 .u32()
1022 .ok()
1023 .map(|ca| ca.iter().flatten().collect::<Vec<u32>>())
1024 })
1025 })
1026 .collect()
1027 })
1028 .unwrap_or_else(|| vec![None; total_rows])
1029}
1030
1031fn extract_all_polygons(col: &Column, total_rows: usize) -> Vec<Option<PolygonRings>> {
1036 let outer_list = match col.list() {
1037 Ok(l) => l,
1038 Err(_) => return vec![None; total_rows],
1039 };
1040
1041 let mut result = Vec::with_capacity(total_rows);
1042 for i in 0..outer_list.len() {
1043 let rings = outer_list.get_as_series(i).and_then(|ring_series| {
1044 let inner_list = ring_series.list().ok()?;
1045 let mut rings = Vec::new();
1046 for j in 0..inner_list.len() {
1047 if let Some(coords_series) = inner_list.get_as_series(j)
1048 && let Ok(f32_ca) = coords_series.f32()
1049 {
1050 let coords: Vec<f32> = f32_ca.iter().map(|v| v.unwrap_or(0.0)).collect();
1051 let points: Vec<(f32, f32)> = coords
1053 .chunks(2)
1054 .filter(|c| c.len() == 2)
1055 .map(|c| (c[0], c[1]))
1056 .collect();
1057 if !points.is_empty() {
1058 rings.push(points);
1059 }
1060 }
1061 }
1062 if rings.is_empty() { None } else { Some(rings) }
1063 });
1064 result.push(rings);
1065 }
1066 result
1067}
1068
1069fn extract_all_binary_masks(col: &Column, total_rows: usize) -> Vec<Option<Vec<u8>>> {
1071 let binary_ca = match col.binary() {
1072 Ok(b) => b,
1073 Err(_) => return vec![None; total_rows],
1074 };
1075
1076 (0..binary_ca.len())
1077 .map(|i| binary_ca.get(i).map(|bytes| bytes.to_vec()))
1078 .collect()
1079}
1080
1081fn extract_f32_column(df: &DataFrame, name: &str, total_rows: usize) -> Vec<Option<f32>> {
1083 df.column(name)
1084 .ok()
1085 .and_then(|c| c.f32().ok())
1086 .map(|ca| ca.iter().collect())
1087 .unwrap_or_else(|| vec![None; total_rows])
1088}
1089
1090fn png_to_rle_segmentation(png_bytes: &[u8], row_index: usize) -> Option<CocoSegmentation> {
1096 if png_bytes.is_empty() {
1097 return None;
1098 }
1099
1100 let mask_data = match crate::MaskData::from_png_checked(png_bytes.to_vec()) {
1101 Ok(m) => m,
1102 Err(e) => {
1103 log::warn!("Skipping invalid PNG mask at row {}: {}", row_index, e);
1104 return None;
1105 }
1106 };
1107
1108 let mw = mask_data.width();
1109 let mh = mask_data.height();
1110 let bit_depth = mask_data.bit_depth();
1111
1112 let decoded = match mask_data.decode() {
1113 Ok(d) => d,
1114 Err(e) => {
1115 log::warn!("Failed to decode PNG mask at row {}: {}", row_index, e);
1116 return None;
1117 }
1118 };
1119
1120 let binary_mask = match bit_depth {
1121 1 => decoded,
1122 8 => {
1123 log::warn!(
1124 "Binarizing 8-bit mask for row {} — score data is lost",
1125 row_index
1126 );
1127 decoded
1128 .iter()
1129 .map(|&v| if v >= 128 { 1 } else { 0 })
1130 .collect()
1131 }
1132 16 => {
1133 log::warn!(
1134 "Binarizing 16-bit mask for row {} — score data is lost",
1135 row_index
1136 );
1137 decoded
1138 .chunks(2)
1139 .map(|pair| {
1140 let val = if pair.len() == 2 {
1141 u16::from_be_bytes([pair[0], pair[1]])
1142 } else {
1143 0
1144 };
1145 if val >= 32768 { 1u8 } else { 0u8 }
1146 })
1147 .collect()
1148 }
1149 _ => decoded,
1150 };
1151
1152 match super::convert::encode_rle(&binary_mask, mw, mh) {
1153 Ok(rle) => Some(CocoSegmentation::Rle(rle)),
1154 Err(e) => {
1155 log::warn!("Failed to encode RLE for row {}: {}", row_index, e);
1156 None
1157 }
1158 }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164 use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
1165 use tempfile::TempDir;
1166
1167 #[test]
1172 fn test_unflatten_polygon_coords_empty() {
1173 let coords: Vec<f32> = vec![];
1174 let result = crate::unflatten_polygon_coordinates(&coords);
1175 assert!(result.is_empty());
1176 }
1177
1178 #[test]
1179 fn test_unflatten_polygon_coords_single_polygon() {
1180 let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
1182 let result = crate::unflatten_polygon_coordinates(&coords);
1183
1184 assert_eq!(result.len(), 1);
1185 assert_eq!(result[0].len(), 4);
1186 assert_eq!(result[0][0], (0.1, 0.2));
1187 assert_eq!(result[0][3], (0.1, 0.4));
1188 }
1189
1190 #[test]
1191 fn test_unflatten_polygon_coords_multiple_polygons() {
1192 let coords = vec![
1194 0.1,
1195 0.1,
1196 0.2,
1197 0.1,
1198 0.15,
1199 0.2, f32::NAN, 0.5,
1202 0.5,
1203 0.6,
1204 0.5,
1205 0.55,
1206 0.6, ];
1208 let result = crate::unflatten_polygon_coordinates(&coords);
1209
1210 assert_eq!(result.len(), 2);
1211 assert_eq!(result[0].len(), 3);
1212 assert_eq!(result[1].len(), 3);
1213 assert_eq!(result[0][0], (0.1, 0.1));
1214 assert_eq!(result[1][0], (0.5, 0.5));
1215 }
1216
1217 #[test]
1218 fn test_unflatten_polygon_coords_leading_nan() {
1219 let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
1221 let result = crate::unflatten_polygon_coordinates(&coords);
1222
1223 assert_eq!(result.len(), 1);
1224 assert_eq!(result[0].len(), 3);
1225 }
1226
1227 #[test]
1228 fn test_unflatten_polygon_coords_trailing_nan() {
1229 let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
1231 let result = crate::unflatten_polygon_coordinates(&coords);
1232
1233 assert_eq!(result.len(), 1);
1234 assert_eq!(result[0].len(), 2);
1235 }
1236
1237 #[test]
1238 fn test_unflatten_polygon_coords_consecutive_nans() {
1239 let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
1241 let result = crate::unflatten_polygon_coordinates(&coords);
1242
1243 assert_eq!(result.len(), 2);
1244 assert_eq!(result[0].len(), 1);
1245 assert_eq!(result[1].len(), 1);
1246 }
1247
1248 #[test]
1249 fn test_unflatten_polygon_coords_odd_values() {
1250 let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1252 let result = crate::unflatten_polygon_coordinates(&coords);
1253
1254 assert_eq!(result.len(), 1);
1255 assert_eq!(result[0].len(), 2); }
1257
1258 #[test]
1263 fn test_convert_image_annotations_basic() {
1264 let image = CocoImage {
1265 id: 1,
1266 width: 640,
1267 height: 480,
1268 file_name: "test_image.jpg".to_string(),
1269 ..Default::default()
1270 };
1271
1272 let dataset = CocoDataset {
1273 images: vec![image.clone()],
1274 categories: vec![CocoCategory {
1275 id: 1,
1276 name: "cat".to_string(),
1277 supercategory: Some("animal".to_string()),
1278 ..Default::default()
1279 }],
1280 annotations: vec![CocoAnnotation {
1281 id: 42,
1282 image_id: 1,
1283 category_id: 1,
1284 bbox: [100.0, 100.0, 200.0, 200.0],
1285 area: 40000.0,
1286 iscrowd: 0,
1287 segmentation: None,
1288 score: None,
1289 }],
1290 ..Default::default()
1291 };
1292
1293 let index = CocoIndex::from_dataset(&dataset);
1294 let samples = convert_image_annotations(&image, &index, true, Some("train"));
1295
1296 assert_eq!(samples.len(), 1);
1297 assert_eq!(samples[0].image_name, Some("test_image".to_string()));
1298 assert_eq!(samples[0].group, Some("train".to_string()));
1299 assert_eq!(samples[0].annotations.len(), 1);
1300 assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
1301 assert_eq!(
1302 samples[0].annotations[0].object_id(),
1303 Some(&"42".to_string()),
1304 "object_id must be populated from COCO annotation id to enable \
1305 prediction-to-prompt linking in prompted-segmentation workflows",
1306 );
1307 }
1308
1309 #[test]
1310 fn test_convert_image_annotations_with_mask() {
1311 let image = CocoImage {
1312 id: 1,
1313 width: 100,
1314 height: 100,
1315 file_name: "masked.jpg".to_string(),
1316 ..Default::default()
1317 };
1318
1319 let dataset = CocoDataset {
1320 images: vec![image.clone()],
1321 categories: vec![CocoCategory {
1322 id: 1,
1323 name: "object".to_string(),
1324 supercategory: None,
1325 ..Default::default()
1326 }],
1327 annotations: vec![CocoAnnotation {
1328 id: 1,
1329 image_id: 1,
1330 category_id: 1,
1331 bbox: [10.0, 10.0, 50.0, 50.0],
1332 area: 2500.0,
1333 iscrowd: 0,
1334 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1335 10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
1336 ]])),
1337 score: None,
1338 }],
1339 ..Default::default()
1340 };
1341
1342 let index = CocoIndex::from_dataset(&dataset);
1343
1344 let samples_with_mask = convert_image_annotations(&image, &index, true, None);
1346 assert!(samples_with_mask[0].annotations[0].polygon().is_some());
1347
1348 let samples_no_mask = convert_image_annotations(&image, &index, false, None);
1350 assert!(samples_no_mask[0].annotations[0].polygon().is_none());
1351 }
1352
1353 #[test]
1354 fn test_convert_image_annotations_object_id_from_lvis_large_id() {
1355 let image = CocoImage {
1360 id: 397133,
1361 width: 640,
1362 height: 480,
1363 file_name: "000000397133.jpg".to_string(),
1364 ..Default::default()
1365 };
1366
1367 let large_id: u64 = 9_876_543_210;
1368 let dataset = CocoDataset {
1369 images: vec![image.clone()],
1370 categories: vec![CocoCategory {
1371 id: 16,
1372 name: "dog".to_string(),
1373 synset: Some("dog.n.01".to_string()),
1374 frequency: Some("f".to_string()),
1375 ..Default::default()
1376 }],
1377 annotations: vec![CocoAnnotation {
1378 id: large_id,
1379 image_id: 397133,
1380 category_id: 16,
1381 bbox: [192.81, 224.8, 74.73, 33.43],
1382 area: 1035.7,
1383 iscrowd: 0,
1384 segmentation: None,
1385 score: None,
1386 }],
1387 ..Default::default()
1388 };
1389
1390 let index = CocoIndex::from_dataset(&dataset);
1391 let samples = convert_image_annotations(&image, &index, true, None);
1392
1393 assert_eq!(samples.len(), 1);
1394 assert_eq!(samples[0].annotations.len(), 1);
1395 assert_eq!(
1396 samples[0].annotations[0].object_id(),
1397 Some(&large_id.to_string()),
1398 );
1399 }
1400
1401 #[test]
1402 fn test_convert_image_annotations_no_annotations() {
1403 let image = CocoImage {
1404 id: 1,
1405 width: 640,
1406 height: 480,
1407 file_name: "empty.jpg".to_string(),
1408 ..Default::default()
1409 };
1410
1411 let dataset = CocoDataset {
1412 images: vec![image.clone()],
1413 categories: vec![],
1414 annotations: vec![],
1415 ..Default::default()
1416 };
1417
1418 let index = CocoIndex::from_dataset(&dataset);
1419 let samples = convert_image_annotations(&image, &index, true, None);
1420
1421 assert_eq!(samples.len(), 1);
1424 assert_eq!(samples[0].image_name, Some("empty".to_string()));
1425 assert!(samples[0].annotations.is_empty());
1426 assert_eq!(samples[0].group, None);
1427 }
1428
1429 #[test]
1434 fn test_sample_name_from_filename() {
1435 assert_eq!(
1436 sample_name_from_filename("000000397133.jpg"),
1437 "000000397133"
1438 );
1439 assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
1440 assert_eq!(sample_name_from_filename("test"), "test");
1441 }
1442
1443 #[test]
1444 fn test_sample_name_from_filename_nested_path() {
1445 assert_eq!(
1446 sample_name_from_filename("a/b/c/deep_image.png"),
1447 "deep_image"
1448 );
1449 }
1450
1451 #[test]
1452 fn test_sample_name_from_filename_no_extension() {
1453 assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
1454 }
1455
1456 #[test]
1461 fn test_coco_to_arrow_options_default() {
1462 let options = CocoToArrowOptions::default();
1463 assert!(options.include_masks);
1464 assert!(options.group.is_none());
1465 assert!(options.max_workers >= 2);
1466 }
1467
1468 #[test]
1469 fn test_arrow_to_coco_options_default() {
1470 let options = ArrowToCocoOptions::default();
1471 assert!(options.groups.is_empty());
1472 assert!(options.include_masks);
1473 assert!(options.info.is_none());
1474 }
1475
1476 #[test]
1477 fn test_max_workers() {
1478 let workers = max_workers();
1479 assert!(workers >= 2);
1480 assert!(workers <= 8);
1481 }
1482
1483 #[tokio::test]
1484 async fn test_coco_to_arrow_minimal() {
1485 let temp_dir = TempDir::new().unwrap();
1486
1487 let coco_json = r#"{
1489 "images": [
1490 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1491 ],
1492 "annotations": [
1493 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1494 ],
1495 "categories": [
1496 {"id": 1, "name": "person", "supercategory": "human"}
1497 ]
1498 }"#;
1499
1500 let coco_path = temp_dir.path().join("test.json");
1501 std::fs::write(&coco_path, coco_json).unwrap();
1502
1503 let arrow_path = temp_dir.path().join("output.arrow");
1504
1505 let options = CocoToArrowOptions::default();
1506 let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1507 .await
1508 .unwrap();
1509
1510 assert_eq!(count, 1);
1511 assert!(arrow_path.exists());
1512
1513 let mut file = std::fs::File::open(&arrow_path).unwrap();
1515 let df = IpcReader::new(&mut file).finish().unwrap();
1516 assert_eq!(df.height(), 1);
1517 }
1518
1519 #[tokio::test]
1520 async fn test_arrow_to_coco_roundtrip() {
1521 let temp_dir = TempDir::new().unwrap();
1522
1523 let original = CocoDataset {
1525 images: vec![CocoImage {
1526 id: 1,
1527 width: 640,
1528 height: 480,
1529 file_name: "test.jpg".to_string(),
1530 ..Default::default()
1531 }],
1532 annotations: vec![CocoAnnotation {
1533 id: 1,
1534 image_id: 1,
1535 category_id: 1,
1536 bbox: [100.0, 50.0, 200.0, 150.0],
1537 area: 30000.0,
1538 iscrowd: 0,
1539 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1540 100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
1541 ]])),
1542 score: None,
1543 }],
1544 categories: vec![CocoCategory {
1545 id: 1,
1546 name: "person".to_string(),
1547 supercategory: Some("human".to_string()),
1548 ..Default::default()
1549 }],
1550 ..Default::default()
1551 };
1552
1553 let coco_path = temp_dir.path().join("original.json");
1555 let writer = CocoWriter::new();
1556 writer.write_json(&original, &coco_path).unwrap();
1557
1558 let arrow_path = temp_dir.path().join("converted.arrow");
1560 let options = CocoToArrowOptions::default();
1561 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1562 .await
1563 .unwrap();
1564
1565 let restored_path = temp_dir.path().join("restored.json");
1567 let options = ArrowToCocoOptions::default();
1568 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1569 .await
1570 .unwrap();
1571
1572 let reader = CocoReader::new();
1574 let restored = reader.read_json(&restored_path).unwrap();
1575
1576 assert_eq!(restored.images.len(), 1);
1577 assert_eq!(restored.annotations.len(), 1);
1578 assert_eq!(restored.categories.len(), 1);
1579
1580 assert_eq!(restored.categories[0].name, "person");
1582 }
1583
1584 #[tokio::test]
1585 async fn test_arrow_to_coco_roundtrip_preserves_annotation_id() {
1586 let temp_dir = TempDir::new().unwrap();
1592
1593 let large_id: u64 = 9_876_543_210;
1594 let original = CocoDataset {
1595 images: vec![CocoImage {
1596 id: 1,
1597 width: 640,
1598 height: 480,
1599 file_name: "test.jpg".to_string(),
1600 ..Default::default()
1601 }],
1602 annotations: vec![
1603 CocoAnnotation {
1604 id: 1,
1605 image_id: 1,
1606 category_id: 1,
1607 bbox: [10.0, 20.0, 100.0, 80.0],
1608 area: 8000.0,
1609 iscrowd: 0,
1610 segmentation: None,
1611 score: None,
1612 },
1613 CocoAnnotation {
1614 id: large_id,
1615 image_id: 1,
1616 category_id: 1,
1617 bbox: [200.0, 200.0, 100.0, 100.0],
1618 area: 10000.0,
1619 iscrowd: 0,
1620 segmentation: None,
1621 score: None,
1622 },
1623 ],
1624 categories: vec![CocoCategory {
1625 id: 1,
1626 name: "person".to_string(),
1627 supercategory: Some("human".to_string()),
1628 ..Default::default()
1629 }],
1630 ..Default::default()
1631 };
1632
1633 let coco_path = temp_dir.path().join("original.json");
1634 let writer = CocoWriter::new();
1635 writer.write_json(&original, &coco_path).unwrap();
1636
1637 let arrow_path = temp_dir.path().join("converted.arrow");
1638 coco_to_arrow(
1639 &coco_path,
1640 &arrow_path,
1641 &CocoToArrowOptions::default(),
1642 None,
1643 )
1644 .await
1645 .unwrap();
1646
1647 let restored_path = temp_dir.path().join("restored.json");
1648 arrow_to_coco(
1649 &arrow_path,
1650 &restored_path,
1651 &ArrowToCocoOptions::default(),
1652 None,
1653 )
1654 .await
1655 .unwrap();
1656
1657 let restored = CocoReader::new().read_json(&restored_path).unwrap();
1658 assert_eq!(restored.annotations.len(), 2);
1659
1660 let restored_ids: std::collections::HashSet<u64> =
1661 restored.annotations.iter().map(|a| a.id).collect();
1662 assert!(
1663 restored_ids.contains(&1),
1664 "small annotation id (1) must round-trip; got {restored_ids:?}"
1665 );
1666 assert!(
1667 restored_ids.contains(&large_id),
1668 "33-bit LVIS-scale annotation id ({large_id}) must round-trip; got {restored_ids:?}"
1669 );
1670 }
1671
1672 #[tokio::test]
1677 async fn test_coco_to_arrow_schema_version_metadata() {
1678 let temp_dir = TempDir::new().unwrap();
1679
1680 let coco_json = r#"{
1682 "images": [
1683 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1684 ],
1685 "annotations": [
1686 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1687 ],
1688 "categories": [
1689 {"id": 1, "name": "person", "supercategory": "human"}
1690 ]
1691 }"#;
1692
1693 let coco_path = temp_dir.path().join("test.json");
1694 std::fs::write(&coco_path, coco_json).unwrap();
1695
1696 let arrow_path = temp_dir.path().join("output.arrow");
1697 let options = CocoToArrowOptions::default();
1698 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1699 .await
1700 .unwrap();
1701
1702 let mut file = std::fs::File::open(&arrow_path).unwrap();
1704 let mut reader = IpcReader::new(&mut file);
1705 let custom_meta = reader.custom_metadata().unwrap();
1706 assert!(custom_meta.is_some(), "custom metadata should be present");
1707
1708 let meta = custom_meta.unwrap();
1709 assert_eq!(
1710 meta.get(&PlSmallStr::from("schema_version")),
1711 Some(&PlSmallStr::from(SCHEMA_VERSION)),
1712 "schema_version metadata should be '2026.04'"
1713 );
1714
1715 assert!(
1717 meta.contains_key(&PlSmallStr::from("category_metadata")),
1718 "category_metadata should be present even without LVIS fields"
1719 );
1720 }
1721
1722 #[tokio::test]
1723 async fn test_coco_to_arrow_category_metadata_lvis() {
1724 let temp_dir = TempDir::new().unwrap();
1725
1726 let coco_json = r#"{
1728 "images": [
1729 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1730 ],
1731 "annotations": [
1732 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0},
1733 {"id": 2, "image_id": 1, "category_id": 2, "bbox": [50, 60, 80, 40], "area": 3200, "iscrowd": 0}
1734 ],
1735 "categories": [
1736 {
1737 "id": 1,
1738 "name": "aerosol_can",
1739 "synset": "aerosol.n.02",
1740 "synonyms": ["aerosol_can", "spray_can"],
1741 "def": "a dispenser that holds a substance under pressure"
1742 },
1743 {
1744 "id": 2,
1745 "name": "person",
1746 "supercategory": "human"
1747 }
1748 ]
1749 }"#;
1750
1751 let coco_path = temp_dir.path().join("lvis.json");
1752 std::fs::write(&coco_path, coco_json).unwrap();
1753
1754 let arrow_path = temp_dir.path().join("lvis_output.arrow");
1755 let options = CocoToArrowOptions::default();
1756 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1757 .await
1758 .unwrap();
1759
1760 let mut file = std::fs::File::open(&arrow_path).unwrap();
1762 let mut reader = IpcReader::new(&mut file);
1763 let custom_meta = reader.custom_metadata().unwrap();
1764 assert!(custom_meta.is_some(), "custom metadata should be present");
1765
1766 let meta = custom_meta.unwrap();
1767
1768 assert_eq!(
1770 meta.get(&PlSmallStr::from("schema_version")),
1771 Some(&PlSmallStr::from(SCHEMA_VERSION)),
1772 );
1773
1774 let cat_meta_str = meta
1776 .get(&PlSmallStr::from("category_metadata"))
1777 .expect("category_metadata should be present for LVIS data");
1778
1779 let cat_meta: HashMap<String, serde_json::Value> =
1780 serde_json::from_str(cat_meta_str.as_str()).unwrap();
1781
1782 assert!(
1784 cat_meta.contains_key("aerosol_can"),
1785 "aerosol_can should be in category_metadata"
1786 );
1787 assert!(
1788 cat_meta.contains_key("person"),
1789 "person should also be in category_metadata"
1790 );
1791
1792 let aerosol = cat_meta.get("aerosol_can").unwrap();
1794 assert_eq!(
1795 aerosol.get("synset").and_then(|v| v.as_str()),
1796 Some("aerosol.n.02")
1797 );
1798 assert_eq!(
1799 aerosol.get("definition").and_then(|v| v.as_str()),
1800 Some("a dispenser that holds a substance under pressure")
1801 );
1802 let synonyms = aerosol.get("synonyms").and_then(|v| v.as_array()).unwrap();
1803 assert_eq!(synonyms.len(), 2);
1804 assert_eq!(synonyms[0].as_str(), Some("aerosol_can"));
1805 assert_eq!(synonyms[1].as_str(), Some("spray_can"));
1806 }
1807
1808 #[tokio::test]
1813 async fn test_coco_arrow_roundtrip_lvis_supercategory() {
1814 let temp_dir = TempDir::new().unwrap();
1815
1816 let coco_json = r#"{
1818 "images": [
1819 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1820 ],
1821 "annotations": [
1822 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1823 ],
1824 "categories": [
1825 {"id": 1, "name": "person", "supercategory": "human"}
1826 ]
1827 }"#;
1828
1829 let coco_path = temp_dir.path().join("original.json");
1830 std::fs::write(&coco_path, coco_json).unwrap();
1831
1832 let arrow_path = temp_dir.path().join("converted.arrow");
1834 let options = CocoToArrowOptions::default();
1835 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1836 .await
1837 .unwrap();
1838
1839 let restored_path = temp_dir.path().join("restored.json");
1841 let options = ArrowToCocoOptions::default();
1842 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1843 .await
1844 .unwrap();
1845
1846 let reader = CocoReader::new();
1848 let restored = reader.read_json(&restored_path).unwrap();
1849
1850 assert_eq!(restored.categories.len(), 1);
1851 assert_eq!(restored.categories[0].name, "person");
1852 assert_eq!(
1853 restored.categories[0].supercategory,
1854 Some("human".to_string()),
1855 "supercategory should survive COCO→Arrow→COCO round-trip"
1856 );
1857 }
1858
1859 #[tokio::test]
1860 async fn test_coco_arrow_roundtrip_neg_categories_no_annotations() {
1861 let temp_dir = TempDir::new().unwrap();
1862
1863 let coco_json = r#"{
1865 "images": [
1866 {
1867 "id": 1,
1868 "width": 640,
1869 "height": 480,
1870 "file_name": "empty.jpg",
1871 "neg_category_ids": [1, 2]
1872 }
1873 ],
1874 "annotations": [],
1875 "categories": [
1876 {"id": 1, "name": "cat", "supercategory": "animal"},
1877 {"id": 2, "name": "dog", "supercategory": "animal"}
1878 ]
1879 }"#;
1880
1881 let coco_path = temp_dir.path().join("original.json");
1882 std::fs::write(&coco_path, coco_json).unwrap();
1883
1884 let arrow_path = temp_dir.path().join("converted.arrow");
1886 let options = CocoToArrowOptions::default();
1887 let sample_count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1888 .await
1889 .unwrap();
1890
1891 assert_eq!(
1893 sample_count, 1,
1894 "sentinel row should be emitted for image with neg data"
1895 );
1896
1897 let restored_path = temp_dir.path().join("restored.json");
1899 let options = ArrowToCocoOptions::default();
1900 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1901 .await
1902 .unwrap();
1903
1904 let reader = CocoReader::new();
1906 let restored = reader.read_json(&restored_path).unwrap();
1907
1908 assert_eq!(restored.images.len(), 1);
1909 assert_eq!(restored.annotations.len(), 0, "no annotations expected");
1910 assert_eq!(restored.categories.len(), 2, "both categories should exist");
1911
1912 let neg = restored.images[0].neg_category_ids.as_ref();
1913 assert!(
1914 neg.is_some(),
1915 "neg_category_ids should survive round-trip for zero-annotation image"
1916 );
1917 let neg_ids = neg.unwrap();
1918 assert_eq!(neg_ids.len(), 2, "should have 2 neg categories");
1919 assert!(neg_ids.contains(&1), "neg_category_ids should contain 1");
1920 assert!(neg_ids.contains(&2), "neg_category_ids should contain 2");
1921
1922 for cat in &restored.categories {
1924 assert_eq!(
1925 cat.supercategory,
1926 Some("animal".to_string()),
1927 "supercategory should survive round-trip for annotation-free category '{}'",
1928 cat.name
1929 );
1930 }
1931 }
1932
1933 #[test]
1934 fn test_convert_image_annotations_neg_only_no_annotations() {
1935 let image = CocoImage {
1936 id: 1,
1937 width: 640,
1938 height: 480,
1939 file_name: "neg_only.jpg".to_string(),
1940 neg_category_ids: Some(vec![1, 2]),
1941 ..Default::default()
1942 };
1943
1944 let dataset = CocoDataset {
1945 images: vec![image.clone()],
1946 categories: vec![
1947 CocoCategory {
1948 id: 1,
1949 name: "cat".to_string(),
1950 supercategory: Some("animal".to_string()),
1951 ..Default::default()
1952 },
1953 CocoCategory {
1954 id: 2,
1955 name: "dog".to_string(),
1956 supercategory: Some("animal".to_string()),
1957 ..Default::default()
1958 },
1959 ],
1960 annotations: vec![],
1961 ..Default::default()
1962 };
1963
1964 let index = CocoIndex::from_dataset(&dataset);
1965 let samples = convert_image_annotations(&image, &index, true, None);
1966
1967 assert_eq!(
1969 samples.len(),
1970 1,
1971 "sentinel row should be emitted for neg-only image"
1972 );
1973 assert_eq!(samples[0].image_name, Some("neg_only".to_string()));
1974 assert!(
1975 samples[0].annotations.is_empty(),
1976 "sentinel should have no annotations"
1977 );
1978 assert!(
1979 samples[0].neg_label_indices.is_some(),
1980 "sentinel should preserve neg_label_indices"
1981 );
1982 assert_eq!(samples[0].neg_label_indices.as_ref().unwrap().len(), 2);
1983 }
1984
1985 #[test]
1986 fn test_convert_image_annotations_no_annotations_emits_placeholder() {
1987 let image = CocoImage {
1991 id: 1,
1992 width: 640,
1993 height: 480,
1994 file_name: "empty.jpg".to_string(),
1995 ..Default::default()
1996 };
1997
1998 let dataset = CocoDataset {
1999 images: vec![image.clone()],
2000 categories: vec![CocoCategory {
2001 id: 1,
2002 name: "person".to_string(),
2003 ..Default::default()
2004 }],
2005 annotations: vec![],
2006 ..Default::default()
2007 };
2008
2009 let index = CocoIndex::from_dataset(&dataset);
2010 let samples = convert_image_annotations(&image, &index, true, Some("train"));
2011
2012 assert_eq!(
2013 samples.len(),
2014 1,
2015 "placeholder row must be emitted for an unannotated image"
2016 );
2017 assert_eq!(samples[0].image_name, Some("empty".to_string()));
2018 assert!(
2019 samples[0].annotations.is_empty(),
2020 "placeholder should have no annotations"
2021 );
2022 assert_eq!(
2023 samples[0].group,
2024 Some("train".to_string()),
2025 "group must be preserved on the placeholder row"
2026 );
2027 }
2028
2029 #[tokio::test]
2030 async fn test_coco_to_arrow_includes_unannotated_images() {
2031 let temp_dir = TempDir::new().unwrap();
2034
2035 let coco_json = r#"{
2036 "images": [
2037 {"id": 1, "width": 640, "height": 480, "file_name": "annotated.jpg"},
2038 {"id": 2, "width": 640, "height": 480, "file_name": "empty.jpg"}
2039 ],
2040 "annotations": [
2041 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
2042 ],
2043 "categories": [
2044 {"id": 1, "name": "person", "supercategory": "human"}
2045 ]
2046 }"#;
2047
2048 let coco_path = temp_dir.path().join("test.json");
2049 std::fs::write(&coco_path, coco_json).unwrap();
2050 let arrow_path = temp_dir.path().join("out.arrow");
2051
2052 let options = CocoToArrowOptions {
2053 group: Some("train".to_string()),
2054 ..Default::default()
2055 };
2056 let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
2057 .await
2058 .unwrap();
2059
2060 assert_eq!(count, 2, "every image must produce at least one row");
2062
2063 let mut file = std::fs::File::open(&arrow_path).unwrap();
2064 let df = IpcReader::new(&mut file).finish().unwrap();
2065 assert_eq!(df.height(), 2);
2066
2067 let names = df.column("name").unwrap().str().unwrap();
2069 let empty_row = (0..df.height()).find(|&i| names.get(i) == Some("empty"));
2070 assert!(
2071 empty_row.is_some(),
2072 "unannotated image 'empty' must appear in the Arrow output"
2073 );
2074 let i = empty_row.unwrap();
2075 let group_col = df.column("group").unwrap().cast(&DataType::String).unwrap();
2076 assert_eq!(
2077 group_col.str().unwrap().get(i),
2078 Some("train"),
2079 "group must be set on the unannotated image's row"
2080 );
2081 let label_col = df.column("label").unwrap().cast(&DataType::String).unwrap();
2082 assert_eq!(
2083 label_col.str().unwrap().get(i),
2084 None,
2085 "unannotated image row must have a null label"
2086 );
2087 }
2088}