use super::types::CocoDataset;
use crate::Error;
use std::{
collections::HashSet,
fs::File,
io::{BufReader, Read},
path::Path,
};
#[derive(Debug, Clone, Default)]
pub struct CocoReadOptions {
pub validate: bool,
pub max_images: usize,
pub category_filter: Vec<String>,
}
pub struct CocoReader {
options: CocoReadOptions,
}
impl CocoReader {
pub fn new() -> Self {
Self {
options: CocoReadOptions::default(),
}
}
pub fn with_options(options: CocoReadOptions) -> Self {
Self { options }
}
pub fn read_json<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
let file = File::open(path.as_ref())?;
let reader = BufReader::with_capacity(64 * 1024, file);
let mut dataset: CocoDataset = serde_json::from_reader(reader)?;
fill_missing_file_names(&mut dataset);
if self.options.validate {
validate_dataset(&dataset)?;
}
Ok(self.apply_filters(dataset))
}
pub fn read_annotations_zip<P: AsRef<Path>>(&self, path: P) -> Result<CocoDataset, Error> {
let file = File::open(path.as_ref())?;
let mut archive = zip::ZipArchive::new(file)?;
let mut merged = CocoDataset::default();
for i in 0..archive.len() {
let mut entry = archive.by_index(i)?;
let name = entry.name().to_string();
if name.ends_with(".json") && name.contains("instances") {
let mut contents = String::new();
entry.read_to_string(&mut contents)?;
let mut dataset: CocoDataset = serde_json::from_str(&contents)?;
fill_missing_file_names(&mut dataset);
merge_datasets(&mut merged, dataset);
}
}
if self.options.validate {
validate_dataset(&merged)?;
}
Ok(self.apply_filters(merged))
}
pub fn list_images<P: AsRef<Path>>(
&self,
path: P,
) -> Result<Vec<(String, std::path::PathBuf)>, Error> {
let path = path.as_ref();
let mut images = Vec::new();
if path.is_dir() {
for entry in walkdir::WalkDir::new(path)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| e.file_type().is_file())
{
let filename = entry.file_name().to_string_lossy().to_lowercase();
if filename.ends_with(".jpg")
|| filename.ends_with(".jpeg")
|| filename.ends_with(".png")
{
let rel_path = entry
.path()
.strip_prefix(path)
.unwrap_or(entry.path())
.to_string_lossy()
.to_string();
images.push((rel_path, entry.path().to_path_buf()));
}
}
} else if path.extension().is_some_and(|e| e == "zip") {
let file = File::open(path)?;
let mut archive = zip::ZipArchive::new(file)?;
for i in 0..archive.len() {
let entry = archive.by_index(i)?;
let name = entry.name().to_string();
let name_lower = name.to_lowercase();
if !entry.is_dir()
&& (name_lower.ends_with(".jpg")
|| name_lower.ends_with(".jpeg")
|| name_lower.ends_with(".png"))
{
images.push((name.clone(), path.join(&name)));
}
}
}
Ok(images)
}
pub fn read_image_from_zip<P: AsRef<Path>>(
&self,
zip_path: P,
image_name: &str,
) -> Result<Vec<u8>, Error> {
let file = File::open(zip_path.as_ref())?;
let mut archive = zip::ZipArchive::new(file)?;
let mut entry = archive.by_name(image_name)?;
let mut buffer = Vec::with_capacity(entry.size() as usize);
entry.read_to_end(&mut buffer)?;
Ok(buffer)
}
fn apply_filters(&self, mut dataset: CocoDataset) -> CocoDataset {
if self.options.max_images > 0 && dataset.images.len() > self.options.max_images {
let image_ids: HashSet<_> = dataset
.images
.iter()
.take(self.options.max_images)
.map(|i| i.id)
.collect();
dataset.images.truncate(self.options.max_images);
dataset
.annotations
.retain(|a| image_ids.contains(&a.image_id));
}
if !self.options.category_filter.is_empty() {
let category_ids: HashSet<_> = dataset
.categories
.iter()
.filter(|c| self.options.category_filter.contains(&c.name))
.map(|c| c.id)
.collect();
dataset
.categories
.retain(|c| self.options.category_filter.contains(&c.name));
dataset
.annotations
.retain(|a| category_ids.contains(&a.category_id));
}
dataset
}
}
impl Default for CocoReader {
fn default() -> Self {
Self::new()
}
}
fn validate_dataset(dataset: &CocoDataset) -> Result<(), Error> {
let image_ids: HashSet<_> = dataset.images.iter().map(|i| i.id).collect();
let category_ids: HashSet<_> = dataset.categories.iter().map(|c| c.id).collect();
for ann in &dataset.annotations {
if !image_ids.contains(&ann.image_id) {
return Err(Error::CocoError(format!(
"Annotation {} references non-existent image_id {}",
ann.id, ann.image_id
)));
}
if !category_ids.contains(&ann.category_id) {
return Err(Error::CocoError(format!(
"Annotation {} references non-existent category_id {}",
ann.id, ann.category_id
)));
}
if ann.bbox[2] <= 0.0 || ann.bbox[3] <= 0.0 {
return Err(Error::CocoError(format!(
"Annotation {} has invalid bbox dimensions",
ann.id
)));
}
}
Ok(())
}
fn derive_file_name_from_coco_url(url: &str) -> Option<String> {
let after_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
let (_host, path) = after_scheme.split_once('/')?;
if path.is_empty() {
return None;
}
if path.starts_with('/') {
return None;
}
if path.contains('\\') || path.contains(':') {
return None;
}
let as_path = Path::new(path);
for component in as_path.components() {
use std::path::Component;
match component {
Component::Normal(_) => continue,
Component::RootDir
| Component::Prefix(_)
| Component::ParentDir
| Component::CurDir => return None,
}
}
Some(path.to_string())
}
fn fill_missing_file_names(dataset: &mut CocoDataset) {
for image in &mut dataset.images {
if !image.file_name.is_empty() {
continue;
}
if let Some(derived) = image
.coco_url
.as_deref()
.and_then(derive_file_name_from_coco_url)
{
image.file_name = derived;
}
}
}
pub fn infer_group_from_filename(filename: &str) -> Option<String> {
let stem = Path::new(filename).file_stem()?.to_str()?;
if let Some(rest) = stem.strip_prefix("instances_") {
let group = rest.trim_end_matches(char::is_numeric);
if !group.is_empty() {
return Some(group.to_string());
}
}
if let Some(rest) = stem.strip_prefix("person_keypoints_") {
let group = rest.trim_end_matches(char::is_numeric);
if !group.is_empty() {
return Some(group.to_string());
}
}
if let Some(rest) = stem.strip_prefix("captions_") {
let group = rest.trim_end_matches(char::is_numeric);
if !group.is_empty() {
return Some(group.to_string());
}
}
if let Some(rest) = stem.strip_prefix("panoptic_") {
let group = rest.trim_end_matches(char::is_numeric);
if !group.is_empty() {
return Some(group.to_string());
}
}
let lower = filename.to_lowercase();
if lower.contains("train") {
return Some("train".to_string());
}
if lower.contains("val") {
return Some("val".to_string());
}
if lower.contains("test") {
return Some("test".to_string());
}
None
}
pub fn infer_group_from_folder(image_path: &str) -> Option<String> {
let path = Path::new(image_path);
let folder = path.parent()?.file_name()?.to_str()?;
if folder.is_empty() {
return None;
}
let group = folder.trim_end_matches(char::is_numeric);
if group.is_empty() {
Some(folder.to_string())
} else {
Some(group.to_string())
}
}
pub fn read_coco_directory<P: AsRef<Path>>(
path: P,
options: &CocoReadOptions,
) -> Result<Vec<(CocoDataset, String)>, Error> {
let path = path.as_ref();
let mut results = Vec::new();
let annotations_dir = path.join("annotations");
let search_dirs: Vec<&Path> = if annotations_dir.is_dir() {
vec![annotations_dir.as_path(), path]
} else {
vec![path]
};
for search_dir in search_dirs {
if !search_dir.is_dir() {
continue;
}
for entry in std::fs::read_dir(search_dir)? {
let entry = entry?;
let file_path = entry.path();
if !file_path.is_file() {
continue;
}
let filename = file_path.file_name().and_then(|s| s.to_str()).unwrap_or("");
if filename.ends_with(".json") && filename.contains("instances") {
let group =
infer_group_from_filename(filename).unwrap_or_else(|| "default".to_string());
let reader = CocoReader::with_options(options.clone());
let dataset = reader.read_json(&file_path)?;
results.push((dataset, group));
}
}
}
if results.is_empty() {
return Err(Error::MissingAnnotations(format!(
"No COCO annotation files found in {}",
path.display()
)));
}
Ok(results)
}
fn merge_datasets(target: &mut CocoDataset, source: CocoDataset) {
if target.info.description.is_none() {
target.info = source.info;
}
let existing_ids: HashSet<_> = target.images.iter().map(|i| i.id).collect();
for image in source.images {
if !existing_ids.contains(&image.id) {
target.images.push(image);
}
}
let existing_cats: HashSet<_> = target.categories.iter().map(|c| c.id).collect();
for cat in source.categories {
if !existing_cats.contains(&cat.id) {
target.categories.push(cat);
}
}
target.annotations.extend(source.annotations);
let existing_licenses: HashSet<_> = target.licenses.iter().map(|l| l.id).collect();
for lic in source.licenses {
if !existing_licenses.contains(&lic.id) {
target.licenses.push(lic);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coco::{CocoAnnotation, CocoCategory, CocoImage};
#[test]
fn test_reader_default() {
let reader = CocoReader::new();
assert!(!reader.options.validate);
assert_eq!(reader.options.max_images, 0);
assert!(reader.options.category_filter.is_empty());
}
#[test]
fn test_reader_with_options() {
let options = CocoReadOptions {
validate: true,
max_images: 100,
category_filter: vec!["person".to_string()],
};
let reader = CocoReader::with_options(options.clone());
assert!(reader.options.validate);
assert_eq!(reader.options.max_images, 100);
}
#[test]
fn test_derive_file_name_from_coco_url() {
assert_eq!(
derive_file_name_from_coco_url(
"http://images.cocodataset.org/val2017/000000397133.jpg"
),
Some("val2017/000000397133.jpg".to_string())
);
assert_eq!(
derive_file_name_from_coco_url(
"https://images.cocodataset.org/train2017/000000000009.jpg"
),
Some("train2017/000000000009.jpg".to_string())
);
assert_eq!(derive_file_name_from_coco_url("host-only"), None);
assert_eq!(derive_file_name_from_coco_url("http://host/"), None);
}
#[test]
fn test_derive_file_name_from_coco_url_rejects_traversal() {
assert_eq!(
derive_file_name_from_coco_url("http://host/../etc/passwd"),
None
);
assert_eq!(
derive_file_name_from_coco_url("http://host/val2017/../../etc/passwd"),
None
);
assert_eq!(derive_file_name_from_coco_url("http://host/./foo.jpg"), None);
}
#[test]
fn test_derive_file_name_from_coco_url_rejects_absolute_and_windows() {
assert_eq!(
derive_file_name_from_coco_url("http://host//etc/passwd"),
None
);
assert_eq!(
derive_file_name_from_coco_url("http://host/val2017\\..\\..\\etc"),
None
);
assert_eq!(
derive_file_name_from_coco_url("http://host/C:/Windows/System32"),
None
);
}
#[test]
fn test_fill_missing_file_names_from_lvis_json() {
let json = r#"{
"images": [
{
"id": 397133,
"width": 640,
"height": 427,
"coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg",
"neg_category_ids": [279, 899],
"not_exhaustive_category_ids": [914]
}
],
"annotations": [],
"categories": []
}"#;
let mut dataset: CocoDataset = serde_json::from_str(json).unwrap();
assert_eq!(dataset.images[0].file_name, "");
fill_missing_file_names(&mut dataset);
assert_eq!(dataset.images[0].file_name, "val2017/000000397133.jpg");
assert_eq!(
dataset.images[0].neg_category_ids.as_deref(),
Some(&[279u32, 899][..])
);
}
#[test]
fn test_fill_missing_file_names_preserves_existing() {
let mut dataset = CocoDataset {
images: vec![CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "custom/path.jpg".to_string(),
coco_url: Some("http://images.cocodataset.org/val2017/foo.jpg".to_string()),
..Default::default()
}],
..Default::default()
};
fill_missing_file_names(&mut dataset);
assert_eq!(dataset.images[0].file_name, "custom/path.jpg");
}
#[test]
fn test_validate_dataset_valid() {
let dataset = CocoDataset {
images: vec![CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "test.jpg".to_string(),
..Default::default()
}],
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 1,
image_id: 1,
category_id: 1,
bbox: [10.0, 20.0, 100.0, 80.0],
area: 8000.0,
iscrowd: 0,
segmentation: None,
score: None,
}],
..Default::default()
};
assert!(validate_dataset(&dataset).is_ok());
}
#[test]
fn test_validate_dataset_missing_image() {
let dataset = CocoDataset {
images: vec![],
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 1,
image_id: 999, category_id: 1,
bbox: [10.0, 20.0, 100.0, 80.0],
..Default::default()
}],
..Default::default()
};
assert!(validate_dataset(&dataset).is_err());
}
#[test]
fn test_merge_datasets() {
let mut target = CocoDataset {
images: vec![CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "img1.jpg".to_string(),
..Default::default()
}],
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![],
..Default::default()
};
let source = CocoDataset {
images: vec![
CocoImage {
id: 1, width: 640,
height: 480,
file_name: "img1.jpg".to_string(),
..Default::default()
},
CocoImage {
id: 2, width: 800,
height: 600,
file_name: "img2.jpg".to_string(),
..Default::default()
},
],
categories: vec![CocoCategory {
id: 2,
name: "car".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![],
..Default::default()
};
merge_datasets(&mut target, source);
assert_eq!(target.images.len(), 2);
assert_eq!(target.categories.len(), 2);
}
#[test]
fn test_apply_max_images_filter() {
let reader = CocoReader::with_options(CocoReadOptions {
max_images: 2,
..Default::default()
});
let dataset = CocoDataset {
images: vec![
CocoImage {
id: 1,
..Default::default()
},
CocoImage {
id: 2,
..Default::default()
},
CocoImage {
id: 3,
..Default::default()
},
],
annotations: vec![
CocoAnnotation {
id: 1,
image_id: 1,
..Default::default()
},
CocoAnnotation {
id: 2,
image_id: 2,
..Default::default()
},
CocoAnnotation {
id: 3,
image_id: 3,
..Default::default()
},
],
..Default::default()
};
let filtered = reader.apply_filters(dataset);
assert_eq!(filtered.images.len(), 2);
assert_eq!(filtered.annotations.len(), 2);
}
#[test]
fn test_infer_group_from_filename_instances() {
assert_eq!(
infer_group_from_filename("instances_train2017.json"),
Some("train".to_string())
);
assert_eq!(
infer_group_from_filename("instances_val2017.json"),
Some("val".to_string())
);
assert_eq!(
infer_group_from_filename("instances_test2017.json"),
Some("test".to_string())
);
}
#[test]
fn test_infer_group_from_filename_keypoints() {
assert_eq!(
infer_group_from_filename("person_keypoints_train2017.json"),
Some("train".to_string())
);
assert_eq!(
infer_group_from_filename("person_keypoints_val2017.json"),
Some("val".to_string())
);
}
#[test]
fn test_infer_group_from_filename_captions() {
assert_eq!(
infer_group_from_filename("captions_train2017.json"),
Some("train".to_string())
);
assert_eq!(
infer_group_from_filename("captions_val2017.json"),
Some("val".to_string())
);
}
#[test]
fn test_infer_group_from_filename_panoptic() {
assert_eq!(
infer_group_from_filename("panoptic_train2017.json"),
Some("train".to_string())
);
assert_eq!(
infer_group_from_filename("panoptic_val2017.json"),
Some("val".to_string())
);
}
#[test]
fn test_infer_group_from_filename_fallback() {
assert_eq!(
infer_group_from_filename("my_custom_train_annotations.json"),
Some("train".to_string())
);
assert_eq!(
infer_group_from_filename("validation_data.json"),
Some("val".to_string())
);
}
#[test]
fn test_infer_group_from_filename_no_match() {
assert_eq!(infer_group_from_filename("annotations.json"), None);
assert_eq!(infer_group_from_filename("data.json"), None);
}
}