use std::{fs, io};
use std::io::Write;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct DataYaml {
pub train: PathBuf,
pub val: PathBuf,
pub test: PathBuf,
pub nc: usize,
pub names: Vec<String>,
}
impl DataYaml {
pub fn new() -> Self {
Self {
train: PathBuf::from("../train/images"),
val: PathBuf::from("../valid/images"),
test: PathBuf::from("../test/images"),
nc: 0,
names: Vec::new(),
}
}
pub fn set_names(&mut self, names: Vec<String>) -> &mut Self {
self.names = names;
self.nc = self.names.len();
self
}
pub fn read_from(yaml: &Path) -> io::Result<DataYaml> {
let yaml_str = fs::read_to_string(yaml)?;
let data_yaml: DataYaml = serde_yaml::from_str(&yaml_str)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(data_yaml)
}
pub fn write_to(&self, file_name: &Path) -> io::Result<()> {
let yaml = serde_yaml::to_string(self)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
if let Some(parent) = file_name.parent() {
fs::create_dir_all(parent)?;
}
let mut file = fs::File::create(file_name)?;
file.write_all(yaml.as_bytes())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn new_has_sensible_defaults() {
let dy = DataYaml::new();
assert_eq!(dy.nc, 0);
assert!(dy.names.is_empty());
assert_eq!(dy.train, PathBuf::from("../train/images"));
assert_eq!(dy.val, PathBuf::from("../valid/images"));
assert_eq!(dy.test, PathBuf::from("../test/images"));
}
#[test]
fn set_names_updates_and_returns_self() {
let mut dy = DataYaml::new();
let names = vec!["cat".into(), "dog".into()];
dy.set_names(names.clone());
assert_eq!(dy.names, names);
assert_eq!(dy.nc, 2);
}
#[test]
fn write_and_read_roundtrip() {
let tmp = std::env::temp_dir().join(format!("dy_test_{}", std::process::id()));
fs::create_dir_all(&tmp).unwrap();
let mut original = DataYaml::new();
original.nc = 3;
original.names = vec!["cat".into(), "dog".into(), "bird".into()];
let yaml_path = tmp.join("data.yaml");
original.write_to(&yaml_path).unwrap();
let restored = DataYaml::read_from(&yaml_path).unwrap();
assert_eq!(restored.nc, 3);
assert_eq!(restored.names, vec!["cat", "dog", "bird"]);
assert_eq!(restored.train, original.train);
fs::remove_dir_all(&tmp).ok();
}
#[test]
fn read_from_invalid_yaml() {
let tmp = std::env::temp_dir().join(format!("dy_bad_{}", std::process::id()));
fs::create_dir_all(&tmp).unwrap();
fs::write(tmp.join("bad.yaml"), "this: [ is not valid yaml {{").unwrap();
assert!(DataYaml::read_from(&tmp.join("bad.yaml")).is_err());
fs::remove_dir_all(&tmp).ok();
}
#[test]
fn read_from_nonexistent_file() {
assert!(DataYaml::read_from(Path::new("/nonex_dy_test")).is_err());
}
#[test]
fn write_to_creates_parent_dirs() {
let tmp = std::env::temp_dir().join(format!("dy_parent_{}", std::process::id()));
let deep = tmp.join("a/b/c/data.yaml");
DataYaml::new().write_to(&deep).unwrap();
assert!(deep.exists());
fs::remove_dir_all(&tmp).ok();
}
}
#[cfg(test)]
mod serde_tests {
use super::*;
#[test]
fn deserialize_minimal_yaml() {
let yaml = "train: ./train\nval: ./val\ntest: ./test\nnc: 0\nnames: []\n";
let dy: DataYaml = serde_yaml::from_str(yaml).unwrap();
assert_eq!(dy.train, PathBuf::from("./train"));
assert!(dy.names.is_empty());
}
#[test]
fn deserialize_with_names() {
let yaml = "train: ./train\nval: ./val\ntest: ./test\nnc: 2\nnames:\n - person\n - car\n";
let dy: DataYaml = serde_yaml::from_str(yaml).unwrap();
assert_eq!(dy.nc, 2);
assert_eq!(dy.names, vec!["person", "car"]);
}
}