Skip to main content

blr_active/active_learning/
orchestration.rs

1//! Algorithm 5 / T2.2: Calibration Orchestration Loop
2//!
3//! Implements a **synchronous** state machine for iterative sensor calibration
4//! (per CLARIFY-3: sync model, no async/await). The host drives the loop by
5//! calling `next_iteration()` and then feeding in new measurements via
6//! `add_measurement()` before calling `next_iteration()` again.
7//!
8//! **State machine logic per iteration:**
9//! 1. Fit BLR+ARD on current sample data.
10//! 2. Assess precision (Algorithm 3) — if MetGoal → PrecisionMet.
11//! 3. Check noise floor (Algorithm 4) — if at floor → NoiseFloorHit.
12//! 4. Generate recommendations (Algorithm 2) — return RecommendNext.
13//!
14//! **History tracking:**
15//! - `PrecisionRecord` pushed after each iteration.
16//! - `SampleRecord` pushed for each measurement added.
17//! - Full history available via `export_history_json()`.
18
19use crate::active_learning::acquisition::{recommend_next_samples, RecommendedSample};
20use crate::active_learning::noise_floor::{detect_noise_floor, NoiseFloorConfig};
21use crate::active_learning::precision::{assess_precision, PrecisionStatus};
22use crate::active_learning::variance::posterior_std_grid;
23use blr_core::{fit, ArdConfig, BLRError};
24
25// ─── Public types ─────────────────────────────────────────────────────────────
26
27/// One measurement added to the calibration session.
28#[derive(Debug, Clone)]
29pub struct SampleRecord {
30    /// Raw sensor input value (e.g., B-field in mT).
31    pub raw_input: f64,
32    /// Measured sensor output value (e.g., voltage in V).
33    pub measured_output: f64,
34    /// Iteration count when this sample was added.
35    pub added_at_iteration: usize,
36}
37
38/// Precision assessment snapshot recorded after each calibration iteration.
39#[derive(Debug, Clone)]
40pub struct PrecisionRecord {
41    /// Iteration index (0-based).
42    pub iteration: usize,
43    /// Number of samples in the model at this iteration.
44    pub sample_count: usize,
45    /// Mean posterior std over the evaluation grid.
46    pub mean_posterior_std: f64,
47    /// Max posterior std over the evaluation grid.
48    pub max_posterior_std: f64,
49    /// 95th-percentile posterior std (primary precision metric).
50    pub percentile_95_std: f64,
51    /// Whether the precision goal was met at this iteration.
52    pub goal_met: bool,
53}
54
55/// Outcome of one call to `CalibrationSession::next_iteration()`.
56#[derive(Debug, Clone)]
57pub enum IterationOutcome {
58    /// Precision goal met — calibration is complete.
59    PrecisionMet,
60    /// Noise floor detected — further measurements unlikely to help.
61    /// Contains the estimated noise floor std.
62    NoiseFloorHit(f64),
63    /// Calibration loop recommends collecting more samples at these locations.
64    RecommendNext(Vec<RecommendedSample>),
65    /// Maximum iteration count reached without convergence.
66    MaxIterationsReached,
67    /// A BLR fitting error occurred.
68    FitError(String),
69}
70
71/// Configuration for a calibration session.
72#[derive(Debug, Clone)]
73pub struct SessionConfig {
74    /// User precision requirement (e.g., 0.01 V std).
75    pub target_precision: f64,
76    /// Maximum number of calibration iterations.
77    pub max_iterations: usize,
78    /// Number of recommendations to provide per iteration.
79    pub top_k: usize,
80    /// Exclusion radius for acquisition (fraction of input range).
81    /// If None, defaults to 7% of (input_max - input_min).
82    pub exclusion_radius: Option<f64>,
83    /// Grid resolution for posterior std evaluation.
84    pub grid_resolution: usize,
85    /// Input range [min, max] for the evaluation grid.
86    pub input_range: (f64, f64),
87    /// Noise floor detection configuration.
88    pub noise_floor_config: NoiseFloorConfig,
89    /// BLR+ARD fitting configuration.
90    pub ard_config: ArdConfig,
91}
92
93impl Default for SessionConfig {
94    fn default() -> Self {
95        Self {
96            target_precision: 0.01,
97            max_iterations: 50,
98            top_k: 3,
99            exclusion_radius: None,
100            grid_resolution: 100,
101            input_range: (0.0, 1.0),
102            noise_floor_config: NoiseFloorConfig::default(),
103            ard_config: ArdConfig::default(),
104        }
105    }
106}
107
108/// An active calibration session tracking model, samples, and history.
109pub struct CalibrationSession {
110    /// Session configuration.
111    pub config: SessionConfig,
112    /// All measurements collected so far.
113    pub samples: Vec<SampleRecord>,
114    /// Precision snapshots from each completed iteration.
115    pub precision_history: Vec<PrecisionRecord>,
116    /// (n_samples, max_posterior_std) pairs for noise floor detection.
117    noise_floor_history: Vec<(usize, f64)>,
118    /// Feature function: maps a scalar input to a D-dimensional feature vector.
119    /// Must match the feature dimensionality used during fitting.
120    feature_fn: Box<dyn Fn(f64) -> Vec<f64>>,
121    /// Feature dimension D (must match feature_fn output length).
122    feature_dim: usize,
123    /// Current iteration count (0-based).
124    pub iteration: usize,
125}
126
127impl CalibrationSession {
128    /// Create a new calibration session.
129    ///
130    /// # Arguments
131    /// - `config`: session configuration
132    /// - `feature_fn`: maps scalar input → feature vector (must have length `feature_dim`)
133    /// - `feature_dim`: number of features D
134    pub fn new(
135        config: SessionConfig,
136        feature_fn: impl Fn(f64) -> Vec<f64> + 'static,
137        feature_dim: usize,
138    ) -> Self {
139        Self {
140            config,
141            samples: Vec::new(),
142            precision_history: Vec::new(),
143            noise_floor_history: Vec::new(),
144            feature_fn: Box::new(feature_fn),
145            feature_dim,
146            iteration: 0,
147        }
148    }
149
150    /// Add a new measurement from the user/sensor.
151    pub fn add_measurement(&mut self, raw_input: f64, measured_output: f64) {
152        self.samples.push(SampleRecord {
153            raw_input,
154            measured_output,
155            added_at_iteration: self.iteration,
156        });
157    }
158
159    /// Add multiple measurements at once.
160    pub fn add_measurements(&mut self, inputs: &[f64], outputs: &[f64]) {
161        for (&x, &y) in inputs.iter().zip(outputs.iter()) {
162            self.add_measurement(x, y);
163        }
164    }
165
166    /// Number of samples currently in the session.
167    pub fn sample_count(&self) -> usize {
168        self.samples.len()
169    }
170
171    /// Run one iteration of the calibration state machine.
172    ///
173    /// Returns an [`IterationOutcome`] indicating what to do next.
174    /// The caller should add new measurements (at the recommended locations)
175    /// before calling this again.
176    pub fn next_iteration(&mut self) -> IterationOutcome {
177        if self.iteration >= self.config.max_iterations {
178            return IterationOutcome::MaxIterationsReached;
179        }
180        if self.samples.is_empty() {
181            return IterationOutcome::FitError(
182                "No samples available — add measurements before iterating".into(),
183            );
184        }
185
186        // ── Step 1: Fit BLR+ARD on current data ──────────────────────────
187        let n = self.samples.len();
188        let d = self.feature_dim;
189
190        let mut phi = Vec::with_capacity(n * d);
191        let mut y = Vec::with_capacity(n);
192
193        for s in &self.samples {
194            let feats = (self.feature_fn)(s.raw_input);
195            // Pad or truncate to d features
196            let actual = feats.len().min(d);
197            phi.extend_from_slice(&feats[..actual]);
198            if actual < d {
199                phi.extend(std::iter::repeat_n(0.0, d - actual));
200            }
201            y.push(s.measured_output);
202        }
203
204        let fitted = match fit(&phi, &y, n, d, &self.config.ard_config) {
205            Ok(m) => m,
206            Err(BLRError::SingularMatrix) => {
207                return IterationOutcome::FitError(
208                    "BLR fit failed: singular matrix — add more diverse samples".into(),
209                );
210            }
211            Err(e) => return IterationOutcome::FitError(format!("BLR fit failed: {e}")),
212        };
213
214        // ── Step 2: Evaluate posterior std over evaluation grid ───────────
215        let (input_min, input_max) = self.config.input_range;
216        let (grid, stds) = posterior_std_grid(
217            fitted.beta,
218            &fitted.posterior.cov,
219            d,
220            input_min,
221            input_max,
222            self.config.grid_resolution,
223            self.feature_fn.as_ref(),
224        );
225
226        let max_std = stds.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
227        let mean_std = stds.iter().sum::<f64>() / stds.len() as f64;
228        let p95 = crate::active_learning::precision::percentile(&stds, 0.95);
229
230        // ── Step 3: Record precision snapshot ────────────────────────────
231        let assessment = assess_precision(&stds, self.config.target_precision);
232        let goal_met = assessment.status == PrecisionStatus::MetGoal;
233
234        self.precision_history.push(PrecisionRecord {
235            iteration: self.iteration,
236            sample_count: n,
237            mean_posterior_std: mean_std,
238            max_posterior_std: max_std,
239            percentile_95_std: p95,
240            goal_met,
241        });
242        self.noise_floor_history.push((n, max_std));
243        self.iteration += 1;
244
245        // ── Step 4: Check precision goal ──────────────────────────────────
246        if goal_met {
247            return IterationOutcome::PrecisionMet;
248        }
249
250        // ── Step 5: Check noise floor ─────────────────────────────────────
251        let nf_diag =
252            detect_noise_floor(&self.noise_floor_history, &self.config.noise_floor_config);
253        if nf_diag.likely_at_floor {
254            return IterationOutcome::NoiseFloorHit(nf_diag.predicted_noise_floor);
255        }
256
257        // ── Step 6: Generate recommendations (Algorithm 2) ────────────────
258        let existing: Vec<f64> = self.samples.iter().map(|s| s.raw_input).collect();
259        let radius = self
260            .config
261            .exclusion_radius
262            .unwrap_or((input_max - input_min) * 0.07);
263
264        let recs = recommend_next_samples(&grid, &stds, &existing, self.config.top_k, radius);
265        IterationOutcome::RecommendNext(recs)
266    }
267
268    /// Export the calibration history as a JSON string.
269    ///
270    /// Schema:
271    /// ```json
272    /// {
273    ///   "iteration_count": <usize>,
274    ///   "sample_count": <usize>,
275    ///   "target_precision": <f64>,
276    ///   "precision_history": [
277    ///     { "iteration": .., "sample_count": .., "mean_std": .., "max_std": ..,
278    ///       "p95_std": .., "goal_met": .. },
279    ///     ...
280    ///   ],
281    ///   "samples": [
282    ///     { "raw_input": .., "measured_output": .., "added_at_iteration": .. },
283    ///     ...
284    ///   ]
285    /// }
286    /// ```
287    pub fn export_history_json(&self) -> String {
288        let precision_entries: Vec<String> = self
289            .precision_history
290            .iter()
291            .map(|r| {
292                format!(
293                    r#"{{"iteration":{},"sample_count":{},"mean_std":{:.6e},"max_std":{:.6e},"p95_std":{:.6e},"goal_met":{}}}"#,
294                    r.iteration, r.sample_count, r.mean_posterior_std,
295                    r.max_posterior_std, r.percentile_95_std, r.goal_met
296                )
297            })
298            .collect();
299
300        let sample_entries: Vec<String> = self
301            .samples
302            .iter()
303            .map(|s| {
304                format!(
305                    r#"{{"raw_input":{:.6e},"measured_output":{:.6e},"added_at_iteration":{}}}"#,
306                    s.raw_input, s.measured_output, s.added_at_iteration
307                )
308            })
309            .collect();
310
311        format!(
312            r#"{{"iteration_count":{},"sample_count":{},"target_precision":{:.6e},"precision_history":[{}],"samples":[{}]}}"#,
313            self.iteration,
314            self.samples.len(),
315            self.config.target_precision,
316            precision_entries.join(","),
317            sample_entries.join(","),
318        )
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    /// Simple linear feature function (1-D polynomial degree 1: [1, x])
327    fn linear_feature(x: f64) -> Vec<f64> {
328        vec![1.0, x]
329    }
330
331    fn make_linear_session(target: f64) -> CalibrationSession {
332        let config = SessionConfig {
333            target_precision: target,
334            max_iterations: 20,
335            top_k: 3,
336            grid_resolution: 50,
337            input_range: (0.0, 10.0),
338            ..Default::default()
339        };
340        CalibrationSession::new(config, linear_feature, 2)
341    }
342
343    /// Session without samples should return FitError, not panic
344    #[test]
345    fn test_empty_session_no_panic() {
346        let mut session = make_linear_session(0.01);
347        let outcome = session.next_iteration();
348        assert!(matches!(outcome, IterationOutcome::FitError(_)));
349    }
350
351    /// With enough clean samples, precision goal should eventually be met
352    #[test]
353    fn test_calibration_reaches_goal() {
354        let mut session = make_linear_session(0.5); // generous target
355                                                    // Add clean data: y = 2x + 1 with very small noise
356        let xs: Vec<f64> = (0..20).map(|i| i as f64 * 0.5).collect();
357        for x in &xs {
358            let y = 2.0 * x + 1.0 + ((x * 3.7).sin() * 0.001); // tiny noise
359            session.add_measurement(*x, y);
360        }
361        let outcome = session.next_iteration();
362        // With 20 well-spread samples and generous target, expect PrecisionMet or RecommendNext
363        assert!(
364            matches!(
365                outcome,
366                IterationOutcome::PrecisionMet | IterationOutcome::RecommendNext(_)
367            ),
368            "unexpected outcome: {:?}",
369            outcome
370        );
371    }
372
373    /// Max iterations respected
374    #[test]
375    fn test_max_iterations_respected() {
376        let config = SessionConfig {
377            max_iterations: 2,
378            target_precision: 0.001, // very tight — won't be met
379            grid_resolution: 20,
380            input_range: (0.0, 5.0),
381            ..Default::default()
382        };
383        let mut session = CalibrationSession::new(config, linear_feature, 2);
384        session.add_measurement(0.0, 0.0);
385        session.add_measurement(5.0, 10.0);
386
387        session.next_iteration();
388        session.add_measurement(2.5, 5.0);
389        session.next_iteration();
390        session.add_measurement(1.0, 2.0);
391
392        let outcome = session.next_iteration();
393        assert!(matches!(outcome, IterationOutcome::MaxIterationsReached));
394    }
395
396    /// JSON export has required keys
397    #[test]
398    fn test_history_export_json_schema() {
399        let mut session = make_linear_session(0.01);
400        session.add_measurement(1.0, 3.0);
401        session.add_measurement(5.0, 11.0);
402        session.add_measurement(9.0, 19.0);
403        let _ = session.next_iteration();
404
405        let json = session.export_history_json();
406        assert!(
407            json.contains("\"iteration_count\""),
408            "missing iteration_count"
409        );
410        assert!(json.contains("\"sample_count\""), "missing sample_count");
411        assert!(
412            json.contains("\"target_precision\""),
413            "missing target_precision"
414        );
415        assert!(
416            json.contains("\"precision_history\""),
417            "missing precision_history"
418        );
419        assert!(json.contains("\"samples\""), "missing samples");
420    }
421}