#[derive(Clone, Debug)]
pub struct InterventionShard {
pub row_id: Vec<i64>,
pub atom: Vec<i64>,
pub dose: Vec<f64>,
pub d_dose: usize,
pub nu_hat_1: Vec<f64>,
pub nu_hat_2: Option<Vec<f64>>,
pub nu_measured: Vec<f64>,
pub group: Vec<i64>,
pub is_control: Vec<bool>,
pub layer: i64,
pub seed: u64,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct EvalForeverSplit {
pub train_groups: Vec<i64>,
pub eval_groups: Vec<i64>,
}
#[inline]
fn splitmix64(x: u64) -> u64 {
let mut z = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
impl InterventionShard {
pub fn validate(&self) -> Result<(), String> {
let m = self.row_id.len();
if self.d_dose == 0 {
return Err("InterventionShard: d_dose must be >= 1".to_string());
}
let checks: [(&str, usize); 5] = [
("atom", self.atom.len()),
("nu_hat_1", self.nu_hat_1.len()),
("nu_measured", self.nu_measured.len()),
("group", self.group.len()),
("is_control", self.is_control.len()),
];
for (name, len) in checks {
if len != m {
return Err(format!(
"InterventionShard: {name} has {len} records but row_id has {m}"
));
}
}
if self.dose.len() != m * self.d_dose {
return Err(format!(
"InterventionShard: dose has {} entries; expected m*d = {}*{} = {}",
self.dose.len(),
m,
self.d_dose,
m * self.d_dose
));
}
if let Some(nu2) = &self.nu_hat_2 {
if nu2.len() != m {
return Err(format!(
"InterventionShard: nu_hat_2 has {} records but row_id has {m}",
nu2.len()
));
}
}
for i in 0..m {
let d_row = &self.dose[i * self.d_dose..(i + 1) * self.d_dose];
if !d_row.iter().all(|v| v.is_finite()) {
return Err(format!("InterventionShard: record {i}: non-finite dose"));
}
let zero_dose = d_row.iter().all(|&v| v == 0.0);
if zero_dose != self.is_control[i] {
return Err(format!(
"InterventionShard: record {i}: is_control={} but dose is {}zero \
(the G3 null is defined by the applied dose)",
self.is_control[i],
if zero_dose { "" } else { "non-" }
));
}
for (name, v) in [
("nu_hat_1", self.nu_hat_1[i]),
("nu_measured", self.nu_measured[i]),
] {
if !(v.is_finite() && v >= 0.0) {
return Err(format!(
"InterventionShard: record {i}: {name} must be finite and >= 0; got {v}"
));
}
}
if let Some(nu2) = &self.nu_hat_2 {
if !(nu2[i].is_finite() && nu2[i] >= 0.0) {
return Err(format!(
"InterventionShard: record {i}: nu_hat_2 must be finite and >= 0; got {}",
nu2[i]
));
}
}
}
Ok(())
}
pub fn n_records(&self) -> usize {
self.row_id.len()
}
pub fn eval_forever_split(&self, seed: u64) -> EvalForeverSplit {
let mut train: Vec<i64> = Vec::new();
let mut eval: Vec<i64> = Vec::new();
let mut groups: Vec<i64> = self.group.clone();
groups.sort_unstable();
groups.dedup();
let seed_mix = splitmix64(seed);
for g in groups {
if splitmix64((g as u64) ^ seed_mix) & 1 == 1 {
eval.push(g);
} else {
train.push(g);
}
}
EvalForeverSplit {
train_groups: train,
eval_groups: eval,
}
}
pub fn control_floor_nats(&self, q: f64) -> Result<f64, String> {
if !(q > 0.0 && q < 1.0) {
return Err(format!(
"control_floor_nats: quantile must be in (0, 1); got {q}"
));
}
let mut nulls: Vec<f64> = self
.nu_measured
.iter()
.zip(self.is_control.iter())
.filter_map(|(&v, &c)| c.then_some(v))
.collect();
if nulls.is_empty() {
return Err(
"control_floor_nats: shard has no Δt = 0 control records; the G3 floor \
must be estimated from controls, never assumed"
.to_string(),
);
}
nulls.sort_by(|a, b| a.partial_cmp(b).expect("validated finite"));
let h = q * (nulls.len() as f64 - 1.0);
let lo = h.floor() as usize;
let hi = h.ceil() as usize;
let frac = h - lo as f64;
Ok(nulls[lo] * (1.0 - frac) + nulls[hi] * frac)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_shard() -> InterventionShard {
InterventionShard {
row_id: vec![0, 1, 2, 3],
atom: vec![0, 0, 1, 1],
dose: vec![0.1, 0.0, -0.2, 0.0],
d_dose: 1,
nu_hat_1: vec![0.5, 0.0, 0.8, 0.0],
nu_hat_2: None,
nu_measured: vec![0.45, 1e-6, 0.7, 2e-6],
group: vec![10, 10, 20, 20],
is_control: vec![false, true, false, true],
layer: 17,
seed: 0,
}
}
#[test]
fn valid_shard_passes() {
assert!(tiny_shard().validate().is_ok());
}
#[test]
fn mislabeled_control_is_a_hard_error() {
let mut s = tiny_shard();
s.is_control[0] = true; assert!(s.validate().unwrap_err().contains("is_control"));
}
#[test]
fn zero_dose_without_control_flag_is_a_hard_error() {
let mut s = tiny_shard();
s.is_control[1] = false; assert!(s.validate().unwrap_err().contains("is_control"));
}
#[test]
fn negative_measured_kl_rejected() {
let mut s = tiny_shard();
s.nu_measured[0] = -0.1;
assert!(s.validate().unwrap_err().contains("nu_measured"));
}
#[test]
fn split_is_deterministic_and_partitions_groups() {
let s = tiny_shard();
let a = s.eval_forever_split(7);
let b = s.eval_forever_split(7);
assert_eq!(a, b);
let mut all: Vec<i64> = a
.train_groups
.iter()
.chain(a.eval_groups.iter())
.copied()
.collect();
all.sort_unstable();
assert_eq!(all, vec![10, 20]);
}
#[test]
fn split_is_per_group_stable_under_shard_growth() {
let s = tiny_shard();
let before = s.eval_forever_split(3);
let mut grown = s.clone();
grown.row_id.extend([4, 5]);
grown.atom.extend([2, 2]);
grown.dose.extend([0.3, 0.0]);
grown.nu_hat_1.extend([0.2, 0.0]);
grown.nu_measured.extend([0.15, 1e-6]);
grown.group.extend([30, 30]);
grown.is_control.extend([false, true]);
grown.validate().unwrap();
let after = grown.eval_forever_split(3);
for g in &before.train_groups {
assert!(after.train_groups.contains(g), "group {g} left train");
}
for g in &before.eval_groups {
assert!(after.eval_groups.contains(g), "group {g} left eval");
}
}
#[test]
fn control_floor_is_a_control_quantile() {
let s = tiny_shard();
let f = s.control_floor_nats(0.5).unwrap();
assert!((f - 1.5e-6).abs() < 1e-12);
}
#[test]
fn control_floor_requires_controls() {
let mut s = tiny_shard();
s.is_control = vec![false; 4];
s.dose = vec![0.1, 0.2, -0.2, 0.4];
s.validate().unwrap();
assert!(s.control_floor_nats(0.5).unwrap_err().contains("control"));
}
#[test]
fn splitmix_reference_values_pin_the_cross_language_contract() {
assert_eq!(splitmix64(0), 0xE220_A839_7B1D_CDAF);
assert_eq!(splitmix64(1), 0x910A_2DEC_8902_5CC1);
}
}