Skip to main content

datasynth_eval/calibration/
safety.rs

1//! C3 Piece 5 — safety rails for the calibration loop.
2//!
3//! Stateless inspectors that read the loop's step history and
4//! return diagnostic signals. The loop calls these between steps;
5//! a positive signal can drive a damping reduction, an emit-warn,
6//! or an outright abort.
7//!
8//! Rails shipped in this piece:
9//!
10//! - [`OscillationDetector`] — flags when the same knob has its Δ
11//!   sign-alternate across `window` recent steps, signalling the
12//!   optimizer is bouncing between two values.
13//! - [`KnobClipDiagnostics`] — tracks per-knob clip counts so the
14//!   loop's reporter can surface knobs that hit their bounds
15//!   often (= likely the bounds need widening, OR the optimizer
16//!   wants to escape its current basin).
17//! - [`WallClockBudget`] — wraps a clock so the loop can stop
18//!   after a configured elapsed time. Useful for overnight runs
19//!   that need a hard deadline.
20//!
21//! Overfit detection (calibration vs held-out validation seeds) is
22//! a follow-up — it requires a second evaluator stream and is
23//! tracked as a separate task in the design doc.
24
25use std::collections::BTreeMap;
26use std::time::{Duration, Instant};
27
28use super::knob::KnobValue;
29use super::loop_runner::StepReport;
30
31/// Detect oscillation on a specific knob across recent steps.
32///
33/// "Oscillation" = the knob's `(after - before)` delta has been
34/// sign-alternating across the configured `window` of the most
35/// recent steps that actually touched this knob. When that
36/// happens, the optimizer is bouncing between two values and the
37/// loop should reduce damping (or warn) rather than keep
38/// proposing the same step direction.
39pub struct OscillationDetector {
40    /// How many recent same-knob steps to scan. Default 5.
41    pub window: usize,
42    /// Minimum alternations needed within the window to flag
43    /// oscillation. Default 3 (≥ 3 sign-changes in 5 steps).
44    pub min_alternations: usize,
45}
46
47impl Default for OscillationDetector {
48    fn default() -> Self {
49        Self {
50            window: 5,
51            min_alternations: 3,
52        }
53    }
54}
55
56impl OscillationDetector {
57    /// Return true if `knob_path` is oscillating per the
58    /// configured window + alternation threshold. Skips steps
59    /// whose patch didn't touch `knob_path` and steps that didn't
60    /// apply a patch at all. Returns false when fewer than
61    /// `window` qualifying steps exist.
62    pub fn check(&self, history: &[StepReport], knob_path: &str) -> bool {
63        // Collect signs of Δ for this knob's recent steps.
64        let mut signs: Vec<i32> = Vec::new();
65        for step in history.iter().rev() {
66            if step.proposed_patch.is_none() {
67                continue;
68            }
69            // Did this step modify our knob? `knob_values` records
70            // the AFTER state; we don't have a direct "before"
71            // pointer, but the proposed_value tells us the
72            // direction the proposer wanted to go, so we use that
73            // as the oscillation signal (sign of proposed - prior
74            // step's value for the same knob).
75            let path = step.knob_values.keys().find(|p| *p == knob_path).cloned();
76            if path.is_none() {
77                continue;
78            }
79
80            // Look at the previous step's recorded value for this
81            // knob to compute Δ direction.
82            let cur = step.knob_values.get(knob_path).and_then(value_as_f64);
83            let prev = find_prior_value(history, step.iter, knob_path);
84            if let (Some(c), Some(p)) = (cur, prev) {
85                let delta = c - p;
86                if delta.abs() > f64::EPSILON {
87                    signs.push(if delta > 0.0 { 1 } else { -1 });
88                }
89            }
90            // Bail once we've reached the configured window.
91            if signs.len() >= self.window {
92                break;
93            }
94        }
95
96        if signs.len() < self.window {
97            return false;
98        }
99        // Count sign alternations.
100        let alternations = signs.windows(2).filter(|w| w[0] != w[1]).count();
101        alternations >= self.min_alternations
102    }
103}
104
105/// Look up the most recent value of `knob_path` recorded BEFORE
106/// the step at `iter`. Returns `None` if no prior step has a
107/// value for this knob (= we're at the start of the history).
108fn find_prior_value(history: &[StepReport], iter: usize, knob_path: &str) -> Option<f64> {
109    for step in history.iter().rev() {
110        if step.iter >= iter {
111            continue;
112        }
113        if let Some(v) = step.knob_values.get(knob_path) {
114            return value_as_f64(v);
115        }
116    }
117    None
118}
119
120fn value_as_f64(v: &KnobValue) -> Option<f64> {
121    Some(v.as_f64())
122}
123
124/// Per-knob diagnostics — currently just clip counts.
125///
126/// Updated by the loop after each `knob.apply()` call; surfaced in
127/// the final run report so a user can see which knobs hit their
128/// bounds often (= bounds probably too tight, OR the optimizer
129/// found a knob it can't move further).
130#[derive(Debug, Clone, Default)]
131pub struct KnobClipDiagnostics {
132    /// `knob_path → (n_clipped_low, n_clipped_high, n_in_range)`.
133    pub counts: BTreeMap<String, ClipCounts>,
134}
135
136#[derive(Debug, Clone, Default)]
137pub struct ClipCounts {
138    pub low: usize,
139    pub high: usize,
140    pub in_range: usize,
141    pub type_mismatch: usize,
142}
143
144impl KnobClipDiagnostics {
145    /// Record one `apply` outcome.
146    pub fn record(&mut self, knob_path: &str, result: super::knob::KnobClipResult) {
147        let entry = self.counts.entry(knob_path.to_string()).or_default();
148        use super::knob::KnobClipResult::*;
149        match result {
150            InRange => entry.in_range += 1,
151            ClippedLow => entry.low += 1,
152            ClippedHigh => entry.high += 1,
153            TypeMismatch => entry.type_mismatch += 1,
154        }
155    }
156
157    /// Knobs whose clip count exceeds `threshold` (low + high
158    /// combined). Useful in the run-report summary as a flag for
159    /// the human reviewer.
160    pub fn frequently_clipped(&self, threshold: usize) -> Vec<&String> {
161        self.counts
162            .iter()
163            .filter(|(_, c)| c.low + c.high >= threshold)
164            .map(|(k, _)| k)
165            .collect()
166    }
167}
168
169/// Wall-clock budget — wraps an `Instant` start time + a budget
170/// duration. Loop calls `.expired()` between steps and breaks
171/// when true.
172#[derive(Debug, Clone)]
173pub struct WallClockBudget {
174    start: Instant,
175    budget: Duration,
176}
177
178impl WallClockBudget {
179    /// Start the clock with `budget` remaining.
180    pub fn new(budget: Duration) -> Self {
181        Self {
182            start: Instant::now(),
183            budget,
184        }
185    }
186
187    /// Has the budget elapsed?
188    pub fn expired(&self) -> bool {
189        self.start.elapsed() >= self.budget
190    }
191
192    /// Time remaining (saturating at zero).
193    pub fn remaining(&self) -> Duration {
194        self.budget.saturating_sub(self.start.elapsed())
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::calibration::knob::{KnobClipResult, KnobValue};
202    use crate::calibration::loop_runner::{ProposedPatch, StepOutcome, StepReport};
203
204    fn make_step_with_knob(iter: usize, knob_path: &str, value: f64) -> StepReport {
205        let mut kv = BTreeMap::new();
206        kv.insert(knob_path.to_string(), KnobValue::F64(value));
207        StepReport {
208            iter,
209            loss_before_mean: 0.0,
210            loss_before_std: 0.0,
211            proposed_patch: Some(ProposedPatch {
212                knob_index: 0,
213                proposed_value: KnobValue::F64(value),
214                rationale: "test".into(),
215            }),
216            loss_after_mean: Some(0.0),
217            loss_after_std: Some(0.0),
218            knob_values: kv,
219            outcome: StepOutcome::Improved,
220        }
221    }
222
223    #[test]
224    fn oscillation_detector_flags_alternating_deltas() {
225        // Knob bouncing 0.05 → 0.07 → 0.05 → 0.07 → 0.05 → 0.07 (deltas: +,-,+,-,+).
226        let history = vec![
227            make_step_with_knob(0, "k", 0.05),
228            make_step_with_knob(1, "k", 0.07),
229            make_step_with_knob(2, "k", 0.05),
230            make_step_with_knob(3, "k", 0.07),
231            make_step_with_knob(4, "k", 0.05),
232            make_step_with_knob(5, "k", 0.07),
233        ];
234        let detector = OscillationDetector::default();
235        assert!(
236            detector.check(&history, "k"),
237            "alternating deltas across 5 same-knob steps should flag"
238        );
239    }
240
241    #[test]
242    fn oscillation_detector_quiet_on_monotonic_progress() {
243        // Knob walking monotonically: 0.05 → 0.06 → 0.07 → 0.08 → 0.09 → 0.10.
244        let history: Vec<_> = (0..6)
245            .map(|i| make_step_with_knob(i, "k", 0.05 + 0.01 * (i as f64)))
246            .collect();
247        let detector = OscillationDetector::default();
248        assert!(
249            !detector.check(&history, "k"),
250            "monotonic walk should not flag oscillation"
251        );
252    }
253
254    #[test]
255    fn oscillation_detector_needs_full_window() {
256        // Only 3 steps — below default window=5 — should not fire.
257        let history = vec![
258            make_step_with_knob(0, "k", 0.05),
259            make_step_with_knob(1, "k", 0.07),
260            make_step_with_knob(2, "k", 0.05),
261        ];
262        let detector = OscillationDetector::default();
263        assert!(!detector.check(&history, "k"));
264    }
265
266    #[test]
267    fn knob_clip_diagnostics_count_each_result() {
268        let mut diag = KnobClipDiagnostics::default();
269        diag.record("fraud.fraud_rate", KnobClipResult::InRange);
270        diag.record("fraud.fraud_rate", KnobClipResult::InRange);
271        diag.record("fraud.fraud_rate", KnobClipResult::ClippedHigh);
272        diag.record("fraud.fraud_rate", KnobClipResult::ClippedLow);
273        diag.record("pool.size", KnobClipResult::TypeMismatch);
274
275        let fraud = diag.counts.get("fraud.fraud_rate").unwrap();
276        assert_eq!(fraud.in_range, 2);
277        assert_eq!(fraud.high, 1);
278        assert_eq!(fraud.low, 1);
279        let pool = diag.counts.get("pool.size").unwrap();
280        assert_eq!(pool.type_mismatch, 1);
281    }
282
283    #[test]
284    fn frequently_clipped_filters_by_threshold() {
285        let mut diag = KnobClipDiagnostics::default();
286        for _ in 0..5 {
287            diag.record("often.clipped", KnobClipResult::ClippedHigh);
288        }
289        for _ in 0..2 {
290            diag.record("rarely.clipped", KnobClipResult::ClippedLow);
291        }
292        diag.record("never.clipped", KnobClipResult::InRange);
293
294        let flagged = diag.frequently_clipped(3);
295        assert_eq!(flagged.len(), 1);
296        assert_eq!(flagged[0], "often.clipped");
297    }
298
299    #[test]
300    fn wall_clock_budget_expires() {
301        let budget = WallClockBudget::new(Duration::from_millis(50));
302        assert!(!budget.expired());
303        std::thread::sleep(Duration::from_millis(80));
304        assert!(budget.expired());
305        assert_eq!(budget.remaining(), Duration::ZERO);
306    }
307}