use ndarray::Array2;
use super::object_store::designed_sampling_mandatory;
use super::shard_reader::CorpusRowSource;
use crate::inference::harvest::TieredHarvest;
use gam_solve::row_sampling_measure::{MeasureProvenance, RowSamplingMeasure};
pub const DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS: usize = 2_000_000;
pub fn auto_designed_budget(total_rows: u64) -> usize {
if designed_sampling_mandatory(total_rows) {
DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
} else {
total_rows as usize
}
}
#[derive(Debug, Clone)]
pub struct DesignedCorpusTarget {
pub target: Array2<f64>,
pub row_ids: Vec<u64>,
pub likelihood_weights: Vec<f64>,
pub provenance: MeasureProvenance,
pub corpus_rows: u64,
}
impl DesignedCorpusTarget {
pub fn len(&self) -> usize {
self.row_ids.len()
}
pub fn is_empty(&self) -> bool {
self.row_ids.is_empty()
}
pub fn is_designed_subsample(&self) -> bool {
(self.len() as u64) < self.corpus_rows
}
}
pub fn collect_designed_target(
source: &mut dyn CorpusRowSource,
measure: Option<&RowSamplingMeasure>,
budget: usize,
seed: u64,
) -> Result<DesignedCorpusTarget, String> {
let corpus_rows = source.total_rows();
let p = source.width();
let n = usize::try_from(corpus_rows)
.map_err(|_| "collect_designed_target: corpus row count exceeds usize".to_string())?;
let uniform;
let measure = match measure {
Some(m) => {
if m.n_rows() != n {
return Err(format!(
"collect_designed_target: measure covers {} rows but the corpus has {n}",
m.n_rows()
));
}
m
}
None => {
uniform = RowSamplingMeasure::uniform(n);
&uniform
}
};
let sample = measure.designed_subsample(budget, seed);
let n_sel = sample.rows.len();
let mut target = Array2::<f64>::zeros((n_sel, p));
let mut row_ids = Vec::with_capacity(n_sel);
source.reset();
let mut next_sel = 0usize;
while next_sel < n_sel {
let Some(batch) = source
.next_batch()
.map_err(|e| format!("collect_designed_target: shard read failed: {e}"))?
else {
break;
};
for (k, &rid) in batch.row_ids.iter().enumerate() {
if next_sel >= n_sel {
break;
}
if rid == sample.rows[next_sel] as u64 {
target.row_mut(next_sel).assign(&batch.rows.row(k));
row_ids.push(rid);
next_sel += 1;
}
}
}
if next_sel != n_sel {
return Err(format!(
"collect_designed_target: stream ended after matching {next_sel} of {n_sel} \
designed rows (corpus declared {corpus_rows} rows)"
));
}
Ok(DesignedCorpusTarget {
target,
row_ids,
likelihood_weights: sample.likelihood_weights,
provenance: sample.provenance,
corpus_rows,
})
}
pub fn collect_designed_target_auto(
source: &mut dyn CorpusRowSource,
seed: u64,
) -> Result<DesignedCorpusTarget, String> {
let budget = auto_designed_budget(source.total_rows());
collect_designed_target(source, None, budget, seed)
}
pub fn collect_designed_target_from_harvest(
source: &mut dyn CorpusRowSource,
harvest: &TieredHarvest,
budget: usize,
seed: u64,
) -> Result<DesignedCorpusTarget, String> {
let measure = harvest.corpus_measure();
collect_designed_target(source, Some(&measure), budget, seed)
}
#[cfg(test)]
mod tests {
use super::super::shard_reader::{MmapShardSource, encode_shard_bytes};
use super::*;
use ndarray::Array2 as NdArray2;
use std::io::Write;
use std::path::PathBuf;
fn planted_rows(n: usize, p: usize) -> NdArray2<f64> {
NdArray2::from_shape_fn((n, p), |(i, j)| {
let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
(x.sin() * 43_758.547).fract() * 2.0 - 1.0
})
}
fn temp_shard_dir(name: &str, rows: &NdArray2<f64>, split_at: usize) -> PathBuf {
let mut dir = std::env::temp_dir();
dir.push(format!(
"gam-designed-target-test-{}-{}",
std::process::id(),
name
));
std::fs::create_dir_all(&dir).expect("create dir");
let parts = [
("a.shard", rows.slice(ndarray::s![..split_at, ..])),
("b.shard", rows.slice(ndarray::s![split_at.., ..])),
];
for (key, part) in parts {
let bytes = encode_shard_bytes(part);
let mut f = std::fs::File::create(dir.join(key)).expect("create shard");
f.write_all(&bytes).expect("write shard");
f.sync_all().expect("sync");
}
dir
}
#[test]
fn full_budget_collects_every_row_bit_for_bit_with_unit_weights() {
let n = 137;
let p = 5;
let rows = planted_rows(n, p);
let dir = temp_shard_dir("full", &rows, 60);
let mut src = MmapShardSource::open_dir(&dir).expect("open");
let collected = collect_designed_target_auto(&mut src, 7).expect("collect");
assert!(!collected.is_designed_subsample());
assert_eq!(collected.row_ids, (0..n as u64).collect::<Vec<_>>());
assert!(collected.likelihood_weights.iter().all(|&w| w == 1.0));
let stored = rows.mapv(|v| f64::from(v as f32));
for (a, b) in collected.target.iter().zip(stored.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn designed_budget_collects_exactly_the_designed_rows_with_their_weights() {
let n = 200;
let p = 3;
let rows = planted_rows(n, p);
let dir = temp_shard_dir("designed", &rows, 90);
let mut src = MmapShardSource::open_dir(&dir).expect("open");
let budget = 40usize;
let seed = 17u64;
let collected = collect_designed_target(&mut src, None, budget, seed).expect("collect");
assert!(collected.is_designed_subsample());
let sample = RowSamplingMeasure::uniform(n).designed_subsample(budget, seed);
assert_eq!(
collected.row_ids,
sample.rows.iter().map(|&r| r as u64).collect::<Vec<_>>()
);
assert_eq!(collected.likelihood_weights, sample.likelihood_weights);
let stored = rows.mapv(|v| f64::from(v as f32));
for (k, &rid) in collected.row_ids.iter().enumerate() {
for c in 0..p {
assert_eq!(
collected.target[[k, c]].to_bits(),
stored[[rid as usize, c]].to_bits(),
"row {rid} col {c}"
);
}
}
let again = collect_designed_target(&mut src, None, budget, seed).expect("collect again");
assert_eq!(again.row_ids, collected.row_ids);
for (a, b) in again.target.iter().zip(collected.target.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn measure_dimension_mismatch_is_rejected() {
let rows = planted_rows(20, 2);
let dir = temp_shard_dir("mismatch", &rows, 10);
let mut src = MmapShardSource::open_dir(&dir).expect("open");
let wrong = RowSamplingMeasure::uniform(7);
let err = collect_designed_target(&mut src, Some(&wrong), 5, 1)
.expect_err("mismatched measure must be rejected");
assert!(err.contains("covers 7 rows"), "got: {err}");
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn auto_budget_is_exact_below_threshold_and_bounded_above_it() {
assert_eq!(auto_designed_budget(1_000), 1_000);
assert_eq!(
auto_designed_budget(99_999_999),
99_999_999,
"below the mandatory threshold the budget is the whole corpus"
);
assert_eq!(
auto_designed_budget(100_000_000),
DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
);
assert_eq!(
auto_designed_budget(u64::MAX),
DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
);
}
}