use std::collections::BTreeSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DataProvenance {
Measured,
Synthetic,
Mock,
}
impl DataProvenance {
pub fn is_claimable(self) -> bool {
matches!(self, DataProvenance::Measured)
}
pub fn tag(self) -> &'static str {
match self {
DataProvenance::Measured => "measured",
DataProvenance::Synthetic => "synthetic",
DataProvenance::Mock => "mock",
}
}
}
pub const NO_CLAIM: &str = "research use only — not claimable (non-measured data, leaky split, or unmet thresholds)";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Occupancy {
pub present: bool,
pub person_count: u32,
}
impl Occupancy {
pub fn new(present: bool, person_count: u32) -> Self {
Self { present, person_count }
}
}
#[derive(Debug, Clone)]
pub struct LabeledSample {
pub subject_id: String,
pub environment_id: String,
pub truth: Occupancy,
pub predicted: Occupancy,
}
#[derive(Debug, Clone)]
pub struct EvalSplit {
pub train_idx: Vec<usize>,
pub test_idx: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SplitError {
SubjectLeakage(String),
EnvironmentLeakage(String),
IndexOutOfRange(usize),
EmptyTest,
}
impl EvalSplit {
pub fn validate(&self, samples: &[LabeledSample]) -> Result<(), SplitError> {
if self.test_idx.is_empty() {
return Err(SplitError::EmptyTest);
}
for &i in self.train_idx.iter().chain(&self.test_idx) {
if i >= samples.len() {
return Err(SplitError::IndexOutOfRange(i));
}
}
let train_subjects: BTreeSet<&str> =
self.train_idx.iter().map(|&i| samples[i].subject_id.as_str()).collect();
let train_envs: BTreeSet<&str> =
self.train_idx.iter().map(|&i| samples[i].environment_id.as_str()).collect();
for &i in &self.test_idx {
let s = &samples[i];
if train_subjects.contains(s.subject_id.as_str()) {
return Err(SplitError::SubjectLeakage(s.subject_id.clone()));
}
if train_envs.contains(s.environment_id.as_str()) {
return Err(SplitError::EnvironmentLeakage(s.environment_id.clone()));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct BenchmarkCriteria {
pub min_presence_f1: f64,
pub max_count_mae: f64,
pub min_test_samples: usize,
pub min_positive_rate: f64,
pub bootstrap_iters: usize,
pub bootstrap_seed: u64,
}
impl Default for BenchmarkCriteria {
fn default() -> Self {
Self {
min_presence_f1: 0.9,
max_count_mae: 0.5,
min_test_samples: 30,
min_positive_rate: 0.1,
bootstrap_iters: 1000,
bootstrap_seed: 42,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BenchmarkReport {
pub provenance_tag: &'static str,
pub n_test: usize,
pub presence_accuracy: f64,
pub presence_f1: f64,
pub presence_f1_ci: (f64, f64),
pub count_exact_match: f64,
pub count_mae: f64,
pub provenance_pass: bool,
pub split_pass: bool,
pub presence_pass: bool,
pub count_pass: bool,
pub sample_size_pass: bool,
pub class_balance_pass: bool,
pub overall_pass: bool,
pub released_claim: String,
}
impl BenchmarkReport {
pub fn claim(&self) -> &str {
&self.released_claim
}
}
#[inline]
pub fn claim_allowed(
provenance_pass: bool,
split_pass: bool,
sample_size_pass: bool,
class_balance_pass: bool,
presence_pass: bool,
count_pass: bool,
) -> bool {
provenance_pass
&& split_pass
&& sample_size_pass
&& class_balance_pass
&& presence_pass
&& count_pass
}
pub fn evaluate(
samples: &[LabeledSample],
provenance: DataProvenance,
split: &EvalSplit,
criteria: &BenchmarkCriteria,
) -> BenchmarkReport {
let split_pass = split.validate(samples).is_ok();
let test: Vec<&LabeledSample> = split
.test_idx
.iter()
.filter(|&&i| i < samples.len())
.map(|&i| &samples[i])
.collect();
let n_test = test.len();
let (mut tp, mut fp, mut tn, mut fn_) = (0u64, 0u64, 0u64, 0u64);
let mut count_abs_err_sum = 0.0;
let mut count_exact = 0u64;
let mut truth_present = 0u64;
for s in &test {
if s.truth.present {
truth_present += 1;
}
match (s.predicted.present, s.truth.present) {
(true, true) => tp += 1,
(true, false) => fp += 1,
(false, false) => tn += 1,
(false, true) => fn_ += 1,
}
count_abs_err_sum +=
(s.predicted.person_count as f64 - s.truth.person_count as f64).abs();
if s.predicted.person_count == s.truth.person_count {
count_exact += 1;
}
}
let presence_accuracy = if n_test > 0 {
(tp + tn) as f64 / n_test as f64
} else {
0.0
};
let presence_f1 = f1_from_confusion(tp, fp, fn_);
let count_mae = if n_test > 0 {
count_abs_err_sum / n_test as f64
} else {
f64::INFINITY
};
let count_exact_match = if n_test > 0 {
count_exact as f64 / n_test as f64
} else {
0.0
};
let presence_f1_ci = bootstrap_f1_ci(&test, criteria.bootstrap_iters, criteria.bootstrap_seed);
let provenance_pass = provenance.is_claimable();
let sample_size_pass = n_test >= criteria.min_test_samples;
let positive_rate = if n_test > 0 {
truth_present as f64 / n_test as f64
} else {
0.0
};
let class_balance_pass =
n_test > 0 && positive_rate >= criteria.min_positive_rate && truth_present < n_test as u64;
let presence_pass = presence_f1_ci.0 >= criteria.min_presence_f1;
let count_pass = count_mae <= criteria.max_count_mae;
let overall_pass = claim_allowed(
provenance_pass,
split_pass,
sample_size_pass,
class_balance_pass,
presence_pass,
count_pass,
);
let released_claim = if overall_pass {
format!(
"presence F1 {:.3} (95% CI {:.3}-{:.3}), count MAE {:.3} on {} held-out measured samples",
presence_f1, presence_f1_ci.0, presence_f1_ci.1, count_mae, n_test
)
} else {
NO_CLAIM.to_string()
};
BenchmarkReport {
provenance_tag: provenance.tag(),
n_test,
presence_accuracy,
presence_f1,
presence_f1_ci,
count_exact_match,
count_mae,
provenance_pass,
split_pass,
presence_pass,
count_pass,
sample_size_pass,
class_balance_pass,
overall_pass,
released_claim,
}
}
fn f1_from_confusion(tp: u64, fp: u64, fn_: u64) -> f64 {
let denom = 2 * tp + fp + fn_;
if denom == 0 {
return 0.0;
}
(2 * tp) as f64 / denom as f64
}
fn bootstrap_f1_ci(test: &[&LabeledSample], iters: usize, seed: u64) -> (f64, f64) {
let n = test.len();
if n == 0 || iters == 0 {
return (0.0, 0.0);
}
let mut state = seed;
let mut next = || {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
};
let mut f1s = Vec::with_capacity(iters);
for _ in 0..iters {
let (mut tp, mut fp, mut fn_) = (0u64, 0u64, 0u64);
for _ in 0..n {
let idx = (next() % n as u64) as usize;
let s = test[idx];
match (s.predicted.present, s.truth.present) {
(true, true) => tp += 1,
(true, false) => fp += 1,
(false, true) => fn_ += 1,
(false, false) => {}
}
}
f1s.push(f1_from_confusion(tp, fp, fn_));
}
f1s.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let pct = |q: f64| {
let rank = ((q * (f1s.len() as f64 - 1.0)).round() as usize).min(f1s.len() - 1);
f1s[rank]
};
(pct(0.025), pct(0.975))
}
#[cfg(test)]
mod tests {
use super::*;
fn sample(subj: &str, env: &str, t: (bool, u32), p: (bool, u32)) -> LabeledSample {
LabeledSample {
subject_id: subj.into(),
environment_id: env.into(),
truth: Occupancy::new(t.0, t.1),
predicted: Occupancy::new(p.0, p.1),
}
}
fn perfect_measured(n: usize) -> (Vec<LabeledSample>, EvalSplit) {
let mut samples = Vec::new();
for i in 0..n {
samples.push(sample(
&format!("train-s{i}"),
&format!("train-e{i}"),
(i % 2 == 0, (i % 3) as u32),
(i % 2 == 0, (i % 3) as u32),
));
}
for i in 0..n {
samples.push(sample(
&format!("test-s{i}"),
&format!("test-e{i}"),
(i % 2 == 0, (i % 3) as u32),
(i % 2 == 0, (i % 3) as u32),
));
}
let split = EvalSplit {
train_idx: (0..n).collect(),
test_idx: (n..2 * n).collect(),
};
(samples, split)
}
#[test]
fn perfect_measured_releases_claim() {
let (samples, split) = perfect_measured(40);
let r = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert!(r.overall_pass);
assert!((r.presence_f1 - 1.0).abs() < 1e-9);
assert_eq!(r.count_mae, 0.0);
assert!(r.released_claim.contains("F1"));
assert!(!r.released_claim.contains("research use only"));
}
#[test]
fn synthetic_data_is_scored_but_never_claimed() {
let (samples, split) = perfect_measured(40);
let r = evaluate(&samples, DataProvenance::Synthetic, &split, &BenchmarkCriteria::default());
assert!((r.presence_f1 - 1.0).abs() < 1e-9);
assert!(!r.provenance_pass);
assert!(!r.overall_pass);
assert_eq!(r.claim(), NO_CLAIM);
}
#[test]
fn mock_data_is_never_claimed() {
let (samples, split) = perfect_measured(40);
let r = evaluate(&samples, DataProvenance::Mock, &split, &BenchmarkCriteria::default());
assert!(!r.provenance_pass);
assert_eq!(r.claim(), NO_CLAIM);
}
#[test]
fn subject_leakage_is_rejected() {
let samples = vec![
sample("shared", "e0", (true, 1), (true, 1)),
sample("shared", "e1", (true, 1), (true, 1)),
];
let split = EvalSplit { train_idx: vec![0], test_idx: vec![1] };
assert_eq!(
split.validate(&samples),
Err(SplitError::SubjectLeakage("shared".into()))
);
let r = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert!(!r.split_pass);
assert!(!r.overall_pass);
assert_eq!(r.claim(), NO_CLAIM);
}
#[test]
fn environment_leakage_is_rejected() {
let samples = vec![
sample("s0", "shared-room", (true, 1), (true, 1)),
sample("s1", "shared-room", (true, 1), (true, 1)),
];
let split = EvalSplit { train_idx: vec![0], test_idx: vec![1] };
assert_eq!(
split.validate(&samples),
Err(SplitError::EnvironmentLeakage("shared-room".into()))
);
}
#[test]
fn small_sample_is_withheld_even_if_perfect() {
let (samples, split) = perfect_measured(5); let r = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert!(!r.sample_size_pass);
assert!(!r.overall_pass);
}
#[test]
fn gate_uses_ci_lower_bound_not_point_estimate() {
let mut samples = Vec::new();
for i in 0..40 {
samples.push(sample(
&format!("train-{i}"),
&format!("te-{i}"),
(i % 2 == 0, 1),
(i % 2 == 0, 1),
));
}
for i in 0..40 {
let truth_present = i < 20;
let predicted_present = truth_present && i >= 3; samples.push(sample(
&format!("test-{i}"),
&format!("tn-{i}"),
(truth_present, u32::from(truth_present)),
(predicted_present, u32::from(truth_present)),
));
}
let split = EvalSplit { train_idx: (0..40).collect(), test_idx: (40..80).collect() };
let criteria = BenchmarkCriteria::default();
let r = evaluate(&samples, DataProvenance::Measured, &split, &criteria);
assert!(
r.presence_f1 >= criteria.min_presence_f1,
"fixture must put the point estimate ({:.3}) above the threshold",
r.presence_f1
);
assert!(
r.presence_f1_ci.0 < criteria.min_presence_f1,
"fixture must put the CI lower bound ({:.3}) below the threshold",
r.presence_f1_ci.0
);
assert!(!r.presence_pass);
assert!(!r.overall_pass);
assert_eq!(r.claim(), NO_CLAIM);
assert!(r.provenance_pass && r.split_pass && r.sample_size_pass);
assert!(r.class_balance_pass && r.count_pass);
}
#[test]
fn all_absent_test_set_is_degenerate_and_withheld() {
let mut samples = Vec::new();
for i in 0..40 {
samples.push(sample(&format!("tr-{i}"), &format!("te-{i}"), (true, 1), (true, 1)));
}
for i in 0..40 {
samples.push(sample(&format!("ts-{i}"), &format!("ev-{i}"), (false, 0), (false, 0)));
}
let split = EvalSplit { train_idx: (0..40).collect(), test_idx: (40..80).collect() };
let r = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert_eq!(r.presence_f1, 0.0);
assert_eq!(r.presence_f1_ci, (0.0, 0.0));
assert!(!r.class_balance_pass);
assert!(!r.overall_pass);
assert_eq!(r.claim(), NO_CLAIM);
}
#[test]
fn all_present_test_set_is_degenerate_and_withheld() {
let mut samples = Vec::new();
for i in 0..40 {
samples.push(sample(&format!("tr-{i}"), &format!("te-{i}"), (i % 2 == 0, 1), (i % 2 == 0, 1)));
}
for i in 0..40 {
samples.push(sample(&format!("ts-{i}"), &format!("ev-{i}"), (true, 1), (true, 1)));
}
let split = EvalSplit { train_idx: (0..40).collect(), test_idx: (40..80).collect() };
let r = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert!((r.presence_f1 - 1.0).abs() < 1e-9, "metric still computed");
assert!(!r.class_balance_pass, "single-class test set is degenerate");
assert!(!r.overall_pass);
assert_eq!(r.claim(), NO_CLAIM);
}
#[test]
fn bootstrap_ci_is_deterministic() {
let (samples, split) = perfect_measured(40);
let a = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
let b = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert_eq!(a.presence_f1_ci, b.presence_f1_ci);
}
#[test]
fn count_mae_failure_withholds_claim() {
let mut samples = Vec::new();
for i in 0..40 {
samples.push(sample(&format!("tr-{i}"), &format!("te-{i}"), (true, 1), (true, 1)));
}
for i in 0..40 {
let present = i % 2 == 0;
let truth_count = u32::from(present);
samples.push(sample(
&format!("ts-{i}"),
&format!("ev-{i}"),
(present, truth_count),
(present, truth_count + 2),
));
}
let split = EvalSplit { train_idx: (0..40).collect(), test_idx: (40..80).collect() };
let r = evaluate(&samples, DataProvenance::Measured, &split, &BenchmarkCriteria::default());
assert!(r.presence_pass);
assert!(r.class_balance_pass);
assert!(!r.count_pass);
assert!(!r.overall_pass);
}
#[test]
fn claim_invariant_requires_all_six() {
assert!(claim_allowed(true, true, true, true, true, true));
for i in 0..6 {
let v: Vec<bool> = (0..6).map(|j| j != i).collect();
assert!(
!claim_allowed(v[0], v[1], v[2], v[3], v[4], v[5]),
"criterion {i} false must deny the claim"
);
}
}
}