datasynth_eval/calibration/
history.rs1use std::collections::BTreeMap;
26use std::path::Path;
27
28use serde::{Deserialize, Serialize};
29
30use super::knob::KnobValue;
31use super::loop_runner::{CalibrationLoop, StepReport};
32
33pub const HISTORY_SCHEMA_VERSION: &str = "1.0";
37
38#[derive(Debug)]
40pub enum HistoryError {
41 Io(std::io::Error),
42 Parse(serde_json::Error),
43 SchemaMismatch {
44 found: String,
45 expected: &'static str,
46 },
47}
48
49impl std::fmt::Display for HistoryError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Self::Io(e) => write!(f, "history IO: {e}"),
53 Self::Parse(e) => write!(f, "history JSON parse: {e}"),
54 Self::SchemaMismatch { found, expected } => write!(
55 f,
56 "history schema mismatch: file declares {found}, runtime expects {expected}"
57 ),
58 }
59 }
60}
61
62impl std::error::Error for HistoryError {}
63
64impl From<std::io::Error> for HistoryError {
65 fn from(e: std::io::Error) -> Self {
66 Self::Io(e)
67 }
68}
69
70impl From<serde_json::Error> for HistoryError {
71 fn from(e: serde_json::Error) -> Self {
72 Self::Parse(e)
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CalibrationHistory {
79 pub schema_version: String,
82 pub objective_metric: String,
86 pub steps: Vec<StepReport>,
89 pub best_loss_mean: Option<f64>,
92 pub best_loss_std: Option<f64>,
95 pub best_knob_values: BTreeMap<String, KnobValue>,
98}
99
100impl CalibrationHistory {
101 pub fn from_loop(loop_: &CalibrationLoop) -> Self {
105 Self {
106 schema_version: HISTORY_SCHEMA_VERSION.to_string(),
107 objective_metric: loop_.objective.metric.name().to_string(),
108 steps: loop_.history.clone(),
109 best_loss_mean: loop_.best_loss.map(|(m, _)| m),
110 best_loss_std: loop_.best_loss.map(|(_, s)| s),
111 best_knob_values: loop_.best_knob_values.clone(),
112 }
113 }
114
115 pub fn save(&self, path: &Path) -> Result<(), HistoryError> {
119 let tmp = path.with_extension("tmp");
120 let json = serde_json::to_string_pretty(self)?;
121 std::fs::write(&tmp, json)?;
122 std::fs::rename(&tmp, path)?;
123 Ok(())
124 }
125
126 pub fn load(path: &Path) -> Result<Self, HistoryError> {
130 let bytes = std::fs::read(path)?;
131 let parsed: Self = serde_json::from_slice(&bytes)?;
132 if parsed.schema_version != HISTORY_SCHEMA_VERSION {
133 return Err(HistoryError::SchemaMismatch {
134 found: parsed.schema_version,
135 expected: HISTORY_SCHEMA_VERSION,
136 });
137 }
138 Ok(parsed)
139 }
140
141 pub fn apply_to(&self, loop_: &mut CalibrationLoop) -> Result<(), HistoryError> {
151 if self.objective_metric != loop_.objective.metric.name() {
153 return Err(HistoryError::SchemaMismatch {
154 found: self.objective_metric.clone(),
155 expected: loop_.objective.metric.name(),
156 });
157 }
158 loop_.history = self.steps.clone();
159 loop_.best_loss = match (self.best_loss_mean, self.best_loss_std) {
160 (Some(m), Some(s)) => Some((m, s)),
161 (Some(m), None) => Some((m, 0.0)),
162 _ => None,
163 };
164 loop_.best_knob_values = self.best_knob_values.clone();
165 if let Some(last) = self.steps.last() {
169 for knob in &mut loop_.knobs {
170 if let Some(v) = last.knob_values.get(&knob.path) {
171 knob.current = *v;
172 }
173 }
174 }
175 Ok(())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::calibration::knob::CalibrationKnob;
183 use crate::calibration::loop_runner::{CalibrationConfig, StepOutcome};
184 use crate::calibration::objective::CalibrationObjective;
185 use std::collections::BTreeMap;
186 use tempfile::TempDir;
187
188 fn empty_loop() -> CalibrationLoop {
189 CalibrationLoop::new(
190 CalibrationObjective::bf_composite(),
191 vec![CalibrationKnob::new_f64("test.rate", 0.10, 0.0, 1.0, 0.02)],
192 CalibrationConfig::default(),
193 )
194 }
195
196 fn fake_step(iter: usize, before: f64, after: f64, knob_value: f64) -> StepReport {
197 let mut kv = BTreeMap::new();
198 kv.insert("test.rate".to_string(), KnobValue::F64(knob_value));
199 StepReport {
200 iter,
201 loss_before_mean: before,
202 loss_before_std: 1.0,
203 proposed_patch: None,
204 loss_after_mean: Some(after),
205 loss_after_std: Some(1.0),
206 knob_values: kv,
207 outcome: StepOutcome::Improved,
208 }
209 }
210
211 #[test]
212 fn save_and_load_round_trips() {
213 let tmp = TempDir::new().unwrap();
214 let path = tmp.path().join("calibration_history.json");
215
216 let mut loop_ = empty_loop();
217 loop_.history.push(fake_step(0, 50.0, 45.0, 0.08));
218 loop_.history.push(fake_step(1, 45.0, 40.0, 0.06));
219 loop_.best_loss = Some((40.0, 1.0));
220 loop_
221 .best_knob_values
222 .insert("test.rate".into(), KnobValue::F64(0.06));
223
224 let history = CalibrationHistory::from_loop(&loop_);
225 history.save(&path).unwrap();
226
227 let loaded = CalibrationHistory::load(&path).unwrap();
228 assert_eq!(loaded.schema_version, HISTORY_SCHEMA_VERSION);
229 assert_eq!(loaded.objective_metric, "bf_composite");
230 assert_eq!(loaded.steps.len(), 2);
231 assert_eq!(loaded.best_loss_mean, Some(40.0));
232 assert_eq!(
233 loaded.best_knob_values.get("test.rate"),
234 Some(&KnobValue::F64(0.06))
235 );
236 }
237
238 #[test]
239 fn schema_mismatch_rejected_on_load() {
240 let tmp = TempDir::new().unwrap();
241 let path = tmp.path().join("history.json");
242 let bad = r#"{
243 "schema_version": "99.99",
244 "objective_metric": "bf_composite",
245 "steps": [],
246 "best_loss_mean": null,
247 "best_loss_std": null,
248 "best_knob_values": {}
249 }"#;
250 std::fs::write(&path, bad).unwrap();
251 let err = CalibrationHistory::load(&path).expect_err("schema must mismatch");
252 assert!(
253 matches!(err, HistoryError::SchemaMismatch { .. }),
254 "expected SchemaMismatch, got {err:?}"
255 );
256 }
257
258 #[test]
259 fn apply_to_restores_knob_state_and_history() {
260 let tmp = TempDir::new().unwrap();
261 let path = tmp.path().join("h.json");
262
263 let mut src = empty_loop();
265 src.history.push(fake_step(0, 50.0, 45.0, 0.06));
266 src.best_loss = Some((45.0, 1.0));
267 src.best_knob_values
268 .insert("test.rate".into(), KnobValue::F64(0.06));
269 src.knobs[0].current = KnobValue::F64(0.06);
270
271 CalibrationHistory::from_loop(&src).save(&path).unwrap();
272
273 let mut dst = empty_loop();
275 assert_eq!(dst.knobs[0].current.as_f64(), 0.10);
276 assert!(dst.history.is_empty());
277
278 CalibrationHistory::load(&path)
279 .unwrap()
280 .apply_to(&mut dst)
281 .unwrap();
282
283 assert_eq!(dst.history.len(), 1);
284 assert_eq!(dst.best_loss, Some((45.0, 1.0)));
285 assert!(
286 (dst.knobs[0].current.as_f64() - 0.06).abs() < 1e-9,
287 "knob should resume to last-step value: got {}",
288 dst.knobs[0].current
289 );
290 }
291
292 #[test]
293 fn apply_to_rejects_objective_mismatch() {
294 let tmp = TempDir::new().unwrap();
295 let path = tmp.path().join("h.json");
296
297 let src = empty_loop();
299 CalibrationHistory::from_loop(&src).save(&path).unwrap();
300
301 let mut dst = CalibrationLoop::new(
303 CalibrationObjective::default()
304 .with_metric(crate::calibration::ObjectiveMetric::BfCompositeMedian),
305 vec![CalibrationKnob::new_f64("test.rate", 0.10, 0.0, 1.0, 0.02)],
306 CalibrationConfig::default(),
307 );
308
309 let err = CalibrationHistory::load(&path)
310 .unwrap()
311 .apply_to(&mut dst)
312 .expect_err("objective mismatch must reject");
313 assert!(matches!(err, HistoryError::SchemaMismatch { .. }));
314 }
315
316 #[test]
317 fn save_uses_atomic_rename() {
318 let tmp = TempDir::new().unwrap();
323 let path = tmp.path().join("history.json");
324 let tmp_path = path.with_extension("tmp");
325
326 let loop_ = empty_loop();
327 CalibrationHistory::from_loop(&loop_).save(&path).unwrap();
328
329 assert!(path.exists(), "target file should exist after save");
330 assert!(
331 !tmp_path.exists(),
332 "tmp staging file should be renamed away after save"
333 );
334 }
335}