edgefirst_client/coco/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! COCO to EdgeFirst Arrow format conversion.
5//!
6//! Provides high-performance conversion between COCO JSON and EdgeFirst Arrow
7//! format, supporting async operations and progress tracking.
8
9use super::{
10    convert::{
11        box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask, mask_to_coco_polygon,
12    },
13    reader::CocoReader,
14    types::{CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
15    writer::{CocoDatasetBuilder, CocoWriter},
16};
17use crate::{Annotation, Box2d, Error, Mask, Progress, Sample};
18use polars::prelude::*;
19use std::{
20    collections::HashMap,
21    path::Path,
22    sync::{
23        Arc,
24        atomic::{AtomicUsize, Ordering},
25    },
26};
27use tokio::sync::{Semaphore, mpsc::Sender};
28
29/// Unflatten polygon coordinates from Arrow flat format.
30///
31/// Converts `[x1, y1, x2, y2, NaN, x3, y3, ...]` to `[[(x1,y1), (x2,y2)],
32/// [(x3,y3), ...]]`
33///
34/// **IMPORTANT**: The separator is a SINGLE NaN, not a pair. We must process
35/// elements one at a time, not in chunks of 2, to correctly handle the
36/// separator.
37fn unflatten_polygon_coords(coords: &[f32]) -> Vec<Vec<(f32, f32)>> {
38    let mut polygons = Vec::new();
39    let mut current = Vec::new();
40    let mut i = 0;
41
42    while i < coords.len() {
43        if coords[i].is_nan() {
44            // Single NaN separator - save current polygon and start new one
45            if !current.is_empty() {
46                polygons.push(std::mem::take(&mut current));
47            }
48            i += 1;
49        } else if i + 1 < coords.len() && !coords[i + 1].is_nan() {
50            // Have both x and y coordinates (neither is NaN)
51            current.push((coords[i], coords[i + 1]));
52            i += 2;
53        } else if i + 1 < coords.len() && coords[i + 1].is_nan() {
54            // x is valid but y is NaN - this shouldn't happen in well-formed data
55            // but handle it gracefully: skip x, process NaN on next iteration
56            i += 1;
57        } else {
58            // Odd trailing value - skip
59            i += 1;
60        }
61    }
62
63    if !current.is_empty() {
64        polygons.push(current);
65    }
66
67    polygons
68}
69
70/// Options for COCO to Arrow conversion.
71#[derive(Debug, Clone)]
72pub struct CocoToArrowOptions {
73    /// Include segmentation masks in output.
74    pub include_masks: bool,
75    /// Group name for all samples (e.g., "train", "val").
76    pub group: Option<String>,
77    /// Maximum number of parallel workers.
78    pub max_workers: usize,
79}
80
81impl Default for CocoToArrowOptions {
82    fn default() -> Self {
83        Self {
84            include_masks: true,
85            group: None,
86            max_workers: max_workers(),
87        }
88    }
89}
90
91/// Options for Arrow to COCO conversion.
92#[derive(Debug, Clone)]
93pub struct ArrowToCocoOptions {
94    /// Filter by group names (empty = all).
95    pub groups: Vec<String>,
96    /// Include segmentation masks in output.
97    pub include_masks: bool,
98    /// COCO info section.
99    pub info: Option<CocoInfo>,
100}
101
102impl Default for ArrowToCocoOptions {
103    fn default() -> Self {
104        Self {
105            groups: vec![],
106            include_masks: true,
107            info: None,
108        }
109    }
110}
111
112/// Determine maximum number of parallel workers.
113fn max_workers() -> usize {
114    std::env::var("MAX_COCO_WORKERS")
115        .ok()
116        .and_then(|v| v.parse().ok())
117        .unwrap_or_else(|| {
118            let cpus = std::thread::available_parallelism()
119                .map(|n| n.get())
120                .unwrap_or(4);
121            (cpus / 2).clamp(2, 8)
122        })
123}
124
125/// Convert COCO annotations to EdgeFirst Arrow format.
126///
127/// This is a high-performance async conversion that uses parallel workers
128/// for parsing and transforming annotations.
129///
130/// # Arguments
131/// * `coco_path` - Path to COCO annotation JSON file or ZIP archive
132/// * `output_path` - Output Arrow file path
133/// * `options` - Conversion options
134/// * `progress` - Optional progress channel
135///
136/// # Returns
137/// Number of samples converted
138pub async fn coco_to_arrow<P: AsRef<Path>>(
139    coco_path: P,
140    output_path: P,
141    options: &CocoToArrowOptions,
142    progress: Option<Sender<Progress>>,
143) -> Result<usize, Error> {
144    let coco_path = coco_path.as_ref();
145    let output_path = output_path.as_ref();
146
147    // Read COCO dataset
148    let reader = CocoReader::new();
149    let dataset = if coco_path.extension().is_some_and(|e| e == "zip") {
150        reader.read_annotations_zip(coco_path)?
151    } else {
152        reader.read_json(coco_path)?
153    };
154
155    // Build index for efficient lookups
156    let index = Arc::new(CocoIndex::from_dataset(&dataset));
157    let total_images = dataset.images.len();
158
159    // Send initial progress
160    if let Some(ref p) = progress {
161        let _ = p
162            .send(Progress {
163                current: 0,
164                total: total_images,
165            })
166            .await;
167    }
168
169    // Process images in parallel
170    let sem = Arc::new(Semaphore::new(options.max_workers));
171    let current = Arc::new(AtomicUsize::new(0));
172    let include_masks = options.include_masks;
173    let group = options.group.clone();
174
175    let mut tasks = Vec::with_capacity(total_images);
176
177    for image in dataset.images {
178        let sem = sem.clone();
179        let index = index.clone();
180        let current = current.clone();
181        let progress = progress.clone();
182        let total = total_images;
183        let group = group.clone();
184
185        let task = tokio::spawn(async move {
186            let _permit = sem.acquire().await.map_err(Error::SemaphoreError)?;
187
188            // Convert this image's annotations to EdgeFirst samples
189            let samples =
190                convert_image_annotations(&image, &index, include_masks, group.as_deref());
191
192            // Update progress
193            let c = current.fetch_add(1, Ordering::SeqCst) + 1;
194            if let Some(ref p) = progress {
195                let _ = p.send(Progress { current: c, total }).await;
196            }
197
198            Ok::<Vec<Sample>, Error>(samples)
199        });
200
201        tasks.push(task);
202    }
203
204    // Collect all samples
205    let mut all_samples = Vec::with_capacity(total_images);
206    for task in tasks {
207        let samples = task.await??;
208        all_samples.extend(samples);
209    }
210
211    // Convert to DataFrame
212    let df = crate::samples_dataframe(&all_samples)?;
213
214    // Write Arrow file
215    if let Some(parent) = output_path.parent()
216        && !parent.as_os_str().is_empty()
217    {
218        std::fs::create_dir_all(parent)?;
219    }
220    let mut file = std::fs::File::create(output_path)?;
221    IpcWriter::new(&mut file).finish(&mut df.clone())?;
222
223    Ok(all_samples.len())
224}
225
226/// Convert a single image's annotations to EdgeFirst samples.
227fn convert_image_annotations(
228    image: &CocoImage,
229    index: &CocoIndex,
230    include_masks: bool,
231    group: Option<&str>,
232) -> Vec<Sample> {
233    let annotations = index.annotations_for_image(image.id);
234    let sample_name = sample_name_from_filename(&image.file_name);
235
236    annotations
237        .iter()
238        .filter_map(|ann| {
239            let label = index.label_name(ann.category_id)?;
240            let label_index = index.label_index(ann.category_id);
241
242            // Convert bbox
243            let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
244
245            // Convert mask if present and requested
246            let mask = if include_masks {
247                ann.segmentation
248                    .as_ref()
249                    .and_then(|seg| coco_segmentation_to_mask(seg, image.width, image.height).ok())
250            } else {
251                None
252            };
253
254            let mut annotation = Annotation::new();
255            annotation.set_name(Some(sample_name.clone()));
256            annotation.set_label(Some(label.to_string()));
257            annotation.set_label_index(label_index);
258            annotation.set_box2d(Some(box2d));
259            annotation.set_mask(mask);
260            annotation.set_group(group.map(String::from));
261
262            let sample = Sample {
263                image_name: Some(sample_name.clone()),
264                width: Some(image.width),
265                height: Some(image.height),
266                group: group.map(String::from),
267                annotations: vec![annotation],
268                ..Default::default()
269            };
270
271            Some(sample)
272        })
273        .collect()
274}
275
276/// Extract sample name from image filename.
277fn sample_name_from_filename(filename: &str) -> String {
278    Path::new(filename)
279        .file_stem()
280        .and_then(|s| s.to_str())
281        .map(String::from)
282        .unwrap_or_else(|| filename.to_string())
283}
284
285/// Convert EdgeFirst Arrow format to COCO annotations.
286///
287/// Reads an Arrow file and produces COCO JSON output.
288///
289/// # Arguments
290/// * `arrow_path` - Path to EdgeFirst Arrow file
291/// * `output_path` - Output COCO JSON file path
292/// * `options` - Conversion options
293/// * `progress` - Optional progress channel
294///
295/// # Returns
296/// Number of annotations converted
297pub async fn arrow_to_coco<P: AsRef<Path>>(
298    arrow_path: P,
299    output_path: P,
300    options: &ArrowToCocoOptions,
301    progress: Option<Sender<Progress>>,
302) -> Result<usize, Error> {
303    let arrow_path = arrow_path.as_ref();
304    let output_path = output_path.as_ref();
305
306    // Read Arrow file
307    let mut file = std::fs::File::open(arrow_path)?;
308    let df = IpcReader::new(&mut file).finish()?;
309
310    // Get group column for filtering
311    let groups_to_filter: std::collections::HashSet<_> = options.groups.iter().cloned().collect();
312
313    let total_rows = df.height();
314
315    if let Some(ref p) = progress {
316        let _ = p
317            .send(Progress {
318                current: 0,
319                total: total_rows,
320            })
321            .await;
322    }
323
324    // Extract columns - all at once for O(n) instead of O(n²) per-row access
325    let names: Vec<String> = df
326        .column("name")?
327        .str()?
328        .into_iter()
329        .map(|s| s.unwrap_or_default().to_string())
330        .collect();
331
332    let labels: Vec<String> = df
333        .column("label")?
334        .cast(&DataType::String)?
335        .str()?
336        .into_iter()
337        .map(|s| s.unwrap_or_default().to_string())
338        .collect();
339
340    // Get group column for filtering
341    let groups: Vec<String> = df
342        .column("group")
343        .ok()
344        .and_then(|c| c.cast(&DataType::String).ok())
345        .map(|c| {
346            c.str()
347                .ok()
348                .map(|s| {
349                    s.into_iter()
350                        .map(|v| v.unwrap_or_default().to_string())
351                        .collect()
352                })
353                .unwrap_or_default()
354        })
355        .unwrap_or_else(|| vec!["".to_string(); total_rows]);
356
357    // Extract all box2d values upfront (O(n) instead of O(n²))
358    let box2ds = extract_all_box2ds(df.column("box2d")?)?;
359
360    // Extract all masks upfront if present
361    let masks = if options.include_masks {
362        df.column("mask").ok().map(extract_all_masks).transpose()?
363    } else {
364        None
365    };
366
367    // Extract all sizes upfront if present
368    let sizes = df
369        .column("size")
370        .ok()
371        .and_then(|c| extract_all_sizes(c).ok());
372
373    // Build COCO dataset
374    let mut builder = CocoDatasetBuilder::new();
375
376    if let Some(info) = &options.info {
377        builder = builder.info(info.clone());
378    }
379
380    // Track unique images and categories
381    let mut image_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
382    let mut image_ids: HashMap<String, u64> = HashMap::new();
383    let mut category_ids: HashMap<String, u32> = HashMap::new();
384
385    // First pass: collect unique images and categories
386    for i in 0..total_rows {
387        // Skip if group filtering is active and this row doesn't match
388        if !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i]) {
389            continue;
390        }
391
392        let name = &names[i];
393        let label = &labels[i];
394
395        // Get or estimate image dimensions
396        if !image_ids.contains_key(name) {
397            let (width, height) = sizes
398                .as_ref()
399                .and_then(|s| s.get(i).copied())
400                .unwrap_or((0, 0));
401
402            let id = builder.add_image(name, width, height);
403            image_ids.insert(name.clone(), id);
404            image_dimensions.insert(name.clone(), (width, height));
405        }
406
407        if !label.is_empty() && !category_ids.contains_key(label) {
408            let id = builder.add_category(label, None);
409            category_ids.insert(label.clone(), id);
410        }
411    }
412
413    // Second pass: create annotations
414    let mut last_progress_update = 0;
415    for i in 0..total_rows {
416        // Skip if group filtering is active and this row doesn't match
417        if !groups_to_filter.is_empty() && !groups_to_filter.contains(&groups[i]) {
418            continue;
419        }
420
421        let name = &names[i];
422        let label = &labels[i];
423
424        let image_id = *image_ids.get(name).unwrap_or(&0);
425        let category_id = *category_ids.get(label).unwrap_or(&0);
426        let (width, height) = *image_dimensions.get(name).unwrap_or(&(1, 1));
427
428        // Convert box2d from Arrow center-normalized [cx, cy, w, h] to COCO format
429        // Arrow stores center-point, Box2d expects top-left
430        let bbox = box2ds.get(i).map(|box2d| {
431            let cx = box2d[0];
432            let cy = box2d[1];
433            let w = box2d[2];
434            let h = box2d[3];
435            // Convert from center-point to top-left format
436            let left = cx - w / 2.0;
437            let top = cy - h / 2.0;
438            let ef_box2d = Box2d::new(left, top, w, h);
439            box2d_to_coco_bbox(&ef_box2d, width, height)
440        });
441
442        // Convert mask if present
443        let segmentation = if options.include_masks {
444            masks.as_ref().and_then(|m| {
445                m.get(i).and_then(|coords| {
446                    if coords.is_empty() {
447                        None
448                    } else {
449                        let polygons = unflatten_polygon_coords(coords);
450                        let mask = Mask::new(polygons);
451                        let coco_poly = mask_to_coco_polygon(&mask, width, height);
452                        if coco_poly.is_empty() {
453                            None
454                        } else {
455                            Some(CocoSegmentation::Polygon(coco_poly))
456                        }
457                    }
458                })
459            })
460        } else {
461            None
462        };
463
464        if let Some(bbox) = bbox {
465            builder.add_annotation(image_id, category_id, bbox, segmentation);
466        }
467
468        // Update progress every 1000 rows to reduce overhead
469        if let Some(ref p) = progress
470            && (i - last_progress_update >= 1000 || i == total_rows - 1)
471        {
472            let _ = p
473                .send(Progress {
474                    current: i + 1,
475                    total: total_rows,
476                })
477                .await;
478            last_progress_update = i;
479        }
480    }
481
482    let dataset = builder.build();
483    let annotation_count = dataset.annotations.len();
484
485    // Write output
486    let writer = CocoWriter::new();
487    writer.write_json(&dataset, output_path)?;
488
489    Ok(annotation_count)
490}
491
492/// Extract all box2d values from a column at once (O(n) instead of O(n²)).
493fn extract_all_box2ds(col: &Column) -> Result<Vec<[f32; 4]>, Error> {
494    let arr = col.array()?;
495    let mut result = Vec::with_capacity(arr.len());
496
497    for inner in arr.amortized_iter() {
498        let values = if let Some(inner) = inner {
499            let series = inner.as_ref();
500            let vals: Vec<f32> = series
501                .f32()
502                .map_err(|e| Error::CocoError(format!("box2d cast error: {}", e)))?
503                .into_iter()
504                .map(|v| v.unwrap_or(0.0))
505                .collect();
506
507            if vals.len() == 4 {
508                [vals[0], vals[1], vals[2], vals[3]]
509            } else {
510                [0.0, 0.0, 0.0, 0.0]
511            }
512        } else {
513            [0.0, 0.0, 0.0, 0.0]
514        };
515        result.push(values);
516    }
517
518    Ok(result)
519}
520
521/// Extract all mask coordinates from a column at once (O(n) instead of O(n²)).
522fn extract_all_masks(col: &Column) -> Result<Vec<Vec<f32>>, Error> {
523    let list = col.list()?;
524    let mut result = Vec::with_capacity(list.len());
525
526    for i in 0..list.len() {
527        let coords = match list.get_as_series(i) {
528            Some(series) => series
529                .f32()
530                .map_err(|e| Error::CocoError(format!("mask cast error: {}", e)))?
531                .into_iter()
532                .map(|v| v.unwrap_or(f32::NAN))
533                .collect(),
534            None => vec![],
535        };
536        result.push(coords);
537    }
538
539    Ok(result)
540}
541
542/// Extract all image sizes from a column at once.
543fn extract_all_sizes(col: &Column) -> Result<Vec<(u32, u32)>, Error> {
544    let arr = col.array()?;
545    let mut result = Vec::with_capacity(arr.len());
546
547    for inner in arr.amortized_iter() {
548        let size = if let Some(inner) = inner {
549            let series = inner.as_ref();
550            let values: Vec<u32> = series
551                .u32()
552                .map_err(|e| Error::CocoError(format!("size cast error: {}", e)))?
553                .into_iter()
554                .map(|v| v.unwrap_or(0))
555                .collect();
556
557            if values.len() >= 2 {
558                (values[0], values[1])
559            } else {
560                (0, 0)
561            }
562        } else {
563            (0, 0)
564        };
565        result.push(size);
566    }
567
568    Ok(result)
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use crate::coco::{CocoAnnotation, CocoCategory, CocoDataset};
575    use tempfile::TempDir;
576
577    // =========================================================================
578    // unflatten_polygon_coords tests
579    // =========================================================================
580
581    #[test]
582    fn test_unflatten_polygon_coords_empty() {
583        let coords: Vec<f32> = vec![];
584        let result = unflatten_polygon_coords(&coords);
585        assert!(result.is_empty());
586    }
587
588    #[test]
589    fn test_unflatten_polygon_coords_single_polygon() {
590        // Simple rectangle: 4 points
591        let coords = vec![0.1, 0.2, 0.3, 0.2, 0.3, 0.4, 0.1, 0.4];
592        let result = unflatten_polygon_coords(&coords);
593
594        assert_eq!(result.len(), 1);
595        assert_eq!(result[0].len(), 4);
596        assert_eq!(result[0][0], (0.1, 0.2));
597        assert_eq!(result[0][3], (0.1, 0.4));
598    }
599
600    #[test]
601    fn test_unflatten_polygon_coords_multiple_polygons() {
602        // Two triangles separated by NaN
603        let coords = vec![
604            0.1,
605            0.1,
606            0.2,
607            0.1,
608            0.15,
609            0.2,      // First triangle
610            f32::NAN, // Separator
611            0.5,
612            0.5,
613            0.6,
614            0.5,
615            0.55,
616            0.6, // Second triangle
617        ];
618        let result = unflatten_polygon_coords(&coords);
619
620        assert_eq!(result.len(), 2);
621        assert_eq!(result[0].len(), 3);
622        assert_eq!(result[1].len(), 3);
623        assert_eq!(result[0][0], (0.1, 0.1));
624        assert_eq!(result[1][0], (0.5, 0.5));
625    }
626
627    #[test]
628    fn test_unflatten_polygon_coords_leading_nan() {
629        // NaN at the start should be handled gracefully
630        let coords = vec![f32::NAN, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
631        let result = unflatten_polygon_coords(&coords);
632
633        assert_eq!(result.len(), 1);
634        assert_eq!(result[0].len(), 3);
635    }
636
637    #[test]
638    fn test_unflatten_polygon_coords_trailing_nan() {
639        // NaN at the end
640        let coords = vec![0.1, 0.2, 0.3, 0.4, f32::NAN];
641        let result = unflatten_polygon_coords(&coords);
642
643        assert_eq!(result.len(), 1);
644        assert_eq!(result[0].len(), 2);
645    }
646
647    #[test]
648    fn test_unflatten_polygon_coords_consecutive_nans() {
649        // Multiple NaNs in a row
650        let coords = vec![0.1, 0.2, f32::NAN, f32::NAN, 0.3, 0.4];
651        let result = unflatten_polygon_coords(&coords);
652
653        assert_eq!(result.len(), 2);
654        assert_eq!(result[0].len(), 1);
655        assert_eq!(result[1].len(), 1);
656    }
657
658    #[test]
659    fn test_unflatten_polygon_coords_odd_values() {
660        // Odd number of coordinates (trailing x without y)
661        let coords = vec![0.1, 0.2, 0.3, 0.4, 0.5];
662        let result = unflatten_polygon_coords(&coords);
663
664        assert_eq!(result.len(), 1);
665        assert_eq!(result[0].len(), 2); // Only complete pairs
666    }
667
668    // =========================================================================
669    // convert_image_annotations tests
670    // =========================================================================
671
672    #[test]
673    fn test_convert_image_annotations_basic() {
674        let image = CocoImage {
675            id: 1,
676            width: 640,
677            height: 480,
678            file_name: "test_image.jpg".to_string(),
679            ..Default::default()
680        };
681
682        let dataset = CocoDataset {
683            images: vec![image.clone()],
684            categories: vec![CocoCategory {
685                id: 1,
686                name: "cat".to_string(),
687                supercategory: Some("animal".to_string()),
688            }],
689            annotations: vec![CocoAnnotation {
690                id: 1,
691                image_id: 1,
692                category_id: 1,
693                bbox: [100.0, 100.0, 200.0, 200.0],
694                area: 40000.0,
695                iscrowd: 0,
696                segmentation: None,
697            }],
698            ..Default::default()
699        };
700
701        let index = CocoIndex::from_dataset(&dataset);
702        let samples = convert_image_annotations(&image, &index, true, Some("train"));
703
704        assert_eq!(samples.len(), 1);
705        assert_eq!(samples[0].image_name, Some("test_image".to_string()));
706        assert_eq!(samples[0].group, Some("train".to_string()));
707        assert_eq!(samples[0].annotations.len(), 1);
708        assert_eq!(samples[0].annotations[0].label(), Some(&"cat".to_string()));
709    }
710
711    #[test]
712    fn test_convert_image_annotations_with_mask() {
713        let image = CocoImage {
714            id: 1,
715            width: 100,
716            height: 100,
717            file_name: "masked.jpg".to_string(),
718            ..Default::default()
719        };
720
721        let dataset = CocoDataset {
722            images: vec![image.clone()],
723            categories: vec![CocoCategory {
724                id: 1,
725                name: "object".to_string(),
726                supercategory: None,
727            }],
728            annotations: vec![CocoAnnotation {
729                id: 1,
730                image_id: 1,
731                category_id: 1,
732                bbox: [10.0, 10.0, 50.0, 50.0],
733                area: 2500.0,
734                iscrowd: 0,
735                segmentation: Some(CocoSegmentation::Polygon(vec![vec![
736                    10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
737                ]])),
738            }],
739            ..Default::default()
740        };
741
742        let index = CocoIndex::from_dataset(&dataset);
743
744        // With masks enabled
745        let samples_with_mask = convert_image_annotations(&image, &index, true, None);
746        assert!(samples_with_mask[0].annotations[0].mask().is_some());
747
748        // With masks disabled
749        let samples_no_mask = convert_image_annotations(&image, &index, false, None);
750        assert!(samples_no_mask[0].annotations[0].mask().is_none());
751    }
752
753    #[test]
754    fn test_convert_image_annotations_no_annotations() {
755        let image = CocoImage {
756            id: 1,
757            width: 640,
758            height: 480,
759            file_name: "empty.jpg".to_string(),
760            ..Default::default()
761        };
762
763        let dataset = CocoDataset {
764            images: vec![image.clone()],
765            categories: vec![],
766            annotations: vec![],
767            ..Default::default()
768        };
769
770        let index = CocoIndex::from_dataset(&dataset);
771        let samples = convert_image_annotations(&image, &index, true, None);
772
773        assert!(samples.is_empty());
774    }
775
776    // =========================================================================
777    // sample_name_from_filename tests
778    // =========================================================================
779
780    #[test]
781    fn test_sample_name_from_filename() {
782        assert_eq!(
783            sample_name_from_filename("000000397133.jpg"),
784            "000000397133"
785        );
786        assert_eq!(sample_name_from_filename("train2017/image.jpg"), "image");
787        assert_eq!(sample_name_from_filename("test"), "test");
788    }
789
790    #[test]
791    fn test_sample_name_from_filename_nested_path() {
792        assert_eq!(
793            sample_name_from_filename("a/b/c/deep_image.png"),
794            "deep_image"
795        );
796    }
797
798    #[test]
799    fn test_sample_name_from_filename_no_extension() {
800        assert_eq!(sample_name_from_filename("no_extension"), "no_extension");
801    }
802
803    // =========================================================================
804    // Options tests
805    // =========================================================================
806
807    #[test]
808    fn test_coco_to_arrow_options_default() {
809        let options = CocoToArrowOptions::default();
810        assert!(options.include_masks);
811        assert!(options.group.is_none());
812        assert!(options.max_workers >= 2);
813    }
814
815    #[test]
816    fn test_arrow_to_coco_options_default() {
817        let options = ArrowToCocoOptions::default();
818        assert!(options.groups.is_empty());
819        assert!(options.include_masks);
820        assert!(options.info.is_none());
821    }
822
823    #[test]
824    fn test_max_workers() {
825        let workers = max_workers();
826        assert!(workers >= 2);
827        assert!(workers <= 8);
828    }
829
830    #[tokio::test]
831    async fn test_coco_to_arrow_minimal() {
832        let temp_dir = TempDir::new().unwrap();
833
834        // Create minimal COCO JSON
835        let coco_json = r#"{
836            "images": [
837                {"id": 1, "width": 640, "height": 480, "file_name": "test.jpg"}
838            ],
839            "annotations": [
840                {"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 100, 80], "area": 8000, "iscrowd": 0}
841            ],
842            "categories": [
843                {"id": 1, "name": "person", "supercategory": "human"}
844            ]
845        }"#;
846
847        let coco_path = temp_dir.path().join("test.json");
848        std::fs::write(&coco_path, coco_json).unwrap();
849
850        let arrow_path = temp_dir.path().join("output.arrow");
851
852        let options = CocoToArrowOptions::default();
853        let count = coco_to_arrow(&coco_path, &arrow_path, &options, None)
854            .await
855            .unwrap();
856
857        assert_eq!(count, 1);
858        assert!(arrow_path.exists());
859
860        // Verify Arrow contents
861        let mut file = std::fs::File::open(&arrow_path).unwrap();
862        let df = IpcReader::new(&mut file).finish().unwrap();
863        assert_eq!(df.height(), 1);
864    }
865
866    #[tokio::test]
867    async fn test_arrow_to_coco_roundtrip() {
868        let temp_dir = TempDir::new().unwrap();
869
870        // Create COCO JSON
871        let original = CocoDataset {
872            images: vec![CocoImage {
873                id: 1,
874                width: 640,
875                height: 480,
876                file_name: "test.jpg".to_string(),
877                ..Default::default()
878            }],
879            annotations: vec![CocoAnnotation {
880                id: 1,
881                image_id: 1,
882                category_id: 1,
883                bbox: [100.0, 50.0, 200.0, 150.0],
884                area: 30000.0,
885                iscrowd: 0,
886                segmentation: Some(CocoSegmentation::Polygon(vec![vec![
887                    100.0, 50.0, 300.0, 50.0, 300.0, 200.0, 100.0, 200.0,
888                ]])),
889            }],
890            categories: vec![CocoCategory {
891                id: 1,
892                name: "person".to_string(),
893                supercategory: Some("human".to_string()),
894            }],
895            ..Default::default()
896        };
897
898        // Write original COCO
899        let coco_path = temp_dir.path().join("original.json");
900        let writer = CocoWriter::new();
901        writer.write_json(&original, &coco_path).unwrap();
902
903        // Convert to Arrow
904        let arrow_path = temp_dir.path().join("converted.arrow");
905        let options = CocoToArrowOptions::default();
906        coco_to_arrow(&coco_path, &arrow_path, &options, None)
907            .await
908            .unwrap();
909
910        // Convert back to COCO
911        let restored_path = temp_dir.path().join("restored.json");
912        let options = ArrowToCocoOptions::default();
913        arrow_to_coco(&arrow_path, &restored_path, &options, None)
914            .await
915            .unwrap();
916
917        // Verify restored data
918        let reader = CocoReader::new();
919        let restored = reader.read_json(&restored_path).unwrap();
920
921        assert_eq!(restored.images.len(), 1);
922        assert_eq!(restored.annotations.len(), 1);
923        assert_eq!(restored.categories.len(), 1);
924
925        // Check category name preserved
926        assert_eq!(restored.categories[0].name, "person");
927    }
928}