use super::types::{
CocoAnnotation, CocoCategory, CocoDataset, CocoImage, CocoInfo, CocoSegmentation,
};
use crate::Error;
use std::{
fs::File,
io::{BufWriter, Write},
path::Path,
};
use zip::{CompressionMethod, write::SimpleFileOptions};
#[derive(Debug, Clone)]
pub struct CocoWriteOptions {
pub compress: bool,
pub pretty: bool,
}
impl Default for CocoWriteOptions {
fn default() -> Self {
Self {
compress: true,
pretty: false,
}
}
}
pub struct CocoWriter {
options: CocoWriteOptions,
}
impl CocoWriter {
pub fn new() -> Self {
Self {
options: CocoWriteOptions::default(),
}
}
pub fn with_options(options: CocoWriteOptions) -> Self {
Self { options }
}
pub fn write_json<P: AsRef<Path>>(&self, dataset: &CocoDataset, path: P) -> Result<(), Error> {
if let Some(parent) = path.as_ref().parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent)?;
}
let file = File::create(path.as_ref())?;
let writer = BufWriter::with_capacity(64 * 1024, file);
if self.options.pretty {
serde_json::to_writer_pretty(writer, dataset)?;
} else {
serde_json::to_writer(writer, dataset)?;
}
Ok(())
}
pub fn write_zip<P: AsRef<Path>>(
&self,
dataset: &CocoDataset,
images: impl Iterator<Item = (String, Vec<u8>)>,
path: P,
) -> Result<(), Error> {
if let Some(parent) = path.as_ref().parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent)?;
}
let file = File::create(path.as_ref())?;
let mut zip = zip::ZipWriter::new(file);
let options = if self.options.compress {
SimpleFileOptions::default().compression_method(CompressionMethod::Deflated)
} else {
SimpleFileOptions::default().compression_method(CompressionMethod::Stored)
};
zip.start_file("annotations/instances.json", options)?;
let json = if self.options.pretty {
serde_json::to_string_pretty(dataset)?
} else {
serde_json::to_string(dataset)?
};
zip.write_all(json.as_bytes())?;
for (filename, data) in images {
zip.start_file(&filename, options)?;
zip.write_all(&data)?;
}
zip.finish()?;
Ok(())
}
pub fn write_zip_from_dir<P: AsRef<Path>>(
&self,
dataset: &CocoDataset,
images_dir: P,
path: P,
) -> Result<(), Error> {
let images_dir = images_dir.as_ref();
let images = dataset.images.iter().filter_map(|img| {
let img_path = images_dir.join(&img.file_name);
std::fs::read(&img_path)
.ok()
.map(|data| (format!("images/{}", img.file_name), data))
});
self.write_zip(dataset, images, path)
}
pub fn write_split_by_group<P: AsRef<Path>>(
&self,
dataset: &CocoDataset,
group_assignments: &[String],
images_source: Option<&Path>,
output_dir: P,
) -> Result<std::collections::HashMap<String, usize>, Error> {
use std::collections::{HashMap, HashSet};
let output_dir = output_dir.as_ref();
if dataset.images.len() != group_assignments.len() {
return Err(Error::CocoError(format!(
"Image count ({}) does not match group assignment count ({})",
dataset.images.len(),
group_assignments.len()
)));
}
let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
for (idx, group) in group_assignments.iter().enumerate() {
groups.entry(group.clone()).or_default().push(idx);
}
let mut result = HashMap::new();
for (group_name, image_indices) in &groups {
let group_dir = output_dir.join(group_name);
let annotations_dir = group_dir.join("annotations");
let images_dir = group_dir.join("images");
std::fs::create_dir_all(&annotations_dir)?;
std::fs::create_dir_all(&images_dir)?;
let image_ids: HashSet<u64> = image_indices
.iter()
.map(|&idx| dataset.images[idx].id)
.collect();
let subset = CocoDataset {
info: dataset.info.clone(),
licenses: dataset.licenses.clone(),
images: image_indices
.iter()
.map(|&idx| dataset.images[idx].clone())
.collect(),
annotations: dataset
.annotations
.iter()
.filter(|ann| image_ids.contains(&ann.image_id))
.cloned()
.collect(),
categories: dataset.categories.clone(),
};
let ann_file = annotations_dir.join(format!("instances_{}.json", group_name));
self.write_json(&subset, &ann_file)?;
if let Some(source) = images_source {
for &idx in image_indices {
let image = &dataset.images[idx];
let src_path = source.join(&image.file_name);
let dst_path = images_dir.join(&image.file_name);
if src_path.exists() {
std::fs::copy(&src_path, &dst_path)?;
}
}
}
result.insert(group_name.clone(), image_indices.len());
}
Ok(result)
}
pub fn write_split_by_group_zip<P: AsRef<Path>>(
&self,
dataset: &CocoDataset,
group_assignments: &[String],
images_source: Option<&Path>,
output_dir: P,
) -> Result<std::collections::HashMap<String, usize>, Error> {
use std::collections::{HashMap, HashSet};
let output_dir = output_dir.as_ref();
std::fs::create_dir_all(output_dir)?;
if dataset.images.len() != group_assignments.len() {
return Err(Error::CocoError(format!(
"Image count ({}) does not match group assignment count ({})",
dataset.images.len(),
group_assignments.len()
)));
}
let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
for (idx, group) in group_assignments.iter().enumerate() {
groups.entry(group.clone()).or_default().push(idx);
}
let mut result = HashMap::new();
for (group_name, image_indices) in &groups {
let image_ids: HashSet<u64> = image_indices
.iter()
.map(|&idx| dataset.images[idx].id)
.collect();
let subset = CocoDataset {
info: dataset.info.clone(),
licenses: dataset.licenses.clone(),
images: image_indices
.iter()
.map(|&idx| dataset.images[idx].clone())
.collect(),
annotations: dataset
.annotations
.iter()
.filter(|ann| image_ids.contains(&ann.image_id))
.cloned()
.collect(),
categories: dataset.categories.clone(),
};
let images: Vec<(String, Vec<u8>)> = if let Some(source) = images_source {
image_indices
.iter()
.filter_map(|&idx| {
let image = &dataset.images[idx];
let src_path = source.join(&image.file_name);
std::fs::read(&src_path)
.ok()
.map(|data| (format!("images/{}", image.file_name), data))
})
.collect()
} else {
vec![]
};
let zip_path = output_dir.join(format!("{}.zip", group_name));
self.write_zip(&subset, images.into_iter(), &zip_path)?;
result.insert(group_name.clone(), image_indices.len());
}
Ok(result)
}
}
impl Default for CocoWriter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct CocoDatasetBuilder {
dataset: CocoDataset,
next_image_id: u64,
next_annotation_id: u64,
next_category_id: u32,
}
impl CocoDatasetBuilder {
pub fn new() -> Self {
Self {
dataset: CocoDataset::default(),
next_image_id: 1,
next_annotation_id: 1,
next_category_id: 1,
}
}
pub fn info(mut self, info: CocoInfo) -> Self {
self.dataset.info = info;
self
}
pub fn add_category(&mut self, name: &str, supercategory: Option<&str>) -> u32 {
for cat in &self.dataset.categories {
if cat.name == name {
return cat.id;
}
}
let id = self.next_category_id;
self.next_category_id += 1;
self.dataset.categories.push(CocoCategory {
id,
name: name.to_string(),
supercategory: supercategory.map(String::from),
..Default::default()
});
id
}
pub fn add_category_with_id(
&mut self,
id: u32,
name: &str,
supercategory: Option<&str>,
) -> u32 {
for cat in &self.dataset.categories {
if cat.id == id || cat.name == name {
return cat.id;
}
}
self.dataset.categories.push(CocoCategory {
id,
name: name.to_string(),
supercategory: supercategory.map(String::from),
..Default::default()
});
if id >= self.next_category_id {
self.next_category_id = id + 1;
}
id
}
pub fn add_image(&mut self, file_name: &str, width: u32, height: u32) -> u64 {
let id = self.next_image_id;
self.next_image_id += 1;
self.dataset.images.push(CocoImage {
id,
width,
height,
file_name: file_name.to_string(),
..Default::default()
});
id
}
pub fn add_annotation(
&mut self,
image_id: u64,
category_id: u32,
bbox: [f64; 4],
segmentation: Option<CocoSegmentation>,
) -> u64 {
self.add_annotation_with_iscrowd(image_id, category_id, bbox, segmentation, 0)
}
pub fn add_annotation_with_iscrowd(
&mut self,
image_id: u64,
category_id: u32,
bbox: [f64; 4],
segmentation: Option<CocoSegmentation>,
iscrowd: u8,
) -> u64 {
let id = self.next_annotation_id;
self.next_annotation_id += 1;
let area = bbox[2] * bbox[3];
self.dataset.annotations.push(CocoAnnotation {
id,
image_id,
category_id,
bbox,
area,
iscrowd,
segmentation,
score: None,
});
id
}
pub fn set_annotation_score(&mut self, annotation_id: u64, score: f64) {
if let Some(ann) = self
.dataset
.annotations
.iter_mut()
.find(|a| a.id == annotation_id)
{
ann.score = Some(score);
}
}
pub fn set_image_neg_categories(
&mut self,
image_id: u64,
neg_category_ids: Option<Vec<u32>>,
not_exhaustive_category_ids: Option<Vec<u32>>,
) {
if let Some(img) = self.dataset.images.iter_mut().find(|i| i.id == image_id) {
img.neg_category_ids = neg_category_ids;
img.not_exhaustive_category_ids = not_exhaustive_category_ids;
}
}
pub fn set_category_metadata(
&mut self,
name: &str,
synset: Option<String>,
frequency: Option<String>,
synonyms: Option<Vec<String>>,
def: Option<String>,
) {
if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
if synset.is_some() {
cat.synset = synset;
}
if frequency.is_some() {
cat.frequency = frequency;
}
if synonyms.is_some() {
cat.synonyms = synonyms;
}
if def.is_some() {
cat.def = def;
}
}
}
pub fn set_category_supercategory(&mut self, name: &str, supercategory: &str) {
if let Some(cat) = self.dataset.categories.iter_mut().find(|c| c.name == name) {
cat.supercategory = Some(supercategory.to_string());
}
}
pub fn build(self) -> CocoDataset {
self.dataset
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_writer_default() {
let writer = CocoWriter::new();
assert!(writer.options.compress);
assert!(!writer.options.pretty);
}
#[test]
fn test_write_json() {
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("test.json");
let dataset = CocoDataset {
images: vec![CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "test.jpg".to_string(),
..Default::default()
}],
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![CocoAnnotation {
id: 1,
image_id: 1,
category_id: 1,
bbox: [10.0, 20.0, 100.0, 80.0],
area: 8000.0,
iscrowd: 0,
segmentation: None,
score: None,
}],
..Default::default()
};
let writer = CocoWriter::new();
writer.write_json(&dataset, &output_path).unwrap();
assert!(output_path.exists());
let contents = std::fs::read_to_string(&output_path).unwrap();
let restored: CocoDataset = serde_json::from_str(&contents).unwrap();
assert_eq!(restored.images.len(), 1);
assert_eq!(restored.annotations.len(), 1);
}
#[test]
fn test_write_json_pretty() {
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("test_pretty.json");
let dataset = CocoDataset::default();
let writer = CocoWriter::with_options(CocoWriteOptions {
pretty: true,
compress: false,
});
writer.write_json(&dataset, &output_path).unwrap();
let contents = std::fs::read_to_string(&output_path).unwrap();
assert!(contents.contains('\n')); }
#[test]
fn test_dataset_builder() {
let mut builder = CocoDatasetBuilder::new();
let person_id = builder.add_category("person", Some("human"));
let car_id = builder.add_category("car", Some("vehicle"));
assert_eq!(person_id, 1);
assert_eq!(car_id, 2);
let person_id2 = builder.add_category("person", None);
assert_eq!(person_id2, 1);
let img1 = builder.add_image("image1.jpg", 640, 480);
let img2 = builder.add_image("image2.jpg", 800, 600);
assert_eq!(img1, 1);
assert_eq!(img2, 2);
let ann1 = builder.add_annotation(img1, person_id, [10.0, 20.0, 100.0, 80.0], None);
let ann2 = builder.add_annotation(img1, car_id, [50.0, 60.0, 150.0, 100.0], None);
assert_eq!(ann1, 1);
assert_eq!(ann2, 2);
let dataset = builder.build();
assert_eq!(dataset.categories.len(), 2);
assert_eq!(dataset.images.len(), 2);
assert_eq!(dataset.annotations.len(), 2);
}
#[test]
fn test_write_zip() {
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("test.zip");
let dataset = CocoDataset {
images: vec![CocoImage {
id: 1,
width: 100,
height: 100,
file_name: "test.jpg".to_string(),
..Default::default()
}],
..Default::default()
};
let images = vec![("images/test.jpg".to_string(), vec![0xFF, 0xD8, 0xFF])];
let writer = CocoWriter::new();
writer
.write_zip(&dataset, images.into_iter(), &output_path)
.unwrap();
assert!(output_path.exists());
let file = std::fs::File::open(&output_path).unwrap();
let mut archive = zip::ZipArchive::new(file).unwrap();
assert!(archive.by_name("annotations/instances.json").is_ok());
assert!(archive.by_name("images/test.jpg").is_ok());
}
#[test]
fn test_write_split_by_group() {
let temp_dir = TempDir::new().unwrap();
let output_dir = temp_dir.path().join("split_output");
let dataset = CocoDataset {
images: vec![
CocoImage {
id: 1,
width: 640,
height: 480,
file_name: "train1.jpg".to_string(),
..Default::default()
},
CocoImage {
id: 2,
width: 640,
height: 480,
file_name: "train2.jpg".to_string(),
..Default::default()
},
CocoImage {
id: 3,
width: 800,
height: 600,
file_name: "val1.jpg".to_string(),
..Default::default()
},
],
categories: vec![CocoCategory {
id: 1,
name: "person".to_string(),
supercategory: None,
..Default::default()
}],
annotations: vec![
CocoAnnotation {
id: 1,
image_id: 1,
category_id: 1,
bbox: [10.0, 20.0, 100.0, 80.0],
..Default::default()
},
CocoAnnotation {
id: 2,
image_id: 2,
category_id: 1,
bbox: [20.0, 30.0, 100.0, 80.0],
..Default::default()
},
CocoAnnotation {
id: 3,
image_id: 3,
category_id: 1,
bbox: [30.0, 40.0, 100.0, 80.0],
..Default::default()
},
],
..Default::default()
};
let groups = vec!["train".to_string(), "train".to_string(), "val".to_string()];
let writer = CocoWriter::new();
let result = writer
.write_split_by_group(&dataset, &groups, None, &output_dir)
.unwrap();
assert_eq!(result.get("train"), Some(&2));
assert_eq!(result.get("val"), Some(&1));
assert!(
output_dir
.join("train/annotations/instances_train.json")
.exists()
);
assert!(
output_dir
.join("val/annotations/instances_val.json")
.exists()
);
let train_json =
std::fs::read_to_string(output_dir.join("train/annotations/instances_train.json"))
.unwrap();
let train_data: CocoDataset = serde_json::from_str(&train_json).unwrap();
assert_eq!(train_data.images.len(), 2);
assert_eq!(train_data.annotations.len(), 2);
let val_json =
std::fs::read_to_string(output_dir.join("val/annotations/instances_val.json")).unwrap();
let val_data: CocoDataset = serde_json::from_str(&val_json).unwrap();
assert_eq!(val_data.images.len(), 1);
assert_eq!(val_data.annotations.len(), 1);
}
#[test]
fn test_write_split_by_group_mismatch() {
let dataset = CocoDataset {
images: vec![CocoImage {
id: 1,
..Default::default()
}],
..Default::default()
};
let groups = vec!["train".to_string(), "val".to_string()];
let writer = CocoWriter::new();
let result =
writer.write_split_by_group(&dataset, &groups, None, std::path::Path::new("/tmp/test"));
assert!(result.is_err());
}
}