use std::path::Path;
use serde::de::DeserializeOwned;
use crate::error::{KernelError, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FieldError {
pub path: String,
pub expected: String,
pub value: String,
}
impl std::fmt::Display for FieldError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"field '{}': expected {}, got {}",
self.path, self.expected, self.value
)
}
}
impl std::error::Error for FieldError {}
fn parse_serde_errors(msg: &str) -> Vec<FieldError> {
let mut errors = Vec::new();
for line in msg.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
if line.starts_with("TOML parse error")
|| line.starts_with('|')
|| line.starts_with("+-")
|| line
.chars()
.next()
.is_some_and(|c| c.is_ascii_digit() && line.contains('|'))
{
continue;
}
if let Some(rest) = line.strip_prefix("invalid type: ") {
if let Some((found_part, expected_part)) = rest.split_once(", expected ") {
let value = found_part
.strip_prefix("found ")
.unwrap_or(found_part)
.trim();
errors.push(FieldError {
path: ".".into(),
expected: expected_part.trim().to_string(),
value: value.to_string(),
});
}
continue;
}
if let Some(rest) = line.strip_prefix("invalid value: ") {
if let Some((found_part, expected_part)) = rest.split_once(", expected ") {
let value = found_part
.strip_prefix("found ")
.unwrap_or(found_part)
.trim();
errors.push(FieldError {
path: ".".into(),
expected: expected_part.trim().to_string(),
value: value.to_string(),
});
}
continue;
}
if let Some(rest) = line.strip_prefix("missing field `") {
if let Some(field) = rest.strip_suffix('`') {
errors.push(FieldError {
path: field.to_string(),
expected: "unknown".into(),
value: "missing".into(),
});
}
continue;
}
if let Some(rest) = line.strip_prefix("unknown field `") {
if let Some((field, rest2)) = rest.split_once('`') {
errors.push(FieldError {
path: field.to_string(),
expected: rest2.trim_start_matches(", expected ").to_string(),
value: "unknown field".into(),
});
}
continue;
}
errors.push(FieldError {
path: ".".into(),
expected: "unknown".into(),
value: line.to_string(),
});
}
if errors.is_empty() {
errors.push(FieldError {
path: ".".into(),
expected: "unknown".into(),
value: msg.to_string(),
});
}
errors
}
pub fn validate_config<T: DeserializeOwned>(
toml_str: &str,
) -> std::result::Result<T, Vec<FieldError>> {
match toml::from_str::<T>(toml_str) {
Ok(v) => Ok(v),
Err(e) => Err(parse_serde_errors(&e.to_string())),
}
}
pub fn load_toml_config<T: DeserializeOwned>(path: &Path, template: Option<&str>) -> Result<T> {
if !path.exists()
&& let Some(tmpl) = template
{
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, tmpl)?;
}
let content = std::fs::read_to_string(path)?;
toml::from_str(&content).map_err(|e| KernelError::Config(format!("{}: {}", path.display(), e)))
}
pub fn parse_toml_config<T: DeserializeOwned>(content: &str) -> Result<T> {
toml::from_str(content).map_err(|e| KernelError::Config(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use tempfile::TempDir;
#[derive(Debug, Deserialize, PartialEq)]
struct TestConfig {
name: String,
#[serde(default = "default_port")]
port: u16,
}
fn default_port() -> u16 {
8080
}
#[test]
fn test_load_existing_config() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.toml");
std::fs::write(&path, "name = \"test\"\nport = 3000\n").unwrap();
let config: TestConfig = load_toml_config(&path, None).unwrap();
assert_eq!(config.name, "test");
assert_eq!(config.port, 3000);
}
#[test]
fn test_load_creates_from_template() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("new.toml");
let template = "name = \"default\"\nport = 9090\n";
let config: TestConfig = load_toml_config(&path, Some(template)).unwrap();
assert_eq!(config.name, "default");
assert_eq!(config.port, 9090);
assert!(path.exists());
}
#[test]
fn test_parse_toml_config() {
let config: TestConfig = parse_toml_config("name = \"hello\"").unwrap();
assert_eq!(config.name, "hello");
assert_eq!(config.port, 8080); }
#[test]
fn validate_config_success() {
let result = validate_config::<TestConfig>("name = \"test\"\nport = 3000");
assert!(result.is_ok());
let config = result.unwrap();
assert_eq!(config.port, 3000);
}
#[test]
fn validate_config_wrong_type() {
let result = validate_config::<TestConfig>("name = \"test\"\nport = \"not_a_number\"");
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].expected, "u16");
assert!(errors[0].value.contains("not_a_number"));
}
#[test]
fn validate_config_missing_field() {
let result = validate_config::<TestConfig>("port = 3000");
assert!(result.is_err());
let errors = result.unwrap_err();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].path, "name");
assert_eq!(errors[0].value, "missing");
}
#[test]
fn validate_config_unknown_field() {
let result = validate_config::<TestConfig>("name = \"test\"\nextra = true");
assert!(result.is_ok());
}
}