use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum FieldSchema {
String {
mutator: StringMutatorKind,
},
Float {
range: (f64, f64),
sigma: f64,
},
Integer {
range: (i64, i64),
sigma: f64,
},
Categorical {
choices: Vec<String>,
},
Set {
pool: Vec<String>,
max: Option<usize>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum StringMutatorKind {
LlmRewrite,
TemplateSlot {
slots: BTreeMap<String, Vec<String>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GenomeSchema {
pub fields: BTreeMap<String, FieldSchema>,
pub mutation_rates: BTreeMap<String, f64>,
}
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum SchemaError {
#[error("field {0}: float range inverted ({1} > {2})")]
InvertedFloatRange(String, f64, f64),
#[error("field {0}: integer range inverted ({1} > {2})")]
InvertedIntegerRange(String, i64, i64),
#[error("field {0}: categorical has no choices")]
EmptyCategorical(String),
#[error("field {0}: set pool is empty")]
EmptySetPool(String),
#[error("mutation rate references unknown field {0}")]
UnknownMutationRateField(String),
#[error("field {0}: mutation rate {1} not in [0.0, 1.0]")]
MutationRateOutOfRange(String, f64),
}
impl GenomeSchema {
pub fn validate(&self) -> Result<(), SchemaError> {
for (name, field) in &self.fields {
match field {
FieldSchema::Float {
range: (lo, hi), ..
} if lo > hi => {
return Err(SchemaError::InvertedFloatRange(name.clone(), *lo, *hi));
}
FieldSchema::Integer {
range: (lo, hi), ..
} if lo > hi => {
return Err(SchemaError::InvertedIntegerRange(name.clone(), *lo, *hi));
}
FieldSchema::Categorical { choices } if choices.is_empty() => {
return Err(SchemaError::EmptyCategorical(name.clone()));
}
FieldSchema::Set { pool, .. } if pool.is_empty() => {
return Err(SchemaError::EmptySetPool(name.clone()));
}
_ => {}
}
}
for (name, rate) in &self.mutation_rates {
if !self.fields.contains_key(name) {
return Err(SchemaError::UnknownMutationRateField(name.clone()));
}
if !(0.0..=1.0).contains(rate) {
return Err(SchemaError::MutationRateOutOfRange(name.clone(), *rate));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn schema_roundtrips_through_json() {
let schema = GenomeSchema {
fields: BTreeMap::from([(
"temperature".to_string(),
FieldSchema::Float {
range: (0.0, 2.0),
sigma: 0.2,
},
)]),
mutation_rates: BTreeMap::from([("temperature".to_string(), 0.1)]),
};
let json = serde_json::to_string(&schema).unwrap();
let back: GenomeSchema = serde_json::from_str(&json).unwrap();
assert_eq!(schema, back);
}
#[test]
fn validate_rejects_inverted_float_range() {
let bad = GenomeSchema {
fields: BTreeMap::from([(
"t".to_string(),
FieldSchema::Float {
range: (2.0, 0.0),
sigma: 0.1,
},
)]),
mutation_rates: BTreeMap::new(),
};
assert!(bad.validate().is_err());
}
#[test]
fn validate_rejects_empty_categorical() {
let bad = GenomeSchema {
fields: BTreeMap::from([(
"model".to_string(),
FieldSchema::Categorical { choices: vec![] },
)]),
mutation_rates: BTreeMap::new(),
};
assert!(bad.validate().is_err());
}
#[test]
fn validate_rejects_mutation_rate_for_unknown_field() {
let bad = GenomeSchema {
fields: BTreeMap::new(),
mutation_rates: BTreeMap::from([("nope".to_string(), 0.1)]),
};
assert!(bad.validate().is_err());
}
#[test]
fn validate_rejects_mutation_rate_out_of_range() {
let bad = GenomeSchema {
fields: BTreeMap::from([(
"t".to_string(),
FieldSchema::Float {
range: (0.0, 1.0),
sigma: 0.1,
},
)]),
mutation_rates: BTreeMap::from([("t".to_string(), 1.5)]),
};
assert!(bad.validate().is_err());
}
#[test]
fn validate_accepts_well_formed_schema() {
let good = GenomeSchema {
fields: BTreeMap::from([(
"t".to_string(),
FieldSchema::Float {
range: (0.0, 2.0),
sigma: 0.2,
},
)]),
mutation_rates: BTreeMap::from([("t".to_string(), 0.1)]),
};
assert!(good.validate().is_ok());
}
}