use std::collections::HashMap;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use crate::model::data_yaml::DataYaml;
pub const DATASET_DEF_FILE: &str = "data.yaml";
pub fn find_dataset(root: &Path) -> Result<Vec<PathBuf>, io::Error> {
if root.join(DATASET_DEF_FILE).exists() {
return Ok(vec![root.to_path_buf()]);
}
let dir_list = fs::read_dir(root).map_err(|e| {
eprintln!("find_dataset 读取文件夹失败!");
e
})?;
let mut result: Vec<PathBuf> = Vec::new();
for entry in dir_list.filter_map(Result::ok) {
if !entry.path().is_dir() {
continue;
}
if !entry.path().join(DATASET_DEF_FILE).exists() {
continue;
}
result.push(entry.path());
}
Ok(result)
}
pub fn register_dataset_label(
dataset: &Path,
label_map: &mut HashMap<String, usize>,
) -> Vec<usize> {
let data_yaml = DataYaml::read_from(dataset.join("data.yaml").as_path())
.unwrap_or_else(|_| DataYaml::new());
let mut transformed_label_index_list: Vec<usize> = Vec::new();
for label in data_yaml.names.iter() {
let label = label.to_lowercase();
if label_map.contains_key(&label) {
transformed_label_index_list.push(label_map[&label]);
} else {
transformed_label_index_list.push(label_map.len());
label_map.insert(label.clone(), label_map.len());
}
}
transformed_label_index_list
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::fs;
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
fn tmp_dir() -> std::path::PathBuf {
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!("ds_test_{}_{}", std::process::id(), n));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn find_dataset_empty_dir() {
let root = tmp_dir();
let result = find_dataset(&root).unwrap();
assert!(result.is_empty());
fs::remove_dir_all(&root).ok();
}
#[test]
fn find_dataset_with_data_yaml() {
let root = tmp_dir();
let ds = root.join("my_dataset");
fs::create_dir_all(&ds).unwrap();
fs::File::create(ds.join("data.yaml")).unwrap();
let result = find_dataset(&root).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].file_name().unwrap(), "my_dataset");
fs::remove_dir_all(&root).ok();
}
#[test]
fn find_dataset_root_is_dataset() {
let root = tmp_dir();
fs::File::create(root.join("data.yaml")).unwrap();
let result = find_dataset(&root).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0], root);
fs::remove_dir_all(&root).ok();
}
#[test]
fn find_dataset_skips_without_yaml() {
let root = tmp_dir();
fs::create_dir_all(root.join("has_yaml")).unwrap();
fs::File::create(root.join("has_yaml").join("data.yaml")).unwrap();
fs::create_dir_all(root.join("no_yaml")).unwrap();
let result = find_dataset(&root).unwrap();
assert_eq!(result.len(), 1);
fs::remove_dir_all(&root).ok();
}
#[test]
fn find_dataset_skips_files() {
let root = tmp_dir();
fs::File::create(root.join("not_a_dir.yaml")).unwrap();
let result = find_dataset(&root).unwrap();
assert!(result.is_empty());
fs::remove_dir_all(&root).ok();
}
#[test]
fn find_dataset_nonexistent_path() {
assert!(find_dataset(std::path::Path::new("/nonexistent_yolo_test_42")).is_err());
}
#[test]
fn register_new_labels_sequential() {
let tmp = tmp_dir();
let ds = tmp.join("ds");
fs::create_dir_all(&ds).unwrap();
let yaml = concat!(
"names:\n",
" - cat\n",
" - dog\n",
" - bird\n",
"train: ../train/images\n",
"val: ../valid/images\n",
"test: ../test/images\n",
"nc: 3\n",
);
fs::write(ds.join("data.yaml"), yaml).unwrap();
let mut map: HashMap<String, usize> = HashMap::new();
let ids = register_dataset_label(&ds, &mut map);
assert_eq!(ids, vec![0, 1, 2]);
assert_eq!(map["cat"], 0);
assert_eq!(map["dog"], 1);
assert_eq!(map["bird"], 2);
fs::remove_dir_all(&tmp).ok();
}
#[test]
fn register_labels_with_existing_map() {
let tmp = tmp_dir();
let ds1 = tmp.join("ds1");
fs::create_dir_all(&ds1).unwrap();
fs::write(ds1.join("data.yaml"), concat!(
"names:\n - cat\n - dog\n",
"train: ../train/images\nval: ../valid/images\ntest: ../test/images\nnc: 2\n",
)).unwrap();
let ds2 = tmp.join("ds2");
fs::create_dir_all(&ds2).unwrap();
fs::write(ds2.join("data.yaml"), concat!(
"names:\n - dog\n - fish\n",
"train: ../train/images\nval: ../valid/images\ntest: ../test/images\nnc: 2\n",
)).unwrap();
let mut map: HashMap<String, usize> = HashMap::new();
let ids1 = register_dataset_label(&ds1, &mut map);
assert_eq!(ids1, vec![0, 1]);
let ids2 = register_dataset_label(&ds2, &mut map);
assert_eq!(ids2, vec![1, 2]);
assert_eq!(map.len(), 3);
fs::remove_dir_all(&tmp).ok();
}
#[test]
fn register_labels_case_insensitive() {
let tmp = tmp_dir();
let ds = tmp.join("ds");
fs::create_dir_all(&ds).unwrap();
fs::write(ds.join("data.yaml"), concat!(
"names:\n - Cat\n - DOG\n",
"train: ../train/images\nval: ../valid/images\ntest: ../test/images\nnc: 2\n",
)).unwrap();
let mut map: HashMap<String, usize> = HashMap::new();
let ids = register_dataset_label(&ds, &mut map);
assert_eq!(ids, vec![0, 1]);
assert!(map.contains_key("cat"));
assert!(map.contains_key("dog"));
fs::remove_dir_all(&tmp).ok();
}
#[test]
fn register_labels_missing_yaml_defaults_empty() {
let tmp = tmp_dir();
let ds = tmp.join("no_yaml_ds");
fs::create_dir_all(&ds).unwrap();
let mut map: HashMap<String, usize> = HashMap::new();
let ids = register_dataset_label(&ds, &mut map);
assert!(ids.is_empty());
assert!(map.is_empty());
fs::remove_dir_all(&tmp).ok();
}
}