yolo-set 0.1.0

A CLI tool for managing YOLO datasets — merge, deduplicate, remap labels, and more.
use std::collections::HashMap;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};

use crate::model::data_yaml::DataYaml;

pub const DATASET_DEF_FILE: &str = "data.yaml";

pub fn find_dataset(root: &Path) -> Result<Vec<PathBuf>, io::Error> {
	if root.join(DATASET_DEF_FILE).exists() {
		return Ok(vec![root.to_path_buf()]);
	}
	let dir_list = fs::read_dir(root).map_err(|e| {
		eprintln!("find_dataset 读取文件夹失败!");
		e
	})?;
	let mut result: Vec<PathBuf> = Vec::new();
	for entry in dir_list.filter_map(Result::ok) {
		if !entry.path().is_dir() {
			continue;
		}
		if !entry.path().join(DATASET_DEF_FILE).exists() {
			continue;
		}
		result.push(entry.path());
	}
	Ok(result)
}

pub fn register_dataset_label(
	dataset: &Path,
	label_map: &mut HashMap<String, usize>,
) -> Vec<usize> {
	let data_yaml = DataYaml::read_from(dataset.join("data.yaml").as_path())
		.unwrap_or_else(|_| DataYaml::new());

	let mut transformed_label_index_list: Vec<usize> = Vec::new();
	for label in data_yaml.names.iter() {
		let label = label.to_lowercase();
		if label_map.contains_key(&label) {
			transformed_label_index_list.push(label_map[&label]);
		} else {
			transformed_label_index_list.push(label_map.len());
			label_map.insert(label.clone(), label_map.len());
		}
	}
	transformed_label_index_list
}

#[cfg(test)]
mod tests {
	use super::*;
	use std::collections::HashMap;
	use std::fs;
	use std::sync::atomic::{AtomicUsize, Ordering};

	static COUNTER: AtomicUsize = AtomicUsize::new(0);

	fn tmp_dir() -> std::path::PathBuf {
		let n = COUNTER.fetch_add(1, Ordering::Relaxed);
		let dir = std::env::temp_dir().join(format!("ds_test_{}_{}", std::process::id(), n));
		let _ = fs::remove_dir_all(&dir);
		fs::create_dir_all(&dir).unwrap();
		dir
	}

	// ── find_dataset ──

	#[test]
	fn find_dataset_empty_dir() {
		let root = tmp_dir();
		let result = find_dataset(&root).unwrap();
		assert!(result.is_empty());
		fs::remove_dir_all(&root).ok();
	}

	#[test]
	fn find_dataset_with_data_yaml() {
		let root = tmp_dir();
		let ds = root.join("my_dataset");
		fs::create_dir_all(&ds).unwrap();
		fs::File::create(ds.join("data.yaml")).unwrap();
		let result = find_dataset(&root).unwrap();
		assert_eq!(result.len(), 1);
		assert_eq!(result[0].file_name().unwrap(), "my_dataset");
		fs::remove_dir_all(&root).ok();
	}

	#[test]
	fn find_dataset_root_is_dataset() {
		let root = tmp_dir();
		fs::File::create(root.join("data.yaml")).unwrap();
		let result = find_dataset(&root).unwrap();
		assert_eq!(result.len(), 1);
		assert_eq!(result[0], root);
		fs::remove_dir_all(&root).ok();
	}

	#[test]
	fn find_dataset_skips_without_yaml() {
		let root = tmp_dir();
		fs::create_dir_all(root.join("has_yaml")).unwrap();
		fs::File::create(root.join("has_yaml").join("data.yaml")).unwrap();
		fs::create_dir_all(root.join("no_yaml")).unwrap();
		let result = find_dataset(&root).unwrap();
		assert_eq!(result.len(), 1);
		fs::remove_dir_all(&root).ok();
	}

	#[test]
	fn find_dataset_skips_files() {
		let root = tmp_dir();
		fs::File::create(root.join("not_a_dir.yaml")).unwrap();
		let result = find_dataset(&root).unwrap();
		assert!(result.is_empty());
		fs::remove_dir_all(&root).ok();
	}

	#[test]
	fn find_dataset_nonexistent_path() {
		assert!(find_dataset(std::path::Path::new("/nonexistent_yolo_test_42")).is_err());
	}

	// ── register_dataset_label ──

	#[test]
	fn register_new_labels_sequential() {
		let tmp = tmp_dir();
		let ds = tmp.join("ds");
		fs::create_dir_all(&ds).unwrap();
		let yaml = concat!(
			"names:\n",
			"  - cat\n",
			"  - dog\n",
			"  - bird\n",
			"train: ../train/images\n",
			"val: ../valid/images\n",
			"test: ../test/images\n",
			"nc: 3\n",
		);
		fs::write(ds.join("data.yaml"), yaml).unwrap();

		let mut map: HashMap<String, usize> = HashMap::new();
		let ids = register_dataset_label(&ds, &mut map);
		assert_eq!(ids, vec![0, 1, 2]);
		assert_eq!(map["cat"], 0);
		assert_eq!(map["dog"], 1);
		assert_eq!(map["bird"], 2);
		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn register_labels_with_existing_map() {
		let tmp = tmp_dir();
		let ds1 = tmp.join("ds1");
		fs::create_dir_all(&ds1).unwrap();
		fs::write(ds1.join("data.yaml"), concat!(
			"names:\n  - cat\n  - dog\n",
			"train: ../train/images\nval: ../valid/images\ntest: ../test/images\nnc: 2\n",
		)).unwrap();

		let ds2 = tmp.join("ds2");
		fs::create_dir_all(&ds2).unwrap();
		fs::write(ds2.join("data.yaml"), concat!(
			"names:\n  - dog\n  - fish\n",
			"train: ../train/images\nval: ../valid/images\ntest: ../test/images\nnc: 2\n",
		)).unwrap();

		let mut map: HashMap<String, usize> = HashMap::new();
		let ids1 = register_dataset_label(&ds1, &mut map);
		assert_eq!(ids1, vec![0, 1]);

		let ids2 = register_dataset_label(&ds2, &mut map);
		assert_eq!(ids2, vec![1, 2]);
		assert_eq!(map.len(), 3);
		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn register_labels_case_insensitive() {
		let tmp = tmp_dir();
		let ds = tmp.join("ds");
		fs::create_dir_all(&ds).unwrap();
		fs::write(ds.join("data.yaml"), concat!(
			"names:\n  - Cat\n  - DOG\n",
			"train: ../train/images\nval: ../valid/images\ntest: ../test/images\nnc: 2\n",
		)).unwrap();

		let mut map: HashMap<String, usize> = HashMap::new();
		let ids = register_dataset_label(&ds, &mut map);
		assert_eq!(ids, vec![0, 1]);
		assert!(map.contains_key("cat"));
		assert!(map.contains_key("dog"));
		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn register_labels_missing_yaml_defaults_empty() {
		let tmp = tmp_dir();
		let ds = tmp.join("no_yaml_ds");
		fs::create_dir_all(&ds).unwrap();
		let mut map: HashMap<String, usize> = HashMap::new();
		let ids = register_dataset_label(&ds, &mut map);
		assert!(ids.is_empty());
		assert!(map.is_empty());
		fs::remove_dir_all(&tmp).ok();
	}
}