blr_active/active_learning/
acquisition.rs1#[derive(Debug, Clone, PartialEq)]
12pub struct RecommendedSample {
13 pub input_value: f64,
15 pub expected_std: f64,
17 pub rank: usize,
19}
20
21pub fn recommend_next_samples(
41 grid: &[f64],
42 posterior_std: &[f64],
43 existing_samples: &[f64],
44 k: usize,
45 exclusion_radius: f64,
46) -> Vec<RecommendedSample> {
47 debug_assert_eq!(
48 grid.len(),
49 posterior_std.len(),
50 "grid and posterior_std must have equal length"
51 );
52
53 if grid.is_empty() || k == 0 {
54 return Vec::new();
55 }
56
57 let mut scored: Vec<(f64, f64)> = grid
59 .iter()
60 .copied()
61 .zip(posterior_std.iter().copied())
62 .collect();
63 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
64
65 let mut candidates: Vec<(f64, f64)> = scored
67 .iter()
68 .filter(|(x, _)| {
69 existing_samples
70 .iter()
71 .all(|s| (x - s).abs() > exclusion_radius)
72 })
73 .copied()
74 .collect();
75
76 if candidates.is_empty() {
78 eprintln!(
79 "[active_learning::acquisition] Warning: all {} grid points excluded by \
80 exclusion_radius={:.4}; falling back to highest-variance point.",
81 grid.len(),
82 exclusion_radius
83 );
84 candidates = scored.clone();
85 } else if candidates.len() < k {
86 eprintln!(
87 "[active_learning::acquisition] Warning: only {} candidates remain after \
88 exclusion (requested k={}).",
89 candidates.len(),
90 k
91 );
92 }
93
94 candidates
95 .into_iter()
96 .take(k)
97 .enumerate()
98 .map(|(idx, (input_value, expected_std))| RecommendedSample {
99 input_value,
100 expected_std,
101 rank: idx + 1,
102 })
103 .collect()
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 fn make_grid(n: usize) -> (Vec<f64>, Vec<f64>) {
111 let grid: Vec<f64> = (0..n).map(|i| i as f64).collect();
112 let stds: Vec<f64> = grid
114 .iter()
115 .map(|&x| -(x - 5.0).powi(2) / 10.0 + 1.0)
116 .collect();
117 (grid, stds)
118 }
119
120 #[test]
122 fn test_acquisition_maxes_identified() {
123 let (grid, stds) = make_grid(10);
124 let recs = recommend_next_samples(&grid, &stds, &[], 3, 0.0);
125 assert_eq!(recs.len(), 3);
126 assert!(
128 recs[0].expected_std >= recs[1].expected_std,
129 "recommendations must be sorted descending"
130 );
131 assert!(
132 recs[1].expected_std >= recs[2].expected_std,
133 "recommendations must be sorted descending"
134 );
135 }
136
137 #[test]
139 fn test_exclusion_radius_respected() {
140 let grid = vec![0.0, 1.0, 2.0, 3.0, 4.0];
141 let stds = vec![0.9, 0.8, 0.7, 0.6, 0.5];
142 let existing = vec![0.5]; let recs = recommend_next_samples(&grid, &stds, &existing, 5, 1.0);
144 for r in &recs {
145 for s in &existing {
146 assert!(
147 (r.input_value - s).abs() > 1.0,
148 "recommended point {:.2} is within exclusion radius of {:.2}",
149 r.input_value,
150 s
151 );
152 }
153 }
154 }
155
156 #[test]
158 fn test_ranking_deterministic() {
159 let (grid, stds) = make_grid(10);
160 let recs1 = recommend_next_samples(&grid, &stds, &[], 5, 0.0);
161 let recs2 = recommend_next_samples(&grid, &stds, &[], 5, 0.0);
162 for (r1, r2) in recs1.iter().zip(recs2.iter()) {
163 assert_eq!(r1.input_value, r2.input_value, "output not deterministic");
164 }
165 for (i, r) in recs1.iter().enumerate() {
166 assert_eq!(r.rank, i + 1, "rank must be 1-indexed sequential");
167 }
168 }
169
170 #[test]
172 fn test_edge_case_all_excluded_no_panic() {
173 let grid = vec![1.0, 2.0, 3.0];
174 let stds = vec![0.5, 0.9, 0.3];
175 let existing = vec![2.0]; let recs = recommend_next_samples(&grid, &stds, &existing, 2, 5.0);
177 assert!(
179 !recs.is_empty(),
180 "must return fallback even when all excluded"
181 );
182 assert_eq!(
183 recs[0].input_value, 2.0,
184 "fallback should be highest-variance point"
185 );
186 }
187
188 #[test]
190 fn test_k_zero() {
191 let grid = vec![1.0, 2.0];
192 let stds = vec![0.5, 0.9];
193 let recs = recommend_next_samples(&grid, &stds, &[], 0, 0.0);
194 assert!(recs.is_empty());
195 }
196
197 #[test]
199 fn test_empty_grid() {
200 let recs = recommend_next_samples(&[], &[], &[], 3, 0.1);
201 assert!(recs.is_empty());
202 }
203}