annotation_format_converter/
lib.rs

1use quick_xml::de;
2use serde::{Deserialize, Serialize};
3use std::error::Error;
4use std::fs;
5
6#[derive(Deserialize, Debug)]
7pub struct Voc {
8    pub filename: String,
9    pub size: Size,
10    #[serde(rename = "object")]
11    pub objects: Vec<Object>,
12}
13
14#[derive(Deserialize, Debug)]
15pub struct Size {
16    width: u32,
17    height: u32,
18}
19
20#[derive(Deserialize, Debug)]
21pub struct Object {
22    name: String,
23    bndbox: Bndbox,
24}
25
26#[derive(Deserialize, Debug)]
27pub struct Bndbox {
28    xmin: f32,
29    ymin: f32,
30    xmax: f32,
31    ymax: f32,
32}
33
34#[derive(Serialize, Debug)]
35pub struct Yolo {
36    class: usize,
37    xcr: f32,
38    ycr: f32,
39    wr: f32,
40    hr: f32,
41}
42
43pub struct Converter {
44    classes: Vec<String>,
45}
46
47impl Converter {
48    pub fn new(classes: Vec<String>) -> Converter {
49        Converter { classes }
50    }
51    pub fn voc_to_yolo(&self, xml: &str, dest: &str) -> Result<(), Box<dyn Error>> {
52        let content = fs::read_to_string(xml)?;
53        let voc = de::from_str::<Voc>(content.as_str())?;
54
55        let w = voc.size.width;
56        let h = voc.size.height;
57
58        let mut lines: Vec<String> = Vec::new();
59        for object in voc.objects {
60            let bndbox = object.bndbox;
61            let xcr = (bndbox.xmin + bndbox.xmax) / 2.0 / w as f32;
62            let ycr = (bndbox.ymin + bndbox.ymax) / 2.0 / h as f32;
63            let wr = (bndbox.xmax - bndbox.xmin) / w as f32;
64            let hr = (bndbox.ymax - bndbox.ymin) / h as f32;
65
66            if !self.classes.contains(&object.name) {
67                return Ok(());
68            }
69
70            let mut index = 0;
71            for item in self.classes.iter().enumerate() {
72                if item.1.eq(&object.name) {
73                    index = item.0
74                }
75            }
76            let line = format!("{} {} {} {} {}", index, xcr, ycr, wr, hr);
77
78            lines.push(line);
79        }
80
81
82        fs::write(dest, lines.join("\n"))?;
83        println!("{:#?}", lines);
84        Ok(())
85    }
86}