use crate::error::{JudgyError, Result};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
pub fn parse_binary_string(input: &str) -> Result<Vec<u8>> {
if input.trim().is_empty() {
return Ok(Vec::new());
}
let mut result = Vec::new();
for part in input.split(',') {
let trimmed = part.trim();
match trimmed.parse::<u8>() {
Ok(0) => result.push(0),
Ok(1) => result.push(1),
Ok(other) => {
return Err(JudgyError::parse(format!(
"Invalid value '{}': must be 0 or 1",
other
)));
}
Err(_) => {
return Err(JudgyError::parse(format!(
"Could not parse '{}' as integer",
trimmed
)));
}
}
}
Ok(result)
}
pub fn load_binary_from_csv<P: AsRef<Path>>(path: P) -> Result<Vec<u8>> {
let file = File::open(&path).map_err(|e| {
JudgyError::Io(std::io::Error::new(
e.kind(),
format!("Failed to open file '{}': {}", path.as_ref().display(), e),
))
})?;
let reader = BufReader::new(file);
let mut result = Vec::new();
for (line_num, line) in reader.lines().enumerate() {
let line = line?;
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.contains(char::is_alphabetic) {
continue;
}
match trimmed.parse::<u8>() {
Ok(0) => result.push(0),
Ok(1) => result.push(1),
Ok(other) => {
return Err(JudgyError::parse(format!(
"Invalid value '{}' on line {}: must be 0 or 1",
other,
line_num + 1
)));
}
Err(_) => {
return Err(JudgyError::parse(format!(
"Could not parse '{}' on line {} as integer",
trimmed,
line_num + 1
)));
}
}
}
if result.is_empty() {
return Err(JudgyError::parse(format!(
"No valid data found in file '{}'",
path.as_ref().display()
)));
}
Ok(result)
}
pub fn format_float(value: f64, precision: usize) -> String {
if value.is_nan() {
"NaN".to_string()
} else if value.is_infinite() {
if value.is_sign_positive() {
"∞".to_string()
} else {
"-∞".to_string()
}
} else {
format!("{:.precision$}", value, precision = precision)
}
}
pub fn format_percentage(value: f64, precision: usize) -> String {
if value.is_nan() {
"NaN%".to_string()
} else {
format!("{:.precision$}%", value * 100.0, precision = precision)
}
}
pub fn validate_probability(value: f64, name: &str) -> Result<()> {
if value.is_nan() {
return Err(JudgyError::input_validation(format!(
"{} cannot be NaN",
name
)));
}
if !(0.0..=1.0).contains(&value) {
return Err(JudgyError::input_validation(format!(
"{} must be between 0 and 1, got {}",
name, value
)));
}
Ok(())
}
pub fn parse_range(input: &str) -> Result<(f64, f64)> {
let parts: Vec<&str> = input.split(',').collect();
if parts.len() != 2 {
return Err(JudgyError::parse(format!(
"Range must be in format 'min,max', got '{}'",
input
)));
}
let min = parts[0].trim().parse::<f64>().map_err(|_| {
JudgyError::parse(format!("Could not parse '{}' as number", parts[0].trim()))
})?;
let max = parts[1].trim().parse::<f64>().map_err(|_| {
JudgyError::parse(format!("Could not parse '{}' as number", parts[1].trim()))
})?;
if min >= max {
return Err(JudgyError::parse(format!(
"Range minimum ({}) must be less than maximum ({})",
min, max
)));
}
Ok((min, max))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_parse_binary_string_valid() {
let result = parse_binary_string("1,0,1,1,0").unwrap();
assert_eq!(result, vec![1, 0, 1, 1, 0]);
let result = parse_binary_string("0").unwrap();
assert_eq!(result, vec![0]);
let result = parse_binary_string("1").unwrap();
assert_eq!(result, vec![1]);
let result = parse_binary_string(" 1 , 0 , 1 ").unwrap();
assert_eq!(result, vec![1, 0, 1]);
}
#[test]
fn test_parse_binary_string_empty() {
let result = parse_binary_string("").unwrap();
assert_eq!(result, Vec::<u8>::new());
let result = parse_binary_string(" ").unwrap();
assert_eq!(result, Vec::<u8>::new());
}
#[test]
fn test_parse_binary_string_invalid() {
assert!(parse_binary_string("1,2,0").is_err());
assert!(parse_binary_string("1,a,0").is_err());
assert!(parse_binary_string("1,,0").is_err());
assert!(parse_binary_string("1.0,0").is_err());
}
#[test]
fn test_load_binary_from_csv() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
writeln!(temp_file, "1")?;
writeln!(temp_file, "0")?;
writeln!(temp_file, "1")?;
writeln!(temp_file, "1")?;
writeln!(temp_file, "0")?;
temp_file.flush()?;
let result = load_binary_from_csv(temp_file.path())?;
assert_eq!(result, vec![1, 0, 1, 1, 0]);
Ok(())
}
#[test]
fn test_load_binary_from_csv_with_header() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
writeln!(temp_file, "label")?; writeln!(temp_file, "1")?;
writeln!(temp_file, "0")?;
writeln!(temp_file, "1")?;
temp_file.flush()?;
let result = load_binary_from_csv(temp_file.path())?;
assert_eq!(result, vec![1, 0, 1]);
Ok(())
}
#[test]
fn test_load_binary_from_csv_with_empty_lines() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
writeln!(temp_file, "1")?;
writeln!(temp_file, "")?; writeln!(temp_file, "0")?;
writeln!(temp_file, " ")?; writeln!(temp_file, "1")?;
temp_file.flush()?;
let result = load_binary_from_csv(temp_file.path())?;
assert_eq!(result, vec![1, 0, 1]);
Ok(())
}
#[test]
fn test_load_binary_from_csv_invalid_data() {
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "1").unwrap();
writeln!(temp_file, "2").unwrap(); temp_file.flush().unwrap();
let result = load_binary_from_csv(temp_file.path());
assert!(result.is_err());
}
#[test]
fn test_load_binary_from_csv_nonexistent_file() {
let result = load_binary_from_csv("nonexistent_file.csv");
assert!(result.is_err());
}
#[test]
fn test_format_float() {
assert_eq!(format_float(0.12345, 3), "0.123");
assert_eq!(format_float(1.0, 2), "1.00");
assert_eq!(format_float(f64::NAN, 2), "NaN");
assert_eq!(format_float(f64::INFINITY, 2), "∞");
assert_eq!(format_float(f64::NEG_INFINITY, 2), "-∞");
}
#[test]
fn test_format_percentage() {
assert_eq!(format_percentage(0.12345, 1), "12.3%");
assert_eq!(format_percentage(0.5, 0), "50%");
assert_eq!(format_percentage(f64::NAN, 2), "NaN%");
}
#[test]
fn test_validate_probability() {
assert!(validate_probability(0.5, "test").is_ok());
assert!(validate_probability(0.0, "test").is_ok());
assert!(validate_probability(1.0, "test").is_ok());
assert!(validate_probability(-0.1, "test").is_err());
assert!(validate_probability(1.1, "test").is_err());
assert!(validate_probability(f64::NAN, "test").is_err());
}
#[test]
fn test_parse_range() {
let (min, max) = parse_range("0.5,0.9").unwrap();
assert_eq!(min, 0.5);
assert_eq!(max, 0.9);
let (min, max) = parse_range(" 0.1 , 0.8 ").unwrap();
assert_eq!(min, 0.1);
assert_eq!(max, 0.8);
assert!(parse_range("0.5").is_err()); assert!(parse_range("0.5,0.4").is_err()); assert!(parse_range("a,0.5").is_err()); assert!(parse_range("0.5,b").is_err()); }
}