use crate::families::bms::{
EmpiricalZGrid, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
bernoulli_marginal_link_map, empirical_intercept_from_marginal,
};
use crate::families::marginal_slope_shared::{
ObservedDenestedCellPartials, eval_coeff4_at,
probit_frailty_scale as marginal_slope_probit_frailty_scale, scale_coeff4,
};
use crate::families::survival::lognormal_kernel::FrailtySpec;
use crate::inference::model::{SavedCompiledFlexBlock, SavedLatentZNormalization};
use crate::matrix::DesignMatrix;
use crate::probability::{normal_cdf, normal_pdf};
use crate::solver::estimate::{EstimationError, UnifiedFitResult};
use crate::types::{InverseLink, LikelihoodSpec};
use ndarray::{Array1, Array2, ArrayView1};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
pub struct PredictResult {
pub eta: Array1<f64>,
pub mean: Array1<f64>,
}
pub struct PredictInput {
pub design: DesignMatrix,
pub offset: Array1<f64>,
pub design_noise: Option<DesignMatrix>,
pub offset_noise: Option<Array1<f64>>,
pub auxiliary_scalar: Option<Array1<f64>>,
pub auxiliary_matrix: Option<Array2<f64>>,
}
pub struct BernoulliMarginalSlopePredictor {
pub beta_marginal: Array1<f64>,
pub beta_logslope: Array1<f64>,
pub beta_score_warp: Option<Array1<f64>>,
pub beta_link_dev: Option<Array1<f64>>,
pub base_link: InverseLink,
pub z_column: String,
pub latent_z_normalization: SavedLatentZNormalization,
pub latent_measure: LatentMeasureKind,
pub baseline_marginal: f64,
pub baseline_logslope: f64,
pub covariance: Option<Array2<f64>>,
pub score_warp_runtime: Option<SavedCompiledFlexBlock>,
pub link_deviation_runtime: Option<SavedCompiledFlexBlock>,
pub gaussian_frailty_sd: Option<f64>,
pub latent_z_calibration: Option<LatentZRankIntCalibration>,
pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
}
fn prediction_chunk_rows(parameter_dim: usize, local_dim: usize, n_rows: usize) -> usize {
const PREDICTION_TARGET_WORK_BYTES: usize = 2 * 1024 * 1024;
const PREDICTION_MIN_CHUNK_ROWS: usize = 16;
const PREDICTION_MAX_CHUNK_ROWS: usize = 4096;
if n_rows == 0 {
return 1;
}
let bytes_per_row = parameter_dim
.max(1)
.saturating_mul(local_dim.max(1))
.saturating_mul(std::mem::size_of::<f64>())
.saturating_mul(4);
let target_rows = if bytes_per_row == 0 {
n_rows
} else {
PREDICTION_TARGET_WORK_BYTES / bytes_per_row
};
target_rows
.max(PREDICTION_MIN_CHUNK_ROWS)
.min(PREDICTION_MAX_CHUNK_ROWS)
.min(n_rows.max(1))
}
#[derive(Default)]
struct BmsAnchorCorrections {
score_warp_anchor_rows: Option<Array2<f64>>,
link_dev_anchor_rows: Option<Array2<f64>>,
score_warp: Option<Array2<f64>>,
link_dev: Option<Array2<f64>>,
}
impl BmsAnchorCorrections {
fn score_warp_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
self.score_warp.as_ref().map(|m| m.row(row))
}
fn link_dev_row(&self, row: usize) -> Option<ndarray::ArrayView1<'_, f64>> {
self.link_dev.as_ref().map(|m| m.row(row))
}
fn score_warp_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
self.score_warp_anchor_rows.as_ref().map(|m| m.view())
}
fn link_dev_anchor_rows_view(&self) -> Option<ndarray::ArrayView2<'_, f64>> {
self.link_dev_anchor_rows.as_ref().map(|m| m.view())
}
}
impl BernoulliMarginalSlopePredictor {
fn build_anchor_correction_matrices(
&self,
input: &PredictInput,
design_logslope: &DesignMatrix,
z: &Array1<f64>,
) -> Result<BmsAnchorCorrections, EstimationError> {
use gam::inference::model::SavedAnchorKind;
let needs_score = self
.score_warp_runtime
.as_ref()
.is_some_and(|r| r.anchor_correction.is_some());
let needs_link = self
.link_deviation_runtime
.as_ref()
.is_some_and(|r| r.anchor_correction.is_some());
if !needs_score && !needs_link {
return Ok(BmsAnchorCorrections::default());
}
let marginal_dense = input
.design
.try_to_dense_arc(
"bernoulli marginal-slope predict-time marginal anchor materialisation",
)
.map_err(EstimationError::InvalidInput)?;
let logslope_dense = design_logslope
.try_to_dense_arc(
"bernoulli marginal-slope predict-time logslope anchor materialisation",
)
.map_err(EstimationError::InvalidInput)?;
let n_rows = marginal_dense.nrows();
if logslope_dense.nrows() != n_rows {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope predict anchor materialisation row mismatch: marginal {} vs logslope {}",
n_rows,
logslope_dense.nrows()
)));
}
if z.len() != n_rows {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope predict anchor materialisation: z has {} entries, expected {}",
z.len(),
n_rows
)));
}
let p_marginal = marginal_dense.ncols();
let p_logslope = logslope_dense.ncols();
let d_parametric = p_marginal + p_logslope;
let mut parametric_rows = Array2::<f64>::zeros((n_rows, d_parametric));
parametric_rows
.slice_mut(ndarray::s![.., 0..p_marginal])
.assign(&marginal_dense.view());
parametric_rows
.slice_mut(ndarray::s![.., p_marginal..d_parametric])
.assign(&logslope_dense.view());
let score_warp = if needs_score {
let runtime = self.score_warp_runtime.as_ref().unwrap();
self.validate_runtime_anchor_layout_parametric_only(runtime, "score_warp")?;
runtime
.anchor_correction_matrix(parametric_rows.view())
.map_err(EstimationError::from)?
} else {
None
};
let (link_dev_anchor_rows, link_dev) = if needs_link {
let runtime = self.link_deviation_runtime.as_ref().unwrap();
let mut saw_flex_tail = false;
let mut flex_tail_ncols: usize = 0;
for (idx, component) in runtime.anchor_components.iter().enumerate() {
match &component.kind {
SavedAnchorKind::Parametric { .. } => {
if saw_flex_tail {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope link-deviation saved anchor components \
are out of order: parametric component at index {idx} follows \
a FlexEvaluation tail",
)));
}
}
SavedAnchorKind::FlexEvaluation { ncols } => {
if saw_flex_tail {
return Err(EstimationError::InvalidInput(
"bernoulli marginal-slope link-deviation saved anchor components \
carry more than one FlexEvaluation tail; fit-time stacking emits \
at most one (score-warp)"
.to_string(),
));
}
saw_flex_tail = true;
flex_tail_ncols = *ncols;
}
}
}
let rows = if saw_flex_tail {
let score_runtime = self.score_warp_runtime.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope link-deviation saved anchor includes a \
FlexEvaluation tail but the saved score-warp runtime is missing"
.to_string(),
)
})?;
let score_basis = if score_runtime.anchor_correction.is_some() {
score_runtime
.design_with_anchor_rows(z, parametric_rows.view())
.map_err(EstimationError::from)?
} else {
score_runtime.design(z).map_err(EstimationError::from)?
};
if score_basis.ncols() != flex_tail_ncols {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope link-deviation FlexEvaluation tail expects \
{} score-warp basis columns at predict rows, got {}",
flex_tail_ncols,
score_basis.ncols()
)));
}
let mut combined = Array2::<f64>::zeros((n_rows, d_parametric + flex_tail_ncols));
combined
.slice_mut(ndarray::s![.., 0..d_parametric])
.assign(¶metric_rows.view());
combined
.slice_mut(ndarray::s![.., d_parametric..])
.assign(&score_basis.view());
combined
} else {
parametric_rows.clone()
};
let corr = runtime
.anchor_correction_matrix(rows.view())
.map_err(EstimationError::from)?;
(Some(rows), corr)
} else {
(None, None)
};
Ok(BmsAnchorCorrections {
score_warp_anchor_rows: Some(parametric_rows),
link_dev_anchor_rows,
score_warp,
link_dev,
})
}
fn validate_runtime_anchor_layout_parametric_only(
&self,
runtime: &SavedCompiledFlexBlock,
runtime_label: &str,
) -> Result<(), EstimationError> {
use gam::inference::model::SavedAnchorKind;
for (idx, component) in runtime.anchor_components.iter().enumerate() {
match &component.kind {
SavedAnchorKind::Parametric { .. } => {}
SavedAnchorKind::FlexEvaluation { .. } => {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope {runtime_label} saved anchor component at \
index {idx} is FlexEvaluation; only Parametric components are \
expected for this runtime",
)));
}
}
}
Ok(())
}
pub fn likelihood_family(&self) -> LikelihoodSpec {
LikelihoodSpec::binomial_probit()
}
pub fn mean_from_eta(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
Ok(eta.mapv(normal_cdf))
}
pub fn mean_derivative_from_eta(
&self,
eta: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
Ok(eta.mapv(normal_pdf))
}
pub(crate) fn probit_frailty_scale(&self) -> f64 {
marginal_slope_probit_frailty_scale(self.gaussian_frailty_sd)
}
fn apply_latent_z_calibration(&self, z: &Array1<f64>) -> Array1<f64> {
match &self.latent_z_calibration {
Some(cal) => Array1::from_iter(z.iter().map(|&zi| cal.apply_at_predict(zi))),
None => z.clone(),
}
}
fn apply_latent_z_conditional_calibration(
&self,
z: &Array1<f64>,
input: &PredictInput,
) -> Result<Array1<f64>, EstimationError> {
let Some(cal) = self.latent_z_conditional_calibration.as_ref() else {
return Ok(z.clone());
};
let a_block = input.design.to_dense();
cal.apply(z.view(), a_block.view())
.map_err(EstimationError::InvalidInput)
}
fn rigid_intercept_from_marginal(&self, marginal_eta: f64, slope: f64) -> f64 {
let probit_scale = self.probit_frailty_scale();
marginal_eta * (1.0 + (probit_scale * slope).powi(2)).sqrt() / probit_scale
}
fn empirical_rigid_intercept_and_gradient(
&self,
marginal_eta: f64,
slope: f64,
nodes: &[f64],
weights: &[f64],
) -> Result<(f64, f64, f64), EstimationError> {
let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
.map_err(EstimationError::InvalidInput)?;
let scale = self.probit_frailty_scale();
let intercept = empirical_intercept_from_marginal(
marginal.mu,
marginal.q,
slope,
scale,
nodes,
weights,
None,
)
.map_err(EstimationError::InvalidInput)?;
let observed_slope = scale * slope;
let mut f_a = 0.0;
let mut f_b = 0.0;
for (&node, &weight) in nodes.iter().zip(weights.iter()) {
let eta = intercept + observed_slope * node;
let pdf = normal_pdf(eta);
f_a += weight * pdf;
f_b += weight * pdf * scale * node;
}
if !(f_a.is_finite() && f_a > 0.0 && f_b.is_finite()) {
return Err(EstimationError::InvalidInput(format!(
"empirical latent prediction calibration derivative is invalid: F_a={f_a}, F_b={f_b}"
)));
}
let a_marginal_eta = marginal.mu1 / f_a;
let a_slope = -f_b / f_a;
Ok((intercept, a_marginal_eta, a_slope))
}
fn local_empirical_mixture_for_point(
point: &[f64],
centers: &[Vec<f64>],
top_k: usize,
bandwidth: f64,
) -> Result<Vec<(usize, f64)>, EstimationError> {
if centers.is_empty() {
return Err(EstimationError::InvalidInput(
"local empirical latent prediction has no centers".to_string(),
));
}
if top_k == 0 {
return Err(EstimationError::InvalidInput(
"local empirical latent prediction top_k must be positive".to_string(),
));
}
if !(bandwidth.is_finite() && bandwidth > 0.0) {
return Err(EstimationError::InvalidInput(format!(
"local empirical latent prediction bandwidth must be finite and positive, got {bandwidth}"
)));
}
let bw2 = bandwidth * bandwidth;
let mut distances = Vec::<(usize, f64)>::with_capacity(centers.len());
for (idx, center) in centers.iter().enumerate() {
if center.len() != point.len() {
return Err(EstimationError::InvalidInput(format!(
"local empirical latent prediction center {idx} dimension mismatch: center={}, point={}",
center.len(),
point.len()
)));
}
let d2 = center
.iter()
.zip(point.iter())
.map(|(&c, &x)| {
let delta = x - c;
delta * delta
})
.sum::<f64>();
if !d2.is_finite() {
return Err(EstimationError::InvalidInput(
"local empirical latent prediction distance is non-finite".to_string(),
));
}
distances.push((idx, d2));
}
distances.sort_by(|left, right| {
left.1
.partial_cmp(&right.1)
.expect("validated local empirical distances are finite")
});
let k = top_k.min(distances.len());
let mut mixture = Vec::with_capacity(k);
let mut total = 0.0;
for &(idx, d2) in distances.iter().take(k) {
let weight = (-0.5 * d2 / bw2).exp().max(1e-300);
mixture.push((idx, weight));
total += weight;
}
if !(total.is_finite() && total > 0.0) {
return Err(EstimationError::InvalidInput(
"local empirical latent prediction mixture has non-positive total weight"
.to_string(),
));
}
for (_, weight) in &mut mixture {
*weight /= total;
}
Ok(mixture)
}
fn combine_empirical_grids(
grids: &[EmpiricalZGrid],
mixture: &[(usize, f64)],
) -> Result<EmpiricalZGrid, EstimationError> {
let total_len = mixture
.iter()
.map(|&(idx, _)| grids.get(idx).map_or(0, |grid| grid.nodes.len()))
.sum::<usize>();
let mut nodes = Vec::with_capacity(total_len);
let mut weights = Vec::with_capacity(total_len);
let mut total_weight = 0.0;
for &(grid_idx, grid_weight) in mixture {
if !(grid_weight.is_finite() && grid_weight >= 0.0) {
return Err(EstimationError::InvalidInput(format!(
"local empirical latent prediction mixture weight must be finite and non-negative, got {grid_weight}"
)));
}
let grid = grids.get(grid_idx).ok_or_else(|| {
EstimationError::InvalidInput(format!(
"local empirical latent prediction grid index {grid_idx} is out of bounds for {} grids",
grids.len()
))
})?;
if grid.nodes.len() != grid.weights.len() || grid.nodes.is_empty() {
return Err(EstimationError::InvalidInput(format!(
"local empirical latent prediction grid {grid_idx} is invalid: nodes={}, weights={}",
grid.nodes.len(),
grid.weights.len()
)));
}
for (node, weight) in grid.pairs() {
let combined_weight = grid_weight * weight;
if !(node.is_finite() && combined_weight.is_finite() && combined_weight >= 0.0) {
return Err(EstimationError::InvalidInput(
"local empirical latent prediction grid contains invalid node/weight"
.to_string(),
));
}
nodes.push(node);
weights.push(combined_weight);
total_weight += combined_weight;
}
}
if !(total_weight.is_finite() && total_weight > 0.0) {
return Err(EstimationError::InvalidInput(
"local empirical latent prediction combined grid has non-positive total weight"
.to_string(),
));
}
for weight in &mut weights {
*weight /= total_weight;
}
Ok(EmpiricalZGrid { nodes, weights })
}
fn empirical_grid_for_prediction_row(
&self,
input: &PredictInput,
row: usize,
) -> Result<Option<EmpiricalZGrid>, EstimationError> {
match &self.latent_measure {
LatentMeasureKind::StandardNormal => Ok(None),
LatentMeasureKind::GlobalEmpirical { grid } => Ok(Some(grid.clone())),
LatentMeasureKind::LocalEmpirical {
centers,
grids,
top_k,
bandwidth,
..
} => {
let conditioning = input.auxiliary_matrix.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope local empirical prediction requires auxiliary conditioning matrix"
.to_string(),
)
})?;
if row >= conditioning.nrows() {
return Err(EstimationError::InvalidInput(format!(
"local empirical latent prediction row {row} is out of bounds for {} conditioning rows",
conditioning.nrows()
)));
}
let expected_dim = centers.first().map_or(0, Vec::len);
if conditioning.ncols() != expected_dim {
return Err(EstimationError::InvalidInput(format!(
"local empirical latent prediction conditioning dimension mismatch: got {}, expected {expected_dim}",
conditioning.ncols()
)));
}
let point = conditioning.row(row).to_vec();
let mixture =
Self::local_empirical_mixture_for_point(&point, centers, *top_k, *bandwidth)?;
Self::combine_empirical_grids(grids, &mixture).map(Some)
}
}
}
fn transform_internal_eta_to_base_scale(
&self,
internal_eta: Array1<f64>,
internal_grad: Option<Array2<f64>>,
) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
Ok((internal_eta, internal_grad))
}
fn link_terms_value_d1(
&self,
eta0: &Array1<f64>,
beta_link_dev: Option<&Array1<f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
if let (Some(runtime), Some(beta)) = (&self.link_deviation_runtime, beta_link_dev) {
let basis = runtime
.design_uncorrected(eta0)
.map_err(EstimationError::from)?;
let mut value = &basis.dot(beta) + eta0;
if let Some(corr) = link_dev_correction_for_row {
let offset = corr.dot(beta);
for v in value.iter_mut() {
*v -= offset;
}
} else if runtime.anchor_correction.is_some() {
return Err(EstimationError::InvalidInput(
"bernoulli marginal-slope link-deviation runtime has an anchor residual but \
no per-row correction was supplied to link_terms_value_d1"
.to_string(),
));
}
let d1 = runtime
.first_derivative_design(eta0)
.map_err(EstimationError::from)?;
Ok((value, d1.dot(beta) + 1.0))
} else {
Ok((eta0.clone(), Array1::ones(eta0.len())))
}
}
fn denested_partition_cells(
&self,
a: f64,
b: f64,
beta_score_warp: Option<&Array1<f64>>,
beta_link_dev: Option<&Array1<f64>>,
score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<Vec<gam::families::cubic_cell_kernel::DenestedPartitionCell>, EstimationError> {
let score_breaks = if let Some(runtime) = self.score_warp_runtime.as_ref() {
runtime.breakpoints().map_err(EstimationError::from)?
} else {
Vec::new()
};
let link_breaks = if let Some(runtime) = self.link_deviation_runtime.as_ref() {
runtime.breakpoints().map_err(EstimationError::from)?
} else {
Vec::new()
};
let mut cells =
gam::families::cubic_cell_kernel::build_denested_partition_cells_with_tails(
a,
b,
&score_breaks,
&link_breaks,
|z| {
if let (Some(runtime), Some(beta)) =
(self.score_warp_runtime.as_ref(), beta_score_warp)
{
let mut span = runtime.local_cubic_at(beta, z)?;
if let Some(corr) = score_warp_correction_for_row {
span.c0 -= corr.dot(beta);
}
Ok(span)
} else {
Ok(gam::families::cubic_cell_kernel::LocalSpanCubic {
left: 0.0,
right: 1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
})
}
},
|u| {
if let (Some(runtime), Some(beta)) =
(self.link_deviation_runtime.as_ref(), beta_link_dev)
{
let mut span = runtime.local_cubic_at(beta, u)?;
if let Some(corr) = link_dev_correction_for_row {
span.c0 -= corr.dot(beta);
}
Ok(span)
} else {
Ok(gam::families::cubic_cell_kernel::LocalSpanCubic {
left: 0.0,
right: 1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
})
}
},
)
.map_err(EstimationError::InvalidInput)?;
let scale = self.probit_frailty_scale();
if scale != 1.0 {
for partition_cell in &mut cells {
partition_cell.cell.c0 *= scale;
partition_cell.cell.c1 *= scale;
partition_cell.cell.c2 *= scale;
partition_cell.cell.c3 *= scale;
}
}
Ok(cells)
}
fn evaluate_denested_calibration(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_score_warp: Option<&Array1<f64>>,
beta_link_dev: Option<&Array1<f64>>,
score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<(f64, f64, f64), EstimationError> {
let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
.map_err(EstimationError::InvalidInput)?;
let cells = self.denested_partition_cells(
a,
slope,
beta_score_warp,
beta_link_dev,
score_warp_correction_for_row,
link_dev_correction_for_row,
)?;
let scale = self.probit_frailty_scale();
let mut f = -marginal.mu;
let mut f_a = 0.0;
let mut f_aa = 0.0;
for partition_cell in cells {
let cell = partition_cell.cell;
let (dc_da_raw, _) =
gam::families::cubic_cell_kernel::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
slope,
);
let (d2c_da2_raw, _, _) =
gam::families::cubic_cell_kernel::denested_cell_second_partials(
partition_cell.score_span,
partition_cell.link_span,
a,
slope,
);
let dc_da = scale_coeff4(dc_da_raw, scale);
let d2c_da2 = scale_coeff4(d2c_da2_raw, scale);
let max_degree =
gam::families::cubic_cell_kernel::cell_second_derivative_required_max_degree(
&dc_da, &dc_da, &d2c_da2,
);
let state = gam::families::cubic_cell_kernel::evaluate_cell_moments(cell, max_degree)
.map_err(EstimationError::InvalidInput)?;
f += state.value;
f_a += gam::families::cubic_cell_kernel::cell_first_derivative_from_moments(
&dc_da,
&state.moments,
)
.map_err(EstimationError::InvalidInput)?;
f_aa += gam::families::cubic_cell_kernel::cell_second_derivative_from_moments(
cell,
&dc_da,
&dc_da,
&d2c_da2,
&state.moments,
)
.map_err(EstimationError::InvalidInput)?;
}
Ok((f, f_a, f_aa))
}
fn observed_denested_cell_partials_at_z(
&self,
z_value: f64,
a: f64,
b: f64,
beta_score_warp: Option<&Array1<f64>>,
beta_link_dev: Option<&Array1<f64>>,
score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<ObservedDenestedCellPartials, EstimationError> {
use gam::families::cubic_cell_kernel as exact;
let zero_span = exact::LocalSpanCubic {
left: 0.0,
right: 1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
};
let u_value = a + b * z_value;
let score_span = if let (Some(runtime), Some(beta)) =
(self.score_warp_runtime.as_ref(), beta_score_warp)
{
let mut span = runtime
.local_cubic_at(beta, z_value)
.map_err(EstimationError::from)?;
if let Some(corr) = score_warp_correction_for_row {
span.c0 -= corr.dot(beta);
}
span
} else {
zero_span
};
let link_span = if let (Some(runtime), Some(beta)) =
(self.link_deviation_runtime.as_ref(), beta_link_dev)
{
let mut span = runtime
.local_cubic_at(beta, u_value)
.map_err(EstimationError::from)?;
if let Some(corr) = link_dev_correction_for_row {
span.c0 -= corr.dot(beta);
}
span
} else {
zero_span
};
let scale = self.probit_frailty_scale();
let coeff = scale_coeff4(
exact::denested_cell_coefficients(score_span, link_span, a, b),
scale,
);
let (dc_da_raw, dc_db_raw) =
exact::denested_cell_coefficient_partials(score_span, link_span, a, b);
let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
exact::denested_cell_second_partials(score_span, link_span, a, b);
let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) = exact::denested_cell_third_partials(link_span);
Ok(ObservedDenestedCellPartials {
coeff,
dc_da: scale_coeff4(dc_da_raw, scale),
dc_db: scale_coeff4(dc_db_raw, scale),
dc_daa: scale_coeff4(dc_daa_raw, scale),
dc_dab: scale_coeff4(dc_dab_raw, scale),
dc_dbb: scale_coeff4(dc_dbb_raw, scale),
dc_daaa: scale_coeff4(dc_daaa, scale),
dc_daab: scale_coeff4(dc_daab, scale),
dc_dabb: scale_coeff4(dc_dabb, scale),
dc_dbbb: scale_coeff4(dc_dbbb, scale),
})
}
fn evaluate_empirical_denested_calibration(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_score_warp: Option<&Array1<f64>>,
beta_link_dev: Option<&Array1<f64>>,
grid: &EmpiricalZGrid,
score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<(f64, f64, f64), EstimationError> {
let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
.map_err(EstimationError::InvalidInput)?;
let mut f = -marginal.mu;
let mut f_a = 0.0;
let mut f_aa = 0.0;
for (node, weight) in grid.pairs() {
let obs = self.observed_denested_cell_partials_at_z(
node,
a,
slope,
beta_score_warp,
beta_link_dev,
score_warp_correction_for_row,
link_dev_correction_for_row,
)?;
let eta = eval_coeff4_at(&obs.coeff, node);
let eta_a = eval_coeff4_at(&obs.dc_da, node);
let eta_aa = eval_coeff4_at(&obs.dc_daa, node);
let pdf = normal_pdf(eta);
f += weight * normal_cdf(eta);
f_a += weight * pdf * eta_a;
f_aa += weight * pdf * (eta_aa - eta * eta_a * eta_a);
}
Ok((f, f_a, f_aa))
}
fn evaluate_prediction_calibration(
&self,
a: f64,
marginal_eta: f64,
slope: f64,
beta_score_warp: Option<&Array1<f64>>,
beta_link_dev: Option<&Array1<f64>>,
empirical_grid: Option<&EmpiricalZGrid>,
score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<(f64, f64, f64), EstimationError> {
if let Some(grid) = empirical_grid {
self.evaluate_empirical_denested_calibration(
a,
marginal_eta,
slope,
beta_score_warp,
beta_link_dev,
grid,
score_warp_correction_for_row,
link_dev_correction_for_row,
)
} else {
self.evaluate_denested_calibration(
a,
marginal_eta,
slope,
beta_score_warp,
beta_link_dev,
score_warp_correction_for_row,
link_dev_correction_for_row,
)
}
}
pub fn from_unified(
unified: &UnifiedFitResult,
z_column: String,
latent_z_normalization: SavedLatentZNormalization,
latent_measure: LatentMeasureKind,
baseline_marginal: f64,
baseline_logslope: f64,
base_link: InverseLink,
frailty: FrailtySpec,
score_warp_runtime: Option<SavedCompiledFlexBlock>,
link_deviation_runtime: Option<SavedCompiledFlexBlock>,
latent_z_calibration: Option<gam::families::bms::LatentZRankIntCalibration>,
latent_z_conditional_calibration: Option<gam::families::bms::LatentZConditionalCalibration>,
) -> Result<Self, String> {
let gaussian_frailty_sd = match frailty {
FrailtySpec::None => None,
FrailtySpec::GaussianShift {
sigma_fixed: Some(sigma),
} => Some(sigma),
FrailtySpec::GaussianShift { sigma_fixed: None } => {
return Err(
"bernoulli marginal-slope predictor requires a fixed GaussianShift sigma"
.to_string(),
);
}
FrailtySpec::HazardMultiplier { .. } => {
return Err(
"bernoulli marginal-slope predictor does not support HazardMultiplier frailty"
.to_string(),
);
}
};
if !matches!(
base_link,
InverseLink::Standard(gam::types::StandardLink::Probit)
) {
return Err(
"bernoulli marginal-slope predictor requires link(type=probit); saved non-probit marginal-slope models must be refit"
.to_string(),
);
}
if let Some(runtime) = score_warp_runtime.as_ref() {
runtime.validate_exact_replay_contract().map_err(|e| {
format!("bernoulli marginal-slope score-warp runtime is invalid: {e}")
})?;
}
if let Some(runtime) = link_deviation_runtime.as_ref() {
runtime.validate_exact_replay_contract().map_err(|e| {
format!("bernoulli marginal-slope link-deviation runtime is invalid: {e}")
})?;
}
latent_z_normalization
.validate("bernoulli marginal-slope predictor")
.map_err(|e| {
format!("bernoulli marginal-slope predictor latent z normalization is invalid: {e}")
})?;
latent_measure
.validate("bernoulli marginal-slope predictor latent measure")
.map_err(|e| {
format!("bernoulli marginal-slope predictor latent measure is invalid: {e}")
})?;
let blocks = &unified.blocks;
let expected_blocks = 2
+ usize::from(score_warp_runtime.is_some())
+ usize::from(link_deviation_runtime.is_some());
if blocks.len() != expected_blocks {
return Err(format!(
"bernoulli marginal-slope predictor requires exactly {expected_blocks} coefficient blocks under the current exact de-nested semantics, got {}",
blocks.len()
));
}
let mut cursor = 2usize;
let beta_score_warp = if score_warp_runtime.is_some() {
let beta = blocks
.get(cursor)
.ok_or_else(|| "missing score-warp coefficient block".to_string())?
.beta
.clone();
cursor += 1;
Some(beta)
} else {
None
};
let beta_link_dev = if link_deviation_runtime.is_some() {
Some(
blocks
.get(cursor)
.ok_or_else(|| "missing link-deviation coefficient block".to_string())?
.beta
.clone(),
)
} else {
None
};
Ok(Self {
beta_marginal: blocks[0].beta.clone(),
beta_logslope: blocks[1].beta.clone(),
beta_score_warp,
beta_link_dev,
base_link,
z_column,
latent_z_normalization,
latent_measure,
baseline_marginal,
baseline_logslope,
covariance: unified.beta_covariance().cloned(),
score_warp_runtime,
link_deviation_runtime,
gaussian_frailty_sd,
latent_z_calibration,
latent_z_conditional_calibration,
})
}
pub fn theta(&self) -> Array1<f64> {
let total = self.beta_marginal.len()
+ self.beta_logslope.len()
+ self.beta_score_warp.as_ref().map_or(0, |b| b.len())
+ self.beta_link_dev.as_ref().map_or(0, |b| b.len());
let mut theta = Array1::<f64>::zeros(total);
let mut cursor = 0usize;
theta
.slice_mut(ndarray::s![cursor..cursor + self.beta_marginal.len()])
.assign(&self.beta_marginal);
cursor += self.beta_marginal.len();
theta
.slice_mut(ndarray::s![cursor..cursor + self.beta_logslope.len()])
.assign(&self.beta_logslope);
cursor += self.beta_logslope.len();
if let Some(beta) = self.beta_score_warp.as_ref() {
theta
.slice_mut(ndarray::s![cursor..cursor + beta.len()])
.assign(beta);
cursor += beta.len();
}
if let Some(beta) = self.beta_link_dev.as_ref() {
theta
.slice_mut(ndarray::s![cursor..cursor + beta.len()])
.assign(beta);
}
theta
}
fn split_theta<'a>(
&'a self,
theta: &'a Array1<f64>,
) -> Result<
(
ArrayView1<'a, f64>,
ArrayView1<'a, f64>,
Option<ArrayView1<'a, f64>>,
Option<ArrayView1<'a, f64>>,
),
EstimationError,
> {
let expected = self.theta().len();
if theta.len() != expected {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope theta length mismatch: expected {expected}, got {}",
theta.len()
)));
}
let mut cursor = 0usize;
let marginal = theta.slice(ndarray::s![cursor..cursor + self.beta_marginal.len()]);
cursor += self.beta_marginal.len();
let logslope = theta.slice(ndarray::s![cursor..cursor + self.beta_logslope.len()]);
cursor += self.beta_logslope.len();
let score_warp = self.beta_score_warp.as_ref().map(|beta| {
let view = theta.slice(ndarray::s![cursor..cursor + beta.len()]);
cursor += beta.len();
view
});
let link_dev = self
.beta_link_dev
.as_ref()
.map(|beta| theta.slice(ndarray::s![cursor..cursor + beta.len()]));
Ok((marginal, logslope, score_warp, link_dev))
}
fn solve_intercept_scalar(
&self,
marginal_eta: f64,
slope: f64,
link_dev_beta: Option<&Array1<f64>>,
score_warp_beta: Option<&Array1<f64>>,
empirical_grid: Option<&EmpiricalZGrid>,
warm_start_buf: &mut Array1<f64>,
score_warp_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
link_dev_correction_for_row: Option<ndarray::ArrayView1<'_, f64>>,
) -> Result<f64, EstimationError> {
let marginal = bernoulli_marginal_link_map(&self.base_link, marginal_eta)
.map_err(EstimationError::InvalidInput)?;
let eval = |a: f64| -> Result<(f64, f64, f64), String> {
self.evaluate_prediction_calibration(
a,
marginal_eta,
slope,
score_warp_beta,
link_dev_beta,
empirical_grid,
score_warp_correction_for_row,
link_dev_correction_for_row,
)
.map_err(|err| err.to_string())
};
let probit_scale = self.probit_frailty_scale();
let a_rigid = self.rigid_intercept_from_marginal(marginal.q, slope);
let mut intercept = a_rigid;
if let (Some(_), Some(beta)) = (self.link_deviation_runtime.as_ref(), link_dev_beta) {
warm_start_buf[0] = a_rigid;
let one_pt = warm_start_buf.slice(ndarray::s![0..1]).to_owned();
let (l_val, l_d1) =
self.link_terms_value_d1(&one_pt, Some(beta), link_dev_correction_for_row)?;
let ell1 = l_d1[0];
if ell1 > 1e-8 {
let ell0 = l_val[0] - ell1 * a_rigid;
let observed_logslope = probit_scale * ell1 * slope;
intercept = (marginal.q * (1.0 + observed_logslope * observed_logslope).sqrt()
/ probit_scale
- ell0)
/ ell1;
}
}
let target = marginal.mu;
let abs_tol = 1e-8_f64.max(1e-4 * target.abs());
let (root, _, f_best) = gam::families::monotone_root::solve_monotone_root(
eval,
intercept,
"saved bernoulli intercept",
abs_tol,
64,
48,
)?;
if f_best.abs() > abs_tol {
return Err(EstimationError::InvalidInput(format!(
"saved bernoulli marginal-slope intercept solve failed: residual={f_best:.3e} at a={root:.6}, target mu={target:.6}"
)));
}
Ok(root)
}
pub fn final_eta_and_gradient_from_theta(
&self,
input: &PredictInput,
theta: &Array1<f64>,
need_gradient: bool,
) -> Result<(Array1<f64>, Option<Array2<f64>>), EstimationError> {
let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(format!(
"bernoulli marginal-slope prediction requires auxiliary z column '{}'",
self.z_column
))
})?;
let z_normalized = self
.latent_z_normalization
.apply(z_raw, "bernoulli marginal-slope prediction")
.map_err(EstimationError::from)?;
let z = self.apply_latent_z_calibration(&z_normalized);
let z = self.apply_latent_z_conditional_calibration(&z, input)?;
let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope prediction requires logslope design".to_string(),
)
})?;
let (beta_marginal, beta_logslope, beta_score_warp, beta_link_dev) =
self.split_theta(theta)?;
if self.score_warp_runtime.is_some() != beta_score_warp.is_some() {
return Err(EstimationError::InvalidInput(
"bernoulli marginal-slope saved score-warp runtime/coefficients are inconsistent"
.to_string(),
));
}
if self.link_deviation_runtime.is_some() != beta_link_dev.is_some() {
return Err(EstimationError::InvalidInput(
"bernoulli marginal-slope saved link-deviation runtime/coefficients are inconsistent"
.to_string(),
));
}
let n = z.len();
if input.offset.len() != n {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
input.offset.len()
)));
}
let logslope_offset = input
.offset_noise
.as_ref()
.map_or_else(|| Array1::zeros(n), Clone::clone);
if logslope_offset.len() != n {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
logslope_offset.len()
)));
}
let marginal_eta = input
.design
.dot(&beta_marginal.to_owned())
.mapv(|v| v + self.baseline_marginal)
+ &input.offset;
let logslope_eta = design_logslope
.dot(&beta_logslope.to_owned())
.mapv(|v| v + self.baseline_logslope)
+ &logslope_offset;
let flex_active =
self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
let marginal_dim = self.beta_marginal.len();
let logslope_dim = self.beta_logslope.len();
let score_warp_dim = self.beta_score_warp.as_ref().map_or(0, Array1::len);
let link_dev_dim = self.beta_link_dev.as_ref().map_or(0, Array1::len);
let logslope_offset = marginal_dim;
let score_warp_offset = logslope_offset + logslope_dim;
let link_dev_offset = score_warp_offset + score_warp_dim;
let chunk_size = prediction_chunk_rows(theta.len(), 1, n);
let num_chunks = n.div_ceil(chunk_size);
let scale = self.probit_frailty_scale();
let anchor_corrections =
self.build_anchor_correction_matrices(input, design_logslope, &z)?;
let marginal_map = marginal_eta
.iter()
.map(|&eta| {
bernoulli_marginal_link_map(&self.base_link, eta)
.map_err(EstimationError::InvalidInput)
})
.collect::<Result<Vec<_>, _>>()?;
if !flex_active {
let (final_eta_internal, marginal_scales, logslope_scales) = match &self.latent_measure
{
LatentMeasureKind::StandardNormal => {
let sb_vec = logslope_eta.mapv(|b| scale * b);
let c_vec = sb_vec.mapv(|sb| (1.0 + sb * sb).sqrt());
let final_eta_internal = Array1::from_iter(
(0..n).map(|i| c_vec[i] * marginal_eta[i] + sb_vec[i] * z[i]),
);
let marginal_scales = c_vec;
let logslope_scales = Array1::from_iter((0..n).map(|i| {
marginal_eta[i] * (scale * scale) * logslope_eta[i] / marginal_scales[i]
+ scale * z[i]
}));
(final_eta_internal, marginal_scales, logslope_scales)
}
LatentMeasureKind::GlobalEmpirical { grid } => {
let mut final_eta = Array1::<f64>::zeros(n);
let mut marginal_scales = Array1::<f64>::zeros(n);
let mut logslope_scales = Array1::<f64>::zeros(n);
for i in 0..n {
let (intercept, a_marginal, a_slope) = self
.empirical_rigid_intercept_and_gradient(
marginal_eta[i],
logslope_eta[i],
&grid.nodes,
&grid.weights,
)?;
final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
marginal_scales[i] = a_marginal;
logslope_scales[i] = a_slope + scale * z[i];
}
(final_eta, marginal_scales, logslope_scales)
}
LatentMeasureKind::LocalEmpirical { .. } => {
let mut final_eta = Array1::<f64>::zeros(n);
let mut marginal_scales = Array1::<f64>::zeros(n);
let mut logslope_scales = Array1::<f64>::zeros(n);
for i in 0..n {
let grid = self
.empirical_grid_for_prediction_row(input, i)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"local empirical latent prediction did not produce a row grid"
.to_string(),
)
})?;
let (intercept, a_marginal, a_slope) = self
.empirical_rigid_intercept_and_gradient(
marginal_eta[i],
logslope_eta[i],
&grid.nodes,
&grid.weights,
)?;
final_eta[i] = intercept + scale * logslope_eta[i] * z[i];
marginal_scales[i] = a_marginal;
logslope_scales[i] = a_slope + scale * z[i];
}
(final_eta, marginal_scales, logslope_scales)
}
};
if !need_gradient {
return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
}
let mut grad_internal = Array2::<f64>::zeros((n, theta.len()));
let mut start = 0usize;
while start < n {
let end = (start + chunk_size).min(n);
let mc = input
.design
.try_row_chunk(start..end)
.map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
let lc = design_logslope
.try_row_chunk(start..end)
.map_err(|e| EstimationError::InvalidInput(e.to_string()))?;
for li in 0..(end - start) {
let i = start + li;
let c = marginal_scales[i];
let g_scale = logslope_scales[i];
let mut row = grad_internal.row_mut(i);
for j in 0..marginal_dim {
row[j] = c * mc[[li, j]];
}
for j in 0..logslope_dim {
row[logslope_offset + j] = g_scale * lc[[li, j]];
}
}
start = end;
}
return self
.transform_internal_eta_to_base_scale(final_eta_internal, Some(grad_internal));
}
let score_warp_obs_design = self
.score_warp_runtime
.as_ref()
.map(|runtime| {
if runtime.anchor_correction.is_some() {
let anchor_rows = anchor_corrections
.score_warp_anchor_rows_view()
.ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope score-warp anchor residual present but \
anchor_corrections bundle is missing the parametric anchor rows"
.to_string(),
)
})?;
runtime
.design_with_anchor_rows(&z, anchor_rows)
.map_err(EstimationError::from)
} else {
runtime.design(&z).map_err(EstimationError::from)
}
})
.transpose()?;
let score_dev_obs =
if let (Some(design), Some(beta)) = (score_warp_obs_design.as_ref(), beta_score_warp) {
design.dot(&beta.to_owned())
} else {
Array1::zeros(n)
};
let score_warp_beta_owned = beta_score_warp.as_ref().map(|v| v.to_owned());
let link_dev_beta_owned = beta_link_dev.as_ref().map(|v| v.to_owned());
let mut intercepts = Array1::<f64>::zeros(n);
let mut a_q_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
let mut a_b_vec = need_gradient.then(|| Array1::<f64>::zeros(n));
let mut a_h_rows = if need_gradient && score_warp_dim > 0 {
Some(Array2::<f64>::zeros((n, score_warp_dim)))
} else {
None
};
let mut a_w_rows = if need_gradient && link_dev_dim > 0 {
Some(Array2::<f64>::zeros((n, link_dev_dim)))
} else {
None
};
let solve_result: Result<(), EstimationError> = {
use ndarray::Axis;
use rayon::iter::IndexedParallelIterator;
let intercepts_chunks: Vec<ndarray::ArrayViewMut1<f64>> = intercepts
.axis_chunks_iter_mut(Axis(0), chunk_size)
.collect();
let a_q_chunks: Option<Vec<ndarray::ArrayViewMut1<f64>>> = a_q_vec
.as_mut()
.map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
let a_b_chunks: Option<Vec<ndarray::ArrayViewMut1<f64>>> = a_b_vec
.as_mut()
.map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
let a_h_chunks: Option<Vec<ndarray::ArrayViewMut2<f64>>> = a_h_rows
.as_mut()
.map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
let a_w_chunks: Option<Vec<ndarray::ArrayViewMut2<f64>>> = a_w_rows
.as_mut()
.map(|a| a.axis_chunks_iter_mut(Axis(0), chunk_size).collect());
struct FlexSolveSink<'a> {
intercepts: ndarray::ArrayViewMut1<'a, f64>,
a_q: Option<ndarray::ArrayViewMut1<'a, f64>>,
a_b: Option<ndarray::ArrayViewMut1<'a, f64>>,
a_h: Option<ndarray::ArrayViewMut2<'a, f64>>,
a_w: Option<ndarray::ArrayViewMut2<'a, f64>>,
}
let mut sinks: Vec<FlexSolveSink<'_>> = Vec::with_capacity(num_chunks);
let mut intercepts_iter = intercepts_chunks.into_iter();
let mut a_q_iter = a_q_chunks.map(|v| v.into_iter());
let mut a_b_iter = a_b_chunks.map(|v| v.into_iter());
let mut a_h_iter = a_h_chunks.map(|v| v.into_iter());
let mut a_w_iter = a_w_chunks.map(|v| v.into_iter());
for _ in 0..num_chunks {
sinks.push(FlexSolveSink {
intercepts: intercepts_iter.next().expect("chunk count matches"),
a_q: a_q_iter
.as_mut()
.map(|it| it.next().expect("chunk count matches")),
a_b: a_b_iter
.as_mut()
.map(|it| it.next().expect("chunk count matches")),
a_h: a_h_iter
.as_mut()
.map(|it| it.next().expect("chunk count matches")),
a_w: a_w_iter
.as_mut()
.map(|it| it.next().expect("chunk count matches")),
});
}
let global_score_basis_table: Option<
Vec<Vec<gam::families::cubic_cell_kernel::LocalSpanCubic>>,
> = if let (LatentMeasureKind::GlobalEmpirical { grid }, Some(runtime)) =
(&self.latent_measure, self.score_warp_runtime.as_ref())
{
let mut table = Vec::with_capacity(score_warp_dim);
for j in 0..score_warp_dim {
let mut row = Vec::with_capacity(grid.nodes.len());
for &node in &grid.nodes {
row.push(
runtime
.basis_cubic_at(j, node)
.map_err(EstimationError::from)?,
);
}
table.push(row);
}
Some(table)
} else {
None
};
let global_score_basis_table = global_score_basis_table.as_ref();
sinks
.into_par_iter()
.enumerate()
.try_for_each(|(chunk_idx, mut sink)| -> Result<(), EstimationError> {
let start = chunk_idx * chunk_size;
let end = (start + chunk_size).min(n);
let rows = end - start;
let intercepts_view = &mut sink.intercepts;
let mut a_q = sink.a_q.as_mut();
let mut a_b = sink.a_b.as_mut();
let mut a_h = sink.a_h.as_mut();
let mut a_w = sink.a_w.as_mut();
let mut warm_start_buf = Array1::<f64>::zeros(1);
let mut f_h_row = vec![0.0; score_warp_dim];
let mut f_w_row = vec![0.0; link_dev_dim];
for local_row in 0..rows {
let i = start + local_row;
let slope = logslope_eta[i];
let q = marginal_eta[i];
let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
let score_corr_row = anchor_corrections.score_warp_row(i);
let link_corr_row = anchor_corrections.link_dev_row(i);
intercepts_view[local_row] = self.solve_intercept_scalar(
q,
slope,
link_dev_beta_owned.as_ref(),
score_warp_beta_owned.as_ref(),
empirical_grid.as_ref(),
&mut warm_start_buf,
score_corr_row,
link_corr_row,
)?;
if !need_gradient {
continue;
}
let intercept = intercepts_view[local_row];
let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
intercept,
q,
slope,
score_warp_beta_owned.as_ref(),
link_dev_beta_owned.as_ref(),
empirical_grid.as_ref(),
score_corr_row,
link_corr_row,
)?;
let m_a = m_a_raw.max(1e-12);
a_q.as_mut().expect("a_q allocated when need_gradient")[local_row] =
marginal_map[i].mu1 / m_a;
let mut f_b = 0.0;
f_h_row.fill(0.0);
f_w_row.fill(0.0);
if let Some(grid) = empirical_grid.as_ref() {
for (node_idx, (node, weight)) in grid.pairs().enumerate() {
let obs = self.observed_denested_cell_partials_at_z(
node,
intercept,
slope,
score_warp_beta_owned.as_ref(),
link_dev_beta_owned.as_ref(),
score_corr_row,
link_corr_row,
)?;
let eta = eval_coeff4_at(&obs.coeff, node);
let pdf = normal_pdf(eta);
f_b += weight * pdf * eval_coeff4_at(&obs.dc_db, node);
if let Some(runtime) = self.score_warp_runtime.as_ref() {
for j in 0..score_warp_dim {
let mut basis_span = if let Some(table) =
global_score_basis_table
{
table[j][node_idx]
} else {
runtime
.basis_cubic_at(j, node)
.map_err(EstimationError::from)?
};
if let Some(corr) = score_corr_row {
basis_span.c0 -= corr[j];
}
let coeffs = gam::families::cubic_cell_kernel::score_basis_cell_coefficients(
basis_span,
slope,
);
let coeffs = scale_coeff4(coeffs, scale);
f_h_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
}
}
if let Some(runtime) = self.link_deviation_runtime.as_ref() {
for j in 0..link_dev_dim {
let mut basis_span = runtime
.basis_cubic_at(j, intercept + slope * node)
.map_err(EstimationError::from)?;
if let Some(corr) = link_corr_row {
basis_span.c0 -= corr[j];
}
let coeffs = gam::families::cubic_cell_kernel::link_basis_cell_coefficients(
basis_span,
intercept,
slope,
);
let coeffs = scale_coeff4(coeffs, scale);
f_w_row[j] += weight * pdf * eval_coeff4_at(&coeffs, node);
}
}
}
} else {
let cells = self.denested_partition_cells(
intercept,
slope,
score_warp_beta_owned.as_ref(),
link_dev_beta_owned.as_ref(),
score_corr_row,
link_corr_row,
)?;
for partition_cell in cells {
let cell = partition_cell.cell;
let state =
gam::families::cubic_cell_kernel::evaluate_cell_moments(
cell, 9,
)
.map_err(EstimationError::InvalidInput)?;
let (_, dc_db_raw) = gam::families::cubic_cell_kernel::denested_cell_coefficient_partials(
partition_cell.score_span,
partition_cell.link_span,
intercept,
slope,
);
let dc_db = scale_coeff4(dc_db_raw, scale);
f_b += gam::families::cubic_cell_kernel::cell_first_derivative_from_moments(
&dc_db,
&state.moments,
)
.map_err(EstimationError::InvalidInput)?;
let mid = 0.5 * (cell.left + cell.right);
if let Some(runtime) = self.score_warp_runtime.as_ref() {
for j in 0..score_warp_dim {
let mut basis_span = runtime
.basis_cubic_at(j, mid)
.map_err(EstimationError::from)?;
if let Some(corr) = score_corr_row {
basis_span.c0 -= corr[j];
}
let coeffs = gam::families::cubic_cell_kernel::score_basis_cell_coefficients(
basis_span, slope,
);
let coeffs = scale_coeff4(coeffs, scale);
f_h_row[j] += gam::families::cubic_cell_kernel::cell_first_derivative_from_moments(
&coeffs,
&state.moments,
)
.map_err(EstimationError::InvalidInput)?;
}
}
if let Some(runtime) = self.link_deviation_runtime.as_ref() {
for j in 0..link_dev_dim {
let mut basis_span = runtime
.basis_cubic_at(j, intercept + slope * mid)
.map_err(EstimationError::from)?;
if let Some(corr) = link_corr_row {
basis_span.c0 -= corr[j];
}
let coeffs = gam::families::cubic_cell_kernel::link_basis_cell_coefficients(
basis_span,
intercept,
slope,
);
let coeffs = scale_coeff4(coeffs, scale);
f_w_row[j] += gam::families::cubic_cell_kernel::cell_first_derivative_from_moments(
&coeffs,
&state.moments,
)
.map_err(EstimationError::InvalidInput)?;
}
}
}
}
if let Some(a_h_view) = a_h.as_mut() {
let factor = -1.0 / m_a;
for j in 0..score_warp_dim {
a_h_view[[local_row, j]] = factor * f_h_row[j];
}
}
if let Some(a_w_view) = a_w.as_mut() {
let factor = -1.0 / m_a;
for j in 0..link_dev_dim {
a_w_view[[local_row, j]] = factor * f_w_row[j];
}
}
a_b.as_mut().expect("a_b allocated when need_gradient")[local_row] =
-f_b / m_a;
}
Ok(())
})
};
solve_result?;
let eta_base = &intercepts + &(&logslope_eta * &z);
let mut link_c_obs: Option<Array1<f64>> = None;
let mut link_basis_obs: Option<Array2<f64>> = None;
let link_dev_obs = if let (Some(runtime), Some(beta_owned)) = (
self.link_deviation_runtime.as_ref(),
link_dev_beta_owned.as_ref(),
) {
let basis = if runtime.anchor_correction.is_some() {
let anchor_rows =
anchor_corrections
.link_dev_anchor_rows_view()
.ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope link-deviation anchor residual present but \
anchor_corrections bundle is missing the parametric anchor rows"
.to_string(),
)
})?;
runtime
.design_with_anchor_rows(&eta_base, anchor_rows)
.map_err(EstimationError::from)?
} else {
runtime.design(&eta_base).map_err(EstimationError::from)?
};
let dev = basis.dot(beta_owned);
if need_gradient {
let d1 = runtime
.first_derivative_design(&eta_base)
.map_err(EstimationError::from)?;
let mut c_obs = d1.dot(beta_owned);
c_obs.mapv_inplace(|v| v + 1.0);
link_c_obs = Some(c_obs);
link_basis_obs = Some(basis);
}
dev
} else {
Array1::zeros(n)
};
let final_eta_internal =
(&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
if !need_gradient {
return self.transform_internal_eta_to_base_scale(final_eta_internal, None);
}
let a_q_vec = a_q_vec.unwrap();
let a_b_vec = a_b_vec.unwrap();
let mut grad = Array2::<f64>::zeros((n, theta.len()));
{
use ndarray::Axis;
use rayon::iter::IndexedParallelIterator;
let grad_result: Result<(), String> = grad
.axis_chunks_iter_mut(Axis(0), chunk_size)
.into_par_iter()
.enumerate()
.try_for_each(|(chunk_idx, mut grad_chunk)| -> Result<(), String> {
let start = chunk_idx * chunk_size;
let end = (start + chunk_size).min(n);
let mc = input
.design
.try_row_chunk(start..end)
.map_err(|e| e.to_string())?;
let lc = design_logslope
.try_row_chunk(start..end)
.map_err(|e| e.to_string())?;
let rows = end - start;
for li in 0..rows {
let i = start + li;
let mut row = grad_chunk.row_mut(li);
let a_q = a_q_vec[i];
for j in 0..marginal_dim {
row[j] = a_q * mc[[li, j]];
}
let base_multiplier = link_c_obs.as_ref().map_or(1.0, |c| c[i]);
let g_scale = base_multiplier * (a_b_vec[i] + z[i]) + score_dev_obs[i];
for j in 0..logslope_dim {
row[logslope_offset + j] = g_scale * lc[[li, j]];
}
if let (Some(a_h_rows), Some(obs_design)) =
(a_h_rows.as_ref(), score_warp_obs_design.as_ref())
{
let slope = logslope_eta[i];
for j in 0..score_warp_dim {
row[score_warp_offset + j] =
base_multiplier * a_h_rows[[i, j]] + slope * obs_design[[i, j]];
}
}
if let Some(a_w_rows) = a_w_rows.as_ref() {
for j in 0..link_dev_dim {
row[link_dev_offset + j] = a_w_rows[[i, j]];
}
}
if let (Some(link_c), Some(link_basis)) =
(link_c_obs.as_ref(), link_basis_obs.as_ref())
{
let c = link_c[i];
for j in 0..marginal_dim {
row[j] *= c;
}
for j in 0..link_dev_dim {
row[link_dev_offset + j] =
c * row[link_dev_offset + j] + link_basis[[i, j]];
}
}
}
Ok(())
});
grad_result.map_err(EstimationError::InvalidInput)?;
}
if scale != 1.0 {
grad.mapv_inplace(|v| scale * v);
}
self.transform_internal_eta_to_base_scale(final_eta_internal, Some(grad))
}
pub fn final_eta_from_theta(
&self,
input: &PredictInput,
theta: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
let (eta, _) = self.final_eta_and_gradient_from_theta(input, theta, false)?;
Ok(eta)
}
pub fn theta_len(&self) -> usize {
self.beta_marginal.len()
+ self.beta_logslope.len()
+ self.beta_score_warp.as_ref().map_or(0, Array1::len)
+ self.beta_link_dev.as_ref().map_or(0, Array1::len)
}
pub fn predict_eta_and_q_chain(
&self,
input: &PredictInput,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
let z_raw = input.auxiliary_scalar.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(format!(
"bernoulli marginal-slope prediction requires auxiliary z column '{}'",
self.z_column
))
})?;
let z_normalized = self
.latent_z_normalization
.apply(z_raw, "bernoulli marginal-slope prediction")
.map_err(EstimationError::from)?;
let z = self.apply_latent_z_calibration(&z_normalized);
let z = self.apply_latent_z_conditional_calibration(&z, input)?;
let design_logslope = input.design_noise.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope prediction requires logslope design".to_string(),
)
})?;
let n = z.len();
if input.offset.len() != n {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope prediction primary offset length mismatch: rows={n}, offset={}",
input.offset.len()
)));
}
let logslope_offset = input
.offset_noise
.as_ref()
.map_or_else(|| Array1::zeros(n), Clone::clone);
if logslope_offset.len() != n {
return Err(EstimationError::InvalidInput(format!(
"bernoulli marginal-slope prediction logslope offset length mismatch: rows={n}, offset_noise={}",
logslope_offset.len()
)));
}
let marginal_eta = input
.design
.dot(&self.beta_marginal)
.mapv(|v| v + self.baseline_marginal)
+ &input.offset;
let logslope_eta = design_logslope
.dot(&self.beta_logslope)
.mapv(|v| v + self.baseline_logslope)
+ &logslope_offset;
let scale = self.probit_frailty_scale();
let flex_active =
self.score_warp_runtime.is_some() || self.link_deviation_runtime.is_some();
if !flex_active {
match &self.latent_measure {
LatentMeasureKind::StandardNormal => {
let sb = logslope_eta.mapv(|x| scale * x);
let deta_dq = sb.mapv(|s| (1.0 + s * s).sqrt());
let eta = &deta_dq * marginal_eta + &sb * z;
return Ok((eta, deta_dq));
}
_ => {
let mut eta = Array1::<f64>::zeros(n);
let mut deta_dq = Array1::<f64>::zeros(n);
for i in 0..n {
let grid = self
.empirical_grid_for_prediction_row(input, i)?
.ok_or_else(|| {
EstimationError::InvalidInput(
"empirical latent prediction did not produce a row grid"
.to_string(),
)
})?;
let (intercept, a_marginal, _) = self
.empirical_rigid_intercept_and_gradient(
marginal_eta[i],
logslope_eta[i],
&grid.nodes,
&grid.weights,
)?;
eta[i] = intercept + scale * logslope_eta[i] * z[i];
deta_dq[i] = a_marginal;
}
return Ok((eta, deta_dq));
}
}
}
let marginal_map = marginal_eta
.iter()
.map(|&eta_marg| {
bernoulli_marginal_link_map(&self.base_link, eta_marg)
.map_err(EstimationError::InvalidInput)
})
.collect::<Result<Vec<_>, _>>()?;
let anchor_corrections =
self.build_anchor_correction_matrices(input, design_logslope, &z)?;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Result<Vec<(f64, f64)>, EstimationError> = (0..n)
.into_par_iter()
.map_init(
|| Array1::<f64>::zeros(1),
|warm_start_buf, i| {
let q = marginal_eta[i];
let slope = logslope_eta[i];
let empirical_grid = self.empirical_grid_for_prediction_row(input, i)?;
let score_corr_row = anchor_corrections.score_warp_row(i);
let link_corr_row = anchor_corrections.link_dev_row(i);
let intercept = self.solve_intercept_scalar(
q,
slope,
self.beta_link_dev.as_ref(),
self.beta_score_warp.as_ref(),
empirical_grid.as_ref(),
warm_start_buf,
score_corr_row,
link_corr_row,
)?;
let (_, m_a_raw, _) = self.evaluate_prediction_calibration(
intercept,
q,
slope,
self.beta_score_warp.as_ref(),
self.beta_link_dev.as_ref(),
empirical_grid.as_ref(),
score_corr_row,
link_corr_row,
)?;
let m_a = m_a_raw.max(1e-12);
Ok((intercept, marginal_map[i].mu1 / m_a))
},
)
.collect();
let pairs = pairs?;
let mut intercepts = Array1::<f64>::zeros(n);
let mut a_q = Array1::<f64>::zeros(n);
for (i, (intercept, a)) in pairs.into_iter().enumerate() {
intercepts[i] = intercept;
a_q[i] = a;
}
let score_dev_obs = if let (Some(runtime), Some(beta)) = (
self.score_warp_runtime.as_ref(),
self.beta_score_warp.as_ref(),
) {
let design = if runtime.anchor_correction.is_some() {
let anchor_rows = anchor_corrections
.score_warp_anchor_rows_view()
.ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope score-warp anchor residual present but \
anchor_corrections bundle is missing the parametric anchor rows"
.to_string(),
)
})?;
runtime
.design_with_anchor_rows(&z, anchor_rows)
.map_err(EstimationError::from)?
} else {
runtime.design(&z).map_err(EstimationError::from)?
};
design.dot(beta)
} else {
Array1::zeros(n)
};
let eta_base = &intercepts + &(&logslope_eta * &z);
let (link_dev_obs, link_c_obs) = if let (Some(runtime), Some(beta)) = (
self.link_deviation_runtime.as_ref(),
self.beta_link_dev.as_ref(),
) {
let basis = if runtime.anchor_correction.is_some() {
let anchor_rows =
anchor_corrections
.link_dev_anchor_rows_view()
.ok_or_else(|| {
EstimationError::InvalidInput(
"bernoulli marginal-slope link-deviation anchor residual present but \
anchor_corrections bundle is missing the parametric anchor rows"
.to_string(),
)
})?;
runtime
.design_with_anchor_rows(&eta_base, anchor_rows)
.map_err(EstimationError::from)?
} else {
runtime.design(&eta_base).map_err(EstimationError::from)?
};
let dev = basis.dot(beta);
let d1 = runtime
.first_derivative_design(&eta_base)
.map_err(EstimationError::from)?;
let mut c_obs = d1.dot(beta);
c_obs.mapv_inplace(|v| v + 1.0);
(dev, c_obs)
} else {
(Array1::zeros(n), Array1::ones(n))
};
let final_eta_internal =
(&eta_base + &(&logslope_eta * &score_dev_obs) + &link_dev_obs).mapv(|v| scale * v);
let deta_dq = (&link_c_obs * &a_q).mapv(|v| scale * v);
Ok((final_eta_internal, deta_dq))
}
}