use axonml_tensor::Tensor;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Deserialize)]
struct CocoJson {
images: Vec<CocoImage>,
annotations: Vec<CocoAnno>,
categories: Vec<CocoCategory>,
}
#[derive(Deserialize)]
struct CocoImage {
id: u64,
file_name: String,
width: u32,
height: u32,
}
#[derive(Deserialize)]
struct CocoAnno {
image_id: u64,
category_id: u64,
bbox: [f32; 4], #[allow(dead_code)]
area: f32,
iscrowd: u32,
}
#[derive(Deserialize)]
struct CocoCategory {
id: u64,
#[allow(dead_code)]
name: String,
}
#[derive(Debug, Clone)]
pub struct CocoAnnotation {
pub bbox: [f32; 4],
pub category_id: usize,
}
pub struct CocoDataset {
_image_dir: PathBuf,
entries: Vec<CocoEntry>,
target_size: (usize, usize),
num_classes: usize,
}
struct CocoEntry {
image_path: PathBuf,
_orig_w: u32,
_orig_h: u32,
annotations: Vec<CocoAnnotation>,
}
impl CocoDataset {
pub fn new<P: AsRef<Path>>(
image_dir: P,
annotation_json: P,
target_size: (usize, usize),
) -> Result<Self, String> {
let image_dir = image_dir.as_ref().to_path_buf();
let json_str = std::fs::read_to_string(annotation_json.as_ref())
.map_err(|e| format!("Failed to read COCO annotations: {e}"))?;
let coco: CocoJson = serde_json::from_str(&json_str)
.map_err(|e| format!("Failed to parse COCO JSON: {e}"))?;
let mut cat_remap: HashMap<u64, usize> = HashMap::new();
let mut sorted_cats: Vec<u64> = coco.categories.iter().map(|c| c.id).collect();
sorted_cats.sort_unstable();
for (new_id, &old_id) in sorted_cats.iter().enumerate() {
cat_remap.insert(old_id, new_id);
}
let num_classes = sorted_cats.len();
let _image_map: HashMap<u64, &CocoImage> =
coco.images.iter().map(|img| (img.id, img)).collect();
let mut anno_map: HashMap<u64, Vec<&CocoAnno>> = HashMap::new();
for anno in &coco.annotations {
if anno.iscrowd == 0 {
anno_map.entry(anno.image_id).or_default().push(anno);
}
}
let mut entries = Vec::new();
for img in &coco.images {
let annos = anno_map.get(&img.id);
let mut annotations = Vec::new();
if let Some(img_annos) = annos {
for anno in img_annos {
let x = anno.bbox[0];
let y = anno.bbox[1];
let w = anno.bbox[2];
let h = anno.bbox[3];
if w > 0.0 && h > 0.0 {
if let Some(&cat_id) = cat_remap.get(&anno.category_id) {
annotations.push(CocoAnnotation {
bbox: [
x / img.width as f32,
y / img.height as f32,
(x + w) / img.width as f32,
(y + h) / img.height as f32,
],
category_id: cat_id,
});
}
}
}
}
if !annotations.is_empty() {
entries.push(CocoEntry {
image_path: image_dir.join(&img.file_name),
_orig_w: img.width,
_orig_h: img.height,
annotations,
});
}
}
Ok(Self {
_image_dir: image_dir,
entries,
target_size,
num_classes,
})
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn num_classes(&self) -> usize {
self.num_classes
}
pub fn get(&self, index: usize) -> Option<(Tensor<f32>, Vec<CocoAnnotation>)> {
let entry = self.entries.get(index)?;
let (th, tw) = self.target_size;
let img = crate::image_io::load_image_resized(&entry.image_path, th, tw).ok()?;
Some((img, entry.annotations.clone()))
}
pub fn get_annotations(&self, index: usize) -> Option<&[CocoAnnotation]> {
self.entries.get(index).map(|e| e.annotations.as_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coco_json_parse() {
let json = r#"{
"images": [
{"id": 1, "file_name": "img1.jpg", "width": 640, "height": 480},
{"id": 2, "file_name": "img2.jpg", "width": 320, "height": 240}
],
"annotations": [
{"image_id": 1, "category_id": 1, "bbox": [10.0, 20.0, 100.0, 200.0], "area": 20000.0, "iscrowd": 0},
{"image_id": 1, "category_id": 3, "bbox": [50.0, 60.0, 30.0, 40.0], "area": 1200.0, "iscrowd": 0},
{"image_id": 2, "category_id": 1, "bbox": [5.0, 5.0, 50.0, 50.0], "area": 2500.0, "iscrowd": 0},
{"image_id": 2, "category_id": 2, "bbox": [0.0, 0.0, 10.0, 10.0], "area": 100.0, "iscrowd": 1}
],
"categories": [
{"id": 1, "name": "person"},
{"id": 2, "name": "car"},
{"id": 3, "name": "dog"}
]
}"#;
let coco: CocoJson = serde_json::from_str(json).unwrap();
assert_eq!(coco.images.len(), 2);
assert_eq!(coco.annotations.len(), 4);
assert_eq!(coco.categories.len(), 3);
let non_crowd: Vec<_> = coco.annotations.iter().filter(|a| a.iscrowd == 0).collect();
assert_eq!(non_crowd.len(), 3);
}
#[test]
fn test_coco_category_remapping() {
let mut cat_remap: HashMap<u64, usize> = HashMap::new();
let cats = vec![1, 3, 5, 10]; for (new_id, &old_id) in cats.iter().enumerate() {
cat_remap.insert(old_id, new_id);
}
assert_eq!(cat_remap[&1], 0);
assert_eq!(cat_remap[&3], 1);
assert_eq!(cat_remap[&5], 2);
assert_eq!(cat_remap[&10], 3);
}
#[test]
fn test_bbox_normalization() {
let x = 10.0f32;
let y = 20.0f32;
let w = 100.0f32;
let h = 200.0f32;
let img_w = 640.0f32;
let img_h = 480.0f32;
let normalized = [x / img_w, y / img_h, (x + w) / img_w, (y + h) / img_h];
assert!((normalized[0] - 10.0 / 640.0).abs() < 1e-5);
assert!((normalized[1] - 20.0 / 480.0).abs() < 1e-5);
assert!((normalized[2] - 110.0 / 640.0).abs() < 1e-5);
assert!((normalized[3] - 220.0 / 480.0).abs() < 1e-5);
assert!(normalized.iter().all(|&v| v >= 0.0 && v <= 1.0));
}
}