blr_active/active_learning/
orchestration.rs1use crate::active_learning::acquisition::{recommend_next_samples, RecommendedSample};
20use crate::active_learning::noise_floor::{detect_noise_floor, NoiseFloorConfig};
21use crate::active_learning::precision::{assess_precision, PrecisionStatus};
22use crate::active_learning::variance::posterior_std_grid;
23use blr_core::{fit, ArdConfig, BLRError};
24
25#[derive(Debug, Clone)]
29pub struct SampleRecord {
30 pub raw_input: f64,
32 pub measured_output: f64,
34 pub added_at_iteration: usize,
36}
37
38#[derive(Debug, Clone)]
40pub struct PrecisionRecord {
41 pub iteration: usize,
43 pub sample_count: usize,
45 pub mean_posterior_std: f64,
47 pub max_posterior_std: f64,
49 pub percentile_95_std: f64,
51 pub goal_met: bool,
53}
54
55#[derive(Debug, Clone)]
57pub enum IterationOutcome {
58 PrecisionMet,
60 NoiseFloorHit(f64),
63 RecommendNext(Vec<RecommendedSample>),
65 MaxIterationsReached,
67 FitError(String),
69}
70
71#[derive(Debug, Clone)]
73pub struct SessionConfig {
74 pub target_precision: f64,
76 pub max_iterations: usize,
78 pub top_k: usize,
80 pub exclusion_radius: Option<f64>,
83 pub grid_resolution: usize,
85 pub input_range: (f64, f64),
87 pub noise_floor_config: NoiseFloorConfig,
89 pub ard_config: ArdConfig,
91}
92
93impl Default for SessionConfig {
94 fn default() -> Self {
95 Self {
96 target_precision: 0.01,
97 max_iterations: 50,
98 top_k: 3,
99 exclusion_radius: None,
100 grid_resolution: 100,
101 input_range: (0.0, 1.0),
102 noise_floor_config: NoiseFloorConfig::default(),
103 ard_config: ArdConfig::default(),
104 }
105 }
106}
107
108pub struct CalibrationSession {
110 pub config: SessionConfig,
112 pub samples: Vec<SampleRecord>,
114 pub precision_history: Vec<PrecisionRecord>,
116 noise_floor_history: Vec<(usize, f64)>,
118 feature_fn: Box<dyn Fn(f64) -> Vec<f64>>,
121 feature_dim: usize,
123 pub iteration: usize,
125}
126
127impl CalibrationSession {
128 pub fn new(
135 config: SessionConfig,
136 feature_fn: impl Fn(f64) -> Vec<f64> + 'static,
137 feature_dim: usize,
138 ) -> Self {
139 Self {
140 config,
141 samples: Vec::new(),
142 precision_history: Vec::new(),
143 noise_floor_history: Vec::new(),
144 feature_fn: Box::new(feature_fn),
145 feature_dim,
146 iteration: 0,
147 }
148 }
149
150 pub fn add_measurement(&mut self, raw_input: f64, measured_output: f64) {
152 self.samples.push(SampleRecord {
153 raw_input,
154 measured_output,
155 added_at_iteration: self.iteration,
156 });
157 }
158
159 pub fn add_measurements(&mut self, inputs: &[f64], outputs: &[f64]) {
161 for (&x, &y) in inputs.iter().zip(outputs.iter()) {
162 self.add_measurement(x, y);
163 }
164 }
165
166 pub fn sample_count(&self) -> usize {
168 self.samples.len()
169 }
170
171 pub fn next_iteration(&mut self) -> IterationOutcome {
177 if self.iteration >= self.config.max_iterations {
178 return IterationOutcome::MaxIterationsReached;
179 }
180 if self.samples.is_empty() {
181 return IterationOutcome::FitError(
182 "No samples available — add measurements before iterating".into(),
183 );
184 }
185
186 let n = self.samples.len();
188 let d = self.feature_dim;
189
190 let mut phi = Vec::with_capacity(n * d);
191 let mut y = Vec::with_capacity(n);
192
193 for s in &self.samples {
194 let feats = (self.feature_fn)(s.raw_input);
195 let actual = feats.len().min(d);
197 phi.extend_from_slice(&feats[..actual]);
198 if actual < d {
199 phi.extend(std::iter::repeat_n(0.0, d - actual));
200 }
201 y.push(s.measured_output);
202 }
203
204 let fitted = match fit(&phi, &y, n, d, &self.config.ard_config) {
205 Ok(m) => m,
206 Err(BLRError::SingularMatrix) => {
207 return IterationOutcome::FitError(
208 "BLR fit failed: singular matrix — add more diverse samples".into(),
209 );
210 }
211 Err(e) => return IterationOutcome::FitError(format!("BLR fit failed: {e}")),
212 };
213
214 let (input_min, input_max) = self.config.input_range;
216 let (grid, stds) = posterior_std_grid(
217 fitted.beta,
218 &fitted.posterior.cov,
219 d,
220 input_min,
221 input_max,
222 self.config.grid_resolution,
223 self.feature_fn.as_ref(),
224 );
225
226 let max_std = stds.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
227 let mean_std = stds.iter().sum::<f64>() / stds.len() as f64;
228 let p95 = crate::active_learning::precision::percentile(&stds, 0.95);
229
230 let assessment = assess_precision(&stds, self.config.target_precision);
232 let goal_met = assessment.status == PrecisionStatus::MetGoal;
233
234 self.precision_history.push(PrecisionRecord {
235 iteration: self.iteration,
236 sample_count: n,
237 mean_posterior_std: mean_std,
238 max_posterior_std: max_std,
239 percentile_95_std: p95,
240 goal_met,
241 });
242 self.noise_floor_history.push((n, max_std));
243 self.iteration += 1;
244
245 if goal_met {
247 return IterationOutcome::PrecisionMet;
248 }
249
250 let nf_diag =
252 detect_noise_floor(&self.noise_floor_history, &self.config.noise_floor_config);
253 if nf_diag.likely_at_floor {
254 return IterationOutcome::NoiseFloorHit(nf_diag.predicted_noise_floor);
255 }
256
257 let existing: Vec<f64> = self.samples.iter().map(|s| s.raw_input).collect();
259 let radius = self
260 .config
261 .exclusion_radius
262 .unwrap_or((input_max - input_min) * 0.07);
263
264 let recs = recommend_next_samples(&grid, &stds, &existing, self.config.top_k, radius);
265 IterationOutcome::RecommendNext(recs)
266 }
267
268 pub fn export_history_json(&self) -> String {
288 let precision_entries: Vec<String> = self
289 .precision_history
290 .iter()
291 .map(|r| {
292 format!(
293 r#"{{"iteration":{},"sample_count":{},"mean_std":{:.6e},"max_std":{:.6e},"p95_std":{:.6e},"goal_met":{}}}"#,
294 r.iteration, r.sample_count, r.mean_posterior_std,
295 r.max_posterior_std, r.percentile_95_std, r.goal_met
296 )
297 })
298 .collect();
299
300 let sample_entries: Vec<String> = self
301 .samples
302 .iter()
303 .map(|s| {
304 format!(
305 r#"{{"raw_input":{:.6e},"measured_output":{:.6e},"added_at_iteration":{}}}"#,
306 s.raw_input, s.measured_output, s.added_at_iteration
307 )
308 })
309 .collect();
310
311 format!(
312 r#"{{"iteration_count":{},"sample_count":{},"target_precision":{:.6e},"precision_history":[{}],"samples":[{}]}}"#,
313 self.iteration,
314 self.samples.len(),
315 self.config.target_precision,
316 precision_entries.join(","),
317 sample_entries.join(","),
318 )
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn linear_feature(x: f64) -> Vec<f64> {
328 vec![1.0, x]
329 }
330
331 fn make_linear_session(target: f64) -> CalibrationSession {
332 let config = SessionConfig {
333 target_precision: target,
334 max_iterations: 20,
335 top_k: 3,
336 grid_resolution: 50,
337 input_range: (0.0, 10.0),
338 ..Default::default()
339 };
340 CalibrationSession::new(config, linear_feature, 2)
341 }
342
343 #[test]
345 fn test_empty_session_no_panic() {
346 let mut session = make_linear_session(0.01);
347 let outcome = session.next_iteration();
348 assert!(matches!(outcome, IterationOutcome::FitError(_)));
349 }
350
351 #[test]
353 fn test_calibration_reaches_goal() {
354 let mut session = make_linear_session(0.5); let xs: Vec<f64> = (0..20).map(|i| i as f64 * 0.5).collect();
357 for x in &xs {
358 let y = 2.0 * x + 1.0 + ((x * 3.7).sin() * 0.001); session.add_measurement(*x, y);
360 }
361 let outcome = session.next_iteration();
362 assert!(
364 matches!(
365 outcome,
366 IterationOutcome::PrecisionMet | IterationOutcome::RecommendNext(_)
367 ),
368 "unexpected outcome: {:?}",
369 outcome
370 );
371 }
372
373 #[test]
375 fn test_max_iterations_respected() {
376 let config = SessionConfig {
377 max_iterations: 2,
378 target_precision: 0.001, grid_resolution: 20,
380 input_range: (0.0, 5.0),
381 ..Default::default()
382 };
383 let mut session = CalibrationSession::new(config, linear_feature, 2);
384 session.add_measurement(0.0, 0.0);
385 session.add_measurement(5.0, 10.0);
386
387 session.next_iteration();
388 session.add_measurement(2.5, 5.0);
389 session.next_iteration();
390 session.add_measurement(1.0, 2.0);
391
392 let outcome = session.next_iteration();
393 assert!(matches!(outcome, IterationOutcome::MaxIterationsReached));
394 }
395
396 #[test]
398 fn test_history_export_json_schema() {
399 let mut session = make_linear_session(0.01);
400 session.add_measurement(1.0, 3.0);
401 session.add_measurement(5.0, 11.0);
402 session.add_measurement(9.0, 19.0);
403 let _ = session.next_iteration();
404
405 let json = session.export_history_json();
406 assert!(
407 json.contains("\"iteration_count\""),
408 "missing iteration_count"
409 );
410 assert!(json.contains("\"sample_count\""), "missing sample_count");
411 assert!(
412 json.contains("\"target_precision\""),
413 "missing target_precision"
414 );
415 assert!(
416 json.contains("\"precision_history\""),
417 "missing precision_history"
418 );
419 assert!(json.contains("\"samples\""), "missing samples");
420 }
421}