1use super::{
10 convert::{
11 box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask, mask_to_coco_polygon,
12 },
13 reader::CocoReader,
14 types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
15 writer::{CocoDatasetBuilder, CocoWriter},
16};
17use crate::{Annotation, Box2d, Error, Mask, Progress, Sample};
18use polars::prelude::*;
19use std::{
20 collections::HashMap,
21 path::Path,
22 sync::{
23 Arc,
24 atomic::{AtomicUsize, Ordering},
25 },
26};
27use tokio::sync::{Semaphore, mpsc::Sender};
28
29fn unflatten_polygon_coords(coords: &[f32]) -> Vec<Vec<(f32, f32)>> {
38 let mut polygons = Vec::new();
39 let mut current = Vec::new();
40 let mut i = 0;
41
42 while i < coords.len() {
43 if coords[i].is_nan() {
44 if !current.is_empty() {
46 polygons.push(std::mem::take(&mut current));
47 }
48 i += 1;
49 } else if i + 1 < coords.len() && !coords[i + 1].is_nan() {
50 current.push((coords[i], coords[i + 1]));
52 i += 2;
53 } else if i + 1 < coords.len() && coords[i + 1].is_nan() {
54 i += 1;
57 } else {
58 i += 1;
60 }
61 }
62
63 if !current.is_empty() {
64 polygons.push(current);
65 }
66
67 polygons
68}
69
70#[derive(Debug, Clone)]
72pub struct CocoToArrowOptions {
73 pub include_masks: bool,
75 pub group: Option<String>,
77 pub max_workers: usize,
79}
80
81impl Default for CocoToArrowOptions {
82 fn default() -> Self {
83 Self {
84 include_masks: true,
85 group: None,
86 max_workers: max_workers(),
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct ArrowToCocoOptions {
94 pub groups: Vec<String>,
96 pub include_masks: bool,
98 pub info: Option<CocoInfo>,
100}
101
102impl Default for ArrowToCocoOptions {
103 fn default() -> Self {
104 Self {
105 groups: vec![],
106 include_masks: true,
107 info: None,
108 }
109 }
110}
111
112fn max_workers() -> usize {
114 std::env::var("MAX_COCO_WORKERS")
115 .ok()
116 .and_then(|v| v.parse().ok())
117 .unwrap_or_else(|| {
118 let cpus = std::thread::available_parallelism()
119 .map(|n| n.get())
120 .unwrap_or(4);
121 (cpus / 2).clamp(2, 8)
122 })
123}
124
125pub async fn coco_to_arrow<P: AsRef<Path>>(
139 coco_path: P,
140 output_path: P,
141 options: &CocoToArrowOptions,
142 progress: Option<Sender<Progress>>,
143) -> Result<usize, Error> {
144 let coco_path = coco_path.as_ref();
145 let output_path = output_path.as_ref();
146
147 let reader = CocoReader::new();
149 let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
150 reader.read_annotations_zip(coco_path)?
151 } else {
152 reader.read_json(coco_path)?
153 };
154
155 let index = Arc::new(CocoIndex::from_dataset(&dataset));
157 let total_images = dataset.images.len();
158
159 if let Some(ref p) = progress {
161 let _ = p
162 .send(Progress {
163 current: 0,
164 total: total_images,
165 })
166 .await;
167 }
168
169 let sem = Arc::new(Semaphore::new(options.max_workers));
171 let current = Arc::new(AtomicUsize::new(0));
172 let include_masks = options.include_masks;
173 let group = options.group.clone();
174
175 let mut tasks = Vec::with_capacity(total_images);
176
177 for image in dataset.images {
178 let sem = sem.clone();
179 let index = index.clone();
180 let current = current.clone();
181 let progress = progress.clone();
182 let total = total_images;
183 let group = group.clone();
184
185 let task = tokio::spawn(async move {
186 let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
187
188 let samples =
190 convert_image_annotations(&image, &index, include_masks, group.as_deref());
191
192 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
194 if let Some(ref p) = progress {
195 let _ = p.send(Progress { current: c, total }).await;
196 }
197
198 Ok::<Vec<Sample>, Error>(samples)
199 });
200
201 tasks.push(task);
202 }
203
204 let mut all_samples = Vec::with_capacity(total_images);
206 for task in tasks {
207 let samples = task.await??;
208 all_samples.extend(samples);
209 }
210
211 let df = crate::samples_dataframe(&all_samples)?;
213
214 if let Some(parent) = output_path.parent()
216 && !parent.as_os_str().is_empty()
217 {
218 std::fs::create_dir_all(parent)?;
219 }
220 let mut file = std::fs::File::create(output_path)?;
221 IpcWriter::new(&mut file).finish(&mut df.clone())?;
222
223 Ok(all_samples.len())
224}
225
226fn convert_image_annotations(
228 image: &CocoImage,
229 index: &CocoIndex,
230 include_masks: bool,
231 group: Option<&str>,
232) -> Vec<Sample> {
233 let annotations = index.annotations_for_image(image.id);
234 let sample_name = sample_name_from_filename(&image.file_name);
235
236 annotations
237 .iter()
238 .filter_map(|ann| {
239 let label = index.label_name(ann.category_id)?;
240 let label_index = index.label_index(ann.category_id);
241
242 let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
244
245 let mask = if include_masks {
247 ann.segmentation
248 .as_ref()
249 .and_then(|seg| coco_segmentation_to_mask(seg, image.width, image.height).ok())
250 } else {
251 None
252 };
253
254 let mut annotation = Annotation::new();
255 annotation.set_name(Some(sample_name.clone()));
256 annotation.set_label(Some(label.to_string()));
257 annotation.set_label_index(label_index);
258 annotation.set_box2d(Some(box2d));
259 annotation.set_mask(mask);
260 annotation.set_group(group.map(String::from));
261
262 let sample = Sample {
263 image_name: Some(sample_name.clone()),
264 width: Some(image.width),
265 height: Some(image.height),
266 group: group.map(String::from),
267 annotations: vec![annotation],
268 ..Default::default()
269 };
270
271 Some(sample)
272 })
273 .collect()
274}
275
276fn sample_name_from_filename(filename: &str) -> String {
278 Path::new(filename)
279 .file_stem()
280 .and_then(|s| s.to_str())
281 .map(String::from)
282 .unwrap_or_else(|| filename.to_string())
283}
284
285pub async fn arrow_to_coco<P: AsRef<Path>>(
298 arrow_path: P,
299 output_path: P,
300 options: &ArrowToCocoOptions,
301 progress: Option<Sender<Progress>>,
302) -> Result<usize, Error> {
303 let arrow_path = arrow_path.as_ref();
304 let output_path = output_path.as_ref();
305
306 let mut file = std::fs::File::open(arrow_path)?;
308 let df = IpcReader::new(&mut file).finish()?;
309
310 let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
312
313 let total_rows = df.height();
314
315 if let Some(ref p) = progress {
316 let _ = p
317 .send(Progress {
318 current: 0,
319 total: total_rows,
320 })
321 .await;
322 }
323
324 let names: Vec<String> = df
326 .column("name")?
327 .str()?
328 .into_iter()
329 .map(|s| s.unwrap_or_default().to_string())
330 .collect();
331
332 let labels: Vec<String> = df
333 .column("label")?
334 .cast(&DataType::String)?
335 .str()?
336 .into_iter()
337 .map(|s| s.unwrap_or_default().to_string())
338 .collect();
339
340 let groups: Vec<String> = df
342 .column("group")
343 .ok()
344 .and_then(|c| c.cast(&DataType::String).ok())
345 .map(|c| {
346 c.str()
347 .ok()
348 .map(|s| {
349 s.into_iter()
350 .map(|v| v.unwrap_or_default().to_string())
351 .collect()
352 })
353 .unwrap_or_default()
354 })
355 .unwrap_or_else(|| vec!["".to_string(); total_rows]);
356
357 let box2ds = extract_all_box2ds(df.column("box2d")?)?;
359
360 let masks = if options.include_masks {
362 df.column("mask").ok().map(extract_all_masks).transpose()?
363 } else {
364 None
365 };
366
367 let sizes = df
369 .column("size")
370 .ok()
371 .and_then(|c| extract_all_sizes(c).ok());
372
373 let mut builder = CocoDatasetBuilder::new();
375
376 if let Some(info) = &options.info {
377 builder = builder.info(info.clone());
378 }
379
380 let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
382 let mut image_ids: HashMap<String, u64> = HashMap::new();
383 let mut category_ids: HashMap<String, u32> = HashMap::new();
384
385 for i in 0..total_rows {
387 if !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i]) {
389 continue;
390 }
391
392 let name = &names[i];
393 let label = &labels[i];
394
395 if !image_ids.contains_key(name) {
397 let (width, height) = sizes
398 .as_ref()
399 .and_then(|s| s.get(i).copied())
400 .unwrap_or((0, 0));
401
402 let id = builder.add_image(name, width, height);
403 image_ids.insert(name.clone(), id);
404 image_dimensions.insert(name.clone(), (width, height));
405 }
406
407 if !label.is_empty() && !category_ids.contains_key(label) {
408 let id = builder.add_category(label, None);
409 category_ids.insert(label.clone(), id);
410 }
411 }
412
413 let mut last_progress_update = 0;
415 for i in 0..total_rows {
416 if !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i]) {
418 continue;
419 }
420
421 let name = &names[i];
422 let label = &labels[i];
423
424 let image_id = *image_ids.get(name).unwrap_or(&0);
425 let category_id = *category_ids.get(label).unwrap_or(&0);
426 let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
427
428 let bbox = box2ds.get(i).map(|box2d| {
431 let cx = box2d[0];
432 let cy = box2d[1];
433 let w = box2d[2];
434 let h = box2d[3];
435 let left = cx - w / 2.0;
437 let top = cy - h / 2.0;
438 let ef_box2d = Box2d::new(left, top, w, h);
439 box2d_to_coco_bbox(&ef_box2d, width, height)
440 });
441
442 let segmentation = if options.include_masks {
444 masks.as_ref().and_then(|m| {
445 m.get(i).and_then(|coords| {
446 if coords.is_empty() {
447 None
448 } else {
449 let polygons = unflatten_polygon_coords(coords);
450 let mask = Mask::new(polygons);
451 let coco_poly = mask_to_coco_polygon(&mask, width, height);
452 if coco_poly.is_empty() {
453 None
454 } else {
455 Some(CocoSegmentation::Polygon(coco_poly))
456 }
457 }
458 })
459 })
460 } else {
461 None
462 };
463
464 if let Some(bbox) = bbox {
465 builder.add_annotation(image_id, category_id, bbox, segmentation);
466 }
467
468 if let Some(ref p) = progress
470 && (i - last_progress_update >= 1000 || i == total_rows - 1)
471 {
472 let _ = p
473 .send(Progress {
474 current: i + 1,
475 total: total_rows,
476 })
477 .await;
478 last_progress_update = i;
479 }
480 }
481
482 let dataset = builder.build();
483 let annotation_count = dataset.annotations.len();
484
485 let writer = CocoWriter::new();
487 writer.write_json(&dataset, output_path)?;
488
489 Ok(annotation_count)
490}
491
492fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
494 let arr = col.array()?;
495 let mut result = Vec::with_capacity(arr.len());
496
497 for inner in arr.amortized_iter() {
498 let values = if let Some(inner) = inner {
499 let series = inner.as_ref();
500 let vals: Vec<f32> = series
501 .f32()
502 .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
503 .into_iter()
504 .map(|v| v.unwrap_or(0.0))
505 .collect();
506
507 if vals.len() == 4 {
508 [vals[0], vals[1], vals[2], vals[3]]
509 } else {
510 [0.0, 0.0, 0.0, 0.0]
511 }
512 } else {
513 [0.0, 0.0, 0.0, 0.0]
514 };
515 result.push(values);
516 }
517
518 Ok(result)
519}
520
521fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
523 let list = col.list()?;
524 let mut result = Vec::with_capacity(list.len());
525
526 for i in 0..list.len() {
527 let coords = match list.get_as_series(i) {
528 Some(series) => series
529 .f32()
530 .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
531 .into_iter()
532 .map(|v| v.unwrap_or(f32::NAN))
533 .collect(),
534 None => vec![],
535 };
536 result.push(coords);
537 }
538
539 Ok(result)
540}
541
542fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
544 let arr = col.array()?;
545 let mut result = Vec::with_capacity(arr.len());
546
547 for inner in arr.amortized_iter() {
548 let size = if let Some(inner) = inner {
549 let series = inner.as_ref();
550 let values: Vec<u32> = series
551 .u32()
552 .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
553 .into_iter()
554 .map(|v| v.unwrap_or(0))
555 .collect();
556
557 if values.len() >= 2 {
558 (values[0], values[1])
559 } else {
560 (0, 0)
561 }
562 } else {
563 (0, 0)
564 };
565 result.push(size);
566 }
567
568 Ok(result)
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
575 use tempfile::TempDir;
576
577 #[test]
582 fn test_unflatten_polygon_coords_empty() {
583 let coords: Vec<f32> = vec![];
584 let result = unflatten_polygon_coords(&coords);
585 assert!(result.is_empty());
586 }
587
588 #[test]
589 fn test_unflatten_polygon_coords_single_polygon() {
590 let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
592 let result = unflatten_polygon_coords(&coords);
593
594 assert_eq!(result.len(), 1);
595 assert_eq!(result[0].len(), 4);
596 assert_eq!(result[0][0], (0.1, 0.2));
597 assert_eq!(result[0][3], (0.1, 0.4));
598 }
599
600 #[test]
601 fn test_unflatten_polygon_coords_multiple_polygons() {
602 let coords = vec![
604 0.1,
605 0.1,
606 0.2,
607 0.1,
608 0.15,
609 0.2, f32::NAN, 0.5,
612 0.5,
613 0.6,
614 0.5,
615 0.55,
616 0.6, ];
618 let result = unflatten_polygon_coords(&coords);
619
620 assert_eq!(result.len(), 2);
621 assert_eq!(result[0].len(), 3);
622 assert_eq!(result[1].len(), 3);
623 assert_eq!(result[0][0], (0.1, 0.1));
624 assert_eq!(result[1][0], (0.5, 0.5));
625 }
626
627 #[test]
628 fn test_unflatten_polygon_coords_leading_nan() {
629 let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
631 let result = unflatten_polygon_coords(&coords);
632
633 assert_eq!(result.len(), 1);
634 assert_eq!(result[0].len(), 3);
635 }
636
637 #[test]
638 fn test_unflatten_polygon_coords_trailing_nan() {
639 let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
641 let result = unflatten_polygon_coords(&coords);
642
643 assert_eq!(result.len(), 1);
644 assert_eq!(result[0].len(), 2);
645 }
646
647 #[test]
648 fn test_unflatten_polygon_coords_consecutive_nans() {
649 let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
651 let result = unflatten_polygon_coords(&coords);
652
653 assert_eq!(result.len(), 2);
654 assert_eq!(result[0].len(), 1);
655 assert_eq!(result[1].len(), 1);
656 }
657
658 #[test]
659 fn test_unflatten_polygon_coords_odd_values() {
660 let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
662 let result = unflatten_polygon_coords(&coords);
663
664 assert_eq!(result.len(), 1);
665 assert_eq!(result[0].len(), 2); }
667
668 #[test]
673 fn test_convert_image_annotations_basic() {
674 let image = CocoImage {
675 id: 1,
676 width: 640,
677 height: 480,
678 file_name: "test_image.jpg".to_string(),
679 ..Default::default()
680 };
681
682 let dataset = CocoDataset {
683 images: vec![image.clone()],
684 categories: vec![CocoCategory {
685 id: 1,
686 name: "cat".to_string(),
687 supercategory: Some("animal".to_string()),
688 }],
689 annotations: vec![CocoAnnotation {
690 id: 1,
691 image_id: 1,
692 category_id: 1,
693 bbox: [100.0, 100.0, 200.0, 200.0],
694 area: 40000.0,
695 iscrowd: 0,
696 segmentation: None,
697 }],
698 ..Default::default()
699 };
700
701 let index = CocoIndex::from_dataset(&dataset);
702 let samples = convert_image_annotations(&image, &index, true, Some("train"));
703
704 assert_eq!(samples.len(), 1);
705 assert_eq!(samples[0].image_name, Some("test_image".to_string()));
706 assert_eq!(samples[0].group, Some("train".to_string()));
707 assert_eq!(samples[0].annotations.len(), 1);
708 assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
709 }
710
711 #[test]
712 fn test_convert_image_annotations_with_mask() {
713 let image = CocoImage {
714 id: 1,
715 width: 100,
716 height: 100,
717 file_name: "masked.jpg".to_string(),
718 ..Default::default()
719 };
720
721 let dataset = CocoDataset {
722 images: vec![image.clone()],
723 categories: vec![CocoCategory {
724 id: 1,
725 name: "object".to_string(),
726 supercategory: None,
727 }],
728 annotations: vec![CocoAnnotation {
729 id: 1,
730 image_id: 1,
731 category_id: 1,
732 bbox: [10.0, 10.0, 50.0, 50.0],
733 area: 2500.0,
734 iscrowd: 0,
735 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
736 10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
737 ]])),
738 }],
739 ..Default::default()
740 };
741
742 let index = CocoIndex::from_dataset(&dataset);
743
744 let samples_with_mask = convert_image_annotations(&image, &index, true, None);
746 assert!(samples_with_mask[0].annotations[0].mask().is_some());
747
748 let samples_no_mask = convert_image_annotations(&image, &index, false, None);
750 assert!(samples_no_mask[0].annotations[0].mask().is_none());
751 }
752
753 #[test]
754 fn test_convert_image_annotations_no_annotations() {
755 let image = CocoImage {
756 id: 1,
757 width: 640,
758 height: 480,
759 file_name: "empty.jpg".to_string(),
760 ..Default::default()
761 };
762
763 let dataset = CocoDataset {
764 images: vec![image.clone()],
765 categories: vec![],
766 annotations: vec![],
767 ..Default::default()
768 };
769
770 let index = CocoIndex::from_dataset(&dataset);
771 let samples = convert_image_annotations(&image, &index, true, None);
772
773 assert!(samples.is_empty());
774 }
775
776 #[test]
781 fn test_sample_name_from_filename() {
782 assert_eq!(
783 sample_name_from_filename("000000397133.jpg"),
784 "000000397133"
785 );
786 assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
787 assert_eq!(sample_name_from_filename("test"), "test");
788 }
789
790 #[test]
791 fn test_sample_name_from_filename_nested_path() {
792 assert_eq!(
793 sample_name_from_filename("a/b/c/deep_image.png"),
794 "deep_image"
795 );
796 }
797
798 #[test]
799 fn test_sample_name_from_filename_no_extension() {
800 assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
801 }
802
803 #[test]
808 fn test_coco_to_arrow_options_default() {
809 let options = CocoToArrowOptions::default();
810 assert!(options.include_masks);
811 assert!(options.group.is_none());
812 assert!(options.max_workers >= 2);
813 }
814
815 #[test]
816 fn test_arrow_to_coco_options_default() {
817 let options = ArrowToCocoOptions::default();
818 assert!(options.groups.is_empty());
819 assert!(options.include_masks);
820 assert!(options.info.is_none());
821 }
822
823 #[test]
824 fn test_max_workers() {
825 let workers = max_workers();
826 assert!(workers >= 2);
827 assert!(workers <= 8);
828 }
829
830 #[tokio::test]
831 async fn test_coco_to_arrow_minimal() {
832 let temp_dir = TempDir::new().unwrap();
833
834 let coco_json = r#"{
836 "images": [
837 {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
838 ],
839 "annotations": [
840 {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
841 ],
842 "categories": [
843 {"id": 1, "name": "person", "supercategory": "human"}
844 ]
845 }"#;
846
847 let coco_path = temp_dir.path().join("test.json");
848 std::fs::write(&coco_path, coco_json).unwrap();
849
850 let arrow_path = temp_dir.path().join("output.arrow");
851
852 let options = CocoToArrowOptions::default();
853 let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
854 .await
855 .unwrap();
856
857 assert_eq!(count, 1);
858 assert!(arrow_path.exists());
859
860 let mut file = std::fs::File::open(&arrow_path).unwrap();
862 let df = IpcReader::new(&mut file).finish().unwrap();
863 assert_eq!(df.height(), 1);
864 }
865
866 #[tokio::test]
867 async fn test_arrow_to_coco_roundtrip() {
868 let temp_dir = TempDir::new().unwrap();
869
870 let original = CocoDataset {
872 images: vec![CocoImage {
873 id: 1,
874 width: 640,
875 height: 480,
876 file_name: "test.jpg".to_string(),
877 ..Default::default()
878 }],
879 annotations: vec![CocoAnnotation {
880 id: 1,
881 image_id: 1,
882 category_id: 1,
883 bbox: [100.0, 50.0, 200.0, 150.0],
884 area: 30000.0,
885 iscrowd: 0,
886 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
887 100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
888 ]])),
889 }],
890 categories: vec![CocoCategory {
891 id: 1,
892 name: "person".to_string(),
893 supercategory: Some("human".to_string()),
894 }],
895 ..Default::default()
896 };
897
898 let coco_path = temp_dir.path().join("original.json");
900 let writer = CocoWriter::new();
901 writer.write_json(&original, &coco_path).unwrap();
902
903 let arrow_path = temp_dir.path().join("converted.arrow");
905 let options = CocoToArrowOptions::default();
906 coco_to_arrow(&coco_path, &arrow_path, &options, None)
907 .await
908 .unwrap();
909
910 let restored_path = temp_dir.path().join("restored.json");
912 let options = ArrowToCocoOptions::default();
913 arrow_to_coco(&arrow_path, &restored_path, &options, None)
914 .await
915 .unwrap();
916
917 let reader = CocoReader::new();
919 let restored = reader.read_json(&restored_path).unwrap();
920
921 assert_eq!(restored.images.len(), 1);
922 assert_eq!(restored.annotations.len(), 1);
923 assert_eq!(restored.categories.len(), 1);
924
925 assert_eq!(restored.categories[0].name, "person");
927 }
928}