1use std::collections::BTreeMap;
28use std::io::Write;
29use std::path::PathBuf;
30use std::process::Command;
31use std::sync::atomic::{AtomicU64, Ordering};
32
33pub struct ReferenceResult {
35 values: BTreeMap<String, Vec<f64>>,
36}
37
38impl ReferenceResult {
39 pub fn scalar(&self, key: &str) -> f64 {
42 let v = self.vector(key);
43 assert_eq!(
44 v.len(),
45 1,
46 "reference key {key:?} carried {} values, expected a scalar",
47 v.len()
48 );
49 v[0]
50 }
51
52 pub fn vector(&self, key: &str) -> &[f64] {
55 let msg = format!(
56 "reference did not emit key {key:?}; emitted keys: {:?}",
57 self.keys()
58 );
59 self.values.get(key).expect(&msg).as_slice()
60 }
61
62 pub fn keys(&self) -> Vec<&str> {
64 self.values.keys().map(String::as_str).collect()
65 }
66}
67
68pub struct Column<'a> {
71 pub name: &'a str,
73 pub data: &'a [f64],
75}
76
77impl<'a> Column<'a> {
78 pub fn new(name: &'a str, data: &'a [f64]) -> Self {
80 Self { name, data }
81 }
82}
83
84fn unique_scratch_dir(tag: &str) -> PathBuf {
85 static COUNTER: AtomicU64 = AtomicU64::new(0);
86 let n = COUNTER.fetch_add(1, Ordering::Relaxed);
87 let mut dir = std::env::temp_dir();
88 dir.push(format!(
89 "gam_reference_{}_{}_{}",
90 tag,
91 std::process::id(),
92 n
93 ));
94 std::fs::create_dir_all(&dir).expect("create reference scratch dir");
95 dir
96}
97
98fn write_columns_csv(path: &std::path::Path, columns: &[Column<'_>]) {
99 assert!(
100 !columns.is_empty(),
101 "reference run needs at least one column"
102 );
103 let nrows = columns.iter().map(|c| c.data.len()).max().unwrap_or(0);
111 assert!(
112 nrows > 0,
113 "reference run needs at least one non-empty column"
114 );
115 let mut s = String::new();
116 s.push_str(
117 &columns
118 .iter()
119 .map(|c| c.name.to_string())
120 .collect::<Vec<_>>()
121 .join(","),
122 );
123 s.push('\n');
124 for row in 0..nrows {
125 let line = columns
126 .iter()
127 .map(|c| match c.data.get(row) {
128 Some(value) => format!("{value:.17e}"),
135 None => "NA".to_string(),
136 })
137 .collect::<Vec<_>>()
138 .join(",");
139 s.push_str(&line);
140 s.push('\n');
141 }
142 let mut f = std::fs::File::create(path).expect("write reference data csv");
143 f.write_all(s.as_bytes()).expect("flush reference data csv");
144}
145
146fn parse_emitted(text: &str) -> BTreeMap<String, Vec<f64>> {
147 let mut out: BTreeMap<String, Vec<f64>> = BTreeMap::new();
148 for line in text.lines() {
149 let line = line.trim();
150 if line.is_empty() {
151 continue;
152 }
153 let Some((key, rest)) = line.split_once(':') else {
154 continue;
155 };
156 let key = key.trim();
157 if key.is_empty() {
158 continue;
159 }
160 let values: Vec<f64> = rest
161 .split_whitespace()
162 .map(|tok| match tok {
163 "NA" | "na" | "NaN" | "nan" => f64::NAN,
164 "Inf" | "inf" => f64::INFINITY,
165 "-Inf" | "-inf" => f64::NEG_INFINITY,
166 other => other
167 .parse::<f64>()
168 .expect("reference emitted an unparsable numeric token"),
169 })
170 .collect();
171 out.insert(key.to_string(), values);
172 }
173 out
174}
175
176struct ReferenceKind {
181 tag: &'static str,
183 script_name: &'static str,
185 preamble: &'static str,
187 epilogue: &'static str,
189 spawn_expect: &'static str,
191 write_expect: &'static str,
193 display: &'static str,
195}
196
197impl ReferenceKind {
198 fn r() -> Self {
199 ReferenceKind {
200 tag: "r",
201 script_name: "script.R",
202 preamble: "\
203args <- commandArgs(trailingOnly = TRUE)\n\
204df <- read.csv(args[1])\n\
205.OUT <- args[2]\n\
206emit <- function(key, x) {\n\
207 cat(sprintf('%s:%s\\n', key, paste(format(as.numeric(x), digits = 17, scientific = TRUE), collapse = ' ')),\n\
208 file = .OUT, append = TRUE)\n\
209}\n",
210 epilogue: "",
211 spawn_expect: "spawn Rscript (install R to run reference-comparison tests)",
212 write_expect: "write reference R script",
213 display: "R",
214 }
215 }
216
217 fn python() -> Self {
218 ReferenceKind {
219 tag: "py",
220 script_name: "script.py",
221 preamble: concat!(
229 "import sys\n",
230 "import numpy as np\n",
231 "_data_csv, _out = sys.argv[1], sys.argv[2]\n",
232 "try:\n",
233 " import pandas as pd\n",
234 " df = pd.read_csv(_data_csv)\n",
235 "except Exception:\n",
236 " import csv as _csv\n",
237 " with open(_data_csv) as _fh:\n",
238 " _r = _csv.DictReader(_fh)\n",
239 " _cols = {k: [] for k in _r.fieldnames}\n",
240 " for _row in _r:\n",
241 " for _k, _v in _row.items():\n",
242 " _cols[_k].append(float(_v))\n",
243 " df = {k: np.asarray(v, dtype=float) for k, v in _cols.items()}\n",
244 "_lines = []\n",
245 "def emit(key, x):\n",
246 " arr = np.asarray(x, dtype=float).reshape(-1)\n",
247 " _lines.append(str(key) + ':' + ' '.join(repr(float(v)) for v in arr))\n",
248 ),
249 epilogue: "\nopen(_out, 'w').write('\\n'.join(_lines) + '\\n')\n",
250 spawn_expect: "spawn python3 (install python3 to run reference-comparison tests)",
251 write_expect: "write reference python script",
252 display: "Python",
253 }
254 }
255
256 fn command(&self) -> Command {
259 match self.tag {
260 "r" => {
261 let mut cmd = Command::new("Rscript");
262 cmd.arg("--vanilla");
263 cmd
264 }
265 _ => Command::new("python3"),
266 }
267 }
268}
269
270fn run_subprocess(kind: &ReferenceKind, columns: &[Column<'_>], body: &str) -> ReferenceResult {
278 let dir = unique_scratch_dir(kind.tag);
279 let data_csv = dir.join("data.csv");
280 let out_txt = dir.join("out.txt");
281 let script = dir.join(kind.script_name);
282 write_columns_csv(&data_csv, columns);
283
284 let preamble = kind.preamble;
285 let epilogue = kind.epilogue;
286 let full = format!("{preamble}\n{body}\n{epilogue}");
287 std::fs::write(&script, full).expect(kind.write_expect);
288
289 let output = kind
290 .command()
291 .arg(&script)
292 .arg(&data_csv)
293 .arg(&out_txt)
294 .output()
295 .expect(kind.spawn_expect);
296
297 let stderr = String::from_utf8_lossy(&output.stderr);
298 let stdout = String::from_utf8_lossy(&output.stdout);
299 assert!(
300 output.status.success(),
301 "reference {} body failed (status {:?})\n--- stderr ---\n{stderr}\n--- stdout ---\n{stdout}",
302 kind.display,
303 output.status.code()
304 );
305
306 let emitted = std::fs::read_to_string(&out_txt).unwrap_or_default();
307 let parsed = parse_emitted(&emitted);
308 std::fs::remove_dir_all(&dir).ok();
309 ReferenceResult { values: parsed }
310}
311
312pub fn run_r(columns: &[Column<'_>], body: &str) -> ReferenceResult {
319 run_subprocess(&ReferenceKind::r(), columns, body)
320}
321
322pub fn r_package_available(pkg: &str) -> bool {
337 if ReferenceKind::r()
338 .command()
339 .arg("--version")
340 .output()
341 .is_err()
342 {
343 return false;
344 }
345 let script =
352 format!("cat(if (requireNamespace(\"{pkg}\", quietly = TRUE)) \"1\\n\" else \"0\\n\")");
353 let output = match Command::new("Rscript")
354 .arg("--vanilla")
355 .arg("-e")
356 .arg(script)
357 .output()
358 {
359 Ok(output) => output,
360 Err(_) => return false,
361 };
362 output.status.success() && String::from_utf8_lossy(&output.stdout).trim() == "1"
363}
364
365pub fn run_python(columns: &[Column<'_>], body: &str) -> ReferenceResult {
371 run_subprocess(&ReferenceKind::python(), columns, body)
372}
373
374pub fn relative_l2(a: &[f64], b: &[f64]) -> f64 {
378 assert_eq!(a.len(), b.len(), "relative_l2 length mismatch");
379 let num: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
380 let den: f64 = b.iter().map(|y| y * y).sum();
381 (num / den.max(1e-300)).sqrt()
382}
383
384pub fn rmse(a: &[f64], b: &[f64]) -> f64 {
386 assert_eq!(a.len(), b.len(), "rmse length mismatch");
387 let s: f64 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
388 (s / a.len().max(1) as f64).sqrt()
389}
390
391pub fn r2(pred: &[f64], truth: &[f64]) -> f64 {
393 assert_eq!(pred.len(), truth.len(), "r2 length mismatch");
394 let n = truth.len() as f64;
395 let mean = truth.iter().sum::<f64>() / n;
396 let ss_res: f64 = pred.iter().zip(truth).map(|(p, t)| (t - p) * (t - p)).sum();
397 let ss_tot: f64 = truth.iter().map(|t| (t - mean) * (t - mean)).sum();
398 1.0 - ss_res / ss_tot.max(1e-300)
399}
400
401pub fn held_out_r2(pred: &[f64], truth: &[f64]) -> f64 {
403 r2(pred, truth)
404}
405
406pub fn pad_to(v: &[f64], len: usize) -> Vec<f64> {
417 assert!(
418 v.len() <= len,
419 "pad_to cannot shrink: source has {} rows but the pad target is {len} \
420 (a shorter target would drop real data). Pad every column to a common \
421 n = max(column length) and slice by semantic length in the reference body.",
422 v.len()
423 );
424 let fill = v.last().copied().unwrap_or(0.0);
425 let mut out = v.to_vec();
426 out.resize(len, fill);
427 out
428}
429
430pub fn max_abs_diff(a: &[f64], b: &[f64]) -> f64 {
432 assert_eq!(a.len(), b.len(), "max_abs_diff length mismatch");
433 a.iter()
434 .zip(b)
435 .map(|(x, y)| (x - y).abs())
436 .fold(0.0, f64::max)
437}
438
439#[derive(Clone, Debug)]
441pub struct QualityDiagnostics {
442 pub label: String,
443 pub rmse_vs_truth: Option<f64>,
444 pub rmse_vs_reference: Option<f64>,
445 pub reference_rmse_vs_truth: Option<f64>,
446 pub edf_total: Option<f64>,
447 pub rho: Vec<f64>,
448 pub lambda: Vec<f64>,
449 pub design: Option<DesignDiagnostics>,
450 pub penalties: Vec<PenaltyDiagnostics>,
451 pub prediction: Option<PredictionFingerprint>,
452}
453
454#[derive(Clone, Debug)]
455pub struct DesignDiagnostics {
456 pub nrows: usize,
457 pub ncols: usize,
458 pub rank: usize,
459 pub condition: f64,
460 pub sigma_min: f64,
461 pub sigma_max: f64,
462}
463
464#[derive(Clone, Debug)]
465pub struct PenaltyDiagnostics {
466 pub index: usize,
467 pub col_start: usize,
468 pub col_end: usize,
469 pub rank: usize,
470 pub lambda: Option<f64>,
471 pub eig_min: f64,
472 pub eig_max: f64,
473 pub trace: f64,
474}
475
476#[derive(Clone, Debug)]
477pub struct PredictionFingerprint {
478 pub n: usize,
479 pub mean: f64,
480 pub sd: f64,
481 pub min: f64,
482 pub max: f64,
483 pub first: f64,
484 pub last: f64,
485}
486
487impl QualityDiagnostics {
488 pub fn from_standard_fit(
489 label: impl Into<String>,
490 fit: &gam_models::fit_orchestration::StandardFitResult,
491 ) -> Self {
492 let design = design_diagnostics(&fit.design.design).ok();
493 let penalties = penalty_diagnostics(
494 &fit.design.penalties,
495 fit.fit.lambdas.as_slice().unwrap_or(&[]),
496 );
497 Self {
498 label: label.into(),
499 rmse_vs_truth: None,
500 rmse_vs_reference: None,
501 reference_rmse_vs_truth: None,
502 edf_total: fit.fit.inference.as_ref().map(|i| i.edf_total),
503 rho: fit.fit.log_lambdas.to_vec(),
504 lambda: fit.fit.lambdas.to_vec(),
505 design,
506 penalties,
507 prediction: None,
508 }
509 }
510 pub fn with_truth_rmse(mut self, pred: &[f64], truth: &[f64]) -> Self {
511 self.rmse_vs_truth = Some(rmse(pred, truth));
512 self.prediction = Some(prediction_fingerprint(pred));
513 self
514 }
515 pub fn with_reference_gap(
516 mut self,
517 pred: &[f64],
518 reference: &[f64],
519 truth: Option<&[f64]>,
520 ) -> Self {
521 self.rmse_vs_reference = Some(rmse(pred, reference));
522 if let Some(truth) = truth {
523 self.reference_rmse_vs_truth = Some(rmse(reference, truth));
524 }
525 self
526 }
527 pub fn report(&self) -> String {
528 let mut out = format!("[quality-diagnostics] label={}", self.label);
529 if let Some(v) = self.rmse_vs_truth {
530 out.push_str(&format!(" rmse_truth={v:.6}"));
531 }
532 if let Some(v) = self.rmse_vs_reference {
533 out.push_str(&format!(" rmse_reference_gap={v:.6}"));
534 }
535 if let Some(v) = self.reference_rmse_vs_truth {
536 out.push_str(&format!(" reference_rmse_truth={v:.6}"));
537 }
538 if let Some(v) = self.edf_total {
539 out.push_str(&format!(" edf_total={v:.3}"));
540 }
541 if !self.rho.is_empty() {
542 out.push_str(&format!(" rho={:?}", Rounded(&self.rho)));
543 }
544 if !self.lambda.is_empty() {
545 out.push_str(&format!(" lambda={:?}", Rounded(&self.lambda)));
546 }
547 if let Some(d) = &self.design {
548 out.push_str(&format!(
549 " design={}x{} rank={} cond={:.3e} sigma=[{:.3e},{:.3e}]",
550 d.nrows, d.ncols, d.rank, d.condition, d.sigma_min, d.sigma_max
551 ));
552 }
553 if let Some(p) = &self.prediction {
554 out.push_str(&format!(
555 " pred[n={} mean={:.4} sd={:.4} range=[{:.4},{:.4}] edge=[{:.4},{:.4}]]",
556 p.n, p.mean, p.sd, p.min, p.max, p.first, p.last
557 ));
558 }
559 if !self.penalties.is_empty() {
560 out.push_str(" penalties=");
561 for p in &self.penalties {
562 out.push_str(&format!(
563 " #{} cols={}..{} rank={} lambda={} eig=[{:.3e},{:.3e}] tr={:.3e};",
564 p.index,
565 p.col_start,
566 p.col_end,
567 p.rank,
568 p.lambda
569 .map(|v| format!("{v:.3e}"))
570 .unwrap_or_else(|| "NA".into()),
571 p.eig_min,
572 p.eig_max,
573 p.trace
574 ));
575 }
576 }
577 out
578 }
579}
580
581struct Rounded<'a>(&'a [f64]);
582impl std::fmt::Debug for Rounded<'_> {
583 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584 f.write_str("[")?;
585 for (i, v) in self.0.iter().enumerate() {
586 if i > 0 {
587 f.write_str(", ")?;
588 }
589 write!(f, "{v:.3e}")?;
590 }
591 f.write_str("]")
592 }
593}
594
595pub fn prediction_fingerprint(values: &[f64]) -> PredictionFingerprint {
596 let n = values.len();
597 let mean = values.iter().sum::<f64>() / n.max(1) as f64;
598 let var = values.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n.max(1) as f64;
599 PredictionFingerprint {
600 n,
601 mean,
602 sd: var.sqrt(),
603 min: values.iter().copied().fold(f64::INFINITY, f64::min),
604 max: values.iter().copied().fold(f64::NEG_INFINITY, f64::max),
605 first: values.first().copied().unwrap_or(f64::NAN),
606 last: values.last().copied().unwrap_or(f64::NAN),
607 }
608}
609
610pub fn design_diagnostics(
611 design: &gam_linalg::matrix::DesignMatrix,
612) -> Result<DesignDiagnostics, String> {
613 use gam_linalg::faer_ndarray::FaerSvd;
614 let dense = design
615 .try_to_dense_by_chunks_budgeted("quality diagnostics design SVD", 256 * 1024 * 1024)?;
616 let (_u, s, _vt) = dense.svd(false, false).map_err(|e| e.to_string())?;
617 let sigma_max = s.iter().copied().fold(0.0, f64::max);
618 let tol = (design.nrows().max(design.ncols()) as f64) * f64::EPSILON * sigma_max.max(1.0);
619 let rank = s.iter().filter(|&&v| v > tol).count();
620 let sigma_min = s
621 .iter()
622 .copied()
623 .filter(|v| *v > tol)
624 .fold(0.0_f64, |a, v| if a == 0.0 { v } else { a.min(v) });
625 Ok(DesignDiagnostics {
626 nrows: design.nrows(),
627 ncols: design.ncols(),
628 rank,
629 condition: if sigma_min > 0.0 {
630 sigma_max / sigma_min
631 } else {
632 f64::INFINITY
633 },
634 sigma_min,
635 sigma_max,
636 })
637}
638
639pub fn penalty_diagnostics(
640 penalties: &[gam_terms::smooth::BlockwisePenalty],
641 lambdas: &[f64],
642) -> Vec<PenaltyDiagnostics> {
643 use gam_linalg::faer_ndarray::FaerEigh;
644 use faer::Side;
645 penalties
646 .iter()
647 .enumerate()
648 .map(|(index, p)| {
649 let evals = p
650 .local
651 .eigh(Side::Lower)
652 .map(|(e, _)| e)
653 .unwrap_or_else(|_| ndarray::Array1::from_vec(vec![f64::NAN]));
654 let scale = evals
655 .iter()
656 .copied()
657 .map(f64::abs)
658 .fold(0.0, f64::max)
659 .max(1.0);
660 let tol = scale * 1.0e-10;
661 let rank = evals.iter().filter(|&&v| v > tol).count();
662 PenaltyDiagnostics {
663 index,
664 col_start: p.col_range.start,
665 col_end: p.col_range.end,
666 rank,
667 lambda: lambdas.get(index).copied(),
668 eig_min: evals.iter().copied().fold(f64::INFINITY, f64::min),
669 eig_max: evals.iter().copied().fold(f64::NEG_INFINITY, f64::max),
670 trace: p.local.diag().sum(),
671 }
672 })
673 .collect()
674}
675
676pub struct DmlPartialLinearReference {
687 pub available: bool,
695 pub backend: String,
697 pub theta: f64,
699 pub se: f64,
701 pub ci_lo: f64,
703 pub ci_hi: f64,
705}
706
707pub fn dml_partial_linear_reference(
721 y: &[f64],
722 d: &[f64],
723 x: &[Column<'_>],
724 n_folds: usize,
725) -> DmlPartialLinearReference {
726 assert!(
727 !x.is_empty(),
728 "DML reference needs at least one confounder X"
729 );
730 assert_eq!(y.len(), d.len(), "DML reference y/d length mismatch");
731 let x_names: Vec<String> = x.iter().map(|c| format!("{:?}", c.name)).collect();
732 let x_list = x_names.join(", ");
733 let mut columns: Vec<Column<'_>> = Vec::with_capacity(x.len() + 2);
734 columns.push(Column::new("y", y));
735 columns.push(Column::new("d", d));
736 columns.extend(x.iter().map(|c| Column::new(c.name, c.data)));
737
738 let body = format!(
743 r#"
744import numpy as np
745_xcols = [{x_list}]
746Y = np.asarray(df["y"], dtype=float).reshape(-1)
747D = np.asarray(df["d"], dtype=float).reshape(-1)
748X = np.column_stack([np.asarray(df[c], dtype=float).reshape(-1) for c in _xcols])
749n_folds = {n_folds}
750
751def _have(mod):
752 import importlib.util
753 return importlib.util.find_spec(mod) is not None
754
755theta = float("nan"); se = float("nan"); backend = 0.0; avail = 0.0
756ci_lo = float("nan"); ci_hi = float("nan")
757
758try:
759 if _have("doubleml") and _have("sklearn"):
760 import doubleml as dml
761 from doubleml import DoubleMLData, DoubleMLPLR
762 from sklearn.ensemble import GradientBoostingRegressor
763 data = DoubleMLData.from_arrays(X, Y, D)
764 ml_l = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0)
765 ml_m = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1)
766 plr = DoubleMLPLR(data, ml_l=ml_l, ml_m=ml_m, n_folds=n_folds)
767 plr.fit()
768 theta = float(np.asarray(plr.coef).reshape(-1)[0])
769 se = float(np.asarray(plr.se).reshape(-1)[0])
770 cis = np.asarray(plr.confint(level=0.95))
771 ci_lo = float(cis.reshape(-1)[0]); ci_hi = float(cis.reshape(-1)[1])
772 backend = 1.0; avail = 1.0
773 elif _have("econml") and _have("sklearn"):
774 from econml.dml import LinearDML
775 from sklearn.ensemble import GradientBoostingRegressor
776 est = LinearDML(
777 model_y=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=0),
778 model_t=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=1),
779 cv=n_folds, random_state=0,
780 )
781 est.fit(Y, D, X=None, W=X)
782 theta = float(np.asarray(est.coef_).reshape(-1)[0]) if np.asarray(est.coef_).size else float(est.intercept_)
783 inf = est.coef__inference() if hasattr(est, "coef__inference") else est.intercept__inference()
784 se = float(np.asarray(inf.stderr).reshape(-1)[0])
785 lohi = inf.conf_int(alpha=0.05)
786 ci_lo = float(np.asarray(lohi[0]).reshape(-1)[0]); ci_hi = float(np.asarray(lohi[1]).reshape(-1)[0])
787 backend = 2.0; avail = 1.0
788except Exception as _e:
789 avail = 0.0; backend = 0.0
790 theta = float("nan"); se = float("nan")
791 ci_lo = float("nan"); ci_hi = float("nan")
792
793emit("available", [avail])
794emit("backend", [backend])
795emit("theta", [theta])
796emit("se", [se])
797emit("ci_lo", [ci_lo])
798emit("ci_hi", [ci_hi])
799"#
800 );
801
802 let r = run_python(&columns, &body);
803 let available = r.scalar("available") > 0.5;
804 let backend = match r.scalar("backend") as i64 {
805 1 => "doubleml",
806 2 => "econml",
807 _ => "none",
808 }
809 .to_string();
810 DmlPartialLinearReference {
811 available,
812 backend,
813 theta: r.scalar("theta"),
814 se: r.scalar("se"),
815 ci_lo: r.scalar("ci_lo"),
816 ci_hi: r.scalar("ci_hi"),
817 }
818}
819
820pub fn pearson(a: &[f64], b: &[f64]) -> f64 {
822 assert_eq!(a.len(), b.len(), "pearson length mismatch");
823 let n = a.len() as f64;
824 let ma = a.iter().sum::<f64>() / n;
825 let mb = b.iter().sum::<f64>() / n;
826 let mut sab = 0.0;
827 let mut saa = 0.0;
828 let mut sbb = 0.0;
829 for (x, y) in a.iter().zip(b) {
830 let da = x - ma;
831 let db = y - mb;
832 sab += da * db;
833 saa += da * da;
834 sbb += db * db;
835 }
836 sab / (saa.sqrt() * sbb.sqrt()).max(1e-300)
837}
838
839#[cfg(test)]
840mod pad_to_tests {
841 use super::pad_to;
842
843 #[test]
849 #[should_panic(expected = "pad_to cannot shrink")]
850 fn shrink_to_train_split_is_a_clear_error() {
851 let full = vec![1.0; 654];
852 drop(pad_to(&full, 490));
853 }
854
855 #[test]
861 fn pad_full_and_train_to_common_n_is_consistent() {
862 let n = 654usize;
863 let n_train = 490usize;
864 let full: Vec<f64> = (0..n).map(|i| i as f64).collect();
865 let train: Vec<f64> = (0..n_train).map(|i| (1000 + i) as f64).collect();
866
867 let full_wire = pad_to(&full, n);
868 let train_wire = pad_to(&train, n);
869 assert_eq!(full_wire.len(), n);
870 assert_eq!(train_wire.len(), n);
871
872 assert_eq!(&full_wire[..n], &full[..]);
874 assert_eq!(&train_wire[..n_train], &train[..]);
875 assert_eq!(train_wire[n_train], train[n_train - 1]);
878 assert_eq!(train_wire[n - 1], train[n_train - 1]);
879 }
880}