1use super::{
14 convert::{
15 box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_polygon,
16 polygon_to_coco_polygon,
17 },
18 reader::{CocoReadOptions, CocoReader, read_coco_directory},
19 types::{CocoDataset, CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
20 writer::{CocoDatasetBuilder, CocoWriteOptions, CocoWriter},
21};
22use crate::{
23 Annotation, AnnotationSetID, Client, DatasetID, Error, FileType, Progress, Sample, SampleFile,
24};
25use std::{
26 collections::HashSet,
27 path::{Path, PathBuf},
28};
29use tokio::sync::mpsc::Sender;
30
31#[derive(Debug, Clone)]
33pub struct CocoImportResult {
34 pub total_images: usize,
36 pub skipped: usize,
38 pub imported: usize,
40}
41
42#[derive(Debug, Clone)]
44pub struct CocoImportOptions {
45 pub include_masks: bool,
47 pub include_images: bool,
49 pub group: Option<String>,
51 pub batch_size: usize,
53 pub concurrency: usize,
55 pub resume: bool,
58}
59
60impl Default for CocoImportOptions {
61 fn default() -> Self {
62 Self {
63 include_masks: true,
64 include_images: true,
65 group: None,
66 batch_size: 100,
67 concurrency: 64,
68 resume: true,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct CocoExportOptions {
76 pub groups: Vec<String>,
78 pub include_masks: bool,
80 pub include_images: bool,
82 pub output_zip: bool,
84 pub pretty_json: bool,
86 pub info: Option<CocoInfo>,
88}
89
90impl Default for CocoExportOptions {
91 fn default() -> Self {
92 Self {
93 groups: vec![],
94 include_masks: true,
95 include_images: false,
96 output_zip: false,
97 pretty_json: false,
98 info: None,
99 }
100 }
101}
102
103pub async fn import_coco_to_studio(
144 client: &Client,
145 coco_path: impl AsRef<Path>,
146 dataset_id: DatasetID,
147 annotation_set_id: AnnotationSetID,
148 options: &CocoImportOptions,
149 progress: Option<Sender<Progress>>,
150) -> Result<CocoImportResult, Error> {
151 let coco_path = coco_path.as_ref();
152
153 let (dataset, images_dir) = read_coco_from_path(coco_path)?;
155
156 let total_images = dataset.images.len();
157 if total_images == 0 {
158 return Err(Error::MissingAnnotations(
159 "No images found in COCO dataset".to_string(),
160 ));
161 }
162
163 if options.include_images {
165 validate_images_extracted(&dataset, &images_dir)?;
166 }
167
168 let existing_names = fetch_existing_sample_names(client, &dataset_id, options.resume).await?;
170
171 let group_filter = options.group.as_deref();
173 let (images_to_import, skipped, filtered_by_group) =
174 filter_images_for_import(&dataset.images, group_filter, &existing_names);
175
176 log_import_filter_info(group_filter, filtered_by_group, total_images);
178
179 let to_import = images_to_import.len();
180
181 if to_import == 0 {
183 log_nothing_to_import(skipped);
184 return Ok(CocoImportResult {
185 total_images,
186 skipped,
187 imported: 0,
188 });
189 }
190
191 if skipped > 0 {
192 log::info!(
193 "Resuming import: {} of {} images already imported, {} remaining",
194 skipped,
195 total_images,
196 to_import
197 );
198 }
199
200 let index = CocoIndex::from_dataset(&dataset);
202 send_progress(&progress, 0, to_import).await;
203
204 let upload_ctx = UploadContext {
205 client,
206 dataset_id: &dataset_id,
207 annotation_set_id: &annotation_set_id,
208 options,
209 progress: &progress,
210 };
211 let imported =
212 upload_images_in_batches(&upload_ctx, &images_to_import, &index, &images_dir).await?;
213
214 Ok(CocoImportResult {
215 total_images,
216 skipped,
217 imported,
218 })
219}
220
221async fn fetch_existing_sample_names(
223 client: &Client,
224 dataset_id: &DatasetID,
225 resume: bool,
226) -> Result<HashSet<String>, Error> {
227 if !resume {
228 return Ok(HashSet::new());
229 }
230
231 log::info!("Checking for existing samples in dataset {}...", dataset_id);
232 let names = client.sample_names(*dataset_id, &[], None).await?;
233 log::info!("Found {} existing samples in dataset", names.len());
234
235 if !names.is_empty() {
236 let samples: Vec<_> = names.iter().take(3).collect();
237 log::debug!("Sample names from server: {:?}", samples);
238 }
239
240 Ok(names)
241}
242
243fn log_import_filter_info(group_filter: Option<&str>, filtered_by_group: usize, total: usize) {
245 if filtered_by_group > 0 {
246 log::info!(
247 "Group filter '{}': {} images excluded, {} matching",
248 group_filter.unwrap_or(""),
249 filtered_by_group,
250 total - filtered_by_group
251 );
252 }
253}
254
255fn log_nothing_to_import(skipped: usize) {
257 if skipped > 0 {
258 log::info!(
259 "All {} matching images already imported, nothing to do",
260 skipped
261 );
262 } else {
263 log::info!("No images to import");
264 }
265}
266
267async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
269 if let Some(p) = progress {
270 let _ = p
271 .send(Progress {
272 current,
273 total,
274 status: None,
275 })
276 .await;
277 }
278}
279
280struct UploadContext<'a> {
282 client: &'a Client,
283 dataset_id: &'a DatasetID,
284 annotation_set_id: &'a AnnotationSetID,
285 options: &'a CocoImportOptions,
286 progress: &'a Option<Sender<Progress>>,
287}
288
289async fn upload_images_in_batches<'a>(
291 ctx: &UploadContext<'a>,
292 images: &[&CocoImage],
293 index: &CocoIndex,
294 images_dir: &Path,
295) -> Result<usize, Error> {
296 let mut imported = 0;
297 let to_import = images.len();
298
299 for batch in images.chunks(ctx.options.batch_size) {
300 let samples = convert_batch_to_samples(batch, index, images_dir, ctx.options)?;
301
302 ctx.client
303 .populate_samples_with_concurrency(
304 *ctx.dataset_id,
305 Some(*ctx.annotation_set_id),
306 samples,
307 None,
308 Some(ctx.options.concurrency),
309 )
310 .await?;
311
312 imported += batch.len();
313 send_progress(ctx.progress, imported, to_import).await;
314 }
315
316 Ok(imported)
317}
318
319fn convert_batch_to_samples(
321 batch: &[&CocoImage],
322 index: &CocoIndex,
323 images_dir: &Path,
324 options: &CocoImportOptions,
325) -> Result<Vec<Sample>, Error> {
326 let mut samples = Vec::with_capacity(batch.len());
327
328 for image in batch {
329 let image_group = super::reader::infer_group_from_folder(&image.file_name);
330 let sample = convert_coco_image_to_sample(
331 image,
332 index,
333 images_dir,
334 options.include_masks,
335 options.include_images,
336 image_group.as_deref(),
337 )?;
338 samples.push(sample);
339 }
340
341 Ok(samples)
342}
343
344fn validate_images_extracted(dataset: &CocoDataset, images_dir: &Path) -> Result<(), Error> {
346 let sample_size = std::cmp::min(5, dataset.images.len());
348 let mut missing = Vec::new();
349
350 for image in dataset.images.iter().take(sample_size) {
351 if find_image_file(images_dir, &image.file_name).is_none() {
352 missing.push(image.file_name.clone());
353 }
354 }
355
356 if !missing.is_empty() {
357 let examples: Vec<_> = missing.iter().take(3).cloned().collect();
358 return Err(Error::MissingImages(format!(
359 "Images must be extracted before import.\n\
360 Cannot find: {}\n\n\
361 Searched in: {}\n\
362 Expected subdirectories: train2017/, val2017/, images/\n\n\
363 Please extract your COCO image archives first:\n\
364 $ cd {} && unzip train2017.zip && unzip val2017.zip",
365 examples.join(", "),
366 images_dir.display(),
367 images_dir.display()
368 )));
369 }
370
371 Ok(())
372}
373
374fn find_image_file(base_dir: &Path, file_name: &str) -> Option<PathBuf> {
376 let candidates = [
377 base_dir.join(file_name),
378 base_dir.join("images").join(file_name),
379 base_dir.join("train2017").join(file_name),
380 base_dir.join("val2017").join(file_name),
381 base_dir.join("test2017").join(file_name),
382 base_dir.join("train2014").join(file_name),
383 base_dir.join("val2014").join(file_name),
384 ];
385 candidates.into_iter().find(|p| p.exists())
386}
387
388fn infer_group_from_filename(path: &Path) -> Option<String> {
396 let stem = path.file_stem()?.to_str()?;
397
398 if let Some(rest) = stem.strip_prefix("instances_") {
400 let group = rest.trim_end_matches(char::is_numeric);
402 if !group.is_empty() {
403 return Some(group.to_string());
404 }
405 }
406
407 for prefix in ["train", "val", "test", "validation"] {
409 if stem.starts_with(prefix) {
410 return Some(prefix.to_string());
411 }
412 }
413
414 None
415}
416
417fn read_coco_from_path(coco_path: &Path) -> Result<(CocoDataset, PathBuf), Error> {
421 if coco_path.is_dir() {
422 let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
424 log::info!("Found {} annotation files in directory", datasets.len());
425
426 let mut merged = CocoDataset::default();
428 for (mut ds, group) in datasets {
429 log::info!(
430 " - {} group: {} images, {} annotations",
431 group,
432 ds.images.len(),
433 ds.annotations.len()
434 );
435 for image in &mut ds.images {
438 if !image.file_name.contains('/') {
439 image.file_name = format!("{}2017/{}", group, image.file_name);
440 }
441 }
442 merge_coco_datasets(&mut merged, ds);
443 }
444 Ok((merged, coco_path.to_path_buf()))
445 } else if coco_path.extension().is_some_and(|e| e == "json") {
446 let reader = CocoReader::new();
448 let dataset = reader.read_json(coco_path)?;
449 let parent = coco_path
450 .parent()
451 .and_then(|p| p.parent()) .unwrap_or(Path::new("."));
453 Ok((dataset, parent.to_path_buf()))
454 } else {
455 Err(Error::InvalidParameters(
456 "COCO import requires a JSON annotation file or directory. \
457 ZIP archives must be extracted first."
458 .to_string(),
459 ))
460 }
461}
462
463fn filter_images_for_import<'a>(
468 images: &'a [CocoImage],
469 group_filter: Option<&str>,
470 existing_names: &HashSet<String>,
471) -> (Vec<&'a CocoImage>, usize, usize) {
472 let total = images.len();
473
474 let images_to_import: Vec<_> = images
476 .iter()
477 .filter(|img| {
478 if let Some(filter) = group_filter {
480 let inferred = super::reader::infer_group_from_folder(&img.file_name);
481 if inferred.as_deref() != Some(filter) {
482 return false;
483 }
484 }
485 let sample_name = extract_sample_name(&img.file_name);
487 !existing_names.contains(&sample_name)
488 })
489 .collect();
490
491 let filtered_by_group = if group_filter.is_some() {
493 images
494 .iter()
495 .filter(|img| {
496 let inferred = super::reader::infer_group_from_folder(&img.file_name);
497 inferred.as_deref() != group_filter
498 })
499 .count()
500 } else {
501 0
502 };
503
504 let skipped = total - filtered_by_group - images_to_import.len();
505 (images_to_import, skipped, filtered_by_group)
506}
507
508fn extract_sample_name(file_name: &str) -> String {
510 Path::new(file_name)
511 .file_stem()
512 .and_then(|s| s.to_str())
513 .map(String::from)
514 .unwrap_or_else(|| file_name.to_string())
515}
516
517fn merge_coco_datasets(target: &mut CocoDataset, source: CocoDataset) {
522 let existing_image_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
524 for image in source.images {
525 if !existing_image_ids.contains(&image.id) {
526 target.images.push(image);
527 }
528 }
529
530 let existing_cat_ids: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
532 for cat in source.categories {
533 if !existing_cat_ids.contains(&cat.id) {
534 target.categories.push(cat);
535 }
536 }
537
538 target.annotations.extend(source.annotations);
540
541 let existing_license_ids: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
543 for license in source.licenses {
544 if !existing_license_ids.contains(&license.id) {
545 target.licenses.push(license);
546 }
547 }
548
549 if target.info.description.is_none() && source.info.description.is_some() {
551 target.info = source.info;
552 }
553}
554
555fn convert_coco_image_to_sample(
565 image: &CocoImage,
566 index: &CocoIndex,
567 images_dir: &Path,
568 include_masks: bool,
569 include_images: bool,
570 group: Option<&str>,
571) -> Result<Sample, Error> {
572 let sample_name = Path::new(&image.file_name)
573 .file_stem()
574 .and_then(|s| s.to_str())
575 .map(String::from)
576 .unwrap_or_else(|| image.file_name.clone());
577
578 let annotations = index
580 .annotations_for_image(image.id)
581 .iter()
582 .filter_map(|coco_ann| {
583 let label = index.label_name(coco_ann.category_id)?;
584 let label_index = index.label_index(coco_ann.category_id);
585
586 let box2d = coco_bbox_to_box2d(&coco_ann.bbox, image.width, image.height);
587
588 let polygon = if include_masks {
589 coco_ann.segmentation.as_ref().and_then(|seg| {
590 coco_segmentation_to_polygon(seg, image.width, image.height).ok()
591 })
592 } else {
593 None
594 };
595
596 {
597 let mut ann = Annotation::new();
598 ann.set_name(Some(sample_name.clone()));
599 ann.set_label(Some(label.to_string()));
600 ann.set_label_index(label_index);
601 ann.set_box2d(Some(box2d));
602 ann.set_polygon(polygon);
603 ann.set_group(group.map(String::from));
604 ann.set_iscrowd(Some(coco_ann.iscrowd != 0));
605 ann.set_category_frequency(index.frequency(coco_ann.category_id).map(String::from));
606 Some(ann)
607 }
608 })
609 .collect();
610
611 let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
613 ids.iter()
614 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
615 .collect::<Vec<u32>>()
616 });
617 let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
618 ids.iter()
619 .filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
620 .collect::<Vec<u32>>()
621 });
622
623 let mut files = Vec::new();
625 if include_images && let Some(image_path) = find_image_file(images_dir, &image.file_name) {
626 files.push(SampleFile::with_filename(
627 FileType::Image.to_string(),
628 image_path.to_string_lossy().to_string(),
629 ));
630 }
631
632 Ok(Sample {
633 image_name: Some(sample_name),
634 width: Some(image.width),
635 height: Some(image.height),
636 group: group.map(String::from),
637 neg_label_indices,
638 not_exhaustive_label_indices,
639 files,
640 annotations,
641 ..Default::default()
642 })
643}
644
645pub async fn export_studio_to_coco(
660 client: &Client,
661 dataset_id: DatasetID,
662 annotation_set_id: AnnotationSetID,
663 output_path: impl AsRef<Path>,
664 options: &CocoExportOptions,
665 progress: Option<Sender<Progress>>,
666) -> Result<usize, Error> {
667 let output_path = output_path.as_ref();
668
669 let groups: Vec<String> = options.groups.clone();
671 let annotation_types = [crate::AnnotationType::Box2d, crate::AnnotationType::Polygon];
672
673 let all_samples = client
675 .samples(
676 dataset_id,
677 Some(annotation_set_id),
678 &annotation_types,
679 &groups,
680 &[],
681 progress.clone(),
682 )
683 .await?;
684
685 let mut builder = CocoDatasetBuilder::new();
687
688 if let Some(info) = &options.info {
689 builder = builder.info(info.clone());
690 }
691
692 for sample in &all_samples {
693 let image_name = sample.image_name.as_deref().unwrap_or("unknown");
694 let width = sample.width.unwrap_or(0);
695 let height = sample.height.unwrap_or(0);
696
697 let file_name = if image_name.contains('.') {
699 image_name.to_string()
700 } else {
701 format!("{}.jpg", image_name)
702 };
703 let image_id = builder.add_image(&file_name, width, height);
704
705 for ann in &sample.annotations {
706 let bbox = if let Some(box2d) = ann.box2d() {
708 Some(box2d_to_coco_bbox(box2d, width, height))
709 } else if let Some(polygon) = ann.polygon() {
710 compute_bbox_from_polygon(polygon, width, height)
711 } else {
712 None
713 };
714
715 if let Some(bbox) = bbox {
716 let label = ann.label().map(|s| s.as_str()).unwrap_or("unknown");
717 let category_id = builder.add_category(label, None);
718
719 let segmentation = if options.include_masks {
720 ann.polygon().map(|polygon| {
721 let coco_poly = polygon_to_coco_polygon(polygon, width, height);
722 CocoSegmentation::Polygon(coco_poly)
723 })
724 } else {
725 None
726 };
727
728 builder.add_annotation(image_id, category_id, bbox, segmentation);
729 }
730 }
731 }
732
733 let dataset = builder.build();
734 let annotation_count = dataset.annotations.len();
735
736 let writer = CocoWriter::with_options(CocoWriteOptions {
738 compress: true,
739 pretty: options.pretty_json,
740 });
741
742 if options.output_zip {
743 let images = if options.include_images {
745 download_images(client, &all_samples, progress.clone()).await?
746 } else {
747 vec![]
748 };
749
750 writer.write_zip(&dataset, images.into_iter(), output_path)?;
751 } else {
752 writer.write_json(&dataset, output_path)?;
753 }
754
755 Ok(annotation_count)
756}
757
758async fn download_images(
763 client: &Client,
764 samples: &[Sample],
765 progress: Option<Sender<Progress>>,
766) -> Result<Vec<(String, Vec<u8>)>, Error> {
767 let mut result = Vec::with_capacity(samples.len());
768 let total = samples.len();
769
770 for (i, sample) in samples.iter().enumerate() {
771 let image_url = sample.files.iter().find_map(|f| {
773 if f.file_type() == "image" {
774 f.url()
775 } else {
776 None
777 }
778 });
779
780 if let Some(url) = image_url {
781 match client.download(url).await {
783 Ok(data) => {
784 let name = sample.image_name.as_deref().unwrap_or("unknown");
786 let filename = if name.contains('.') {
787 format!("images/{}", name)
788 } else {
789 format!("images/{}.jpg", name)
790 };
791 result.push((filename, data));
792 }
793 Err(e) => {
794 log::warn!(
796 "Failed to download image for sample {:?}: {}",
797 sample.image_name,
798 e
799 );
800 }
801 }
802 }
803
804 if let Some(ref p) = progress {
806 let _ = p
807 .send(Progress {
808 current: i + 1,
809 total,
810 status: None,
811 })
812 .await;
813 }
814 }
815
816 Ok(result)
817}
818
819#[derive(Debug, Clone)]
821pub struct CocoVerifyOptions {
822 pub verify_masks: bool,
824 pub group: Option<String>,
826}
827
828impl Default for CocoVerifyOptions {
829 fn default() -> Self {
830 Self {
831 verify_masks: true,
832 group: None,
833 }
834 }
835}
836
837#[derive(Debug, Clone)]
839pub struct CocoUpdateResult {
840 pub total_images: usize,
842 pub updated: usize,
844 pub not_found: usize,
846}
847
848#[derive(Debug, Clone)]
850pub struct CocoUpdateOptions {
851 pub include_masks: bool,
853 pub group: Option<String>,
855 pub batch_size: usize,
857 pub concurrency: usize,
859}
860
861impl Default for CocoUpdateOptions {
862 fn default() -> Self {
863 Self {
864 include_masks: true,
865 group: None,
866 batch_size: 100,
867 concurrency: 64,
868 }
869 }
870}
871
872fn read_coco_dataset_for_update(coco_path: &Path) -> Result<CocoDataset, Error> {
874 if coco_path.is_dir() {
875 let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
877 log::info!("Found {} annotation files in directory", datasets.len());
878
879 let mut merged = CocoDataset::default();
881 for (mut ds, group) in datasets {
882 log::info!(
883 " - {} group: {} images, {} annotations",
884 group,
885 ds.images.len(),
886 ds.annotations.len()
887 );
888 for image in &mut ds.images {
890 if !image.file_name.contains('/') {
891 image.file_name = format!("{}2017/{}", group, image.file_name);
892 }
893 }
894 merge_coco_datasets(&mut merged, ds);
895 }
896 Ok(merged)
897 } else if coco_path.extension().is_some_and(|e| e == "json") {
898 let reader = CocoReader::new();
899 reader.read_json(coco_path)
900 } else {
901 Err(Error::InvalidParameters(
902 "COCO update requires a JSON annotation file or directory.".to_string(),
903 ))
904 }
905}
906
907fn build_sample_info_map(
910 samples: &[Sample],
911) -> std::collections::HashMap<String, (crate::SampleID, u32, u32, Option<String>)> {
912 use std::collections::HashMap;
913 let mut sample_info = HashMap::new();
914 for sample in samples {
915 if let (Some(name), Some(id), Some(w), Some(h)) =
916 (sample.name(), sample.id(), sample.width, sample.height)
917 {
918 sample_info.insert(name, (id, w, h, sample.group.clone()));
919 }
920 }
921 sample_info
922}
923
924async fn ensure_labels_exist(
926 client: &Client,
927 dataset_id: &DatasetID,
928 categories: &[crate::coco::CocoCategory],
929) -> Result<std::collections::HashMap<String, u64>, Error> {
930 use std::collections::{HashMap, HashSet};
931
932 let existing_labels = client.labels(*dataset_id).await?;
934 let existing_label_names: HashSet<String> = existing_labels
935 .iter()
936 .map(|l| l.name().to_string())
937 .collect();
938
939 let missing_labels: Vec<String> = categories
941 .iter()
942 .filter(|c| !existing_label_names.contains(&c.name))
943 .map(|c| c.name.clone())
944 .collect();
945
946 if !missing_labels.is_empty() {
948 log::info!(
949 "Creating {} missing labels in Studio...",
950 missing_labels.len()
951 );
952 for label_name in &missing_labels {
953 client.add_label(*dataset_id, label_name).await?;
954 }
955 }
956
957 let labels = client.labels(*dataset_id).await?;
959 let label_map: HashMap<String, u64> = labels
960 .iter()
961 .map(|l| (l.name().to_string(), l.id()))
962 .collect();
963
964 log::info!(
965 "Label map has {} entries for {} COCO categories",
966 label_map.len(),
967 categories.len()
968 );
969
970 Ok(label_map)
971}
972
973fn convert_coco_annotation_to_server(
977 coco_ann: &super::types::CocoAnnotation,
978 coco_index: &CocoIndex,
979 label_map: &std::collections::HashMap<String, u64>,
980 image_id: u64,
981 annotation_set_id: u64,
982 dims: (u32, u32),
983 include_masks: bool,
984) -> (crate::api::ServerAnnotation, bool) {
985 let (width, height) = dims;
986
987 let category_name = coco_index
989 .categories
990 .get(&coco_ann.category_id)
991 .map(|c| c.name.as_str())
992 .unwrap_or("unknown");
993
994 let label_id = label_map.get(category_name).copied();
995 let missing_label = label_id.is_none();
996
997 let box2d = coco_bbox_to_box2d(&coco_ann.bbox, width, height);
999
1000 let polygon = if include_masks {
1002 coco_ann
1003 .segmentation
1004 .as_ref()
1005 .and_then(|seg| coco_segmentation_to_polygon(seg, width, height).ok())
1006 .map(|p| polygon_to_polygon_string(&p))
1007 .unwrap_or_default()
1008 } else {
1009 String::new()
1010 };
1011
1012 let annotation_type = if polygon.is_empty() { "box" } else { "seg" }.to_string();
1013
1014 let server_ann = crate::api::ServerAnnotation {
1015 label_id,
1016 label_index: None,
1017 label_name: Some(category_name.to_string()),
1018 annotation_type,
1019 x: box2d.left() as f64,
1020 y: box2d.top() as f64,
1021 w: box2d.width() as f64,
1022 h: box2d.height() as f64,
1023 score: 1.0,
1024 polygon,
1025 image_id,
1026 annotation_set_id,
1027 object_reference: None,
1028 };
1029
1030 (server_ann, missing_label)
1031}
1032
1033fn process_image_for_update(
1036 coco_image: &CocoImage,
1037 sample_info: &std::collections::HashMap<String, (crate::SampleID, u32, u32, Option<String>)>,
1038 coco_index: &CocoIndex,
1039 label_map: &std::collections::HashMap<String, u64>,
1040 annotation_set_id: u64,
1041 include_masks: bool,
1042) -> Option<(
1043 crate::SampleID,
1044 Vec<crate::api::ServerAnnotation>,
1045 Option<String>,
1046 usize,
1047)> {
1048 let sample_name = extract_sample_name(&coco_image.file_name);
1049 let expected_group = super::reader::infer_group_from_folder(&coco_image.file_name);
1050
1051 let (sample_id, width, height, current_group) = sample_info.get(&sample_name)?;
1052 let (sample_id, width, height) = (*sample_id, *width, *height);
1053 let image_id: u64 = sample_id.into();
1054
1055 let group_update = expected_group.as_ref().and_then(|expected| {
1057 if Some(expected) != current_group.as_ref() {
1058 Some(expected.clone())
1059 } else {
1060 None
1061 }
1062 });
1063
1064 let mut annotations = Vec::new();
1066 let mut missing_label_count = 0;
1067
1068 for coco_ann in coco_index.annotations_for_image(coco_image.id) {
1069 let (server_ann, missing) = convert_coco_annotation_to_server(
1070 coco_ann,
1071 coco_index,
1072 label_map,
1073 image_id,
1074 annotation_set_id,
1075 (width, height),
1076 include_masks,
1077 );
1078 if missing {
1079 missing_label_count += 1;
1080 }
1081 annotations.push(server_ann);
1082 }
1083
1084 Some((sample_id, annotations, group_update, missing_label_count))
1085}
1086
1087async fn update_sample_groups(
1089 client: &Client,
1090 dataset_id: &DatasetID,
1091 samples_needing_group_update: &[(crate::SampleID, String)],
1092) -> usize {
1093 use std::collections::{HashMap, HashSet};
1094
1095 if samples_needing_group_update.is_empty() {
1096 return 0;
1097 }
1098
1099 log::info!(
1100 "Updating groups for {} samples...",
1101 samples_needing_group_update.len()
1102 );
1103
1104 let unique_groups: HashSet<String> = samples_needing_group_update
1106 .iter()
1107 .map(|(_, group)| group.clone())
1108 .collect();
1109
1110 let mut group_id_map: HashMap<String, u64> = HashMap::new();
1111 for group_name in unique_groups {
1112 match client.get_or_create_group(*dataset_id, &group_name).await {
1113 Ok(group_id) => {
1114 group_id_map.insert(group_name, group_id);
1115 }
1116 Err(e) => {
1117 log::warn!("Failed to get/create group '{}': {}", group_name, e);
1118 }
1119 }
1120 }
1121
1122 let mut updated_count = 0;
1124 let mut failed_count = 0;
1125 for (sample_id, group_name) in samples_needing_group_update {
1126 if let Some(&group_id) = group_id_map.get(group_name) {
1127 match client.set_sample_group_id(*sample_id, group_id).await {
1128 Ok(_) => {
1129 updated_count += 1;
1130 if updated_count % 1000 == 0 {
1131 log::debug!("Updated groups for {} samples so far", updated_count);
1132 }
1133 }
1134 Err(e) => {
1135 failed_count += 1;
1136 if failed_count <= 5 {
1137 log::warn!("Failed to update group for sample {:?}: {}", sample_id, e);
1138 }
1139 }
1140 }
1141 }
1142 }
1143
1144 if failed_count > 5 {
1145 log::warn!("... and {} more group update failures", failed_count - 5);
1146 }
1147 log::info!(
1148 "Updated groups for {} samples ({} failed)",
1149 updated_count,
1150 failed_count
1151 );
1152
1153 updated_count
1154}
1155
1156pub async fn update_coco_annotations(
1176 client: &Client,
1177 coco_path: impl AsRef<Path>,
1178 dataset_id: DatasetID,
1179 annotation_set_id: AnnotationSetID,
1180 options: &CocoUpdateOptions,
1181 progress: Option<Sender<Progress>>,
1182) -> Result<CocoUpdateResult, Error> {
1183 use crate::{SampleID, api::ServerAnnotation};
1184
1185 let coco_path = coco_path.as_ref();
1186
1187 let dataset = read_coco_dataset_for_update(coco_path)?;
1189 let total_images = dataset.images.len();
1190
1191 if total_images == 0 {
1192 return Err(Error::MissingAnnotations(
1193 "No images found in COCO dataset".to_string(),
1194 ));
1195 }
1196
1197 log::info!(
1198 "COCO dataset: {} images, {} annotations, {} categories",
1199 total_images,
1200 dataset.annotations.len(),
1201 dataset.categories.len()
1202 );
1203
1204 log::info!("Fetching existing samples from Studio...");
1206 let existing_samples = client
1207 .samples(
1208 dataset_id,
1209 Some(annotation_set_id),
1210 &[],
1211 &[],
1212 &[],
1213 progress.clone(),
1214 )
1215 .await?;
1216
1217 let sample_info = build_sample_info_map(&existing_samples);
1218 log::info!(
1219 "Found {} existing samples in Studio with IDs and dimensions",
1220 sample_info.len()
1221 );
1222
1223 let coco_index = CocoIndex::from_dataset(&dataset);
1225
1226 let label_map = ensure_labels_exist(client, &dataset_id, &dataset.categories).await?;
1228
1229 let annotation_set_id_u64: u64 = annotation_set_id.into();
1231 let mut sample_ids_to_update: Vec<SampleID> = Vec::new();
1232 let mut server_annotations: Vec<ServerAnnotation> = Vec::new();
1233 let mut samples_needing_group_update: Vec<(SampleID, String)> = Vec::new();
1234 let mut not_found = 0;
1235 let mut missing_label_count = 0;
1236
1237 for coco_image in &dataset.images {
1238 match process_image_for_update(
1239 coco_image,
1240 &sample_info,
1241 &coco_index,
1242 &label_map,
1243 annotation_set_id_u64,
1244 options.include_masks,
1245 ) {
1246 Some((sample_id, annotations, group_update, missing_labels)) => {
1247 sample_ids_to_update.push(sample_id);
1248 server_annotations.extend(annotations);
1249 missing_label_count += missing_labels;
1250 if let Some(group) = group_update {
1251 samples_needing_group_update.push((sample_id, group));
1252 }
1253 }
1254 None => {
1255 not_found += 1;
1256 log::debug!(
1257 "Sample not found in Studio: {}",
1258 extract_sample_name(&coco_image.file_name)
1259 );
1260 }
1261 }
1262 }
1263
1264 let to_update = sample_ids_to_update.len();
1265 log::info!(
1266 "Updating {} samples ({} not found in Studio), {} annotations",
1267 to_update,
1268 not_found,
1269 server_annotations.len()
1270 );
1271
1272 if missing_label_count > 0 {
1273 log::warn!(
1274 "{} annotations have missing label_id (category not found in label map)",
1275 missing_label_count
1276 );
1277 }
1278
1279 if to_update == 0 {
1280 return Ok(CocoUpdateResult {
1281 total_images,
1282 updated: 0,
1283 not_found,
1284 });
1285 }
1286
1287 if let Some(ref tx) = progress {
1289 let _ = tx
1290 .send(Progress {
1291 current: 0,
1292 total: to_update,
1293 status: None,
1294 })
1295 .await;
1296 }
1297
1298 log::info!(
1300 "Deleting existing annotations for {} samples...",
1301 sample_ids_to_update.len()
1302 );
1303 let annotation_types = if options.include_masks {
1304 vec!["box".to_string(), "seg".to_string()]
1305 } else {
1306 vec!["box".to_string()]
1307 };
1308
1309 for batch in sample_ids_to_update.chunks(options.batch_size) {
1311 client
1312 .delete_annotations_bulk(annotation_set_id, &annotation_types, batch)
1313 .await?;
1314 }
1315
1316 if let Some(ref tx) = progress {
1318 let _ = tx
1319 .send(Progress {
1320 current: to_update / 2,
1321 total: to_update,
1322 status: None,
1323 })
1324 .await;
1325 }
1326
1327 log::info!("Adding {} new annotations...", server_annotations.len());
1329 let mut added = 0;
1330 for batch in server_annotations.chunks(options.batch_size) {
1331 client
1332 .add_annotations_bulk(annotation_set_id, batch.to_vec())
1333 .await?;
1334 added += batch.len();
1335 log::debug!("Added {} annotations so far", added);
1336 }
1337
1338 if let Some(ref tx) = progress {
1340 let _ = tx
1341 .send(Progress {
1342 current: to_update,
1343 total: to_update,
1344 status: None,
1345 })
1346 .await;
1347 }
1348
1349 let groups_updated =
1351 update_sample_groups(client, &dataset_id, &samples_needing_group_update).await;
1352
1353 log::info!(
1354 "Update complete: {} samples updated, {} not found, {} annotations added, {} groups updated",
1355 to_update,
1356 not_found,
1357 added,
1358 groups_updated
1359 );
1360
1361 Ok(CocoUpdateResult {
1362 total_images,
1363 updated: to_update,
1364 not_found,
1365 })
1366}
1367
1368fn polygon_to_polygon_string(polygon: &crate::Polygon) -> String {
1381 let rings: Vec<Vec<[f32; 2]>> = polygon
1385 .rings
1386 .iter()
1387 .map(|ring| {
1388 ring.iter()
1389 .filter(|(x, y)| x.is_finite() && y.is_finite())
1390 .map(|&(x, y)| [x, y])
1391 .collect()
1392 })
1393 .filter(|ring: &Vec<[f32; 2]>| ring.len() >= 3) .collect();
1395
1396 serde_json::to_string(&rings).unwrap_or_default()
1397}
1398
1399fn compute_bbox_from_polygon(
1404 polygon: &crate::Polygon,
1405 width: u32,
1406 height: u32,
1407) -> Option<[f64; 4]> {
1408 if polygon.rings.is_empty() {
1409 return None;
1410 }
1411
1412 let mut min_x = f32::MAX;
1413 let mut min_y = f32::MAX;
1414 let mut max_x = f32::MIN;
1415 let mut max_y = f32::MIN;
1416
1417 for ring in &polygon.rings {
1418 for &(x, y) in ring {
1419 if x.is_finite() && y.is_finite() {
1420 min_x = min_x.min(x);
1421 min_y = min_y.min(y);
1422 max_x = max_x.max(x);
1423 max_y = max_y.max(y);
1424 }
1425 }
1426 }
1427
1428 if min_x == f32::MAX || min_y == f32::MAX {
1429 return None;
1430 }
1431
1432 let x = (min_x * width as f32) as f64;
1434 let y = (min_y * height as f32) as f64;
1435 let w = ((max_x - min_x) * width as f32) as f64;
1436 let h = ((max_y - min_y) * height as f32) as f64;
1437
1438 if w > 0.0 && h > 0.0 {
1439 Some([x, y, w, h])
1440 } else {
1441 None
1442 }
1443}
1444
1445pub async fn verify_coco_import(
1466 client: &Client,
1467 coco_path: impl AsRef<Path>,
1468 dataset_id: DatasetID,
1469 annotation_set_id: AnnotationSetID,
1470 options: &CocoVerifyOptions,
1471 progress: Option<Sender<Progress>>,
1472) -> Result<super::verify::VerificationResult, Error> {
1473 use super::{verify::*, writer::CocoDatasetBuilder};
1474
1475 let coco_path = coco_path.as_ref();
1476
1477 log::info!("Reading local COCO dataset from {:?}", coco_path);
1479 let (coco_dataset, inferred_group) = if coco_path.is_dir() {
1480 let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
1482 log::info!("Found {} annotation files in directory", datasets.len());
1483
1484 let mut merged = CocoDataset::default();
1485 for (ds, group) in datasets {
1486 log::info!(
1487 " - {} group: {} images, {} annotations",
1488 group,
1489 ds.images.len(),
1490 ds.annotations.len()
1491 );
1492 merge_coco_datasets(&mut merged, ds);
1493 }
1494 (merged, None)
1496 } else if coco_path.extension().is_some_and(|e| e == "json") {
1497 let reader = CocoReader::new();
1498 let dataset = reader.read_json(coco_path)?;
1499 let group = infer_group_from_filename(coco_path);
1500 (dataset, group)
1501 } else {
1502 return Err(Error::InvalidParameters(
1503 "COCO verification requires a JSON annotation file or directory.".to_string(),
1504 ));
1505 };
1506
1507 let effective_group = options.group.clone().or(inferred_group);
1509 let groups: Vec<String> = effective_group
1510 .as_ref()
1511 .map(|g| vec![g.clone()])
1512 .unwrap_or_default();
1513
1514 log::info!(
1515 "Local COCO: {} images, {} annotations",
1516 coco_dataset.images.len(),
1517 coco_dataset.annotations.len()
1518 );
1519
1520 log::info!("Fetching samples from Studio dataset {}...", dataset_id);
1522 let annotation_types = [crate::AnnotationType::Box2d, crate::AnnotationType::Polygon];
1523
1524 let studio_samples = client
1525 .samples(
1526 dataset_id,
1527 Some(annotation_set_id),
1528 &annotation_types,
1529 &groups,
1530 &[],
1531 progress.clone(),
1532 )
1533 .await?;
1534
1535 let total_annotations: usize = studio_samples.iter().map(|s| s.annotations.len()).sum();
1536 log::info!(
1537 "Studio: {} samples, {} total annotations",
1538 studio_samples.len(),
1539 total_annotations
1540 );
1541
1542 let mut builder = CocoDatasetBuilder::new();
1544
1545 for sample in &studio_samples {
1546 let image_name = sample.image_name.as_deref().unwrap_or("unknown");
1547 let width = sample.width.unwrap_or(0);
1548 let height = sample.height.unwrap_or(0);
1549
1550 let file_name = if image_name.contains('.') {
1552 image_name.to_string()
1553 } else {
1554 format!("{}.jpg", image_name)
1555 };
1556 let image_id = builder.add_image(&file_name, width, height);
1557
1558 for ann in &sample.annotations {
1559 let bbox = if let Some(box2d) = ann.box2d() {
1561 Some(box2d_to_coco_bbox(box2d, width, height))
1562 } else if let Some(polygon) = ann.polygon() {
1563 compute_bbox_from_polygon(polygon, width, height)
1565 } else {
1566 None
1567 };
1568
1569 if let Some(bbox) = bbox {
1570 let label = ann.label().map(|s| s.as_str()).unwrap_or("unknown");
1571 let category_id = builder.add_category(label, None);
1572
1573 let segmentation = if options.verify_masks {
1574 ann.polygon().map(|polygon| {
1575 let coco_poly = polygon_to_coco_polygon(polygon, width, height);
1576 CocoSegmentation::Polygon(coco_poly)
1577 })
1578 } else {
1579 None
1580 };
1581
1582 builder.add_annotation(image_id, category_id, bbox, segmentation);
1583 }
1584 }
1585 }
1586
1587 let studio_dataset = builder.build();
1588
1589 let coco_names: HashSet<String> = coco_dataset
1591 .images
1592 .iter()
1593 .map(|img| {
1594 Path::new(&img.file_name)
1595 .file_stem()
1596 .and_then(|s| s.to_str())
1597 .map(String::from)
1598 .unwrap_or_else(|| img.file_name.clone())
1599 })
1600 .collect();
1601
1602 let studio_names: HashSet<String> = studio_samples.iter().filter_map(|s| s.name()).collect();
1603
1604 let missing_images: Vec<String> = coco_names.difference(&studio_names).cloned().collect();
1605 let extra_images: Vec<String> = studio_names.difference(&coco_names).cloned().collect();
1606
1607 log::info!("Validating bounding boxes...");
1609 let bbox_validation = validate_bboxes(&coco_dataset, &studio_dataset);
1610
1611 log::info!("Validating segmentation masks...");
1613 let mask_validation = if options.verify_masks {
1614 validate_masks(&coco_dataset, &studio_dataset)
1615 } else {
1616 MaskValidationResult::new()
1617 };
1618
1619 let category_validation = validate_categories(&coco_dataset, &studio_dataset);
1621
1622 Ok(VerificationResult {
1623 coco_image_count: coco_dataset.images.len(),
1624 studio_image_count: studio_samples.len(),
1625 missing_images,
1626 extra_images,
1627 coco_annotation_count: coco_dataset.annotations.len(),
1628 studio_annotation_count: studio_dataset.annotations.len(),
1629 bbox_validation,
1630 mask_validation,
1631 category_validation,
1632 })
1633}
1634
1635#[cfg(test)]
1636mod tests {
1637 use super::*;
1638 use crate::coco::{CocoAnnotation, CocoCategory};
1639
1640 #[test]
1645 fn test_coco_import_options_default() {
1646 let options = CocoImportOptions::default();
1647 assert!(options.include_masks);
1648 assert!(options.include_images);
1649 assert!(options.group.is_none());
1650 assert_eq!(options.batch_size, 100);
1651 assert_eq!(options.concurrency, 64);
1652 assert!(options.resume);
1653 }
1654
1655 #[test]
1656 fn test_coco_export_options_default() {
1657 let options = CocoExportOptions::default();
1658 assert!(options.groups.is_empty());
1659 assert!(options.include_masks);
1660 assert!(!options.include_images);
1661 assert!(!options.output_zip);
1662 assert!(!options.pretty_json);
1663 assert!(options.info.is_none());
1664 }
1665
1666 #[test]
1667 fn test_coco_update_options_default() {
1668 let options = CocoUpdateOptions::default();
1669 assert!(options.include_masks);
1670 assert!(options.group.is_none());
1671 assert_eq!(options.batch_size, 100);
1672 assert_eq!(options.concurrency, 64);
1673 }
1674
1675 #[test]
1676 fn test_coco_verify_options_default() {
1677 let options = CocoVerifyOptions::default();
1678 assert!(options.verify_masks);
1679 assert!(options.group.is_none());
1680 }
1681
1682 #[test]
1687 fn test_find_image_file_nonexistent() {
1688 let result = find_image_file(Path::new("/nonexistent"), "test.jpg");
1689 assert!(result.is_none());
1690 }
1691
1692 #[test]
1693 fn test_find_image_file_with_subdirectory_in_name() {
1694 let result = find_image_file(Path::new("/nonexistent"), "train2017/image.jpg");
1696 assert!(result.is_none()); }
1698
1699 #[test]
1704 fn test_infer_group_from_filename_instances_train() {
1705 let path = Path::new("annotations/instances_train2017.json");
1706 assert_eq!(infer_group_from_filename(path), Some("train".to_string()));
1707 }
1708
1709 #[test]
1710 fn test_infer_group_from_filename_instances_val() {
1711 let path = Path::new("annotations/instances_val2017.json");
1712 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1713 }
1714
1715 #[test]
1716 fn test_infer_group_from_filename_instances_test() {
1717 let path = Path::new("instances_test2017.json");
1718 assert_eq!(infer_group_from_filename(path), Some("test".to_string()));
1719 }
1720
1721 #[test]
1722 fn test_infer_group_from_filename_train_prefix() {
1723 let path = Path::new("train_annotations.json");
1724 assert_eq!(infer_group_from_filename(path), Some("train".to_string()));
1725 }
1726
1727 #[test]
1728 fn test_infer_group_from_filename_val_prefix() {
1729 let path = Path::new("val_data.json");
1730 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1731 }
1732
1733 #[test]
1734 fn test_infer_group_from_filename_validation_prefix() {
1735 let path = Path::new("validation_set.json");
1738 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1739 }
1740
1741 #[test]
1742 fn test_infer_group_from_filename_custom() {
1743 let path = Path::new("my_custom_annotations.json");
1744 assert_eq!(infer_group_from_filename(path), None);
1745 }
1746
1747 #[test]
1748 fn test_infer_group_from_filename_instances_2014() {
1749 let path = Path::new("instances_val2014.json");
1750 assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
1751 }
1752
1753 #[test]
1758 fn test_merge_coco_datasets_empty() {
1759 let mut target = CocoDataset::default();
1760 let source = CocoDataset::default();
1761 merge_coco_datasets(&mut target, source);
1762 assert!(target.images.is_empty());
1763 assert!(target.annotations.is_empty());
1764 assert!(target.categories.is_empty());
1765 }
1766
1767 #[test]
1768 fn test_merge_coco_datasets_basic() {
1769 let mut target = CocoDataset {
1770 images: vec![CocoImage {
1771 id: 1,
1772 file_name: "img1.jpg".to_string(),
1773 ..Default::default()
1774 }],
1775 categories: vec![CocoCategory {
1776 id: 1,
1777 name: "cat".to_string(),
1778 supercategory: None,
1779 ..Default::default()
1780 }],
1781 annotations: vec![CocoAnnotation {
1782 id: 1,
1783 image_id: 1,
1784 category_id: 1,
1785 ..Default::default()
1786 }],
1787 ..Default::default()
1788 };
1789
1790 let source = CocoDataset {
1791 images: vec![CocoImage {
1792 id: 2,
1793 file_name: "img2.jpg".to_string(),
1794 ..Default::default()
1795 }],
1796 categories: vec![CocoCategory {
1797 id: 2,
1798 name: "dog".to_string(),
1799 supercategory: None,
1800 ..Default::default()
1801 }],
1802 annotations: vec![CocoAnnotation {
1803 id: 2,
1804 image_id: 2,
1805 category_id: 2,
1806 ..Default::default()
1807 }],
1808 ..Default::default()
1809 };
1810
1811 merge_coco_datasets(&mut target, source);
1812
1813 assert_eq!(target.images.len(), 2);
1814 assert_eq!(target.categories.len(), 2);
1815 assert_eq!(target.annotations.len(), 2);
1816 }
1817
1818 #[test]
1819 fn test_merge_coco_datasets_deduplicates_images() {
1820 let mut target = CocoDataset {
1821 images: vec![CocoImage {
1822 id: 1,
1823 file_name: "img1.jpg".to_string(),
1824 ..Default::default()
1825 }],
1826 ..Default::default()
1827 };
1828
1829 let source = CocoDataset {
1830 images: vec![
1831 CocoImage {
1832 id: 1, file_name: "img1_dup.jpg".to_string(),
1834 ..Default::default()
1835 },
1836 CocoImage {
1837 id: 2,
1838 file_name: "img2.jpg".to_string(),
1839 ..Default::default()
1840 },
1841 ],
1842 ..Default::default()
1843 };
1844
1845 merge_coco_datasets(&mut target, source);
1846
1847 assert_eq!(target.images.len(), 2); assert_eq!(target.images[0].file_name, "img1.jpg"); }
1850
1851 #[test]
1852 fn test_merge_coco_datasets_deduplicates_categories() {
1853 let mut target = CocoDataset {
1854 categories: vec![CocoCategory {
1855 id: 1,
1856 name: "person".to_string(),
1857 supercategory: None,
1858 ..Default::default()
1859 }],
1860 ..Default::default()
1861 };
1862
1863 let source = CocoDataset {
1864 categories: vec![
1865 CocoCategory {
1866 id: 1, name: "person_dup".to_string(),
1868 supercategory: None,
1869 ..Default::default()
1870 },
1871 CocoCategory {
1872 id: 2,
1873 name: "car".to_string(),
1874 supercategory: None,
1875 ..Default::default()
1876 },
1877 ],
1878 ..Default::default()
1879 };
1880
1881 merge_coco_datasets(&mut target, source);
1882
1883 assert_eq!(target.categories.len(), 2);
1884 assert_eq!(target.categories[0].name, "person"); }
1886
1887 #[test]
1888 fn test_merge_coco_datasets_info_preserved() {
1889 let mut target = CocoDataset::default();
1890
1891 let source = CocoDataset {
1892 info: CocoInfo {
1893 description: Some("Test dataset".to_string()),
1894 ..Default::default()
1895 },
1896 ..Default::default()
1897 };
1898
1899 merge_coco_datasets(&mut target, source);
1900
1901 assert_eq!(target.info.description, Some("Test dataset".to_string()));
1902 }
1903
1904 #[test]
1909 fn test_convert_coco_image_to_sample() {
1910 let image = CocoImage {
1911 id: 1,
1912 width: 640,
1913 height: 480,
1914 file_name: "test.jpg".to_string(),
1915 ..Default::default()
1916 };
1917
1918 let dataset = CocoDataset {
1919 images: vec![image.clone()],
1920 categories: vec![CocoCategory {
1921 id: 1,
1922 name: "person".to_string(),
1923 supercategory: None,
1924 ..Default::default()
1925 }],
1926 annotations: vec![CocoAnnotation {
1927 id: 1,
1928 image_id: 1,
1929 category_id: 1,
1930 bbox: [100.0, 50.0, 200.0, 150.0],
1931 area: 30000.0,
1932 iscrowd: 0,
1933 segmentation: None,
1934 score: None,
1935 }],
1936 ..Default::default()
1937 };
1938
1939 let index = CocoIndex::from_dataset(&dataset);
1940
1941 let sample = convert_coco_image_to_sample(
1942 &image,
1943 &index,
1944 Path::new("/tmp"),
1945 true,
1946 false, Some("train"),
1948 )
1949 .unwrap();
1950
1951 assert_eq!(sample.image_name, Some("test".to_string()));
1952 assert_eq!(sample.width, Some(640));
1953 assert_eq!(sample.height, Some(480));
1954 assert_eq!(sample.group, Some("train".to_string()));
1955 assert_eq!(sample.annotations.len(), 1);
1956 assert_eq!(sample.annotations[0].label(), Some(&"person".to_string()));
1957 }
1958
1959 #[test]
1960 fn test_convert_coco_image_to_sample_no_annotations() {
1961 let image = CocoImage {
1962 id: 1,
1963 width: 640,
1964 height: 480,
1965 file_name: "empty.jpg".to_string(),
1966 ..Default::default()
1967 };
1968
1969 let dataset = CocoDataset {
1970 images: vec![image.clone()],
1971 categories: vec![],
1972 annotations: vec![],
1973 ..Default::default()
1974 };
1975
1976 let index = CocoIndex::from_dataset(&dataset);
1977
1978 let sample =
1979 convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), true, false, None)
1980 .unwrap();
1981
1982 assert_eq!(sample.image_name, Some("empty".to_string()));
1983 assert!(sample.annotations.is_empty());
1984 }
1985
1986 #[test]
1987 fn test_convert_coco_image_to_sample_with_mask() {
1988 let image = CocoImage {
1989 id: 1,
1990 width: 100,
1991 height: 100,
1992 file_name: "masked.jpg".to_string(),
1993 ..Default::default()
1994 };
1995
1996 let dataset = CocoDataset {
1997 images: vec![image.clone()],
1998 categories: vec![CocoCategory {
1999 id: 1,
2000 name: "object".to_string(),
2001 supercategory: None,
2002 ..Default::default()
2003 }],
2004 annotations: vec![CocoAnnotation {
2005 id: 1,
2006 image_id: 1,
2007 category_id: 1,
2008 bbox: [10.0, 10.0, 50.0, 50.0],
2009 area: 2500.0,
2010 iscrowd: 0,
2011 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
2012 10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
2013 ]])),
2014 score: None,
2015 }],
2016 ..Default::default()
2017 };
2018
2019 let index = CocoIndex::from_dataset(&dataset);
2020
2021 let sample_with =
2023 convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), true, false, None)
2024 .unwrap();
2025 assert!(sample_with.annotations[0].polygon().is_some());
2026
2027 let sample_without =
2029 convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), false, false, None)
2030 .unwrap();
2031 assert!(sample_without.annotations[0].polygon().is_none());
2032 }
2033
2034 #[test]
2039 fn test_compute_bbox_from_polygon_simple() {
2040 let mask = crate::Polygon::new(vec![vec![(0.1, 0.1), (0.5, 0.1), (0.5, 0.5), (0.1, 0.5)]]);
2041
2042 let bbox = compute_bbox_from_polygon(&mask, 100, 100);
2043
2044 assert!(bbox.is_some());
2045 let [x, y, w, h] = bbox.unwrap();
2046 assert!((x - 10.0).abs() < 1.0);
2047 assert!((y - 10.0).abs() < 1.0);
2048 assert!((w - 40.0).abs() < 1.0);
2049 assert!((h - 40.0).abs() < 1.0);
2050 }
2051
2052 #[test]
2053 fn test_compute_bbox_from_polygon_empty() {
2054 let mask = crate::Polygon::new(vec![]);
2055 let bbox = compute_bbox_from_polygon(&mask, 100, 100);
2056 assert!(bbox.is_none());
2057 }
2058
2059 #[test]
2060 fn test_compute_bbox_from_polygon_with_nan() {
2061 let mask = crate::Polygon::new(vec![vec![(f32::NAN, f32::NAN), (f32::NAN, f32::NAN)]]);
2062 let bbox = compute_bbox_from_polygon(&mask, 100, 100);
2063 assert!(bbox.is_none());
2064 }
2065
2066 #[test]
2067 fn test_compute_bbox_from_polygon_multiple_rings() {
2068 let mask = crate::Polygon::new(vec![
2070 vec![(0.1, 0.1), (0.2, 0.1), (0.2, 0.2), (0.1, 0.2)],
2071 vec![(0.8, 0.8), (0.9, 0.8), (0.9, 0.9), (0.8, 0.9)],
2072 ]);
2073
2074 let bbox = compute_bbox_from_polygon(&mask, 100, 100);
2075
2076 assert!(bbox.is_some());
2077 let [x, y, w, h] = bbox.unwrap();
2078 assert!((x - 10.0).abs() < 1.0);
2080 assert!((y - 10.0).abs() < 1.0);
2081 assert!((w - 80.0).abs() < 1.0);
2082 assert!((h - 80.0).abs() < 1.0);
2083 }
2084
2085 #[test]
2090 fn test_polygon_to_polygon_string() {
2091 let mask = crate::Polygon::new(vec![vec![(0.1, 0.2), (0.3, 0.4), (0.5, 0.6)]]);
2093
2094 let result = polygon_to_polygon_string(&mask);
2095
2096 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2099 }
2100
2101 #[test]
2102 fn test_polygon_to_polygon_string_multiple_rings() {
2103 let mask = crate::Polygon::new(vec![
2106 vec![(0.1, 0.1), (0.2, 0.1), (0.15, 0.2)], vec![(0.5, 0.5), (0.6, 0.5), (0.55, 0.6)], ]);
2109
2110 let result = polygon_to_polygon_string(&mask);
2111
2112 assert_eq!(
2114 result,
2115 "[[[0.1,0.1],[0.2,0.1],[0.15,0.2]],[[0.5,0.5],[0.6,0.5],[0.55,0.6]]]"
2116 );
2117 }
2118
2119 #[test]
2120 fn test_polygon_to_polygon_string_filters_nan_values() {
2121 let mask = crate::Polygon::new(vec![vec![
2123 (0.1, 0.2),
2124 (f32::NAN, 0.4), (0.3, 0.4),
2126 (0.5, 0.6),
2127 ]]);
2128
2129 let result = polygon_to_polygon_string(&mask);
2130
2131 assert!(
2133 !result.contains("null"),
2134 "NaN values should be filtered out, got: {}",
2135 result
2136 );
2137 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2139 }
2140
2141 #[test]
2142 fn test_polygon_to_polygon_string_filters_infinity() {
2143 let mask = crate::Polygon::new(vec![vec![
2145 (0.1, 0.2),
2146 (f32::INFINITY, 0.4), (0.3, 0.4),
2148 (0.5, 0.6),
2149 ]]);
2150
2151 let result = polygon_to_polygon_string(&mask);
2152
2153 assert!(
2154 !result.contains("null"),
2155 "Infinity values should be filtered out"
2156 );
2157 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2158 }
2159
2160 #[test]
2161 fn test_polygon_to_polygon_string_too_few_points_after_filter() {
2162 let mask = crate::Polygon::new(vec![vec![
2164 (0.1, 0.2),
2165 (f32::NAN, 0.4), (f32::NAN, f32::NAN), ]]);
2168
2169 let result = polygon_to_polygon_string(&mask);
2170
2171 assert_eq!(result, "[]");
2173 }
2174
2175 #[test]
2176 fn test_polygon_to_polygon_string_negative_infinity() {
2177 let mask = crate::Polygon::new(vec![vec![
2178 (0.1, 0.2),
2179 (f32::NEG_INFINITY, 0.4), (0.3, 0.4),
2181 (0.5, 0.6),
2182 ]]);
2183
2184 let result = polygon_to_polygon_string(&mask);
2185 assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
2186 }
2187
2188 #[test]
2193 fn test_coco_import_result() {
2194 let result = CocoImportResult {
2195 total_images: 100,
2196 skipped: 30,
2197 imported: 70,
2198 };
2199
2200 assert_eq!(result.total_images, 100);
2201 assert_eq!(result.skipped, 30);
2202 assert_eq!(result.imported, 70);
2203 }
2204
2205 #[test]
2210 fn test_coco_update_result() {
2211 let result = CocoUpdateResult {
2212 total_images: 500,
2213 updated: 450,
2214 not_found: 50,
2215 };
2216
2217 assert_eq!(result.total_images, 500);
2218 assert_eq!(result.updated, 450);
2219 assert_eq!(result.not_found, 50);
2220 }
2221
2222 #[test]
2227 fn test_read_coco_dataset_for_update_invalid_extension() {
2228 let result = read_coco_dataset_for_update(Path::new("/tmp/file.txt"));
2229 assert!(result.is_err());
2230 let err = result.unwrap_err();
2231 assert!(
2232 err.to_string()
2233 .contains("COCO update requires a JSON annotation file")
2234 );
2235 }
2236
2237 #[test]
2238 fn test_read_coco_dataset_for_update_nonexistent_json() {
2239 let result = read_coco_dataset_for_update(Path::new("/nonexistent/file.json"));
2240 assert!(result.is_err());
2241 }
2242
2243 #[test]
2244 fn test_read_coco_dataset_for_update_nonexistent_directory() {
2245 let result = read_coco_dataset_for_update(Path::new("/nonexistent_dir"));
2246 assert!(result.is_err());
2248 }
2249
2250 #[test]
2255 fn test_build_sample_info_map_empty() {
2256 let samples: Vec<crate::Sample> = vec![];
2257 let map = build_sample_info_map(&samples);
2258 assert!(map.is_empty());
2259 }
2260
2261 #[test]
2262 fn test_build_sample_info_map_with_samples() {
2263 use crate::{Sample, SampleID};
2264
2265 let sample1 = Sample {
2266 image_name: Some("sample1".to_string()),
2267 id: Some(SampleID::from(1)),
2268 width: Some(640),
2269 height: Some(480),
2270 group: Some("train".to_string()),
2271 ..Default::default()
2272 };
2273
2274 let sample2 = Sample {
2275 image_name: Some("sample2".to_string()),
2276 id: Some(SampleID::from(2)),
2277 width: Some(1280),
2278 height: Some(720),
2279 group: None,
2280 ..Default::default()
2281 };
2282
2283 let samples = vec![sample1, sample2];
2284 let map = build_sample_info_map(&samples);
2285
2286 assert_eq!(map.len(), 2);
2287 assert!(map.contains_key("sample1"));
2288 assert!(map.contains_key("sample2"));
2289
2290 let (id1, w1, h1, g1) = map.get("sample1").unwrap();
2291 assert_eq!(*id1, SampleID::from(1));
2292 assert_eq!(*w1, 640);
2293 assert_eq!(*h1, 480);
2294 assert_eq!(g1.as_deref(), Some("train"));
2295
2296 let (id2, w2, h2, g2) = map.get("sample2").unwrap();
2297 assert_eq!(*id2, SampleID::from(2));
2298 assert_eq!(*w2, 1280);
2299 assert_eq!(*h2, 720);
2300 assert!(g2.is_none());
2301 }
2302
2303 #[test]
2304 fn test_build_sample_info_map_skips_incomplete_samples() {
2305 use crate::Sample;
2306
2307 let sample_no_id = Sample {
2309 image_name: Some("no_id".to_string()),
2310 width: Some(640),
2311 height: Some(480),
2312 ..Default::default()
2313 };
2314
2315 let sample_no_name = Sample {
2317 id: Some(crate::SampleID::from(1)),
2318 width: Some(640),
2319 height: Some(480),
2320 ..Default::default()
2321 };
2322
2323 let sample_no_dims = Sample {
2325 image_name: Some("no_dims".to_string()),
2326 id: Some(crate::SampleID::from(2)),
2327 ..Default::default()
2328 };
2329
2330 let samples = vec![sample_no_id, sample_no_name, sample_no_dims];
2331 let map = build_sample_info_map(&samples);
2332
2333 assert!(map.is_empty());
2335 }
2336
2337 #[test]
2342 fn test_coco_import_options_clone() {
2343 let options = CocoImportOptions::default();
2345 let cloned = options.clone();
2346
2347 assert_eq!(options.batch_size, cloned.batch_size);
2348 assert_eq!(options.concurrency, cloned.concurrency);
2349 assert_eq!(options.include_masks, cloned.include_masks);
2350 }
2351
2352 #[test]
2357 fn test_coco_import_options_custom() {
2358 let options = CocoImportOptions {
2359 include_masks: false,
2360 include_images: false,
2361 group: Some("test".to_string()),
2362 batch_size: 50,
2363 concurrency: 32,
2364 resume: false,
2365 };
2366
2367 assert!(!options.include_masks);
2368 assert!(!options.include_images);
2369 assert_eq!(options.group.as_deref(), Some("test"));
2370 assert_eq!(options.batch_size, 50);
2371 assert_eq!(options.concurrency, 32);
2372 assert!(!options.resume);
2373 }
2374
2375 #[test]
2376 fn test_coco_update_options_custom() {
2377 let options = CocoUpdateOptions {
2378 include_masks: false,
2379 group: Some("val".to_string()),
2380 batch_size: 25,
2381 concurrency: 16,
2382 };
2383
2384 assert!(!options.include_masks);
2385 assert_eq!(options.group.as_deref(), Some("val"));
2386 assert_eq!(options.batch_size, 25);
2387 assert_eq!(options.concurrency, 16);
2388 }
2389
2390 #[test]
2395 fn test_extract_sample_name_simple() {
2396 assert_eq!(extract_sample_name("image.jpg"), "image");
2397 }
2398
2399 #[test]
2400 fn test_extract_sample_name_with_path() {
2401 assert_eq!(extract_sample_name("train2017/000001.jpg"), "000001");
2402 }
2403
2404 #[test]
2405 fn test_extract_sample_name_no_extension() {
2406 assert_eq!(extract_sample_name("image"), "image");
2407 }
2408
2409 #[test]
2410 fn test_extract_sample_name_multiple_dots() {
2411 assert_eq!(extract_sample_name("image.v2.final.jpg"), "image.v2.final");
2412 }
2413}