Skip to main content

blr_active/active_learning/
acquisition.rs

1//! Algorithm 2: Variance-Maximizing Acquisition Function
2//!
3//! Identifies the highest-uncertainty regions in the input space, excluding
4//! regions already sampled. Returns top-K recommendations for the user.
5//!
6//! **Design decision:** Exclusion radius prevents redundant sampling. Points
7//! within `exclusion_radius` of any existing sample are filtered before ranking.
8//! If fewer than K candidates remain, returns what is available (no panic).
9
10/// A single recommended sampling location with its expected uncertainty.
11#[derive(Debug, Clone, PartialEq)]
12pub struct RecommendedSample {
13    /// Recommended input value to sample next.
14    pub input_value: f64,
15    /// Posterior standard deviation at this location.
16    pub expected_std: f64,
17    /// Rank among all recommendations (1 = highest uncertainty).
18    pub rank: usize,
19}
20
21/// Find the top-K highest-variance grid points, excluding near-existing samples.
22///
23/// # Algorithm
24/// 1. Pair each grid point with its posterior std.
25/// 2. Filter out points within `exclusion_radius` of any existing sample.
26/// 3. Sort the remaining points by posterior std (descending).
27/// 4. Return the top `k` (or fewer if not enough candidates remain).
28///
29/// # Arguments
30/// - `grid`: input values of grid points
31/// - `posterior_std`: posterior standard deviations at each grid point (same length as `grid`)
32/// - `existing_samples`: locations of previously collected measurements
33/// - `k`: number of recommendations to return
34/// - `exclusion_radius`: minimum distance from an existing sample to be eligible
35///
36/// # Edge cases
37/// - If all grid points are excluded, returns the single highest-variance point
38///   regardless of exclusion (sensible fallback — never return nothing).
39/// - If k > available candidates, returns all available candidates.
40pub fn recommend_next_samples(
41    grid: &[f64],
42    posterior_std: &[f64],
43    existing_samples: &[f64],
44    k: usize,
45    exclusion_radius: f64,
46) -> Vec<RecommendedSample> {
47    debug_assert_eq!(
48        grid.len(),
49        posterior_std.len(),
50        "grid and posterior_std must have equal length"
51    );
52
53    if grid.is_empty() || k == 0 {
54        return Vec::new();
55    }
56
57    // Build (input_value, std) pairs sorted by std descending (stable sort → deterministic)
58    let mut scored: Vec<(f64, f64)> = grid
59        .iter()
60        .copied()
61        .zip(posterior_std.iter().copied())
62        .collect();
63    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
64
65    // Filter out points too close to existing samples
66    let mut candidates: Vec<(f64, f64)> = scored
67        .iter()
68        .filter(|(x, _)| {
69            existing_samples
70                .iter()
71                .all(|s| (x - s).abs() > exclusion_radius)
72        })
73        .copied()
74        .collect();
75
76    // Fallback: if all excluded, use the sorted list without exclusion filter
77    if candidates.is_empty() {
78        eprintln!(
79            "[active_learning::acquisition] Warning: all {} grid points excluded by \
80             exclusion_radius={:.4}; falling back to highest-variance point.",
81            grid.len(),
82            exclusion_radius
83        );
84        candidates = scored.clone();
85    } else if candidates.len() < k {
86        eprintln!(
87            "[active_learning::acquisition] Warning: only {} candidates remain after \
88             exclusion (requested k={}).",
89            candidates.len(),
90            k
91        );
92    }
93
94    candidates
95        .into_iter()
96        .take(k)
97        .enumerate()
98        .map(|(idx, (input_value, expected_std))| RecommendedSample {
99            input_value,
100            expected_std,
101            rank: idx + 1,
102        })
103        .collect()
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    fn make_grid(n: usize) -> (Vec<f64>, Vec<f64>) {
111        let grid: Vec<f64> = (0..n).map(|i| i as f64).collect();
112        // std peaks at index 5 for a 10-point grid
113        let stds: Vec<f64> = grid
114            .iter()
115            .map(|&x| -(x - 5.0).powi(2) / 10.0 + 1.0)
116            .collect();
117        (grid, stds)
118    }
119
120    /// Top-k should be the indices with the highest std values
121    #[test]
122    fn test_acquisition_maxes_identified() {
123        let (grid, stds) = make_grid(10);
124        let recs = recommend_next_samples(&grid, &stds, &[], 3, 0.0);
125        assert_eq!(recs.len(), 3);
126        // First recommendation should have the highest std
127        assert!(
128            recs[0].expected_std >= recs[1].expected_std,
129            "recommendations must be sorted descending"
130        );
131        assert!(
132            recs[1].expected_std >= recs[2].expected_std,
133            "recommendations must be sorted descending"
134        );
135    }
136
137    /// Points within exclusion radius of an existing sample must not appear
138    #[test]
139    fn test_exclusion_radius_respected() {
140        let grid = vec![0.0, 1.0, 2.0, 3.0, 4.0];
141        let stds = vec![0.9, 0.8, 0.7, 0.6, 0.5];
142        let existing = vec![0.5]; // exclude within 1.0 of 0.5
143        let recs = recommend_next_samples(&grid, &stds, &existing, 5, 1.0);
144        for r in &recs {
145            for s in &existing {
146                assert!(
147                    (r.input_value - s).abs() > 1.0,
148                    "recommended point {:.2} is within exclusion radius of {:.2}",
149                    r.input_value,
150                    s
151                );
152            }
153        }
154    }
155
156    /// Ranking must be 1-indexed and monotone
157    #[test]
158    fn test_ranking_deterministic() {
159        let (grid, stds) = make_grid(10);
160        let recs1 = recommend_next_samples(&grid, &stds, &[], 5, 0.0);
161        let recs2 = recommend_next_samples(&grid, &stds, &[], 5, 0.0);
162        for (r1, r2) in recs1.iter().zip(recs2.iter()) {
163            assert_eq!(r1.input_value, r2.input_value, "output not deterministic");
164        }
165        for (i, r) in recs1.iter().enumerate() {
166            assert_eq!(r.rank, i + 1, "rank must be 1-indexed sequential");
167        }
168    }
169
170    /// When all points are excluded, return the best fallback — never panic
171    #[test]
172    fn test_edge_case_all_excluded_no_panic() {
173        let grid = vec![1.0, 2.0, 3.0];
174        let stds = vec![0.5, 0.9, 0.3];
175        let existing = vec![2.0]; // exclude everything within radius 5
176        let recs = recommend_next_samples(&grid, &stds, &existing, 2, 5.0);
177        // Fallback: returns something (the highest variance point)
178        assert!(
179            !recs.is_empty(),
180            "must return fallback even when all excluded"
181        );
182        assert_eq!(
183            recs[0].input_value, 2.0,
184            "fallback should be highest-variance point"
185        );
186    }
187
188    /// k=0 returns empty vec
189    #[test]
190    fn test_k_zero() {
191        let grid = vec![1.0, 2.0];
192        let stds = vec![0.5, 0.9];
193        let recs = recommend_next_samples(&grid, &stds, &[], 0, 0.0);
194        assert!(recs.is_empty());
195    }
196
197    /// Empty grid returns empty vec
198    #[test]
199    fn test_empty_grid() {
200        let recs = recommend_next_samples(&[], &[], &[], 3, 0.1);
201        assert!(recs.is_empty());
202    }
203}