use std::collections::BTreeMap;
use std::io::Write;
use std::path::PathBuf;
use std::process::Command;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct ReferenceResult {
values: BTreeMap<String, Vec<f64>>,
}
impl ReferenceResult {
pub fn scalar(&self, key: &str) -> f64 {
let v = self.vector(key);
assert_eq!(
v.len(),
1,
"reference key {key:?} carried {} values, expected a scalar",
v.len()
);
v[0]
}
pub fn vector(&self, key: &str) -> &[f64] {
let msg = format!(
"reference did not emit key {key:?}; emitted keys: {:?}",
self.keys()
);
self.values.get(key).expect(&msg).as_slice()
}
pub fn keys(&self) -> Vec<&str> {
self.values.keys().map(String::as_str).collect()
}
}
pub struct Column<'a> {
pub name: &'a str,
pub data: &'a [f64],
}
impl<'a> Column<'a> {
pub fn new(name: &'a str, data: &'a [f64]) -> Self {
Self { name, data }
}
}
fn unique_scratch_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let mut dir = std::env::temp_dir();
dir.push(format!(
"gam_reference_{}_{}_{}",
tag,
std::process::id(),
n
));
std::fs::create_dir_all(&dir).expect("create reference scratch dir");
dir
}
fn write_columns_csv(path: &std::path::Path, columns: &[Column<'_>]) {
assert!(
!columns.is_empty(),
"reference run needs at least one column"
);
let nrows = columns.iter().map(|c| c.data.len()).max().unwrap_or(0);
assert!(
nrows > 0,
"reference run needs at least one non-empty column"
);
let mut s = String::new();
s.push_str(
&columns
.iter()
.map(|c| c.name.to_string())
.collect::<Vec<_>>()
.join(","),
);
s.push('\n');
for row in 0..nrows {
let line = columns
.iter()
.map(|c| match c.data.get(row) {
Some(value) => format!("{value:.17e}"),
None => "NA".to_string(),
})
.collect::<Vec<_>>()
.join(",");
s.push_str(&line);
s.push('\n');
}
let mut f = std::fs::File::create(path).expect("write reference data csv");
f.write_all(s.as_bytes()).expect("flush reference data csv");
}
fn parse_emitted(text: &str) -> BTreeMap<String, Vec<f64>> {
let mut out: BTreeMap<String, Vec<f64>> = BTreeMap::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let Some((key, rest)) = line.split_once(':') else {
continue;
};
let key = key.trim();
if key.is_empty() {
continue;
}
let values: Vec<f64> = rest
.split_whitespace()
.map(|tok| match tok {
"NA" | "na" | "NaN" | "nan" => f64::NAN,
"Inf" | "inf" => f64::INFINITY,
"-Inf" | "-inf" => f64::NEG_INFINITY,
other => other
.parse::<f64>()
.expect("reference emitted an unparsable numeric token"),
})
.collect();
out.insert(key.to_string(), values);
}
out
}
struct ReferenceKind {
tag: &'static str,
script_name: &'static str,
preamble: &'static str,
epilogue: &'static str,
spawn_expect: &'static str,
write_expect: &'static str,
display: &'static str,
}
impl ReferenceKind {
fn r() -> Self {
ReferenceKind {
tag: "r",
script_name: "script.R",
preamble: "\
args <- commandArgs(trailingOnly = TRUE)\n\
df <- read.csv(args[1])\n\
.OUT <- args[2]\n\
emit <- function(key, x) {\n\
cat(sprintf('%s:%s\\n', key, paste(format(as.numeric(x), digits = 17, scientific = TRUE), collapse = ' ')),\n\
file = .OUT, append = TRUE)\n\
}\n",
epilogue: "",
spawn_expect: "spawn Rscript (install R to run reference-comparison tests)",
write_expect: "write reference R script",
display: "R",
}
}
fn python() -> Self {
ReferenceKind {
tag: "py",
script_name: "script.py",
preamble: concat!(
"import sys\n",
"import numpy as np\n",
"_data_csv, _out = sys.argv[1], sys.argv[2]\n",
"try:\n",
" import pandas as pd\n",
" df = pd.read_csv(_data_csv)\n",
"except Exception:\n",
" import csv as _csv\n",
" with open(_data_csv) as _fh:\n",
" _r = _csv.DictReader(_fh)\n",
" _cols = {k: [] for k in _r.fieldnames}\n",
" for _row in _r:\n",
" for _k, _v in _row.items():\n",
" _cols[_k].append(float(_v))\n",
" df = {k: np.asarray(v, dtype=float) for k, v in _cols.items()}\n",
"_lines = []\n",
"def emit(key, x):\n",
" arr = np.asarray(x, dtype=float).reshape(-1)\n",
" _lines.append(str(key) + ':' + ' '.join(repr(float(v)) for v in arr))\n",
),
epilogue: "\nopen(_out, 'w').write('\\n'.join(_lines) + '\\n')\n",
spawn_expect: "spawn python3 (install python3 to run reference-comparison tests)",
write_expect: "write reference python script",
display: "Python",
}
}
fn command(&self) -> Command {
match self.tag {
"r" => {
let mut cmd = Command::new("Rscript");
cmd.arg("--vanilla");
cmd
}
_ => Command::new("python3"),
}
}
}
fn run_subprocess(kind: &ReferenceKind, columns: &[Column<'_>], body: &str) -> ReferenceResult {
let dir = unique_scratch_dir(kind.tag);
let data_csv = dir.join("data.csv");
let out_txt = dir.join("out.txt");
let script = dir.join(kind.script_name);
write_columns_csv(&data_csv, columns);
let preamble = kind.preamble;
let epilogue = kind.epilogue;
let full = format!("{preamble}\n{body}\n{epilogue}");
std::fs::write(&script, full).expect(kind.write_expect);
let output = kind
.command()
.arg(&script)
.arg(&data_csv)
.arg(&out_txt)
.output()
.expect(kind.spawn_expect);
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(
output.status.success(),
"reference {} body failed (status {:?})\n--- stderr ---\n{stderr}\n--- stdout ---\n{stdout}",
kind.display,
output.status.code()
);
let emitted = std::fs::read_to_string(&out_txt).unwrap_or_default();
let parsed = parse_emitted(&emitted);
std::fs::remove_dir_all(&dir).ok();
ReferenceResult { values: parsed }
}
pub fn run_r(columns: &[Column<'_>], body: &str) -> ReferenceResult {
run_subprocess(&ReferenceKind::r(), columns, body)
}
pub fn r_package_available(pkg: &str) -> bool {
if ReferenceKind::r()
.command()
.arg("--version")
.output()
.is_err()
{
return false;
}
let script =
format!("cat(if (requireNamespace(\"{pkg}\", quietly = TRUE)) \"1\\n\" else \"0\\n\")");
let output = match Command::new("Rscript")
.arg("--vanilla")
.arg("-e")
.arg(script)
.output()
{
Ok(output) => output,
Err(_) => return false,
};
output.status.success() && String::from_utf8_lossy(&output.stdout).trim() == "1"
}
pub fn run_python(columns: &[Column<'_>], body: &str) -> ReferenceResult {
run_subprocess(&ReferenceKind::python(), columns, body)
}
pub fn relative_l2(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "relative_l2 length mismatch");
let num: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
let den: f64 = b.iter().map(|y| y * y).sum();
(num / den.max(1e-300)).sqrt()
}
pub fn rmse(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "rmse length mismatch");
let s: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
(s / a.len().max(1) as f64).sqrt()
}
pub fn r2(pred: &[f64], truth: &[f64]) -> f64 {
assert_eq!(pred.len(), truth.len(), "r2 length mismatch");
let n = truth.len() as f64;
let mean = truth.iter().sum::<f64>() / n;
let ss_res: f64 = pred.iter().zip(truth).map(|(p, t)| (t - p) * (t - p)).sum();
let ss_tot: f64 = truth.iter().map(|t| (t - mean) * (t - mean)).sum();
1.0 - ss_res / ss_tot.max(1e-300)
}
pub fn held_out_r2(pred: &[f64], truth: &[f64]) -> f64 {
r2(pred, truth)
}
pub fn pad_to(v: &[f64], len: usize) -> Vec<f64> {
assert!(
v.len() <= len,
"pad_to cannot shrink: source has {} rows but the pad target is {len} \
(a shorter target would drop real data). Pad every column to a common \
n = max(column length) and slice by semantic length in the reference body.",
v.len()
);
let fill = v.last().copied().unwrap_or(0.0);
let mut out = v.to_vec();
out.resize(len, fill);
out
}
pub fn max_abs_diff(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "max_abs_diff length mismatch");
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs())
.fold(0.0, f64::max)
}
#[derive(Clone, Debug)]
pub struct QualityDiagnostics {
pub label: String,
pub rmse_vs_truth: Option<f64>,
pub rmse_vs_reference: Option<f64>,
pub reference_rmse_vs_truth: Option<f64>,
pub edf_total: Option<f64>,
pub rho: Vec<f64>,
pub lambda: Vec<f64>,
pub design: Option<DesignDiagnostics>,
pub penalties: Vec<PenaltyDiagnostics>,
pub prediction: Option<PredictionFingerprint>,
}
#[derive(Clone, Debug)]
pub struct DesignDiagnostics {
pub nrows: usize,
pub ncols: usize,
pub rank: usize,
pub condition: f64,
pub sigma_min: f64,
pub sigma_max: f64,
}
#[derive(Clone, Debug)]
pub struct PenaltyDiagnostics {
pub index: usize,
pub col_start: usize,
pub col_end: usize,
pub rank: usize,
pub lambda: Option<f64>,
pub eig_min: f64,
pub eig_max: f64,
pub trace: f64,
}
#[derive(Clone, Debug)]
pub struct PredictionFingerprint {
pub n: usize,
pub mean: f64,
pub sd: f64,
pub min: f64,
pub max: f64,
pub first: f64,
pub last: f64,
}
impl QualityDiagnostics {
pub fn from_standard_fit(label: impl Into<String>, fit: &crate::StandardFitResult) -> Self {
let design = design_diagnostics(&fit.design.design).ok();
let penalties = penalty_diagnostics(
&fit.design.penalties,
fit.fit.lambdas.as_slice().unwrap_or(&[]),
);
Self {
label: label.into(),
rmse_vs_truth: None,
rmse_vs_reference: None,
reference_rmse_vs_truth: None,
edf_total: fit.fit.inference.as_ref().map(|i| i.edf_total),
rho: fit.fit.log_lambdas.to_vec(),
lambda: fit.fit.lambdas.to_vec(),
design,
penalties,
prediction: None,
}
}
pub fn with_truth_rmse(mut self, pred: &[f64], truth: &[f64]) -> Self {
self.rmse_vs_truth = Some(rmse(pred, truth));
self.prediction = Some(prediction_fingerprint(pred));
self
}
pub fn with_reference_gap(
mut self,
pred: &[f64],
reference: &[f64],
truth: Option<&[f64]>,
) -> Self {
self.rmse_vs_reference = Some(rmse(pred, reference));
if let Some(truth) = truth {
self.reference_rmse_vs_truth = Some(rmse(reference, truth));
}
self
}
pub fn report(&self) -> String {
let mut out = format!("[quality-diagnostics] label={}", self.label);
if let Some(v) = self.rmse_vs_truth {
out.push_str(&format!(" rmse_truth={v:.6}"));
}
if let Some(v) = self.rmse_vs_reference {
out.push_str(&format!(" rmse_reference_gap={v:.6}"));
}
if let Some(v) = self.reference_rmse_vs_truth {
out.push_str(&format!(" reference_rmse_truth={v:.6}"));
}
if let Some(v) = self.edf_total {
out.push_str(&format!(" edf_total={v:.3}"));
}
if !self.rho.is_empty() {
out.push_str(&format!(" rho={:?}", Rounded(&self.rho)));
}
if !self.lambda.is_empty() {
out.push_str(&format!(" lambda={:?}", Rounded(&self.lambda)));
}
if let Some(d) = &self.design {
out.push_str(&format!(
" design={}x{} rank={} cond={:.3e} sigma=[{:.3e},{:.3e}]",
d.nrows, d.ncols, d.rank, d.condition, d.sigma_min, d.sigma_max
));
}
if let Some(p) = &self.prediction {
out.push_str(&format!(
" pred[n={} mean={:.4} sd={:.4} range=[{:.4},{:.4}] edge=[{:.4},{:.4}]]",
p.n, p.mean, p.sd, p.min, p.max, p.first, p.last
));
}
if !self.penalties.is_empty() {
out.push_str(" penalties=");
for p in &self.penalties {
out.push_str(&format!(
" #{} cols={}..{} rank={} lambda={} eig=[{:.3e},{:.3e}] tr={:.3e};",
p.index,
p.col_start,
p.col_end,
p.rank,
p.lambda
.map(|v| format!("{v:.3e}"))
.unwrap_or_else(|| "NA".into()),
p.eig_min,
p.eig_max,
p.trace
));
}
}
out
}
}
struct Rounded<'a>(&'a [f64]);
impl std::fmt::Debug for Rounded<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("[")?;
for (i, v) in self.0.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{v:.3e}")?;
}
f.write_str("]")
}
}
pub fn prediction_fingerprint(values: &[f64]) -> PredictionFingerprint {
let n = values.len();
let mean = values.iter().sum::<f64>() / n.max(1) as f64;
let var = values.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n.max(1) as f64;
PredictionFingerprint {
n,
mean,
sd: var.sqrt(),
min: values.iter().copied().fold(f64::INFINITY, f64::min),
max: values.iter().copied().fold(f64::NEG_INFINITY, f64::max),
first: values.first().copied().unwrap_or(f64::NAN),
last: values.last().copied().unwrap_or(f64::NAN),
}
}
pub fn design_diagnostics(
design: &crate::matrix::DesignMatrix,
) -> Result<DesignDiagnostics, String> {
use crate::faer_ndarray::FaerSvd;
let dense = design
.try_to_dense_by_chunks_budgeted("quality diagnostics design SVD", 256 * 1024 * 1024)?;
let (_u, s, _vt) = dense.svd(false, false).map_err(|e| e.to_string())?;
let sigma_max = s.iter().copied().fold(0.0, f64::max);
let tol = (design.nrows().max(design.ncols()) as f64) * f64::EPSILON * sigma_max.max(1.0);
let rank = s.iter().filter(|&&v| v > tol).count();
let sigma_min = s
.iter()
.copied()
.filter(|v| *v > tol)
.fold(0.0_f64, |a, v| if a == 0.0 { v } else { a.min(v) });
Ok(DesignDiagnostics {
nrows: design.nrows(),
ncols: design.ncols(),
rank,
condition: if sigma_min > 0.0 {
sigma_max / sigma_min
} else {
f64::INFINITY
},
sigma_min,
sigma_max,
})
}
pub fn penalty_diagnostics(
penalties: &[crate::terms::smooth::BlockwisePenalty],
lambdas: &[f64],
) -> Vec<PenaltyDiagnostics> {
use crate::faer_ndarray::FaerEigh;
use faer::Side;
penalties
.iter()
.enumerate()
.map(|(index, p)| {
let evals = p
.local
.eigh(Side::Lower)
.map(|(e, _)| e)
.unwrap_or_else(|_| ndarray::Array1::from_vec(vec![f64::NAN]));
let scale = evals
.iter()
.copied()
.map(f64::abs)
.fold(0.0, f64::max)
.max(1.0);
let tol = scale * 1.0e-10;
let rank = evals.iter().filter(|&&v| v > tol).count();
PenaltyDiagnostics {
index,
col_start: p.col_range.start,
col_end: p.col_range.end,
rank,
lambda: lambdas.get(index).copied(),
eig_min: evals.iter().copied().fold(f64::INFINITY, f64::min),
eig_max: evals.iter().copied().fold(f64::NEG_INFINITY, f64::max),
trace: p.local.diag().sum(),
}
})
.collect()
}
pub struct DmlPartialLinearReference {
pub available: bool,
pub backend: String,
pub theta: f64,
pub se: f64,
pub ci_lo: f64,
pub ci_hi: f64,
}
pub fn dml_partial_linear_reference(
y: &[f64],
d: &[f64],
x: &[Column<'_>],
n_folds: usize,
) -> DmlPartialLinearReference {
assert!(
!x.is_empty(),
"DML reference needs at least one confounder X"
);
assert_eq!(y.len(), d.len(), "DML reference y/d length mismatch");
let x_names: Vec<String> = x.iter().map(|c| format!("{:?}", c.name)).collect();
let x_list = x_names.join(", ");
let mut columns: Vec<Column<'_>> = Vec::with_capacity(x.len() + 2);
columns.push(Column::new("y", y));
columns.push(Column::new("d", d));
columns.extend(x.iter().map(|c| Column::new(c.name, c.data)));
let body = format!(
r#"
import numpy as np
_xcols = [{x_list}]
Y = np.asarray(df["y"], dtype=float).reshape(-1)
D = np.asarray(df["d"], dtype=float).reshape(-1)
X = np.column_stack([np.asarray(df[c], dtype=float).reshape(-1) for c in _xcols])
n_folds = {n_folds}
def _have(mod):
import importlib.util
return importlib.util.find_spec(mod) is not None
theta = float("nan"); se = float("nan"); backend = 0.0; avail = 0.0
ci_lo = float("nan"); ci_hi = float("nan")
try:
if _have("doubleml") and _have("sklearn"):
import doubleml as dml
from doubleml import DoubleMLData, DoubleMLPLR
from sklearn.ensemble import GradientBoostingRegressor
data = DoubleMLData.from_arrays(X, Y, D)
ml_l = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0)
ml_m = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1)
plr = DoubleMLPLR(data, ml_l=ml_l, ml_m=ml_m, n_folds=n_folds)
plr.fit()
theta = float(np.asarray(plr.coef).reshape(-1)[0])
se = float(np.asarray(plr.se).reshape(-1)[0])
cis = np.asarray(plr.confint(level=0.95))
ci_lo = float(cis.reshape(-1)[0]); ci_hi = float(cis.reshape(-1)[1])
backend = 1.0; avail = 1.0
elif _have("econml") and _have("sklearn"):
from econml.dml import LinearDML
from sklearn.ensemble import GradientBoostingRegressor
est = LinearDML(
model_y=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0),
model_t=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1),
cv=n_folds, random_state=0,
)
est.fit(Y, D, X=None, W=X)
theta = float(np.asarray(est.coef_).reshape(-1)[0]) if np.asarray(est.coef_).size else float(est.intercept_)
inf = est.coef__inference() if hasattr(est, "coef__inference") else est.intercept__inference()
se = float(np.asarray(inf.stderr).reshape(-1)[0])
lohi = inf.conf_int(alpha=0.05)
ci_lo = float(np.asarray(lohi[0]).reshape(-1)[0]); ci_hi = float(np.asarray(lohi[1]).reshape(-1)[0])
backend = 2.0; avail = 1.0
except Exception as _e:
avail = 0.0; backend = 0.0
theta = float("nan"); se = float("nan")
ci_lo = float("nan"); ci_hi = float("nan")
emit("available", [avail])
emit("backend", [backend])
emit("theta", [theta])
emit("se", [se])
emit("ci_lo", [ci_lo])
emit("ci_hi", [ci_hi])
"#
);
let r = run_python(&columns, &body);
let available = r.scalar("available") > 0.5;
let backend = match r.scalar("backend") as i64 {
1 => "doubleml",
2 => "econml",
_ => "none",
}
.to_string();
DmlPartialLinearReference {
available,
backend,
theta: r.scalar("theta"),
se: r.scalar("se"),
ci_lo: r.scalar("ci_lo"),
ci_hi: r.scalar("ci_hi"),
}
}
pub fn pearson(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "pearson length mismatch");
let n = a.len() as f64;
let ma = a.iter().sum::<f64>() / n;
let mb = b.iter().sum::<f64>() / n;
let mut sab = 0.0;
let mut saa = 0.0;
let mut sbb = 0.0;
for (x, y) in a.iter().zip(b) {
let da = x - ma;
let db = y - mb;
sab += da * db;
saa += da * da;
sbb += db * db;
}
sab / (saa.sqrt() * sbb.sqrt()).max(1e-300)
}
#[cfg(test)]
mod pad_to_tests {
use super::pad_to;
#[test]
#[should_panic(expected = "pad_to cannot shrink")]
fn shrink_to_train_split_is_a_clear_error() {
let full = vec![1.0; 654];
drop(pad_to(&full, 490));
}
#[test]
fn pad_full_and_train_to_common_n_is_consistent() {
let n = 654usize;
let n_train = 490usize;
let full: Vec<f64> = (0..n).map(|i| i as f64).collect();
let train: Vec<f64> = (0..n_train).map(|i| (1000 + i) as f64).collect();
let full_wire = pad_to(&full, n);
let train_wire = pad_to(&train, n);
assert_eq!(full_wire.len(), n);
assert_eq!(train_wire.len(), n);
assert_eq!(&full_wire[..n], &full[..]);
assert_eq!(&train_wire[..n_train], &train[..]);
assert_eq!(train_wire[n_train], train[n_train - 1]);
assert_eq!(train_wire[n - 1], train[n_train - 1]);
}
}