use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use crate::types::{Annotation, Category, Dataset, Image};
use super::ConvertError;
#[derive(Debug, Clone)]
pub struct YoloStats {
pub images: usize,
pub annotations: usize,
pub skipped_crowd: usize,
pub missing_bbox: usize,
}
pub fn coco_to_yolo(dataset: &Dataset, output_dir: &Path) -> Result<YoloStats, ConvertError> {
fs::create_dir_all(output_dir)?;
let mut sorted_cats: Vec<&Category> = dataset.categories.iter().collect();
sorted_cats.sort_by_key(|c| c.id);
let cat_id_to_idx: HashMap<u64, usize> = sorted_cats
.iter()
.enumerate()
.map(|(i, c)| (c.id, i))
.collect();
let anns_by_image = super::anns_by_image(dataset);
let mut total_annotations = 0usize;
let mut skipped_crowd = 0usize;
let mut missing_bbox = 0usize;
for img in &dataset.images {
if img.width == 0 || img.height == 0 {
return Err(ConvertError::MissingImageDimensions(img.id));
}
let stem = super::file_stem(&img.file_name);
let txt_path = output_dir.join(format!("{stem}.txt"));
let mut file = fs::File::create(&txt_path)?;
let w = img.width as f64;
let h = img.height as f64;
if let Some(anns) = anns_by_image.get(&img.id) {
for ann in anns {
if ann.iscrowd {
skipped_crowd += 1;
continue;
}
let bbox = match ann.bbox {
Some(b) => b,
None => {
missing_bbox += 1;
continue;
}
};
let class_idx = match cat_id_to_idx.get(&ann.category_id) {
Some(&idx) => idx,
None => continue,
};
let [x, y, bw, bh] = bbox;
let cx = (x + bw / 2.0) / w;
let cy = (y + bh / 2.0) / h;
let nw = bw / w;
let nh = bh / h;
writeln!(file, "{class_idx} {cx:.6} {cy:.6} {nw:.6} {nh:.6}")?;
total_annotations += 1;
}
}
}
let yaml_path = output_dir.join("data.yaml");
let mut yaml_file = fs::File::create(&yaml_path)?;
writeln!(yaml_file, "nc: {}", sorted_cats.len())?;
let names_csv: Vec<&str> = sorted_cats.iter().map(|c| c.name.as_str()).collect();
writeln!(yaml_file, "names: [{}]", names_csv.join(", "))?;
Ok(YoloStats {
images: dataset.images.len(),
annotations: total_annotations,
skipped_crowd,
missing_bbox,
})
}
pub fn yolo_to_coco(
yolo_dir: &Path,
image_dims: &HashMap<String, (u32, u32)>,
) -> Result<Dataset, ConvertError> {
let yaml_path = yolo_dir.join("data.yaml");
if !yaml_path.exists() {
return Err(ConvertError::MissingDataYaml);
}
let yaml_content = fs::read_to_string(&yaml_path)?;
let names = parse_data_yaml(&yaml_content)?;
let categories: Vec<Category> = names
.iter()
.enumerate()
.map(|(i, name)| Category {
id: (i + 1) as u64,
name: name.clone(),
supercategory: None,
skeleton: None,
keypoints: None,
frequency: None,
})
.collect();
let mut txt_files: Vec<PathBuf> = fs::read_dir(yolo_dir)?
.filter_map(|entry| {
let path = entry.ok()?.path();
if path.extension().and_then(|e| e.to_str()) == Some("txt") {
Some(path)
} else {
None
}
})
.collect();
txt_files.sort();
let mut images: Vec<Image> = Vec::new();
let mut annotations: Vec<Annotation> = Vec::new();
let mut img_id = 1u64;
let mut ann_id = 1u64;
for txt_path in &txt_files {
let stem = txt_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_string();
let (width, height) = super::lookup_image_dims(image_dims, &stem);
images.push(Image {
id: img_id,
file_name: stem.clone(),
width,
height,
license: None,
coco_url: None,
flickr_url: None,
date_captured: None,
neg_category_ids: vec![],
not_exhaustive_category_ids: vec![],
});
let content = fs::read_to_string(txt_path)?;
let w = width as f64;
let h = height as f64;
for line in content.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != 5 {
return Err(ConvertError::ParseError(format!(
"expected 5 fields, got {} in: {line}",
parts.len()
)));
}
let class_idx: usize = parts[0].parse().map_err(|_| {
ConvertError::ParseError(format!("invalid class_idx: {}", parts[0]))
})?;
let cx: f64 = parts[1]
.parse()
.map_err(|_| ConvertError::ParseError(format!("invalid cx: {}", parts[1])))?;
let cy: f64 = parts[2]
.parse()
.map_err(|_| ConvertError::ParseError(format!("invalid cy: {}", parts[2])))?;
let bw: f64 = parts[3]
.parse()
.map_err(|_| ConvertError::ParseError(format!("invalid width: {}", parts[3])))?;
let bh: f64 = parts[4]
.parse()
.map_err(|_| ConvertError::ParseError(format!("invalid height: {}", parts[4])))?;
if class_idx >= categories.len() {
return Err(ConvertError::ParseError(format!(
"class_idx {class_idx} out of range (nc={})",
categories.len()
)));
}
let category_id = (class_idx + 1) as u64;
let px = (cx - bw / 2.0) * w;
let py = (cy - bh / 2.0) * h;
let pw = bw * w;
let ph = bh * h;
annotations.push(Annotation {
id: ann_id,
image_id: img_id,
category_id,
bbox: Some([px, py, pw, ph]),
area: Some(pw * ph),
segmentation: None,
iscrowd: false,
keypoints: None,
num_keypoints: None,
obb: None,
score: None,
is_group_of: None,
});
ann_id += 1;
}
img_id += 1;
}
Ok(Dataset {
info: None,
images,
annotations,
categories,
licenses: vec![],
})
}
fn parse_data_yaml(content: &str) -> Result<Vec<String>, ConvertError> {
for line in content.lines() {
let line = line.trim();
if let Some(rest) = line.strip_prefix("names:") {
let rest = rest.trim();
if let Some(inner) = rest.strip_prefix('[').and_then(|r| r.strip_suffix(']')) {
let names: Vec<String> = inner
.split(',')
.map(|s| s.trim().trim_matches('"').trim_matches('\'').to_string())
.filter(|s| !s.is_empty())
.collect();
return Ok(names);
}
}
}
Err(ConvertError::ParseError(
"no 'names' field found in data.yaml".into(),
))
}