1use ndarray::Array2;
36
37use super::object_store::designed_sampling_mandatory;
38use super::shard_reader::CorpusRowSource;
39use crate::inference::harvest::TieredHarvest;
40use gam_solve::row_sampling_measure::{MeasureProvenance, RowSamplingMeasure};
41
42pub const DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS: usize = 2_000_000;
49
50pub fn auto_designed_budget(total_rows: u64) -> usize {
55 if designed_sampling_mandatory(total_rows) {
56 DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
57 } else {
58 total_rows as usize
59 }
60}
61
62#[derive(Debug, Clone)]
65pub struct DesignedCorpusTarget {
66 pub target: Array2<f64>,
69 pub row_ids: Vec<u64>,
73 pub likelihood_weights: Vec<f64>,
78 pub provenance: MeasureProvenance,
80 pub corpus_rows: u64,
82}
83
84impl DesignedCorpusTarget {
85 pub fn len(&self) -> usize {
87 self.row_ids.len()
88 }
89
90 pub fn is_empty(&self) -> bool {
91 self.row_ids.is_empty()
92 }
93
94 pub fn is_designed_subsample(&self) -> bool {
97 (self.len() as u64) < self.corpus_rows
98 }
99}
100
101pub fn collect_designed_target(
109 source: &mut dyn CorpusRowSource,
110 measure: Option<&RowSamplingMeasure>,
111 budget: usize,
112 seed: u64,
113) -> Result<DesignedCorpusTarget, String> {
114 let corpus_rows = source.total_rows();
115 let p = source.width();
116 let n = usize::try_from(corpus_rows)
117 .map_err(|_| "collect_designed_target: corpus row count exceeds usize".to_string())?;
118 let uniform;
119 let measure = match measure {
120 Some(m) => {
121 if m.n_rows() != n {
122 return Err(format!(
123 "collect_designed_target: measure covers {} rows but the corpus has {n}",
124 m.n_rows()
125 ));
126 }
127 m
128 }
129 None => {
130 uniform = RowSamplingMeasure::uniform(n);
131 &uniform
132 }
133 };
134 let sample = measure.designed_subsample(budget, seed);
135 let n_sel = sample.rows.len();
136 let mut target = Array2::<f64>::zeros((n_sel, p));
137 let mut row_ids = Vec::with_capacity(n_sel);
138
139 source.reset();
140 let mut next_sel = 0usize;
143 while next_sel < n_sel {
144 let Some(batch) = source
145 .next_batch()
146 .map_err(|e| format!("collect_designed_target: shard read failed: {e}"))?
147 else {
148 break;
149 };
150 for (k, &rid) in batch.row_ids.iter().enumerate() {
151 if next_sel >= n_sel {
152 break;
153 }
154 if rid == sample.rows[next_sel] as u64 {
155 target.row_mut(next_sel).assign(&batch.rows.row(k));
156 row_ids.push(rid);
157 next_sel += 1;
158 }
159 }
160 }
161 if next_sel != n_sel {
162 return Err(format!(
163 "collect_designed_target: stream ended after matching {next_sel} of {n_sel} \
164 designed rows (corpus declared {corpus_rows} rows)"
165 ));
166 }
167 Ok(DesignedCorpusTarget {
168 target,
169 row_ids,
170 likelihood_weights: sample.likelihood_weights,
171 provenance: sample.provenance,
172 corpus_rows,
173 })
174}
175
176pub fn collect_designed_target_auto(
180 source: &mut dyn CorpusRowSource,
181 seed: u64,
182) -> Result<DesignedCorpusTarget, String> {
183 let budget = auto_designed_budget(source.total_rows());
184 collect_designed_target(source, None, budget, seed)
185}
186
187pub fn collect_designed_target_from_harvest(
192 source: &mut dyn CorpusRowSource,
193 harvest: &TieredHarvest,
194 budget: usize,
195 seed: u64,
196) -> Result<DesignedCorpusTarget, String> {
197 let measure = harvest.corpus_measure();
198 collect_designed_target(source, Some(&measure), budget, seed)
199}
200
201#[cfg(test)]
202mod tests {
203 use super::super::shard_reader::{MmapShardSource, encode_shard_bytes};
204 use super::*;
205 use ndarray::Array2 as NdArray2;
206 use std::io::Write;
207 use std::path::PathBuf;
208
209 fn planted_rows(n: usize, p: usize) -> NdArray2<f64> {
210 NdArray2::from_shape_fn((n, p), |(i, j)| {
211 let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
212 (x.sin() * 43_758.547).fract() * 2.0 - 1.0
213 })
214 }
215
216 fn temp_shard_dir(name: &str, rows: &NdArray2<f64>, split_at: usize) -> PathBuf {
217 let mut dir = std::env::temp_dir();
218 dir.push(format!(
219 "gam-designed-target-test-{}-{}",
220 std::process::id(),
221 name
222 ));
223 std::fs::create_dir_all(&dir).expect("create dir");
224 let parts = [
225 ("a.shard", rows.slice(ndarray::s![..split_at, ..])),
226 ("b.shard", rows.slice(ndarray::s![split_at.., ..])),
227 ];
228 for (key, part) in parts {
229 let bytes = encode_shard_bytes(part);
230 let mut f = std::fs::File::create(dir.join(key)).expect("create shard");
231 f.write_all(&bytes).expect("write shard");
232 f.sync_all().expect("sync");
233 }
234 dir
235 }
236
237 #[test]
238 fn full_budget_collects_every_row_bit_for_bit_with_unit_weights() {
239 let n = 137;
240 let p = 5;
241 let rows = planted_rows(n, p);
242 let dir = temp_shard_dir("full", &rows, 60);
243 let mut src = MmapShardSource::open_dir(&dir).expect("open");
244 let collected = collect_designed_target_auto(&mut src, 7).expect("collect");
245
246 assert!(!collected.is_designed_subsample());
247 assert_eq!(collected.row_ids, (0..n as u64).collect::<Vec<_>>());
248 assert!(collected.likelihood_weights.iter().all(|&w| w == 1.0));
249 let stored = rows.mapv(|v| f64::from(v as f32));
252 for (a, b) in collected.target.iter().zip(stored.iter()) {
253 assert_eq!(a.to_bits(), b.to_bits());
254 }
255 std::fs::remove_dir_all(&dir).ok();
256 }
257
258 #[test]
259 fn designed_budget_collects_exactly_the_designed_rows_with_their_weights() {
260 let n = 200;
261 let p = 3;
262 let rows = planted_rows(n, p);
263 let dir = temp_shard_dir("designed", &rows, 90);
264 let mut src = MmapShardSource::open_dir(&dir).expect("open");
265
266 let budget = 40usize;
267 let seed = 17u64;
268 let collected = collect_designed_target(&mut src, None, budget, seed).expect("collect");
269 assert!(collected.is_designed_subsample());
270
271 let sample = RowSamplingMeasure::uniform(n).designed_subsample(budget, seed);
274 assert_eq!(
275 collected.row_ids,
276 sample.rows.iter().map(|&r| r as u64).collect::<Vec<_>>()
277 );
278 assert_eq!(collected.likelihood_weights, sample.likelihood_weights);
279
280 let stored = rows.mapv(|v| f64::from(v as f32));
282 for (k, &rid) in collected.row_ids.iter().enumerate() {
283 for c in 0..p {
284 assert_eq!(
285 collected.target[[k, c]].to_bits(),
286 stored[[rid as usize, c]].to_bits(),
287 "row {rid} col {c}"
288 );
289 }
290 }
291
292 let again = collect_designed_target(&mut src, None, budget, seed).expect("collect again");
294 assert_eq!(again.row_ids, collected.row_ids);
295 for (a, b) in again.target.iter().zip(collected.target.iter()) {
296 assert_eq!(a.to_bits(), b.to_bits());
297 }
298 std::fs::remove_dir_all(&dir).ok();
299 }
300
301 #[test]
302 fn measure_dimension_mismatch_is_rejected() {
303 let rows = planted_rows(20, 2);
304 let dir = temp_shard_dir("mismatch", &rows, 10);
305 let mut src = MmapShardSource::open_dir(&dir).expect("open");
306 let wrong = RowSamplingMeasure::uniform(7);
307 let err = collect_designed_target(&mut src, Some(&wrong), 5, 1)
308 .expect_err("mismatched measure must be rejected");
309 assert!(err.contains("covers 7 rows"), "got: {err}");
310 std::fs::remove_dir_all(&dir).ok();
311 }
312
313 #[test]
314 fn auto_budget_is_exact_below_threshold_and_bounded_above_it() {
315 assert_eq!(auto_designed_budget(1_000), 1_000);
316 assert_eq!(
317 auto_designed_budget(99_999_999),
318 99_999_999,
319 "below the mandatory threshold the budget is the whole corpus"
320 );
321 assert_eq!(
322 auto_designed_budget(100_000_000),
323 DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
324 );
325 assert_eq!(
326 auto_designed_budget(u64::MAX),
327 DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
328 );
329 }
330}