#[macro_export]
#[doc(hidden)]
macro_rules! schema_key_to_str {
($key:ident) => {
stringify!($key)
};
($key:literal) => {
$key
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! schema_field_constraint {
($fb:expr, non_empty) => {
$fb.non_empty()
};
($fb:expr, required) => {
$fb.required()
};
($fb:expr, range($lo:expr, $hi:expr)) => {
$fb.range_i64($lo, $hi)
};
($fb:expr, range_f64($lo:expr, $hi:expr)) => {
$fb.range_f64($lo, $hi)
};
($fb:expr, one_of($($allowed:expr),* $(,)?)) => {
$fb.one_of(&[$($allowed),*])
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! schema_field_constraints {
($fb:expr, ) => {
$fb
};
($fb:expr, $cname:ident ($($cargs:tt)*) , $($rest:tt)*) => {
{
let fb = $crate::schema_field_constraint!($fb, $cname ($($cargs)*));
$crate::schema_field_constraints!(fb, $($rest)*)
}
};
($fb:expr, $cname:ident ($($cargs:tt)*)) => {
$crate::schema_field_constraint!($fb, $cname ($($cargs)*))
};
($fb:expr, $cname:ident , $($rest:tt)*) => {
{
let fb = $crate::schema_field_constraint!($fb, $cname);
$crate::schema_field_constraints!(fb, $($rest)*)
}
};
($fb:expr, $cname:ident) => {
$crate::schema_field_constraint!($fb, $cname)
};
}
#[macro_export]
macro_rules! schema {
( @entry $s:ident, $key:tt, { $($nested:tt)* } ) => {
let key_str = $crate::schema_key_to_str!($key);
let nested_schema = $crate::schema! { $($nested)* };
$s = $s.section(key_str, nested_schema);
};
( @entry $s:ident, $key:tt, [ $($constraints:tt)* ] ) => {
let key_str = $crate::schema_key_to_str!($key);
let fb = $s.field(key_str);
let fb = $crate::schema_field_constraints!(fb, $($constraints)*);
$s = fb.done();
};
( @entry $s:ident, $key:tt, $cname:ident ($($cargs:tt)*) ) => {
let key_str = $crate::schema_key_to_str!($key);
let fb = $s.field(key_str);
let fb = $crate::schema_field_constraint!(fb, $cname ($($cargs)*));
$s = fb.done();
};
( @entry $s:ident, $key:tt, $cname:ident ) => {
let key_str = $crate::schema_key_to_str!($key);
let fb = $s.field(key_str);
let fb = $crate::schema_field_constraint!(fb, $cname);
$s = fb.done();
};
( $($key:tt : $val:tt $(($($args:tt)*))? ),* $(,)? ) => {
{
let mut s = $crate::Schema::new();
$(
$crate::schema!(@entry s, $key, $val $(($($args)*))?);
)*
s
}
};
}
use toml::Value;
use crate::validation::{ErrorKind, Loc, LocSegment, Severity, ValidationError, ValidationErrors};
#[derive(Debug, Clone)]
enum Constraint {
NonEmpty,
RangeI64 { lo: i64, hi: i64 },
RangeF64 { lo: f64, hi: f64 },
OneOf { allowed: Vec<String> },
Required,
Predicate { code: &'static str, msg: String, test: fn(&Value) -> bool },
}
impl Constraint {
fn check(
&self,
field: &str,
value: Option<&Value>,
checks_run: &mut usize,
) -> Option<ValidationError> {
*checks_run += 1;
match self {
Self::Required => {
if value.is_none() {
return Some(make_err(field, ErrorKind::Missing, None, "field is required"));
}
None
}
Self::NonEmpty => {
let s = value.and_then(Value::as_str).unwrap_or("");
if s.is_empty() {
return Some(make_err(
field,
ErrorKind::Empty,
Some("\"\"".into()),
"must not be empty",
));
}
None
}
Self::RangeI64 { lo, hi } => {
let n = value.and_then(Value::as_integer).unwrap_or(0);
if n < *lo || n > *hi {
let msg = format!("input must be in range {lo}..={hi}");
return Some(make_err(
field,
ErrorKind::OutOfRange {
lower: Some(lo.to_string()),
upper: Some(hi.to_string()),
},
Some(n.to_string()),
msg,
));
}
None
}
Self::RangeF64 { lo, hi } => {
let n = value.and_then(Value::as_float).unwrap_or(0.0);
if n.is_nan() || n < *lo || n > *hi {
let msg = format!("input must be in range {lo}..={hi}");
return Some(make_err(
field,
ErrorKind::OutOfRange {
lower: Some(lo.to_string()),
upper: Some(hi.to_string()),
},
Some(n.to_string()),
msg,
));
}
None
}
Self::OneOf { allowed } => {
let s = value.and_then(Value::as_str).unwrap_or("");
if !allowed.iter().any(|a| a == s) {
let msg = format!("must be one of: {}", allowed.join(", "));
return Some(make_err(
field,
ErrorKind::NotOneOf { allowed: allowed.clone() },
Some(s.to_string()),
msg,
));
}
None
}
Self::Predicate { code, msg, test } => {
if !test(value.unwrap_or(&Value::Boolean(false))) {
return Some(make_err(field, ErrorKind::Predicate { code }, None, msg.clone()));
}
None
}
}
}
}
fn make_err(
field: &str,
kind: ErrorKind,
input: Option<String>,
msg: impl Into<String>,
) -> ValidationError {
ValidationError {
loc: Loc(vec![LocSegment::Key(field.to_string())]),
kind,
severity: Severity::Error,
input,
msg: msg.into(),
}
}
pub struct FieldBuilder<'a> {
schema: &'a mut Schema,
name: String,
}
impl<'a> FieldBuilder<'a> {
pub fn non_empty(self) -> Self {
self.add(Constraint::NonEmpty)
}
pub fn range_i64(self, lo: i64, hi: i64) -> Self {
self.add(Constraint::RangeI64 { lo, hi })
}
pub fn range_f64(self, lo: f64, hi: f64) -> Self {
self.add(Constraint::RangeF64 { lo, hi })
}
pub fn one_of(self, allowed: &[&str]) -> Self {
let allowed = allowed.iter().map(|s| (*s).to_string()).collect();
self.add(Constraint::OneOf { allowed })
}
pub fn required(self) -> Self {
self.add(Constraint::Required)
}
pub fn predicate(
self,
code: &'static str,
msg: impl Into<String>,
test: fn(&Value) -> bool,
) -> Self {
self.add(Constraint::Predicate { code, msg: msg.into(), test })
}
fn add(self, c: Constraint) -> Self {
let idx = self.schema.fields.iter().position(|(n, _)| n == &self.name);
match idx {
Some(i) => self.schema.fields[i].1.push(c),
None => self.schema.fields.push((self.name.clone(), vec![c])),
}
self
}
pub fn done(self) -> Schema {
Schema { fields: self.schema.fields.clone(), sections: self.schema.sections.clone() }
}
}
#[derive(Debug, Clone, Default)]
pub struct Schema {
fields: Vec<(String, Vec<Constraint>)>,
sections: Vec<(String, Schema)>,
}
impl Schema {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn field(&mut self, name: &str) -> FieldBuilder<'_> {
FieldBuilder { schema: self, name: name.to_string() }
}
pub fn section(mut self, name: &str, sub: Schema) -> Self {
self.sections.push((name.to_string(), sub));
self
}
pub fn validate_value(&self, value: &Value) -> Result<(), ValidationErrors> {
let mut errors = Vec::new();
let mut checks_run = 0usize;
self.check_value(value, &[], &mut errors, &mut checks_run);
if errors.is_empty() {
Ok(())
} else {
Err(ValidationErrors { errors, title: None, checks_run })
}
}
pub fn validate_str(&self, toml_str: &str) -> Result<(), ValidationErrors> {
let value: Value = match toml::from_str(toml_str) {
Ok(v) => v,
Err(e) => {
let err = ValidationError {
loc: Loc::default(),
kind: ErrorKind::Predicate { code: "parse_error" },
severity: Severity::Fatal,
input: None,
msg: e.to_string(),
};
return Err(ValidationErrors {
errors: vec![err],
title: Some("TOML".into()),
checks_run: 0,
});
}
};
self.validate_value(&value)
}
#[must_use]
pub fn constraint_count(&self) -> usize {
let direct: usize = self.fields.iter().map(|(_, cs)| cs.len()).sum();
let nested: usize = self.sections.iter().map(|(_, s)| s.constraint_count()).sum();
direct + nested
}
fn check_value(
&self,
value: &Value,
prefix: &[LocSegment],
errors: &mut Vec<ValidationError>,
checks_run: &mut usize,
) {
for (name, constraints) in &self.fields {
let child = value.get(name.as_str());
for c in constraints {
if let Some(mut e) = c.check(name, child, checks_run) {
let mut loc_segs = prefix.to_vec();
loc_segs.extend(e.loc.0.drain(..));
e.loc = Loc(loc_segs);
errors.push(e);
}
}
}
for (section_name, sub_schema) in &self.sections {
let sub_value = value.get(section_name.as_str());
let mut sub_prefix = prefix.to_vec();
sub_prefix.push(LocSegment::Key(section_name.clone()));
match sub_value {
Some(v) => sub_schema.check_value(v, &sub_prefix, errors, checks_run),
None => {
sub_schema.report_section_missing(
section_name,
&sub_prefix,
errors,
checks_run,
);
}
}
}
}
fn report_section_missing(
&self,
_section: &str,
prefix: &[LocSegment],
errors: &mut Vec<ValidationError>,
checks_run: &mut usize,
) {
for (name, constraints) in &self.fields {
for c in constraints {
if let Some(mut e) = c.check(name, None, checks_run) {
let mut loc_segs = prefix.to_vec();
loc_segs.extend(e.loc.0.drain(..));
e.loc = Loc(loc_segs);
errors.push(e);
}
}
}
for (sub_section_name, sub_schema) in &self.sections {
let mut sub_prefix = prefix.to_vec();
sub_prefix.push(LocSegment::Key(sub_section_name.clone()));
sub_schema.report_section_missing(sub_section_name, &sub_prefix, errors, checks_run);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn server_schema() -> Schema {
Schema::new().field("host").non_empty().done().field("port").range_i64(1, 65535).done()
}
fn app_schema() -> Schema {
Schema::new()
.field("name")
.non_empty()
.done()
.field("workers")
.range_i64(1, 1024)
.done()
.field("log_level")
.one_of(&["trace", "debug", "info", "warn", "error"])
.done()
.section("server", server_schema())
}
#[test]
fn valid_config_passes() {
let toml = r#"
name = "demo"
workers = 8
log_level = "info"
[server]
host = "localhost"
port = 8080
"#;
assert!(app_schema().validate_str(toml).is_ok());
}
#[test]
fn collects_all_errors() {
let toml = r#"
name = ""
workers = 0
log_level = "verbose"
[server]
host = ""
port = 0
"#;
let errs = app_schema().validate_str(toml).unwrap_err();
assert_eq!(errs.len(), 5);
}
#[test]
fn nested_section_paths_are_prefixed() {
let toml = "name = \"ok\"\nworkers = 4\nlog_level = \"info\"\n[server]\nhost = \"\"\nport = 8080\n";
let errs = app_schema().validate_str(toml).unwrap_err();
let locs: Vec<String> = errs.errors().iter().map(|e| e.loc.to_string()).collect();
assert!(locs.contains(&"server.host".to_string()));
}
#[test]
fn fitness_reflects_partial_pass() {
let toml = "name = \"ok\"\nworkers = 0\nlog_level = \"info\"\n[server]\nhost = \"h\"\nport = 8080\n";
let errs = app_schema().validate_str(toml).unwrap_err();
assert!(errs.fitness() > 0.0 && errs.fitness() < 1.0);
}
#[test]
fn variant_id_stable_across_equal_error_patterns() {
let toml =
"name = \"\"\nworkers = 1\nlog_level = \"info\"\n[server]\nhost = \"h\"\nport = 80\n";
let id1 = app_schema().validate_str(toml).unwrap_err().variant_id();
let id2 = app_schema().validate_str(toml).unwrap_err().variant_id();
assert_eq!(id1, id2);
}
#[test]
fn parse_error_produces_fatal_error() {
let errs = Schema::new().validate_str("not valid toml :::").unwrap_err();
assert!(errs.errors()[0].is_fatal());
assert_eq!(errs.errors()[0].code(), "parse_error");
}
#[test]
fn one_of_constraint() {
let schema = Schema::new().field("level").one_of(&["info", "warn", "error"]).done();
assert!(schema.validate_str("level = \"info\"").is_ok());
let errs = schema.validate_str("level = \"verbose\"").unwrap_err();
assert_eq!(errs.errors()[0].code(), "not_one_of");
}
#[test]
fn range_f64_constraint() {
let schema = Schema::new().field("ratio").range_f64(0.0, 1.0).done();
assert!(schema.validate_str("ratio = 0.5").is_ok());
let errs = schema.validate_str("ratio = 2.0").unwrap_err();
assert_eq!(errs.errors()[0].code(), "out_of_range");
}
#[test]
fn predicate_constraint() {
let schema = Schema::new()
.field("port")
.predicate("no_well_known", "prefer ports above 1024", |v| {
v.as_integer().map_or(true, |n| n > 1024)
})
.done();
assert!(schema.validate_str("port = 8080").is_ok());
let errs = schema.validate_str("port = 80").unwrap_err();
assert_eq!(errs.errors()[0].code(), "no_well_known");
}
#[test]
fn by_section_grouping_works_on_schema_errors() {
let toml =
"name = \"\"\nworkers = 0\nlog_level = \"info\"\n[server]\nhost = \"\"\nport = 8080\n";
let errs = app_schema().validate_str(toml).unwrap_err();
let by_sec = errs.by_section();
assert!(by_sec.contains_key("name"));
assert!(by_sec.contains_key("workers"));
assert!(by_sec.contains_key("server"));
}
}