datasynth_eval/calibration/
safety.rs1use std::collections::BTreeMap;
26use std::time::{Duration, Instant};
27
28use super::knob::KnobValue;
29use super::loop_runner::StepReport;
30
31pub struct OscillationDetector {
40 pub window: usize,
42 pub min_alternations: usize,
45}
46
47impl Default for OscillationDetector {
48 fn default() -> Self {
49 Self {
50 window: 5,
51 min_alternations: 3,
52 }
53 }
54}
55
56impl OscillationDetector {
57 pub fn check(&self, history: &[StepReport], knob_path: &str) -> bool {
63 let mut signs: Vec<i32> = Vec::new();
65 for step in history.iter().rev() {
66 if step.proposed_patch.is_none() {
67 continue;
68 }
69 let path = step.knob_values.keys().find(|p| *p == knob_path).cloned();
76 if path.is_none() {
77 continue;
78 }
79
80 let cur = step.knob_values.get(knob_path).and_then(value_as_f64);
83 let prev = find_prior_value(history, step.iter, knob_path);
84 if let (Some(c), Some(p)) = (cur, prev) {
85 let delta = c - p;
86 if delta.abs() > f64::EPSILON {
87 signs.push(if delta > 0.0 { 1 } else { -1 });
88 }
89 }
90 if signs.len() >= self.window {
92 break;
93 }
94 }
95
96 if signs.len() < self.window {
97 return false;
98 }
99 let alternations = signs.windows(2).filter(|w| w[0] != w[1]).count();
101 alternations >= self.min_alternations
102 }
103}
104
105fn find_prior_value(history: &[StepReport], iter: usize, knob_path: &str) -> Option<f64> {
109 for step in history.iter().rev() {
110 if step.iter >= iter {
111 continue;
112 }
113 if let Some(v) = step.knob_values.get(knob_path) {
114 return value_as_f64(v);
115 }
116 }
117 None
118}
119
120fn value_as_f64(v: &KnobValue) -> Option<f64> {
121 Some(v.as_f64())
122}
123
124#[derive(Debug, Clone, Default)]
131pub struct KnobClipDiagnostics {
132 pub counts: BTreeMap<String, ClipCounts>,
134}
135
136#[derive(Debug, Clone, Default)]
137pub struct ClipCounts {
138 pub low: usize,
139 pub high: usize,
140 pub in_range: usize,
141 pub type_mismatch: usize,
142}
143
144impl KnobClipDiagnostics {
145 pub fn record(&mut self, knob_path: &str, result: super::knob::KnobClipResult) {
147 let entry = self.counts.entry(knob_path.to_string()).or_default();
148 use super::knob::KnobClipResult::*;
149 match result {
150 InRange => entry.in_range += 1,
151 ClippedLow => entry.low += 1,
152 ClippedHigh => entry.high += 1,
153 TypeMismatch => entry.type_mismatch += 1,
154 }
155 }
156
157 pub fn frequently_clipped(&self, threshold: usize) -> Vec<&String> {
161 self.counts
162 .iter()
163 .filter(|(_, c)| c.low + c.high >= threshold)
164 .map(|(k, _)| k)
165 .collect()
166 }
167}
168
169#[derive(Debug, Clone)]
173pub struct WallClockBudget {
174 start: Instant,
175 budget: Duration,
176}
177
178impl WallClockBudget {
179 pub fn new(budget: Duration) -> Self {
181 Self {
182 start: Instant::now(),
183 budget,
184 }
185 }
186
187 pub fn expired(&self) -> bool {
189 self.start.elapsed() >= self.budget
190 }
191
192 pub fn remaining(&self) -> Duration {
194 self.budget.saturating_sub(self.start.elapsed())
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::calibration::knob::{KnobClipResult, KnobValue};
202 use crate::calibration::loop_runner::{ProposedPatch, StepOutcome, StepReport};
203
204 fn make_step_with_knob(iter: usize, knob_path: &str, value: f64) -> StepReport {
205 let mut kv = BTreeMap::new();
206 kv.insert(knob_path.to_string(), KnobValue::F64(value));
207 StepReport {
208 iter,
209 loss_before_mean: 0.0,
210 loss_before_std: 0.0,
211 proposed_patch: Some(ProposedPatch {
212 knob_index: 0,
213 proposed_value: KnobValue::F64(value),
214 rationale: "test".into(),
215 }),
216 loss_after_mean: Some(0.0),
217 loss_after_std: Some(0.0),
218 knob_values: kv,
219 outcome: StepOutcome::Improved,
220 }
221 }
222
223 #[test]
224 fn oscillation_detector_flags_alternating_deltas() {
225 let history = vec![
227 make_step_with_knob(0, "k", 0.05),
228 make_step_with_knob(1, "k", 0.07),
229 make_step_with_knob(2, "k", 0.05),
230 make_step_with_knob(3, "k", 0.07),
231 make_step_with_knob(4, "k", 0.05),
232 make_step_with_knob(5, "k", 0.07),
233 ];
234 let detector = OscillationDetector::default();
235 assert!(
236 detector.check(&history, "k"),
237 "alternating deltas across 5 same-knob steps should flag"
238 );
239 }
240
241 #[test]
242 fn oscillation_detector_quiet_on_monotonic_progress() {
243 let history: Vec<_> = (0..6)
245 .map(|i| make_step_with_knob(i, "k", 0.05 + 0.01 * (i as f64)))
246 .collect();
247 let detector = OscillationDetector::default();
248 assert!(
249 !detector.check(&history, "k"),
250 "monotonic walk should not flag oscillation"
251 );
252 }
253
254 #[test]
255 fn oscillation_detector_needs_full_window() {
256 let history = vec![
258 make_step_with_knob(0, "k", 0.05),
259 make_step_with_knob(1, "k", 0.07),
260 make_step_with_knob(2, "k", 0.05),
261 ];
262 let detector = OscillationDetector::default();
263 assert!(!detector.check(&history, "k"));
264 }
265
266 #[test]
267 fn knob_clip_diagnostics_count_each_result() {
268 let mut diag = KnobClipDiagnostics::default();
269 diag.record("fraud.fraud_rate", KnobClipResult::InRange);
270 diag.record("fraud.fraud_rate", KnobClipResult::InRange);
271 diag.record("fraud.fraud_rate", KnobClipResult::ClippedHigh);
272 diag.record("fraud.fraud_rate", KnobClipResult::ClippedLow);
273 diag.record("pool.size", KnobClipResult::TypeMismatch);
274
275 let fraud = diag.counts.get("fraud.fraud_rate").unwrap();
276 assert_eq!(fraud.in_range, 2);
277 assert_eq!(fraud.high, 1);
278 assert_eq!(fraud.low, 1);
279 let pool = diag.counts.get("pool.size").unwrap();
280 assert_eq!(pool.type_mismatch, 1);
281 }
282
283 #[test]
284 fn frequently_clipped_filters_by_threshold() {
285 let mut diag = KnobClipDiagnostics::default();
286 for _ in 0..5 {
287 diag.record("often.clipped", KnobClipResult::ClippedHigh);
288 }
289 for _ in 0..2 {
290 diag.record("rarely.clipped", KnobClipResult::ClippedLow);
291 }
292 diag.record("never.clipped", KnobClipResult::InRange);
293
294 let flagged = diag.frequently_clipped(3);
295 assert_eq!(flagged.len(), 1);
296 assert_eq!(flagged[0], "often.clipped");
297 }
298
299 #[test]
300 fn wall_clock_budget_expires() {
301 let budget = WallClockBudget::new(Duration::from_millis(50));
302 assert!(!budget.expired());
303 std::thread::sleep(Duration::from_millis(80));
304 assert!(budget.expired());
305 assert_eq!(budget.remaining(), Duration::ZERO);
306 }
307}