1use super::{
14 convert::{
15 box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_mask, mask_to_coco_polygon,
16 },
17 reader::{CocoReadOptions, CocoReader, read_coco_directory},
18 types::{CocoDataset, CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
19 writer::{CocoDatasetBuilder, CocoWriteOptions, CocoWriter},
20};
21use crate::{
22 Annotation, AnnotationSetID, Client, DatasetID, Error, FileType, Progress, Sample, SampleFile,
23};
24use std::{
25 collections::HashSet,
26 path::{Path, PathBuf},
27};
28use tokio::sync::mpsc::Sender;
29
30#[derive(Debug, Clone)]
32pub struct CocoImportResult {
33 pub total_images: usize,
35 pub skipped: usize,
37 pub imported: usize,
39}
40
41#[derive(Debug, Clone)]
43pub struct CocoImportOptions {
44 pub include_masks: bool,
46 pub include_images: bool,
48 pub group: Option<String>,
50 pub batch_size: usize,
52 pub concurrency: usize,
54 pub resume: bool,
57}
58
59impl Default for CocoImportOptions {
60 fn default() -> Self {
61 Self {
62 include_masks: true,
63 include_images: true,
64 group: None,
65 batch_size: 100,
66 concurrency: 64,
67 resume: true,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct CocoExportOptions {
75 pub groups: Vec<String>,
77 pub include_masks: bool,
79 pub include_images: bool,
81 pub output_zip: bool,
83 pub pretty_json: bool,
85 pub info: Option<CocoInfo>,
87}
88
89impl Default for CocoExportOptions {
90 fn default() -> Self {
91 Self {
92 groups: vec![],
93 include_masks: true,
94 include_images: false,
95 output_zip: false,
96 pretty_json: false,
97 info: None,
98 }
99 }
100}
101
102pub async fn import_coco_to_studio(
143 client: &Client,
144 coco_path: impl AsRef<Path>,
145 dataset_id: DatasetID,
146 annotation_set_id: AnnotationSetID,
147 options: &CocoImportOptions,
148 progress: Option<Sender<Progress>>,
149) -> Result<CocoImportResult, Error> {
150 let coco_path = coco_path.as_ref();
151
152 let (dataset, images_dir) = read_coco_from_path(coco_path)?;
154
155 let total_images = dataset.images.len();
156 if total_images == 0 {
157 return Err(Error::MissingAnnotations(
158 "No images found in COCO dataset".to_string(),
159 ));
160 }
161
162 if options.include_images {
164 validate_images_extracted(&dataset, &images_dir)?;
165 }
166
167 let existing_names = fetch_existing_sample_names(client, &dataset_id, options.resume).await?;
169
170 let group_filter = options.group.as_deref();
172 let (images_to_import, skipped, filtered_by_group) =
173 filter_images_for_import(&dataset.images, group_filter, &existing_names);
174
175 log_import_filter_info(group_filter, filtered_by_group, total_images);
177
178 let to_import = images_to_import.len();
179
180 if to_import == 0 {
182 log_nothing_to_import(skipped);
183 return Ok(CocoImportResult {
184 total_images,
185 skipped,
186 imported: 0,
187 });
188 }
189
190 if skipped > 0 {
191 log::info!(
192 "Resuming import: {} of {} images already imported, {} remaining",
193 skipped,
194 total_images,
195 to_import
196 );
197 }
198
199 let index = CocoIndex::from_dataset(&dataset);
201 send_progress(&progress, 0, to_import).await;
202
203 let upload_ctx = UploadContext {
204 client,
205 dataset_id: &dataset_id,
206 annotation_set_id: &annotation_set_id,
207 options,
208 progress: &progress,
209 };
210 let imported =
211 upload_images_in_batches(&upload_ctx, &images_to_import, &index, &images_dir).await?;
212
213 Ok(CocoImportResult {
214 total_images,
215 skipped,
216 imported,
217 })
218}
219
220async fn fetch_existing_sample_names(
222 client: &Client,
223 dataset_id: &DatasetID,
224 resume: bool,
225) -> Result<HashSet<String>, Error> {
226 if !resume {
227 return Ok(HashSet::new());
228 }
229
230 log::info!("Checking for existing samples in dataset {}...", dataset_id);
231 let names = client.sample_names(*dataset_id, &[], None).await?;
232 log::info!("Found {} existing samples in dataset", names.len());
233
234 if !names.is_empty() {
235 let samples: Vec<_> = names.iter().take(3).collect();
236 log::debug!("Sample names from server: {:?}", samples);
237 }
238
239 Ok(names)
240}
241
242fn log_import_filter_info(group_filter: Option<&str>, filtered_by_group: usize, total: usize) {
244 if filtered_by_group > 0 {
245 log::info!(
246 "Group filter '{}': {} images excluded, {} matching",
247 group_filter.unwrap_or(""),
248 filtered_by_group,
249 total - filtered_by_group
250 );
251 }
252}
253
254fn log_nothing_to_import(skipped: usize) {
256 if skipped > 0 {
257 log::info!(
258 "All {} matching images already imported, nothing to do",
259 skipped
260 );
261 } else {
262 log::info!("No images to import");
263 }
264}
265
266async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
268 if let Some(p) = progress {
269 let _ = p.send(Progress { current, total }).await;
270 }
271}
272
273struct UploadContext<'a> {
275 client: &'a Client,
276 dataset_id: &'a DatasetID,
277 annotation_set_id: &'a AnnotationSetID,
278 options: &'a CocoImportOptions,
279 progress: &'a Option<Sender<Progress>>,
280}
281
282async fn upload_images_in_batches<'a>(
284 ctx: &UploadContext<'a>,
285 images: &[&CocoImage],
286 index: &CocoIndex,
287 images_dir: &Path,
288) -> Result<usize, Error> {
289 let mut imported = 0;
290 let to_import = images.len();
291
292 for batch in images.chunks(ctx.options.batch_size) {
293 let samples = convert_batch_to_samples(batch, index, images_dir, ctx.options)?;
294
295 ctx.client
296 .populate_samples_with_concurrency(
297 *ctx.dataset_id,
298 Some(*ctx.annotation_set_id),
299 samples,
300 None,
301 Some(ctx.options.concurrency),
302 )
303 .await?;
304
305 imported += batch.len();
306 send_progress(ctx.progress, imported, to_import).await;
307 }
308
309 Ok(imported)
310}
311
312fn convert_batch_to_samples(
314 batch: &[&CocoImage],
315 index: &CocoIndex,
316 images_dir: &Path,
317 options: &CocoImportOptions,
318) -> Result<Vec<Sample>, Error> {
319 let mut samples = Vec::with_capacity(batch.len());
320
321 for image in batch {
322 let image_group = super::reader::infer_group_from_folder(&image.file_name);
323 let sample = convert_coco_image_to_sample(
324 image,
325 index,
326 images_dir,
327 options.include_masks,
328 options.include_images,
329 image_group.as_deref(),
330 )?;
331 samples.push(sample);
332 }
333
334 Ok(samples)
335}
336
337fn validate_images_extracted(dataset: &CocoDataset, images_dir: &Path) -> Result<(), Error> {
339 let sample_size = std::cmp::min(5, dataset.images.len());
341 let mut missing = Vec::new();
342
343 for image in dataset.images.iter().take(sample_size) {
344 if find_image_file(images_dir, &image.file_name).is_none() {
345 missing.push(image.file_name.clone());
346 }
347 }
348
349 if !missing.is_empty() {
350 let examples: Vec<_> = missing.iter().take(3).cloned().collect();
351 return Err(Error::MissingImages(format!(
352 "Images must be extracted before import.\n\
353 Cannot find: {}\n\n\
354 Searched in: {}\n\
355 Expected subdirectories: train2017/, val2017/, images/\n\n\
356 Please extract your COCO image archives first:\n\
357 $ cd {} && unzip train2017.zip && unzip val2017.zip",
358 examples.join(", "),
359 images_dir.display(),
360 images_dir.display()
361 )));
362 }
363
364 Ok(())
365}
366
367fn find_image_file(base_dir: &Path, file_name: &str) -> Option<PathBuf> {
369 let candidates = [
370 base_dir.join(file_name),
371 base_dir.join("images").join(file_name),
372 base_dir.join("train2017").join(file_name),
373 base_dir.join("val2017").join(file_name),
374 base_dir.join("test2017").join(file_name),
375 base_dir.join("train2014").join(file_name),
376 base_dir.join("val2014").join(file_name),
377 ];
378 candidates.into_iter().find(|p| p.exists())
379}
380
381fn infer_group_from_filename(path: &Path) -> Option<String> {
389 let stem = path.file_stem()?.to_str()?;
390
391 if let Some(rest) = stem.strip_prefix("instances_") {
393 let group = rest.trim_end_matches(char::is_numeric);
395 if !group.is_empty() {
396 return Some(group.to_string());
397 }
398 }
399
400 for prefix in ["train", "val", "test", "validation"] {
402 if stem.starts_with(prefix) {
403 return Some(prefix.to_string());
404 }
405 }
406
407 None
408}
409
410fn read_coco_from_path(coco_path: &Path) -> Result<(CocoDataset, PathBuf), Error> {
414 if coco_path.is_dir() {
415 let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
417 log::info!("Found {} annotation files in directory", datasets.len());
418
419 let mut merged = CocoDataset::default();
421 for (mut ds, group) in datasets {
422 log::info!(
423 " - {} group: {} images, {} annotations",
424 group,
425 ds.images.len(),
426 ds.annotations.len()
427 );
428 for image in &mut ds.images {
431 if !image.file_name.contains('/') {
432 image.file_name = format!("{}2017/{}", group, image.file_name);
433 }
434 }
435 merge_coco_datasets(&mut merged, ds);
436 }
437 Ok((merged, coco_path.to_path_buf()))
438 } else if coco_path.extension().is_some_and(|e| e == "json") {
439 let reader = CocoReader::new();
441 let dataset = reader.read_json(coco_path)?;
442 let parent = coco_path
443 .parent()
444 .and_then(|p| p.parent()) .unwrap_or(Path::new("."));
446 Ok((dataset, parent.to_path_buf()))
447 } else {
448 Err(Error::InvalidParameters(
449 "COCO import requires a JSON annotation file or directory. \
450 ZIP archives must be extracted first."
451 .to_string(),
452 ))
453 }
454}
455
456fn filter_images_for_import<'a>(
461 images: &'a [CocoImage],
462 group_filter: Option<&str>,
463 existing_names: &HashSet<String>,
464) -> (Vec<&'a CocoImage>, usize, usize) {
465 let total = images.len();
466
467 let images_to_import: Vec<_> = images
469 .iter()
470 .filter(|img| {
471 if let Some(filter) = group_filter {
473 let inferred = super::reader::infer_group_from_folder(&img.file_name);
474 if inferred.as_deref() != Some(filter) {
475 return false;
476 }
477 }
478 let sample_name = extract_sample_name(&img.file_name);
480 !existing_names.contains(&sample_name)
481 })
482 .collect();
483
484 let filtered_by_group = if group_filter.is_some() {
486 images
487 .iter()
488 .filter(|img| {
489 let inferred = super::reader::infer_group_from_folder(&img.file_name);
490 inferred.as_deref() != group_filter
491 })
492 .count()
493 } else {
494 0
495 };
496
497 let skipped = total - filtered_by_group - images_to_import.len();
498 (images_to_import, skipped, filtered_by_group)
499}
500
501fn extract_sample_name(file_name: &str) -> String {
503 Path::new(file_name)
504 .file_stem()
505 .and_then(|s| s.to_str())
506 .map(String::from)
507 .unwrap_or_else(|| file_name.to_string())
508}
509
510fn merge_coco_datasets(target: &mut CocoDataset, source: CocoDataset) {
515 let existing_image_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
517 for image in source.images {
518 if !existing_image_ids.contains(&image.id) {
519 target.images.push(image);
520 }
521 }
522
523 let existing_cat_ids: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
525 for cat in source.categories {
526 if !existing_cat_ids.contains(&cat.id) {
527 target.categories.push(cat);
528 }
529 }
530
531 target.annotations.extend(source.annotations);
533
534 let existing_license_ids: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
536 for license in source.licenses {
537 if !existing_license_ids.contains(&license.id) {
538 target.licenses.push(license);
539 }
540 }
541
542 if target.info.description.is_none() && source.info.description.is_some() {
544 target.info = source.info;
545 }
546}
547
548fn convert_coco_image_to_sample(
550 image: &CocoImage,
551 index: &CocoIndex,
552 images_dir: &Path,
553 include_masks: bool,
554 include_images: bool,
555 group: Option<&str>,
556) -> Result<Sample, Error> {
557 let sample_name = Path::new(&image.file_name)
558 .file_stem()
559 .and_then(|s| s.to_str())
560 .map(String::from)
561 .unwrap_or_else(|| image.file_name.clone());
562
563 let annotations = index
565 .annotations_for_image(image.id)
566 .iter()
567 .filter_map(|ann| {
568 let label = index.label_name(ann.category_id)?;
569 let label_index = index.label_index(ann.category_id);
570
571 let box2d = coco_bbox_to_box2d(&ann.bbox, image.width, image.height);
572
573 let mask = if include_masks {
574 ann.segmentation
575 .as_ref()
576 .and_then(|seg| coco_segmentation_to_mask(seg, image.width, image.height).ok())
577 } else {
578 None
579 };
580
581 {
582 let mut ann = Annotation::new();
583 ann.set_name(Some(sample_name.clone()));
584 ann.set_label(Some(label.to_string()));
585 ann.set_label_index(label_index);
586 ann.set_box2d(Some(box2d));
587 ann.set_mask(mask);
588 ann.set_group(group.map(String::from));
589 Some(ann)
590 }
591 })
592 .collect();
593
594 let mut files = Vec::new();
596 if include_images && let Some(image_path) = find_image_file(images_dir, &image.file_name) {
597 files.push(SampleFile::with_filename(
598 FileType::Image.to_string(),
599 image_path.to_string_lossy().to_string(),
600 ));
601 }
602
603 Ok(Sample {
604 image_name: Some(sample_name),
605 width: Some(image.width),
606 height: Some(image.height),
607 group: group.map(String::from),
608 files,
609 annotations,
610 ..Default::default()
611 })
612}
613
614pub async fn export_studio_to_coco(
629 client: &Client,
630 dataset_id: DatasetID,
631 annotation_set_id: AnnotationSetID,
632 output_path: impl AsRef<Path>,
633 options: &CocoExportOptions,
634 progress: Option<Sender<Progress>>,
635) -> Result<usize, Error> {
636 let output_path = output_path.as_ref();
637
638 let groups: Vec<String> = options.groups.clone();
640 let annotation_types = [crate::AnnotationType::Box2d, crate::AnnotationType::Mask];
641
642 let all_samples = client
644 .samples(
645 dataset_id,
646 Some(annotation_set_id),
647 &annotation_types,
648 &groups,
649 &[],
650 progress.clone(),
651 )
652 .await?;
653
654 let mut builder = CocoDatasetBuilder::new();
656
657 if let Some(info) = &options.info {
658 builder = builder.info(info.clone());
659 }
660
661 for sample in &all_samples {
662 let image_name = sample.image_name.as_deref().unwrap_or("unknown");
663 let width = sample.width.unwrap_or(0);
664 let height = sample.height.unwrap_or(0);
665
666 let file_name = if image_name.contains('.') {
668 image_name.to_string()
669 } else {
670 format!("{}.jpg", image_name)
671 };
672 let image_id = builder.add_image(&file_name, width, height);
673
674 for ann in &sample.annotations {
675 if let Some(box2d) = ann.box2d() {
676 let label = ann.label().map(|s| s.as_str()).unwrap_or("unknown");
677 let category_id = builder.add_category(label, None);
678
679 let bbox = box2d_to_coco_bbox(box2d, width, height);
680
681 let segmentation = if options.include_masks {
682 ann.mask().map(|mask| {
683 let coco_poly = mask_to_coco_polygon(mask, width, height);
684 CocoSegmentation::Polygon(coco_poly)
685 })
686 } else {
687 None
688 };
689
690 builder.add_annotation(image_id, category_id, bbox, segmentation);
691 }
692 }
693 }
694
695 let dataset = builder.build();
696 let annotation_count = dataset.annotations.len();
697
698 let writer = CocoWriter::with_options(CocoWriteOptions {
700 compress: true,
701 pretty: options.pretty_json,
702 });
703
704 if options.output_zip {
705 let images = if options.include_images {
707 download_images(client, &all_samples, progress.clone()).await?
708 } else {
709 vec![]
710 };
711
712 writer.write_zip(&dataset, images.into_iter(), output_path)?;
713 } else {
714 writer.write_json(&dataset, output_path)?;
715 }
716
717 Ok(annotation_count)
718}
719
720async fn download_images(
725 client: &Client,
726 samples: &[Sample],
727 progress: Option<Sender<Progress>>,
728) -> Result<Vec<(String, Vec<u8>)>, Error> {
729 let mut result = Vec::with_capacity(samples.len());
730 let total = samples.len();
731
732 for (i, sample) in samples.iter().enumerate() {
733 let image_url = sample.files.iter().find_map(|f| {
735 if f.file_type() == "image" {
736 f.url()
737 } else {
738 None
739 }
740 });
741
742 if let Some(url) = image_url {
743 match client.download(url).await {
745 Ok(data) => {
746 let name = sample.image_name.as_deref().unwrap_or("unknown");
748 let filename = if name.contains('.') {
749 format!("images/{}", name)
750 } else {
751 format!("images/{}.jpg", name)
752 };
753 result.push((filename, data));
754 }
755 Err(e) => {
756 log::warn!(
758 "Failed to download image for sample {:?}: {}",
759 sample.image_name,
760 e
761 );
762 }
763 }
764 }
765
766 if let Some(ref p) = progress {
768 let _ = p
769 .send(Progress {
770 current: i + 1,
771 total,
772 })
773 .await;
774 }
775 }
776
777 Ok(result)
778}
779
780#[derive(Debug, Clone)]
782pub struct CocoVerifyOptions {
783 pub verify_masks: bool,
785 pub group: Option<String>,
787}
788
789impl Default for CocoVerifyOptions {
790 fn default() -> Self {
791 Self {
792 verify_masks: true,
793 group: None,
794 }
795 }
796}
797
798#[derive(Debug, Clone)]
800pub struct CocoUpdateResult {
801 pub total_images: usize,
803 pub updated: usize,
805 pub not_found: usize,
807}
808
809#[derive(Debug, Clone)]
811pub struct CocoUpdateOptions {
812 pub include_masks: bool,
814 pub group: Option<String>,
816 pub batch_size: usize,
818 pub concurrency: usize,
820}
821
822impl Default for CocoUpdateOptions {
823 fn default() -> Self {
824 Self {
825 include_masks: true,
826 group: None,
827 batch_size: 100,
828 concurrency: 64,
829 }
830 }
831}
832
833fn read_coco_dataset_for_update(coco_path: &Path) -> Result<CocoDataset, Error> {
835 if coco_path.is_dir() {
836 let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
838 log::info!("Found {} annotation files in directory", datasets.len());
839
840 let mut merged = CocoDataset::default();
842 for (mut ds, group) in datasets {
843 log::info!(
844 " - {} group: {} images, {} annotations",
845 group,
846 ds.images.len(),
847 ds.annotations.len()
848 );
849 for image in &mut ds.images {
851 if !image.file_name.contains('/') {
852 image.file_name = format!("{}2017/{}", group, image.file_name);
853 }
854 }
855 merge_coco_datasets(&mut merged, ds);
856 }
857 Ok(merged)
858 } else if coco_path.extension().is_some_and(|e| e == "json") {
859 let reader = CocoReader::new();
860 reader.read_json(coco_path)
861 } else {
862 Err(Error::InvalidParameters(
863 "COCO update requires a JSON annotation file or directory.".to_string(),
864 ))
865 }
866}
867
868fn build_sample_info_map(
871 samples: &[Sample],
872) -> std::collections::HashMap<String, (crate::SampleID, u32, u32, Option<String>)> {
873 use std::collections::HashMap;
874 let mut sample_info = HashMap::new();
875 for sample in samples {
876 if let (Some(name), Some(id), Some(w), Some(h)) =
877 (sample.name(), sample.id(), sample.width, sample.height)
878 {
879 sample_info.insert(name, (id, w, h, sample.group.clone()));
880 }
881 }
882 sample_info
883}
884
885async fn ensure_labels_exist(
887 client: &Client,
888 dataset_id: &DatasetID,
889 categories: &[crate::coco::CocoCategory],
890) -> Result<std::collections::HashMap<String, u64>, Error> {
891 use std::collections::{HashMap, HashSet};
892
893 let existing_labels = client.labels(*dataset_id).await?;
895 let existing_label_names: HashSet<String> = existing_labels
896 .iter()
897 .map(|l| l.name().to_string())
898 .collect();
899
900 let missing_labels: Vec<String> = categories
902 .iter()
903 .filter(|c| !existing_label_names.contains(&c.name))
904 .map(|c| c.name.clone())
905 .collect();
906
907 if !missing_labels.is_empty() {
909 log::info!(
910 "Creating {} missing labels in Studio...",
911 missing_labels.len()
912 );
913 for label_name in &missing_labels {
914 client.add_label(*dataset_id, label_name).await?;
915 }
916 }
917
918 let labels = client.labels(*dataset_id).await?;
920 let label_map: HashMap<String, u64> = labels
921 .iter()
922 .map(|l| (l.name().to_string(), l.id()))
923 .collect();
924
925 log::info!(
926 "Label map has {} entries for {} COCO categories",
927 label_map.len(),
928 categories.len()
929 );
930
931 Ok(label_map)
932}
933
934fn convert_coco_annotation_to_server(
938 coco_ann: &super::types::CocoAnnotation,
939 coco_index: &CocoIndex,
940 label_map: &std::collections::HashMap<String, u64>,
941 image_id: u64,
942 annotation_set_id: u64,
943 dims: (u32, u32),
944 include_masks: bool,
945) -> (crate::api::ServerAnnotation, bool) {
946 let (width, height) = dims;
947
948 let category_name = coco_index
950 .categories
951 .get(&coco_ann.category_id)
952 .map(|c| c.name.as_str())
953 .unwrap_or("unknown");
954
955 let label_id = label_map.get(category_name).copied();
956 let missing_label = label_id.is_none();
957
958 let box2d = coco_bbox_to_box2d(&coco_ann.bbox, width, height);
960
961 let polygon = if include_masks {
963 coco_ann
964 .segmentation
965 .as_ref()
966 .and_then(|seg| coco_segmentation_to_mask(seg, width, height).ok())
967 .map(|mask| mask_to_polygon_string(&mask))
968 .unwrap_or_default()
969 } else {
970 String::new()
971 };
972
973 let annotation_type = if polygon.is_empty() { "box" } else { "seg" }.to_string();
974
975 let server_ann = crate::api::ServerAnnotation {
976 label_id,
977 label_index: None,
978 label_name: Some(category_name.to_string()),
979 annotation_type,
980 x: box2d.left() as f64,
981 y: box2d.top() as f64,
982 w: box2d.width() as f64,
983 h: box2d.height() as f64,
984 score: 1.0,
985 polygon,
986 image_id,
987 annotation_set_id,
988 object_reference: None,
989 };
990
991 (server_ann, missing_label)
992}
993
994fn process_image_for_update(
997 coco_image: &CocoImage,
998 sample_info: &std::collections::HashMap<String, (crate::SampleID, u32, u32, Option<String>)>,
999 coco_index: &CocoIndex,
1000 label_map: &std::collections::HashMap<String, u64>,
1001 annotation_set_id: u64,
1002 include_masks: bool,
1003) -> Option<(
1004 crate::SampleID,
1005 Vec<crate::api::ServerAnnotation>,
1006 Option<String>,
1007 usize,
1008)> {
1009 let sample_name = extract_sample_name(&coco_image.file_name);
1010 let expected_group = super::reader::infer_group_from_folder(&coco_image.file_name);
1011
1012 let (sample_id, width, height, current_group) = sample_info.get(&sample_name)?;
1013 let (sample_id, width, height) = (*sample_id, *width, *height);
1014 let image_id: u64 = sample_id.into();
1015
1016 let group_update = expected_group.as_ref().and_then(|expected| {
1018 if Some(expected) != current_group.as_ref() {
1019 Some(expected.clone())
1020 } else {
1021 None
1022 }
1023 });
1024
1025 let mut annotations = Vec::new();
1027 let mut missing_label_count = 0;
1028
1029 for coco_ann in coco_index.annotations_for_image(coco_image.id) {
1030 let (server_ann, missing) = convert_coco_annotation_to_server(
1031 coco_ann,
1032 coco_index,
1033 label_map,
1034 image_id,
1035 annotation_set_id,
1036 (width, height),
1037 include_masks,
1038 );
1039 if missing {
1040 missing_label_count += 1;
1041 }
1042 annotations.push(server_ann);
1043 }
1044
1045 Some((sample_id, annotations, group_update, missing_label_count))
1046}
1047
1048async fn update_sample_groups(
1050 client: &Client,
1051 dataset_id: &DatasetID,
1052 samples_needing_group_update: &[(crate::SampleID, String)],
1053) -> usize {
1054 use std::collections::{HashMap, HashSet};
1055
1056 if samples_needing_group_update.is_empty() {
1057 return 0;
1058 }
1059
1060 log::info!(
1061 "Updating groups for {} samples...",
1062 samples_needing_group_update.len()
1063 );
1064
1065 let unique_groups: HashSet<String> = samples_needing_group_update
1067 .iter()
1068 .map(|(_, group)| group.clone())
1069 .collect();
1070
1071 let mut group_id_map: HashMap<String, u64> = HashMap::new();
1072 for group_name in unique_groups {
1073 match client.get_or_create_group(*dataset_id, &group_name).await {
1074 Ok(group_id) => {
1075 group_id_map.insert(group_name, group_id);
1076 }
1077 Err(e) => {
1078 log::warn!("Failed to get/create group '{}': {}", group_name, e);
1079 }
1080 }
1081 }
1082
1083 let mut updated_count = 0;
1085 let mut failed_count = 0;
1086 for (sample_id, group_name) in samples_needing_group_update {
1087 if let Some(&group_id) = group_id_map.get(group_name) {
1088 match client.set_sample_group_id(*sample_id, group_id).await {
1089 Ok(_) => {
1090 updated_count += 1;
1091 if updated_count % 1000 == 0 {
1092 log::debug!("Updated groups for {} samples so far", updated_count);
1093 }
1094 }
1095 Err(e) => {
1096 failed_count += 1;
1097 if failed_count <= 5 {
1098 log::warn!("Failed to update group for sample {:?}: {}", sample_id, e);
1099 }
1100 }
1101 }
1102 }
1103 }
1104
1105 if failed_count > 5 {
1106 log::warn!("... and {} more group update failures", failed_count - 5);
1107 }
1108 log::info!(
1109 "Updated groups for {} samples ({} failed)",
1110 updated_count,
1111 failed_count
1112 );
1113
1114 updated_count
1115}
1116
1117pub async fn update_coco_annotations(
1137 client: &Client,
1138 coco_path: impl AsRef<Path>,
1139 dataset_id: DatasetID,
1140 annotation_set_id: AnnotationSetID,
1141 options: &CocoUpdateOptions,
1142 progress: Option<Sender<Progress>>,
1143) -> Result<CocoUpdateResult, Error> {
1144 use crate::{SampleID, api::ServerAnnotation};
1145
1146 let coco_path = coco_path.as_ref();
1147
1148 let dataset = read_coco_dataset_for_update(coco_path)?;
1150 let total_images = dataset.images.len();
1151
1152 if total_images == 0 {
1153 return Err(Error::MissingAnnotations(
1154 "No images found in COCO dataset".to_string(),
1155 ));
1156 }
1157
1158 log::info!(
1159 "COCO dataset: {} images, {} annotations, {} categories",
1160 total_images,
1161 dataset.annotations.len(),
1162 dataset.categories.len()
1163 );
1164
1165 log::info!("Fetching existing samples from Studio...");
1167 let existing_samples = client
1168 .samples(
1169 dataset_id,
1170 Some(annotation_set_id),
1171 &[],
1172 &[],
1173 &[],
1174 progress.clone(),
1175 )
1176 .await?;
1177
1178 let sample_info = build_sample_info_map(&existing_samples);
1179 log::info!(
1180 "Found {} existing samples in Studio with IDs and dimensions",
1181 sample_info.len()
1182 );
1183
1184 let coco_index = CocoIndex::from_dataset(&dataset);
1186
1187 let label_map = ensure_labels_exist(client, &dataset_id, &dataset.categories).await?;
1189
1190 let annotation_set_id_u64: u64 = annotation_set_id.into();
1192 let mut sample_ids_to_update: Vec<SampleID> = Vec::new();
1193 let mut server_annotations: Vec<ServerAnnotation> = Vec::new();
1194 let mut samples_needing_group_update: Vec<(SampleID, String)> = Vec::new();
1195 let mut not_found = 0;
1196 let mut missing_label_count = 0;
1197
1198 for coco_image in &dataset.images {
1199 match process_image_for_update(
1200 coco_image,
1201 &sample_info,
1202 &coco_index,
1203 &label_map,
1204 annotation_set_id_u64,
1205 options.include_masks,
1206 ) {
1207 Some((sample_id, annotations, group_update, missing_labels)) => {
1208 sample_ids_to_update.push(sample_id);
1209 server_annotations.extend(annotations);
1210 missing_label_count += missing_labels;
1211 if let Some(group) = group_update {
1212 samples_needing_group_update.push((sample_id, group));
1213 }
1214 }
1215 None => {
1216 not_found += 1;
1217 log::debug!(
1218 "Sample not found in Studio: {}",
1219 extract_sample_name(&coco_image.file_name)
1220 );
1221 }
1222 }
1223 }
1224
1225 let to_update = sample_ids_to_update.len();
1226 log::info!(
1227 "Updating {} samples ({} not found in Studio), {} annotations",
1228 to_update,
1229 not_found,
1230 server_annotations.len()
1231 );
1232
1233 if missing_label_count > 0 {
1234 log::warn!(
1235 "{} annotations have missing label_id (category not found in label map)",
1236 missing_label_count
1237 );
1238 }
1239
1240 if to_update == 0 {
1241 return Ok(CocoUpdateResult {
1242 total_images,
1243 updated: 0,
1244 not_found,
1245 });
1246 }
1247
1248 if let Some(ref tx) = progress {
1250 let _ = tx
1251 .send(Progress {
1252 current: 0,
1253 total: to_update,
1254 })
1255 .await;
1256 }
1257
1258 log::info!(
1260 "Deleting existing annotations for {} samples...",
1261 sample_ids_to_update.len()
1262 );
1263 let annotation_types = if options.include_masks {
1264 vec!["box".to_string(), "seg".to_string()]
1265 } else {
1266 vec!["box".to_string()]
1267 };
1268
1269 for batch in sample_ids_to_update.chunks(options.batch_size) {
1271 client
1272 .delete_annotations_bulk(annotation_set_id, &annotation_types, batch)
1273 .await?;
1274 }
1275
1276 if let Some(ref tx) = progress {
1278 let _ = tx
1279 .send(Progress {
1280 current: to_update / 2,
1281 total: to_update,
1282 })
1283 .await;
1284 }
1285
1286 log::info!("Adding {} new annotations...", server_annotations.len());
1288 let mut added = 0;
1289 for batch in server_annotations.chunks(options.batch_size) {
1290 client
1291 .add_annotations_bulk(annotation_set_id, batch.to_vec())
1292 .await?;
1293 added += batch.len();
1294 log::debug!("Added {} annotations so far", added);
1295 }
1296
1297 if let Some(ref tx) = progress {
1299 let _ = tx
1300 .send(Progress {
1301 current: to_update,
1302 total: to_update,
1303 })
1304 .await;
1305 }
1306
1307 let groups_updated =
1309 update_sample_groups(client, &dataset_id, &samples_needing_group_update).await;
1310
1311 log::info!(
1312 "Update complete: {} samples updated, {} not found, {} annotations added, {} groups updated",
1313 to_update,
1314 not_found,
1315 added,
1316 groups_updated
1317 );
1318
1319 Ok(CocoUpdateResult {
1320 total_images,
1321 updated: to_update,
1322 not_found,
1323 })
1324}
1325
1326fn mask_to_polygon_string(mask: &crate::Mask) -> String {
1339 let polygons: Vec<Vec<[f32; 2]>> = mask
1343 .polygon
1344 .iter()
1345 .map(|ring| {
1346 ring.iter()
1347 .filter(|(x, y)| x.is_finite() && y.is_finite())
1348 .map(|&(x, y)| [x, y])
1349 .collect()
1350 })
1351 .filter(|ring: &Vec<[f32; 2]>| ring.len() >= 3) .collect();
1353
1354 serde_json::to_string(&polygons).unwrap_or_default()
1355}
1356
1357fn compute_bbox_from_mask(mask: &crate::Mask, width: u32, height: u32) -> Option<[f64; 4]> {
1362 if mask.polygon.is_empty() {
1363 return None;
1364 }
1365
1366 let mut min_x = f32::MAX;
1367 let mut min_y = f32::MAX;
1368 let mut max_x = f32::MIN;
1369 let mut max_y = f32::MIN;
1370
1371 for ring in &mask.polygon {
1372 for &(x, y) in ring {
1373 if x.is_finite() && y.is_finite() {
1374 min_x = min_x.min(x);
1375 min_y = min_y.min(y);
1376 max_x = max_x.max(x);
1377 max_y = max_y.max(y);
1378 }
1379 }
1380 }
1381
1382 if min_x == f32::MAX || min_y == f32::MAX {
1383 return None;
1384 }
1385
1386 let x = (min_x * width as f32) as f64;
1388 let y = (min_y * height as f32) as f64;
1389 let w = ((max_x - min_x) * width as f32) as f64;
1390 let h = ((max_y - min_y) * height as f32) as f64;
1391
1392 if w > 0.0 && h > 0.0 {
1393 Some([x, y, w, h])
1394 } else {
1395 None
1396 }
1397}
1398
1399pub async fn verify_coco_import(
1420 client: &Client,
1421 coco_path: impl AsRef<Path>,
1422 dataset_id: DatasetID,
1423 annotation_set_id: AnnotationSetID,
1424 options: &CocoVerifyOptions,
1425 progress: Option<Sender<Progress>>,
1426) -> Result<super::verify::VerificationResult, Error> {
1427 use super::{verify::*, writer::CocoDatasetBuilder};
1428
1429 let coco_path = coco_path.as_ref();
1430
1431 log::info!("Reading local COCO dataset from {:?}", coco_path);
1433 let (coco_dataset, inferred_group) = if coco_path.is_dir() {
1434 let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
1436 log::info!("Found {} annotation files in directory", datasets.len());
1437
1438 let mut merged = CocoDataset::default();
1439 for (ds, group) in datasets {
1440 log::info!(
1441 " - {} group: {} images, {} annotations",
1442 group,
1443 ds.images.len(),
1444 ds.annotations.len()
1445 );
1446 merge_coco_datasets(&mut merged, ds);
1447 }
1448 (merged, None)
1450 } else if coco_path.extension().is_some_and(|e| e == "json") {
1451 let reader = CocoReader::new();
1452 let dataset = reader.read_json(coco_path)?;
1453 let group = infer_group_from_filename(coco_path);
1454 (dataset, group)
1455 } else {
1456 return Err(Error::InvalidParameters(
1457 "COCO verification requires a JSON annotation file or directory.".to_string(),
1458 ));
1459 };
1460
1461 let effective_group = options.group.clone().or(inferred_group);
1463 let groups: Vec<String> = effective_group
1464 .as_ref()
1465 .map(|g| vec![g.clone()])
1466 .unwrap_or_default();
1467
1468 log::info!(
1469 "Local COCO: {} images, {} annotations",
1470 coco_dataset.images.len(),
1471 coco_dataset.annotations.len()
1472 );
1473
1474 log::info!("Fetching samples from Studio dataset {}...", dataset_id);
1476 let annotation_types = [crate::AnnotationType::Box2d, crate::AnnotationType::Mask];
1477
1478 let studio_samples = client
1479 .samples(
1480 dataset_id,
1481 Some(annotation_set_id),
1482 &annotation_types,
1483 &groups,
1484 &[],
1485 progress.clone(),
1486 )
1487 .await?;
1488
1489 let total_annotations: usize = studio_samples.iter().map(|s| s.annotations.len()).sum();
1490 log::info!(
1491 "Studio: {} samples, {} total annotations",
1492 studio_samples.len(),
1493 total_annotations
1494 );
1495
1496 let mut builder = CocoDatasetBuilder::new();
1498
1499 for sample in &studio_samples {
1500 let image_name = sample.image_name.as_deref().unwrap_or("unknown");
1501 let width = sample.width.unwrap_or(0);
1502 let height = sample.height.unwrap_or(0);
1503
1504 let file_name = if image_name.contains('.') {
1506 image_name.to_string()
1507 } else {
1508 format!("{}.jpg", image_name)
1509 };
1510 let image_id = builder.add_image(&file_name, width, height);
1511
1512 for ann in &sample.annotations {
1513 let bbox = if let Some(box2d) = ann.box2d() {
1515 Some(box2d_to_coco_bbox(box2d, width, height))
1516 } else if let Some(mask) = ann.mask() {
1517 compute_bbox_from_mask(mask, width, height)
1519 } else {
1520 None
1521 };
1522
1523 if let Some(bbox) = bbox {
1524 let label = ann.label().map(|s| s.as_str()).unwrap_or("unknown");
1525 let category_id = builder.add_category(label, None);
1526
1527 let segmentation = if options.verify_masks {
1528 ann.mask().map(|mask| {
1529 let coco_poly = mask_to_coco_polygon(mask, width, height);
1530 CocoSegmentation::Polygon(coco_poly)
1531 })
1532 } else {
1533 None
1534 };
1535
1536 builder.add_annotation(image_id, category_id, bbox, segmentation);
1537 }
1538 }
1539 }
1540
1541 let studio_dataset = builder.build();
1542
1543 let coco_names: HashSet<String> = coco_dataset
1545 .images
1546 .iter()
1547 .map(|img| {
1548 Path::new(&img.file_name)
1549 .file_stem()
1550 .and_then(|s| s.to_str())
1551 .map(String::from)
1552 .unwrap_or_else(|| img.file_name.clone())
1553 })
1554 .collect();
1555
1556 let studio_names: HashSet<String> = studio_samples.iter().filter_map(|s| s.name()).collect();
1557
1558 let missing_images: Vec<String> = coco_names.difference(&studio_names).cloned().collect();
1559 let extra_images: Vec<String> = studio_names.difference(&coco_names).cloned().collect();
1560
1561 log::info!("Validating bounding boxes...");
1563 let bbox_validation = validate_bboxes(&coco_dataset, &studio_dataset);
1564
1565 log::info!("Validating segmentation masks...");
1567 let mask_validation = if options.verify_masks {
1568 validate_masks(&coco_dataset, &studio_dataset)
1569 } else {
1570 MaskValidationResult::new()
1571 };
1572
1573 let category_validation = validate_categories(&coco_dataset, &studio_dataset);
1575
1576 Ok(VerificationResult {
1577 coco_image_count: coco_dataset.images.len(),
1578 studio_image_count: studio_samples.len(),
1579 missing_images,
1580 extra_images,
1581 coco_annotation_count: coco_dataset.annotations.len(),
1582 studio_annotation_count: studio_dataset.annotations.len(),
1583 bbox_validation,
1584 mask_validation,
1585 category_validation,
1586 })
1587}
1588
1589#[cfg(test)]
1590mod tests {
1591 use super::*;
1592 use crate::coco::{CocoAnnotation, CocoCategory};
1593
1594 #[test]
1599 fn test_coco_import_options_default() {
1600 let options = CocoImportOptions::default();
1601 assert!(options.include_masks);
1602 assert!(options.include_images);
1603 assert!(options.group.is_none());
1604 assert_eq!(options.batch_size, 100);
1605 assert_eq!(options.concurrency, 64);
1606 assert!(options.resume);
1607 }
1608
1609 #[test]
1610 fn test_coco_export_options_default() {
1611 let options = CocoExportOptions::default();
1612 assert!(options.groups.is_empty());
1613 assert!(options.include_masks);
1614 assert!(!options.include_images);
1615 assert!(!options.output_zip);
1616 assert!(!options.pretty_json);
1617 assert!(options.info.is_none());
1618 }
1619
1620 #[test]
1621 fn test_coco_update_options_default() {
1622 let options = CocoUpdateOptions::default();
1623 assert!(options.include_masks);
1624 assert!(options.group.is_none());
1625 assert_eq!(options.batch_size, 100);
1626 assert_eq!(options.concurrency, 64);
1627 }
1628
1629 #[test]
1630 fn test_coco_verify_options_default() {
1631 let options = CocoVerifyOptions::default();
1632 assert!(options.verify_masks);
1633 assert!(options.group.is_none());
1634 }
1635
1636 #[test]
1641 fn test_find_image_file_nonexistent() {
1642 let result = find_image_file(Path::new("/nonexistent"), "test.jpg");
1643 assert!(result.is_none());
1644 }
1645
1646 #[test]
1647 fn test_find_image_file_with_subdirectory_in_name() {
1648 let result = find_image_file(Path::new("/nonexistent"), "train2017/image.jpg");
1650 assert!(result.is_none()); }
1652
1653 #[test]
1658 fn test_infer_group_from_filename_instances_train() {
1659 let path = Path::new("annotations/instances_train2017.json");
1660 assert_eq!(infer_group_from_filename(path), Some("train".to_string()));
1661 }
1662
1663 #[test]
1664 fn test_infer_group_from_filename_instances_val() {
1665 let path = Path::new("annotations/instances_val2017.json");
1666 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1667 }
1668
1669 #[test]
1670 fn test_infer_group_from_filename_instances_test() {
1671 let path = Path::new("instances_test2017.json");
1672 assert_eq!(infer_group_from_filename(path), Some("test".to_string()));
1673 }
1674
1675 #[test]
1676 fn test_infer_group_from_filename_train_prefix() {
1677 let path = Path::new("train_annotations.json");
1678 assert_eq!(infer_group_from_filename(path), Some("train".to_string()));
1679 }
1680
1681 #[test]
1682 fn test_infer_group_from_filename_val_prefix() {
1683 let path = Path::new("val_data.json");
1684 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1685 }
1686
1687 #[test]
1688 fn test_infer_group_from_filename_validation_prefix() {
1689 let path = Path::new("validation_set.json");
1692 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1693 }
1694
1695 #[test]
1696 fn test_infer_group_from_filename_custom() {
1697 let path = Path::new("my_custom_annotations.json");
1698 assert_eq!(infer_group_from_filename(path), None);
1699 }
1700
1701 #[test]
1702 fn test_infer_group_from_filename_instances_2014() {
1703 let path = Path::new("instances_val2014.json");
1704 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1705 }
1706
1707 #[test]
1712 fn test_merge_coco_datasets_empty() {
1713 let mut target = CocoDataset::default();
1714 let source = CocoDataset::default();
1715 merge_coco_datasets(&mut target, source);
1716 assert!(target.images.is_empty());
1717 assert!(target.annotations.is_empty());
1718 assert!(target.categories.is_empty());
1719 }
1720
1721 #[test]
1722 fn test_merge_coco_datasets_basic() {
1723 let mut target = CocoDataset {
1724 images: vec![CocoImage {
1725 id: 1,
1726 file_name: "img1.jpg".to_string(),
1727 ..Default::default()
1728 }],
1729 categories: vec![CocoCategory {
1730 id: 1,
1731 name: "cat".to_string(),
1732 supercategory: None,
1733 }],
1734 annotations: vec![CocoAnnotation {
1735 id: 1,
1736 image_id: 1,
1737 category_id: 1,
1738 ..Default::default()
1739 }],
1740 ..Default::default()
1741 };
1742
1743 let source = CocoDataset {
1744 images: vec![CocoImage {
1745 id: 2,
1746 file_name: "img2.jpg".to_string(),
1747 ..Default::default()
1748 }],
1749 categories: vec![CocoCategory {
1750 id: 2,
1751 name: "dog".to_string(),
1752 supercategory: None,
1753 }],
1754 annotations: vec![CocoAnnotation {
1755 id: 2,
1756 image_id: 2,
1757 category_id: 2,
1758 ..Default::default()
1759 }],
1760 ..Default::default()
1761 };
1762
1763 merge_coco_datasets(&mut target, source);
1764
1765 assert_eq!(target.images.len(), 2);
1766 assert_eq!(target.categories.len(), 2);
1767 assert_eq!(target.annotations.len(), 2);
1768 }
1769
1770 #[test]
1771 fn test_merge_coco_datasets_deduplicates_images() {
1772 let mut target = CocoDataset {
1773 images: vec![CocoImage {
1774 id: 1,
1775 file_name: "img1.jpg".to_string(),
1776 ..Default::default()
1777 }],
1778 ..Default::default()
1779 };
1780
1781 let source = CocoDataset {
1782 images: vec![
1783 CocoImage {
1784 id: 1, file_name: "img1_dup.jpg".to_string(),
1786 ..Default::default()
1787 },
1788 CocoImage {
1789 id: 2,
1790 file_name: "img2.jpg".to_string(),
1791 ..Default::default()
1792 },
1793 ],
1794 ..Default::default()
1795 };
1796
1797 merge_coco_datasets(&mut target, source);
1798
1799 assert_eq!(target.images.len(), 2); assert_eq!(target.images[0].file_name, "img1.jpg"); }
1802
1803 #[test]
1804 fn test_merge_coco_datasets_deduplicates_categories() {
1805 let mut target = CocoDataset {
1806 categories: vec![CocoCategory {
1807 id: 1,
1808 name: "person".to_string(),
1809 supercategory: None,
1810 }],
1811 ..Default::default()
1812 };
1813
1814 let source = CocoDataset {
1815 categories: vec![
1816 CocoCategory {
1817 id: 1, name: "person_dup".to_string(),
1819 supercategory: None,
1820 },
1821 CocoCategory {
1822 id: 2,
1823 name: "car".to_string(),
1824 supercategory: None,
1825 },
1826 ],
1827 ..Default::default()
1828 };
1829
1830 merge_coco_datasets(&mut target, source);
1831
1832 assert_eq!(target.categories.len(), 2);
1833 assert_eq!(target.categories[0].name, "person"); }
1835
1836 #[test]
1837 fn test_merge_coco_datasets_info_preserved() {
1838 let mut target = CocoDataset::default();
1839
1840 let source = CocoDataset {
1841 info: CocoInfo {
1842 description: Some("Test dataset".to_string()),
1843 ..Default::default()
1844 },
1845 ..Default::default()
1846 };
1847
1848 merge_coco_datasets(&mut target, source);
1849
1850 assert_eq!(target.info.description, Some("Test dataset".to_string()));
1851 }
1852
1853 #[test]
1858 fn test_convert_coco_image_to_sample() {
1859 let image = CocoImage {
1860 id: 1,
1861 width: 640,
1862 height: 480,
1863 file_name: "test.jpg".to_string(),
1864 ..Default::default()
1865 };
1866
1867 let dataset = CocoDataset {
1868 images: vec![image.clone()],
1869 categories: vec![CocoCategory {
1870 id: 1,
1871 name: "person".to_string(),
1872 supercategory: None,
1873 }],
1874 annotations: vec![CocoAnnotation {
1875 id: 1,
1876 image_id: 1,
1877 category_id: 1,
1878 bbox: [100.0, 50.0, 200.0, 150.0],
1879 area: 30000.0,
1880 iscrowd: 0,
1881 segmentation: None,
1882 }],
1883 ..Default::default()
1884 };
1885
1886 let index = CocoIndex::from_dataset(&dataset);
1887
1888 let sample = convert_coco_image_to_sample(
1889 &image,
1890 &index,
1891 Path::new("/tmp"),
1892 true,
1893 false, Some("train"),
1895 )
1896 .unwrap();
1897
1898 assert_eq!(sample.image_name, Some("test".to_string()));
1899 assert_eq!(sample.width, Some(640));
1900 assert_eq!(sample.height, Some(480));
1901 assert_eq!(sample.group, Some("train".to_string()));
1902 assert_eq!(sample.annotations.len(), 1);
1903 assert_eq!(sample.annotations[0].label(), Some(&"person".to_string()));
1904 }
1905
1906 #[test]
1907 fn test_convert_coco_image_to_sample_no_annotations() {
1908 let image = CocoImage {
1909 id: 1,
1910 width: 640,
1911 height: 480,
1912 file_name: "empty.jpg".to_string(),
1913 ..Default::default()
1914 };
1915
1916 let dataset = CocoDataset {
1917 images: vec![image.clone()],
1918 categories: vec![],
1919 annotations: vec![],
1920 ..Default::default()
1921 };
1922
1923 let index = CocoIndex::from_dataset(&dataset);
1924
1925 let sample =
1926 convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), true, false, None)
1927 .unwrap();
1928
1929 assert_eq!(sample.image_name, Some("empty".to_string()));
1930 assert!(sample.annotations.is_empty());
1931 }
1932
1933 #[test]
1934 fn test_convert_coco_image_to_sample_with_mask() {
1935 let image = CocoImage {
1936 id: 1,
1937 width: 100,
1938 height: 100,
1939 file_name: "masked.jpg".to_string(),
1940 ..Default::default()
1941 };
1942
1943 let dataset = CocoDataset {
1944 images: vec![image.clone()],
1945 categories: vec![CocoCategory {
1946 id: 1,
1947 name: "object".to_string(),
1948 supercategory: None,
1949 }],
1950 annotations: vec![CocoAnnotation {
1951 id: 1,
1952 image_id: 1,
1953 category_id: 1,
1954 bbox: [10.0, 10.0, 50.0, 50.0],
1955 area: 2500.0,
1956 iscrowd: 0,
1957 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
1958 10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
1959 ]])),
1960 }],
1961 ..Default::default()
1962 };
1963
1964 let index = CocoIndex::from_dataset(&dataset);
1965
1966 let sample_with =
1968 convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), true, false, None)
1969 .unwrap();
1970 assert!(sample_with.annotations[0].mask().is_some());
1971
1972 let sample_without =
1974 convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), false, false, None)
1975 .unwrap();
1976 assert!(sample_without.annotations[0].mask().is_none());
1977 }
1978
1979 #[test]
1984 fn test_compute_bbox_from_mask_simple() {
1985 let mask = crate::Mask::new(vec![vec![(0.1, 0.1), (0.5, 0.1), (0.5, 0.5), (0.1, 0.5)]]);
1986
1987 let bbox = compute_bbox_from_mask(&mask, 100, 100);
1988
1989 assert!(bbox.is_some());
1990 let [x, y, w, h] = bbox.unwrap();
1991 assert!((x - 10.0).abs() < 1.0);
1992 assert!((y - 10.0).abs() < 1.0);
1993 assert!((w - 40.0).abs() < 1.0);
1994 assert!((h - 40.0).abs() < 1.0);
1995 }
1996
1997 #[test]
1998 fn test_compute_bbox_from_mask_empty() {
1999 let mask = crate::Mask::new(vec![]);
2000 let bbox = compute_bbox_from_mask(&mask, 100, 100);
2001 assert!(bbox.is_none());
2002 }
2003
2004 #[test]
2005 fn test_compute_bbox_from_mask_with_nan() {
2006 let mask = crate::Mask::new(vec![vec![(f32::NAN, f32::NAN), (f32::NAN, f32::NAN)]]);
2007 let bbox = compute_bbox_from_mask(&mask, 100, 100);
2008 assert!(bbox.is_none());
2009 }
2010
2011 #[test]
2012 fn test_compute_bbox_from_mask_multiple_rings() {
2013 let mask = crate::Mask::new(vec![
2015 vec![(0.1, 0.1), (0.2, 0.1), (0.2, 0.2), (0.1, 0.2)],
2016 vec![(0.8, 0.8), (0.9, 0.8), (0.9, 0.9), (0.8, 0.9)],
2017 ]);
2018
2019 let bbox = compute_bbox_from_mask(&mask, 100, 100);
2020
2021 assert!(bbox.is_some());
2022 let [x, y, w, h] = bbox.unwrap();
2023 assert!((x - 10.0).abs() < 1.0);
2025 assert!((y - 10.0).abs() < 1.0);
2026 assert!((w - 80.0).abs() < 1.0);
2027 assert!((h - 80.0).abs() < 1.0);
2028 }
2029
2030 #[test]
2035 fn test_mask_to_polygon_string() {
2036 let mask = crate::Mask::new(vec![vec![(0.1, 0.2), (0.3, 0.4), (0.5, 0.6)]]);
2038
2039 let result = mask_to_polygon_string(&mask);
2040
2041 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2044 }
2045
2046 #[test]
2047 fn test_mask_to_polygon_string_multiple_rings() {
2048 let mask = crate::Mask::new(vec![
2051 vec![(0.1, 0.1), (0.2, 0.1), (0.15, 0.2)], vec![(0.5, 0.5), (0.6, 0.5), (0.55, 0.6)], ]);
2054
2055 let result = mask_to_polygon_string(&mask);
2056
2057 assert_eq!(
2059 result,
2060 "[[[0.1,0.1],[0.2,0.1],[0.15,0.2]],[[0.5,0.5],[0.6,0.5],[0.55,0.6]]]"
2061 );
2062 }
2063
2064 #[test]
2065 fn test_mask_to_polygon_string_filters_nan_values() {
2066 let mask = crate::Mask::new(vec![vec![
2068 (0.1, 0.2),
2069 (f32::NAN, 0.4), (0.3, 0.4),
2071 (0.5, 0.6),
2072 ]]);
2073
2074 let result = mask_to_polygon_string(&mask);
2075
2076 assert!(
2078 !result.contains("null"),
2079 "NaN values should be filtered out, got: {}",
2080 result
2081 );
2082 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2084 }
2085
2086 #[test]
2087 fn test_mask_to_polygon_string_filters_infinity() {
2088 let mask = crate::Mask::new(vec![vec![
2090 (0.1, 0.2),
2091 (f32::INFINITY, 0.4), (0.3, 0.4),
2093 (0.5, 0.6),
2094 ]]);
2095
2096 let result = mask_to_polygon_string(&mask);
2097
2098 assert!(
2099 !result.contains("null"),
2100 "Infinity values should be filtered out"
2101 );
2102 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2103 }
2104
2105 #[test]
2106 fn test_mask_to_polygon_string_too_few_points_after_filter() {
2107 let mask = crate::Mask::new(vec![vec![
2109 (0.1, 0.2),
2110 (f32::NAN, 0.4), (f32::NAN, f32::NAN), ]]);
2113
2114 let result = mask_to_polygon_string(&mask);
2115
2116 assert_eq!(result, "[]");
2118 }
2119
2120 #[test]
2121 fn test_mask_to_polygon_string_negative_infinity() {
2122 let mask = crate::Mask::new(vec![vec![
2123 (0.1, 0.2),
2124 (f32::NEG_INFINITY, 0.4), (0.3, 0.4),
2126 (0.5, 0.6),
2127 ]]);
2128
2129 let result = mask_to_polygon_string(&mask);
2130 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2131 }
2132
2133 #[test]
2138 fn test_coco_import_result() {
2139 let result = CocoImportResult {
2140 total_images: 100,
2141 skipped: 30,
2142 imported: 70,
2143 };
2144
2145 assert_eq!(result.total_images, 100);
2146 assert_eq!(result.skipped, 30);
2147 assert_eq!(result.imported, 70);
2148 }
2149
2150 #[test]
2155 fn test_coco_update_result() {
2156 let result = CocoUpdateResult {
2157 total_images: 500,
2158 updated: 450,
2159 not_found: 50,
2160 };
2161
2162 assert_eq!(result.total_images, 500);
2163 assert_eq!(result.updated, 450);
2164 assert_eq!(result.not_found, 50);
2165 }
2166
2167 #[test]
2172 fn test_read_coco_dataset_for_update_invalid_extension() {
2173 let result = read_coco_dataset_for_update(Path::new("/tmp/file.txt"));
2174 assert!(result.is_err());
2175 let err = result.unwrap_err();
2176 assert!(
2177 err.to_string()
2178 .contains("COCO update requires a JSON annotation file")
2179 );
2180 }
2181
2182 #[test]
2183 fn test_read_coco_dataset_for_update_nonexistent_json() {
2184 let result = read_coco_dataset_for_update(Path::new("/nonexistent/file.json"));
2185 assert!(result.is_err());
2186 }
2187
2188 #[test]
2189 fn test_read_coco_dataset_for_update_nonexistent_directory() {
2190 let result = read_coco_dataset_for_update(Path::new("/nonexistent_dir"));
2191 assert!(result.is_err());
2193 }
2194
2195 #[test]
2200 fn test_build_sample_info_map_empty() {
2201 let samples: Vec<crate::Sample> = vec![];
2202 let map = build_sample_info_map(&samples);
2203 assert!(map.is_empty());
2204 }
2205
2206 #[test]
2207 fn test_build_sample_info_map_with_samples() {
2208 use crate::{Sample, SampleID};
2209
2210 let mut sample1 = Sample::default();
2211 sample1.image_name = Some("sample1".to_string());
2212 sample1.id = Some(SampleID::from(1));
2213 sample1.width = Some(640);
2214 sample1.height = Some(480);
2215 sample1.group = Some("train".to_string());
2216
2217 let mut sample2 = Sample::default();
2218 sample2.image_name = Some("sample2".to_string());
2219 sample2.id = Some(SampleID::from(2));
2220 sample2.width = Some(1280);
2221 sample2.height = Some(720);
2222 sample2.group = None;
2223
2224 let samples = vec![sample1, sample2];
2225 let map = build_sample_info_map(&samples);
2226
2227 assert_eq!(map.len(), 2);
2228 assert!(map.contains_key("sample1"));
2229 assert!(map.contains_key("sample2"));
2230
2231 let (id1, w1, h1, g1) = map.get("sample1").unwrap();
2232 assert_eq!(*id1, SampleID::from(1));
2233 assert_eq!(*w1, 640);
2234 assert_eq!(*h1, 480);
2235 assert_eq!(g1.as_deref(), Some("train"));
2236
2237 let (id2, w2, h2, g2) = map.get("sample2").unwrap();
2238 assert_eq!(*id2, SampleID::from(2));
2239 assert_eq!(*w2, 1280);
2240 assert_eq!(*h2, 720);
2241 assert!(g2.is_none());
2242 }
2243
2244 #[test]
2245 fn test_build_sample_info_map_skips_incomplete_samples() {
2246 use crate::Sample;
2247
2248 let mut sample_no_id = Sample::default();
2250 sample_no_id.image_name = Some("no_id".to_string());
2251 sample_no_id.width = Some(640);
2252 sample_no_id.height = Some(480);
2253
2254 let mut sample_no_name = Sample::default();
2256 sample_no_name.id = Some(crate::SampleID::from(1));
2257 sample_no_name.width = Some(640);
2258 sample_no_name.height = Some(480);
2259
2260 let mut sample_no_dims = Sample::default();
2262 sample_no_dims.image_name = Some("no_dims".to_string());
2263 sample_no_dims.id = Some(crate::SampleID::from(2));
2264
2265 let samples = vec![sample_no_id, sample_no_name, sample_no_dims];
2266 let map = build_sample_info_map(&samples);
2267
2268 assert!(map.is_empty());
2270 }
2271
2272 #[test]
2277 fn test_coco_import_options_clone() {
2278 let options = CocoImportOptions::default();
2280 let cloned = options.clone();
2281
2282 assert_eq!(options.batch_size, cloned.batch_size);
2283 assert_eq!(options.concurrency, cloned.concurrency);
2284 assert_eq!(options.include_masks, cloned.include_masks);
2285 }
2286
2287 #[test]
2292 fn test_coco_import_options_custom() {
2293 let options = CocoImportOptions {
2294 include_masks: false,
2295 include_images: false,
2296 group: Some("test".to_string()),
2297 batch_size: 50,
2298 concurrency: 32,
2299 resume: false,
2300 };
2301
2302 assert!(!options.include_masks);
2303 assert!(!options.include_images);
2304 assert_eq!(options.group.as_deref(), Some("test"));
2305 assert_eq!(options.batch_size, 50);
2306 assert_eq!(options.concurrency, 32);
2307 assert!(!options.resume);
2308 }
2309
2310 #[test]
2311 fn test_coco_update_options_custom() {
2312 let options = CocoUpdateOptions {
2313 include_masks: false,
2314 group: Some("val".to_string()),
2315 batch_size: 25,
2316 concurrency: 16,
2317 };
2318
2319 assert!(!options.include_masks);
2320 assert_eq!(options.group.as_deref(), Some("val"));
2321 assert_eq!(options.batch_size, 25);
2322 assert_eq!(options.concurrency, 16);
2323 }
2324
2325 #[test]
2330 fn test_extract_sample_name_simple() {
2331 assert_eq!(extract_sample_name("image.jpg"), "image");
2332 }
2333
2334 #[test]
2335 fn test_extract_sample_name_with_path() {
2336 assert_eq!(extract_sample_name("train2017/000001.jpg"), "000001");
2337 }
2338
2339 #[test]
2340 fn test_extract_sample_name_no_extension() {
2341 assert_eq!(extract_sample_name("image"), "image");
2342 }
2343
2344 #[test]
2345 fn test_extract_sample_name_multiple_dots() {
2346 assert_eq!(extract_sample_name("image.v2.final.jpg"), "image.v2.final");
2347 }
2348}