yolo-set 0.1.0

A CLI tool for managing YOLO datasets — merge, deduplicate, remap labels, and more.
use std::collections::HashMap;
use std::fs;
use std::io::{self, BufRead, BufReader, Write};
use std::path::Path;

pub(crate) fn remap_label_prefix(line: &str, mapping: &[usize]) -> String {
	let num_str: String = line.chars().take_while(|ch| ch.is_ascii_digit()).collect();

	if num_str.is_empty() {
		return line.to_string();
	}

	let idx: usize = match num_str.parse() {
		Ok(n) => n,
		Err(_) => return line.to_string(),
	};

	match mapping.get(idx) {
		Some(&new_idx) => new_idx.to_string() + &line[num_str.len()..],
		None => line.to_string(),
	}
}

pub(crate) fn remap_labels_in_file(path: &Path, mapping: &[usize]) -> io::Result<()> {
	let file = fs::File::open(path).map_err(|e| {
		eprintln!("remap_labels_in_file 打开文件{path:?}失败");
		e
	})?;
	let lines: io::Result<Vec<String>> = BufReader::new(file)
		.lines()
		.map(|l| l.map(|line| remap_label_prefix(&line, mapping)))
		.collect();
	let lines = lines.map_err(|e| {
		eprintln!("remap_labels_in_file 转换行失败");
		e
	})?;

	let mut out_file = fs::File::create(path).map_err(|e| {
		eprintln!("remap_labels_in_file 创建路径{path:?}失败");
		e
	})?;
	for line in lines {
		writeln!(out_file, "{}", line).map_err(|e| {
			eprintln!("remap_labels_in_file 写入失败");
			e
		})?;
	}
	Ok(())
}

pub(crate) fn update_label_txt(folder: &Path, mapping: &[usize]) -> io::Result<()> {
	for entry in fs::read_dir(folder).map_err(|e| {
		eprintln!("update_label_txt");
		e
	})? {
		let entry = entry.map_err(|e| {
			eprintln!("update_label_txt");
			e
		})?;
		let path = entry.path();
		if path.extension().and_then(|ext| ext.to_str()) != Some("txt") {
			continue;
		}
		remap_labels_in_file(&path, mapping).map_err(|e| {
			eprintln!("update_label_txt");
			e
		})?;
	}
	Ok(())
}

pub fn key_list_sorted_by_value<T, P>(map: HashMap<T, P>) -> Vec<T>
where
	P: Ord + Clone,
{
	let mut kvlist: Vec<(T, P)> = map.into_iter().collect();
	kvlist.sort_by(|a, b| a.1.cmp(&b.1));
	kvlist.into_iter().map(|(k, _)| k).collect()
}

#[cfg(test)]
mod tests {
	use super::*;
	use std::collections::HashMap;
	use std::fs;
	use std::sync::atomic::{AtomicUsize, Ordering};

	static COUNTER: AtomicUsize = AtomicUsize::new(0);

	fn tmp_dir() -> std::path::PathBuf {
		let n = COUNTER.fetch_add(1, Ordering::Relaxed);
		let dir = std::env::temp_dir().join(format!("label_test_{}_{}", std::process::id(), n));
		let _ = fs::remove_dir_all(&dir);
		fs::create_dir_all(&dir).unwrap();
		dir
	}

	// ── remap_label_prefix ──

	#[test]
	fn remap_prefix_normal() {
		let mapping = [1, 0, 3, 2];
		assert_eq!(remap_label_prefix("0 0.5 0.5 0.3 0.3", &mapping), "1 0.5 0.5 0.3 0.3");
		assert_eq!(remap_label_prefix("2 1.0 0.0 0.2 0.2", &mapping), "3 1.0 0.0 0.2 0.2");
	}

	#[test]
	fn remap_prefix_multi_digit() {
		let mapping = [100, 200];
		assert_eq!(remap_label_prefix("0 data", &mapping), "100 data");
		assert_eq!(remap_label_prefix("1 data", &mapping), "200 data");
	}

