use std::collections::HashSet;
use std::error::Error;
use std::fs;
use std::path::Path;
use std::{collections::HashMap, path::PathBuf};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
use crate::errors::{self, LoadingError, MissingIdError};
use crate::utils::load_img;
use crate::visualize::draw;
#[derive(Clone, Debug, Default, PartialEq, Deserialize, Serialize)]
pub struct Dataset {
#[serde(default)]
pub info: Info,
pub images: Vec<Image>,
pub annotations: Vec<Annotation>,
pub categories: Vec<Category>,
#[serde(default)]
pub licenses: Vec<License>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
pub struct Info {
pub year: u32,
pub version: String,
pub description: String,
pub contributor: String,
pub url: String,
pub date_created: String,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
pub struct License {
pub id: u32,
pub name: String,
pub url: String,
}
#[cfg_attr(
feature = "pyo3",
pyclass(get_all, set_all, module = "rpycocotools.anns")
)]
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
pub struct Image {
pub id: u64,
pub width: u32,
pub height: u32,
pub file_name: String,
#[serde(default)]
pub license: u32,
#[serde(default)]
pub flickr_url: String,
#[serde(default)]
pub coco_url: String,
#[serde(default)]
pub date_captured: String,
}
#[cfg_attr(
feature = "pyo3",
pyclass(subclass, get_all, set_all, module = "rpycocotools.anns")
)]
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct Annotation {
pub id: u64,
pub image_id: u64,
pub category_id: u32,
pub segmentation: Segmentation,
pub area: f64,
pub bbox: Bbox,
pub iscrowd: u32,
}
#[cfg_attr(feature = "pyo3", derive(FromPyObject))]
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum Segmentation {
Rle(Rle),
CocoRle(CocoRle),
Polygons(Polygons),
#[serde(skip)]
PolygonsRS(PolygonsRS),
}
pub type Polygons = Vec<Vec<f64>>;
#[cfg_attr(
feature = "pyo3",
pyclass(get_all, set_all, module = "rpycocotools.anns")
)]
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct PolygonsRS {
pub size: Vec<u32>,
pub counts: Vec<Vec<f64>>,
}
#[cfg_attr(
feature = "pyo3",
pyclass(get_all, set_all, name = "RLE", module = "rpycocotools.anns")
)]
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct Rle {
pub size: Vec<u32>,
pub counts: Vec<u32>,
}
#[cfg_attr(
feature = "pyo3",
pyclass(get_all, set_all, name = "COCO_RLE", module = "rpycocotools.anns")
)]
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct CocoRle {
pub size: Vec<u32>,
pub counts: String,
}
#[cfg_attr(
feature = "pyo3",
pyclass(
sequence,
get_all,
set_all,
name = "BBox",
module = "rpycocotools.anns"
)
)]
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct Bbox {
pub left: f64,
pub top: f64,
pub width: f64,
pub height: f64,
}
#[cfg_attr(
feature = "pyo3",
pyclass(get_all, set_all, module = "rpycocotools.anns")
)]
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct Category {
pub id: u32,
pub name: String,
pub supercategory: String,
}
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct HashmapDataset {
pub(crate) anns: HashMap<u64, Annotation>,
cats: HashMap<u32, Category>,
imgs: HashMap<u64, Image>,
img_to_anns: HashMap<u64, HashSet<u64>>,
pub image_folder: PathBuf,
}
impl HashmapDataset {
pub fn new<P: AsRef<Path>>(annotations_path: P, image_folder: P) -> Result<Self, LoadingError> {
let annotations_path = annotations_path.as_ref().to_path_buf();
let annotations_file_content = fs::read_to_string(&annotations_path)
.map_err(|err| LoadingError::Read(err, annotations_path.clone()))?;
let dataset: Dataset = serde_json::from_str(&annotations_file_content)
.map_err(|err| LoadingError::Deserialize(err, annotations_path.clone()))?;
Self::from_dataset(dataset, image_folder)
}
pub fn from_dataset<P: AsRef<Path>>(
dataset: Dataset,
image_folder: P,
) -> Result<Self, LoadingError> {
let cats = dataset
.categories
.into_iter()
.map(|category| (category.id, category))
.collect();
let imgs: HashMap<u64, Image> = dataset
.images
.clone()
.into_iter()
.map(|image| (image.id, image))
.collect();
let mut anns: HashMap<u64, Annotation> = HashMap::new();
let mut img_to_anns: HashMap<u64, HashSet<u64>> = dataset
.images
.into_iter()
.map(|image| (image.id, HashSet::new()))
.collect();
for mut annotation in dataset.annotations {
let ann_id = annotation.id;
let img_id = annotation.image_id;
if let Segmentation::Polygons(counts) = annotation.segmentation {
annotation.segmentation = Segmentation::PolygonsRS(PolygonsRS {
size: if let Some(img) = imgs.get(&img_id) {
vec![img.height, img.width]
} else {
return Err(MissingIdError::Image(img_id)).map_err(LoadingError::Parsing);
},
counts,
});
};
anns.insert(annotation.id, annotation);
img_to_anns
.entry(img_id)
.or_insert_with(HashSet::new)
.insert(ann_id);
}
Ok(Self {
anns,
cats,
imgs,
img_to_anns,
image_folder: image_folder.as_ref().to_path_buf(),
})
}
pub fn get_ann(&self, ann_id: u64) -> Result<&Annotation, MissingIdError> {
self.anns
.get(&ann_id)
.ok_or(MissingIdError::Annotation(ann_id))
}
#[must_use]
pub fn get_anns(&self) -> Vec<&Annotation> {
self.anns.values().collect()
}
pub fn get_cat(&self, cat_id: u32) -> Result<&Category, MissingIdError> {
self.cats
.get(&cat_id)
.ok_or(MissingIdError::Category(cat_id))
}
#[must_use]
pub fn get_cats(&self) -> Vec<&Category> {
self.cats.values().collect()
}
pub fn get_img(&self, img_id: u64) -> Result<&Image, MissingIdError> {
self.imgs.get(&img_id).ok_or(MissingIdError::Image(img_id))
}
#[must_use]
pub fn get_imgs(&self) -> Vec<&Image> {
self.imgs.values().collect()
}
pub fn get_img_anns(&self, img_id: u64) -> Result<Vec<&Annotation>, MissingIdError> {
self.img_to_anns
.get(&img_id)
.map_or(Err(MissingIdError::Image(img_id)), |ann_ids| {
ann_ids.iter().map(|ann_id| self.get_ann(*ann_id)).collect()
})
}
pub fn draw_img_anns(
&self,
img_id: u64,
draw_bbox: bool,
) -> Result<image::ImageBuffer<image::Rgb<u8>, Vec<u8>>, errors::CocoError> {
let img_path = self.image_folder.join(&self.get_img(img_id)?.file_name);
let mut img = load_img(&img_path)?;
draw::anns(&mut img, &self.get_img_anns(img_id)?, draw_bbox)?;
Ok(img)
}
pub fn draw_ann(
&self,
ann: &Annotation,
draw_bbox: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let img_path = self
.image_folder
.join(&self.get_img(ann.image_id)?.file_name);
let mut img = load_img(&img_path)?;
draw::anns(&mut img, &vec![ann], draw_bbox)?;
Ok(())
}
pub fn save_to<P: AsRef<Path>>(&self, output_path: P) -> Result<(), Box<dyn Error>> {
let dataset = Dataset::from(self);
let f = fs::File::create(output_path)?;
serde_json::to_writer_pretty(&f, &dataset)?;
Ok(())
}
pub fn json(&self) -> Result<String, serde_json::Error> {
let dataset = Dataset::from(self);
serde_json::to_string(&dataset)
}
}
impl From<&HashmapDataset> for Dataset {
fn from(dataset: &HashmapDataset) -> Self {
Self {
images: dataset.get_imgs().into_iter().cloned().collect(),
annotations: dataset.get_anns().into_iter().cloned().collect(),
categories: dataset.get_cats().into_iter().cloned().collect(),
..Default::default()
}
}
}
impl PartialEq for PolygonsRS {
fn eq(&self, other: &Self) -> bool {
if self.size != other.size || self.counts.len() != other.counts.len() {
return false;
}
let other_polygons = other.counts.clone();
for self_poly in &self.counts {
let mut found_match = false;
'outer: for other_poly in &other_polygons {
let mut other_poly = other_poly.clone();
if self_poly.len() != other_poly.len() {
continue;
}
for _ in 0..other_poly.len() {
if &other_poly == self_poly {
found_match = true;
break 'outer;
}
other_poly.rotate_right(1);
}
other_poly.reverse();
let mut reversed_other_poly: Vec<f64> = Vec::new();
for i in (0..other_poly.len()).step_by(2) {
reversed_other_poly.push(other_poly[i + 1]);
reversed_other_poly.push(other_poly[i]);
}
for _ in 0..reversed_other_poly.len() {
if &reversed_other_poly == self_poly {
found_match = true;
break 'outer;
}
reversed_other_poly.rotate_right(1);
}
}
if !found_match {
return false;
}
}
true
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::PolygonsRS;
use rstest::rstest;
#[rstest]
#[case::single_polygon(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3]] },
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3]] },
)]
#[case::two_polygons(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3], vec![7.4, 8.4, 9.5, 10.5, 11.6, 12.6]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3], vec![7.4, 8.4, 9.5, 10.5, 11.6, 12.6]]},
)]
#[case::two_polygons_different_order(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3], vec![7.4, 8.4, 9.5, 10.5, 11.6, 12.6]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![7.4, 8.4, 9.5, 10.5, 11.6, 12.6], vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3]]},
)]
#[case::different_start_point(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3], vec![7.4, 8.4, 9.5, 10.5, 11.6, 12.6]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![11.6, 12.6, 7.4, 8.4, 9.5, 10.5], vec![3.2, 4.2, 5.3, 6.3, 1.1, 2.1]]},
)]
#[case::reversed_order(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3, 7.4, 8.4]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![7.4, 8.4, 5.3, 6.3, 3.2, 4.2, 1.1, 2.1]]},
)]
fn polygon_equality(#[case] poly1: &PolygonsRS, #[case] poly2: &PolygonsRS) {
assert_eq!(poly1, poly2);
}
#[rstest]
#[case::different_length(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3, 7.4, 8.4]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3]]},
)]
#[case::different_digit(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![2.1, 2.1, 3.2, 4.2, 5.3, 6.3]]},
)]
#[case::different_number_of_polygons(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3, 7.4, 8.4]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3, 7.4, 8.4], vec![7.4, 8.4, 9.5, 10.5, 11.6, 12.6]]},
)]
#[case::x_y_inverted(
&PolygonsRS {size: vec![20, 20], counts: vec![vec![1.1, 2.1, 3.2, 4.2, 5.3, 6.3]]},
&PolygonsRS {size: vec![20, 20], counts: vec![vec![2.1, 2.1, 4.2, 3.2, 6.3, 5.3]]},
)]
fn polygon_inequality(#[case] poly1: &PolygonsRS, #[case] poly2: &PolygonsRS) {
assert_ne!(poly1, poly2);
}
}