1use super::{
10 convert::{
11 box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask_data,
12 coco_segmentation_to_polygon, polygon_to_coco_polygon,
13 },
14 reader::CocoReader,
15 types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
16 writer::{CocoDatasetBuilder, CocoWriter},
17};
18use crate::{Annotation, Box2d, Error, Polygon, Progress, Sample};
19use polars::prelude::*;
20use std::{
21 collections::{BTreeMap, HashMap},
22 path::Path,
23 sync::{
24 Arc,
25 atomic::{AtomicUsize, Ordering},
26 },
27};
28use tokio::sync::{Semaphore, mpsc::Sender};
29
30pub const SCHEMA_VERSION: &str = "2026.04";
32
33type PolygonRings = Vec<Vec<(f32, f32)>>;
35
36#[derive(Debug, Clone)]
38pub struct CocoToArrowOptions {
39 pub include_masks: bool,
41 pub group: Option<String>,
43 pub max_workers: usize,
45}
46
47impl Default for CocoToArrowOptions {
48 fn default() -> Self {
49 Self {
50 include_masks: true,
51 group: None,
52 max_workers: max_workers(),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ArrowToCocoOptions {
60 pub groups: Vec<String>,
62 pub include_masks: bool,
64 pub info: Option<CocoInfo>,
66}
67
68impl Default for ArrowToCocoOptions {
69 fn default() -> Self {
70 Self {
71 groups: vec![],
72 include_masks: true,
73 info: None,
74 }
75 }
76}
77
78fn max_workers() -> usize {
80 std::env::var("MAX_COCO_WORKERS")
81 .ok()
82 .and_then(|v| v.parse().ok())
83 .unwrap_or_else(|| {
84 let cpus = std::thread::available_parallelism()
85 .map(|n| n.get())
86 .unwrap_or(4);
87 (cpus / 2).clamp(2, 8)
88 })
89}
90
91pub async fn coco_to_arrow<P: AsRef<Path>>(
105 coco_path: P,
106 output_path: P,
107 options: &CocoToArrowOptions,
108 progress: Option<Sender<Progress>>,
109) -> Result<usize, Error> {
110 let coco_path = coco_path.as_ref();
111 let output_path = output_path.as_ref();
112
113 let reader = CocoReader::new();
115 let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
116 reader.read_annotations_zip(coco_path)?
117 } else {
118 reader.read_json(coco_path)?
119 };
120
121 let index = Arc::new(CocoIndex::from_dataset(&dataset));
123 let total_images = dataset.images.len();
124
125 if let Some(ref p) = progress {
127 let _ = p
128 .send(Progress {
129 current: 0,
130 total: total_images,
131 status: None,
132 })
133 .await;
134 }
135
136 let sem = Arc::new(Semaphore::new(options.max_workers));
138 let current = Arc::new(AtomicUsize::new(0));
139 let include_masks = options.include_masks;
140 let group = options.group.clone();
141
142 let mut tasks = Vec::with_capacity(total_images);
143
144 for image in dataset.images {
145 let sem = sem.clone();
146 let index = index.clone();
147 let current = current.clone();
148 let progress = progress.clone();
149 let total = total_images;
150 let group = group.clone();
151
152 let task = tokio::spawn(async move {
153 let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
154
155 let samples =
157 convert_image_annotations(&image, &index, include_masks, group.as_deref());
158
159 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
161 if let Some(ref p) = progress {
162 let _ = p
163 .send(Progress {
164 current: c,
165 total,
166 status: None,
167 })
168 .await;
169 }
170
171 Ok::<Vec<Sample>, Error>(samples)
172 });
173
174 tasks.push(task);
175 }
176
177 let mut all_samples = Vec::with_capacity(total_images);
179 for task in tasks {
180 let samples = task.await??;
181 all_samples.extend(samples);
182 }
183
184 let df = crate::samples_dataframe(&all_samples)?;
186
187 let mut metadata: BTreeMap<PlSmallStr, PlSmallStr> = BTreeMap::new();
189 metadata.insert(
190 PlSmallStr::from("schema_version"),
191 PlSmallStr::from(SCHEMA_VERSION),
192 );
193
194 if !dataset.categories.is_empty() {
200 let cat_meta: HashMap<String, serde_json::Value> = dataset
201 .categories
202 .iter()
203 .map(|c| {
204 let mut entry = serde_json::Map::new();
205 entry.insert("id".to_string(), serde_json::json!(c.id));
206 if let Some(ref f) = c.frequency {
207 entry.insert(
208 "frequency".to_string(),
209 serde_json::Value::String(f.clone()),
210 );
211 }
212 if let Some(ref s) = c.synset {
213 entry.insert("synset".to_string(), serde_json::Value::String(s.clone()));
214 }
215 if let Some(ref syns) = c.synonyms {
216 entry.insert("synonyms".to_string(), serde_json::json!(syns));
217 }
218 if let Some(ref d) = c.def {
219 entry.insert(
220 "definition".to_string(),
221 serde_json::Value::String(d.clone()),
222 );
223 }
224 if let Some(ref sc) = c.supercategory {
225 entry.insert(
226 "supercategory".to_string(),
227 serde_json::Value::String(sc.clone()),
228 );
229 }
230 (c.name.clone(), serde_json::Value::Object(entry))
234 })
235 .collect();
236
237 let json = serde_json::to_string(&cat_meta).unwrap_or_default();
238 metadata.insert(
239 PlSmallStr::from("category_metadata"),
240 PlSmallStr::from(json.as_str()),
241 );
242 }
243
244 if !dataset.categories.is_empty() {
246 let mut cats: Vec<_> = dataset.categories.iter().collect();
247 cats.sort_by_key(|c| c.id);
248 let labels: Vec<String> = cats.iter().map(|c| c.name.clone()).collect();
249 let labels_json = serde_json::to_string(&labels).unwrap_or_default();
250 metadata.insert(PlSmallStr::from("labels"), PlSmallStr::from(labels_json));
251 }
252
253 if let Some(parent) = output_path.parent()
255 && !parent.as_os_str().is_empty()
256 {
257 std::fs::create_dir_all(parent)?;
258 }
259 let mut file = std::fs::File::create(output_path)?;
260 let mut writer = IpcWriter::new(&mut file);
261 writer.set_custom_schema_metadata(Arc::new(metadata));
262 writer.finish(&mut df.clone())?;
263
264 Ok(all_samples.len())
265}
266
267fn convert_image_annotations(
269 image: &CocoImage,
270 index: &CocoIndex,
271 include_masks: bool,
272 group: Option<&str>,
273) -> Vec<Sample> {
274 let annotations = index.annotations_for_image(image.id);
275 let sample_name = sample_name_from_filename(&image.file_name);
276
277 let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
279 ids.iter()
280 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
281 .collect::<Vec<u32>>()
282 });
283 let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
284 ids.iter()
285 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
286 .collect::<Vec<u32>>()
287 });
288
289 let mut samples: Vec<Sample> = annotations
290 .iter()
291 .filter_map(|ann| {
292 let label = index.label_name(ann.category_id)?;
293 let label_index = index.label_index(ann.category_id);
294
295 let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
297
298 let (polygon, mask) = if include_masks {
302 if let Some(seg) = &ann.segmentation {
303 match seg {
304 CocoSegmentation::Polygon(_) => {
305 let poly =
306 coco_segmentation_to_polygon(seg, image.width, image.height).ok();
307 (poly, None)
308 }
309 CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_) => {
310 let mask_data = coco_segmentation_to_mask_data(seg).ok().flatten();
311 (None, mask_data)
312 }
313 }
314 } else {
315 (None, None)
316 }
317 } else {
318 (None, None)
319 };
320
321 let mut annotation = Annotation::new();
322 annotation.set_name(Some(sample_name.clone()));
323 annotation.set_object_id(Some(ann.id.to_string()));
324 annotation.set_label(Some(label.to_string()));
325 annotation.set_label_index(label_index);
326 annotation.set_box2d(Some(box2d));
327 annotation.set_polygon(polygon);
328 annotation.set_mask(mask);
329 annotation.set_group(group.map(String::from));
330 annotation.set_iscrowd(Some(ann.iscrowd != 0));
331 annotation.set_category_frequency(index.frequency(ann.category_id).map(String::from));
332
333 if let Some(score) = ann.score {
335 let score_f32 = score as f32;
336 if annotation.mask().is_some() {
337 annotation.set_mask_score(Some(score_f32));
338 } else if annotation.polygon().is_some() {
339 annotation.set_polygon_score(Some(score_f32));
340 } else {
341 annotation.set_box2d_score(Some(score_f32));
342 }
343 }
344
345 let mut sample = Sample {
346 image_name: Some(sample_name.clone()),
347 width: Some(image.width),
348 height: Some(image.height),
349 group: group.map(String::from),
350 annotations: vec![annotation],
351 ..Default::default()
352 };
353 sample.neg_label_indices = neg_label_indices.clone();
354 sample.not_exhaustive_label_indices = not_exhaustive_label_indices.clone();
355
356 Some(sample)
357 })
358 .collect();
359
360 if samples.is_empty()
364 && (image.neg_category_ids.is_some() || image.not_exhaustive_category_ids.is_some())
365 {
366 let mut sample = Sample {
367 image_name: Some(sample_name.clone()),
368 width: Some(image.width),
369 height: Some(image.height),
370 group: group.map(String::from),
371 ..Default::default()
372 };
373 sample.neg_label_indices = neg_label_indices;
374 sample.not_exhaustive_label_indices = not_exhaustive_label_indices;
375 samples.push(sample);
376 }
377
378 samples
379}
380
381fn sample_name_from_filename(filename: &str) -> String {
383 Path::new(filename)
384 .file_stem()
385 .and_then(|s| s.to_str())
386 .map(String::from)
387 .unwrap_or_else(|| filename.to_string())
388}
389
390pub async fn arrow_to_coco<P: AsRef<Path>>(
406 arrow_path: P,
407 output_path: P,
408 options: &ArrowToCocoOptions,
409 progress: Option<Sender<Progress>>,
410) -> Result<usize, Error> {
411 let arrow_path = arrow_path.as_ref();
412 let output_path = output_path.as_ref();
413
414 let (schema_version, category_metadata_json, labels_metadata_json) = {
416 let mut meta_file = std::fs::File::open(arrow_path)?;
417 let mut reader = IpcReader::new(&mut meta_file);
418 let meta = reader.custom_metadata().ok().flatten();
419 let sv = meta.as_ref().and_then(|m| {
420 m.get(&PlSmallStr::from("schema_version"))
421 .map(|s| s.to_string())
422 });
423 let cm = meta.as_ref().and_then(|m| {
424 m.get(&PlSmallStr::from("category_metadata"))
425 .map(|s| s.to_string())
426 });
427 let lm = meta
428 .as_ref()
429 .and_then(|m| m.get(&PlSmallStr::from("labels")).map(|s| s.to_string()));
430 (sv, cm, lm)
431 };
432
433 let is_legacy = schema_version.is_none();
435
436 let mut file = std::fs::File::open(arrow_path)?;
438 let df = IpcReader::new(&mut file).finish()?;
439
440 let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
442
443 let total_rows = df.height();
444
445 if let Some(ref p) = progress {
446 let _ = p
447 .send(Progress {
448 current: 0,
449 total: total_rows,
450 status: None,
451 })
452 .await;
453 }
454
455 let names: Vec<String> = df
457 .column("name")?
458 .str()?
459 .into_iter()
460 .map(|s| s.unwrap_or_default().to_string())
461 .collect();
462
463 let labels: Vec<String> = df
464 .column("label")
465 .ok()
466 .and_then(|c| c.cast(&DataType::String).ok())
467 .map(|c| {
468 c.str()
469 .ok()
470 .map(|s| {
471 s.into_iter()
472 .map(|v| v.unwrap_or_default().to_string())
473 .collect()
474 })
475 .unwrap_or_else(|| vec![String::new(); total_rows])
476 })
477 .unwrap_or_else(|| vec![String::new(); total_rows]);
478
479 let label_indices: Vec<Option<u64>> = df
480 .column("label_index")
481 .ok()
482 .map(|c| {
483 c.u64()
484 .ok()
485 .map(|s| s.into_iter().collect())
486 .unwrap_or_else(|| vec![None; total_rows])
487 })
488 .unwrap_or_else(|| vec![None; total_rows]);
489
490 let groups: Vec<String> = df
492 .column("group")
493 .ok()
494 .and_then(|c| c.cast(&DataType::String).ok())
495 .map(|c| {
496 c.str()
497 .ok()
498 .map(|s| {
499 s.into_iter()
500 .map(|v| v.unwrap_or_default().to_string())
501 .collect()
502 })
503 .unwrap_or_default()
504 })
505 .unwrap_or_else(|| vec!["".to_string(); total_rows]);
506
507 let box2ds = df
509 .column("box2d")
510 .ok()
511 .map(extract_all_box2ds)
512 .transpose()?
513 .unwrap_or_else(|| vec![[0.0; 4]; total_rows]);
514
515 let legacy_masks: Option<Vec<Vec<f32>>> = if is_legacy && options.include_masks {
520 df.column("mask").ok().map(extract_all_masks).transpose()?
521 } else {
522 None
523 };
524
525 let polygons_2026: Option<Vec<Option<PolygonRings>>> = if !is_legacy && options.include_masks {
526 df.column("polygon")
527 .ok()
528 .map(|c| extract_all_polygons(c, total_rows))
529 } else {
530 None
531 };
532
533 let mask_binary_2026: Option<Vec<Option<Vec<u8>>>> = if !is_legacy && options.include_masks {
534 df.column("mask")
535 .ok()
536 .map(|c| extract_all_binary_masks(c, total_rows))
537 } else {
538 None
539 };
540
541 let sizes = df
543 .column("size")
544 .ok()
545 .and_then(|c| extract_all_sizes(c).ok());
546
547 let iscrowds: Vec<u8> = df
549 .column("iscrowd")
550 .ok()
551 .map(|c| {
552 if let Ok(bool_ca) = c.bool() {
554 bool_ca
555 .into_iter()
556 .map(|v| if v.unwrap_or(false) { 1 } else { 0 })
557 .collect()
558 } else {
559 c.u32()
560 .ok()
561 .map(|s| s.into_iter().map(|v| v.unwrap_or(0) as u8).collect())
562 .unwrap_or_else(|| vec![0; total_rows])
563 }
564 })
565 .unwrap_or_else(|| vec![0; total_rows]);
566
567 let category_frequencies: Vec<Option<String>> = df
569 .column("category_frequency")
570 .ok()
571 .and_then(|c| c.cast(&DataType::String).ok())
572 .map(|c| {
573 c.str()
574 .ok()
575 .map(|s| s.into_iter().map(|v| v.map(String::from)).collect())
576 .unwrap_or_else(|| vec![None; total_rows])
577 })
578 .unwrap_or_else(|| vec![None; total_rows]);
579
580 let neg_label_indices: Vec<Option<Vec<u32>>> = df
582 .column("neg_label_indices")
583 .ok()
584 .map(|c| extract_list_u32_column(c, total_rows))
585 .unwrap_or_else(|| vec![None; total_rows]);
586
587 let not_exhaustive_label_indices: Vec<Option<Vec<u32>>> = df
589 .column("not_exhaustive_label_indices")
590 .ok()
591 .map(|c| extract_list_u32_column(c, total_rows))
592 .unwrap_or_else(|| vec![None; total_rows]);
593
594 let box2d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box2d_score", total_rows);
596 let box3d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box3d_score", total_rows);
597 let polygon_scores: Vec<Option<f32>> = extract_f32_column(&df, "polygon_score", total_rows);
598 let mask_scores: Vec<Option<f32>> = extract_f32_column(&df, "mask_score", total_rows);
599
600 let object_id_u64s: Vec<Option<u64>> = df
611 .column("object_id")
612 .ok()
613 .and_then(|c| c.cast(&DataType::String).ok())
614 .map(|c| {
615 c.str()
616 .ok()
617 .map(|s| {
618 s.into_iter()
619 .map(|v| v.and_then(|s| s.parse::<u64>().ok()))
620 .collect()
621 })
622 .unwrap_or_else(|| vec![None; total_rows])
623 })
624 .unwrap_or_else(|| vec![None; total_rows]);
625
626 let mut builder = CocoDatasetBuilder::new();
628
629 if let Some(info) = &options.info {
630 builder = builder.info(info.clone());
631 }
632
633 let skip_row = |i: usize| -> bool {
635 !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i])
636 };
637
638 let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
640 let mut image_ids: HashMap<String, u64> = HashMap::new();
641 let mut category_ids: HashMap<String, u32> = HashMap::new();
642
643 for i in 0..total_rows {
645 if skip_row(i) {
646 continue;
647 }
648
649 let name = &names[i];
650 let label = &labels[i];
651
652 if !image_ids.contains_key(name) {
654 let (width, height) = sizes
655 .as_ref()
656 .and_then(|s| s.get(i).copied())
657 .unwrap_or((0, 0));
658
659 let id = builder.add_image(name, width, height);
660 image_ids.insert(name.clone(), id);
661 image_dimensions.insert(name.clone(), (width, height));
662 }
663
664 if !label.is_empty() && !category_ids.contains_key(label) {
665 let id = if let Some(Some(idx)) = label_indices.get(i) {
666 builder.add_category_with_id(*idx as u32, label, None)
667 } else {
668 builder.add_category(label, None)
669 };
670 category_ids.insert(label.clone(), id);
671 }
672 }
673
674 let mut last_progress_update = 0;
676 for i in 0..total_rows {
677 if skip_row(i) {
678 continue;
679 }
680
681 let name = &names[i];
682 let label = &labels[i];
683
684 if label.is_empty() {
686 continue;
687 }
688
689 let image_id = *image_ids.get(name).unwrap_or(&0);
690 let category_id = *category_ids.get(label).unwrap_or(&0);
691 let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
692
693 let bbox = box2ds.get(i).map(|box2d| {
696 let cx = box2d[0];
697 let cy = box2d[1];
698 let w = box2d[2];
699 let h = box2d[3];
700 let left = cx - w / 2.0;
702 let top = cy - h / 2.0;
703 let ef_box2d = Box2d::new(left, top, w, h);
704 box2d_to_coco_bbox(&ef_box2d, width, height)
705 });
706
707 let segmentation = if options.include_masks {
709 if is_legacy {
710 legacy_masks.as_ref().and_then(|m| {
712 m.get(i).and_then(|coords| {
713 if coords.is_empty() {
714 None
715 } else {
716 let rings = crate::unflatten_polygon_coordinates(coords);
717 let polygon = Polygon::new(rings);
718 let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
719 if coco_poly.is_empty() {
720 None
721 } else {
722 Some(CocoSegmentation::Polygon(coco_poly))
723 }
724 }
725 })
726 })
727 } else {
728 let mask_seg = mask_binary_2026.as_ref().and_then(|masks| {
730 masks.get(i).and_then(|opt_bytes| {
731 opt_bytes
732 .as_ref()
733 .and_then(|png_bytes| png_to_rle_segmentation(png_bytes, i))
734 })
735 });
736
737 if mask_seg.is_some() {
738 mask_seg
739 } else {
740 polygons_2026.as_ref().and_then(|polys| {
742 polys.get(i).and_then(|opt_rings| {
743 opt_rings.as_ref().and_then(|rings| {
744 if rings.is_empty() {
745 return None;
746 }
747 let polygon = Polygon::new(rings.clone());
748 let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
749 if coco_poly.is_empty() {
750 None
751 } else {
752 Some(CocoSegmentation::Polygon(coco_poly))
753 }
754 })
755 })
756 })
757 }
758 }
759 } else {
760 None
761 };
762
763 let score: Option<f64> = mask_scores[i]
765 .or(polygon_scores[i])
766 .or(box3d_scores[i])
767 .or(box2d_scores[i])
768 .map(|s| s as f64);
769
770 if let Some(bbox) = bbox {
771 let iscrowd = iscrowds[i];
772 let ann_id = builder.add_annotation_with_id(
773 object_id_u64s[i],
774 image_id,
775 category_id,
776 bbox,
777 segmentation,
778 iscrowd,
779 );
780
781 if let Some(score_val) = score {
783 builder.set_annotation_score(ann_id, score_val);
784 }
785 }
786
787 if let Some(ref p) = progress
789 && (i - last_progress_update >= 1000 || i == total_rows - 1)
790 {
791 let _ = p
792 .send(Progress {
793 current: i + 1,
794 total: total_rows,
795 status: None,
796 })
797 .await;
798 last_progress_update = i;
799 }
800 }
801
802 if let Some(ref p) = progress
804 && last_progress_update < total_rows.saturating_sub(1)
805 {
806 let _ = p
807 .send(Progress {
808 current: total_rows,
809 total: total_rows,
810 status: None,
811 })
812 .await;
813 }
814
815 {
818 let mut processed_images: std::collections::HashSet<u64> = std::collections::HashSet::new();
819 for i in 0..total_rows {
820 if skip_row(i) {
821 continue;
822 }
823 let name = &names[i];
824 if let Some(&image_id) = image_ids.get(name) {
825 if !processed_images.insert(image_id) {
826 continue;
827 }
828 let neg = neg_label_indices[i].clone();
829 let not_exhaustive = not_exhaustive_label_indices[i].clone();
830 if neg.is_some() || not_exhaustive.is_some() {
831 builder.set_image_neg_categories(image_id, neg, not_exhaustive);
832 }
833 }
834 }
835 }
836
837 {
840 let mut freq_map: HashMap<String, String> = HashMap::new();
841 for i in 0..total_rows {
842 if skip_row(i) {
843 continue;
844 }
845 let label = &labels[i];
846 if !label.is_empty()
847 && !freq_map.contains_key(label)
848 && let Some(ref freq) = category_frequencies[i]
849 {
850 freq_map.insert(label.clone(), freq.clone());
851 }
852 }
853 for (name, freq) in &freq_map {
854 builder.set_category_metadata(name, None, Some(freq.clone()), None, None);
855 }
856 }
857
858 if let Some(ref json_str) = category_metadata_json
865 && let Ok(meta) = serde_json::from_str::<HashMap<String, serde_json::Value>>(json_str)
866 {
867 for (cat_name, value) in &meta {
868 let supercategory = value.get("supercategory").and_then(|v| v.as_str());
869
870 if !category_ids.contains_key(cat_name.as_str()) {
872 let cat_id = value.get("id").and_then(|v| v.as_u64()).map(|id| id as u32);
873 let id = if let Some(cat_id) = cat_id {
874 builder.add_category_with_id(cat_id, cat_name, supercategory)
875 } else {
876 builder.add_category(cat_name, supercategory)
877 };
878 category_ids.insert(cat_name.clone(), id);
879 } else {
880 if let Some(sc) = supercategory {
882 builder.set_category_supercategory(cat_name, sc);
883 }
884 }
885
886 let synset = value
887 .get("synset")
888 .and_then(|v| v.as_str())
889 .map(String::from);
890 let frequency = value
891 .get("frequency")
892 .and_then(|v| v.as_str())
893 .map(String::from);
894 let synonyms = value.get("synonyms").and_then(|v| {
895 v.as_array().map(|arr| {
896 arr.iter()
897 .filter_map(|s| s.as_str().map(String::from))
898 .collect()
899 })
900 });
901 let def = value
902 .get("definition")
903 .and_then(|v| v.as_str())
904 .map(String::from);
905
906 builder.set_category_metadata(cat_name, synset, frequency, synonyms, def);
907 }
908 }
909
910 if category_metadata_json.is_none()
913 && let Some(ref labels_json) = labels_metadata_json
914 && let Ok(label_names) = serde_json::from_str::<Vec<String>>(labels_json)
915 {
916 for label_name in &label_names {
917 if !category_ids.contains_key(label_name) {
918 let id = builder.add_category(label_name, None);
919 category_ids.insert(label_name.clone(), id);
920 }
921 }
922 }
923
924 let dataset = builder.build();
925 let annotation_count = dataset.annotations.len();
926
927 let writer = CocoWriter::new();
929 writer.write_json(&dataset, output_path)?;
930
931 Ok(annotation_count)
932}
933
934fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
936 let arr = col.array()?;
937 let mut result = Vec::with_capacity(arr.len());
938
939 for inner in arr.amortized_iter() {
940 let values = if let Some(inner) = inner {
941 let series = inner.as_ref();
942 let vals: Vec<f32> = series
943 .f32()
944 .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
945 .into_iter()
946 .map(|v| v.unwrap_or(0.0))
947 .collect();
948
949 if vals.len() == 4 {
950 [vals[0], vals[1], vals[2], vals[3]]
951 } else {
952 [0.0, 0.0, 0.0, 0.0]
953 }
954 } else {
955 [0.0, 0.0, 0.0, 0.0]
956 };
957 result.push(values);
958 }
959
960 Ok(result)
961}
962
963fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
965 let list = col.list()?;
966 let mut result = Vec::with_capacity(list.len());
967
968 for i in 0..list.len() {
969 let coords = match list.get_as_series(i) {
970 Some(series) => series
971 .f32()
972 .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
973 .into_iter()
974 .map(|v| v.unwrap_or(f32::NAN))
975 .collect(),
976 None => vec![],
977 };
978 result.push(coords);
979 }
980
981 Ok(result)
982}
983
984fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
986 let arr = col.array()?;
987 let mut result = Vec::with_capacity(arr.len());
988
989 for inner in arr.amortized_iter() {
990 let size = if let Some(inner) = inner {
991 let series = inner.as_ref();
992 let values: Vec<u32> = series
993 .u32()
994 .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
995 .into_iter()
996 .map(|v| v.unwrap_or(0))
997 .collect();
998
999 if values.len() >= 2 {
1000 (values[0], values[1])
1001 } else {
1002 (0, 0)
1003 }
1004 } else {
1005 (0, 0)
1006 };
1007 result.push(size);
1008 }
1009
1010 Ok(result)
1011}
1012
1013fn extract_list_u32_column(col: &Column, total_rows: usize) -> Vec<Option<Vec<u32>>> {
1015 col.list()
1016 .ok()
1017 .map(|list| {
1018 (0..list.len())
1019 .map(|i| {
1020 list.get_as_series(i).and_then(|series| {
1021 series
1022 .u32()
1023 .ok()
1024 .map(|ca| ca.into_iter().flatten().collect::<Vec<u32>>())
1025 })
1026 })
1027 .collect()
1028 })
1029 .unwrap_or_else(|| vec![None; total_rows])
1030}
1031
1032fn extract_all_polygons(col: &Column, total_rows: usize) -> Vec<Option<PolygonRings>> {
1037 let outer_list = match col.list() {
1038 Ok(l) => l,
1039 Err(_) => return vec![None; total_rows],
1040 };
1041
1042 let mut result = Vec::with_capacity(total_rows);
1043 for i in 0..outer_list.len() {
1044 let rings = outer_list.get_as_series(i).and_then(|ring_series| {
1045 let inner_list = ring_series.list().ok()?;
1046 let mut rings = Vec::new();
1047 for j in 0..inner_list.len() {
1048 if let Some(coords_series) = inner_list.get_as_series(j)
1049 && let Ok(f32_ca) = coords_series.f32()
1050 {
1051 let coords: Vec<f32> = f32_ca.into_iter().map(|v| v.unwrap_or(0.0)).collect();
1052 let points: Vec<(f32, f32)> = coords
1054 .chunks(2)
1055 .filter(|c| c.len() == 2)
1056 .map(|c| (c[0], c[1]))
1057 .collect();
1058 if !points.is_empty() {
1059 rings.push(points);
1060 }
1061 }
1062 }
1063 if rings.is_empty() { None } else { Some(rings) }
1064 });
1065 result.push(rings);
1066 }
1067 result
1068}
1069
1070fn extract_all_binary_masks(col: &Column, total_rows: usize) -> Vec<Option<Vec<u8>>> {
1072 let binary_ca = match col.binary() {
1073 Ok(b) => b,
1074 Err(_) => return vec![None; total_rows],
1075 };
1076
1077 (0..binary_ca.len())
1078 .map(|i| binary_ca.get(i).map(|bytes| bytes.to_vec()))
1079 .collect()
1080}
1081
1082fn extract_f32_column(df: &DataFrame, name: &str, total_rows: usize) -> Vec<Option<f32>> {
1084 df.column(name)
1085 .ok()
1086 .and_then(|c| c.f32().ok())
1087 .map(|ca| ca.into_iter().collect())
1088 .unwrap_or_else(|| vec![None; total_rows])
1089}
1090
1091fn png_to_rle_segmentation(png_bytes: &[u8], row_index: usize) -> Option<CocoSegmentation> {
1097 if png_bytes.is_empty() {
1098 return None;
1099 }
1100
1101 let mask_data = match crate::MaskData::from_png_checked(png_bytes.to_vec()) {
1102 Ok(m) => m,
1103 Err(e) => {
1104 log::warn!("Skipping invalid PNG mask at row {}: {}", row_index, e);
1105 return None;
1106 }
1107 };
1108
1109 let mw = mask_data.width();
1110 let mh = mask_data.height();
1111 let bit_depth = mask_data.bit_depth();
1112
1113 let decoded = match mask_data.decode() {
1114 Ok(d) => d,
1115 Err(e) => {
1116 log::warn!("Failed to decode PNG mask at row {}: {}", row_index, e);
1117 return None;
1118 }
1119 };
1120
1121 let binary_mask = match bit_depth {
1122 1 => decoded,
1123 8 => {
1124 log::warn!(
1125 "Binarizing 8-bit mask for row {} — score data is lost",
1126 row_index
1127 );
1128 decoded
1129 .iter()
1130 .map(|&v| if v >= 128 { 1 } else { 0 })
1131 .collect()
1132 }
1133 16 => {
1134 log::warn!(
1135 "Binarizing 16-bit mask for row {} — score data is lost",
1136 row_index
1137 );
1138 decoded
1139 .chunks(2)
1140 .map(|pair| {
1141 let val = if pair.len() == 2 {
1142 u16::from_be_bytes([pair[0], pair[1]])
1143 } else {
1144 0
1145 };
1146 if val >= 32768 { 1u8 } else { 0u8 }
1147 })
1148 .collect()
1149 }
1150 _ => decoded,
1151 };
1152
1153 match super::convert::encode_rle(&binary_mask, mw, mh) {
1154 Ok(rle) => Some(CocoSegmentation::Rle(rle)),
1155 Err(e) => {
1156 log::warn!("Failed to encode RLE for row {}: {}", row_index, e);
1157 None
1158 }
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165 use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
1166 use tempfile::TempDir;
1167
1168 #[test]
1173 fn test_unflatten_polygon_coords_empty() {
1174 let coords: Vec<f32> = vec![];
1175 let result = crate::unflatten_polygon_coordinates(&coords);
1176 assert!(result.is_empty());
1177 }
1178
1179 #[test]
1180 fn test_unflatten_polygon_coords_single_polygon() {
1181 let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
1183 let result = crate::unflatten_polygon_coordinates(&coords);
1184
1185 assert_eq!(result.len(), 1);
1186 assert_eq!(result[0].len(), 4);
1187 assert_eq!(result[0][0], (0.1, 0.2));
1188 assert_eq!(result[0][3], (0.1, 0.4));
1189 }
1190
1191 #[test]
1192 fn test_unflatten_polygon_coords_multiple_polygons() {
1193 let coords = vec![
1195 0.1,
1196 0.1,
1197 0.2,
1198 0.1,
1199 0.15,
1200 0.2, f32::NAN, 0.5,
1203 0.5,
1204 0.6,
1205 0.5,
1206 0.55,
1207 0.6, ];
1209 let result = crate::unflatten_polygon_coordinates(&coords);
1210
1211 assert_eq!(result.len(), 2);
1212 assert_eq!(result[0].len(), 3);
1213 assert_eq!(result[1].len(), 3);
1214 assert_eq!(result[0][0], (0.1, 0.1));
1215 assert_eq!(result[1][0], (0.5, 0.5));
1216 }
1217
1218 #[test]
1219 fn test_unflatten_polygon_coords_leading_nan() {
1220 let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
1222 let result = crate::unflatten_polygon_coordinates(&coords);
1223
1224 assert_eq!(result.len(), 1);
1225 assert_eq!(result[0].len(), 3);
1226 }
1227
1228 #[test]
1229 fn test_unflatten_polygon_coords_trailing_nan() {
1230 let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
1232 let result = crate::unflatten_polygon_coordinates(&coords);
1233
1234 assert_eq!(result.len(), 1);
1235 assert_eq!(result[0].len(), 2);
1236 }
1237
1238 #[test]
1239 fn test_unflatten_polygon_coords_consecutive_nans() {
1240 let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
1242 let result = crate::unflatten_polygon_coordinates(&coords);
1243
1244 assert_eq!(result.len(), 2);
1245 assert_eq!(result[0].len(), 1);
1246 assert_eq!(result[1].len(), 1);
1247 }
1248
1249 #[test]
1250 fn test_unflatten_polygon_coords_odd_values() {
1251 let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1253 let result = crate::unflatten_polygon_coordinates(&coords);
1254
1255 assert_eq!(result.len(), 1);
1256 assert_eq!(result[0].len(), 2); }
1258
1259 #[test]
1264 fn test_convert_image_annotations_basic() {
1265 let image = CocoImage {
1266 id: 1,
1267 width: 640,
1268 height: 480,
1269 file_name: "test_image.jpg".to_string(),
1270 ..Default::default()
1271 };
1272
1273 let dataset = CocoDataset {
1274 images: vec![image.clone()],
1275 categories: vec![CocoCategory {
1276 id: 1,
1277 name: "cat".to_string(),
1278 supercategory: Some("animal".to_string()),
1279 ..Default::default()
1280 }],
1281 annotations: vec![CocoAnnotation {
1282 id: 42,
1283 image_id: 1,
1284 category_id: 1,
1285 bbox: [100.0, 100.0, 200.0, 200.0],
1286 area: 40000.0,
1287 iscrowd: 0,
1288 segmentation: None,
1289 score: None,
1290 }],
1291 ..Default::default()
1292 };
1293
1294 let index = CocoIndex::from_dataset(&dataset);
1295 let samples = convert_image_annotations(&image, &index, true, Some("train"));
1296
1297 assert_eq!(samples.len(), 1);
1298 assert_eq!(samples[0].image_name, Some("test_image".to_string()));
1299 assert_eq!(samples[0].group, Some("train".to_string()));
1300 assert_eq!(samples[0].annotations.len(), 1);
1301 assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
1302 assert_eq!(
1303 samples[0].annotations[0].object_id(),
1304 Some(&"42".to_string()),
1305 "object_id must be populated from COCO annotation id to enable \
1306 prediction-to-prompt linking in prompted-segmentation workflows",
1307 );
1308 }
1309
1310 #[test]
1311 fn test_convert_image_annotations_with_mask() {
1312 let image = CocoImage {
1313 id: 1,
1314 width: 100,
1315 height: 100,
1316 file_name: "masked.jpg".to_string(),
1317 ..Default::default()
1318 };
1319
1320 let dataset = CocoDataset {
1321 images: vec![image.clone()],
1322 categories: vec![CocoCategory {
1323 id: 1,
1324 name: "object".to_string(),
1325 supercategory: None,
1326 ..Default::default()
1327 }],
1328 annotations: vec![CocoAnnotation {
1329 id: 1,
1330 image_id: 1,
1331 category_id: 1,
1332 bbox: [10.0, 10.0, 50.0, 50.0],
1333 area: 2500.0,
1334 iscrowd: 0,
1335 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1336 10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
1337 ]])),
1338 score: None,
1339 }],
1340 ..Default::default()
1341 };
1342
1343 let index = CocoIndex::from_dataset(&dataset);
1344
1345 let samples_with_mask = convert_image_annotations(&image, &index, true, None);
1347 assert!(samples_with_mask[0].annotations[0].polygon().is_some());
1348
1349 let samples_no_mask = convert_image_annotations(&image, &index, false, None);
1351 assert!(samples_no_mask[0].annotations[0].polygon().is_none());
1352 }
1353
1354 #[test]
1355 fn test_convert_image_annotations_object_id_from_lvis_large_id() {
1356 let image = CocoImage {
1361 id: 397133,
1362 width: 640,
1363 height: 480,
1364 file_name: "000000397133.jpg".to_string(),
1365 ..Default::default()
1366 };
1367
1368 let large_id: u64 = 9_876_543_210;
1369 let dataset = CocoDataset {
1370 images: vec![image.clone()],
1371 categories: vec![CocoCategory {
1372 id: 16,
1373 name: "dog".to_string(),
1374 synset: Some("dog.n.01".to_string()),
1375 frequency: Some("f".to_string()),
1376 ..Default::default()
1377 }],
1378 annotations: vec![CocoAnnotation {
1379 id: large_id,
1380 image_id: 397133,
1381 category_id: 16,
1382 bbox: [192.81, 224.8, 74.73, 33.43],
1383 area: 1035.7,
1384 iscrowd: 0,
1385 segmentation: None,
1386 score: None,
1387 }],
1388 ..Default::default()
1389 };
1390
1391 let index = CocoIndex::from_dataset(&dataset);
1392 let samples = convert_image_annotations(&image, &index, true, None);
1393
1394 assert_eq!(samples.len(), 1);
1395 assert_eq!(samples[0].annotations.len(), 1);
1396 assert_eq!(
1397 samples[0].annotations[0].object_id(),
1398 Some(&large_id.to_string()),
1399 );
1400 }
1401
1402 #[test]
1403 fn test_convert_image_annotations_no_annotations() {
1404 let image = CocoImage {
1405 id: 1,
1406 width: 640,
1407 height: 480,
1408 file_name: "empty.jpg".to_string(),
1409 ..Default::default()
1410 };
1411
1412 let dataset = CocoDataset {
1413 images: vec![image.clone()],
1414 categories: vec![],
1415 annotations: vec![],
1416 ..Default::default()
1417 };
1418
1419 let index = CocoIndex::from_dataset(&dataset);
1420 let samples = convert_image_annotations(&image, &index, true, None);
1421
1422 assert!(samples.is_empty());
1423 }
1424
1425 #[test]
1430 fn test_sample_name_from_filename() {
1431 assert_eq!(
1432 sample_name_from_filename("000000397133.jpg"),
1433 "000000397133"
1434 );
1435 assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
1436 assert_eq!(sample_name_from_filename("test"), "test");
1437 }
1438
1439 #[test]
1440 fn test_sample_name_from_filename_nested_path() {
1441 assert_eq!(
1442 sample_name_from_filename("a/b/c/deep_image.png"),
1443 "deep_image"
1444 );
1445 }
1446
1447 #[test]
1448 fn test_sample_name_from_filename_no_extension() {
1449 assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
1450 }
1451
1452 #[test]
1457 fn test_coco_to_arrow_options_default() {
1458 let options = CocoToArrowOptions::default();
1459 assert!(options.include_masks);
1460 assert!(options.group.is_none());
1461 assert!(options.max_workers >= 2);
1462 }
1463
1464 #[test]
1465 fn test_arrow_to_coco_options_default() {
1466 let options = ArrowToCocoOptions::default();
1467 assert!(options.groups.is_empty());
1468 assert!(options.include_masks);
1469 assert!(options.info.is_none());
1470 }
1471
1472 #[test]
1473 fn test_max_workers() {
1474 let workers = max_workers();
1475 assert!(workers >= 2);
1476 assert!(workers <= 8);
1477 }
1478
1479 #[tokio::test]
1480 async fn test_coco_to_arrow_minimal() {
1481 let temp_dir = TempDir::new().unwrap();
1482
1483 let coco_json = r#"{
1485 "images": [
1486 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1487 ],
1488 "annotations": [
1489 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1490 ],
1491 "categories": [
1492 {"id": 1, "name": "person", "supercategory": "human"}
1493 ]
1494 }"#;
1495
1496 let coco_path = temp_dir.path().join("test.json");
1497 std::fs::write(&coco_path, coco_json).unwrap();
1498
1499 let arrow_path = temp_dir.path().join("output.arrow");
1500
1501 let options = CocoToArrowOptions::default();
1502 let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1503 .await
1504 .unwrap();
1505
1506 assert_eq!(count, 1);
1507 assert!(arrow_path.exists());
1508
1509 let mut file = std::fs::File::open(&arrow_path).unwrap();
1511 let df = IpcReader::new(&mut file).finish().unwrap();
1512 assert_eq!(df.height(), 1);
1513 }
1514
1515 #[tokio::test]
1516 async fn test_arrow_to_coco_roundtrip() {
1517 let temp_dir = TempDir::new().unwrap();
1518
1519 let original = CocoDataset {
1521 images: vec![CocoImage {
1522 id: 1,
1523 width: 640,
1524 height: 480,
1525 file_name: "test.jpg".to_string(),
1526 ..Default::default()
1527 }],
1528 annotations: vec![CocoAnnotation {
1529 id: 1,
1530 image_id: 1,
1531 category_id: 1,
1532 bbox: [100.0, 50.0, 200.0, 150.0],
1533 area: 30000.0,
1534 iscrowd: 0,
1535 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1536 100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
1537 ]])),
1538 score: None,
1539 }],
1540 categories: vec![CocoCategory {
1541 id: 1,
1542 name: "person".to_string(),
1543 supercategory: Some("human".to_string()),
1544 ..Default::default()
1545 }],
1546 ..Default::default()
1547 };
1548
1549 let coco_path = temp_dir.path().join("original.json");
1551 let writer = CocoWriter::new();
1552 writer.write_json(&original, &coco_path).unwrap();
1553
1554 let arrow_path = temp_dir.path().join("converted.arrow");
1556 let options = CocoToArrowOptions::default();
1557 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1558 .await
1559 .unwrap();
1560
1561 let restored_path = temp_dir.path().join("restored.json");
1563 let options = ArrowToCocoOptions::default();
1564 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1565 .await
1566 .unwrap();
1567
1568 let reader = CocoReader::new();
1570 let restored = reader.read_json(&restored_path).unwrap();
1571
1572 assert_eq!(restored.images.len(), 1);
1573 assert_eq!(restored.annotations.len(), 1);
1574 assert_eq!(restored.categories.len(), 1);
1575
1576 assert_eq!(restored.categories[0].name, "person");
1578 }
1579
1580 #[tokio::test]
1581 async fn test_arrow_to_coco_roundtrip_preserves_annotation_id() {
1582 let temp_dir = TempDir::new().unwrap();
1588
1589 let large_id: u64 = 9_876_543_210;
1590 let original = CocoDataset {
1591 images: vec![CocoImage {
1592 id: 1,
1593 width: 640,
1594 height: 480,
1595 file_name: "test.jpg".to_string(),
1596 ..Default::default()
1597 }],
1598 annotations: vec![
1599 CocoAnnotation {
1600 id: 1,
1601 image_id: 1,
1602 category_id: 1,
1603 bbox: [10.0, 20.0, 100.0, 80.0],
1604 area: 8000.0,
1605 iscrowd: 0,
1606 segmentation: None,
1607 score: None,
1608 },
1609 CocoAnnotation {
1610 id: large_id,
1611 image_id: 1,
1612 category_id: 1,
1613 bbox: [200.0, 200.0, 100.0, 100.0],
1614 area: 10000.0,
1615 iscrowd: 0,
1616 segmentation: None,
1617 score: None,
1618 },
1619 ],
1620 categories: vec![CocoCategory {
1621 id: 1,
1622 name: "person".to_string(),
1623 supercategory: Some("human".to_string()),
1624 ..Default::default()
1625 }],
1626 ..Default::default()
1627 };
1628
1629 let coco_path = temp_dir.path().join("original.json");
1630 let writer = CocoWriter::new();
1631 writer.write_json(&original, &coco_path).unwrap();
1632
1633 let arrow_path = temp_dir.path().join("converted.arrow");
1634 coco_to_arrow(
1635 &coco_path,
1636 &arrow_path,
1637 &CocoToArrowOptions::default(),
1638 None,
1639 )
1640 .await
1641 .unwrap();
1642
1643 let restored_path = temp_dir.path().join("restored.json");
1644 arrow_to_coco(
1645 &arrow_path,
1646 &restored_path,
1647 &ArrowToCocoOptions::default(),
1648 None,
1649 )
1650 .await
1651 .unwrap();
1652
1653 let restored = CocoReader::new().read_json(&restored_path).unwrap();
1654 assert_eq!(restored.annotations.len(), 2);
1655
1656 let restored_ids: std::collections::HashSet<u64> =
1657 restored.annotations.iter().map(|a| a.id).collect();
1658 assert!(
1659 restored_ids.contains(&1),
1660 "small annotation id (1) must round-trip; got {restored_ids:?}"
1661 );
1662 assert!(
1663 restored_ids.contains(&large_id),
1664 "33-bit LVIS-scale annotation id ({large_id}) must round-trip; got {restored_ids:?}"
1665 );
1666 }
1667
1668 #[tokio::test]
1673 async fn test_coco_to_arrow_schema_version_metadata() {
1674 let temp_dir = TempDir::new().unwrap();
1675
1676 let coco_json = r#"{
1678 "images": [
1679 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1680 ],
1681 "annotations": [
1682 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1683 ],
1684 "categories": [
1685 {"id": 1, "name": "person", "supercategory": "human"}
1686 ]
1687 }"#;
1688
1689 let coco_path = temp_dir.path().join("test.json");
1690 std::fs::write(&coco_path, coco_json).unwrap();
1691
1692 let arrow_path = temp_dir.path().join("output.arrow");
1693 let options = CocoToArrowOptions::default();
1694 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1695 .await
1696 .unwrap();
1697
1698 let mut file = std::fs::File::open(&arrow_path).unwrap();
1700 let mut reader = IpcReader::new(&mut file);
1701 let custom_meta = reader.custom_metadata().unwrap();
1702 assert!(custom_meta.is_some(), "custom metadata should be present");
1703
1704 let meta = custom_meta.unwrap();
1705 assert_eq!(
1706 meta.get(&PlSmallStr::from("schema_version")),
1707 Some(&PlSmallStr::from(SCHEMA_VERSION)),
1708 "schema_version metadata should be '2026.04'"
1709 );
1710
1711 assert!(
1713 meta.contains_key(&PlSmallStr::from("category_metadata")),
1714 "category_metadata should be present even without LVIS fields"
1715 );
1716 }
1717
1718 #[tokio::test]
1719 async fn test_coco_to_arrow_category_metadata_lvis() {
1720 let temp_dir = TempDir::new().unwrap();
1721
1722 let coco_json = r#"{
1724 "images": [
1725 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1726 ],
1727 "annotations": [
1728 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0},
1729 {"id": 2, "image_id": 1, "category_id": 2, "bbox": [50, 60, 80, 40], "area": 3200, "iscrowd": 0}
1730 ],
1731 "categories": [
1732 {
1733 "id": 1,
1734 "name": "aerosol_can",
1735 "synset": "aerosol.n.02",
1736 "synonyms": ["aerosol_can", "spray_can"],
1737 "def": "a dispenser that holds a substance under pressure"
1738 },
1739 {
1740 "id": 2,
1741 "name": "person",
1742 "supercategory": "human"
1743 }
1744 ]
1745 }"#;
1746
1747 let coco_path = temp_dir.path().join("lvis.json");
1748 std::fs::write(&coco_path, coco_json).unwrap();
1749
1750 let arrow_path = temp_dir.path().join("lvis_output.arrow");
1751 let options = CocoToArrowOptions::default();
1752 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1753 .await
1754 .unwrap();
1755
1756 let mut file = std::fs::File::open(&arrow_path).unwrap();
1758 let mut reader = IpcReader::new(&mut file);
1759 let custom_meta = reader.custom_metadata().unwrap();
1760 assert!(custom_meta.is_some(), "custom metadata should be present");
1761
1762 let meta = custom_meta.unwrap();
1763
1764 assert_eq!(
1766 meta.get(&PlSmallStr::from("schema_version")),
1767 Some(&PlSmallStr::from(SCHEMA_VERSION)),
1768 );
1769
1770 let cat_meta_str = meta
1772 .get(&PlSmallStr::from("category_metadata"))
1773 .expect("category_metadata should be present for LVIS data");
1774
1775 let cat_meta: HashMap<String, serde_json::Value> =
1776 serde_json::from_str(cat_meta_str.as_str()).unwrap();
1777
1778 assert!(
1780 cat_meta.contains_key("aerosol_can"),
1781 "aerosol_can should be in category_metadata"
1782 );
1783 assert!(
1784 cat_meta.contains_key("person"),
1785 "person should also be in category_metadata"
1786 );
1787
1788 let aerosol = cat_meta.get("aerosol_can").unwrap();
1790 assert_eq!(
1791 aerosol.get("synset").and_then(|v| v.as_str()),
1792 Some("aerosol.n.02")
1793 );
1794 assert_eq!(
1795 aerosol.get("definition").and_then(|v| v.as_str()),
1796 Some("a dispenser that holds a substance under pressure")
1797 );
1798 let synonyms = aerosol.get("synonyms").and_then(|v| v.as_array()).unwrap();
1799 assert_eq!(synonyms.len(), 2);
1800 assert_eq!(synonyms[0].as_str(), Some("aerosol_can"));
1801 assert_eq!(synonyms[1].as_str(), Some("spray_can"));
1802 }
1803
1804 #[tokio::test]
1809 async fn test_coco_arrow_roundtrip_lvis_supercategory() {
1810 let temp_dir = TempDir::new().unwrap();
1811
1812 let coco_json = r#"{
1814 "images": [
1815 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1816 ],
1817 "annotations": [
1818 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1819 ],
1820 "categories": [
1821 {"id": 1, "name": "person", "supercategory": "human"}
1822 ]
1823 }"#;
1824
1825 let coco_path = temp_dir.path().join("original.json");
1826 std::fs::write(&coco_path, coco_json).unwrap();
1827
1828 let arrow_path = temp_dir.path().join("converted.arrow");
1830 let options = CocoToArrowOptions::default();
1831 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1832 .await
1833 .unwrap();
1834
1835 let restored_path = temp_dir.path().join("restored.json");
1837 let options = ArrowToCocoOptions::default();
1838 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1839 .await
1840 .unwrap();
1841
1842 let reader = CocoReader::new();
1844 let restored = reader.read_json(&restored_path).unwrap();
1845
1846 assert_eq!(restored.categories.len(), 1);
1847 assert_eq!(restored.categories[0].name, "person");
1848 assert_eq!(
1849 restored.categories[0].supercategory,
1850 Some("human".to_string()),
1851 "supercategory should survive COCO→Arrow→COCO round-trip"
1852 );
1853 }
1854
1855 #[tokio::test]
1856 async fn test_coco_arrow_roundtrip_neg_categories_no_annotations() {
1857 let temp_dir = TempDir::new().unwrap();
1858
1859 let coco_json = r#"{
1861 "images": [
1862 {
1863 "id": 1,
1864 "width": 640,
1865 "height": 480,
1866 "file_name": "empty.jpg",
1867 "neg_category_ids": [1, 2]
1868 }
1869 ],
1870 "annotations": [],
1871 "categories": [
1872 {"id": 1, "name": "cat", "supercategory": "animal"},
1873 {"id": 2, "name": "dog", "supercategory": "animal"}
1874 ]
1875 }"#;
1876
1877 let coco_path = temp_dir.path().join("original.json");
1878 std::fs::write(&coco_path, coco_json).unwrap();
1879
1880 let arrow_path = temp_dir.path().join("converted.arrow");
1882 let options = CocoToArrowOptions::default();
1883 let sample_count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1884 .await
1885 .unwrap();
1886
1887 assert_eq!(
1889 sample_count, 1,
1890 "sentinel row should be emitted for image with neg data"
1891 );
1892
1893 let restored_path = temp_dir.path().join("restored.json");
1895 let options = ArrowToCocoOptions::default();
1896 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1897 .await
1898 .unwrap();
1899
1900 let reader = CocoReader::new();
1902 let restored = reader.read_json(&restored_path).unwrap();
1903
1904 assert_eq!(restored.images.len(), 1);
1905 assert_eq!(restored.annotations.len(), 0, "no annotations expected");
1906 assert_eq!(restored.categories.len(), 2, "both categories should exist");
1907
1908 let neg = restored.images[0].neg_category_ids.as_ref();
1909 assert!(
1910 neg.is_some(),
1911 "neg_category_ids should survive round-trip for zero-annotation image"
1912 );
1913 let neg_ids = neg.unwrap();
1914 assert_eq!(neg_ids.len(), 2, "should have 2 neg categories");
1915 assert!(neg_ids.contains(&1), "neg_category_ids should contain 1");
1916 assert!(neg_ids.contains(&2), "neg_category_ids should contain 2");
1917
1918 for cat in &restored.categories {
1920 assert_eq!(
1921 cat.supercategory,
1922 Some("animal".to_string()),
1923 "supercategory should survive round-trip for annotation-free category '{}'",
1924 cat.name
1925 );
1926 }
1927 }
1928
1929 #[test]
1930 fn test_convert_image_annotations_neg_only_no_annotations() {
1931 let image = CocoImage {
1932 id: 1,
1933 width: 640,
1934 height: 480,
1935 file_name: "neg_only.jpg".to_string(),
1936 neg_category_ids: Some(vec![1, 2]),
1937 ..Default::default()
1938 };
1939
1940 let dataset = CocoDataset {
1941 images: vec![image.clone()],
1942 categories: vec![
1943 CocoCategory {
1944 id: 1,
1945 name: "cat".to_string(),
1946 supercategory: Some("animal".to_string()),
1947 ..Default::default()
1948 },
1949 CocoCategory {
1950 id: 2,
1951 name: "dog".to_string(),
1952 supercategory: Some("animal".to_string()),
1953 ..Default::default()
1954 },
1955 ],
1956 annotations: vec![],
1957 ..Default::default()
1958 };
1959
1960 let index = CocoIndex::from_dataset(&dataset);
1961 let samples = convert_image_annotations(&image, &index, true, None);
1962
1963 assert_eq!(
1965 samples.len(),
1966 1,
1967 "sentinel row should be emitted for neg-only image"
1968 );
1969 assert_eq!(samples[0].image_name, Some("neg_only".to_string()));
1970 assert!(
1971 samples[0].annotations.is_empty(),
1972 "sentinel should have no annotations"
1973 );
1974 assert!(
1975 samples[0].neg_label_indices.is_some(),
1976 "sentinel should preserve neg_label_indices"
1977 );
1978 assert_eq!(samples[0].neg_label_indices.as_ref().unwrap().len(), 2);
1979 }
1980}