1use super::{
10 convert::{
11 box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask_data,
12 coco_segmentation_to_polygon, polygon_to_coco_polygon,
13 },
14 reader::CocoReader,
15 types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
16 writer::{CocoDatasetBuilder, CocoWriter},
17};
18use crate::{Annotation, Box2d, Error, Polygon, Progress, Sample};
19use polars::prelude::*;
20use std::{
21 collections::{BTreeMap, HashMap},
22 path::Path,
23 sync::{
24 Arc,
25 atomic::{AtomicUsize, Ordering},
26 },
27};
28use tokio::sync::{Semaphore, mpsc::Sender};
29
30pub const SCHEMA_VERSION: &str = "2026.04";
32
33type PolygonRings = Vec<Vec<(f32, f32)>>;
35
36#[derive(Debug, Clone)]
38pub struct CocoToArrowOptions {
39 pub include_masks: bool,
41 pub group: Option<String>,
43 pub max_workers: usize,
45}
46
47impl Default for CocoToArrowOptions {
48 fn default() -> Self {
49 Self {
50 include_masks: true,
51 group: None,
52 max_workers: max_workers(),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ArrowToCocoOptions {
60 pub groups: Vec<String>,
62 pub include_masks: bool,
64 pub info: Option<CocoInfo>,
66}
67
68impl Default for ArrowToCocoOptions {
69 fn default() -> Self {
70 Self {
71 groups: vec![],
72 include_masks: true,
73 info: None,
74 }
75 }
76}
77
78fn max_workers() -> usize {
80 std::env::var("MAX_COCO_WORKERS")
81 .ok()
82 .and_then(|v| v.parse().ok())
83 .unwrap_or_else(|| {
84 let cpus = std::thread::available_parallelism()
85 .map(|n| n.get())
86 .unwrap_or(4);
87 (cpus / 2).clamp(2, 8)
88 })
89}
90
91pub async fn coco_to_arrow<P: AsRef<Path>>(
105 coco_path: P,
106 output_path: P,
107 options: &CocoToArrowOptions,
108 progress: Option<Sender<Progress>>,
109) -> Result<usize, Error> {
110 let coco_path = coco_path.as_ref();
111 let output_path = output_path.as_ref();
112
113 let reader = CocoReader::new();
115 let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
116 reader.read_annotations_zip(coco_path)?
117 } else {
118 reader.read_json(coco_path)?
119 };
120
121 let index = Arc::new(CocoIndex::from_dataset(&dataset));
123 let total_images = dataset.images.len();
124
125 if let Some(ref p) = progress {
127 let _ = p
128 .send(Progress {
129 current: 0,
130 total: total_images,
131 status: None,
132 })
133 .await;
134 }
135
136 let sem = Arc::new(Semaphore::new(options.max_workers));
138 let current = Arc::new(AtomicUsize::new(0));
139 let include_masks = options.include_masks;
140 let group = options.group.clone();
141
142 let mut tasks = Vec::with_capacity(total_images);
143
144 for image in dataset.images {
145 let sem = sem.clone();
146 let index = index.clone();
147 let current = current.clone();
148 let progress = progress.clone();
149 let total = total_images;
150 let group = group.clone();
151
152 let task = tokio::spawn(async move {
153 let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
154
155 let samples =
157 convert_image_annotations(&image, &index, include_masks, group.as_deref());
158
159 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
161 if let Some(ref p) = progress {
162 let _ = p
163 .send(Progress {
164 current: c,
165 total,
166 status: None,
167 })
168 .await;
169 }
170
171 Ok::<Vec<Sample>, Error>(samples)
172 });
173
174 tasks.push(task);
175 }
176
177 let mut all_samples = Vec::with_capacity(total_images);
179 for task in tasks {
180 let samples = task.await??;
181 all_samples.extend(samples);
182 }
183
184 let df = crate::samples_dataframe(&all_samples)?;
186
187 let mut metadata: BTreeMap<PlSmallStr, PlSmallStr> = BTreeMap::new();
189 metadata.insert(
190 PlSmallStr::from("schema_version"),
191 PlSmallStr::from(SCHEMA_VERSION),
192 );
193
194 if !dataset.categories.is_empty() {
200 let cat_meta: HashMap<String, serde_json::Value> = dataset
201 .categories
202 .iter()
203 .map(|c| {
204 let mut entry = serde_json::Map::new();
205 entry.insert("id".to_string(), serde_json::json!(c.id));
206 if let Some(ref f) = c.frequency {
207 entry.insert(
208 "frequency".to_string(),
209 serde_json::Value::String(f.clone()),
210 );
211 }
212 if let Some(ref s) = c.synset {
213 entry.insert("synset".to_string(), serde_json::Value::String(s.clone()));
214 }
215 if let Some(ref syns) = c.synonyms {
216 entry.insert("synonyms".to_string(), serde_json::json!(syns));
217 }
218 if let Some(ref d) = c.def {
219 entry.insert(
220 "definition".to_string(),
221 serde_json::Value::String(d.clone()),
222 );
223 }
224 if let Some(ref sc) = c.supercategory {
225 entry.insert(
226 "supercategory".to_string(),
227 serde_json::Value::String(sc.clone()),
228 );
229 }
230 (c.name.clone(), serde_json::Value::Object(entry))
234 })
235 .collect();
236
237 let json = serde_json::to_string(&cat_meta).unwrap_or_default();
238 metadata.insert(
239 PlSmallStr::from("category_metadata"),
240 PlSmallStr::from(json.as_str()),
241 );
242 }
243
244 if !dataset.categories.is_empty() {
246 let mut cats: Vec<_> = dataset.categories.iter().collect();
247 cats.sort_by_key(|c| c.id);
248 let labels: Vec<String> = cats.iter().map(|c| c.name.clone()).collect();
249 let labels_json = serde_json::to_string(&labels).unwrap_or_default();
250 metadata.insert(PlSmallStr::from("labels"), PlSmallStr::from(labels_json));
251 }
252
253 if let Some(parent) = output_path.parent()
255 && !parent.as_os_str().is_empty()
256 {
257 std::fs::create_dir_all(parent)?;
258 }
259 let mut file = std::fs::File::create(output_path)?;
260 let mut writer = IpcWriter::new(&mut file);
261 writer.set_custom_schema_metadata(Arc::new(metadata));
262 writer.finish(&mut df.clone())?;
263
264 Ok(all_samples.len())
265}
266
267fn convert_image_annotations(
269 image: &CocoImage,
270 index: &CocoIndex,
271 include_masks: bool,
272 group: Option<&str>,
273) -> Vec<Sample> {
274 let annotations = index.annotations_for_image(image.id);
275 let sample_name = sample_name_from_filename(&image.file_name);
276
277 let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
279 ids.iter()
280 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
281 .collect::<Vec<u32>>()
282 });
283 let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
284 ids.iter()
285 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
286 .collect::<Vec<u32>>()
287 });
288
289 let mut samples: Vec<Sample> = annotations
290 .iter()
291 .filter_map(|ann| {
292 let label = index.label_name(ann.category_id)?;
293 let label_index = index.label_index(ann.category_id);
294
295 let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
297
298 let (polygon, mask) = if include_masks {
302 if let Some(seg) = &ann.segmentation {
303 match seg {
304 CocoSegmentation::Polygon(_) => {
305 let poly =
306 coco_segmentation_to_polygon(seg, image.width, image.height).ok();
307 (poly, None)
308 }
309 CocoSegmentation::Rle(_) | CocoSegmentation::CompressedRle(_) => {
310 let mask_data = coco_segmentation_to_mask_data(seg).ok().flatten();
311 (None, mask_data)
312 }
313 }
314 } else {
315 (None, None)
316 }
317 } else {
318 (None, None)
319 };
320
321 let mut annotation = Annotation::new();
322 annotation.set_name(Some(sample_name.clone()));
323 annotation.set_label(Some(label.to_string()));
324 annotation.set_label_index(label_index);
325 annotation.set_box2d(Some(box2d));
326 annotation.set_polygon(polygon);
327 annotation.set_mask(mask);
328 annotation.set_group(group.map(String::from));
329 annotation.set_iscrowd(Some(ann.iscrowd != 0));
330 annotation.set_category_frequency(index.frequency(ann.category_id).map(String::from));
331
332 if let Some(score) = ann.score {
334 let score_f32 = score as f32;
335 if annotation.mask().is_some() {
336 annotation.set_mask_score(Some(score_f32));
337 } else if annotation.polygon().is_some() {
338 annotation.set_polygon_score(Some(score_f32));
339 } else {
340 annotation.set_box2d_score(Some(score_f32));
341 }
342 }
343
344 let mut sample = Sample {
345 image_name: Some(sample_name.clone()),
346 width: Some(image.width),
347 height: Some(image.height),
348 group: group.map(String::from),
349 annotations: vec![annotation],
350 ..Default::default()
351 };
352 sample.neg_label_indices = neg_label_indices.clone();
353 sample.not_exhaustive_label_indices = not_exhaustive_label_indices.clone();
354
355 Some(sample)
356 })
357 .collect();
358
359 if samples.is_empty()
363 && (image.neg_category_ids.is_some() || image.not_exhaustive_category_ids.is_some())
364 {
365 let mut sample = Sample {
366 image_name: Some(sample_name.clone()),
367 width: Some(image.width),
368 height: Some(image.height),
369 group: group.map(String::from),
370 ..Default::default()
371 };
372 sample.neg_label_indices = neg_label_indices;
373 sample.not_exhaustive_label_indices = not_exhaustive_label_indices;
374 samples.push(sample);
375 }
376
377 samples
378}
379
380fn sample_name_from_filename(filename: &str) -> String {
382 Path::new(filename)
383 .file_stem()
384 .and_then(|s| s.to_str())
385 .map(String::from)
386 .unwrap_or_else(|| filename.to_string())
387}
388
389pub async fn arrow_to_coco<P: AsRef<Path>>(
405 arrow_path: P,
406 output_path: P,
407 options: &ArrowToCocoOptions,
408 progress: Option<Sender<Progress>>,
409) -> Result<usize, Error> {
410 let arrow_path = arrow_path.as_ref();
411 let output_path = output_path.as_ref();
412
413 let (schema_version, category_metadata_json, labels_metadata_json) = {
415 let mut meta_file = std::fs::File::open(arrow_path)?;
416 let mut reader = IpcReader::new(&mut meta_file);
417 let meta = reader.custom_metadata().ok().flatten();
418 let sv = meta.as_ref().and_then(|m| {
419 m.get(&PlSmallStr::from("schema_version"))
420 .map(|s| s.to_string())
421 });
422 let cm = meta.as_ref().and_then(|m| {
423 m.get(&PlSmallStr::from("category_metadata"))
424 .map(|s| s.to_string())
425 });
426 let lm = meta
427 .as_ref()
428 .and_then(|m| m.get(&PlSmallStr::from("labels")).map(|s| s.to_string()));
429 (sv, cm, lm)
430 };
431
432 let is_legacy = schema_version.is_none();
434
435 let mut file = std::fs::File::open(arrow_path)?;
437 let df = IpcReader::new(&mut file).finish()?;
438
439 let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
441
442 let total_rows = df.height();
443
444 if let Some(ref p) = progress {
445 let _ = p
446 .send(Progress {
447 current: 0,
448 total: total_rows,
449 status: None,
450 })
451 .await;
452 }
453
454 let names: Vec<String> = df
456 .column("name")?
457 .str()?
458 .into_iter()
459 .map(|s| s.unwrap_or_default().to_string())
460 .collect();
461
462 let labels: Vec<String> = df
463 .column("label")
464 .ok()
465 .and_then(|c| c.cast(&DataType::String).ok())
466 .map(|c| {
467 c.str()
468 .ok()
469 .map(|s| {
470 s.into_iter()
471 .map(|v| v.unwrap_or_default().to_string())
472 .collect()
473 })
474 .unwrap_or_else(|| vec![String::new(); total_rows])
475 })
476 .unwrap_or_else(|| vec![String::new(); total_rows]);
477
478 let label_indices: Vec<Option<u64>> = df
479 .column("label_index")
480 .ok()
481 .map(|c| {
482 c.u64()
483 .ok()
484 .map(|s| s.into_iter().collect())
485 .unwrap_or_else(|| vec![None; total_rows])
486 })
487 .unwrap_or_else(|| vec![None; total_rows]);
488
489 let groups: Vec<String> = df
491 .column("group")
492 .ok()
493 .and_then(|c| c.cast(&DataType::String).ok())
494 .map(|c| {
495 c.str()
496 .ok()
497 .map(|s| {
498 s.into_iter()
499 .map(|v| v.unwrap_or_default().to_string())
500 .collect()
501 })
502 .unwrap_or_default()
503 })
504 .unwrap_or_else(|| vec!["".to_string(); total_rows]);
505
506 let box2ds = df
508 .column("box2d")
509 .ok()
510 .map(extract_all_box2ds)
511 .transpose()?
512 .unwrap_or_else(|| vec![[0.0; 4]; total_rows]);
513
514 let legacy_masks: Option<Vec<Vec<f32>>> = if is_legacy && options.include_masks {
519 df.column("mask").ok().map(extract_all_masks).transpose()?
520 } else {
521 None
522 };
523
524 let polygons_2026: Option<Vec<Option<PolygonRings>>> = if !is_legacy && options.include_masks {
525 df.column("polygon")
526 .ok()
527 .map(|c| extract_all_polygons(c, total_rows))
528 } else {
529 None
530 };
531
532 let mask_binary_2026: Option<Vec<Option<Vec<u8>>>> = if !is_legacy && options.include_masks {
533 df.column("mask")
534 .ok()
535 .map(|c| extract_all_binary_masks(c, total_rows))
536 } else {
537 None
538 };
539
540 let sizes = df
542 .column("size")
543 .ok()
544 .and_then(|c| extract_all_sizes(c).ok());
545
546 let iscrowds: Vec<u8> = df
548 .column("iscrowd")
549 .ok()
550 .map(|c| {
551 if let Ok(bool_ca) = c.bool() {
553 bool_ca
554 .into_iter()
555 .map(|v| if v.unwrap_or(false) { 1 } else { 0 })
556 .collect()
557 } else {
558 c.u32()
559 .ok()
560 .map(|s| s.into_iter().map(|v| v.unwrap_or(0) as u8).collect())
561 .unwrap_or_else(|| vec![0; total_rows])
562 }
563 })
564 .unwrap_or_else(|| vec![0; total_rows]);
565
566 let category_frequencies: Vec<Option<String>> = df
568 .column("category_frequency")
569 .ok()
570 .and_then(|c| c.cast(&DataType::String).ok())
571 .map(|c| {
572 c.str()
573 .ok()
574 .map(|s| s.into_iter().map(|v| v.map(String::from)).collect())
575 .unwrap_or_else(|| vec![None; total_rows])
576 })
577 .unwrap_or_else(|| vec![None; total_rows]);
578
579 let neg_label_indices: Vec<Option<Vec<u32>>> = df
581 .column("neg_label_indices")
582 .ok()
583 .map(|c| extract_list_u32_column(c, total_rows))
584 .unwrap_or_else(|| vec![None; total_rows]);
585
586 let not_exhaustive_label_indices: Vec<Option<Vec<u32>>> = df
588 .column("not_exhaustive_label_indices")
589 .ok()
590 .map(|c| extract_list_u32_column(c, total_rows))
591 .unwrap_or_else(|| vec![None; total_rows]);
592
593 let box2d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box2d_score", total_rows);
595 let box3d_scores: Vec<Option<f32>> = extract_f32_column(&df, "box3d_score", total_rows);
596 let polygon_scores: Vec<Option<f32>> = extract_f32_column(&df, "polygon_score", total_rows);
597 let mask_scores: Vec<Option<f32>> = extract_f32_column(&df, "mask_score", total_rows);
598
599 let mut builder = CocoDatasetBuilder::new();
601
602 if let Some(info) = &options.info {
603 builder = builder.info(info.clone());
604 }
605
606 let skip_row = |i: usize| -> bool {
608 !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i])
609 };
610
611 let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
613 let mut image_ids: HashMap<String, u64> = HashMap::new();
614 let mut category_ids: HashMap<String, u32> = HashMap::new();
615
616 for i in 0..total_rows {
618 if skip_row(i) {
619 continue;
620 }
621
622 let name = &names[i];
623 let label = &labels[i];
624
625 if !image_ids.contains_key(name) {
627 let (width, height) = sizes
628 .as_ref()
629 .and_then(|s| s.get(i).copied())
630 .unwrap_or((0, 0));
631
632 let id = builder.add_image(name, width, height);
633 image_ids.insert(name.clone(), id);
634 image_dimensions.insert(name.clone(), (width, height));
635 }
636
637 if !label.is_empty() && !category_ids.contains_key(label) {
638 let id = if let Some(Some(idx)) = label_indices.get(i) {
639 builder.add_category_with_id(*idx as u32, label, None)
640 } else {
641 builder.add_category(label, None)
642 };
643 category_ids.insert(label.clone(), id);
644 }
645 }
646
647 let mut last_progress_update = 0;
649 for i in 0..total_rows {
650 if skip_row(i) {
651 continue;
652 }
653
654 let name = &names[i];
655 let label = &labels[i];
656
657 if label.is_empty() {
659 continue;
660 }
661
662 let image_id = *image_ids.get(name).unwrap_or(&0);
663 let category_id = *category_ids.get(label).unwrap_or(&0);
664 let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
665
666 let bbox = box2ds.get(i).map(|box2d| {
669 let cx = box2d[0];
670 let cy = box2d[1];
671 let w = box2d[2];
672 let h = box2d[3];
673 let left = cx - w / 2.0;
675 let top = cy - h / 2.0;
676 let ef_box2d = Box2d::new(left, top, w, h);
677 box2d_to_coco_bbox(&ef_box2d, width, height)
678 });
679
680 let segmentation = if options.include_masks {
682 if is_legacy {
683 legacy_masks.as_ref().and_then(|m| {
685 m.get(i).and_then(|coords| {
686 if coords.is_empty() {
687 None
688 } else {
689 let rings = crate::unflatten_polygon_coordinates(coords);
690 let polygon = Polygon::new(rings);
691 let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
692 if coco_poly.is_empty() {
693 None
694 } else {
695 Some(CocoSegmentation::Polygon(coco_poly))
696 }
697 }
698 })
699 })
700 } else {
701 let mask_seg = mask_binary_2026.as_ref().and_then(|masks| {
703 masks.get(i).and_then(|opt_bytes| {
704 opt_bytes
705 .as_ref()
706 .and_then(|png_bytes| png_to_rle_segmentation(png_bytes, i))
707 })
708 });
709
710 if mask_seg.is_some() {
711 mask_seg
712 } else {
713 polygons_2026.as_ref().and_then(|polys| {
715 polys.get(i).and_then(|opt_rings| {
716 opt_rings.as_ref().and_then(|rings| {
717 if rings.is_empty() {
718 return None;
719 }
720 let polygon = Polygon::new(rings.clone());
721 let coco_poly = polygon_to_coco_polygon(&polygon, width, height);
722 if coco_poly.is_empty() {
723 None
724 } else {
725 Some(CocoSegmentation::Polygon(coco_poly))
726 }
727 })
728 })
729 })
730 }
731 }
732 } else {
733 None
734 };
735
736 let score: Option<f64> = mask_scores[i]
738 .or(polygon_scores[i])
739 .or(box3d_scores[i])
740 .or(box2d_scores[i])
741 .map(|s| s as f64);
742
743 if let Some(bbox) = bbox {
744 let iscrowd = iscrowds[i];
745 let ann_id = builder.add_annotation_with_iscrowd(
746 image_id,
747 category_id,
748 bbox,
749 segmentation,
750 iscrowd,
751 );
752
753 if let Some(score_val) = score {
755 builder.set_annotation_score(ann_id, score_val);
756 }
757 }
758
759 if let Some(ref p) = progress
761 && (i - last_progress_update >= 1000 || i == total_rows - 1)
762 {
763 let _ = p
764 .send(Progress {
765 current: i + 1,
766 total: total_rows,
767 status: None,
768 })
769 .await;
770 last_progress_update = i;
771 }
772 }
773
774 if let Some(ref p) = progress
776 && last_progress_update < total_rows.saturating_sub(1)
777 {
778 let _ = p
779 .send(Progress {
780 current: total_rows,
781 total: total_rows,
782 status: None,
783 })
784 .await;
785 }
786
787 {
790 let mut processed_images: std::collections::HashSet<u64> = std::collections::HashSet::new();
791 for i in 0..total_rows {
792 if skip_row(i) {
793 continue;
794 }
795 let name = &names[i];
796 if let Some(&image_id) = image_ids.get(name) {
797 if !processed_images.insert(image_id) {
798 continue;
799 }
800 let neg = neg_label_indices[i].clone();
801 let not_exhaustive = not_exhaustive_label_indices[i].clone();
802 if neg.is_some() || not_exhaustive.is_some() {
803 builder.set_image_neg_categories(image_id, neg, not_exhaustive);
804 }
805 }
806 }
807 }
808
809 {
812 let mut freq_map: HashMap<String, String> = HashMap::new();
813 for i in 0..total_rows {
814 if skip_row(i) {
815 continue;
816 }
817 let label = &labels[i];
818 if !label.is_empty()
819 && !freq_map.contains_key(label)
820 && let Some(ref freq) = category_frequencies[i]
821 {
822 freq_map.insert(label.clone(), freq.clone());
823 }
824 }
825 for (name, freq) in &freq_map {
826 builder.set_category_metadata(name, None, Some(freq.clone()), None, None);
827 }
828 }
829
830 if let Some(ref json_str) = category_metadata_json
837 && let Ok(meta) = serde_json::from_str::<HashMap<String, serde_json::Value>>(json_str)
838 {
839 for (cat_name, value) in &meta {
840 let supercategory = value.get("supercategory").and_then(|v| v.as_str());
841
842 if !category_ids.contains_key(cat_name.as_str()) {
844 let cat_id = value.get("id").and_then(|v| v.as_u64()).map(|id| id as u32);
845 let id = if let Some(cat_id) = cat_id {
846 builder.add_category_with_id(cat_id, cat_name, supercategory)
847 } else {
848 builder.add_category(cat_name, supercategory)
849 };
850 category_ids.insert(cat_name.clone(), id);
851 } else {
852 if let Some(sc) = supercategory {
854 builder.set_category_supercategory(cat_name, sc);
855 }
856 }
857
858 let synset = value
859 .get("synset")
860 .and_then(|v| v.as_str())
861 .map(String::from);
862 let frequency = value
863 .get("frequency")
864 .and_then(|v| v.as_str())
865 .map(String::from);
866 let synonyms = value.get("synonyms").and_then(|v| {
867 v.as_array().map(|arr| {
868 arr.iter()
869 .filter_map(|s| s.as_str().map(String::from))
870 .collect()
871 })
872 });
873 let def = value
874 .get("definition")
875 .and_then(|v| v.as_str())
876 .map(String::from);
877
878 builder.set_category_metadata(cat_name, synset, frequency, synonyms, def);
879 }
880 }
881
882 if category_metadata_json.is_none()
885 && let Some(ref labels_json) = labels_metadata_json
886 && let Ok(label_names) = serde_json::from_str::<Vec<String>>(labels_json)
887 {
888 for label_name in &label_names {
889 if !category_ids.contains_key(label_name) {
890 let id = builder.add_category(label_name, None);
891 category_ids.insert(label_name.clone(), id);
892 }
893 }
894 }
895
896 let dataset = builder.build();
897 let annotation_count = dataset.annotations.len();
898
899 let writer = CocoWriter::new();
901 writer.write_json(&dataset, output_path)?;
902
903 Ok(annotation_count)
904}
905
906fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
908 let arr = col.array()?;
909 let mut result = Vec::with_capacity(arr.len());
910
911 for inner in arr.amortized_iter() {
912 let values = if let Some(inner) = inner {
913 let series = inner.as_ref();
914 let vals: Vec<f32> = series
915 .f32()
916 .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
917 .into_iter()
918 .map(|v| v.unwrap_or(0.0))
919 .collect();
920
921 if vals.len() == 4 {
922 [vals[0], vals[1], vals[2], vals[3]]
923 } else {
924 [0.0, 0.0, 0.0, 0.0]
925 }
926 } else {
927 [0.0, 0.0, 0.0, 0.0]
928 };
929 result.push(values);
930 }
931
932 Ok(result)
933}
934
935fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
937 let list = col.list()?;
938 let mut result = Vec::with_capacity(list.len());
939
940 for i in 0..list.len() {
941 let coords = match list.get_as_series(i) {
942 Some(series) => series
943 .f32()
944 .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
945 .into_iter()
946 .map(|v| v.unwrap_or(f32::NAN))
947 .collect(),
948 None => vec![],
949 };
950 result.push(coords);
951 }
952
953 Ok(result)
954}
955
956fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
958 let arr = col.array()?;
959 let mut result = Vec::with_capacity(arr.len());
960
961 for inner in arr.amortized_iter() {
962 let size = if let Some(inner) = inner {
963 let series = inner.as_ref();
964 let values: Vec<u32> = series
965 .u32()
966 .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
967 .into_iter()
968 .map(|v| v.unwrap_or(0))
969 .collect();
970
971 if values.len() >= 2 {
972 (values[0], values[1])
973 } else {
974 (0, 0)
975 }
976 } else {
977 (0, 0)
978 };
979 result.push(size);
980 }
981
982 Ok(result)
983}
984
985fn extract_list_u32_column(col: &Column, total_rows: usize) -> Vec<Option<Vec<u32>>> {
987 col.list()
988 .ok()
989 .map(|list| {
990 (0..list.len())
991 .map(|i| {
992 list.get_as_series(i).and_then(|series| {
993 series
994 .u32()
995 .ok()
996 .map(|ca| ca.into_iter().flatten().collect::<Vec<u32>>())
997 })
998 })
999 .collect()
1000 })
1001 .unwrap_or_else(|| vec![None; total_rows])
1002}
1003
1004fn extract_all_polygons(col: &Column, total_rows: usize) -> Vec<Option<PolygonRings>> {
1009 let outer_list = match col.list() {
1010 Ok(l) => l,
1011 Err(_) => return vec![None; total_rows],
1012 };
1013
1014 let mut result = Vec::with_capacity(total_rows);
1015 for i in 0..outer_list.len() {
1016 let rings = outer_list.get_as_series(i).and_then(|ring_series| {
1017 let inner_list = ring_series.list().ok()?;
1018 let mut rings = Vec::new();
1019 for j in 0..inner_list.len() {
1020 if let Some(coords_series) = inner_list.get_as_series(j)
1021 && let Ok(f32_ca) = coords_series.f32()
1022 {
1023 let coords: Vec<f32> = f32_ca.into_iter().map(|v| v.unwrap_or(0.0)).collect();
1024 let points: Vec<(f32, f32)> = coords
1026 .chunks(2)
1027 .filter(|c| c.len() == 2)
1028 .map(|c| (c[0], c[1]))
1029 .collect();
1030 if !points.is_empty() {
1031 rings.push(points);
1032 }
1033 }
1034 }
1035 if rings.is_empty() { None } else { Some(rings) }
1036 });
1037 result.push(rings);
1038 }
1039 result
1040}
1041
1042fn extract_all_binary_masks(col: &Column, total_rows: usize) -> Vec<Option<Vec<u8>>> {
1044 let binary_ca = match col.binary() {
1045 Ok(b) => b,
1046 Err(_) => return vec![None; total_rows],
1047 };
1048
1049 (0..binary_ca.len())
1050 .map(|i| binary_ca.get(i).map(|bytes| bytes.to_vec()))
1051 .collect()
1052}
1053
1054fn extract_f32_column(df: &DataFrame, name: &str, total_rows: usize) -> Vec<Option<f32>> {
1056 df.column(name)
1057 .ok()
1058 .and_then(|c| c.f32().ok())
1059 .map(|ca| ca.into_iter().collect())
1060 .unwrap_or_else(|| vec![None; total_rows])
1061}
1062
1063fn png_to_rle_segmentation(png_bytes: &[u8], row_index: usize) -> Option<CocoSegmentation> {
1069 if png_bytes.is_empty() {
1070 return None;
1071 }
1072
1073 let mask_data = match crate::MaskData::from_png_checked(png_bytes.to_vec()) {
1074 Ok(m) => m,
1075 Err(e) => {
1076 log::warn!("Skipping invalid PNG mask at row {}: {}", row_index, e);
1077 return None;
1078 }
1079 };
1080
1081 let mw = mask_data.width();
1082 let mh = mask_data.height();
1083 let bit_depth = mask_data.bit_depth();
1084
1085 let decoded = match mask_data.decode() {
1086 Ok(d) => d,
1087 Err(e) => {
1088 log::warn!("Failed to decode PNG mask at row {}: {}", row_index, e);
1089 return None;
1090 }
1091 };
1092
1093 let binary_mask = match bit_depth {
1094 1 => decoded,
1095 8 => {
1096 log::warn!(
1097 "Binarizing 8-bit mask for row {} — score data is lost",
1098 row_index
1099 );
1100 decoded
1101 .iter()
1102 .map(|&v| if v >= 128 { 1 } else { 0 })
1103 .collect()
1104 }
1105 16 => {
1106 log::warn!(
1107 "Binarizing 16-bit mask for row {} — score data is lost",
1108 row_index
1109 );
1110 decoded
1111 .chunks(2)
1112 .map(|pair| {
1113 let val = if pair.len() == 2 {
1114 u16::from_be_bytes([pair[0], pair[1]])
1115 } else {
1116 0
1117 };
1118 if val >= 32768 { 1u8 } else { 0u8 }
1119 })
1120 .collect()
1121 }
1122 _ => decoded,
1123 };
1124
1125 match super::convert::encode_rle(&binary_mask, mw, mh) {
1126 Ok(rle) => Some(CocoSegmentation::Rle(rle)),
1127 Err(e) => {
1128 log::warn!("Failed to encode RLE for row {}: {}", row_index, e);
1129 None
1130 }
1131 }
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136 use super::*;
1137 use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
1138 use tempfile::TempDir;
1139
1140 #[test]
1145 fn test_unflatten_polygon_coords_empty() {
1146 let coords: Vec<f32> = vec![];
1147 let result = crate::unflatten_polygon_coordinates(&coords);
1148 assert!(result.is_empty());
1149 }
1150
1151 #[test]
1152 fn test_unflatten_polygon_coords_single_polygon() {
1153 let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
1155 let result = crate::unflatten_polygon_coordinates(&coords);
1156
1157 assert_eq!(result.len(), 1);
1158 assert_eq!(result[0].len(), 4);
1159 assert_eq!(result[0][0], (0.1, 0.2));
1160 assert_eq!(result[0][3], (0.1, 0.4));
1161 }
1162
1163 #[test]
1164 fn test_unflatten_polygon_coords_multiple_polygons() {
1165 let coords = vec![
1167 0.1,
1168 0.1,
1169 0.2,
1170 0.1,
1171 0.15,
1172 0.2, f32::NAN, 0.5,
1175 0.5,
1176 0.6,
1177 0.5,
1178 0.55,
1179 0.6, ];
1181 let result = crate::unflatten_polygon_coordinates(&coords);
1182
1183 assert_eq!(result.len(), 2);
1184 assert_eq!(result[0].len(), 3);
1185 assert_eq!(result[1].len(), 3);
1186 assert_eq!(result[0][0], (0.1, 0.1));
1187 assert_eq!(result[1][0], (0.5, 0.5));
1188 }
1189
1190 #[test]
1191 fn test_unflatten_polygon_coords_leading_nan() {
1192 let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
1194 let result = crate::unflatten_polygon_coordinates(&coords);
1195
1196 assert_eq!(result.len(), 1);
1197 assert_eq!(result[0].len(), 3);
1198 }
1199
1200 #[test]
1201 fn test_unflatten_polygon_coords_trailing_nan() {
1202 let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
1204 let result = crate::unflatten_polygon_coordinates(&coords);
1205
1206 assert_eq!(result.len(), 1);
1207 assert_eq!(result[0].len(), 2);
1208 }
1209
1210 #[test]
1211 fn test_unflatten_polygon_coords_consecutive_nans() {
1212 let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
1214 let result = crate::unflatten_polygon_coordinates(&coords);
1215
1216 assert_eq!(result.len(), 2);
1217 assert_eq!(result[0].len(), 1);
1218 assert_eq!(result[1].len(), 1);
1219 }
1220
1221 #[test]
1222 fn test_unflatten_polygon_coords_odd_values() {
1223 let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1225 let result = crate::unflatten_polygon_coordinates(&coords);
1226
1227 assert_eq!(result.len(), 1);
1228 assert_eq!(result[0].len(), 2); }
1230
1231 #[test]
1236 fn test_convert_image_annotations_basic() {
1237 let image = CocoImage {
1238 id: 1,
1239 width: 640,
1240 height: 480,
1241 file_name: "test_image.jpg".to_string(),
1242 ..Default::default()
1243 };
1244
1245 let dataset = CocoDataset {
1246 images: vec![image.clone()],
1247 categories: vec![CocoCategory {
1248 id: 1,
1249 name: "cat".to_string(),
1250 supercategory: Some("animal".to_string()),
1251 ..Default::default()
1252 }],
1253 annotations: vec![CocoAnnotation {
1254 id: 1,
1255 image_id: 1,
1256 category_id: 1,
1257 bbox: [100.0, 100.0, 200.0, 200.0],
1258 area: 40000.0,
1259 iscrowd: 0,
1260 segmentation: None,
1261 score: None,
1262 }],
1263 ..Default::default()
1264 };
1265
1266 let index = CocoIndex::from_dataset(&dataset);
1267 let samples = convert_image_annotations(&image, &index, true, Some("train"));
1268
1269 assert_eq!(samples.len(), 1);
1270 assert_eq!(samples[0].image_name, Some("test_image".to_string()));
1271 assert_eq!(samples[0].group, Some("train".to_string()));
1272 assert_eq!(samples[0].annotations.len(), 1);
1273 assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
1274 }
1275
1276 #[test]
1277 fn test_convert_image_annotations_with_mask() {
1278 let image = CocoImage {
1279 id: 1,
1280 width: 100,
1281 height: 100,
1282 file_name: "masked.jpg".to_string(),
1283 ..Default::default()
1284 };
1285
1286 let dataset = CocoDataset {
1287 images: vec![image.clone()],
1288 categories: vec![CocoCategory {
1289 id: 1,
1290 name: "object".to_string(),
1291 supercategory: None,
1292 ..Default::default()
1293 }],
1294 annotations: vec![CocoAnnotation {
1295 id: 1,
1296 image_id: 1,
1297 category_id: 1,
1298 bbox: [10.0, 10.0, 50.0, 50.0],
1299 area: 2500.0,
1300 iscrowd: 0,
1301 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1302 10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
1303 ]])),
1304 score: None,
1305 }],
1306 ..Default::default()
1307 };
1308
1309 let index = CocoIndex::from_dataset(&dataset);
1310
1311 let samples_with_mask = convert_image_annotations(&image, &index, true, None);
1313 assert!(samples_with_mask[0].annotations[0].polygon().is_some());
1314
1315 let samples_no_mask = convert_image_annotations(&image, &index, false, None);
1317 assert!(samples_no_mask[0].annotations[0].polygon().is_none());
1318 }
1319
1320 #[test]
1321 fn test_convert_image_annotations_no_annotations() {
1322 let image = CocoImage {
1323 id: 1,
1324 width: 640,
1325 height: 480,
1326 file_name: "empty.jpg".to_string(),
1327 ..Default::default()
1328 };
1329
1330 let dataset = CocoDataset {
1331 images: vec![image.clone()],
1332 categories: vec![],
1333 annotations: vec![],
1334 ..Default::default()
1335 };
1336
1337 let index = CocoIndex::from_dataset(&dataset);
1338 let samples = convert_image_annotations(&image, &index, true, None);
1339
1340 assert!(samples.is_empty());
1341 }
1342
1343 #[test]
1348 fn test_sample_name_from_filename() {
1349 assert_eq!(
1350 sample_name_from_filename("000000397133.jpg"),
1351 "000000397133"
1352 );
1353 assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
1354 assert_eq!(sample_name_from_filename("test"), "test");
1355 }
1356
1357 #[test]
1358 fn test_sample_name_from_filename_nested_path() {
1359 assert_eq!(
1360 sample_name_from_filename("a/b/c/deep_image.png"),
1361 "deep_image"
1362 );
1363 }
1364
1365 #[test]
1366 fn test_sample_name_from_filename_no_extension() {
1367 assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
1368 }
1369
1370 #[test]
1375 fn test_coco_to_arrow_options_default() {
1376 let options = CocoToArrowOptions::default();
1377 assert!(options.include_masks);
1378 assert!(options.group.is_none());
1379 assert!(options.max_workers >= 2);
1380 }
1381
1382 #[test]
1383 fn test_arrow_to_coco_options_default() {
1384 let options = ArrowToCocoOptions::default();
1385 assert!(options.groups.is_empty());
1386 assert!(options.include_masks);
1387 assert!(options.info.is_none());
1388 }
1389
1390 #[test]
1391 fn test_max_workers() {
1392 let workers = max_workers();
1393 assert!(workers >= 2);
1394 assert!(workers <= 8);
1395 }
1396
1397 #[tokio::test]
1398 async fn test_coco_to_arrow_minimal() {
1399 let temp_dir = TempDir::new().unwrap();
1400
1401 let coco_json = r#"{
1403 "images": [
1404 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1405 ],
1406 "annotations": [
1407 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1408 ],
1409 "categories": [
1410 {"id": 1, "name": "person", "supercategory": "human"}
1411 ]
1412 }"#;
1413
1414 let coco_path = temp_dir.path().join("test.json");
1415 std::fs::write(&coco_path, coco_json).unwrap();
1416
1417 let arrow_path = temp_dir.path().join("output.arrow");
1418
1419 let options = CocoToArrowOptions::default();
1420 let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1421 .await
1422 .unwrap();
1423
1424 assert_eq!(count, 1);
1425 assert!(arrow_path.exists());
1426
1427 let mut file = std::fs::File::open(&arrow_path).unwrap();
1429 let df = IpcReader::new(&mut file).finish().unwrap();
1430 assert_eq!(df.height(), 1);
1431 }
1432
1433 #[tokio::test]
1434 async fn test_arrow_to_coco_roundtrip() {
1435 let temp_dir = TempDir::new().unwrap();
1436
1437 let original = CocoDataset {
1439 images: vec![CocoImage {
1440 id: 1,
1441 width: 640,
1442 height: 480,
1443 file_name: "test.jpg".to_string(),
1444 ..Default::default()
1445 }],
1446 annotations: vec![CocoAnnotation {
1447 id: 1,
1448 image_id: 1,
1449 category_id: 1,
1450 bbox: [100.0, 50.0, 200.0, 150.0],
1451 area: 30000.0,
1452 iscrowd: 0,
1453 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1454 100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
1455 ]])),
1456 score: None,
1457 }],
1458 categories: vec![CocoCategory {
1459 id: 1,
1460 name: "person".to_string(),
1461 supercategory: Some("human".to_string()),
1462 ..Default::default()
1463 }],
1464 ..Default::default()
1465 };
1466
1467 let coco_path = temp_dir.path().join("original.json");
1469 let writer = CocoWriter::new();
1470 writer.write_json(&original, &coco_path).unwrap();
1471
1472 let arrow_path = temp_dir.path().join("converted.arrow");
1474 let options = CocoToArrowOptions::default();
1475 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1476 .await
1477 .unwrap();
1478
1479 let restored_path = temp_dir.path().join("restored.json");
1481 let options = ArrowToCocoOptions::default();
1482 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1483 .await
1484 .unwrap();
1485
1486 let reader = CocoReader::new();
1488 let restored = reader.read_json(&restored_path).unwrap();
1489
1490 assert_eq!(restored.images.len(), 1);
1491 assert_eq!(restored.annotations.len(), 1);
1492 assert_eq!(restored.categories.len(), 1);
1493
1494 assert_eq!(restored.categories[0].name, "person");
1496 }
1497
1498 #[tokio::test]
1503 async fn test_coco_to_arrow_schema_version_metadata() {
1504 let temp_dir = TempDir::new().unwrap();
1505
1506 let coco_json = r#"{
1508 "images": [
1509 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1510 ],
1511 "annotations": [
1512 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1513 ],
1514 "categories": [
1515 {"id": 1, "name": "person", "supercategory": "human"}
1516 ]
1517 }"#;
1518
1519 let coco_path = temp_dir.path().join("test.json");
1520 std::fs::write(&coco_path, coco_json).unwrap();
1521
1522 let arrow_path = temp_dir.path().join("output.arrow");
1523 let options = CocoToArrowOptions::default();
1524 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1525 .await
1526 .unwrap();
1527
1528 let mut file = std::fs::File::open(&arrow_path).unwrap();
1530 let mut reader = IpcReader::new(&mut file);
1531 let custom_meta = reader.custom_metadata().unwrap();
1532 assert!(custom_meta.is_some(), "custom metadata should be present");
1533
1534 let meta = custom_meta.unwrap();
1535 assert_eq!(
1536 meta.get(&PlSmallStr::from("schema_version")),
1537 Some(&PlSmallStr::from(SCHEMA_VERSION)),
1538 "schema_version metadata should be '2026.04'"
1539 );
1540
1541 assert!(
1543 meta.contains_key(&PlSmallStr::from("category_metadata")),
1544 "category_metadata should be present even without LVIS fields"
1545 );
1546 }
1547
1548 #[tokio::test]
1549 async fn test_coco_to_arrow_category_metadata_lvis() {
1550 let temp_dir = TempDir::new().unwrap();
1551
1552 let coco_json = r#"{
1554 "images": [
1555 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1556 ],
1557 "annotations": [
1558 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0},
1559 {"id": 2, "image_id": 1, "category_id": 2, "bbox": [50, 60, 80, 40], "area": 3200, "iscrowd": 0}
1560 ],
1561 "categories": [
1562 {
1563 "id": 1,
1564 "name": "aerosol_can",
1565 "synset": "aerosol.n.02",
1566 "synonyms": ["aerosol_can", "spray_can"],
1567 "def": "a dispenser that holds a substance under pressure"
1568 },
1569 {
1570 "id": 2,
1571 "name": "person",
1572 "supercategory": "human"
1573 }
1574 ]
1575 }"#;
1576
1577 let coco_path = temp_dir.path().join("lvis.json");
1578 std::fs::write(&coco_path, coco_json).unwrap();
1579
1580 let arrow_path = temp_dir.path().join("lvis_output.arrow");
1581 let options = CocoToArrowOptions::default();
1582 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1583 .await
1584 .unwrap();
1585
1586 let mut file = std::fs::File::open(&arrow_path).unwrap();
1588 let mut reader = IpcReader::new(&mut file);
1589 let custom_meta = reader.custom_metadata().unwrap();
1590 assert!(custom_meta.is_some(), "custom metadata should be present");
1591
1592 let meta = custom_meta.unwrap();
1593
1594 assert_eq!(
1596 meta.get(&PlSmallStr::from("schema_version")),
1597 Some(&PlSmallStr::from(SCHEMA_VERSION)),
1598 );
1599
1600 let cat_meta_str = meta
1602 .get(&PlSmallStr::from("category_metadata"))
1603 .expect("category_metadata should be present for LVIS data");
1604
1605 let cat_meta: HashMap<String, serde_json::Value> =
1606 serde_json::from_str(cat_meta_str.as_str()).unwrap();
1607
1608 assert!(
1610 cat_meta.contains_key("aerosol_can"),
1611 "aerosol_can should be in category_metadata"
1612 );
1613 assert!(
1614 cat_meta.contains_key("person"),
1615 "person should also be in category_metadata"
1616 );
1617
1618 let aerosol = cat_meta.get("aerosol_can").unwrap();
1620 assert_eq!(
1621 aerosol.get("synset").and_then(|v| v.as_str()),
1622 Some("aerosol.n.02")
1623 );
1624 assert_eq!(
1625 aerosol.get("definition").and_then(|v| v.as_str()),
1626 Some("a dispenser that holds a substance under pressure")
1627 );
1628 let synonyms = aerosol.get("synonyms").and_then(|v| v.as_array()).unwrap();
1629 assert_eq!(synonyms.len(), 2);
1630 assert_eq!(synonyms[0].as_str(), Some("aerosol_can"));
1631 assert_eq!(synonyms[1].as_str(), Some("spray_can"));
1632 }
1633
1634 #[tokio::test]
1639 async fn test_coco_arrow_roundtrip_lvis_supercategory() {
1640 let temp_dir = TempDir::new().unwrap();
1641
1642 let coco_json = r#"{
1644 "images": [
1645 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
1646 ],
1647 "annotations": [
1648 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
1649 ],
1650 "categories": [
1651 {"id": 1, "name": "person", "supercategory": "human"}
1652 ]
1653 }"#;
1654
1655 let coco_path = temp_dir.path().join("original.json");
1656 std::fs::write(&coco_path, coco_json).unwrap();
1657
1658 let arrow_path = temp_dir.path().join("converted.arrow");
1660 let options = CocoToArrowOptions::default();
1661 coco_to_arrow(&coco_path, &arrow_path, &options, None)
1662 .await
1663 .unwrap();
1664
1665 let restored_path = temp_dir.path().join("restored.json");
1667 let options = ArrowToCocoOptions::default();
1668 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1669 .await
1670 .unwrap();
1671
1672 let reader = CocoReader::new();
1674 let restored = reader.read_json(&restored_path).unwrap();
1675
1676 assert_eq!(restored.categories.len(), 1);
1677 assert_eq!(restored.categories[0].name, "person");
1678 assert_eq!(
1679 restored.categories[0].supercategory,
1680 Some("human".to_string()),
1681 "supercategory should survive COCO→Arrow→COCO round-trip"
1682 );
1683 }
1684
1685 #[tokio::test]
1686 async fn test_coco_arrow_roundtrip_neg_categories_no_annotations() {
1687 let temp_dir = TempDir::new().unwrap();
1688
1689 let coco_json = r#"{
1691 "images": [
1692 {
1693 "id": 1,
1694 "width": 640,
1695 "height": 480,
1696 "file_name": "empty.jpg",
1697 "neg_category_ids": [1, 2]
1698 }
1699 ],
1700 "annotations": [],
1701 "categories": [
1702 {"id": 1, "name": "cat", "supercategory": "animal"},
1703 {"id": 2, "name": "dog", "supercategory": "animal"}
1704 ]
1705 }"#;
1706
1707 let coco_path = temp_dir.path().join("original.json");
1708 std::fs::write(&coco_path, coco_json).unwrap();
1709
1710 let arrow_path = temp_dir.path().join("converted.arrow");
1712 let options = CocoToArrowOptions::default();
1713 let sample_count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
1714 .await
1715 .unwrap();
1716
1717 assert_eq!(
1719 sample_count, 1,
1720 "sentinel row should be emitted for image with neg data"
1721 );
1722
1723 let restored_path = temp_dir.path().join("restored.json");
1725 let options = ArrowToCocoOptions::default();
1726 arrow_to_coco(&arrow_path, &restored_path, &options, None)
1727 .await
1728 .unwrap();
1729
1730 let reader = CocoReader::new();
1732 let restored = reader.read_json(&restored_path).unwrap();
1733
1734 assert_eq!(restored.images.len(), 1);
1735 assert_eq!(restored.annotations.len(), 0, "no annotations expected");
1736 assert_eq!(restored.categories.len(), 2, "both categories should exist");
1737
1738 let neg = restored.images[0].neg_category_ids.as_ref();
1739 assert!(
1740 neg.is_some(),
1741 "neg_category_ids should survive round-trip for zero-annotation image"
1742 );
1743 let neg_ids = neg.unwrap();
1744 assert_eq!(neg_ids.len(), 2, "should have 2 neg categories");
1745 assert!(neg_ids.contains(&1), "neg_category_ids should contain 1");
1746 assert!(neg_ids.contains(&2), "neg_category_ids should contain 2");
1747
1748 for cat in &restored.categories {
1750 assert_eq!(
1751 cat.supercategory,
1752 Some("animal".to_string()),
1753 "supercategory should survive round-trip for annotation-free category '{}'",
1754 cat.name
1755 );
1756 }
1757 }
1758
1759 #[test]
1760 fn test_convert_image_annotations_neg_only_no_annotations() {
1761 let image = CocoImage {
1762 id: 1,
1763 width: 640,
1764 height: 480,
1765 file_name: "neg_only.jpg".to_string(),
1766 neg_category_ids: Some(vec![1, 2]),
1767 ..Default::default()
1768 };
1769
1770 let dataset = CocoDataset {
1771 images: vec![image.clone()],
1772 categories: vec![
1773 CocoCategory {
1774 id: 1,
1775 name: "cat".to_string(),
1776 supercategory: Some("animal".to_string()),
1777 ..Default::default()
1778 },
1779 CocoCategory {
1780 id: 2,
1781 name: "dog".to_string(),
1782 supercategory: Some("animal".to_string()),
1783 ..Default::default()
1784 },
1785 ],
1786 annotations: vec![],
1787 ..Default::default()
1788 };
1789
1790 let index = CocoIndex::from_dataset(&dataset);
1791 let samples = convert_image_annotations(&image, &index, true, None);
1792
1793 assert_eq!(
1795 samples.len(),
1796 1,
1797 "sentinel row should be emitted for neg-only image"
1798 );
1799 assert_eq!(samples[0].image_name, Some("neg_only".to_string()));
1800 assert!(
1801 samples[0].annotations.is_empty(),
1802 "sentinel should have no annotations"
1803 );
1804 assert!(
1805 samples[0].neg_label_indices.is_some(),
1806 "sentinel should preserve neg_label_indices"
1807 );
1808 assert_eq!(samples[0].neg_label_indices.as_ref().unwrap().len(), 2);
1809 }
1810}