yolo-set 0.1.0

A CLI tool for managing YOLO datasets — merge, deduplicate, remap labels, and more.
use std::{fs, io};
use std::io::Write;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
pub struct DataYaml {
	pub train: PathBuf,
	pub val: PathBuf,
	pub test: PathBuf,
	pub nc: usize,
	pub names: Vec<String>,
}

impl DataYaml {
	pub fn new() -> Self {
		Self {
			train: PathBuf::from("../train/images"),
			val: PathBuf::from("../valid/images"),
			test: PathBuf::from("../test/images"),
			nc: 0,
			names: Vec::new(),
		}
	}

	pub fn set_names(&mut self, names: Vec<String>) -> &mut Self {
		self.names = names;
		self.nc = self.names.len();
		self
	}

	pub fn read_from(yaml: &Path) -> io::Result<DataYaml> {
		let yaml_str = fs::read_to_string(yaml)?;
		let data_yaml: DataYaml = serde_yaml::from_str(&yaml_str)
			.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
		Ok(data_yaml)
	}

	pub fn write_to(&self, file_name: &Path) -> io::Result<()> {
		let yaml = serde_yaml::to_string(self)
			.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
		if let Some(parent) = file_name.parent() {
			fs::create_dir_all(parent)?;
		}
		let mut file = fs::File::create(file_name)?;
		file.write_all(yaml.as_bytes())?;
		Ok(())
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use std::fs;

	#[test]
	fn new_has_sensible_defaults() {
		let dy = DataYaml::new();
		assert_eq!(dy.nc, 0);
		assert!(dy.names.is_empty());
		assert_eq!(dy.train, PathBuf::from("../train/images"));
		assert_eq!(dy.val, PathBuf::from("../valid/images"));
		assert_eq!(dy.test, PathBuf::from("../test/images"));
	}

	#[test]
	fn set_names_updates_and_returns_self() {
		let mut dy = DataYaml::new();
		let names = vec!["cat".into(), "dog".into()];
		dy.set_names(names.clone());
		assert_eq!(dy.names, names);
		assert_eq!(dy.nc, 2);
	}

	#[test]
	fn write_and_read_roundtrip() {
		let tmp = std::env::temp_dir().join(format!("dy_test_{}", std::process::id()));
		fs::create_dir_all(&tmp).unwrap();

		let mut original = DataYaml::new();
		original.nc = 3;
		original.names = vec!["cat".into(), "dog".into(), "bird".into()];

		let yaml_path = tmp.join("data.yaml");
		original.write_to(&yaml_path).unwrap();

		let restored = DataYaml::read_from(&yaml_path).unwrap();
		assert_eq!(restored.nc, 3);
		assert_eq!(restored.names, vec!["cat", "dog", "bird"]);
		assert_eq!(restored.train, original.train);

		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn read_from_invalid_yaml() {
		let tmp = std::env::temp_dir().join(format!("dy_bad_{}", std::process::id()));
		fs::create_dir_all(&tmp).unwrap();
		fs::write(tmp.join("bad.yaml"), "this: [ is not valid yaml {{").unwrap();
		assert!(DataYaml::read_from(&tmp.join("bad.yaml")).is_err());
		fs::remove_dir_all(&tmp).ok();
	}

	#[test]
	fn read_from_nonexistent_file() {
		assert!(DataYaml::read_from(Path::new("/nonex_dy_test")).is_err());
	}

	#[test]
	fn write_to_creates_parent_dirs() {
		let tmp = std::env::temp_dir().join(format!("dy_parent_{}", std::process::id()));
		let deep = tmp.join("a/b/c/data.yaml");
		DataYaml::new().write_to(&deep).unwrap();
		assert!(deep.exists());
		fs::remove_dir_all(&tmp).ok();
	}
}

#[cfg(test)]
mod serde_tests {
	use super::*;

	#[test]
	fn deserialize_minimal_yaml() {
		let yaml = "train: ./train\nval: ./val\ntest: ./test\nnc: 0\nnames: []\n";
		let dy: DataYaml = serde_yaml::from_str(yaml).unwrap();
		assert_eq!(dy.train, PathBuf::from("./train"));
		assert!(dy.names.is_empty());
	}

	#[test]
	fn deserialize_with_names() {
		let yaml = "train: ./train\nval: ./val\ntest: ./test\nnc: 2\nnames:\n  - person\n  - car\n";
		let dy: DataYaml = serde_yaml::from_str(yaml).unwrap();
		assert_eq!(dy.nc, 2);
		assert_eq!(dy.names, vec!["person", "car"]);
	}
}