1use super::family::clamp_bernoulli_link_probability;
2use super::*;
3use gam_linalg::matrix::{LinearOperator, SignedWeightsView};
4use gam_math::jet_tower::Tower4;
5
6pub(crate) fn standardize_latent_z_with_policy(
7 z: &Array1<f64>,
8 weights: &Array1<f64>,
9 context: &str,
10 policy: &LatentZPolicy,
11) -> Result<(Array1<f64>, LatentZNormalization), String> {
12 if z.len() != weights.len() {
13 return Err(format!(
14 "{context} latent-score normalization length mismatch: z={}, weights={}",
15 z.len(),
16 weights.len()
17 ));
18 }
19 let weight_sum = weights.iter().copied().sum::<f64>();
20 let weight_sq_sum = weights.iter().map(|&w| w * w).sum::<f64>();
21 if !(weight_sum.is_finite()
22 && weight_sum > 0.0
23 && weight_sq_sum.is_finite()
24 && weight_sq_sum > 0.0)
25 {
26 return Err(format!("{context} requires positive finite total weight"));
27 }
28 let effective_n = weight_sum * weight_sum / weight_sq_sum;
29 if !(effective_n.is_finite() && effective_n > 1.0) {
30 return Err(format!(
31 "{context} requires at least two effective observations for latent-score normalization"
32 ));
33 }
34 let mean = z
35 .iter()
36 .zip(weights.iter())
37 .map(|(&zi, &wi)| wi * zi)
38 .sum::<f64>()
39 / weight_sum;
40 let var = z
41 .iter()
42 .zip(weights.iter())
43 .map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
44 .sum::<f64>()
45 / weight_sum;
46 let sd = var.sqrt();
47 if !(sd.is_finite() && sd > BMS_VARIANCE_FLOOR) {
48 return Err(format!(
49 "{context} requires z with positive finite weighted standard deviation"
50 ));
51 }
52 let target_norm = match policy.normalization {
53 LatentZNormalizationMode::None => LatentZNormalization { mean: 0.0, sd: 1.0 },
54 LatentZNormalizationMode::FitWeighted => LatentZNormalization { mean, sd },
55 LatentZNormalizationMode::Frozen {
56 mean: frozen_mean,
57 sd: frozen_sd,
58 } => LatentZNormalization {
59 mean: frozen_mean,
60 sd: frozen_sd,
61 },
62 };
63 let mean_tol = policy.mean_tol_multiplier / effective_n.sqrt();
64 let sd_tol = policy.sd_tol_multiplier / (2.0 * (effective_n - 1.0).max(1.0)).sqrt();
65 let check_msg = || {
66 format!(
67 "{context} requires z to already be approximately latent N(0,1) before identification normalization; got mean={mean:.6e}, sd={sd:.6e}, effective_n={effective_n:.1}, allowed_mean={mean_tol:.3e}, allowed_sd={sd_tol:.3e}"
68 )
69 };
70 if mean.abs() > mean_tol || (sd - 1.0).abs() > sd_tol {
71 match policy.check_mode {
72 LatentZCheckMode::Strict => return Err(check_msg()),
73 LatentZCheckMode::WarnOnly => log::warn!("{}", check_msg()),
74 LatentZCheckMode::Off => {}
75 }
76 }
77
78 let normalization = target_norm;
79 let z_std = normalization.apply(z, context)?;
80 let skew = z_std
81 .iter()
82 .zip(weights.iter())
83 .map(|(&zi, &wi)| wi * zi.powi(3))
84 .sum::<f64>()
85 / weight_sum;
86 let kurt = z_std
87 .iter()
88 .zip(weights.iter())
89 .map(|(&zi, &wi)| wi * zi.powi(4))
90 .sum::<f64>()
91 / weight_sum
92 - 3.0;
93 if skew.abs() > policy.max_abs_skew || kurt.abs() > policy.max_abs_excess_kurtosis {
94 let msg = format!(
95 "{context} requires z to be approximately Gaussian after identification normalization; got skewness={skew:.3}, excess_kurtosis={kurt:.3}"
96 );
97 match policy.check_mode {
98 LatentZCheckMode::Strict => return Err(msg),
99 LatentZCheckMode::WarnOnly => log::warn!("{}", msg),
100 LatentZCheckMode::Off => {}
101 }
102 }
103 if skew.abs() > 0.75 || kurt.abs() > 2.0 {
104 log::warn!(
105 "{context}: z has skewness={skew:.3} and excess kurtosis={kurt:.3}; latent-measure auto-selection will use empirical calibration unless stricter diagnostics pass"
106 );
107 }
108 Ok((z_std, normalization))
109}
110
111pub fn padded_deviation_seed(seed: &Array1<f64>, min_iqr: f64, pad_fraction: f64) -> Array1<f64> {
112 let mut sorted = seed.to_vec();
113 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114
115 if sorted.len() < 4 {
116 return seed.clone();
117 }
118
119 let n = sorted.len();
120 let q1 = sorted[n / 4];
121 let q3 = sorted[3 * n / 4];
122 let iqr = (q3 - q1).max(min_iqr);
123 let pad = pad_fraction * iqr;
124
125 let mut out = seed.to_vec();
126 out.push(sorted[0] - pad);
127 out.push(sorted[n - 1] + pad);
128 Array1::from_vec(out)
129}
130
131const POOLED_PILOT_MAX_NEWTON_ITERS: usize = 50;
142pub(crate) const POOLED_PILOT_RIDGE_INIT: f64 = 1e-8;
144pub(crate) const POOLED_PILOT_DET_FLOOR: f64 = 1e-18;
147pub(crate) const POOLED_PILOT_RIDGE_GROWTH: f64 = 10.0;
149pub(crate) const POOLED_PILOT_RIDGE_MAX: f64 = 1e6;
152const POOLED_PILOT_MAX_BACKTRACKS: usize = 25;
154pub(crate) const POOLED_PILOT_BACKTRACK_SHRINK: f64 = 0.5;
156pub(crate) const POOLED_PILOT_STALL_TOL: f64 = 1e-10;
159pub(crate) const POOLED_PILOT_MIN_ABS_SLOPE: f64 = 1e-6;
162
163pub(super) fn pooled_probit_baseline(
164 y: &Array1<f64>,
165 z: &Array1<f64>,
166 weights: &Array1<f64>,
167) -> Result<(f64, f64), String> {
168 if y.len() != z.len() || y.len() != weights.len() {
169 return Err(format!(
170 "pooled bernoulli-marginal-slope pilot length mismatch: y={}, z={}, weights={}",
171 y.len(),
172 z.len(),
173 weights.len()
174 ));
175 }
176 let weight_sum = weights.iter().copied().sum::<f64>();
177 if !weight_sum.is_finite() || weight_sum <= 0.0 {
178 return Err(
179 "pooled bernoulli-marginal-slope pilot requires positive finite total weight"
180 .to_string(),
181 );
182 }
183 let prevalence = y
184 .iter()
185 .zip(weights.iter())
186 .map(|(&yi, &wi)| yi * wi)
187 .sum::<f64>()
188 / weight_sum;
189 let prevalence = prevalence.clamp(1e-6, 1.0 - 1e-6);
190 let z_mean = z
191 .iter()
192 .zip(weights.iter())
193 .map(|(&zi, &wi)| zi * wi)
194 .sum::<f64>()
195 / weight_sum;
196 let z_var = z
197 .iter()
198 .zip(weights.iter())
199 .map(|(&zi, &wi)| wi * (zi - z_mean) * (zi - z_mean))
200 .sum::<f64>()
201 / weight_sum;
202 let yz_cov = y
203 .iter()
204 .zip(z.iter())
205 .zip(weights.iter())
206 .map(|((&yi, &zi), &wi)| wi * (yi - prevalence) * (zi - z_mean))
207 .sum::<f64>()
208 / weight_sum;
209 let mut beta0 = standard_normal_quantile(prevalence).map_err(|e| {
210 format!("failed to initialize pooled bernoulli-marginal-slope pilot intercept: {e}")
211 })?;
212 let mut beta1 = if z_var > BMS_VARIANCE_FLOOR {
213 yz_cov / z_var
214 } else {
215 0.0
216 };
217
218 let objective_grad_hess =
219 |intercept: f64, slope: f64| -> Result<(f64, f64, f64, f64, f64, f64), String> {
220 let mut obj = 0.0;
221 let mut g0 = 0.0;
222 let mut g1 = 0.0;
223 let mut h00 = 0.0;
224 let mut h01 = 0.0;
225 let mut h11 = 0.0;
226 for ((&yi, &zi), &wi) in y.iter().zip(z.iter()).zip(weights.iter()) {
227 if wi == 0.0 {
228 continue;
229 }
230 let eta = intercept + slope * zi;
231 let s = 2.0 * yi - 1.0;
232 let margin = s * eta;
233 let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(margin);
234 let g_eta = -wi * s * lambda;
235 let h_eta = wi * lambda * (margin + lambda);
236 obj -= wi * logcdf;
237 g0 += g_eta;
238 g1 += g_eta * zi;
239 h00 += h_eta;
240 h01 += h_eta * zi;
241 h11 += h_eta * zi * zi;
242 }
243 Ok((obj, g0, g1, h00, h01, h11))
244 };
245
246 let mut obj_prev = f64::INFINITY;
247 for _ in 0..POOLED_PILOT_MAX_NEWTON_ITERS {
248 let (obj, g0, g1, h00, h01, h11) = objective_grad_hess(beta0, beta1)?;
249 if !obj.is_finite() || !g0.is_finite() || !g1.is_finite() {
250 return Err(
251 "pooled bernoulli-marginal-slope pilot produced non-finite objective or gradient"
252 .to_string(),
253 );
254 }
255 let grad_max = g0.abs().max(g1.abs());
256 if grad_max < BMS_DERIV_TOL {
257 break;
258 }
259 let mut ridge = POOLED_PILOT_RIDGE_INIT;
260 let (step0, step1) = loop {
261 let h00_r = h00 + ridge;
262 let h11_r = h11 + ridge;
263 let det = h00_r * h11_r - h01 * h01;
264 if det.is_finite() && det.abs() > POOLED_PILOT_DET_FLOOR {
265 let s0 = (h11_r * g0 - h01 * g1) / det;
266 let s1 = (-h01 * g0 + h00_r * g1) / det;
267 if s0.is_finite() && s1.is_finite() {
268 break (s0, s1);
269 }
270 }
271 ridge *= POOLED_PILOT_RIDGE_GROWTH;
272 if ridge > POOLED_PILOT_RIDGE_MAX {
273 return Err(
274 "pooled bernoulli-marginal-slope pilot Hessian solve failed".to_string()
275 );
276 }
277 };
278 let mut accepted = false;
279 let mut step_scale = 1.0;
280 for _ in 0..POOLED_PILOT_MAX_BACKTRACKS {
281 let cand0 = beta0 - step_scale * step0;
282 let cand1 = beta1 - step_scale * step1;
283 let (cand_obj, _, _, _, _, _) = objective_grad_hess(cand0, cand1)?;
284 if cand_obj.is_finite() && cand_obj <= obj {
285 beta0 = cand0;
286 beta1 = cand1;
287 obj_prev = cand_obj;
288 accepted = true;
289 break;
290 }
291 step_scale *= POOLED_PILOT_BACKTRACK_SHRINK;
292 }
293 if !accepted {
294 if (obj_prev - obj).abs() < POOLED_PILOT_STALL_TOL {
295 break;
296 }
297 return Err("pooled bernoulli-marginal-slope pilot line search failed".to_string());
298 }
299 }
300 let a = beta0;
301 let b = if beta1.abs() < POOLED_PILOT_MIN_ABS_SLOPE {
303 if beta1.is_sign_negative() {
304 -POOLED_PILOT_MIN_ABS_SLOPE
305 } else {
306 POOLED_PILOT_MIN_ABS_SLOPE
307 }
308 } else {
309 beta1
310 };
311 Ok((a / (1.0 + b * b).sqrt(), b))
312}
313
314pub(super) fn pilot_irls_hessian_row_metric_at_eta(
354 eta_pilot: &Array1<f64>,
355 sample_weights: &Array1<f64>,
356) -> Array1<f64> {
357 let n = eta_pilot.len();
358 let mut w = Array1::<f64>::zeros(n);
359 for i in 0..n {
360 let eta = eta_pilot[i];
361 let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
362 let phi = normal_pdf(eta).max(1e-300);
363 let var = (mu * (1.0 - mu)).max(1e-300);
364 w[i] = sample_weights[i] * (phi * phi) / var;
365 }
366 w
367}
368
369pub(super) fn rigid_pooled_probit_pilot_eta(
376 base_link: &InverseLink,
377 z: &Array1<f64>,
378 marginal_offset: &Array1<f64>,
379 logslope_offset: &Array1<f64>,
380 baseline_marginal: f64,
381 baseline_logslope: f64,
382 probit_scale: f64,
383) -> Result<Array1<f64>, String> {
384 let n = z.len();
385 let mut out = Array1::<f64>::zeros(n);
386 for i in 0..n {
387 let a_pre = baseline_marginal + marginal_offset[i];
388 let b_pre = baseline_logslope + logslope_offset[i];
389 let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
390 .map_err(|e| format!("rigid_pooled_probit_pilot_eta marginal link map: {e}"))?
391 .q;
392 out[i] = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
393 }
394 Ok(out)
395}
396
397pub(crate) const PILOT_RIDGE_DIAG_FRACTION: f64 = 1e-6;
403pub(crate) const PILOT_RIDGE_DIAG_FLOOR: f64 = 1e-12;
406
407pub(super) fn pilot_eta_for_link_dev_orthogonalisation(
408 base_link: &InverseLink,
409 y: &Array1<f64>,
410 z: &Array1<f64>,
411 weights: &Array1<f64>,
412 marginal_design: &DesignMatrix,
413 marginal_offset: &Array1<f64>,
414 logslope_offset: &Array1<f64>,
415 baseline_marginal: f64,
416 baseline_logslope: f64,
417 probit_scale: f64,
418) -> Result<Array1<f64>, String> {
419 use gam_linalg::faer_ndarray::FaerCholesky;
420
421 let n = y.len();
422 if marginal_design.nrows() != n {
423 return Err(format!(
424 "pilot_eta_for_link_dev_orthogonalisation: marginal design has {} rows, expected {}",
425 marginal_design.nrows(),
426 n,
427 ));
428 }
429 let mut working_eta = Array1::<f64>::zeros(n);
430 let mut w_irls = Array1::<f64>::zeros(n);
431 let mut residual = Array1::<f64>::zeros(n);
432 for i in 0..n {
433 let a_pre = baseline_marginal + marginal_offset[i];
434 let b_pre = baseline_logslope + logslope_offset[i];
435 let q_marg = bernoulli_marginal_link_map(base_link, a_pre)
436 .map_err(|e| {
437 format!("pilot_eta_for_link_dev_orthogonalisation marginal link map: {e}")
438 })?
439 .q;
440 let eta = rigid_observed_eta(q_marg, b_pre, z[i], probit_scale);
441 working_eta[i] = eta;
442 let mu = clamp_bernoulli_link_probability(normal_cdf(eta));
443 let phi = normal_pdf(eta).max(1e-300);
444 let var = (mu * (1.0 - mu)).max(1e-300);
445 w_irls[i] = weights[i] * (phi * phi) / var;
446 residual[i] = (y[i] - mu) / phi;
447 }
448 let p_marg = marginal_design.ncols();
449 if p_marg == 0 {
450 return Ok(working_eta);
451 }
452 let xtwr = marginal_design.compute_xtwy(&w_irls, &residual)?;
453 let mut xtwx = marginal_design.xt_diag_x_signed_op(SignedWeightsView::from_array(&w_irls))?;
454 let trace_diag: f64 = (0..p_marg).map(|i| xtwx[[i, i]]).sum();
455 let ridge =
456 (trace_diag / p_marg as f64).max(PILOT_RIDGE_DIAG_FLOOR) * PILOT_RIDGE_DIAG_FRACTION;
457 for i in 0..p_marg {
458 xtwx[[i, i]] += ridge;
459 }
460 let factor = xtwx
461 .cholesky(faer::Side::Lower)
462 .map_err(|e| format!("pilot_eta_for_link_dev_orthogonalisation Cholesky failed: {e}"))?;
463 let delta_beta_marg = factor.solvevec(&xtwr);
464 let marg_contrib = marginal_design.dot(&delta_beta_marg);
465 Ok(&working_eta + &marg_contrib)
466}
467
468pub(super) fn joint_setup(
469 data: ArrayView2<'_, f64>,
470 marginalspec: &TermCollectionSpec,
471 logslopespec: &TermCollectionSpec,
472 marginal_penalties: usize,
473 logslope_penalties: usize,
474 extra_rho0: &[f64],
475 kappa_options: &SpatialLengthScaleOptimizationOptions,
476) -> ExactJointHyperSetup {
477 let marginal_terms = spatial_length_scale_term_indices(marginalspec);
478 let logslope_terms = spatial_length_scale_term_indices(logslopespec);
479 let rho_dim = marginal_penalties + logslope_penalties + extra_rho0.len();
480 let mut rho0vec = Array1::<f64>::zeros(rho_dim);
481 for (idx, &value) in extra_rho0.iter().enumerate() {
482 rho0vec[marginal_penalties + logslope_penalties + idx] = value;
483 }
484 let rho_lower = Array1::<f64>::from_elem(rho_dim, -12.0);
485 let rho_upper = Array1::<f64>::from_elem(rho_dim, 12.0);
486 let marginal_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
487 marginalspec,
488 &marginal_terms,
489 kappa_options,
490 )
491 .reseed_from_data(data, marginalspec, &marginal_terms, kappa_options);
492 let logslope_kappa = SpatialLogKappaCoords::from_length_scales_aniso(
493 logslopespec,
494 &logslope_terms,
495 kappa_options,
496 )
497 .reseed_from_data(data, logslopespec, &logslope_terms, kappa_options);
498 let mut values = marginal_kappa.as_array().to_vec();
499 values.extend(logslope_kappa.as_array().iter());
500 let marginal_dims = marginal_kappa.dims_per_term().to_vec();
501 let logslope_dims = logslope_kappa.dims_per_term().to_vec();
502 let mut dims = marginal_dims.clone();
503 dims.extend(logslope_dims.iter().copied());
504 let log_kappa0 = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(values), dims.clone());
505 let marginal_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
507 data,
508 marginalspec,
509 &marginal_terms,
510 &marginal_dims,
511 kappa_options,
512 );
513 let logslope_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
514 data,
515 logslopespec,
516 &logslope_terms,
517 &logslope_dims,
518 kappa_options,
519 );
520 let mut lower_vals = marginal_lower.as_array().to_vec();
521 lower_vals.extend(logslope_lower.as_array().iter());
522 let log_kappa_lower =
523 SpatialLogKappaCoords::new_with_dims(Array1::from_vec(lower_vals), dims.clone());
524 let marginal_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
525 data,
526 marginalspec,
527 &marginal_terms,
528 &marginal_dims,
529 kappa_options,
530 );
531 let logslope_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
532 data,
533 logslopespec,
534 &logslope_terms,
535 &logslope_dims,
536 kappa_options,
537 );
538 let mut upper_vals = marginal_upper.as_array().to_vec();
539 upper_vals.extend(logslope_upper.as_array().iter());
540 let log_kappa_upper = SpatialLogKappaCoords::new_with_dims(Array1::from_vec(upper_vals), dims);
541 let log_kappa0 = log_kappa0.clamp_to_bounds(&log_kappa_lower, &log_kappa_upper);
544 ExactJointHyperSetup::new(
545 rho0vec,
546 rho_lower,
547 rho_upper,
548 log_kappa0,
549 log_kappa_lower,
550 log_kappa_upper,
551 )
552}
553
554#[inline]
555pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth_numeric(
556 signed_margin: f64,
557 weight: f64,
558) -> (f64, f64, f64, f64) {
559 if weight == 0.0 || signed_margin == f64::INFINITY {
560 return (0.0, 0.0, 0.0, 0.0);
561 }
562 if signed_margin == f64::NEG_INFINITY {
563 return (f64::NEG_INFINITY, weight, 0.0, 0.0);
564 }
565 if signed_margin.is_nan() {
566 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
567 }
568 let (_, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
569 let k1 = -lambda;
570 let k2 = lambda * (signed_margin + lambda);
571 let k3 = lambda
572 * (1.0
573 - signed_margin * signed_margin
574 - 3.0 * signed_margin * lambda
575 - 2.0 * lambda * lambda);
576 let k4 = lambda
577 * ((signed_margin.powi(3) - 3.0 * signed_margin)
578 + (7.0 * signed_margin * signed_margin - 4.0) * lambda
579 + 12.0 * signed_margin * lambda * lambda
580 + 6.0 * lambda.powi(3));
581 (weight * k1, weight * k2, weight * k3, weight * k4)
582}
583
584pub(crate) fn signed_probit_neglog_derivatives_up_to_fourth(
592 signed_margin: f64,
593 weight: f64,
594) -> Result<(f64, f64, f64, f64), String> {
595 if weight == 0.0 || signed_margin == f64::INFINITY {
596 return Ok((0.0, 0.0, 0.0, 0.0));
597 }
598 if !signed_margin.is_finite() {
599 return Err(format!(
600 "non-finite signed margin in exact probit derivative helper: {signed_margin}"
601 ));
602 }
603 Ok(signed_probit_neglog_derivatives_up_to_fourth_numeric(
604 signed_margin,
605 weight,
606 ))
607}
608
609#[inline]
635pub(crate) fn signed_probit_neglog_unary_stack(signed_margin: f64, weight: f64) -> [f64; 5] {
636 if weight == 0.0 || signed_margin == f64::INFINITY {
637 return [0.0; 5];
638 }
639 if signed_margin == f64::NEG_INFINITY {
640 return [f64::INFINITY, f64::NEG_INFINITY, weight, 0.0, 0.0];
643 }
644 if signed_margin.is_nan() {
645 return [f64::NAN; 5];
646 }
647 let (logcdf, lambda) = signed_probit_logcdf_and_mills_ratio(signed_margin);
650 let m = signed_margin;
651 let k1 = -lambda;
652 let k2 = lambda * (m + lambda);
653 let k3 = lambda * (1.0 - m * m - 3.0 * m * lambda - 2.0 * lambda * lambda);
654 let k4 = lambda
655 * ((m * m * m - 3.0 * m)
656 + (7.0 * m * m - 4.0) * lambda
657 + 12.0 * m * lambda * lambda
658 + 6.0 * lambda * lambda * lambda);
659 [
660 -weight * logcdf,
661 weight * k1,
662 weight * k2,
663 weight * k3,
664 weight * k4,
665 ]
666}
667
668#[inline]
669pub(super) fn rigid_observed_logslope(logslope: f64, probit_scale: f64) -> f64 {
670 probit_scale * logslope
671}
672
673#[inline]
674pub(super) fn rigid_observed_scale(logslope: f64, probit_scale: f64) -> f64 {
675 let observed_logslope = rigid_observed_logslope(logslope, probit_scale);
676 (1.0 + observed_logslope * observed_logslope).sqrt()
677}
678
679#[inline]
680pub(super) fn rigid_intercept_from_marginal(
681 marginal_eta: f64,
682 logslope: f64,
683 probit_scale: f64,
684) -> f64 {
685 marginal_eta * rigid_observed_scale(logslope, probit_scale)
686}
687
688#[inline]
689pub(super) fn rigid_prescale_intercept_from_marginal(
690 marginal_eta: f64,
691 logslope: f64,
692 probit_scale: f64,
693) -> f64 {
694 rigid_intercept_from_marginal(marginal_eta, logslope, probit_scale) / probit_scale
695}
696
697#[inline]
698pub(super) fn rigid_prescale_intercept_derivative_abs(
699 marginal_eta: f64,
700 logslope: f64,
701 probit_scale: f64,
702) -> f64 {
703 let c = rigid_observed_scale(logslope, probit_scale);
704 probit_scale * normal_pdf(marginal_eta) / c
705}
706
707#[inline]
708pub(super) fn rigid_observed_eta(
709 marginal_eta: f64,
710 logslope: f64,
711 z: f64,
712 probit_scale: f64,
713) -> f64 {
714 marginal_slope_standard_normal_scalar_eta(marginal_eta, logslope, z, probit_scale)
715}
716
717#[inline]
718pub(super) fn marginal_slope_standard_normal_scalar_eta(
719 q: f64,
720 slope: f64,
721 z: f64,
722 probit_scale: f64,
723) -> f64 {
724 let observed_slope = rigid_observed_logslope(slope, probit_scale);
725 q * (1.0 + observed_slope * observed_slope).sqrt() + observed_slope * z
726}
727
728pub(super) fn unary_derivatives_normal_cdf(x: f64) -> [f64; 5] {
729 let pdf = normal_pdf(x);
730 [
731 normal_cdf(x),
732 pdf,
733 -x * pdf,
734 (x * x - 1.0) * pdf,
735 (-x.powi(3) + 3.0 * x) * pdf,
736 ]
737}
738
739pub(super) fn unary_derivatives_normal_pdf(x: f64) -> [f64; 5] {
740 let pdf = normal_pdf(x);
741 [
742 pdf,
743 -x * pdf,
744 (x * x - 1.0) * pdf,
745 (-x.powi(3) + 3.0 * x) * pdf,
746 (x.powi(4) - 6.0 * x * x + 3.0) * pdf,
747 ]
748}
749
750#[inline]
757pub(super) fn lse_accumulate(log_max: &mut f64, sum: &mut f64, log_term: f64) {
758 if !log_term.is_finite() {
759 return;
760 }
761 if log_term > *log_max {
762 if log_max.is_finite() {
763 *sum = *sum * (*log_max - log_term).exp() + 1.0;
764 } else {
765 *sum = 1.0;
766 }
767 *log_max = log_term;
768 } else {
769 *sum += (log_term - *log_max).exp();
770 }
771}
772
773#[derive(Clone, Copy, Debug, PartialEq, Eq)]
774pub enum MarginalSlopeCovarianceShape {
775 Diagonal,
776 Full,
777 LowRank,
778}
779
780#[derive(Clone, Debug, PartialEq)]
781pub enum MarginalSlopeCovariance {
782 Diagonal(Array1<f64>),
783 Full(Array2<f64>),
784 LowRank(Array2<f64>),
786}
787
788pub(crate) const COVARIANCE_QUADRATIC_FORM_PSD_TOL: f64 = -1e-10;
793
794impl MarginalSlopeCovariance {
795 pub fn shape(&self) -> MarginalSlopeCovarianceShape {
796 match self {
797 Self::Diagonal(_) => MarginalSlopeCovarianceShape::Diagonal,
798 Self::Full(_) => MarginalSlopeCovarianceShape::Full,
799 Self::LowRank(_) => MarginalSlopeCovarianceShape::LowRank,
800 }
801 }
802
803 pub fn dim(&self) -> usize {
804 match self {
805 Self::Diagonal(diag) => diag.len(),
806 Self::Full(cov) => cov.nrows(),
807 Self::LowRank(factor) => factor.nrows(),
808 }
809 }
810
811 pub fn validate(&self, context: &str) -> Result<(), String> {
812 match self {
813 Self::Diagonal(diag) => {
814 if diag.is_empty() {
815 return Err(format!("{context} diagonal covariance is empty"));
816 }
817 for (idx, &value) in diag.iter().enumerate() {
818 if !(value.is_finite() && value >= 0.0) {
819 return Err(format!(
820 "{context} diagonal covariance entry {idx} must be finite and non-negative, got {value}"
821 ));
822 }
823 }
824 }
825 Self::Full(cov) => {
826 if cov.nrows() == 0 || cov.nrows() != cov.ncols() {
827 return Err(format!(
828 "{context} full covariance must be non-empty and square, got {}x{}",
829 cov.nrows(),
830 cov.ncols()
831 ));
832 }
833 for i in 0..cov.nrows() {
834 for j in 0..cov.ncols() {
835 let value = cov[[i, j]];
836 if !value.is_finite() {
837 return Err(format!(
838 "{context} full covariance entry ({i},{j}) is non-finite"
839 ));
840 }
841 if (value - cov[[j, i]]).abs()
842 > 1e-10 * (1.0 + value.abs().max(cov[[j, i]].abs()))
843 {
844 return Err(format!(
845 "{context} full covariance must be symmetric at ({i},{j})"
846 ));
847 }
848 }
849 }
850 }
851 Self::LowRank(factor) => {
852 if factor.nrows() == 0 {
853 return Err(format!(
854 "{context} low-rank covariance factor has zero rows"
855 ));
856 }
857 for ((i, j), &value) in factor.indexed_iter() {
858 if !value.is_finite() {
859 return Err(format!(
860 "{context} low-rank covariance factor entry ({i},{j}) is non-finite"
861 ));
862 }
863 }
864 }
865 }
866 Ok(())
867 }
868
869 pub fn quadratic_form(&self, vector: &[f64]) -> Result<f64, String> {
870 self.validate("marginal-slope covariance")?;
871 if vector.len() != self.dim() {
872 return Err(format!(
873 "marginal-slope covariance dimension mismatch: vector={}, covariance={}",
874 vector.len(),
875 self.dim()
876 ));
877 }
878 if vector.iter().any(|value| !value.is_finite()) {
879 return Err("marginal-slope covariance vector contains non-finite values".to_string());
880 }
881 let value = match self {
882 Self::Diagonal(diag) => vector
883 .iter()
884 .zip(diag.iter())
885 .map(|(&v, &sigma)| v * v * sigma)
886 .sum::<f64>(),
887 Self::Full(cov) => {
888 let mut total = 0.0;
889 for i in 0..cov.nrows() {
890 let mut row_dot = 0.0;
891 for j in 0..cov.ncols() {
892 row_dot += cov[[i, j]] * vector[j];
893 }
894 total += vector[i] * row_dot;
895 }
896 total
897 }
898 Self::LowRank(factor) => {
899 let mut total = 0.0;
905 for r in 0..factor.ncols() {
906 let mut projection = 0.0;
907 for k in 0..factor.nrows() {
908 projection += factor[[k, r]] * vector[k];
909 }
910 total += projection * projection;
911 }
912 total
913 }
914 };
915 if value.is_finite() && value >= COVARIANCE_QUADRATIC_FORM_PSD_TOL {
916 Ok(value.max(0.0))
917 } else {
918 Err(format!(
919 "marginal-slope covariance quadratic form must be non-negative, got {value}"
920 ))
921 }
922 }
923}
924
925pub fn marginal_slope_covariance_from_scores(
952 scores: ArrayView2<'_, f64>,
953 weights: &Array1<f64>,
954) -> Result<MarginalSlopeCovariance, String> {
955 let (n, k) = scores.dim();
956 if k == 0 {
957 return Err("marginal-slope score matrix must have at least one column".to_string());
958 }
959 if weights.len() != n {
960 return Err(format!(
961 "marginal-slope covariance weight length mismatch: weights={}, rows={n}",
962 weights.len()
963 ));
964 }
965 let total_weight = weights.iter().copied().sum::<f64>();
966 if !(total_weight.is_finite() && total_weight > 0.0) {
967 return Err("marginal-slope covariance needs positive finite total weight".to_string());
968 }
969 let mut mean = Array1::<f64>::zeros(k);
970 for i in 0..n {
971 let weight = weights[i];
972 if !(weight.is_finite() && weight >= 0.0) {
973 return Err(format!(
974 "marginal-slope covariance weight {i} must be finite and non-negative, got {weight}"
975 ));
976 }
977 for j in 0..k {
978 let score = scores[[i, j]];
979 if !score.is_finite() {
980 return Err(format!(
981 "marginal-slope covariance score ({i},{j}) is non-finite"
982 ));
983 }
984 mean[j] += weight * score;
985 }
986 }
987 mean.mapv_inplace(|value| value / total_weight);
988
989 let mut cov = Array2::<f64>::zeros((k, k));
990 for i in 0..n {
991 let weight = weights[i];
992 for a in 0..k {
993 let da = scores[[i, a]] - mean[a];
994 for b in 0..=a {
995 let value = weight * da * (scores[[i, b]] - mean[b]) / total_weight;
996 cov[[a, b]] += value;
997 if a != b {
998 cov[[b, a]] += value;
999 }
1000 }
1001 }
1002 }
1003
1004 if k == 1 {
1029 return Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()));
1030 }
1031
1032 let diag: Vec<f64> = (0..k).map(|i| cov[[i, i]]).collect();
1033 let diag_max = diag.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1034 let numerical_floor = 1e-10 * (1.0 + diag_max);
1035
1036 let mut is_strict_diagonal = true;
1037 'strict: for a in 0..k {
1038 for b in (a + 1)..k {
1039 if cov[[a, b]].abs() > numerical_floor {
1040 is_strict_diagonal = false;
1041 break 'strict;
1042 }
1043 }
1044 }
1045 if is_strict_diagonal {
1046 return Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()));
1047 }
1048
1049 use gam_linalg::faer_ndarray::FaerEigh;
1050 let (evals, evecs) = cov
1051 .eigh(faer::Side::Lower)
1052 .map_err(|err| format!("marginal-slope covariance eigendecomposition failed: {err}"))?;
1053 let max_eval = evals
1054 .iter()
1055 .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
1056 let rank_tol = 1e-10 * max_eval.max(1.0);
1057 let positive: Vec<(usize, f64)> = evals
1058 .iter()
1059 .enumerate()
1060 .filter_map(|(idx, &value)| (value > rank_tol).then_some((idx, value)))
1061 .collect();
1062
1063 if positive.len() < k {
1064 let mut factor = Array2::<f64>::zeros((k, positive.len()));
1067 for (col, (idx, value)) in positive.iter().enumerate() {
1068 let scale = value.sqrt();
1069 for row in 0..k {
1070 factor[[row, col]] = evecs[[row, *idx]] * scale;
1071 }
1072 }
1073 return Ok(MarginalSlopeCovariance::LowRank(factor));
1074 }
1075
1076 let sum_w_sq = weights.iter().map(|&w| w * w).sum::<f64>();
1078 let n_eff = if sum_w_sq > 0.0 {
1079 (total_weight * total_weight) / sum_w_sq
1080 } else {
1081 1.0
1082 };
1083 const OFFDIAG_Z_THRESHOLD: f64 = 4.0;
1084 let mut is_stat_diagonal = true;
1085 'stat: for a in 0..k {
1086 for b in (a + 1)..k {
1087 let stat_se = (diag[a].max(0.0) * diag[b].max(0.0) / n_eff)
1088 .max(0.0)
1089 .sqrt();
1090 let threshold = numerical_floor.max(OFFDIAG_Z_THRESHOLD * stat_se);
1091 if cov[[a, b]].abs() > threshold {
1092 is_stat_diagonal = false;
1093 break 'stat;
1094 }
1095 }
1096 }
1097 if is_stat_diagonal {
1098 Ok(MarginalSlopeCovariance::Diagonal(cov.diag().to_owned()))
1099 } else {
1100 Ok(MarginalSlopeCovariance::Full(cov))
1101 }
1102}
1103
1104pub fn marginal_slope_preserving_scale(
1105 slopes: &[f64],
1106 covariance: &MarginalSlopeCovariance,
1107 probit_scale: f64,
1108) -> Result<f64, String> {
1109 if !probit_scale.is_finite() {
1110 return Err(format!(
1111 "marginal-slope probit scale must be finite, got {probit_scale}"
1112 ));
1113 }
1114 let observed_slopes = slopes
1115 .iter()
1116 .map(|&slope| probit_scale * slope)
1117 .collect::<Vec<_>>();
1118 let variance = covariance.quadratic_form(&observed_slopes)?;
1119 Ok((1.0 + variance).sqrt())
1120}
1121
1122pub fn marginal_slope_probit_eta(
1123 q: f64,
1124 z: &[f64],
1125 slopes: &[f64],
1126 covariance: &MarginalSlopeCovariance,
1127 probit_scale: f64,
1128) -> Result<f64, String> {
1129 if z.len() != slopes.len() {
1130 return Err(format!(
1131 "marginal-slope score/slope dimension mismatch: z={}, slopes={}",
1132 z.len(),
1133 slopes.len()
1134 ));
1135 }
1136 if slopes.len() != covariance.dim() {
1137 return Err(format!(
1138 "marginal-slope covariance dimension mismatch: slopes={}, covariance={}",
1139 slopes.len(),
1140 covariance.dim()
1141 ));
1142 }
1143 if !q.is_finite() || z.iter().any(|value| !value.is_finite()) {
1144 return Err("marginal-slope probit eta inputs must be finite".to_string());
1145 }
1146 let scale = marginal_slope_preserving_scale(slopes, covariance, probit_scale)?;
1147 let linear = z
1148 .iter()
1149 .zip(slopes.iter())
1150 .map(|(&score, &slope)| probit_scale * slope * score)
1151 .sum::<f64>();
1152 Ok(q * scale + linear)
1153}
1154
1155pub(super) fn empirical_rigid_calibration_eval(
1185 intercept: f64,
1186 log_target_mu: f64,
1187 slope: f64,
1188 probit_scale: f64,
1189 nodes: &[f64],
1190 weights: &[f64],
1191) -> Result<(f64, f64, f64), String> {
1192 if !intercept.is_finite() {
1193 return Err(format!(
1194 "empirical latent calibration: non-finite intercept {intercept}"
1195 ));
1196 }
1197 let observed_slope = rigid_observed_logslope(slope, probit_scale);
1198 const HALF_LOG_2PI: f64 = 0.918_938_533_204_672_8; let mut log_max_phi = f64::NEG_INFINITY;
1202 let mut sum_phi = 0.0_f64;
1203 let mut log_max_cdf = f64::NEG_INFINITY;
1204 let mut sum_cdf = 0.0_f64;
1205
1206 let mut log_max_pos = f64::NEG_INFINITY;
1210 let mut sum_pos = 0.0_f64;
1211 let mut log_max_neg = f64::NEG_INFINITY;
1212 let mut sum_neg = 0.0_f64;
1213
1214 for (&node, &weight) in nodes.iter().zip(weights.iter()) {
1215 if !(weight.is_finite() && weight > 0.0) {
1216 continue;
1217 }
1218 let eta = intercept + observed_slope * node;
1219 if !eta.is_finite() {
1220 return Err(format!(
1221 "empirical latent calibration: non-finite η at intercept={intercept}, slope={slope}, node={node}"
1222 ));
1223 }
1224 let log_w = weight.ln();
1225 let log_phi = -0.5 * eta * eta - HALF_LOG_2PI;
1226 let log_term_phi = log_w + log_phi;
1227 let log_term_cdf = log_w + normal_logcdf(eta);
1228
1229 lse_accumulate(&mut log_max_phi, &mut sum_phi, log_term_phi);
1230 lse_accumulate(&mut log_max_cdf, &mut sum_cdf, log_term_cdf);
1231
1232 if eta != 0.0 {
1233 let log_term_eta_phi = log_term_phi + eta.abs().ln();
1234 if eta > 0.0 {
1235 lse_accumulate(&mut log_max_pos, &mut sum_pos, log_term_eta_phi);
1236 } else {
1237 lse_accumulate(&mut log_max_neg, &mut sum_neg, log_term_eta_phi);
1238 }
1239 }
1240 }
1241
1242 if !(sum_phi.is_finite() && sum_cdf.is_finite() && sum_phi > 0.0 && sum_cdf > 0.0) {
1243 return Err(format!(
1244 "empirical latent calibration: log-space accumulation failed (sum_phi={sum_phi}, sum_cdf={sum_cdf}, intercept={intercept})"
1245 ));
1246 }
1247
1248 let log_s_phi = log_max_phi + sum_phi.ln();
1249 let log_s_cdf = log_max_cdf + sum_cdf.ln();
1250
1251 let f = log_s_cdf - log_target_mu;
1253 let log_f_prime = log_s_phi - log_s_cdf;
1265 let f_prime = if log_f_prime > -740.0 {
1266 log_f_prime.exp()
1267 } else {
1268 f64::MIN_POSITIVE
1269 };
1270
1271 let exp_safe = |log_x: f64| -> f64 { if log_x > -740.0 { log_x.exp() } else { 0.0 } };
1279 let pos_over_cdf = if sum_pos > 0.0 {
1280 exp_safe(log_max_pos + sum_pos.ln() - log_s_cdf)
1281 } else {
1282 0.0
1283 };
1284 let neg_over_cdf = if sum_neg > 0.0 {
1285 exp_safe(log_max_neg + sum_neg.ln() - log_s_cdf)
1286 } else {
1287 0.0
1288 };
1289 let s_etaphi_over_s_cdf = pos_over_cdf - neg_over_cdf;
1290 let f_double_prime = -s_etaphi_over_s_cdf - f_prime * f_prime;
1291
1292 if !(f.is_finite() && f_prime.is_finite() && f_prime > 0.0 && f_double_prime.is_finite()) {
1293 return Err(format!(
1294 "empirical latent calibration: non-finite log-space state f={f}, f'={f_prime}, f''={f_double_prime} at intercept={intercept}"
1295 ));
1296 }
1297 Ok((f, f_prime, f_double_prime))
1298}
1299
1300pub(crate) fn empirical_intercept_from_marginal(
1301 target_mu: f64,
1302 target_q: f64,
1303 slope: f64,
1304 probit_scale: f64,
1305 nodes: &[f64],
1306 weights: &[f64],
1307 initial: Option<f64>,
1308) -> Result<f64, String> {
1309 if !(target_mu.is_finite() && target_mu > 0.0 && target_mu < 1.0) {
1310 return Err(format!(
1311 "empirical latent calibration requires target mu in (0,1), got {target_mu}"
1312 ));
1313 }
1314 let log_target_mu = target_mu.ln();
1315 let closed_form_seed = rigid_intercept_from_marginal(target_q, slope, probit_scale);
1316 let seed = initial.unwrap_or(closed_form_seed);
1317 let eval = |a: f64| {
1318 empirical_rigid_calibration_eval(a, log_target_mu, slope, probit_scale, nodes, weights)
1319 };
1320 let abs_tol = 1e-13_f64.max(4.0 * f64::EPSILON);
1327 let solve_from = |s: f64| {
1328 crate::monotone_root::solve_monotone_root(
1329 eval,
1330 s,
1331 "empirical latent intercept",
1332 abs_tol,
1333 64,
1334 48,
1335 )
1336 .map_err(|e| e.to_string())
1339 };
1340 let (root, _, f_best) = match solve_from(seed) {
1351 Ok(v) => v,
1352 Err(first_err) => {
1353 if seed == closed_form_seed {
1354 return Err(first_err);
1355 }
1356 solve_from(closed_form_seed).map_err(|retry_err| {
1357 format!("{first_err}; closed-form retry from a={closed_form_seed:.6}: {retry_err}")
1358 })?
1359 }
1360 };
1361 if f_best.abs() > abs_tol {
1362 return Err(format!(
1363 "empirical latent intercept solve failed: log-residual={f_best:.3e} at a={root:.6}, target mu={target_mu:.6}"
1364 ));
1365 }
1366 Ok(root)
1367}
1368
1369#[inline]
1370pub(super) fn rigid_standard_normal_neglog_only(
1371 q: f64,
1372 g: f64,
1373 z: f64,
1374 y: f64,
1375 w: f64,
1376 probit_scale: f64,
1377) -> Result<f64, String> {
1378 let s = 2.0 * y - 1.0;
1379 let eta = marginal_slope_standard_normal_scalar_eta(q, g, z, probit_scale);
1380 let m = s * eta;
1381 let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(m);
1382 if !logcdf.is_finite() {
1383 return Err(format!(
1384 "rigid probit neglog_only: non-finite log Φ at q={q}, g={g}, z={z}, y={y}"
1385 ));
1386 }
1387 Ok(-w * logcdf)
1388}
1389
1390#[inline]
1420pub(crate) fn rigid_standard_normal_row_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
1421 p: &[S; 2],
1422 marginal: BernoulliMarginalLinkMap,
1423 z: f64,
1424 y: f64,
1425 w: f64,
1426 probit_scale: f64,
1427) -> Result<S, String> {
1428 let signed = rigid_standard_normal_signed_margin(p, marginal, z, y, probit_scale);
1432 let m = signed.value();
1435 if !(m.is_finite() || m == f64::INFINITY) {
1436 return Err(format!(
1437 "non-finite signed margin in rigid probit row NLL: {m}"
1438 ));
1439 }
1440 Ok(signed.compose_unary(signed_probit_neglog_unary_stack(m, w)))
1442}
1443
1444#[inline]
1454pub(crate) fn rigid_standard_normal_signed_margin<S: gam_math::jet_scalar::JetScalar<2>>(
1455 p: &[S; 2],
1456 marginal: BernoulliMarginalLinkMap,
1457 z: f64,
1458 y: f64,
1459 probit_scale: f64,
1460) -> S {
1461 let q = p[0].compose_unary([
1463 marginal.q,
1464 marginal.q1,
1465 marginal.q2,
1466 marginal.q3,
1467 marginal.q4,
1468 ]);
1469 let slope = p[1];
1470 let observed_slope = slope.scale(probit_scale);
1472 let b2 = observed_slope.mul(&observed_slope);
1473 let c = b2.add(&S::constant(1.0)).sqrt();
1474 let eta = q.mul(&c).add(&observed_slope.scale(z));
1476 eta.scale(2.0 * y - 1.0)
1477}
1478
1479pub(crate) struct RigidStandardNormalRow {
1492 pub(crate) marginal: BernoulliMarginalLinkMap,
1493 pub(crate) g: f64,
1494 pub(crate) z: f64,
1495 pub(crate) y: f64,
1496 pub(crate) w: f64,
1497 pub(crate) probit_scale: f64,
1498}
1499
1500impl gam_math::jet_tower::RowNllProgramGeneric<2> for RigidStandardNormalRow {
1501 fn n_rows(&self) -> usize {
1502 1
1503 }
1504
1505 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
1506 if row != 0 {
1507 return Err(format!("RigidStandardNormalRow: row {row} out of range"));
1508 }
1509 Ok([self.marginal.eta_value(), self.g])
1510 }
1511
1512 fn row_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
1513 &self,
1514 row: usize,
1515 p: &[S; 2],
1516 ) -> Result<S, String> {
1517 if row != 0 {
1518 return Err(format!("RigidStandardNormalRow: row {row} out of range"));
1519 }
1520 rigid_standard_normal_row_nll_generic(
1521 p,
1522 self.marginal,
1523 self.z,
1524 self.y,
1525 self.w,
1526 self.probit_scale,
1527 )
1528 }
1529}
1530
1531#[inline]
1532pub(crate) fn rigid_standard_normal_tower(
1533 marginal: BernoulliMarginalLinkMap,
1534 g: f64,
1535 z: f64,
1536 y: f64,
1537 w: f64,
1538 probit_scale: f64,
1539) -> Result<Tower4<2>, String> {
1540 let program = RigidStandardNormalRow {
1547 marginal,
1548 g,
1549 z,
1550 y,
1551 w,
1552 probit_scale,
1553 };
1554 gam_math::jet_tower::generic_full_tower(&program, 0)
1555}
1556
1557#[inline]
1570fn rigid_standard_normal_signed_jet(
1571 marginal: BernoulliMarginalLinkMap,
1572 g: f64,
1573 z: f64,
1574 y: f64,
1575 probit_scale: f64,
1576) -> Tower4<2> {
1577 let p = [
1580 Tower4::<2>::variable(marginal.eta_value(), 0),
1581 Tower4::<2>::variable(g, 1),
1582 ];
1583 rigid_standard_normal_signed_margin(&p, marginal, z, y, probit_scale)
1584}
1585
1586#[inline]
1614pub(super) fn rigid_standard_normal_towers_batch<T>(
1615 marginals: &[BernoulliMarginalLinkMap],
1616 slopes: &[f64],
1617 zs: &[f64],
1618 ys: &[f64],
1619 weights: &[f64],
1620 probit_scale: f64,
1621 out: &mut [T],
1622 mut fill: impl FnMut(&Tower4<2>) -> Result<T, String>,
1623) -> Result<(), String> {
1624 let chunk = marginals.len();
1625 if slopes.len() != chunk
1626 || zs.len() != chunk
1627 || ys.len() != chunk
1628 || weights.len() != chunk
1629 || out.len() != chunk
1630 {
1631 return Err(format!(
1632 "rigid_standard_normal_towers_batch length mismatch: marginals={chunk}, \
1633 slopes={}, zs={}, ys={}, weights={}, out={}",
1634 slopes.len(),
1635 zs.len(),
1636 ys.len(),
1637 weights.len(),
1638 out.len()
1639 ));
1640 }
1641
1642 let mut signed: Vec<Tower4<2>> = Vec::with_capacity(chunk);
1644 let mut margins: Vec<f64> = Vec::with_capacity(chunk);
1645 for i in 0..chunk {
1646 let jet =
1647 rigid_standard_normal_signed_jet(marginals[i], slopes[i], zs[i], ys[i], probit_scale);
1648 margins.push(jet.v);
1649 signed.push(jet);
1650 }
1651
1652 let mut stacks: Vec<[f64; 5]> = Vec::with_capacity(chunk);
1656 for i in 0..chunk {
1657 let m = margins[i];
1658 if !(m.is_finite() || m == f64::INFINITY) {
1659 return Err(format!(
1660 "non-finite signed margin in rigid probit tower batch: {m}"
1661 ));
1662 }
1663 stacks.push(signed_probit_neglog_unary_stack(m, weights[i]));
1664 }
1665
1666 for i in 0..chunk {
1668 let tower = signed[i].compose_unary(stacks[i]);
1669 out[i] = fill(&tower)?;
1670 }
1671 Ok(())
1672}
1673
1674#[inline]
1675pub(super) fn rigid_standard_normal_row_kernel(
1676 marginal: BernoulliMarginalLinkMap,
1677 g: f64,
1678 z: f64,
1679 y: f64,
1680 w: f64,
1681 probit_scale: f64,
1682) -> Result<(f64, [f64; 2], [[f64; 2]; 2]), String> {
1683 let program = RigidStandardNormalRow {
1692 marginal,
1693 g,
1694 z,
1695 y,
1696 w,
1697 probit_scale,
1698 };
1699 gam_math::jet_tower::generic_row_kernel(&program, 0)
1700}
1701
1702#[inline]
1740pub(super) fn rigid_standard_normal_mixed_z_sensitivity(
1741 marginal: BernoulliMarginalLinkMap,
1742 g: f64,
1743 z: f64,
1744 y: f64,
1745 w: f64,
1746 probit_scale: f64,
1747) -> Result<[f64; 2], String> {
1748 use gam_math::jet_tower::Tower2;
1759 let mut q = Tower2::<3>::constant(marginal.q);
1760 q.g[0] = marginal.q1;
1761 q.h[0][0] = marginal.q2;
1762 let slope = Tower2::<3>::variable(g, 1);
1763 let z_var = Tower2::<3>::variable(z, 2);
1764 let observed_logslope = slope * probit_scale;
1765 let c = (observed_logslope * observed_logslope + 1.0).sqrt();
1766 let eta = q * c + slope * (z_var * probit_scale);
1770 let signed = eta * (2.0 * y - 1.0);
1771 if !(signed.v.is_finite() || signed.v == f64::INFINITY) {
1773 return Err(format!(
1774 "rigid probit mixed-z sensitivity: non-finite signed margin {} at q={}, g={g}, z={z}, y={y}",
1775 signed.v, marginal.q
1776 ));
1777 }
1778 let stack = signed_probit_neglog_unary_stack(signed.v, w);
1779 if !stack[0].is_finite() {
1780 return Err(format!(
1781 "rigid probit mixed-z sensitivity: non-finite log Φ at q={}, g={g}, z={z}, y={y}",
1782 marginal.q
1783 ));
1784 }
1785 let tower = signed.compose_unary([stack[0], stack[1], stack[2]]);
1788 let s_q = -tower.h[0][2];
1795 let s_g = -tower.h[1][2];
1796 if !(s_q.is_finite() && s_g.is_finite()) {
1797 return Err(format!(
1798 "rigid probit mixed-z sensitivity: non-finite ∂²(log L)/∂(q,g)∂z = [{s_q}, {s_g}] at q={}, g={g}, z={z}",
1799 marginal.q
1800 ));
1801 }
1802 Ok([s_q, s_g])
1803}
1804
1805pub(super) fn rigid_standard_normal_score_zeta_sensitivity(
1832 base_link: &InverseLink,
1833 marginal_eta: &Array1<f64>,
1834 slope_eta: &Array1<f64>,
1835 z: &Array1<f64>,
1836 y: &Array1<f64>,
1837 weights: &Array1<f64>,
1838 probit_scale: f64,
1839 marginal_design: ArrayView2<'_, f64>,
1840 logslope_design: ArrayView2<'_, f64>,
1841 p_beta: usize,
1842) -> Result<Array2<f64>, String> {
1843 let n = marginal_eta.len();
1844 let p_m = marginal_design.ncols();
1845 let r = logslope_design.ncols();
1846 if slope_eta.len() != n
1847 || z.len() != n
1848 || y.len() != n
1849 || weights.len() != n
1850 || marginal_design.nrows() != n
1851 || logslope_design.nrows() != n
1852 {
1853 return Err(format!(
1854 "score_zeta_sensitivity row mismatch: marginal_eta={n}, slope_eta={}, z={}, y={}, \
1855 weights={}, marginal_design rows={}, logslope_design rows={}",
1856 slope_eta.len(),
1857 z.len(),
1858 y.len(),
1859 weights.len(),
1860 marginal_design.nrows(),
1861 logslope_design.nrows()
1862 ));
1863 }
1864 if p_m + r > p_beta {
1865 return Err(format!(
1866 "score_zeta_sensitivity width overflow: marginal({p_m}) + logslope({r}) > p_beta({p_beta})"
1867 ));
1868 }
1869 let mut s = Array2::<f64>::zeros((n, p_beta));
1870 for i in 0..n {
1871 let marginal = bernoulli_marginal_link_map(base_link, marginal_eta[i])?;
1872 let [s_q, s_g] = rigid_standard_normal_mixed_z_sensitivity(
1873 marginal,
1874 slope_eta[i],
1875 z[i],
1876 y[i],
1877 weights[i],
1878 probit_scale,
1879 )?;
1880 if s_q != 0.0 {
1883 let m_row = marginal_design.row(i);
1884 for (j, &mij) in m_row.iter().enumerate() {
1885 s[[i, j]] = s_q * mij;
1886 }
1887 }
1888 if s_g != 0.0 {
1889 let g_row = logslope_design.row(i);
1890 for (j, &gij) in g_row.iter().enumerate() {
1891 s[[i, p_m + j]] = s_g * gij;
1892 }
1893 }
1894 }
1895 Ok(s)
1896}
1897
1898#[inline]
1899pub(super) fn rigid_standard_normal_third_full(
1900 marginal: BernoulliMarginalLinkMap,
1901 g: f64,
1902 z: f64,
1903 y: f64,
1904 w: f64,
1905 probit_scale: f64,
1906) -> Result<[[[f64; 2]; 2]; 2], String> {
1907 Ok(rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?.t3)
1908}
1909
1910#[inline]
1915pub(super) fn contract_third_full(t: &[[[f64; 2]; 2]; 2], d_eta: f64, d_g: f64) -> [[f64; 2]; 2] {
1916 [
1917 [
1918 t[0][0][0] * d_eta + t[0][0][1] * d_g,
1919 t[0][1][0] * d_eta + t[0][1][1] * d_g,
1920 ],
1921 [
1922 t[1][0][0] * d_eta + t[1][0][1] * d_g,
1923 t[1][1][0] * d_eta + t[1][1][1] * d_g,
1924 ],
1925 ]
1926}
1927
1928#[inline]
1929pub(super) fn rigid_standard_normal_fourth_full(
1930 marginal: BernoulliMarginalLinkMap,
1931 g: f64,
1932 z: f64,
1933 y: f64,
1934 w: f64,
1935 probit_scale: f64,
1936) -> Result<[[[[f64; 2]; 2]; 2]; 2], String> {
1937 Ok(rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?.t4)
1951}
1952
1953#[inline]
1972pub(super) fn contract_fourth_full(
1973 t: &[[[[f64; 2]; 2]; 2]; 2],
1974 u_eta: f64,
1975 u_g: f64,
1976 v_eta: f64,
1977 v_g: f64,
1978) -> [[f64; 2]; 2] {
1979 let mut out = [[0.0; 2]; 2];
1980 for a in 0..2 {
1981 for b in 0..2 {
1982 let mut sum = 0.0;
1983 sum += t[a][b][0][0] * u_eta * v_eta;
1984 sum += t[a][b][0][1] * u_eta * v_g;
1985 sum += t[a][b][1][0] * u_g * v_eta;
1986 sum += t[a][b][1][1] * u_g * v_g;
1987 out[a][b] = sum;
1988 }
1989 }
1990 out
1991}
1992
1993pub(super) fn ensure_finite_third_full_cache_row(
1994 t: &[[[f64; 2]; 2]; 2],
1995 context: &str,
1996) -> Result<(), String> {
1997 if t.iter().flatten().flatten().all(|value| value.is_finite()) {
1998 Ok(())
1999 } else {
2000 Err(format!(
2001 "{context}: warmed third-derivative cache row contains a non-finite value"
2002 ))
2003 }
2004}
2005
2006pub(super) fn ensure_finite_fourth_full_cache_row(
2007 t: &[[[[f64; 2]; 2]; 2]; 2],
2008 context: &str,
2009) -> Result<(), String> {
2010 if t.iter()
2011 .flatten()
2012 .flatten()
2013 .flatten()
2014 .all(|value| value.is_finite())
2015 {
2016 Ok(())
2017 } else {
2018 Err(format!(
2019 "{context}: warmed fourth-derivative cache row contains a non-finite value"
2020 ))
2021 }
2022}
2023
2024pub(crate) fn unary_derivatives_sqrt(x: f64) -> [f64; 5] {
2025 let s = x.max(1e-300).sqrt();
2026 let x1 = x.max(1e-300);
2027 let x2 = x1 * x1;
2028 let x3 = x2 * x1;
2029 [
2030 s,
2031 0.5 / s,
2032 -0.25 / (x1 * s),
2033 3.0 / (8.0 * x2 * s),
2034 -15.0 / (16.0 * x3 * s),
2035 ]
2036}
2037pub(crate) fn unary_derivatives_neglog_phi(x: f64, weight: f64) -> [f64; 5] {
2038 signed_probit_neglog_unary_stack(x, weight)
2043}
2044
2045pub(crate) fn unary_derivatives_log(x: f64) -> [f64; 5] {
2063 let x2 = x * x;
2064 let x3 = x2 * x;
2065 let x4 = x3 * x;
2066 [x.ln(), 1.0 / x, -1.0 / x2, 2.0 / x3, -6.0 / x4]
2067}
2068
2069pub(crate) fn unary_derivatives_log_normal_pdf(x: f64) -> [f64; 5] {
2071 let c = 0.5 * (2.0 * std::f64::consts::PI).ln();
2072 [-0.5 * x * x - c, -x, -1.0, 0.0, 0.0]
2073}
2074
2075#[cfg(test)]
2076mod jet_tower_oracle_tests {
2077 use super::*;
2099
2100 fn rigid_standard_normal_third_and_fourth_full(
2108 marginal: BernoulliMarginalLinkMap,
2109 g: f64,
2110 z: f64,
2111 y: f64,
2112 w: f64,
2113 probit_scale: f64,
2114 ) -> Result<([[[f64; 2]; 2]; 2], [[[[f64; 2]; 2]; 2]; 2]), String> {
2115 let tower = rigid_standard_normal_tower(marginal, g, z, y, w, probit_scale)?;
2116 Ok((tower.t3, tower.t4))
2117 }
2118 use gam_math::jet_tower::{
2119 KernelChannels, RowNllProgram, evaluate_program, verify_kernel_channels,
2120 };
2121
2122 struct BernoulliRigidStandardNormalNllProgram {
2125 primaries: Vec<[f64; 2]>,
2127 z: Vec<f64>,
2129 y: Vec<f64>,
2130 w: Vec<f64>,
2131 probit_scale: f64,
2132 }
2133
2134 impl RowNllProgram<2> for BernoulliRigidStandardNormalNllProgram {
2135 fn n_rows(&self) -> usize {
2136 self.primaries.len()
2137 }
2138
2139 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
2140 self.primaries
2141 .get(row)
2142 .copied()
2143 .ok_or_else(|| format!("bernoulli rigid nll program: row {row} out of range"))
2144 }
2145
2146 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
2147 let z = self.z[row];
2148 let y = self.y[row];
2149 let w = self.w[row];
2150 let s = self.probit_scale;
2151 let eta_marginal = p[0];
2155 let link = bernoulli_marginal_link_map(
2156 &InverseLink::Standard(gam_problem::StandardLink::Probit),
2157 eta_marginal.v,
2158 )?;
2159 let q = eta_marginal.compose_unary([link.q, link.q1, link.q2, link.q3, link.q4]);
2160 let g = p[1];
2161 let observed_slope = g * s;
2163 let c = (observed_slope * observed_slope + 1.0).compose_unary(unary_derivatives_sqrt(
2164 observed_slope.v * observed_slope.v + 1.0,
2165 ));
2166 let eta = q * c + observed_slope * z;
2168 let signed = eta * (2.0 * y - 1.0);
2169 Ok(signed.compose_unary(unary_derivatives_neglog_phi(signed.v, w)))
2171 }
2172 }
2173
2174 fn scalar_nll(eta_marginal: f64, g: f64, z: f64, y: f64, w: f64, s: f64) -> f64 {
2177 let link = bernoulli_marginal_link_map(
2178 &InverseLink::Standard(gam_problem::StandardLink::Probit),
2179 eta_marginal,
2180 )
2181 .unwrap();
2182 let observed_slope = g * s;
2183 let c = (observed_slope * observed_slope + 1.0).sqrt();
2184 let eta = link.q * c + observed_slope * z;
2185 let signed = (2.0 * y - 1.0) * eta;
2186 let cdf = 0.5 * libm::erfc(-signed / std::f64::consts::SQRT_2);
2187 -w * cdf.max(1e-300).ln()
2188 }
2189
2190 #[test]
2191 fn rigid_bernoulli_row_kernel_agrees_with_jet_tower_program_all_channels() {
2192 let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
2196 let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
2197 let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
2198 let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
2199 let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
2200 let n = eta.len();
2201
2202 let dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
2204
2205 for &probit_scale in &[1.0_f64, 0.8] {
2206 let program = BernoulliRigidStandardNormalNllProgram {
2207 primaries: (0..n).map(|r| [eta[r], g[r]]).collect(),
2208 z: z.to_vec(),
2209 y: y.to_vec(),
2210 w: w.to_vec(),
2211 probit_scale,
2212 };
2213
2214 for row in 0..n {
2215 let tower = evaluate_program(&program, row).expect("tower evaluation");
2216
2217 let marginal = bernoulli_marginal_link_map(
2219 &InverseLink::Standard(gam_problem::StandardLink::Probit),
2220 eta[row],
2221 )
2222 .expect("link map");
2223 let (value, gradient, hessian) = rigid_standard_normal_row_kernel(
2224 marginal,
2225 g[row],
2226 z[row],
2227 y[row],
2228 w[row],
2229 probit_scale,
2230 )
2231 .expect("production row kernel");
2232
2233 let (third_full, fourth_full) = rigid_standard_normal_third_and_fourth_full(
2240 marginal,
2241 g[row],
2242 z[row],
2243 y[row],
2244 w[row],
2245 probit_scale,
2246 )
2247 .expect("production third+fourth");
2248 let third: Vec<([f64; 2], [[f64; 2]; 2])> = dirs
2249 .iter()
2250 .map(|d| (*d, contract_third_full(&third_full, d[0], d[1])))
2251 .collect();
2252
2253 let fourth: Vec<([f64; 2], [f64; 2], [[f64; 2]; 2])> = dirs
2254 .iter()
2255 .enumerate()
2256 .map(|(i, u)| {
2257 let v = dirs[(i + 1) % dirs.len()];
2258 (
2259 *u,
2260 v,
2261 contract_fourth_full(&fourth_full, u[0], u[1], v[0], v[1]),
2262 )
2263 })
2264 .collect();
2265
2266 let claims = KernelChannels {
2267 value,
2268 gradient,
2269 hessian,
2270 third,
2271 fourth,
2272 };
2273
2274 verify_kernel_channels(&tower, &claims, 1e-9).unwrap_or_else(|e| {
2275 panic!(
2276 "probit_scale {probit_scale} row {row}: production rigid Bernoulli \
2277 RowKernel disagrees with #932 jet-tower truth: {e}"
2278 )
2279 });
2280
2281 let h = 1e-3;
2285 let f = |de: f64, dg: f64| {
2286 scalar_nll(
2287 eta[row] + de,
2288 g[row] + dg,
2289 z[row],
2290 y[row],
2291 w[row],
2292 probit_scale,
2293 )
2294 };
2295 let f0 = f(0.0, 0.0);
2296 assert!(
2297 (f0 - tower.v).abs() <= 1e-9 * f0.abs().max(1.0),
2298 "row {row}: independent scalar NLL {f0:+.12e} != tower value {:+.12e}",
2299 tower.v
2300 );
2301 let g_eta = (f(-2.0 * h, 0.0) - 8.0 * f(-h, 0.0) + 8.0 * f(h, 0.0)
2303 - f(2.0 * h, 0.0))
2304 / (12.0 * h);
2305 let g_g = (f(0.0, -2.0 * h) - 8.0 * f(0.0, -h) + 8.0 * f(0.0, h) - f(0.0, 2.0 * h))
2306 / (12.0 * h);
2307 for (label, fd, ad) in [("∂η", g_eta, tower.g[0]), ("∂g", g_g, tower.g[1])] {
2308 assert!(
2309 (fd - ad).abs() <= 1e-5 * ad.abs().max(1.0),
2310 "row {row} {label}: FD witness {fd:+.6e} != tower grad {ad:+.6e}"
2311 );
2312 }
2313 }
2314 }
2315 }
2316
2317 #[test]
2327 fn rigid_third_and_fourth_full_shares_one_tower_bit_identical() {
2328 let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
2329 let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
2330 let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
2331 let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
2332 let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
2333 for &probit_scale in &[1.0_f64, 0.8] {
2334 for r in 0..eta.len() {
2335 let marginal = bernoulli_marginal_link_map(
2336 &InverseLink::Standard(gam_problem::StandardLink::Probit),
2337 eta[r],
2338 )
2339 .expect("link map");
2340 let t3_sep = rigid_standard_normal_third_full(
2341 marginal,
2342 g[r],
2343 z[r],
2344 y[r],
2345 w[r],
2346 probit_scale,
2347 )
2348 .expect("separate third");
2349 let t4_sep = rigid_standard_normal_fourth_full(
2350 marginal,
2351 g[r],
2352 z[r],
2353 y[r],
2354 w[r],
2355 probit_scale,
2356 )
2357 .expect("separate fourth");
2358 let (t3_comb, t4_comb) = rigid_standard_normal_third_and_fourth_full(
2359 marginal,
2360 g[r],
2361 z[r],
2362 y[r],
2363 w[r],
2364 probit_scale,
2365 )
2366 .expect("combined third+fourth");
2367 for a in 0..2 {
2369 for b in 0..2 {
2370 for c in 0..2 {
2371 assert_eq!(
2372 t3_comb[a][b][c], t3_sep[a][b][c],
2373 "t3[{a}][{b}][{c}] row {r} scale {probit_scale} not bit-identical"
2374 );
2375 for d in 0..2 {
2376 assert_eq!(
2377 t4_comb[a][b][c][d], t4_sep[a][b][c][d],
2378 "t4[{a}][{b}][{c}][{d}] row {r} scale {probit_scale} not bit-identical"
2379 );
2380 }
2381 }
2382 }
2383 }
2384 }
2385 }
2386 }
2387
2388 #[test]
2399 fn rigid_bernoulli_generic_program_matches_tower4_program_all_channels() {
2400 use gam_math::jet_tower::{
2401 generic_fourth_contracted, generic_full_tower, generic_row_kernel,
2402 generic_third_contracted,
2403 };
2404
2405 let eta = [0.3_f64, -0.7, 0.05, 0.9, -1.2, 2.1, -2.4];
2406 let g = [0.2_f64, -0.5, 0.35, -0.15, 0.6, 0.45, -0.55];
2407 let z = [0.4_f64, -1.1, 0.0, 0.7, -0.3, 1.6, -1.4];
2408 let y = [1.0_f64, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
2409 let w = [1.0_f64, 0.8, 1.3, 0.9, 1.1, 0.7, 1.4];
2410 let n = eta.len();
2411 let dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
2412
2413 let close = |a: f64, b: f64, label: &str| {
2414 let band = 1e-12 + 1e-12 * a.abs().max(b.abs());
2415 assert!(
2416 (a - b).abs() <= band,
2417 "{label}: generic {a:+.15e} vs Tower4-program {b:+.15e} (band {band:.3e})"
2418 );
2419 };
2420
2421 for &probit_scale in &[1.0_f64, 0.8] {
2422 let tower_program = BernoulliRigidStandardNormalNllProgram {
2424 primaries: (0..n).map(|r| [eta[r], g[r]]).collect(),
2425 z: z.to_vec(),
2426 y: y.to_vec(),
2427 w: w.to_vec(),
2428 probit_scale,
2429 };
2430
2431 for row in 0..n {
2432 let truth = evaluate_program(&tower_program, row).expect("Tower4 program tower");
2433
2434 let marginal = bernoulli_marginal_link_map(
2435 &InverseLink::Standard(gam_problem::StandardLink::Probit),
2436 eta[row],
2437 )
2438 .expect("link map");
2439 let program = RigidStandardNormalRow {
2440 marginal,
2441 g: g[row],
2442 z: z[row],
2443 y: y[row],
2444 w: w[row],
2445 probit_scale,
2446 };
2447
2448 let full = generic_full_tower(&program, 0).expect("generic full tower");
2451 close(full.v, truth.v, "full value");
2452 for a in 0..2 {
2453 close(full.g[a], truth.g[a], "full grad");
2454 for b in 0..2 {
2455 close(full.h[a][b], truth.h[a][b], "full hess");
2456 for c in 0..2 {
2457 close(full.t3[a][b][c], truth.t3[a][b][c], "full t3");
2458 for d in 0..2 {
2459 close(full.t4[a][b][c][d], truth.t4[a][b][c][d], "full t4");
2460 }
2461 }
2462 }
2463 }
2464
2465 let (val, grad, hess) =
2467 generic_row_kernel(&program, 0).expect("generic row kernel");
2468 close(val, truth.v, "order2 value");
2469 for a in 0..2 {
2470 close(grad[a], truth.g[a], "order2 grad");
2471 for b in 0..2 {
2472 close(hess[a][b], truth.h[a][b], "order2 hess");
2473 }
2474 }
2475
2476 for dir in &dirs {
2479 let third = generic_third_contracted(&program, 0, dir)
2480 .expect("generic third contracted");
2481 let truth3 = truth.third_contracted(dir);
2482 for a in 0..2 {
2483 for b in 0..2 {
2484 close(third[a][b], truth3[a][b], "third contracted");
2485 }
2486 }
2487 }
2488
2489 for (i, u) in dirs.iter().enumerate() {
2492 let v = dirs[(i + 1) % dirs.len()];
2493 let fourth = generic_fourth_contracted(&program, 0, u, &v)
2494 .expect("generic fourth contracted");
2495 let truth4 = truth.fourth_contracted(u, &v);
2496 for a in 0..2 {
2497 for b in 0..2 {
2498 close(fourth[a][b], truth4[a][b], "fourth contracted");
2499 }
2500 }
2501 }
2502 }
2503 }
2504 }
2505}
2506
2507#[cfg(test)]
2508mod flex_primary_hessian_oracle_tests {
2509 use super::*;
2533 use super::family::*;
2540 use gam_linalg::matrix::DenseDesignMatrix;
2541 use ndarray::Array1;
2542 use ndarray::Array2;
2543 use std::sync::Arc;
2544 use std::sync::Mutex;
2545
2546 fn make_flex_oracle_family(
2552 n: usize,
2553 ) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
2554 let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
2555 let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
2556 let cfg = DeviationBlockConfig {
2557 num_internal_knots: 3,
2558 ..DeviationBlockConfig::default()
2559 };
2560 let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
2561 .expect("build score warp block");
2562 let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
2563 &link_seed, &link_seed, &cfg,
2564 )
2565 .expect("build link deviation block");
2566
2567 let y: Array1<f64> =
2568 Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
2569 let weights: Array1<f64> =
2570 Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
2571 let z: Array1<f64> =
2572 Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
2573 let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
2574 if j == 0 {
2575 1.0
2576 } else {
2577 -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
2578 }
2579 });
2580 let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
2581 if j == 0 {
2582 1.0
2583 } else {
2584 0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
2585 }
2586 });
2587
2588 let family = BernoulliMarginalSlopeFamily {
2589 y: Arc::new(y),
2590 weights: Arc::new(weights),
2591 z: Arc::new(z.clone()),
2592 latent_measure: LatentMeasureKind::StandardNormal,
2593 gaussian_frailty_sd: Some(0.15),
2594 base_link: InverseLink::Standard(gam_problem::StandardLink::Probit),
2595 marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
2596 logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
2597 score_warp: Some(score_prepared.runtime.clone()),
2598 link_dev: Some(link_prepared.runtime.clone()),
2599 policy: gam_runtime::resource::ResourcePolicy::default_library(),
2600 cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
2601 cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
2602 intercept_warm_starts: None,
2603 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
2604 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
2605 };
2606
2607 let beta_m = Array1::from_vec(vec![0.12, -0.04]);
2608 let beta_g = Array1::from_vec(vec![0.35, 0.03]);
2609 let beta_h = Array1::from_iter(
2610 (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
2611 );
2612 let beta_w = Array1::from_iter(
2613 (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
2614 );
2615 let states = vec![
2616 ParameterBlockState {
2617 eta: marginal_x.dot(&beta_m),
2618 beta: beta_m,
2619 },
2620 ParameterBlockState {
2621 eta: logslope_x.dot(&beta_g),
2622 beta: beta_g,
2623 },
2624 ParameterBlockState {
2625 beta: beta_h,
2626 eta: Array1::zeros(z.len()),
2627 },
2628 ParameterBlockState {
2629 beta: beta_w,
2630 eta: Array1::zeros(z.len()),
2631 },
2632 ];
2633 (family, states)
2634 }
2635
2636 fn flex_gradient_at_perturbed(
2644 family: &BernoulliMarginalSlopeFamily,
2645 states: &[ParameterBlockState],
2646 primary: &super::super::hessian_paths::PrimarySlices,
2647 row: usize,
2648 u: usize,
2649 delta: f64,
2650 ) -> Array1<f64> {
2651 let mut states = states.to_vec();
2652 if u == primary.q {
2658 states[0].eta[row] += delta;
2659 } else if u == primary.logslope {
2660 states[1].eta[row] += delta;
2661 } else if let Some(h_range) = primary.h.as_ref()
2662 && h_range.contains(&u)
2663 {
2664 states[2].beta[u - h_range.start] += delta;
2665 } else if let Some(w_range) = primary.w.as_ref()
2666 && w_range.contains(&u)
2667 {
2668 states[3].beta[u - w_range.start] += delta;
2669 } else {
2670 panic!("primary coordinate {u} out of range for flex oracle");
2671 }
2672 let row_ctx = family
2673 .build_row_exact_context_with_stats_and_cell_cache(row, &states, None, false)
2674 .expect("perturbed row context");
2675 let (_neglog, grad, _hess) = family
2676 .compute_row_primary_gradient_hessian(row, &states, primary, &row_ctx)
2677 .expect("perturbed gradient");
2678 grad
2679 }
2680
2681 #[test]
2684 fn flex_primary_hessian_matches_central_fd_of_gradient() {
2685 let n = 12usize;
2686 let (family, states) = make_flex_oracle_family(n);
2687 let cache = family
2688 .build_exact_eval_cache(&states)
2689 .expect("flex exact eval cache");
2690 let primary = &cache.primary;
2691 let r = primary.total;
2692 assert!(
2693 r >= 4,
2694 "flex fixture must carry q + logslope + deviation blocks"
2695 );
2696
2697 let h = 1e-4;
2701 let mut max_rel = 0.0_f64;
2702
2703 for &row in &[2usize, 5, 8] {
2706 let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
2707 let (_neglog, _grad, analytic_hess) = family
2708 .compute_row_primary_gradient_hessian(row, &states, primary, row_ctx)
2709 .expect("analytic flex gradient + hessian");
2710
2711 for u in 0..r {
2712 let grad_plus = flex_gradient_at_perturbed(&family, &states, primary, row, u, h);
2713 let grad_minus = flex_gradient_at_perturbed(&family, &states, primary, row, u, -h);
2714 for v in 0..r {
2715 let fd = (grad_plus[v] - grad_minus[v]) / (2.0 * h);
2716 let analytic = analytic_hess[[v, u]];
2717 let denom = 1.0 + analytic.abs().max(fd.abs());
2718 let rel = (analytic - fd).abs() / denom;
2719 max_rel = max_rel.max(rel);
2720 assert!(
2721 rel <= 1e-6,
2722 "flex hand Hessian H[{v}][{u}] = {analytic:.6e} disagrees with central \
2723 FD of the gradient {fd:.6e} at row {row} (rel {rel:.3e}); a product-rule \
2724 term is dropped or mis-signed"
2725 );
2726 }
2727 }
2728 }
2729 assert!(
2731 max_rel <= 1e-6,
2732 "flex Hessian FD oracle max rel {max_rel:.3e}"
2733 );
2734 }
2735
2736 #[test]
2745 fn arbiter_flex_hessian_h00_fd_step_scaling() {
2746 let n = 12usize;
2747 let (family, states) = make_flex_oracle_family(n);
2748 let cache = family
2749 .build_exact_eval_cache(&states)
2750 .expect("flex exact eval cache");
2751 let primary = &cache.primary;
2752 let row = 2usize;
2753 let u = primary.q; let v = primary.q;
2755
2756 let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
2757 let (_neglog, _grad, analytic_hess) = family
2758 .compute_row_primary_gradient_hessian(row, &states, primary, row_ctx)
2759 .expect("analytic flex gradient + hessian");
2760 let analytic = analytic_hess[[v, u]];
2761
2762 let fd_at = |h: f64| -> f64 {
2763 let gp = flex_gradient_at_perturbed(&family, &states, primary, row, u, h);
2764 let gm = flex_gradient_at_perturbed(&family, &states, primary, row, u, -h);
2765 (gp[v] - gm[v]) / (2.0 * h)
2766 };
2767
2768 let h = 1e-3_f64;
2775 let fd_h = fd_at(h);
2776 let fd_half = fd_at(h * 0.5);
2777 let fd_quarter = fd_at(h * 0.25);
2778 let gap_h = (analytic - fd_h).abs();
2779 let gap_half = (analytic - fd_half).abs();
2780 let gap_quarter = (analytic - fd_quarter).abs();
2781 let rich = (4.0 * fd_half - fd_h) / 3.0;
2782 let rich_gap = (analytic - rich).abs();
2783 let denom = analytic.abs().max(1.0);
2784
2785 let record = format!(
2787 "FLEX H[0][0] ARBITER row 2: analytic={analytic:+.12e} \
2788 fd(h)={fd_h:+.12e} fd(h/2)={fd_half:+.12e} fd(h/4)={fd_quarter:+.12e} \
2789 gap(h)={gap_h:.3e} gap(h/2)={gap_half:.3e} gap(h/4)={gap_quarter:.3e} \
2790 ratio_h_over_half={:.3} ratio_half_over_quarter={:.3} \
2791 richardson={rich:+.12e} richardson_gap={rich_gap:.3e} (rich_rel={:.3e})",
2792 gap_h / gap_half.max(f64::MIN_POSITIVE),
2793 gap_half / gap_quarter.max(f64::MIN_POSITIVE),
2794 rich_gap / denom,
2795 );
2796
2797 assert!(
2803 rich_gap / denom <= 1e-7,
2804 "{record}\nVERDICT: Richardson residual exceeds the FD-truncation floor — \
2805 the hand H[0][0] genuinely diverges (real dropped/mis-signed term), NOT FD noise"
2806 );
2807 }
2808}