	#[test]
	fn remap_prefix_no_digit() {
		let mapping = [5];
		assert_eq!(remap_label_prefix("hello world", &mapping), "hello world");
		assert_eq!(remap_label_prefix("", &mapping), "");
	}

	#[test]
	fn remap_prefix_non_num_start() {
		let mapping = [1, 2];
		assert_eq!(remap_label_prefix("abc123", &mapping), "abc123");
	}

	#[test]
	fn remap_prefix_index_oob() {
		let mapping = [0, 1];
		assert_eq!(remap_label_prefix("5 0.5 0.5 0.3 0.3", &mapping), "5 0.5 0.5 0.3 0.3");
	}

	#[test]
	fn remap_prefix_only_digits() {
		let mapping = [7, 8, 9];
		assert_eq!(remap_label_prefix("2", &mapping), "9");
		assert_eq!(remap_label_prefix("0", &mapping), "7");
	}

	#[test]
	fn remap_prefix_empty_mapping() {
		let mapping: [usize; 0] = [];
		assert_eq!(remap_label_prefix("0 0.5", &mapping), "0 0.5");
	}

	// ── remap_labels_in_file ──

	#[test]
	fn remap_file_content() {
		let tmp = tmp_dir();
		let fp = tmp.join("labels.txt");
		fs::write(&fp, "0 0.5 0.5 0.3 0.3\n1 0.1 0.2 0.3 0.4\nskip me\n").unwrap();
		let mapping = [2, 0];
		remap_labels_in_file(&fp, &mapping).unwrap();
		let result = fs::read_to_string(&fp).unwrap();
		assert_eq!(result, "2 0.5 0.5 0.3 0.3\n0 0.1 0.2 0.3 0.4\nskip me\n");
		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn remap_file_nonexistent() {
		assert!(remap_labels_in_file(std::path::Path::new("/nonex_yolo_t"), &[0]).is_err());
	}

	// ── update_label_txt ──

	#[test]
	fn update_txt_only_txt_files() {
		let tmp = tmp_dir();
		fs::write(tmp.join("labels.txt"), "0 0.5\n").unwrap();
		fs::write(tmp.join("image.jpg"), "0 should stay").unwrap();
		fs::write(tmp.join("notes.md"), "0 notes").unwrap();
		update_label_txt(&tmp, &[9]).unwrap();
		assert_eq!(fs::read_to_string(tmp.join("labels.txt")).unwrap(), "9 0.5\n");
		assert_eq!(fs::read_to_string(tmp.join("image.jpg")).unwrap(), "0 should stay");
		assert_eq!(fs::read_to_string(tmp.join("notes.md")).unwrap(), "0 notes");
		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn update_txt_all_txt_files() {
		let tmp = tmp_dir();
		fs::write(tmp.join("a.txt"), "0 cat\n1 dog\n").unwrap();
		fs::write(tmp.join("b.txt"), "2 fish\nno label\n").unwrap();
		update_label_txt(&tmp, &[10, 11, 12]).unwrap();
		assert_eq!(fs::read_to_string(tmp.join("a.txt")).unwrap(), "10 cat\n11 dog\n");
		assert_eq!(fs::read_to_string(tmp.join("b.txt")).unwrap(), "12 fish\nno label\n");
		fs::remove_dir_all(&tmp).ok();
	}

	// ── key_list_sorted_by_value ──

	#[test]
	fn key_list_empty() {
		let map: HashMap<String, usize> = HashMap::new();
		assert!(key_list_sorted_by_value(map).is_empty());
	}

	#[test]
	fn key_list_sorted() {
		let mut map = HashMap::new();
		map.insert("b", 2);
		map.insert("a", 0);
		map.insert("c", 1);
		assert_eq!(key_list_sorted_by_value(map), vec!["a", "c", "b"]);
	}

	#[test]
	fn key_list_single() {
		let mut map = HashMap::new();
		map.insert("only", 42);
		assert_eq!(key_list_sorted_by_value(map), vec!["only"]);
	}
}