use crate::active_learning::acquisition::{recommend_next_samples, RecommendedSample};
use crate::active_learning::noise_floor::{detect_noise_floor, NoiseFloorConfig};
use crate::active_learning::precision::{assess_precision, PrecisionStatus};
use crate::active_learning::variance::posterior_std_grid;
use blr_core::{fit, ArdConfig, BLRError};
#[derive(Debug, Clone)]
pub struct SampleRecord {
pub raw_input: f64,
pub measured_output: f64,
pub added_at_iteration: usize,
}
#[derive(Debug, Clone)]
pub struct PrecisionRecord {
pub iteration: usize,
pub sample_count: usize,
pub mean_posterior_std: f64,
pub max_posterior_std: f64,
pub percentile_95_std: f64,
pub goal_met: bool,
}
#[derive(Debug, Clone)]
pub enum IterationOutcome {
PrecisionMet,
NoiseFloorHit(f64),
RecommendNext(Vec<RecommendedSample>),
MaxIterationsReached,
FitError(String),
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub target_precision: f64,
pub max_iterations: usize,
pub top_k: usize,
pub exclusion_radius: Option<f64>,
pub grid_resolution: usize,
pub input_range: (f64, f64),
pub noise_floor_config: NoiseFloorConfig,
pub ard_config: ArdConfig,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
target_precision: 0.01,
max_iterations: 50,
top_k: 3,
exclusion_radius: None,
grid_resolution: 100,
input_range: (0.0, 1.0),
noise_floor_config: NoiseFloorConfig::default(),
ard_config: ArdConfig::default(),
}
}
}
pub struct CalibrationSession {
pub config: SessionConfig,
pub samples: Vec<SampleRecord>,
pub precision_history: Vec<PrecisionRecord>,
noise_floor_history: Vec<(usize, f64)>,
feature_fn: Box<dyn Fn(f64) -> Vec<f64>>,
feature_dim: usize,
pub iteration: usize,
}
impl CalibrationSession {
pub fn new(
config: SessionConfig,
feature_fn: impl Fn(f64) -> Vec<f64> + 'static,
feature_dim: usize,
) -> Self {
Self {
config,
samples: Vec::new(),
precision_history: Vec::new(),
noise_floor_history: Vec::new(),
feature_fn: Box::new(feature_fn),
feature_dim,
iteration: 0,
}
}
pub fn add_measurement(&mut self, raw_input: f64, measured_output: f64) {
self.samples.push(SampleRecord {
raw_input,
measured_output,
added_at_iteration: self.iteration,
});
}
pub fn add_measurements(&mut self, inputs: &[f64], outputs: &[f64]) {
for (&x, &y) in inputs.iter().zip(outputs.iter()) {
self.add_measurement(x, y);
}
}
pub fn sample_count(&self) -> usize {
self.samples.len()
}
pub fn next_iteration(&mut self) -> IterationOutcome {
if self.iteration >= self.config.max_iterations {
return IterationOutcome::MaxIterationsReached;
}
if self.samples.is_empty() {
return IterationOutcome::FitError(
"No samples available — add measurements before iterating".into(),
);
}
let n = self.samples.len();
let d = self.feature_dim;
let mut phi = Vec::with_capacity(n * d);
let mut y = Vec::with_capacity(n);
for s in &self.samples {
let feats = (self.feature_fn)(s.raw_input);
let actual = feats.len().min(d);
phi.extend_from_slice(&feats[..actual]);
if actual < d {
phi.extend(std::iter::repeat_n(0.0, d - actual));
}
y.push(s.measured_output);
}
let fitted = match fit(&phi, &y, n, d, &self.config.ard_config) {
Ok(m) => m,
Err(BLRError::SingularMatrix) => {
return IterationOutcome::FitError(
"BLR fit failed: singular matrix — add more diverse samples".into(),
);
}
Err(e) => return IterationOutcome::FitError(format!("BLR fit failed: {e}")),
};
let (input_min, input_max) = self.config.input_range;
let (grid, stds) = posterior_std_grid(
fitted.beta,
&fitted.posterior.cov,
d,
input_min,
input_max,
self.config.grid_resolution,
self.feature_fn.as_ref(),
);
let max_std = stds.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mean_std = stds.iter().sum::<f64>() / stds.len() as f64;
let p95 = crate::active_learning::precision::percentile(&stds, 0.95);
let assessment = assess_precision(&stds, self.config.target_precision);
let goal_met = assessment.status == PrecisionStatus::MetGoal;
self.precision_history.push(PrecisionRecord {
iteration: self.iteration,
sample_count: n,
mean_posterior_std: mean_std,
max_posterior_std: max_std,
percentile_95_std: p95,
goal_met,
});
self.noise_floor_history.push((n, max_std));
self.iteration += 1;
if goal_met {
return IterationOutcome::PrecisionMet;
}
let nf_diag =
detect_noise_floor(&self.noise_floor_history, &self.config.noise_floor_config);
if nf_diag.likely_at_floor {
return IterationOutcome::NoiseFloorHit(nf_diag.predicted_noise_floor);
}
let existing: Vec<f64> = self.samples.iter().map(|s| s.raw_input).collect();
let radius = self
.config
.exclusion_radius
.unwrap_or((input_max - input_min) * 0.07);
let recs = recommend_next_samples(&grid, &stds, &existing, self.config.top_k, radius);
IterationOutcome::RecommendNext(recs)
}
pub fn export_history_json(&self) -> String {
let precision_entries: Vec<String> = self
.precision_history
.iter()
.map(|r| {
format!(
r#"{{"iteration":{},"sample_count":{},"mean_std":{:.6e},"max_std":{:.6e},"p95_std":{:.6e},"goal_met":{}}}"#,
r.iteration, r.sample_count, r.mean_posterior_std,
r.max_posterior_std, r.percentile_95_std, r.goal_met
)
})
.collect();
let sample_entries: Vec<String> = self
.samples
.iter()
.map(|s| {
format!(
r#"{{"raw_input":{:.6e},"measured_output":{:.6e},"added_at_iteration":{}}}"#,
s.raw_input, s.measured_output, s.added_at_iteration
)
})
.collect();
format!(
r#"{{"iteration_count":{},"sample_count":{},"target_precision":{:.6e},"precision_history":[{}],"samples":[{}]}}"#,
self.iteration,
self.samples.len(),
self.config.target_precision,
precision_entries.join(","),
sample_entries.join(","),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn linear_feature(x: f64) -> Vec<f64> {
vec![1.0, x]
}
fn make_linear_session(target: f64) -> CalibrationSession {
let config = SessionConfig {
target_precision: target,
max_iterations: 20,
top_k: 3,
grid_resolution: 50,
input_range: (0.0, 10.0),
..Default::default()
};
CalibrationSession::new(config, linear_feature, 2)
}
#[test]
fn test_empty_session_no_panic() {
let mut session = make_linear_session(0.01);
let outcome = session.next_iteration();
assert!(matches!(outcome, IterationOutcome::FitError(_)));
}
#[test]
fn test_calibration_reaches_goal() {
let mut session = make_linear_session(0.5); let xs: Vec<f64> = (0..20).map(|i| i as f64 * 0.5).collect();
for x in &xs {
let y = 2.0 * x + 1.0 + ((x * 3.7).sin() * 0.001); session.add_measurement(*x, y);
}
let outcome = session.next_iteration();
assert!(
matches!(
outcome,
IterationOutcome::PrecisionMet | IterationOutcome::RecommendNext(_)
),
"unexpected outcome: {:?}",
outcome
);
}
#[test]
fn test_max_iterations_respected() {
let config = SessionConfig {
max_iterations: 2,
target_precision: 0.001, grid_resolution: 20,
input_range: (0.0, 5.0),
..Default::default()
};
let mut session = CalibrationSession::new(config, linear_feature, 2);
session.add_measurement(0.0, 0.0);
session.add_measurement(5.0, 10.0);
session.next_iteration();
session.add_measurement(2.5, 5.0);
session.next_iteration();
session.add_measurement(1.0, 2.0);
let outcome = session.next_iteration();
assert!(matches!(outcome, IterationOutcome::MaxIterationsReached));
}
#[test]
fn test_history_export_json_schema() {
let mut session = make_linear_session(0.01);
session.add_measurement(1.0, 3.0);
session.add_measurement(5.0, 11.0);
session.add_measurement(9.0, 19.0);
let _ = session.next_iteration();
let json = session.export_history_json();
assert!(
json.contains("\"iteration_count\""),
"missing iteration_count"
);
assert!(json.contains("\"sample_count\""), "missing sample_count");
assert!(
json.contains("\"target_precision\""),
"missing target_precision"
);
assert!(
json.contains("\"precision_history\""),
"missing precision_history"
);
assert!(json.contains("\"samples\""), "missing samples");
}
}