use std::collections::HashMap;
use std::fs;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;
use crate::geometry::{corners_to_obb, obb_to_aabb, obb_to_corners};
use crate::types::{Annotation, Category, Dataset, Image};
use super::{ConvertError, anns_by_image, file_stem};
#[derive(Debug, Default)]
pub struct DotaStats {
pub images: usize,
pub annotations: usize,
pub skipped_no_obb: usize,
}
pub fn coco_to_dota(dataset: &Dataset, output_dir: &Path) -> Result<DotaStats, ConvertError> {
fs::create_dir_all(output_dir)?;
let cat_map: HashMap<u64, &str> = dataset
.categories
.iter()
.map(|c| (c.id, c.name.as_str()))
.collect();
let grouped = anns_by_image(dataset);
let mut stats = DotaStats::default();
for img in &dataset.images {
let stem = file_stem(&img.file_name);
let path = output_dir.join(format!("{stem}.txt"));
let mut file = fs::File::create(&path)?;
if let Some(anns) = grouped.get(&img.id) {
for ann in anns {
let obb = match &ann.obb {
Some(o) => o,
None => {
stats.skipped_no_obb += 1;
continue;
}
};
let cat_name = cat_map.get(&ann.category_id).copied().unwrap_or("unknown");
let difficulty = i32::from(ann.iscrowd);
let corners = obb_to_corners(obb);
writeln!(
file,
"{:.1} {:.1} {:.1} {:.1} {:.1} {:.1} {:.1} {:.1} {} {}",
corners[0].0,
corners[0].1,
corners[1].0,
corners[1].1,
corners[2].0,
corners[2].1,
corners[3].0,
corners[3].1,
cat_name,
difficulty,
)?;
stats.annotations += 1;
}
}
stats.images += 1;
}
Ok(stats)
}
pub fn dota_to_coco(
label_dir: &Path,
image_dims: &HashMap<String, (u32, u32)>,
categories: Option<Vec<String>>,
) -> Result<Dataset, ConvertError> {
let mut cat_name_to_id: HashMap<String, u64> = HashMap::new();
let mut cat_list: Vec<Category> = Vec::new();
if let Some(ref cats) = categories {
for (i, name) in cats.iter().enumerate() {
let id = (i + 1) as u64;
cat_name_to_id.insert(name.clone(), id);
cat_list.push(Category {
id,
name: name.clone(),
supercategory: None,
skeleton: None,
keypoints: None,
frequency: None,
});
}
}
let mut images: Vec<Image> = Vec::new();
let mut annotations: Vec<Annotation> = Vec::new();
let mut img_id: u64 = 1;
let mut ann_id: u64 = 1;
let mut entries: Vec<_> = fs::read_dir(label_dir)?
.filter_map(std::result::Result::ok)
.filter(|e| e.path().extension().is_some_and(|ext| ext == "txt"))
.collect();
entries.sort_by_key(std::fs::DirEntry::file_name);
for entry in entries {
let path = entry.path();
let stem = file_stem(
path.file_name()
.expect("entry has a file name")
.to_str()
.unwrap_or(""),
);
let (width, height) = super::lookup_image_dims(image_dims, stem);
images.push(Image {
id: img_id,
file_name: format!("{stem}.png"),
width,
height,
license: None,
coco_url: None,
flickr_url: None,
date_captured: None,
neg_category_ids: vec![],
not_exhaustive_category_ids: vec![],
});
let file = fs::File::open(&path)?;
let reader = BufReader::new(file);
for line in reader.lines() {
let line = line?;
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 9 {
continue; }
let coords: Result<Vec<f64>, _> = parts[..8].iter().map(|s| s.parse::<f64>()).collect();
let coords = match coords {
Ok(c) => c,
Err(_) => continue, };
let cat_name = parts[8].to_string();
let difficulty: u8 = parts.get(9).and_then(|s| s.parse().ok()).unwrap_or(0);
let category_id = if let Some(&id) = cat_name_to_id.get(&cat_name) {
id
} else {
let id = (cat_list.len() + 1) as u64;
cat_name_to_id.insert(cat_name.clone(), id);
cat_list.push(Category {
id,
name: cat_name,
supercategory: None,
skeleton: None,
keypoints: None,
frequency: None,
});
id
};
let obb = corners_to_obb(&coords);
annotations.push(Annotation {
id: ann_id,
image_id: img_id,
category_id,
bbox: Some(obb_to_aabb(&obb)),
area: Some(obb[2] * obb[3]),
segmentation: None,
iscrowd: difficulty > 0,
keypoints: None,
num_keypoints: None,
obb: Some(obb),
score: None,
is_group_of: None,
});
ann_id += 1;
}
img_id += 1;
}
Ok(Dataset {
info: None,
images,
annotations,
categories: cat_list,
licenses: vec![],
})
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use tempfile::TempDir;
const EPS: f64 = 0.2;
#[test]
fn test_corners_to_obb_axis_aligned() {
let coords = [0.0, 0.0, 4.0, 0.0, 4.0, 3.0, 0.0, 3.0];
let obb = corners_to_obb(&coords);
assert!((obb[0] - 2.0).abs() < 1e-9, "cx"); assert!((obb[1] - 1.5).abs() < 1e-9, "cy"); assert!((obb[2] - 4.0).abs() < 1e-9, "w"); assert!((obb[3] - 3.0).abs() < 1e-9, "h"); assert!(obb[4].abs() < 1e-9, "angle should be 0"); }
#[test]
fn test_dota_round_trip() {
let dataset = Dataset {
info: None,
images: vec![Image {
id: 1,
file_name: "test.png".into(),
width: 800,
height: 600,
license: None,
coco_url: None,
flickr_url: None,
date_captured: None,
neg_category_ids: vec![],
not_exhaustive_category_ids: vec![],
}],
annotations: vec![Annotation {
id: 1,
image_id: 1,
category_id: 1,
bbox: Some([10.0, 10.0, 40.0, 20.0]),
area: Some(800.0),
segmentation: None,
iscrowd: false,
keypoints: None,
num_keypoints: None,
obb: Some([30.0, 20.0, 40.0, 20.0, 0.0]),
score: None,
is_group_of: None,
}],
categories: vec![Category {
id: 1,
name: "vehicle".into(),
supercategory: None,
skeleton: None,
keypoints: None,
frequency: None,
}],
licenses: vec![],
};
let tmp = TempDir::new().unwrap();
let label_dir = tmp.path().join("labels");
let stats = coco_to_dota(&dataset, &label_dir).unwrap();
assert_eq!(stats.images, 1);
assert_eq!(stats.annotations, 1);
let mut dims = HashMap::new();
dims.insert("test".into(), (800, 600));
let result = dota_to_coco(&label_dir, &dims, None).unwrap();
assert_eq!(result.annotations.len(), 1);
let ann = &result.annotations[0];
let obb = ann.obb.unwrap();
assert!((obb[0] - 30.0).abs() < EPS, "cx: {}", obb[0]);
assert!((obb[1] - 20.0).abs() < EPS, "cy: {}", obb[1]);
assert!((obb[2] - 40.0).abs() < EPS, "w: {}", obb[2]);
assert!((obb[3] - 20.0).abs() < EPS, "h: {}", obb[3]);
assert!(obb[4].abs() < EPS, "angle: {}", obb[4]);
assert_eq!(result.categories.len(), 1);
assert_eq!(result.categories[0].name, "vehicle");
}
}