1use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(tag = "type", rename_all = "snake_case")]
10pub enum FieldSchema {
11 String {
13 mutator: StringMutatorKind,
15 },
16 Float {
18 range: (f64, f64),
20 sigma: f64,
22 },
23 Integer {
25 range: (i64, i64),
27 sigma: f64,
29 },
30 Categorical {
32 choices: Vec<String>,
34 },
35 Set {
37 pool: Vec<String>,
39 max: Option<usize>,
41 },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46#[serde(rename_all = "snake_case")]
47pub enum StringMutatorKind {
48 LlmRewrite,
50 TemplateSlot {
53 slots: BTreeMap<String, Vec<String>>,
55 },
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct GenomeSchema {
61 pub fields: BTreeMap<String, FieldSchema>,
63 pub mutation_rates: BTreeMap<String, f64>,
66}
67
68#[derive(thiserror::Error, Debug, PartialEq)]
70pub enum SchemaError {
71 #[error("field {0}: float range inverted ({1} > {2})")]
73 InvertedFloatRange(String, f64, f64),
74 #[error("field {0}: integer range inverted ({1} > {2})")]
76 InvertedIntegerRange(String, i64, i64),
77 #[error("field {0}: categorical has no choices")]
79 EmptyCategorical(String),
80 #[error("field {0}: set pool is empty")]
82 EmptySetPool(String),
83 #[error("mutation rate references unknown field {0}")]
85 UnknownMutationRateField(String),
86 #[error("field {0}: mutation rate {1} not in [0.0, 1.0]")]
88 MutationRateOutOfRange(String, f64),
89}
90
91impl GenomeSchema {
92 pub fn validate(&self) -> Result<(), SchemaError> {
100 for (name, field) in &self.fields {
101 match field {
102 FieldSchema::Float {
103 range: (lo, hi), ..
104 } if lo > hi => {
105 return Err(SchemaError::InvertedFloatRange(name.clone(), *lo, *hi));
106 }
107 FieldSchema::Integer {
108 range: (lo, hi), ..
109 } if lo > hi => {
110 return Err(SchemaError::InvertedIntegerRange(name.clone(), *lo, *hi));
111 }
112 FieldSchema::Categorical { choices } if choices.is_empty() => {
113 return Err(SchemaError::EmptyCategorical(name.clone()));
114 }
115 FieldSchema::Set { pool, .. } if pool.is_empty() => {
116 return Err(SchemaError::EmptySetPool(name.clone()));
117 }
118 _ => {}
119 }
120 }
121 for (name, rate) in &self.mutation_rates {
122 if !self.fields.contains_key(name) {
123 return Err(SchemaError::UnknownMutationRateField(name.clone()));
124 }
125 if !(0.0..=1.0).contains(rate) {
126 return Err(SchemaError::MutationRateOutOfRange(name.clone(), *rate));
127 }
128 }
129 Ok(())
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn schema_roundtrips_through_json() {
139 let schema = GenomeSchema {
140 fields: BTreeMap::from([(
141 "temperature".to_string(),
142 FieldSchema::Float {
143 range: (0.0, 2.0),
144 sigma: 0.2,
145 },
146 )]),
147 mutation_rates: BTreeMap::from([("temperature".to_string(), 0.1)]),
148 };
149 let json = serde_json::to_string(&schema).unwrap();
150 let back: GenomeSchema = serde_json::from_str(&json).unwrap();
151 assert_eq!(schema, back);
152 }
153
154 #[test]
155 fn validate_rejects_inverted_float_range() {
156 let bad = GenomeSchema {
157 fields: BTreeMap::from([(
158 "t".to_string(),
159 FieldSchema::Float {
160 range: (2.0, 0.0),
161 sigma: 0.1,
162 },
163 )]),
164 mutation_rates: BTreeMap::new(),
165 };
166 assert!(bad.validate().is_err());
167 }
168
169 #[test]
170 fn validate_rejects_empty_categorical() {
171 let bad = GenomeSchema {
172 fields: BTreeMap::from([(
173 "model".to_string(),
174 FieldSchema::Categorical { choices: vec![] },
175 )]),
176 mutation_rates: BTreeMap::new(),
177 };
178 assert!(bad.validate().is_err());
179 }
180
181 #[test]
182 fn validate_rejects_mutation_rate_for_unknown_field() {
183 let bad = GenomeSchema {
184 fields: BTreeMap::new(),
185 mutation_rates: BTreeMap::from([("nope".to_string(), 0.1)]),
186 };
187 assert!(bad.validate().is_err());
188 }
189
190 #[test]
191 fn validate_rejects_mutation_rate_out_of_range() {
192 let bad = GenomeSchema {
193 fields: BTreeMap::from([(
194 "t".to_string(),
195 FieldSchema::Float {
196 range: (0.0, 1.0),
197 sigma: 0.1,
198 },
199 )]),
200 mutation_rates: BTreeMap::from([("t".to_string(), 1.5)]),
201 };
202 assert!(bad.validate().is_err());
203 }
204
205 #[test]
206 fn validate_accepts_well_formed_schema() {
207 let good = GenomeSchema {
208 fields: BTreeMap::from([(
209 "t".to_string(),
210 FieldSchema::Float {
211 range: (0.0, 2.0),
212 sigma: 0.2,
213 },
214 )]),
215 mutation_rates: BTreeMap::from([("t".to_string(), 0.1)]),
216 };
217 assert!(good.validate().is_ok());
218 }
219}