use super::{
convert::{
box2d_to_coco_bbox, coco_bbox_to_box2d, coco_segmentation_to_polygon,
polygon_to_coco_polygon,
},
reader::{CocoReadOptions, CocoReader, read_coco_directory},
types::{CocoDataset, CocoImage, CocoIndex, CocoInfo, CocoSegmentation},
writer::{CocoDatasetBuilder, CocoWriteOptions, CocoWriter},
};
use crate::{
Annotation, AnnotationSetID, Client, DatasetID, Error, FileType, Progress, Sample, SampleFile,
};
use std::{
collections::HashSet,
path::{Path, PathBuf},
};
use tokio::sync::mpsc::Sender;
#[derive(Debug, Clone)]
pub struct CocoImportResult {
pub total_images: usize,
pub skipped: usize,
pub imported: usize,
}
#[derive(Debug, Clone)]
pub struct CocoImportOptions {
pub include_masks: bool,
pub include_images: bool,
pub group: Option<String>,
pub batch_size: usize,
pub concurrency: usize,
pub resume: bool,
}
impl Default for CocoImportOptions {
fn default() -> Self {
Self {
include_masks: true,
include_images: true,
group: None,
batch_size: 100,
concurrency: 64,
resume: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CocoExportOptions {
pub groups: Vec<String>,
pub include_masks: bool,
pub include_images: bool,
pub output_zip: bool,
pub pretty_json: bool,
pub info: Option<CocoInfo>,
}
impl Default for CocoExportOptions {
fn default() -> Self {
Self {
groups: vec![],
include_masks: true,
include_images: false,
output_zip: false,
pretty_json: false,
info: None,
}
}
}
pub async fn import_coco_to_studio(
client: &Client,
coco_path: impl AsRef<Path>,
dataset_id: DatasetID,
annotation_set_id: AnnotationSetID,
options: &CocoImportOptions,
progress: Option<Sender<Progress>>,
) -> Result<CocoImportResult, Error> {
let coco_path = coco_path.as_ref();
let (dataset, images_dir) = read_coco_from_path(coco_path)?;
let total_images = dataset.images.len();
if total_images == 0 {
return Err(Error::MissingAnnotations(
"No images found in COCO dataset".to_string(),
));
}
if options.include_images {
validate_images_extracted(&dataset, &images_dir)?;
}
let existing_names = fetch_existing_sample_names(client, &dataset_id, options.resume).await?;
let group_filter = options.group.as_deref();
let (images_to_import, skipped, filtered_by_group) =
filter_images_for_import(&dataset.images, group_filter, &existing_names);
log_import_filter_info(group_filter, filtered_by_group, total_images);
let to_import = images_to_import.len();
if to_import == 0 {
log_nothing_to_import(skipped);
return Ok(CocoImportResult {
total_images,
skipped,
imported: 0,
});
}
if skipped > 0 {
log::info!(
"Resuming import: {} of {} images already imported, {} remaining",
skipped,
total_images,
to_import
);
}
let index = CocoIndex::from_dataset(&dataset);
send_progress(&progress, 0, to_import).await;
let upload_ctx = UploadContext {
client,
dataset_id: &dataset_id,
annotation_set_id: &annotation_set_id,
options,
progress: &progress,
};
let imported =
upload_images_in_batches(&upload_ctx, &images_to_import, &index, &images_dir).await?;
Ok(CocoImportResult {
total_images,
skipped,
imported,
})
}
async fn fetch_existing_sample_names(
client: &Client,
dataset_id: &DatasetID,
resume: bool,
) -> Result<HashSet<String>, Error> {
if !resume {
return Ok(HashSet::new());
}
log::info!("Checking for existing samples in dataset {}...", dataset_id);
let names = client.sample_names(*dataset_id, &[], None).await?;
log::info!("Found {} existing samples in dataset", names.len());
if !names.is_empty() {
let samples: Vec<_> = names.iter().take(3).collect();
log::debug!("Sample names from server: {:?}", samples);
}
Ok(names)
}
fn log_import_filter_info(group_filter: Option<&str>, filtered_by_group: usize, total: usize) {
if filtered_by_group > 0 {
log::info!(
"Group filter '{}': {} images excluded, {} matching",
group_filter.unwrap_or(""),
filtered_by_group,
total - filtered_by_group
);
}
}
fn log_nothing_to_import(skipped: usize) {
if skipped > 0 {
log::info!(
"All {} matching images already imported, nothing to do",
skipped
);
} else {
log::info!("No images to import");
}
}
async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
if let Some(p) = progress {
let _ = p
.send(Progress {
current,
total,
status: None,
})
.await;
}
}
struct UploadContext<'a> {
client: &'a Client,
dataset_id: &'a DatasetID,
annotation_set_id: &'a AnnotationSetID,
options: &'a CocoImportOptions,
progress: &'a Option<Sender<Progress>>,
}
async fn upload_images_in_batches<'a>(
ctx: &UploadContext<'a>,
images: &[&CocoImage],
index: &CocoIndex,
images_dir: &Path,
) -> Result<usize, Error> {
let mut imported = 0;
let to_import = images.len();
for batch in images.chunks(ctx.options.batch_size) {
let samples = convert_batch_to_samples(batch, index, images_dir, ctx.options)?;
ctx.client
.populate_samples_with_concurrency(
*ctx.dataset_id,
Some(*ctx.annotation_set_id),
samples,
None,
Some(ctx.options.concurrency),
)
.await?;
imported += batch.len();
send_progress(ctx.progress, imported, to_import).await;
}
Ok(imported)
}
fn convert_batch_to_samples(
batch: &[&CocoImage],
index: &CocoIndex,
images_dir: &Path,
options: &CocoImportOptions,
) -> Result<Vec<Sample>, Error> {
let mut samples = Vec::with_capacity(batch.len());
for image in batch {
let image_group = super::reader::infer_group_from_folder(&image.file_name);
let sample = convert_coco_image_to_sample(
image,
index,
images_dir,
options.include_masks,
options.include_images,
image_group.as_deref(),
)?;
samples.push(sample);
}
Ok(samples)
}
fn validate_images_extracted(dataset: &CocoDataset, images_dir: &Path) -> Result<(), Error> {
let sample_size = std::cmp::min(5, dataset.images.len());
let mut missing = Vec::new();
for image in dataset.images.iter().take(sample_size) {
if find_image_file(images_dir, &image.file_name).is_none() {
missing.push(image.file_name.clone());
}
}
if !missing.is_empty() {
let examples: Vec<_> = missing.iter().take(3).cloned().collect();
return Err(Error::MissingImages(format!(
"Images must be extracted before import.\n\
Cannot find: {}\n\n\
Searched in: {}\n\
Expected subdirectories: train2017/, val2017/, images/\n\n\
Please extract your COCO image archives first:\n\
$ cd {} && unzip train2017.zip && unzip val2017.zip",
examples.join(", "),
images_dir.display(),
images_dir.display()
)));
}
Ok(())
}
fn find_image_file(base_dir: &Path, file_name: &str) -> Option<PathBuf> {
let candidates = [
base_dir.join(file_name),
base_dir.join("images").join(file_name),
base_dir.join("train2017").join(file_name),
base_dir.join("val2017").join(file_name),
base_dir.join("test2017").join(file_name),
base_dir.join("train2014").join(file_name),
base_dir.join("val2014").join(file_name),
];
candidates.into_iter().find(|p| p.exists())
}
fn infer_group_from_filename(path: &Path) -> Option<String> {
let stem = path.file_stem()?.to_str()?;
if let Some(rest) = stem.strip_prefix("instances_") {
let group = rest.trim_end_matches(char::is_numeric);
if !group.is_empty() {
return Some(group.to_string());
}
}
for prefix in ["train", "val", "test", "validation"] {
if stem.starts_with(prefix) {
return Some(prefix.to_string());
}
}
None
}
fn read_coco_from_path(coco_path: &Path) -> Result<(CocoDataset, PathBuf), Error> {
if coco_path.is_dir() {
let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
log::info!("Found {} annotation files in directory", datasets.len());
let mut merged = CocoDataset::default();
for (mut ds, group) in datasets {
log::info!(
" - {} group: {} images, {} annotations",
group,
ds.images.len(),
ds.annotations.len()
);
for image in &mut ds.images {
if !image.file_name.contains('/') {
image.file_name = format!("{}2017/{}", group, image.file_name);
}
}
merge_coco_datasets(&mut merged, ds);
}
Ok((merged, coco_path.to_path_buf()))
} else if coco_path.extension().is_some_and(|e| e == "json") {
let reader = CocoReader::new();
let dataset = reader.read_json(coco_path)?;
let parent = coco_path
.parent()
.and_then(|p| p.parent()) .unwrap_or(Path::new("."));
Ok((dataset, parent.to_path_buf()))
} else {
Err(Error::InvalidParameters(
"COCO import requires a JSON annotation file or directory. \
ZIP archives must be extracted first."
.to_string(),
))
}
}
fn filter_images_for_import<'a>(
images: &'a [CocoImage],
group_filter: Option<&str>,
existing_names: &HashSet<String>,
) -> (Vec<&'a CocoImage>, usize, usize) {
let total = images.len();
let images_to_import: Vec<_> = images
.iter()
.filter(|img| {
if let Some(filter) = group_filter {
let inferred = super::reader::infer_group_from_folder(&img.file_name);
if inferred.as_deref() != Some(filter) {
return false;
}
}
let sample_name = extract_sample_name(&img.file_name);
!existing_names.contains(&sample_name)
})
.collect();
let filtered_by_group = if group_filter.is_some() {
images
.iter()
.filter(|img| {
let inferred = super::reader::infer_group_from_folder(&img.file_name);
inferred.as_deref() != group_filter
})
.count()
} else {
0
};
let skipped = total - filtered_by_group - images_to_import.len();
(images_to_import, skipped, filtered_by_group)
}
fn extract_sample_name(file_name: &str) -> String {
Path::new(file_name)
.file_stem()
.and_then(|s| s.to_str())
.map(String::from)
.unwrap_or_else(|| file_name.to_string())
}
fn merge_coco_datasets(target: &mut CocoDataset, source: CocoDataset) {
let existing_image_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
for image in source.images {
if !existing_image_ids.contains(&image.id) {
target.images.push(image);
}
}
let existing_cat_ids: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
for cat in source.categories {
if !existing_cat_ids.contains(&cat.id) {
target.categories.push(cat);
}
}
target.annotations.extend(source.annotations);
let existing_license_ids: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
for license in source.licenses {
if !existing_license_ids.contains(&license.id) {
target.licenses.push(license);
}
}
if target.info.description.is_none() && source.info.description.is_some() {
target.info = source.info;
}
}
fn convert_coco_image_to_sample(
image: &CocoImage,
index: &CocoIndex,
images_dir: &Path,
include_masks: bool,
include_images: bool,
group: Option<&str>,
) -> Result<Sample, Error> {
let sample_name = Path::new(&image.file_name)
.file_stem()
.and_then(|s| s.to_str())
.map(String::from)
.unwrap_or_else(|| image.file_name.clone());
let annotations = index
.annotations_for_image(image.id)
.iter()
.filter_map(|coco_ann| {
let label = index.label_name(coco_ann.category_id)?;
let label_index = index.label_index(coco_ann.category_id);
let box2d = coco_bbox_to_box2d(&coco_ann.bbox, image.width, image.height);
let polygon = if include_masks {
coco_ann.segmentation.as_ref().and_then(|seg| {
coco_segmentation_to_polygon(seg, image.width, image.height).ok()
})
} else {
None
};
{
let mut ann = Annotation::new();
ann.set_name(Some(sample_name.clone()));
ann.set_label(Some(label.to_string()));
ann.set_label_index(label_index);
ann.set_box2d(Some(box2d));
ann.set_polygon(polygon);
ann.set_group(group.map(String::from));
ann.set_iscrowd(Some(coco_ann.iscrowd != 0));
ann.set_category_frequency(index.frequency(coco_ann.category_id).map(String::from));
Some(ann)
}
})
.collect();
let neg_label_indices = image.neg_category_ids.as_ref().map(|ids| {
ids.iter()
.filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
.collect::<Vec<u32>>()
});
let not_exhaustive_label_indices = image.not_exhaustive_category_ids.as_ref().map(|ids| {
ids.iter()
.filter_map(|&id| index.label_index(id).map(|idx| idx as u32))
.collect::<Vec<u32>>()
});
let mut files = Vec::new();
if include_images && let Some(image_path) = find_image_file(images_dir, &image.file_name) {
files.push(SampleFile::with_filename(
FileType::Image.to_string(),
image_path.to_string_lossy().to_string(),
));
}
Ok(Sample {
image_name: Some(sample_name),
width: Some(image.width),
height: Some(image.height),
group: group.map(String::from),
neg_label_indices,
not_exhaustive_label_indices,
files,
annotations,
..Default::default()
})
}
pub async fn export_studio_to_coco(
client: &Client,
dataset_id: DatasetID,
annotation_set_id: AnnotationSetID,
output_path: impl AsRef<Path>,
options: &CocoExportOptions,
progress: Option<Sender<Progress>>,
) -> Result<usize, Error> {
let output_path = output_path.as_ref();
let groups: Vec<String> = options.groups.clone();
let annotation_types = [crate::AnnotationType::Box2d, crate::AnnotationType::Polygon];
let all_samples = client
.samples(
dataset_id,
Some(annotation_set_id),
&annotation_types,
&groups,
&[],
progress.clone(),
)
.await?;
let mut builder = CocoDatasetBuilder::new();
if let Some(info) = &options.info {
builder = builder.info(info.clone());
}
for sample in &all_samples {
let image_name = sample.image_name.as_deref().unwrap_or("unknown");
let width = sample.width.unwrap_or(0);
let height = sample.height.unwrap_or(0);
let file_name = if image_name.contains('.') {
image_name.to_string()
} else {
format!("{}.jpg", image_name)
};
let image_id = builder.add_image(&file_name, width, height);
for ann in &sample.annotations {
let bbox = if let Some(box2d) = ann.box2d() {
Some(box2d_to_coco_bbox(box2d, width, height))
} else if let Some(polygon) = ann.polygon() {
compute_bbox_from_polygon(polygon, width, height)
} else {
None
};
if let Some(bbox) = bbox {
let label = ann.label().map(|s| s.as_str()).unwrap_or("unknown");
let category_id = builder.add_category(label, None);
let segmentation = if options.include_masks {
ann.polygon().map(|polygon| {
let coco_poly = polygon_to_coco_polygon(polygon, width, height);
CocoSegmentation::Polygon(coco_poly)
})
} else {
None
};
builder.add_annotation(image_id, category_id, bbox, segmentation);
}
}
}
let dataset = builder.build();
let annotation_count = dataset.annotations.len();
let writer = CocoWriter::with_options(CocoWriteOptions {
compress: true,
pretty: options.pretty_json,
});
if options.output_zip {
let images = if options.include_images {
download_images(client, &all_samples, progress.clone()).await?
} else {
vec![]
};
writer.write_zip(&dataset, images.into_iter(), output_path)?;
} else {
writer.write_json(&dataset, output_path)?;
}
Ok(annotation_count)
}
async fn download_images(
client: &Client,
samples: &[Sample],
progress: Option<Sender<Progress>>,
) -> Result<Vec<(String, Vec<u8>)>, Error> {
let mut result = Vec::with_capacity(samples.len());
let total = samples.len();
for (i, sample) in samples.iter().enumerate() {
let image_url = sample.files.iter().find_map(|f| {
if f.file_type() == "image" {
f.url()
} else {
None
}
});
if let Some(url) = image_url {
match client.download(url).await {
Ok(data) => {
let name = sample.image_name.as_deref().unwrap_or("unknown");
let filename = if name.contains('.') {
format!("images/{}", name)
} else {
format!("images/{}.jpg", name)
};
result.push((filename, data));
}
Err(e) => {
log::warn!(
"Failed to download image for sample {:?}: {}",
sample.image_name,
e
);
}
}
}
if let Some(ref p) = progress {
let _ = p
.send(Progress {
current: i + 1,
total,
status: None,
})
.await;
}
}
Ok(result)
}
#[derive(Debug, Clone)]
pub struct CocoVerifyOptions {
pub verify_masks: bool,
pub group: Option<String>,
}
impl Default for CocoVerifyOptions {
fn default() -> Self {
Self {
verify_masks: true,
group: None,
}
}
}
#[derive(Debug, Clone)]
pub struct CocoUpdateResult {
pub total_images: usize,
pub updated: usize,
pub not_found: usize,
}
#[derive(Debug, Clone)]
pub struct CocoUpdateOptions {
pub include_masks: bool,
pub group: Option<String>,
pub batch_size: usize,
pub concurrency: usize,
}
impl Default for CocoUpdateOptions {
fn default() -> Self {
Self {
include_masks: true,
group: None,
batch_size: 100,
concurrency: 64,
}
}
}
fn read_coco_dataset_for_update(coco_path: &Path) -> Result<CocoDataset, Error> {
if coco_path.is_dir() {
let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
log::info!("Found {} annotation files in directory", datasets.len());
let mut merged = CocoDataset::default();
for (mut ds, group) in datasets {
log::info!(
" - {} group: {} images, {} annotations",
group,
ds.images.len(),
ds.annotations.len()
);
for image in &mut ds.images {
if !image.file_name.contains('/') {
image.file_name = format!("{}2017/{}", group, image.file_name);
}
}
merge_coco_datasets(&mut merged, ds);
}
Ok(merged)
} else if coco_path.extension().is_some_and(|e| e == "json") {
let reader = CocoReader::new();
reader.read_json(coco_path)
} else {
Err(Error::InvalidParameters(
"COCO update requires a JSON annotation file or directory.".to_string(),
))
}
}
fn build_sample_info_map(
samples: &[Sample],
) -> std::collections::HashMap<String, (crate::SampleID, u32, u32, Option<String>)> {
use std::collections::HashMap;
let mut sample_info = HashMap::new();
for sample in samples {
if let (Some(name), Some(id), Some(w), Some(h)) =
(sample.name(), sample.id(), sample.width, sample.height)
{
sample_info.insert(name, (id, w, h, sample.group.clone()));
}
}
sample_info
}
async fn ensure_labels_exist(
client: &Client,
dataset_id: &DatasetID,
categories: &[crate::coco::CocoCategory],
) -> Result<std::collections::HashMap<String, u64>, Error> {
use std::collections::{HashMap, HashSet};
let existing_labels = client.labels(*dataset_id).await?;
let existing_label_names: HashSet<String> = existing_labels
.iter()
.map(|l| l.name().to_string())
.collect();
let missing_labels: Vec<String> = categories
.iter()
.filter(|c| !existing_label_names.contains(&c.name))
.map(|c| c.name.clone())
.collect();
if !missing_labels.is_empty() {
log::info!(
"Creating {} missing labels in Studio...",
missing_labels.len()
);
for label_name in &missing_labels {
client.add_label(*dataset_id, label_name).await?;
}
}
let labels = client.labels(*dataset_id).await?;
let label_map: HashMap<String, u64> = labels
.iter()
.map(|l| (l.name().to_string(), l.id()))
.collect();
log::info!(
"Label map has {} entries for {} COCO categories",
label_map.len(),
categories.len()
);
Ok(label_map)
}
fn convert_coco_annotation_to_server(
coco_ann: &super::types::CocoAnnotation,
coco_index: &CocoIndex,
label_map: &std::collections::HashMap<String, u64>,
image_id: u64,
annotation_set_id: u64,
dims: (u32, u32),
include_masks: bool,
) -> (crate::api::ServerAnnotation, bool) {
let (width, height) = dims;
let category_name = coco_index
.categories
.get(&coco_ann.category_id)
.map(|c| c.name.as_str())
.unwrap_or("unknown");
let label_id = label_map.get(category_name).copied();
let missing_label = label_id.is_none();
let box2d = coco_bbox_to_box2d(&coco_ann.bbox, width, height);
let polygon = if include_masks {
coco_ann
.segmentation
.as_ref()
.and_then(|seg| coco_segmentation_to_polygon(seg, width, height).ok())
.map(|p| polygon_to_polygon_string(&p))
.unwrap_or_default()
} else {
String::new()
};
let annotation_type = if polygon.is_empty() { "box" } else { "seg" }.to_string();
let server_ann = crate::api::ServerAnnotation {
label_id,
label_index: None,
label_name: Some(category_name.to_string()),
annotation_type,
x: box2d.left() as f64,
y: box2d.top() as f64,
w: box2d.width() as f64,
h: box2d.height() as f64,
score: 1.0,
polygon,
image_id,
annotation_set_id,
object_reference: None,
};
(server_ann, missing_label)
}
fn process_image_for_update(
coco_image: &CocoImage,
sample_info: &std::collections::HashMap<String, (crate::SampleID, u32, u32, Option<String>)>,
coco_index: &CocoIndex,
label_map: &std::collections::HashMap<String, u64>,
annotation_set_id: u64,
include_masks: bool,
) -> Option<(
crate::SampleID,
Vec<crate::api::ServerAnnotation>,
Option<String>,
usize,
)> {
let sample_name = extract_sample_name(&coco_image.file_name);
let expected_group = super::reader::infer_group_from_folder(&coco_image.file_name);
let (sample_id, width, height, current_group) = sample_info.get(&sample_name)?;
let (sample_id, width, height) = (*sample_id, *width, *height);
let image_id: u64 = sample_id.into();
let group_update = expected_group.as_ref().and_then(|expected| {
if Some(expected) != current_group.as_ref() {
Some(expected.clone())
} else {
None
}
});
let mut annotations = Vec::new();
let mut missing_label_count = 0;
for coco_ann in coco_index.annotations_for_image(coco_image.id) {
let (server_ann, missing) = convert_coco_annotation_to_server(
coco_ann,
coco_index,
label_map,
image_id,
annotation_set_id,
(width, height),
include_masks,
);
if missing {
missing_label_count += 1;
}
annotations.push(server_ann);
}
Some((sample_id, annotations, group_update, missing_label_count))
}
async fn update_sample_groups(
client: &Client,
dataset_id: &DatasetID,
samples_needing_group_update: &[(crate::SampleID, String)],
) -> usize {
use std::collections::{HashMap, HashSet};
if samples_needing_group_update.is_empty() {
return 0;
}
log::info!(
"Updating groups for {} samples...",
samples_needing_group_update.len()
);
let unique_groups: HashSet<String> = samples_needing_group_update
.iter()
.map(|(_, group)| group.clone())
.collect();
let mut group_id_map: HashMap<String, u64> = HashMap::new();
for group_name in unique_groups {
match client.get_or_create_group(*dataset_id, &group_name).await {
Ok(group_id) => {
group_id_map.insert(group_name, group_id);
}
Err(e) => {
log::warn!("Failed to get/create group '{}': {}", group_name, e);
}
}
}
let mut updated_count = 0;
let mut failed_count = 0;
for (sample_id, group_name) in samples_needing_group_update {
if let Some(&group_id) = group_id_map.get(group_name) {
match client.set_sample_group_id(*sample_id, group_id).await {
Ok(_) => {
updated_count += 1;
if updated_count % 1000 == 0 {
log::debug!("Updated groups for {} samples so far", updated_count);
}
}
Err(e) => {
failed_count += 1;
if failed_count <= 5 {
log::warn!("Failed to update group for sample {:?}: {}", sample_id, e);
}
}
}
}
}
if failed_count > 5 {
log::warn!("... and {} more group update failures", failed_count - 5);
}
log::info!(
"Updated groups for {} samples ({} failed)",
updated_count,
failed_count
);
updated_count
}
pub async fn update_coco_annotations(
client: &Client,
coco_path: impl AsRef<Path>,
dataset_id: DatasetID,
annotation_set_id: AnnotationSetID,
options: &CocoUpdateOptions,
progress: Option<Sender<Progress>>,
) -> Result<CocoUpdateResult, Error> {
use crate::{SampleID, api::ServerAnnotation};
let coco_path = coco_path.as_ref();
let dataset = read_coco_dataset_for_update(coco_path)?;
let total_images = dataset.images.len();
if total_images == 0 {
return Err(Error::MissingAnnotations(
"No images found in COCO dataset".to_string(),
));
}
log::info!(
"COCO dataset: {} images, {} annotations, {} categories",
total_images,
dataset.annotations.len(),
dataset.categories.len()
);
log::info!("Fetching existing samples from Studio...");
let existing_samples = client
.samples(
dataset_id,
Some(annotation_set_id),
&[],
&[],
&[],
progress.clone(),
)
.await?;
let sample_info = build_sample_info_map(&existing_samples);
log::info!(
"Found {} existing samples in Studio with IDs and dimensions",
sample_info.len()
);
let coco_index = CocoIndex::from_dataset(&dataset);
let label_map = ensure_labels_exist(client, &dataset_id, &dataset.categories).await?;
let annotation_set_id_u64: u64 = annotation_set_id.into();
let mut sample_ids_to_update: Vec<SampleID> = Vec::new();
let mut server_annotations: Vec<ServerAnnotation> = Vec::new();
let mut samples_needing_group_update: Vec<(SampleID, String)> = Vec::new();
let mut not_found = 0;
let mut missing_label_count = 0;
for coco_image in &dataset.images {
match process_image_for_update(
coco_image,
&sample_info,
&coco_index,
&label_map,
annotation_set_id_u64,
options.include_masks,
) {
Some((sample_id, annotations, group_update, missing_labels)) => {
sample_ids_to_update.push(sample_id);
server_annotations.extend(annotations);
missing_label_count += missing_labels;
if let Some(group) = group_update {
samples_needing_group_update.push((sample_id, group));
}
}
None => {
not_found += 1;
log::debug!(
"Sample not found in Studio: {}",
extract_sample_name(&coco_image.file_name)
);
}
}
}
let to_update = sample_ids_to_update.len();
log::info!(
"Updating {} samples ({} not found in Studio), {} annotations",
to_update,
not_found,
server_annotations.len()
);
if missing_label_count > 0 {
log::warn!(
"{} annotations have missing label_id (category not found in label map)",
missing_label_count
);
}
if to_update == 0 {
return Ok(CocoUpdateResult {
total_images,
updated: 0,
not_found,
});
}
if let Some(ref tx) = progress {
let _ = tx
.send(Progress {
current: 0,
total: to_update,
status: None,
})
.await;
}
log::info!(
"Deleting existing annotations for {} samples...",
sample_ids_to_update.len()
);
let annotation_types = if options.include_masks {
vec!["box".to_string(), "seg".to_string()]
} else {
vec!["box".to_string()]
};
for batch in sample_ids_to_update.chunks(options.batch_size) {
client
.delete_annotations_bulk(annotation_set_id, &annotation_types, batch)
.await?;
}
if let Some(ref tx) = progress {
let _ = tx
.send(Progress {
current: to_update / 2,
total: to_update,
status: None,
})
.await;
}
log::info!("Adding {} new annotations...", server_annotations.len());
let mut added = 0;
for batch in server_annotations.chunks(options.batch_size) {
client
.add_annotations_bulk(annotation_set_id, batch.to_vec())
.await?;
added += batch.len();
log::debug!("Added {} annotations so far", added);
}
if let Some(ref tx) = progress {
let _ = tx
.send(Progress {
current: to_update,
total: to_update,
status: None,
})
.await;
}
let groups_updated =
update_sample_groups(client, &dataset_id, &samples_needing_group_update).await;
log::info!(
"Update complete: {} samples updated, {} not found, {} annotations added, {} groups updated",
to_update,
not_found,
added,
groups_updated
);
Ok(CocoUpdateResult {
total_images,
updated: to_update,
not_found,
})
}
fn polygon_to_polygon_string(polygon: &crate::Polygon) -> String {
let rings: Vec<Vec<[f32; 2]>> = polygon
.rings
.iter()
.map(|ring| {
ring.iter()
.filter(|(x, y)| x.is_finite() && y.is_finite())
.map(|&(x, y)| [x, y])
.collect()
})
.filter(|ring: &Vec<[f32; 2]>| ring.len() >= 3) .collect();
serde_json::to_string(&rings).unwrap_or_default()
}
fn compute_bbox_from_polygon(
polygon: &crate::Polygon,
width: u32,
height: u32,
) -> Option<[f64; 4]> {
if polygon.rings.is_empty() {
return None;
}
let mut min_x = f32::MAX;
let mut min_y = f32::MAX;
let mut max_x = f32::MIN;
let mut max_y = f32::MIN;
for ring in &polygon.rings {
for &(x, y) in ring {
if x.is_finite() && y.is_finite() {
min_x = min_x.min(x);
min_y = min_y.min(y);
max_x = max_x.max(x);
max_y = max_y.max(y);
}
}
}
if min_x == f32::MAX || min_y == f32::MAX {
return None;
}
let x = (min_x * width as f32) as f64;
let y = (min_y * height as f32) as f64;
let w = ((max_x - min_x) * width as f32) as f64;
let h = ((max_y - min_y) * height as f32) as f64;
if w > 0.0 && h > 0.0 {
Some([x, y, w, h])
} else {
None
}
}
pub async fn verify_coco_import(
client: &Client,
coco_path: impl AsRef<Path>,
dataset_id: DatasetID,
annotation_set_id: AnnotationSetID,
options: &CocoVerifyOptions,
progress: Option<Sender<Progress>>,
) -> Result<super::verify::VerificationResult, Error> {
use super::{verify::*, writer::CocoDatasetBuilder};
let coco_path = coco_path.as_ref();
log::info!("Reading local COCO dataset from {:?}", coco_path);
let (coco_dataset, inferred_group) = if coco_path.is_dir() {
let datasets = read_coco_directory(coco_path, &CocoReadOptions::default())?;
log::info!("Found {} annotation files in directory", datasets.len());
let mut merged = CocoDataset::default();
for (ds, group) in datasets {
log::info!(
" - {} group: {} images, {} annotations",
group,
ds.images.len(),
ds.annotations.len()
);
merge_coco_datasets(&mut merged, ds);
}
(merged, None)
} else if coco_path.extension().is_some_and(|e| e == "json") {
let reader = CocoReader::new();
let dataset = reader.read_json(coco_path)?;
let group = infer_group_from_filename(coco_path);
(dataset, group)
} else {
return Err(Error::InvalidParameters(
"COCO verification requires a JSON annotation file or directory.".to_string(),
));
};
let effective_group = options.group.clone().or(inferred_group);
let groups: Vec<String> = effective_group
.as_ref()
.map(|g| vec![g.clone()])
.unwrap_or_default();
log::info!(
"Local COCO: {} images, {} annotations",
coco_dataset.images.len(),
coco_dataset.annotations.len()
);
log::info!("Fetching samples from Studio dataset {}...", dataset_id);
let annotation_types = [crate::AnnotationType::Box2d, crate::AnnotationType::Polygon];
let studio_samples = client
.samples(
dataset_id,
Some(annotation_set_id),
&annotation_types,
&groups,
&[],
progress.clone(),
)
.await?;
let total_annotations: usize = studio_samples.iter().map(|s| s.annotations.len()).sum();
log::info!(
"Studio: {} samples, {} total annotations",
studio_samples.len(),
total_annotations
);
let mut builder = CocoDatasetBuilder::new();
for sample in &studio_samples {
let image_name = sample.image_name.as_deref().unwrap_or("unknown");
let width = sample.width.unwrap_or(0);
let height = sample.height.unwrap_or(0);
let file_name = if image_name.contains('.') {
image_name.to_string()
} else {
format!("{}.jpg", image_name)
};
let image_id = builder.add_image(&file_name, width, height);
for ann in &sample.annotations {
let bbox = if let Some(box2d) = ann.box2d() {
Some(box2d_to_coco_bbox(box2d, width, height))
} else if let Some(polygon) = ann.polygon() {
compute_bbox_from_polygon(polygon, width, height)
} else {
None
};
if let Some(bbox) = bbox {
let label = ann.label().map(|s| s.as_str()).unwrap_or("unknown");
let category_id = builder.add_category(label, None);
let segmentation = if options.verify_masks {
ann.polygon().map(|polygon| {
let coco_poly = polygon_to_coco_polygon(polygon, width, height);
CocoSegmentation::Polygon(coco_poly)
})
} else {
None
};
builder.add_annotation(image_id, category_id, bbox, segmentation);
}
}
}
let studio_dataset = builder.build();
let coco_names: HashSet<String> = coco_dataset
.images
.iter()
.map(|img| {
Path::new(&img.file_name)
.file_stem()
.and_then(|s| s.to_str())
.map(String::from)
.unwrap_or_else(|| img.file_name.clone())
})
.collect();
let studio_names: HashSet<String> = studio_samples.iter().filter_map(|s| s.name()).collect();
let missing_images: Vec<String> = coco_names.difference(&studio_names).cloned().collect();
let extra_images: Vec<String> = studio_names.difference(&coco_names).cloned().collect();
log::info!("Validating bounding boxes...");
let bbox_validation = validate_bboxes(&coco_dataset, &studio_dataset);
log::info!("Validating segmentation masks...");
let mask_validation = if options.verify_masks {
validate_masks(&coco_dataset, &studio_dataset)
} else {
MaskValidationResult::new()
};
let category_validation = validate_categories(&coco_dataset, &studio_dataset);
Ok(VerificationResult {
coco_image_count: coco_dataset.images.len(),
studio_image_count: studio_samples.len(),
missing_images,
extra_images,
coco_annotation_count: coco_dataset.annotations.len(),
studio_annotation_count: studio_dataset.annotations.len(),
bbox_validation,
mask_validation,
category_validation,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coco::{CocoAnnotation, CocoCategory};
#[test]
fn test_coco_import_options_default() {
let options = CocoImportOptions::default();
assert!(options.include_masks);
assert!(options.include_images);
assert!(options.group.is_none());
assert_eq!(options.batch_size, 100);
assert_eq!(options.concurrency, 64);
assert!(options.resume);
}
#[test]
fn test_coco_export_options_default() {
let options = CocoExportOptions::default();
assert!(options.groups.is_empty());
assert!(options.include_masks);
assert!(!options.include_images);
assert!(!options.output_zip);
assert!(!options.pretty_json);
assert!(options.info.is_none());
}
#[test]
fn test_coco_update_options_default() {
let options = CocoUpdateOptions::default();
assert!(options.include_masks);
assert!(options.group.is_none());
assert_eq!(options.batch_size, 100);
assert_eq!(options.concurrency, 64);
}
#[test]
fn test_coco_verify_options_default() {
let options = CocoVerifyOptions::default();
assert!(options.verify_masks);
assert!(options.group.is_none());
}
#[test]
fn test_find_image_file_nonexistent() {
let result = find_image_file(Path::new("/nonexistent"), "test.jpg");
assert!(result.is_none());
}
#[test]
fn test_find_image_file_with_subdirectory_in_name() {
let result = find_image_file(Path::new("/nonexistent"), "train2017/image.jpg");
assert!(result.is_none()); }
#[test]
fn test_infer_group_from_filename_instances_train() {
let path = Path::new("annotations/instances_train2017.json");
assert_eq!(infer_group_from_filename(path), Some("train".to_string()));
}
#[test]
fn test_infer_group_from_filename_instances_val() {
let path = Path::new("annotations/instances_val2017.json");
assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
}
#[test]
fn test_infer_group_from_filename_instances_test() {
let path = Path::new("instances_test2017.json");
assert_eq!(infer_group_from_filename(path), Some("test".to_string()));
}
#[test]
fn test_infer_group_from_filename_train_prefix() {
let path = Path::new("train_annotations.json");
assert_eq!(infer_group_from_filename(path), Some("train".to_string()));
}
#[test]
fn test_infer_group_from_filename_val_prefix() {
let path = Path::new("val_data.json");
assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
}
#[test]
fn test_infer_group_from_filename_validation_prefix() {
let path = Path::new("validation_set.json");
assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
}
#[test]
fn test_infer_group_from_filename_custom() {
let path = Path::new("my_custom_annotations.json");
assert_eq!(infer_group_from_filename(path), None);
}
#[test]
fn test_infer_group_from_filename_instances_2014() {
let path = Path::new("instances_val2014.json");
assert_eq!(infer_group_from_filename(path), Some("val".to_string()));
}
#[test]
fn test_merge_coco_datasets_empty() {
let mut target = CocoDataset::default();
let source = CocoDataset::default();
merge_coco_datasets(&mut target, source);
assert!(target.images.is_empty());
assert!(target.annotations.is_empty());
assert!(target.categories.is_empty());
}
#[test]
fn test_merge_coco_datasets_basic() {
let mut target = CocoDataset {
images: vec![CocoImage {
id: 1,
file_name: "img1.jpg".to_string(),
..Default::default()
}],
categories: vec![CocoCategory {
id: 1,
name: "cat".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 1,
image_id: 1,
category_id: 1,
..Default::default()
}],
..Default::default()
};
let source = CocoDataset {
images: vec![CocoImage {
id: 2,
file_name: "img2.jpg".to_string(),
..Default::default()
}],
categories: vec![CocoCategory {
id: 2,
name: "dog".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 2,
image_id: 2,
category_id: 2,
..Default::default()
}],
..Default::default()
};
merge_coco_datasets(&mut target, source);
assert_eq!(target.images.len(), 2);
assert_eq!(target.categories.len(), 2);
assert_eq!(target.annotations.len(), 2);
}
#[test]
fn test_merge_coco_datasets_deduplicates_images() {
let mut target = CocoDataset {
images: vec![CocoImage {
id: 1,
file_name: "img1.jpg".to_string(),
..Default::default()
}],
..Default::default()
};
let source = CocoDataset {
images: vec![
CocoImage {
id: 1, file_name: "img1_dup.jpg".to_string(),
..Default::default()
},
CocoImage {
id: 2,
file_name: "img2.jpg".to_string(),
..Default::default()
},
],
..Default::default()
};
merge_coco_datasets(&mut target, source);
assert_eq!(target.images.len(), 2); assert_eq!(target.images[0].file_name, "img1.jpg"); }
#[test]
fn test_merge_coco_datasets_deduplicates_categories() {
let mut target = CocoDataset {
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
..Default::default()
};
let source = CocoDataset {
categories: vec![
CocoCategory {
id: 1, name: "person_dup".to_string(),
supercategory: None,
..Default::default()
},
CocoCategory {
id: 2,
name: "car".to_string(),
supercategory: None,
..Default::default()
},
],
..Default::default()
};
merge_coco_datasets(&mut target, source);
assert_eq!(target.categories.len(), 2);
assert_eq!(target.categories[0].name, "person"); }
#[test]
fn test_merge_coco_datasets_info_preserved() {
let mut target = CocoDataset::default();
let source = CocoDataset {
info: CocoInfo {
description: Some("Test dataset".to_string()),
..Default::default()
},
..Default::default()
};
merge_coco_datasets(&mut target, source);
assert_eq!(target.info.description, Some("Test dataset".to_string()));
}
#[test]
fn test_convert_coco_image_to_sample() {
let image = CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "test.jpg".to_string(),
..Default::default()
};
let dataset = CocoDataset {
images: vec![image.clone()],
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 1,
image_id: 1,
category_id: 1,
bbox: [100.0, 50.0, 200.0, 150.0],
area: 30000.0,
iscrowd: 0,
segmentation: None,
score: None,
}],
..Default::default()
};
let index = CocoIndex::from_dataset(&dataset);
let sample = convert_coco_image_to_sample(
&image,
&index,
Path::new("/tmp"),
true,
false, Some("train"),
)
.unwrap();
assert_eq!(sample.image_name, Some("test".to_string()));
assert_eq!(sample.width, Some(640));
assert_eq!(sample.height, Some(480));
assert_eq!(sample.group, Some("train".to_string()));
assert_eq!(sample.annotations.len(), 1);
assert_eq!(sample.annotations[0].label(), Some(&"person".to_string()));
}
#[test]
fn test_convert_coco_image_to_sample_no_annotations() {
let image = CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "empty.jpg".to_string(),
..Default::default()
};
let dataset = CocoDataset {
images: vec![image.clone()],
categories: vec![],
annotations: vec![],
..Default::default()
};
let index = CocoIndex::from_dataset(&dataset);
let sample =
convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), true, false, None)
.unwrap();
assert_eq!(sample.image_name, Some("empty".to_string()));
assert!(sample.annotations.is_empty());
}
#[test]
fn test_convert_coco_image_to_sample_with_mask() {
let image = CocoImage {
id: 1,
width: 100,
height: 100,
file_name: "masked.jpg".to_string(),
..Default::default()
};
let dataset = CocoDataset {
images: vec![image.clone()],
categories: vec![CocoCategory {
id: 1,
name: "object".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 1,
image_id: 1,
category_id: 1,
bbox: [10.0, 10.0, 50.0, 50.0],
area: 2500.0,
iscrowd: 0,
segmentation: Some(CocoSegmentation::Polygon(vec![vec![
10.0, 10.0, 60.0, 10.0, 60.0, 60.0, 10.0, 60.0,
]])),
score: None,
}],
..Default::default()
};
let index = CocoIndex::from_dataset(&dataset);
let sample_with =
convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), true, false, None)
.unwrap();
assert!(sample_with.annotations[0].polygon().is_some());
let sample_without =
convert_coco_image_to_sample(&image, &index, Path::new("/tmp"), false, false, None)
.unwrap();
assert!(sample_without.annotations[0].polygon().is_none());
}
#[test]
fn test_compute_bbox_from_polygon_simple() {
let mask = crate::Polygon::new(vec![vec![(0.1, 0.1), (0.5, 0.1), (0.5, 0.5), (0.1, 0.5)]]);
let bbox = compute_bbox_from_polygon(&mask, 100, 100);
assert!(bbox.is_some());
let [x, y, w, h] = bbox.unwrap();
assert!((x - 10.0).abs() < 1.0);
assert!((y - 10.0).abs() < 1.0);
assert!((w - 40.0).abs() < 1.0);
assert!((h - 40.0).abs() < 1.0);
}
#[test]
fn test_compute_bbox_from_polygon_empty() {
let mask = crate::Polygon::new(vec![]);
let bbox = compute_bbox_from_polygon(&mask, 100, 100);
assert!(bbox.is_none());
}
#[test]
fn test_compute_bbox_from_polygon_with_nan() {
let mask = crate::Polygon::new(vec![vec![(f32::NAN, f32::NAN), (f32::NAN, f32::NAN)]]);
let bbox = compute_bbox_from_polygon(&mask, 100, 100);
assert!(bbox.is_none());
}
#[test]
fn test_compute_bbox_from_polygon_multiple_rings() {
let mask = crate::Polygon::new(vec![
vec![(0.1, 0.1), (0.2, 0.1), (0.2, 0.2), (0.1, 0.2)],
vec![(0.8, 0.8), (0.9, 0.8), (0.9, 0.9), (0.8, 0.9)],
]);
let bbox = compute_bbox_from_polygon(&mask, 100, 100);
assert!(bbox.is_some());
let [x, y, w, h] = bbox.unwrap();
assert!((x - 10.0).abs() < 1.0);
assert!((y - 10.0).abs() < 1.0);
assert!((w - 80.0).abs() < 1.0);
assert!((h - 80.0).abs() < 1.0);
}
#[test]
fn test_polygon_to_polygon_string() {
let mask = crate::Polygon::new(vec![vec![(0.1, 0.2), (0.3, 0.4), (0.5, 0.6)]]);
let result = polygon_to_polygon_string(&mask);
assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
}
#[test]
fn test_polygon_to_polygon_string_multiple_rings() {
let mask = crate::Polygon::new(vec![
vec![(0.1, 0.1), (0.2, 0.1), (0.15, 0.2)], vec![(0.5, 0.5), (0.6, 0.5), (0.55, 0.6)], ]);
let result = polygon_to_polygon_string(&mask);
assert_eq!(
result,
"[[[0.1,0.1],[0.2,0.1],[0.15,0.2]],[[0.5,0.5],[0.6,0.5],[0.55,0.6]]]"
);
}
#[test]
fn test_polygon_to_polygon_string_filters_nan_values() {
let mask = crate::Polygon::new(vec![vec![
(0.1, 0.2),
(f32::NAN, 0.4), (0.3, 0.4),
(0.5, 0.6),
]]);
let result = polygon_to_polygon_string(&mask);
assert!(
!result.contains("null"),
"NaN values should be filtered out, got: {}",
result
);
assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
}
#[test]
fn test_polygon_to_polygon_string_filters_infinity() {
let mask = crate::Polygon::new(vec![vec![
(0.1, 0.2),
(f32::INFINITY, 0.4), (0.3, 0.4),
(0.5, 0.6),
]]);
let result = polygon_to_polygon_string(&mask);
assert!(
!result.contains("null"),
"Infinity values should be filtered out"
);
assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
}
#[test]
fn test_polygon_to_polygon_string_too_few_points_after_filter() {
let mask = crate::Polygon::new(vec![vec![
(0.1, 0.2),
(f32::NAN, 0.4), (f32::NAN, f32::NAN), ]]);
let result = polygon_to_polygon_string(&mask);
assert_eq!(result, "[]");
}
#[test]
fn test_polygon_to_polygon_string_negative_infinity() {
let mask = crate::Polygon::new(vec![vec![
(0.1, 0.2),
(f32::NEG_INFINITY, 0.4), (0.3, 0.4),
(0.5, 0.6),
]]);
let result = polygon_to_polygon_string(&mask);
assert_eq!(result, "[[[0.1,0.2],[0.3,0.4],[0.5,0.6]]]");
}
#[test]
fn test_coco_import_result() {
let result = CocoImportResult {
total_images: 100,
skipped: 30,
imported: 70,
};
assert_eq!(result.total_images, 100);
assert_eq!(result.skipped, 30);
assert_eq!(result.imported, 70);
}
#[test]
fn test_coco_update_result() {
let result = CocoUpdateResult {
total_images: 500,
updated: 450,
not_found: 50,
};
assert_eq!(result.total_images, 500);
assert_eq!(result.updated, 450);
assert_eq!(result.not_found, 50);
}
#[test]
fn test_read_coco_dataset_for_update_invalid_extension() {
let result = read_coco_dataset_for_update(Path::new("/tmp/file.txt"));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("COCO update requires a JSON annotation file")
);
}
#[test]
fn test_read_coco_dataset_for_update_nonexistent_json() {
let result = read_coco_dataset_for_update(Path::new("/nonexistent/file.json"));
assert!(result.is_err());
}
#[test]
fn test_read_coco_dataset_for_update_nonexistent_directory() {
let result = read_coco_dataset_for_update(Path::new("/nonexistent_dir"));
assert!(result.is_err());
}
#[test]
fn test_build_sample_info_map_empty() {
let samples: Vec<crate::Sample> = vec![];
let map = build_sample_info_map(&samples);
assert!(map.is_empty());
}
#[test]
fn test_build_sample_info_map_with_samples() {
use crate::{Sample, SampleID};
let sample1 = Sample {
image_name: Some("sample1".to_string()),
id: Some(SampleID::from(1)),
width: Some(640),
height: Some(480),
group: Some("train".to_string()),
..Default::default()
};
let sample2 = Sample {
image_name: Some("sample2".to_string()),
id: Some(SampleID::from(2)),
width: Some(1280),
height: Some(720),
group: None,
..Default::default()
};
let samples = vec![sample1, sample2];
let map = build_sample_info_map(&samples);
assert_eq!(map.len(), 2);
assert!(map.contains_key("sample1"));
assert!(map.contains_key("sample2"));
let (id1, w1, h1, g1) = map.get("sample1").unwrap();
assert_eq!(*id1, SampleID::from(1));
assert_eq!(*w1, 640);
assert_eq!(*h1, 480);
assert_eq!(g1.as_deref(), Some("train"));
let (id2, w2, h2, g2) = map.get("sample2").unwrap();
assert_eq!(*id2, SampleID::from(2));
assert_eq!(*w2, 1280);
assert_eq!(*h2, 720);
assert!(g2.is_none());
}
#[test]
fn test_build_sample_info_map_skips_incomplete_samples() {
use crate::Sample;
let sample_no_id = Sample {
image_name: Some("no_id".to_string()),
width: Some(640),
height: Some(480),
..Default::default()
};
let sample_no_name = Sample {
id: Some(crate::SampleID::from(1)),
width: Some(640),
height: Some(480),
..Default::default()
};
let sample_no_dims = Sample {
image_name: Some("no_dims".to_string()),
id: Some(crate::SampleID::from(2)),
..Default::default()
};
let samples = vec![sample_no_id, sample_no_name, sample_no_dims];
let map = build_sample_info_map(&samples);
assert!(map.is_empty());
}
#[test]
fn test_coco_import_options_clone() {
let options = CocoImportOptions::default();
let cloned = options.clone();
assert_eq!(options.batch_size, cloned.batch_size);
assert_eq!(options.concurrency, cloned.concurrency);
assert_eq!(options.include_masks, cloned.include_masks);
}
#[test]
fn test_coco_import_options_custom() {
let options = CocoImportOptions {
include_masks: false,
include_images: false,
group: Some("test".to_string()),
batch_size: 50,
concurrency: 32,
resume: false,
};
assert!(!options.include_masks);
assert!(!options.include_images);
assert_eq!(options.group.as_deref(), Some("test"));
assert_eq!(options.batch_size, 50);
assert_eq!(options.concurrency, 32);
assert!(!options.resume);
}
#[test]
fn test_coco_update_options_custom() {
let options = CocoUpdateOptions {
include_masks: false,
group: Some("val".to_string()),
batch_size: 25,
concurrency: 16,
};
assert!(!options.include_masks);
assert_eq!(options.group.as_deref(), Some("val"));
assert_eq!(options.batch_size, 25);
assert_eq!(options.concurrency, 16);
}
#[test]
fn test_extract_sample_name_simple() {
assert_eq!(extract_sample_name("image.jpg"), "image");
}
#[test]
fn test_extract_sample_name_with_path() {
assert_eq!(extract_sample_name("train2017/000001.jpg"), "000001");
}
#[test]
fn test_extract_sample_name_no_extension() {
assert_eq!(extract_sample_name("image"), "image");
}
#[test]
fn test_extract_sample_name_multiple_dots() {
assert_eq!(extract_sample_name("image.v2.final.jpg"), "image.v2.final");
}
}