#[derive(Debug, Clone, PartialEq)]
pub struct RecommendedSample {
pub input_value: f64,
pub expected_std: f64,
pub rank: usize,
}
pub fn recommend_next_samples(
grid: &[f64],
posterior_std: &[f64],
existing_samples: &[f64],
k: usize,
exclusion_radius: f64,
) -> Vec<RecommendedSample> {
debug_assert_eq!(
grid.len(),
posterior_std.len(),
"grid and posterior_std must have equal length"
);
if grid.is_empty() || k == 0 {
return Vec::new();
}
let mut scored: Vec<(f64, f64)> = grid
.iter()
.copied()
.zip(posterior_std.iter().copied())
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut candidates: Vec<(f64, f64)> = scored
.iter()
.filter(|(x, _)| {
existing_samples
.iter()
.all(|s| (x - s).abs() > exclusion_radius)
})
.copied()
.collect();
if candidates.is_empty() {
eprintln!(
"[active_learning::acquisition] Warning: all {} grid points excluded by \
exclusion_radius={:.4}; falling back to highest-variance point.",
grid.len(),
exclusion_radius
);
candidates = scored.clone();
} else if candidates.len() < k {
eprintln!(
"[active_learning::acquisition] Warning: only {} candidates remain after \
exclusion (requested k={}).",
candidates.len(),
k
);
}
candidates
.into_iter()
.take(k)
.enumerate()
.map(|(idx, (input_value, expected_std))| RecommendedSample {
input_value,
expected_std,
rank: idx + 1,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_grid(n: usize) -> (Vec<f64>, Vec<f64>) {
let grid: Vec<f64> = (0..n).map(|i| i as f64).collect();
let stds: Vec<f64> = grid
.iter()
.map(|&x| -(x - 5.0).powi(2) / 10.0 + 1.0)
.collect();
(grid, stds)
}
#[test]
fn test_acquisition_maxes_identified() {
let (grid, stds) = make_grid(10);
let recs = recommend_next_samples(&grid, &stds, &[], 3, 0.0);
assert_eq!(recs.len(), 3);
assert!(
recs[0].expected_std >= recs[1].expected_std,
"recommendations must be sorted descending"
);
assert!(
recs[1].expected_std >= recs[2].expected_std,
"recommendations must be sorted descending"
);
}
#[test]
fn test_exclusion_radius_respected() {
let grid = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let stds = vec![0.9, 0.8, 0.7, 0.6, 0.5];
let existing = vec![0.5]; let recs = recommend_next_samples(&grid, &stds, &existing, 5, 1.0);
for r in &recs {
for s in &existing {
assert!(
(r.input_value - s).abs() > 1.0,
"recommended point {:.2} is within exclusion radius of {:.2}",
r.input_value,
s
);
}
}
}
#[test]
fn test_ranking_deterministic() {
let (grid, stds) = make_grid(10);
let recs1 = recommend_next_samples(&grid, &stds, &[], 5, 0.0);
let recs2 = recommend_next_samples(&grid, &stds, &[], 5, 0.0);
for (r1, r2) in recs1.iter().zip(recs2.iter()) {
assert_eq!(r1.input_value, r2.input_value, "output not deterministic");
}
for (i, r) in recs1.iter().enumerate() {
assert_eq!(r.rank, i + 1, "rank must be 1-indexed sequential");
}
}
#[test]
fn test_edge_case_all_excluded_no_panic() {
let grid = vec![1.0, 2.0, 3.0];
let stds = vec![0.5, 0.9, 0.3];
let existing = vec![2.0]; let recs = recommend_next_samples(&grid, &stds, &existing, 2, 5.0);
assert!(
!recs.is_empty(),
"must return fallback even when all excluded"
);
assert_eq!(
recs[0].input_value, 2.0,
"fallback should be highest-variance point"
);
}
#[test]
fn test_k_zero() {
let grid = vec![1.0, 2.0];
let stds = vec![0.5, 0.9];
let recs = recommend_next_samples(&grid, &stds, &[], 0, 0.0);
assert!(recs.is_empty());
}
#[test]
fn test_empty_grid() {
let recs = recommend_next_samples(&[], &[], &[], 3, 0.1);
assert!(recs.is_empty());
}
}