use std::collections::BTreeMap;
use std::io::Write;
use std::path::PathBuf;
use std::process::Command;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct ReferenceResult {
values: BTreeMap<String, Vec<f64>>,
}
impl ReferenceResult {
pub fn scalar(&self, key: &str) -> f64 {
let v = self.vector(key);
assert_eq!(
v.len(),
1,
"reference key {key:?} carried {} values, expected a scalar",
v.len()
);
v[0]
}
pub fn vector(&self, key: &str) -> &[f64] {
let msg = format!(
"reference did not emit key {key:?}; emitted keys: {:?}",
self.keys()
);
self.values.get(key).expect(&msg).as_slice()
}
pub fn keys(&self) -> Vec<&str> {
self.values.keys().map(String::as_str).collect()
}
}
pub struct Column<'a> {
pub name: &'a str,
pub data: &'a [f64],
}
impl<'a> Column<'a> {
pub fn new(name: &'a str, data: &'a [f64]) -> Self {
Self { name, data }
}
}
fn unique_scratch_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let mut dir = std::env::temp_dir();
dir.push(format!(
"gam_reference_{}_{}_{}",
tag,
std::process::id(),
n
));
std::fs::create_dir_all(&dir).expect("create reference scratch dir");
dir
}
fn write_columns_csv(path: &std::path::Path, columns: &[Column<'_>]) {
assert!(
!columns.is_empty(),
"reference run needs at least one column"
);
let nrows = columns[0].data.len();
for c in columns {
assert_eq!(
c.data.len(),
nrows,
"reference column {:?} has {} rows, expected {}",
c.name,
c.data.len(),
nrows
);
}
let mut s = String::new();
s.push_str(
&columns
.iter()
.map(|c| c.name.to_string())
.collect::<Vec<_>>()
.join(","),
);
s.push('\n');
for row in 0..nrows {
let line = columns
.iter()
.map(|c| format!("{:.17e}", c.data[row]))
.collect::<Vec<_>>()
.join(",");
s.push_str(&line);
s.push('\n');
}
let mut f = std::fs::File::create(path).expect("write reference data csv");
f.write_all(s.as_bytes()).expect("flush reference data csv");
}
fn parse_emitted(text: &str) -> BTreeMap<String, Vec<f64>> {
let mut out: BTreeMap<String, Vec<f64>> = BTreeMap::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let Some((key, rest)) = line.split_once(':') else {
continue;
};
let key = key.trim();
if key.is_empty() {
continue;
}
let values: Vec<f64> = rest
.split_whitespace()
.map(|tok| match tok {
"NA" | "na" | "NaN" | "nan" => f64::NAN,
"Inf" | "inf" => f64::INFINITY,
"-Inf" | "-inf" => f64::NEG_INFINITY,
other => other
.parse::<f64>()
.expect("reference emitted an unparsable numeric token"),
})
.collect();
out.insert(key.to_string(), values);
}
out
}
struct ReferenceKind {
tag: &'static str,
script_name: &'static str,
preamble: &'static str,
epilogue: &'static str,
spawn_expect: &'static str,
write_expect: &'static str,
display: &'static str,
}
impl ReferenceKind {
fn r() -> Self {
ReferenceKind {
tag: "r",
script_name: "script.R",
preamble: "\
args <- commandArgs(trailingOnly = TRUE)\n\
df <- read.csv(args[1])\n\
.OUT <- args[2]\n\
emit <- function(key, x) {\n\
cat(sprintf('%s:%s\\n', key, paste(format(as.numeric(x), digits = 17, scientific = TRUE), collapse = ' ')),\n\
file = .OUT, append = TRUE)\n\
}\n",
epilogue: "",
spawn_expect: "spawn Rscript (install R to run reference-comparison tests)",
write_expect: "write reference R script",
display: "R",
}
}
fn python() -> Self {
ReferenceKind {
tag: "py",
script_name: "script.py",
preamble: concat!(
"import sys\n",
"import numpy as np\n",
"_data_csv, _out = sys.argv[1], sys.argv[2]\n",
"try:\n",
" import pandas as pd\n",
" df = pd.read_csv(_data_csv)\n",
"except Exception:\n",
" import csv as _csv\n",
" with open(_data_csv) as _fh:\n",
" _r = _csv.DictReader(_fh)\n",
" _cols = {k: [] for k in _r.fieldnames}\n",
" for _row in _r:\n",
" for _k, _v in _row.items():\n",
" _cols[_k].append(float(_v))\n",
" df = {k: np.asarray(v, dtype=float) for k, v in _cols.items()}\n",
"_lines = []\n",
"def emit(key, x):\n",
" arr = np.asarray(x, dtype=float).reshape(-1)\n",
" _lines.append(str(key) + ':' + ' '.join(repr(float(v)) for v in arr))\n",
),
epilogue: "\nopen(_out, 'w').write('\\n'.join(_lines) + '\\n')\n",
spawn_expect: "spawn python3 (install python3 to run reference-comparison tests)",
write_expect: "write reference python script",
display: "Python",
}
}
fn command(&self) -> Command {
match self.tag {
"r" => {
let mut cmd = Command::new("Rscript");
cmd.arg("--vanilla");
cmd
}
_ => Command::new("python3"),
}
}
}
fn run_subprocess(kind: &ReferenceKind, columns: &[Column<'_>], body: &str) -> ReferenceResult {
let dir = unique_scratch_dir(kind.tag);
let data_csv = dir.join("data.csv");
let out_txt = dir.join("out.txt");
let script = dir.join(kind.script_name);
write_columns_csv(&data_csv, columns);
let preamble = kind.preamble;
let epilogue = kind.epilogue;
let full = format!("{preamble}\n{body}\n{epilogue}");
std::fs::write(&script, full).expect(kind.write_expect);
let output = kind
.command()
.arg(&script)
.arg(&data_csv)
.arg(&out_txt)
.output()
.expect(kind.spawn_expect);
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(
output.status.success(),
"reference {} body failed (status {:?})\n--- stderr ---\n{stderr}\n--- stdout ---\n{stdout}",
kind.display,
output.status.code()
);
let emitted = std::fs::read_to_string(&out_txt).unwrap_or_default();
let parsed = parse_emitted(&emitted);
std::fs::remove_dir_all(&dir).ok();
ReferenceResult { values: parsed }
}
pub fn run_r(columns: &[Column<'_>], body: &str) -> ReferenceResult {
run_subprocess(&ReferenceKind::r(), columns, body)
}
pub fn run_python(columns: &[Column<'_>], body: &str) -> ReferenceResult {
run_subprocess(&ReferenceKind::python(), columns, body)
}
pub fn relative_l2(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "relative_l2 length mismatch");
let num: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
let den: f64 = b.iter().map(|y| y * y).sum();
(num / den.max(1e-300)).sqrt()
}
pub fn rmse(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "rmse length mismatch");
let s: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
(s / a.len().max(1) as f64).sqrt()
}
pub fn max_abs_diff(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "max_abs_diff length mismatch");
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs())
.fold(0.0, f64::max)
}
pub struct DmlPartialLinearReference {
pub available: bool,
pub backend: String,
pub theta: f64,
pub se: f64,
pub ci_lo: f64,
pub ci_hi: f64,
}
pub fn dml_partial_linear_reference(
y: &[f64],
d: &[f64],
x: &[Column<'_>],
n_folds: usize,
) -> DmlPartialLinearReference {
assert!(
!x.is_empty(),
"DML reference needs at least one confounder X"
);
assert_eq!(y.len(), d.len(), "DML reference y/d length mismatch");
let x_names: Vec<String> = x.iter().map(|c| format!("{:?}", c.name)).collect();
let x_list = x_names.join(", ");
let mut columns: Vec<Column<'_>> = Vec::with_capacity(x.len() + 2);
columns.push(Column::new("y", y));
columns.push(Column::new("d", d));
columns.extend(x.iter().map(|c| Column::new(c.name, c.data)));
let body = format!(
r#"
import numpy as np
_xcols = [{x_list}]
Y = np.asarray(df["y"], dtype=float).reshape(-1)
D = np.asarray(df["d"], dtype=float).reshape(-1)
X = np.column_stack([np.asarray(df[c], dtype=float).reshape(-1) for c in _xcols])
n_folds = {n_folds}
def _have(mod):
import importlib.util
return importlib.util.find_spec(mod) is not None
theta = float("nan"); se = float("nan"); backend = 0.0; avail = 0.0
ci_lo = float("nan"); ci_hi = float("nan")
try:
if _have("doubleml") and _have("sklearn"):
import doubleml as dml
from doubleml import DoubleMLData, DoubleMLPLR
from sklearn.ensemble import GradientBoostingRegressor
data = DoubleMLData.from_arrays(X, Y, D)
ml_l = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0)
ml_m = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1)
plr = DoubleMLPLR(data, ml_l=ml_l, ml_m=ml_m, n_folds=n_folds)
plr.fit()
theta = float(np.asarray(plr.coef).reshape(-1)[0])
se = float(np.asarray(plr.se).reshape(-1)[0])
cis = np.asarray(plr.confint(level=0.95))
ci_lo = float(cis.reshape(-1)[0]); ci_hi = float(cis.reshape(-1)[1])
backend = 1.0; avail = 1.0
elif _have("econml") and _have("sklearn"):
from econml.dml import LinearDML
from sklearn.ensemble import GradientBoostingRegressor
est = LinearDML(
model_y=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0),
model_t=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1),
cv=n_folds, random_state=0,
)
est.fit(Y, D, X=None, W=X)
theta = float(np.asarray(est.coef_).reshape(-1)[0]) if np.asarray(est.coef_).size else float(est.intercept_)
inf = est.coef__inference() if hasattr(est, "coef__inference") else est.intercept__inference()
se = float(np.asarray(inf.stderr).reshape(-1)[0])
lohi = inf.conf_int(alpha=0.05)
ci_lo = float(np.asarray(lohi[0]).reshape(-1)[0]); ci_hi = float(np.asarray(lohi[1]).reshape(-1)[0])
backend = 2.0; avail = 1.0
except Exception as _e:
avail = 0.0; backend = 0.0
theta = float("nan"); se = float("nan")
ci_lo = float("nan"); ci_hi = float("nan")
emit("available", [avail])
emit("backend", [backend])
emit("theta", [theta])
emit("se", [se])
emit("ci_lo", [ci_lo])
emit("ci_hi", [ci_hi])
"#
);
let r = run_python(&columns, &body);
let available = r.scalar("available") > 0.5;
let backend = match r.scalar("backend") as i64 {
1 => "doubleml",
2 => "econml",
_ => "none",
}
.to_string();
DmlPartialLinearReference {
available,
backend,
theta: r.scalar("theta"),
se: r.scalar("se"),
ci_lo: r.scalar("ci_lo"),
ci_hi: r.scalar("ci_hi"),
}
}
pub fn pearson(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "pearson length mismatch");
let n = a.len() as f64;
let ma = a.iter().sum::<f64>() / n;
let mb = b.iter().sum::<f64>() / n;
let mut sab = 0.0;
let mut saa = 0.0;
let mut sbb = 0.0;
for (x, y) in a.iter().zip(b) {
let da = x - ma;
let db = y - mb;
sab += da * db;
saa += da * da;
sbb += db * db;
}
sab / (saa.sqrt() * sbb.sqrt()).max(1e-300)
}