Skip to main content

gam_test_support/
reference.rs

1//! End-to-end quality comparison against mature, standard statistical tools.
2//!
3//! The harness lets a `cargo test` integration test fit the *same* data with a
4//! trusted reference implementation and assert that gam's fitted function,
5//! coefficients, effective degrees of freedom, predictions, or uncertainty
6//! agree with what practitioners already trust. It is deliberately
7//! tool-agnostic: a test supplies an arbitrary R or Python body and the harness
8//! handles all of the data plumbing and result parsing.
9//!
10//! Reference toolchains supported today:
11//!   * **R** via `Rscript` — `mgcv`, `gamlss`, `survival`, and any package the
12//!     body chooses to `library()`.
13//!   * **Python** via `python3` — `scikit-learn`, `scipy`, `statsmodels`,
14//!     `lifelines`, `scikit-survival`, and anything else importable.
15//!
16//! There is **no skip path**. If the interpreter or a required package is not
17//! installed, `run_r`/`run_python` fail loudly and the test fails — a missing
18//! reference dependency is a real failure, not a silent pass. CI is expected to
19//! provision the reference stack. (Only genuine hardware gates, e.g. CUDA, are
20//! allowed to skip; that lives in `tests/common/gpu_gate.rs`, not here.)
21//!
22//! Wire protocol (kept dependency-free on purpose — no JSON crate on the R/
23//! Python side): the test body calls `emit("key", numeric_vector)` for every
24//! quantity it wants to return. The harness reads these back as
25//! `key: v1 v2 v3 ...` lines and exposes them as `f64` scalars / vectors.
26
27use std::collections::BTreeMap;
28use std::io::Write;
29use std::path::PathBuf;
30use std::process::Command;
31use std::sync::atomic::{AtomicU64, Ordering};
32
33/// Parsed results emitted by a reference-tool body via `emit(key, values)`.
34pub struct ReferenceResult {
35    values: BTreeMap<String, Vec<f64>>,
36}
37
38impl ReferenceResult {
39    /// Fetch a single scalar emitted under `key`. Fails the test loudly when
40    /// the key is missing or did not carry exactly one value.
41    pub fn scalar(&self, key: &str) -> f64 {
42        let v = self.vector(key);
43        assert_eq!(
44            v.len(),
45            1,
46            "reference key {key:?} carried {} values, expected a scalar",
47            v.len()
48        );
49        v[0]
50    }
51
52    /// Fetch the vector emitted under `key`. Fails the test when the key is
53    /// missing.
54    pub fn vector(&self, key: &str) -> &[f64] {
55        let msg = format!(
56            "reference did not emit key {key:?}; emitted keys: {:?}",
57            self.keys()
58        );
59        self.values.get(key).expect(&msg).as_slice()
60    }
61
62    /// Keys the reference body emitted, for diagnostics.
63    pub fn keys(&self) -> Vec<&str> {
64        self.values.keys().map(String::as_str).collect()
65    }
66}
67
68/// A named numeric column handed to the reference body as a `data.frame` column
69/// (R) or a NumPy array `df["name"]` (Python).
70pub struct Column<'a> {
71    /// Column header, referenced verbatim inside the reference body.
72    pub name: &'a str,
73    /// Column values, one per row. Length must match across all columns.
74    pub data: &'a [f64],
75}
76
77impl<'a> Column<'a> {
78    /// Convenience constructor.
79    pub fn new(name: &'a str, data: &'a [f64]) -> Self {
80        Self { name, data }
81    }
82}
83
84fn unique_scratch_dir(tag: &str) -> PathBuf {
85    static COUNTER: AtomicU64 = AtomicU64::new(0);
86    let n = COUNTER.fetch_add(1, Ordering::Relaxed);
87    let mut dir = std::env::temp_dir();
88    dir.push(format!(
89        "gam_reference_{}_{}_{}",
90        tag,
91        std::process::id(),
92        n
93    ));
94    std::fs::create_dir_all(&dir).expect("create reference scratch dir");
95    dir
96}
97
98fn write_columns_csv(path: &std::path::Path, columns: &[Column<'_>]) {
99    assert!(
100        !columns.is_empty(),
101        "reference run needs at least one column"
102    );
103    // Ragged columns are first-class: tests routinely ship a training column
104    // (n rows), a grid column (grid_n rows), and a scalar option (1 row) in
105    // the same data.frame, and the reference bodies dereference the surplus
106    // tail with `is.finite(...)` / NaN filters on their side. The CSV row
107    // grid runs from row 0 to `nrows = max column length`; shorter columns
108    // emit `NaN` past their own length so every column appears at its
109    // natural width to the reference interpreter.
110    let nrows = columns.iter().map(|c| c.data.len()).max().unwrap_or(0);
111    assert!(
112        nrows > 0,
113        "reference run needs at least one non-empty column"
114    );
115    let mut s = String::new();
116    s.push_str(
117        &columns
118            .iter()
119            .map(|c| c.name.to_string())
120            .collect::<Vec<_>>()
121            .join(","),
122    );
123    s.push('\n');
124    for row in 0..nrows {
125        let line = columns
126            .iter()
127            .map(|c| match c.data.get(row) {
128                // `NA` is the missing-value token recognised by R
129                // `read.csv` (its default `na.strings`) AND by pandas
130                // `read_csv` (its default `na_values` list). Both
131                // produce IEEE NaN downstream, so the reference body's
132                // `is.finite(...)` / `np.isfinite(...)` mask filters
133                // out the surplus tail of a short column cleanly.
134                Some(value) => format!("{value:.17e}"),
135                None => "NA".to_string(),
136            })
137            .collect::<Vec<_>>()
138            .join(",");
139        s.push_str(&line);
140        s.push('\n');
141    }
142    let mut f = std::fs::File::create(path).expect("write reference data csv");
143    f.write_all(s.as_bytes()).expect("flush reference data csv");
144}
145
146fn parse_emitted(text: &str) -> BTreeMap<String, Vec<f64>> {
147    let mut out: BTreeMap<String, Vec<f64>> = BTreeMap::new();
148    for line in text.lines() {
149        let line = line.trim();
150        if line.is_empty() {
151            continue;
152        }
153        let Some((key, rest)) = line.split_once(':') else {
154            continue;
155        };
156        let key = key.trim();
157        if key.is_empty() {
158            continue;
159        }
160        let values: Vec<f64> = rest
161            .split_whitespace()
162            .map(|tok| match tok {
163                "NA" | "na" | "NaN" | "nan" => f64::NAN,
164                "Inf" | "inf" => f64::INFINITY,
165                "-Inf" | "-inf" => f64::NEG_INFINITY,
166                other => other
167                    .parse::<f64>()
168                    .expect("reference emitted an unparsable numeric token"),
169            })
170            .collect();
171        out.insert(key.to_string(), values);
172    }
173    out
174}
175
176/// Per-language specifics for a reference subprocess run: scratch-dir tag,
177/// script filename, the interpreter `Command` (with any fixed leading args
178/// already applied), the script preamble/epilogue wrapped around the test body,
179/// and the human-readable name used in failure messages.
180struct ReferenceKind {
181    /// Short tag for the scratch directory (`"r"`, `"py"`).
182    tag: &'static str,
183    /// Script filename written into the scratch directory.
184    script_name: &'static str,
185    /// Code prepended before the test body (exposes `df`, output path, `emit`).
186    preamble: &'static str,
187    /// Code appended after the test body (e.g. Python's flush-to-file step).
188    epilogue: &'static str,
189    /// `.expect` message text for the spawn failure.
190    spawn_expect: &'static str,
191    /// `.expect` message text for the script-write failure.
192    write_expect: &'static str,
193    /// Human-readable language name used in the non-zero-exit assertion.
194    display: &'static str,
195}
196
197impl ReferenceKind {
198    fn r() -> Self {
199        ReferenceKind {
200            tag: "r",
201            script_name: "script.R",
202            preamble: "\
203args <- commandArgs(trailingOnly = TRUE)\n\
204df <- read.csv(args[1])\n\
205.OUT <- args[2]\n\
206emit <- function(key, x) {\n\
207  cat(sprintf('%s:%s\\n', key, paste(format(as.numeric(x), digits = 17, scientific = TRUE), collapse = ' ')),\n\
208      file = .OUT, append = TRUE)\n\
209}\n",
210            epilogue: "",
211            spawn_expect: "spawn Rscript (install R to run reference-comparison tests)",
212            write_expect: "write reference R script",
213            display: "R",
214        }
215    }
216
217    fn python() -> Self {
218        ReferenceKind {
219            tag: "py",
220            script_name: "script.py",
221            // NOTE: built with concat! of per-line literals, NOT the `"\<newline>`
222            // continuation idiom. Rust's `\<newline>` continuation strips the
223            // leading whitespace of the following source line, which silently
224            // destroys Python's significant indentation — an earlier version did
225            // exactly that and every Python reference died with
226            // `IndentationError: expected an indented block after 'try'`. Keeping
227            // the indentation INSIDE each literal makes it immune to that.
228            preamble: concat!(
229                "import sys\n",
230                "import numpy as np\n",
231                "_data_csv, _out = sys.argv[1], sys.argv[2]\n",
232                "try:\n",
233                "    import pandas as pd\n",
234                "    df = pd.read_csv(_data_csv)\n",
235                "except Exception:\n",
236                "    import csv as _csv\n",
237                "    with open(_data_csv) as _fh:\n",
238                "        _r = _csv.DictReader(_fh)\n",
239                "        _cols = {k: [] for k in _r.fieldnames}\n",
240                "        for _row in _r:\n",
241                "            for _k, _v in _row.items():\n",
242                "                _cols[_k].append(float(_v))\n",
243                "    df = {k: np.asarray(v, dtype=float) for k, v in _cols.items()}\n",
244                "_lines = []\n",
245                "def emit(key, x):\n",
246                "    arr = np.asarray(x, dtype=float).reshape(-1)\n",
247                "    _lines.append(str(key) + ':' + ' '.join(repr(float(v)) for v in arr))\n",
248            ),
249            epilogue: "\nopen(_out, 'w').write('\\n'.join(_lines) + '\\n')\n",
250            spawn_expect: "spawn python3 (install python3 to run reference-comparison tests)",
251            write_expect: "write reference python script",
252            display: "Python",
253        }
254    }
255
256    /// Build the interpreter command with its fixed leading arguments (before
257    /// the script path / data CSV / output path are appended by the runner).
258    fn command(&self) -> Command {
259        match self.tag {
260            "r" => {
261                let mut cmd = Command::new("Rscript");
262                cmd.arg("--vanilla");
263                cmd
264            }
265            _ => Command::new("python3"),
266        }
267    }
268}
269
270/// Run a reference body in the interpreter described by `kind`. The columns are
271/// written to a CSV the script reads; the wrapped script exposes `df`, the
272/// output path, and `emit("key", values)`, runs `body`, and emits results back
273/// over the line protocol. Fails the test with captured stderr/stdout when the
274/// interpreter exits non-zero — a broken or unavailable reference run (missing
275/// interpreter, missing package, runtime error) is a hard failure, never a
276/// silent skip.
277fn run_subprocess(kind: &ReferenceKind, columns: &[Column<'_>], body: &str) -> ReferenceResult {
278    let dir = unique_scratch_dir(kind.tag);
279    let data_csv = dir.join("data.csv");
280    let out_txt = dir.join("out.txt");
281    let script = dir.join(kind.script_name);
282    write_columns_csv(&data_csv, columns);
283
284    let preamble = kind.preamble;
285    let epilogue = kind.epilogue;
286    let full = format!("{preamble}\n{body}\n{epilogue}");
287    std::fs::write(&script, full).expect(kind.write_expect);
288
289    let output = kind
290        .command()
291        .arg(&script)
292        .arg(&data_csv)
293        .arg(&out_txt)
294        .output()
295        .expect(kind.spawn_expect);
296
297    let stderr = String::from_utf8_lossy(&output.stderr);
298    let stdout = String::from_utf8_lossy(&output.stdout);
299    assert!(
300        output.status.success(),
301        "reference {} body failed (status {:?})\n--- stderr ---\n{stderr}\n--- stdout ---\n{stdout}",
302        kind.display,
303        output.status.code()
304    );
305
306    let emitted = std::fs::read_to_string(&out_txt).unwrap_or_default();
307    let parsed = parse_emitted(&emitted);
308    std::fs::remove_dir_all(&dir).ok();
309    ReferenceResult { values: parsed }
310}
311
312/// Run an R reference body. The columns are exposed as a `data.frame` named
313/// `df`; the body calls `emit("key", numeric_vector)` to return results. The
314/// harness prepends the `df`, output path, and `emit` helper. Fails the test
315/// with the captured stderr when R exits non-zero — a broken or unavailable
316/// reference run (missing `Rscript`, missing package, R error) is a hard test
317/// failure, never a silent skip.
318pub fn run_r(columns: &[Column<'_>], body: &str) -> ReferenceResult {
319    run_subprocess(&ReferenceKind::r(), columns, body)
320}
321
322/// Probe whether an R package can actually be **loaded** (namespace + any native
323/// `dyn.load`) in the reference interpreter, without raising. Returns `true`
324/// only when `requireNamespace` reports the package is usable.
325///
326/// This is the narrow, documented environmental-gate escape hatch — the same
327/// category as the CUDA hardware gate and the DoubleML/EconML `available` flag,
328/// NOT a general skip path. It exists for the handful of references the
329/// reference-quality CI job provisions only *best-effort* because they are large
330/// and/or native and not reliably installable on a bare runner (notably R-INLA,
331/// whose bundled native binaries `dyn.load` per-OS). A test that gates on this
332/// MUST still assert its tool-free, absolute quality bars unconditionally and
333/// skip only the *match-or-beat-vs-this-tool* arm when the tool is genuinely
334/// absent — never the gam-side claim. Every other reference dependency remains a
335/// hard failure via [`run_r`]/[`run_python`].
336pub fn r_package_available(pkg: &str) -> bool {
337    if ReferenceKind::r()
338        .command()
339        .arg("--version")
340        .output()
341        .is_err()
342    {
343        return false;
344    }
345    // `requireNamespace` is contractually non-throwing (returns FALSE and warns
346    // on a failed load), so a present probe interpreter exits zero and this is
347    // never itself a hard failure. Treat a missing `Rscript` binary the same as
348    // a missing package for this narrowly-scoped environmental gate: callers
349    // that use the gate still assert their tool-free gam quality bars and skip
350    // only the external-reference arm.
351    let script =
352        format!("cat(if (requireNamespace(\"{pkg}\", quietly = TRUE)) \"1\\n\" else \"0\\n\")");
353    let output = match Command::new("Rscript")
354        .arg("--vanilla")
355        .arg("-e")
356        .arg(script)
357        .output()
358    {
359        Ok(output) => output,
360        Err(_) => return false,
361    };
362    output.status.success() && String::from_utf8_lossy(&output.stdout).trim() == "1"
363}
364
365/// Run a Python reference body. The columns are exposed as a pandas `df` (or,
366/// when pandas is unavailable, a dict of NumPy arrays). The body calls
367/// `emit("key", iterable)` to return results. Fails the test with captured
368/// stderr when Python exits non-zero (missing `python3`, missing module, or a
369/// raised exception).
370pub fn run_python(columns: &[Column<'_>], body: &str) -> ReferenceResult {
371    run_subprocess(&ReferenceKind::python(), columns, body)
372}
373
374/// Relative L2 distance `||a - b|| / max(||b||, eps)` — the natural
375/// scale-free measure of how closely a fitted function tracks a reference
376/// function evaluated on the same grid.
377pub fn relative_l2(a: &[f64], b: &[f64]) -> f64 {
378    assert_eq!(a.len(), b.len(), "relative_l2 length mismatch");
379    let num: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
380    let den: f64 = b.iter().map(|y| y * y).sum();
381    (num / den.max(1e-300)).sqrt()
382}
383
384/// Root-mean-square difference between two equal-length vectors.
385pub fn rmse(a: &[f64], b: &[f64]) -> f64 {
386    assert_eq!(a.len(), b.len(), "rmse length mismatch");
387    let s: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
388    (s / a.len().max(1) as f64).sqrt()
389}
390
391/// Coefficient of determination against the mean predictor.
392pub fn r2(pred: &[f64], truth: &[f64]) -> f64 {
393    assert_eq!(pred.len(), truth.len(), "r2 length mismatch");
394    let n = truth.len() as f64;
395    let mean = truth.iter().sum::<f64>() / n;
396    let ss_res: f64 = pred.iter().zip(truth).map(|(p, t)| (t - p) * (t - p)).sum();
397    let ss_tot: f64 = truth.iter().map(|t| (t - mean) * (t - mean)).sum();
398    1.0 - ss_res / ss_tot.max(1e-300)
399}
400
401/// Out-of-sample coefficient of determination against the held-out mean.
402pub fn held_out_r2(pred: &[f64], truth: &[f64]) -> f64 {
403    r2(pred, truth)
404}
405
406/// Right-pad a vector with its last value, or 0.0 when empty.
407///
408/// `pad_to` is a *grow-only* helper used to lift a short column up to the common
409/// wire width of a ragged reference frame; the padded tail is never read by a
410/// correctly-sliced reference body. Asking it to *shrink* a column (target
411/// shorter than the source) is always a caller bug — it would silently drop the
412/// tail of real data — so this is a hard error with an actionable message rather
413/// than a quiet truncation. The usual culprit is padding a full-data column to a
414/// train-split width: pad every column to a single `n = max(len)` and slice each
415/// by its own semantic length inside the reference body instead.
416pub fn pad_to(v: &[f64], len: usize) -> Vec<f64> {
417    assert!(
418        v.len() <= len,
419        "pad_to cannot shrink: source has {} rows but the pad target is {len} \
420         (a shorter target would drop real data). Pad every column to a common \
421         n = max(column length) and slice by semantic length in the reference body.",
422        v.len()
423    );
424    let fill = v.last().copied().unwrap_or(0.0);
425    let mut out = v.to_vec();
426    out.resize(len, fill);
427    out
428}
429
430/// Maximum absolute difference between two equal-length vectors.
431pub fn max_abs_diff(a: &[f64], b: &[f64]) -> f64 {
432    assert_eq!(a.len(), b.len(), "max_abs_diff length mismatch");
433    a.iter()
434        .zip(b)
435        .map(|(x, y)| (x - y).abs())
436        .fold(0.0, f64::max)
437}
438
439/// Compact, reusable diagnostics for truth/reference quality tests.
440#[derive(Clone, Debug)]
441pub struct QualityDiagnostics {
442    pub label: String,
443    pub rmse_vs_truth: Option<f64>,
444    pub rmse_vs_reference: Option<f64>,
445    pub reference_rmse_vs_truth: Option<f64>,
446    pub edf_total: Option<f64>,
447    pub rho: Vec<f64>,
448    pub lambda: Vec<f64>,
449    pub design: Option<DesignDiagnostics>,
450    pub penalties: Vec<PenaltyDiagnostics>,
451    pub prediction: Option<PredictionFingerprint>,
452}
453
454#[derive(Clone, Debug)]
455pub struct DesignDiagnostics {
456    pub nrows: usize,
457    pub ncols: usize,
458    pub rank: usize,
459    pub condition: f64,
460    pub sigma_min: f64,
461    pub sigma_max: f64,
462}
463
464#[derive(Clone, Debug)]
465pub struct PenaltyDiagnostics {
466    pub index: usize,
467    pub col_start: usize,
468    pub col_end: usize,
469    pub rank: usize,
470    pub lambda: Option<f64>,
471    pub eig_min: f64,
472    pub eig_max: f64,
473    pub trace: f64,
474}
475
476#[derive(Clone, Debug)]
477pub struct PredictionFingerprint {
478    pub n: usize,
479    pub mean: f64,
480    pub sd: f64,
481    pub min: f64,
482    pub max: f64,
483    pub first: f64,
484    pub last: f64,
485}
486
487impl QualityDiagnostics {
488    pub fn from_standard_fit(
489        label: impl Into<String>,
490        fit: &gam_models::fit_orchestration::StandardFitResult,
491    ) -> Self {
492        let design = design_diagnostics(&fit.design.design).ok();
493        let penalties = penalty_diagnostics(
494            &fit.design.penalties,
495            fit.fit.lambdas.as_slice().unwrap_or(&[]),
496        );
497        Self {
498            label: label.into(),
499            rmse_vs_truth: None,
500            rmse_vs_reference: None,
501            reference_rmse_vs_truth: None,
502            edf_total: fit.fit.inference.as_ref().map(|i| i.edf_total),
503            rho: fit.fit.log_lambdas.to_vec(),
504            lambda: fit.fit.lambdas.to_vec(),
505            design,
506            penalties,
507            prediction: None,
508        }
509    }
510    pub fn with_truth_rmse(mut self, pred: &[f64], truth: &[f64]) -> Self {
511        self.rmse_vs_truth = Some(rmse(pred, truth));
512        self.prediction = Some(prediction_fingerprint(pred));
513        self
514    }
515    pub fn with_reference_gap(
516        mut self,
517        pred: &[f64],
518        reference: &[f64],
519        truth: Option<&[f64]>,
520    ) -> Self {
521        self.rmse_vs_reference = Some(rmse(pred, reference));
522        if let Some(truth) = truth {
523            self.reference_rmse_vs_truth = Some(rmse(reference, truth));
524        }
525        self
526    }
527    pub fn report(&self) -> String {
528        let mut out = format!("[quality-diagnostics] label={}", self.label);
529        if let Some(v) = self.rmse_vs_truth {
530            out.push_str(&format!(" rmse_truth={v:.6}"));
531        }
532        if let Some(v) = self.rmse_vs_reference {
533            out.push_str(&format!(" rmse_reference_gap={v:.6}"));
534        }
535        if let Some(v) = self.reference_rmse_vs_truth {
536            out.push_str(&format!(" reference_rmse_truth={v:.6}"));
537        }
538        if let Some(v) = self.edf_total {
539            out.push_str(&format!(" edf_total={v:.3}"));
540        }
541        if !self.rho.is_empty() {
542            out.push_str(&format!(" rho={:?}", Rounded(&self.rho)));
543        }
544        if !self.lambda.is_empty() {
545            out.push_str(&format!(" lambda={:?}", Rounded(&self.lambda)));
546        }
547        if let Some(d) = &self.design {
548            out.push_str(&format!(
549                " design={}x{} rank={} cond={:.3e} sigma=[{:.3e},{:.3e}]",
550                d.nrows, d.ncols, d.rank, d.condition, d.sigma_min, d.sigma_max
551            ));
552        }
553        if let Some(p) = &self.prediction {
554            out.push_str(&format!(
555                " pred[n={} mean={:.4} sd={:.4} range=[{:.4},{:.4}] edge=[{:.4},{:.4}]]",
556                p.n, p.mean, p.sd, p.min, p.max, p.first, p.last
557            ));
558        }
559        if !self.penalties.is_empty() {
560            out.push_str(" penalties=");
561            for p in &self.penalties {
562                out.push_str(&format!(
563                    " #{} cols={}..{} rank={} lambda={} eig=[{:.3e},{:.3e}] tr={:.3e};",
564                    p.index,
565                    p.col_start,
566                    p.col_end,
567                    p.rank,
568                    p.lambda
569                        .map(|v| format!("{v:.3e}"))
570                        .unwrap_or_else(|| "NA".into()),
571                    p.eig_min,
572                    p.eig_max,
573                    p.trace
574                ));
575            }
576        }
577        out
578    }
579}
580
581struct Rounded<'a>(&'a [f64]);
582impl std::fmt::Debug for Rounded<'_> {
583    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584        f.write_str("[")?;
585        for (i, v) in self.0.iter().enumerate() {
586            if i > 0 {
587                f.write_str(", ")?;
588            }
589            write!(f, "{v:.3e}")?;
590        }
591        f.write_str("]")
592    }
593}
594
595pub fn prediction_fingerprint(values: &[f64]) -> PredictionFingerprint {
596    let n = values.len();
597    let mean = values.iter().sum::<f64>() / n.max(1) as f64;
598    let var = values.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n.max(1) as f64;
599    PredictionFingerprint {
600        n,
601        mean,
602        sd: var.sqrt(),
603        min: values.iter().copied().fold(f64::INFINITY, f64::min),
604        max: values.iter().copied().fold(f64::NEG_INFINITY, f64::max),
605        first: values.first().copied().unwrap_or(f64::NAN),
606        last: values.last().copied().unwrap_or(f64::NAN),
607    }
608}
609
610pub fn design_diagnostics(
611    design: &gam_linalg::matrix::DesignMatrix,
612) -> Result<DesignDiagnostics, String> {
613    use gam_linalg::faer_ndarray::FaerSvd;
614    let dense = design
615        .try_to_dense_by_chunks_budgeted("quality diagnostics design SVD", 256 * 1024 * 1024)?;
616    let (_u, s, _vt) = dense.svd(false, false).map_err(|e| e.to_string())?;
617    let sigma_max = s.iter().copied().fold(0.0, f64::max);
618    let tol = (design.nrows().max(design.ncols()) as f64) * f64::EPSILON * sigma_max.max(1.0);
619    let rank = s.iter().filter(|&&v| v > tol).count();
620    let sigma_min = s
621        .iter()
622        .copied()
623        .filter(|v| *v > tol)
624        .fold(0.0_f64, |a, v| if a == 0.0 { v } else { a.min(v) });
625    Ok(DesignDiagnostics {
626        nrows: design.nrows(),
627        ncols: design.ncols(),
628        rank,
629        condition: if sigma_min > 0.0 {
630            sigma_max / sigma_min
631        } else {
632            f64::INFINITY
633        },
634        sigma_min,
635        sigma_max,
636    })
637}
638
639pub fn penalty_diagnostics(
640    penalties: &[gam_terms::smooth::BlockwisePenalty],
641    lambdas: &[f64],
642) -> Vec<PenaltyDiagnostics> {
643    use gam_linalg::faer_ndarray::FaerEigh;
644    use faer::Side;
645    penalties
646        .iter()
647        .enumerate()
648        .map(|(index, p)| {
649            let evals = p
650                .local
651                .eigh(Side::Lower)
652                .map(|(e, _)| e)
653                .unwrap_or_else(|_| ndarray::Array1::from_vec(vec![f64::NAN]));
654            let scale = evals
655                .iter()
656                .copied()
657                .map(f64::abs)
658                .fold(0.0, f64::max)
659                .max(1.0);
660            let tol = scale * 1.0e-10;
661            let rank = evals.iter().filter(|&&v| v > tol).count();
662            PenaltyDiagnostics {
663                index,
664                col_start: p.col_range.start,
665                col_end: p.col_range.end,
666                rank,
667                lambda: lambdas.get(index).copied(),
668                eig_min: evals.iter().copied().fold(f64::INFINITY, f64::min),
669                eig_max: evals.iter().copied().fold(f64::NEG_INFINITY, f64::max),
670                trace: p.local.diag().sum(),
671            }
672        })
673        .collect()
674}
675
676/// A Double Machine Learning (DML) reference estimate of the average linear
677/// effect `θ = E[∂E(Y|D,X)/∂D]` of a treatment/dose `D` on outcome `Y` after
678/// partialling out confounders `X`, computed by a mature Python DML library
679/// (DoubleML's partially-linear model, with EconML's `LinearDML` as fallback).
680///
681/// This is the Neyman-orthogonal scalar-target baseline used by #461's Sim C:
682/// the cross-fitted DML estimator is, by construction, first-order insensitive
683/// to first-stage nuisance estimation error, so its `theta`/`se` are the
684/// reference bias/coverage that gam's orthogonalized marginal-slope target
685/// `θ = E_x[β(x)]` must match-or-beat under x-dependent Stage-1 miscalibration.
686pub struct DmlPartialLinearReference {
687    /// Whether a DML library was importable in the reference interpreter. When
688    /// `false`, `theta`/`se`/`ci_lo`/`ci_hi` are `NaN` and the caller should
689    /// emit a clear skip message rather than asserting against them — DoubleML/
690    /// EconML are heavier optional dependencies than scipy/mgcv, so their
691    /// absence is treated as a genuine environmental gate (mirroring the
692    /// CUDA-only skip in `tests/common/gpu_gate.rs`) rather than the hard
693    /// failure that a missing scipy/R would be.
694    pub available: bool,
695    /// Which backend produced the estimate: "doubleml", "econml", or "none".
696    pub backend: String,
697    /// Point estimate of the average linear treatment effect `θ`.
698    pub theta: f64,
699    /// Standard error of `θ̂` reported by the DML library.
700    pub se: f64,
701    /// Lower end of the library's 95% confidence interval for `θ`.
702    pub ci_lo: f64,
703    /// Upper end of the library's 95% confidence interval for `θ`.
704    pub ci_hi: f64,
705}
706
707/// Fit a partially-linear DML model `Y = θ·D + g(X) + ε`, `D = m(X) + ν` with a
708/// mature Python DML library and return its orthogonal estimate of `θ`.
709///
710/// `y`, `d`, and the columns of `x` must share a common length. `n_folds` sets
711/// the cross-fitting fold count (DML's sample-splitting ingredient). The
712/// reference uses gradient-boosted nuisance learners so the partialling-out is
713/// genuinely nonparametric, exercising the orthogonality the estimator claims.
714///
715/// When neither DoubleML nor EconML is importable, the returned struct has
716/// `available == false`; the interpreter itself still exits zero (the import
717/// probe is guarded), so this is *not* a hard failure — the caller decides
718/// whether to skip. A missing `python3`/`numpy`/`scikit-learn`, by contrast, is
719/// still a loud failure via the underlying [`run_python`] contract.
720pub fn dml_partial_linear_reference(
721    y: &[f64],
722    d: &[f64],
723    x: &[Column<'_>],
724    n_folds: usize,
725) -> DmlPartialLinearReference {
726    assert!(
727        !x.is_empty(),
728        "DML reference needs at least one confounder X"
729    );
730    assert_eq!(y.len(), d.len(), "DML reference y/d length mismatch");
731    let x_names: Vec<String> = x.iter().map(|c| format!("{:?}", c.name)).collect();
732    let x_list = x_names.join(", ");
733    let mut columns: Vec<Column<'_>> = Vec::with_capacity(x.len() + 2);
734    columns.push(Column::new("y", y));
735    columns.push(Column::new("d", d));
736    columns.extend(x.iter().map(|c| Column::new(c.name, c.data)));
737
738    // The body first probes for an importable DML backend; if none is present it
739    // emits `available=0` and returns cleanly (interpreter exits zero), so the
740    // Rust side can skip-with-message instead of failing. When a backend exists
741    // the estimate is real and emitted under `theta`/`se`/`ci_lo`/`ci_hi`.
742    let body = format!(
743        r#"
744import numpy as np
745_xcols = [{x_list}]
746Y = np.asarray(df["y"], dtype=float).reshape(-1)
747D = np.asarray(df["d"], dtype=float).reshape(-1)
748X = np.column_stack([np.asarray(df[c], dtype=float).reshape(-1) for c in _xcols])
749n_folds = {n_folds}
750
751def _have(mod):
752    import importlib.util
753    return importlib.util.find_spec(mod) is not None
754
755theta = float("nan"); se = float("nan"); backend = 0.0; avail = 0.0
756ci_lo = float("nan"); ci_hi = float("nan")
757
758try:
759    if _have("doubleml") and _have("sklearn"):
760        import doubleml as dml
761        from doubleml import DoubleMLData, DoubleMLPLR
762        from sklearn.ensemble import GradientBoostingRegressor
763        data = DoubleMLData.from_arrays(X, Y, D)
764        ml_l = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0)
765        ml_m = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1)
766        plr = DoubleMLPLR(data, ml_l=ml_l, ml_m=ml_m, n_folds=n_folds)
767        plr.fit()
768        theta = float(np.asarray(plr.coef).reshape(-1)[0])
769        se = float(np.asarray(plr.se).reshape(-1)[0])
770        cis = np.asarray(plr.confint(level=0.95))
771        ci_lo = float(cis.reshape(-1)[0]); ci_hi = float(cis.reshape(-1)[1])
772        backend = 1.0; avail = 1.0
773    elif _have("econml") and _have("sklearn"):
774        from econml.dml import LinearDML
775        from sklearn.ensemble import GradientBoostingRegressor
776        est = LinearDML(
777            model_y=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0),
778            model_t=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1),
779            cv=n_folds, random_state=0,
780        )
781        est.fit(Y, D, X=None, W=X)
782        theta = float(np.asarray(est.coef_).reshape(-1)[0]) if np.asarray(est.coef_).size else float(est.intercept_)
783        inf = est.coef__inference() if hasattr(est, "coef__inference") else est.intercept__inference()
784        se = float(np.asarray(inf.stderr).reshape(-1)[0])
785        lohi = inf.conf_int(alpha=0.05)
786        ci_lo = float(np.asarray(lohi[0]).reshape(-1)[0]); ci_hi = float(np.asarray(lohi[1]).reshape(-1)[0])
787        backend = 2.0; avail = 1.0
788except Exception as _e:
789    avail = 0.0; backend = 0.0
790    theta = float("nan"); se = float("nan")
791    ci_lo = float("nan"); ci_hi = float("nan")
792
793emit("available", [avail])
794emit("backend", [backend])
795emit("theta", [theta])
796emit("se", [se])
797emit("ci_lo", [ci_lo])
798emit("ci_hi", [ci_hi])
799"#
800    );
801
802    let r = run_python(&columns, &body);
803    let available = r.scalar("available") > 0.5;
804    let backend = match r.scalar("backend") as i64 {
805        1 => "doubleml",
806        2 => "econml",
807        _ => "none",
808    }
809    .to_string();
810    DmlPartialLinearReference {
811        available,
812        backend,
813        theta: r.scalar("theta"),
814        se: r.scalar("se"),
815        ci_lo: r.scalar("ci_lo"),
816        ci_hi: r.scalar("ci_hi"),
817    }
818}
819
820/// Pearson correlation between two equal-length vectors.
821pub fn pearson(a: &[f64], b: &[f64]) -> f64 {
822    assert_eq!(a.len(), b.len(), "pearson length mismatch");
823    let n = a.len() as f64;
824    let ma = a.iter().sum::<f64>() / n;
825    let mb = b.iter().sum::<f64>() / n;
826    let mut sab = 0.0;
827    let mut saa = 0.0;
828    let mut sbb = 0.0;
829    for (x, y) in a.iter().zip(b) {
830        let da = x - ma;
831        let db = y - mb;
832        sab += da * db;
833        saa += da * da;
834        sbb += db * db;
835    }
836    sab / (saa.sqrt() * sbb.sqrt()).max(1e-300)
837}
838
839#[cfg(test)]
840mod pad_to_tests {
841    use super::pad_to;
842
843    /// Regression for #1084: the exact shape that used to panic with the
844    /// inscrutable "pad target 490 shorter than source 654". A full-data column
845    /// (654 rows) padded down to a train-split width (490 rows) must still be a
846    /// hard error — silently dropping 164 rows of real data is never correct —
847    /// but now with an actionable message naming the cause and the fix.
848    #[test]
849    #[should_panic(expected = "pad_to cannot shrink")]
850    fn shrink_to_train_split_is_a_clear_error() {
851        let full = vec![1.0; 654];
852        drop(pad_to(&full, 490));
853    }
854
855    /// The documented fix: padding both a full-data column and a train-split
856    /// column to a common `n = max(len)` yields equal-length columns whose
857    /// real-data prefixes are preserved, so a reference body can slice each by
858    /// its own semantic length. This is the consistent-split path #1084's
859    /// prostate test now follows.
860    #[test]
861    fn pad_full_and_train_to_common_n_is_consistent() {
862        let n = 654usize;
863        let n_train = 490usize;
864        let full: Vec<f64> = (0..n).map(|i| i as f64).collect();
865        let train: Vec<f64> = (0..n_train).map(|i| (1000 + i) as f64).collect();
866
867        let full_wire = pad_to(&full, n);
868        let train_wire = pad_to(&train, n);
869        assert_eq!(full_wire.len(), n);
870        assert_eq!(train_wire.len(), n);
871
872        // Real-data prefixes survive untouched.
873        assert_eq!(&full_wire[..n], &full[..]);
874        assert_eq!(&train_wire[..n_train], &train[..]);
875        // The padded tail repeats the last real value (never read by a body
876        // that slices by `n_train`), confirming no real data leaks past it.
877        assert_eq!(train_wire[n_train], train[n_train - 1]);
878        assert_eq!(train_wire[n - 1], train[n_train - 1]);
879    }
880}