1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct CocoDataset {
19 #[serde(default)]
21 pub info: CocoInfo,
22 #[serde(default)]
24 pub licenses: Vec<CocoLicense>,
25 pub images: Vec<CocoImage>,
27 #[serde(default)]
29 pub annotations: Vec<CocoAnnotation>,
30 #[serde(default)]
32 pub categories: Vec<CocoCategory>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct CocoInfo {
38 #[serde(default)]
40 pub year: Option<u32>,
41 #[serde(default)]
43 pub version: Option<String>,
44 #[serde(default)]
46 pub description: Option<String>,
47 #[serde(default)]
49 pub contributor: Option<String>,
50 #[serde(default)]
52 pub url: Option<String>,
53 #[serde(default)]
55 pub date_created: Option<String>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CocoLicense {
61 pub id: u32,
63 pub name: String,
65 #[serde(default)]
67 pub url: Option<String>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, Default)]
74pub struct CocoImage {
75 pub id: u64,
77 pub width: u32,
79 pub height: u32,
81 pub file_name: String,
83 #[serde(default)]
85 pub license: Option<u32>,
86 #[serde(default)]
88 pub flickr_url: Option<String>,
89 #[serde(default)]
91 pub coco_url: Option<String>,
92 #[serde(default)]
94 pub date_captured: Option<String>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, Default)]
101pub struct CocoCategory {
102 pub id: u32,
104 pub name: String,
106 #[serde(default)]
108 pub supercategory: Option<String>,
109}
110
111#[derive(Debug, Clone, Default, Serialize, Deserialize)]
117pub struct CocoAnnotation {
118 pub id: u64,
120 pub image_id: u64,
122 pub category_id: u32,
124 pub bbox: [f64; 4],
126 #[serde(default)]
128 pub area: f64,
129 #[serde(default)]
131 pub iscrowd: u8,
132 #[serde(default, skip_serializing_if = "Option::is_none")]
134 pub segmentation: Option<CocoSegmentation>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(untagged)]
145pub enum CocoSegmentation {
146 Polygon(Vec<Vec<f64>>),
150 Rle(CocoRle),
152 CompressedRle(CocoCompressedRle),
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct CocoRle {
162 pub counts: Vec<u32>,
164 pub size: [u32; 2],
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct CocoCompressedRle {
173 pub counts: String,
175 pub size: [u32; 2],
177}
178
179#[derive(Debug, Clone)]
183pub struct CocoIndex {
184 pub images: HashMap<u64, CocoImage>,
186 pub categories: HashMap<u32, CocoCategory>,
188 pub label_indices: HashMap<u32, u64>,
190 pub annotations_by_image: HashMap<u64, Vec<CocoAnnotation>>,
192}
193
194impl CocoIndex {
195 pub fn from_dataset(dataset: &CocoDataset) -> Self {
200 let images: HashMap<_, _> = dataset
201 .images
202 .iter()
203 .map(|img| (img.id, img.clone()))
204 .collect();
205
206 let categories: HashMap<_, _> = dataset
207 .categories
208 .iter()
209 .map(|cat| (cat.id, cat.clone()))
210 .collect();
211
212 let mut category_names: Vec<_> = dataset
214 .categories
215 .iter()
216 .map(|c| (c.id, c.name.clone()))
217 .collect();
218 category_names.sort_by(|a, b| a.1.cmp(&b.1));
219 let label_indices: HashMap<_, _> = category_names
220 .iter()
221 .enumerate()
222 .map(|(idx, (cat_id, _))| (*cat_id, idx as u64))
223 .collect();
224
225 let mut annotations_by_image: HashMap<u64, Vec<CocoAnnotation>> = HashMap::new();
226 for ann in &dataset.annotations {
227 annotations_by_image
228 .entry(ann.image_id)
229 .or_default()
230 .push(ann.clone());
231 }
232
233 Self {
234 images,
235 categories,
236 label_indices,
237 annotations_by_image,
238 }
239 }
240
241 pub fn label_name(&self, category_id: u32) -> Option<&str> {
243 self.categories.get(&category_id).map(|c| c.name.as_str())
244 }
245
246 pub fn label_index(&self, category_id: u32) -> Option<u64> {
248 self.label_indices.get(&category_id).copied()
249 }
250
251 pub fn annotations_for_image(&self, image_id: u64) -> &[CocoAnnotation] {
253 self.annotations_by_image
254 .get(&image_id)
255 .map(|v| v.as_slice())
256 .unwrap_or(&[])
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_coco_dataset_default() {
266 let dataset = CocoDataset::default();
267 assert!(dataset.images.is_empty());
268 assert!(dataset.annotations.is_empty());
269 assert!(dataset.categories.is_empty());
270 }
271
272 #[test]
273 fn test_coco_index_from_dataset() {
274 let dataset = CocoDataset {
275 images: vec![
276 CocoImage {
277 id: 1,
278 width: 640,
279 height: 480,
280 file_name: "image1.jpg".to_string(),
281 ..Default::default()
282 },
283 CocoImage {
284 id: 2,
285 width: 800,
286 height: 600,
287 file_name: "image2.jpg".to_string(),
288 ..Default::default()
289 },
290 ],
291 categories: vec![
292 CocoCategory {
293 id: 1,
294 name: "person".to_string(),
295 supercategory: Some("human".to_string()),
296 },
297 CocoCategory {
298 id: 2,
299 name: "car".to_string(),
300 supercategory: Some("vehicle".to_string()),
301 },
302 ],
303 annotations: vec![
304 CocoAnnotation {
305 id: 100,
306 image_id: 1,
307 category_id: 1,
308 bbox: [10.0, 20.0, 100.0, 200.0],
309 area: 20000.0,
310 iscrowd: 0,
311 segmentation: None,
312 },
313 CocoAnnotation {
314 id: 101,
315 image_id: 1,
316 category_id: 2,
317 bbox: [50.0, 60.0, 150.0, 100.0],
318 area: 15000.0,
319 iscrowd: 0,
320 segmentation: None,
321 },
322 ],
323 ..Default::default()
324 };
325
326 let index = CocoIndex::from_dataset(&dataset);
327
328 assert_eq!(index.images.len(), 2);
330 assert_eq!(index.images.get(&1).unwrap().file_name, "image1.jpg");
331
332 assert_eq!(index.categories.len(), 2);
334 assert_eq!(index.label_name(1), Some("person"));
335 assert_eq!(index.label_name(2), Some("car"));
336
337 assert_eq!(index.label_index(2), Some(0)); assert_eq!(index.label_index(1), Some(1)); let anns = index.annotations_for_image(1);
343 assert_eq!(anns.len(), 2);
344
345 let anns = index.annotations_for_image(2);
346 assert!(anns.is_empty());
347 }
348
349 #[test]
350 fn test_coco_segmentation_polygon_deserialize() {
351 let json = r#"[[100.0, 200.0, 150.0, 250.0, 100.0, 250.0]]"#;
352 let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
353
354 match seg {
355 CocoSegmentation::Polygon(polys) => {
356 assert_eq!(polys.len(), 1);
357 assert_eq!(polys[0].len(), 6);
358 }
359 _ => panic!("Expected polygon segmentation"),
360 }
361 }
362
363 #[test]
364 fn test_coco_segmentation_rle_deserialize() {
365 let json = r#"{"counts": [10, 20, 30, 40], "size": [100, 200]}"#;
366 let seg: CocoSegmentation = serde_json::from_str(json).unwrap();
367
368 match seg {
369 CocoSegmentation::Rle(rle) => {
370 assert_eq!(rle.counts, vec![10, 20, 30, 40]);
371 assert_eq!(rle.size, [100, 200]);
372 }
373 _ => panic!("Expected RLE segmentation"),
374 }
375 }
376
377 #[test]
378 fn test_coco_annotation_roundtrip() {
379 let ann = CocoAnnotation {
380 id: 12345,
381 image_id: 67890,
382 category_id: 1,
383 bbox: [100.5, 200.5, 50.0, 80.0],
384 area: 4000.0,
385 iscrowd: 0,
386 segmentation: Some(CocoSegmentation::Polygon(vec![vec![
387 100.0, 200.0, 150.0, 200.0, 150.0, 280.0, 100.0, 280.0,
388 ]])),
389 };
390
391 let json = serde_json::to_string(&ann).unwrap();
392 let restored: CocoAnnotation = serde_json::from_str(&json).unwrap();
393
394 assert_eq!(restored.id, ann.id);
395 assert_eq!(restored.image_id, ann.image_id);
396 assert_eq!(restored.category_id, ann.category_id);
397 assert_eq!(restored.bbox, ann.bbox);
398 }
399}