use super::*;
#[derive(Clone)]
pub(super) struct BernoulliMarginalSlopeFamily {
pub(super) y: Arc<Array1<f64>>,
pub(super) weights: Arc<Array1<f64>>,
pub(super) z: Arc<Array1<f64>>,
pub(super) latent_measure: LatentMeasureKind,
pub(super) gaussian_frailty_sd: Option<f64>,
pub(super) base_link: InverseLink,
pub(super) marginal_design: DesignMatrix,
pub(super) logslope_design: DesignMatrix,
pub(super) score_warp: Option<DeviationRuntime>,
pub(super) link_dev: Option<DeviationRuntime>,
pub(super) policy: crate::resource::ResourcePolicy,
pub(super) cell_moment_lru: Arc<exact_kernel::CellMomentLruCache>,
pub(super) cell_moment_cache_stats: Arc<exact_kernel::CellMomentCacheStats>,
pub(super) intercept_warm_starts: Option<Arc<BernoulliInterceptWarmStartCache>>,
pub(super) auto_subsample_phase_counter: Arc<std::sync::atomic::AtomicUsize>,
pub(super) auto_subsample_last_rho: Arc<Mutex<Option<Array1<f64>>>>,
}
#[derive(Clone)]
pub(super) struct BernoulliInterceptPredictorWarmStart {
pub(super) intercept: f64,
pub(super) primary_point: Vec<f64>,
pub(super) intercept_primary_deriv: Vec<f64>,
}
pub(super) struct BernoulliInterceptWarmStartCache {
pub(super) intercept_value: Vec<AtomicU64>,
pub(super) intercept_tag: Vec<AtomicU64>,
pub(super) predictors: Vec<Mutex<Option<BernoulliInterceptPredictorWarmStart>>>,
}
impl BernoulliInterceptWarmStartCache {
#[inline]
pub(super) fn len(&self) -> usize {
self.intercept_value.len()
}
#[inline]
pub(super) fn load_tagged(&self, row: usize, beta_tag: u64) -> Option<f64> {
let value_slot = self.intercept_value.get(row)?;
let tag_slot = self.intercept_tag.get(row)?;
let tag_before = tag_slot.load(Ordering::Acquire);
if tag_before != beta_tag {
return None;
}
let bits = value_slot.load(Ordering::Relaxed);
let tag_after = tag_slot.load(Ordering::Acquire);
if tag_after != beta_tag {
return None;
}
let value = f64::from_bits(bits);
value.is_finite().then_some(value)
}
#[inline]
pub(super) fn store_tagged(&self, row: usize, value: f64, beta_tag: u64) {
if let (Some(value_slot), Some(tag_slot)) =
(self.intercept_value.get(row), self.intercept_tag.get(row))
{
tag_slot.store(0, Ordering::Release);
value_slot.store(value.to_bits(), Ordering::Relaxed);
tag_slot.store(beta_tag, Ordering::Release);
}
}
#[inline]
pub(super) fn compare_exchange_unseeded(
&self,
row: usize,
value: f64,
beta_tag: u64,
) -> Result<(), u64> {
let value_slot = self.intercept_value.get(row).ok_or(0u64)?;
let tag_slot = self.intercept_tag.get(row).ok_or(0u64)?;
match tag_slot.compare_exchange(0, beta_tag, Ordering::AcqRel, Ordering::Acquire) {
Ok(_) => {
value_slot.store(value.to_bits(), Ordering::Relaxed);
Ok(())
}
Err(prev) => Err(prev),
}
}
pub(super) fn predictor_seed(&self, row: usize, current_point: &[f64]) -> Option<f64> {
let warm = self.predictors.get(row)?.lock().ok()?.as_ref().cloned()?;
if warm.primary_point.len() != current_point.len()
|| warm.intercept_primary_deriv.len() != current_point.len()
|| !warm.intercept.is_finite()
{
return None;
}
let correction = warm
.intercept_primary_deriv
.iter()
.zip(current_point.iter().zip(warm.primary_point.iter()))
.map(|(a_u, (new, old))| a_u * (new - old))
.sum::<f64>();
let seed = warm.intercept + correction;
seed.is_finite().then_some(seed)
}
pub(super) fn store_predictor(
&self,
row: usize,
intercept: f64,
primary_point: Vec<f64>,
intercept_primary_deriv: Vec<f64>,
) {
if !intercept.is_finite()
|| primary_point.iter().any(|value| !value.is_finite())
|| intercept_primary_deriv
.iter()
.any(|value| !value.is_finite())
{
return;
}
let Some(slot) = self.predictors.get(row) else {
return;
};
if let Ok(mut guard) = slot.lock() {
*guard = Some(BernoulliInterceptPredictorWarmStart {
intercept,
primary_point,
intercept_primary_deriv,
});
}
}
}
pub(super) fn new_intercept_warm_start_cache(n: usize) -> Arc<BernoulliInterceptWarmStartCache> {
Arc::new(BernoulliInterceptWarmStartCache {
intercept_value: (0..n).map(|_| AtomicU64::new(f64::NAN.to_bits())).collect(),
intercept_tag: (0..n).map(|_| AtomicU64::new(0)).collect(),
predictors: (0..n).map(|_| Mutex::new(None)).collect(),
})
}
#[inline]
pub(super) fn hash_intercept_warm_start_key_rigid(marginal_q: f64, slope: f64) -> u64 {
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut hash = FNV_OFFSET;
let mix = |hash: &mut u64, byte: u8| {
*hash ^= byte as u64;
*hash = hash.wrapping_mul(FNV_PRIME);
};
mix(&mut hash, 0xb1);
for x in [marginal_q, slope] {
let bits = if x == 0.0 { 0u64 } else { x.to_bits() };
for b in bits.to_le_bytes() {
mix(&mut hash, b);
}
}
if hash == 0 { 1 } else { hash }
}
#[inline]
pub(super) fn hash_intercept_warm_start_key_flex(
marginal_eta: f64,
slope: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> u64 {
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut hash = FNV_OFFSET;
let mix = |hash: &mut u64, byte: u8| {
*hash ^= byte as u64;
*hash = hash.wrapping_mul(FNV_PRIME);
};
mix(&mut hash, 0xb2);
for x in [marginal_eta, slope] {
let bits = if x == 0.0 { 0u64 } else { x.to_bits() };
for b in bits.to_le_bytes() {
mix(&mut hash, b);
}
}
let feed = |hash: &mut u64, beta: Option<&Array1<f64>>, marker: u8| {
mix(hash, marker);
match beta {
None => mix(hash, 0xffu8),
Some(v) => {
let len = v.len() as u64;
for b in len.to_le_bytes() {
mix(hash, b);
}
for x in v.iter() {
let bits = if *x == 0.0 { 0u64 } else { x.to_bits() };
for b in bits.to_le_bytes() {
mix(hash, b);
}
}
}
}
};
feed(&mut hash, beta_h, 0xc1);
feed(&mut hash, beta_w, 0xc2);
if hash == 0 { 1 } else { hash }
}
#[derive(Clone, Default)]
pub(super) struct ThetaHints {
pub(super) marginal_beta: Option<Array1<f64>>,
pub(super) logslope_beta: Option<Array1<f64>>,
pub(super) score_warp_beta: Option<Array1<f64>>,
pub(super) link_dev_beta: Option<Array1<f64>>,
}
pub(crate) fn build_score_warp_deviation_block_from_seed(
seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed(seed, seed, cfg)
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct BernoulliMarginalLinkMap {
pub mu: f64,
pub mu1: f64,
pub mu2: f64,
pub mu3: f64,
pub mu4: f64,
pub q: f64,
pub q1: f64,
pub q2: f64,
pub q3: f64,
pub q4: f64,
pub q1_sq: f64,
pub q1_cu: f64,
pub q1_q: f64,
}
#[inline]
pub(super) fn clamp_bernoulli_link_probability(probability: f64) -> f64 {
probability.clamp(
BERNOULLI_LINK_PROBABILITY_EPS,
1.0 - BERNOULLI_LINK_PROBABILITY_EPS,
)
}
pub(crate) fn bernoulli_marginal_slope_eta_from_probability(
base_link: &InverseLink,
probability: f64,
context: &str,
) -> Result<f64, String> {
require_probit_marginal_slope_link(base_link, context)?;
let target = clamp_bernoulli_link_probability(probability);
standard_normal_quantile(target)
.map_err(|e| format!("{context} failed to invert probit probability {target}: {e}"))
}
pub(crate) fn bernoulli_marginal_link_map(
base_link: &InverseLink,
eta: f64,
) -> Result<BernoulliMarginalLinkMap, String> {
require_probit_marginal_slope_link(base_link, "bernoulli marginal-slope")?;
let raw_mu = normal_cdf(eta);
let mu = clamp_bernoulli_link_probability(raw_mu);
let q = standard_normal_quantile(mu).map_err(|e| {
format!("bernoulli marginal-slope probit target inversion failed at mu={mu}: {e}")
})?;
if raw_mu <= BERNOULLI_LINK_PROBABILITY_EPS || raw_mu >= 1.0 - BERNOULLI_LINK_PROBABILITY_EPS {
return Ok(BernoulliMarginalLinkMap {
mu,
mu1: 0.0,
mu2: 0.0,
mu3: 0.0,
mu4: 0.0,
q,
q1: 0.0,
q2: 0.0,
q3: 0.0,
q4: 0.0,
q1_sq: 0.0,
q1_cu: 0.0,
q1_q: 0.0,
});
}
let phi_eta = normal_pdf(eta);
let phi_q = normal_pdf(q);
if !phi_q.is_finite() || phi_q <= 0.0 {
return Err(format!(
"bernoulli marginal-slope internal probit density must be positive, got phi(q)={phi_q} at eta={eta}, q={q}"
));
}
let mu1 = phi_eta;
let mu2 = -eta * phi_eta;
let mu3 = (eta * eta - 1.0) * phi_eta;
let mu4 = -(eta.powi(3) - 3.0 * eta) * phi_eta;
let q1 = mu1 / phi_q;
let q1_sq = q1 * q1;
let q1_cu = q1_sq * q1;
let q1_q = q1_sq * q1_sq;
let q2 = mu2 / phi_q + q * q1_sq;
let q3 = mu3 / phi_q + 3.0 * q * q1 * q2 - (q * q - 1.0) * q1_cu;
let q4 = mu4 / phi_q + (q.powi(3) - 3.0 * q) * q1_q + 4.0 * q * q1 * q3 + 3.0 * q * q2 * q2
- 6.0 * (q * q - 1.0) * q1_sq * q2;
Ok(BernoulliMarginalLinkMap {
mu,
mu1,
mu2,
mu3,
mu4,
q,
q1,
q2,
q3,
q4,
q1_sq,
q1_cu,
q1_q,
})
}
pub(super) fn require_probit_marginal_slope_link(
base_link: &InverseLink,
context: &str,
) -> Result<(), String> {
if matches!(base_link, InverseLink::Standard(StandardLink::Probit)) {
Ok(())
} else {
Err(format!(
"{context} requires link(type=probit); non-probit marginal-slope base links are not supported by the calibrated de-nested probit kernel"
))
}
}
pub(crate) fn build_link_deviation_block_from_knots_design_seed_and_weights(
knot_seed: &Array1<f64>,
design_seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
build_deviation_block_from_knots_and_design_seed(knot_seed, design_seed, cfg)
}
pub(super) fn build_deviation_block_from_knots_and_design_seed(
knot_seed: &Array1<f64>,
design_seed: &Array1<f64>,
cfg: &DeviationBlockConfig,
) -> Result<DeviationPrepared, String> {
if cfg.degree != 3 {
return Err(format!(
"structural deviation runtime is cubic; degree must be 3, got {}",
cfg.degree
));
}
let penalty_orders = resolve_deviation_operator_orders(cfg)?;
let knots = initialize_monotone_wiggle_knots_from_seed(
knot_seed.view(),
cfg.degree,
cfg.num_internal_knots,
)?;
let max_penalty_order = penalty_orders.iter().copied().max().ok_or_else(|| {
"deviation block requires at least one positive function-penalty derivative order"
.to_string()
})?;
let runtime = DeviationRuntime::try_new(knots, cfg.monotonicity_eps, max_penalty_order)?;
let design = runtime.design(design_seed)?;
let p = design.ncols();
if p == 0 {
return Err("structural deviation basis has no free derivative controls".to_string());
}
let mut block = ParameterBlockInput {
design: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
offset: Array1::zeros(design_seed.len()),
penalties: Vec::new(),
nullspace_dims: Vec::new(),
initial_log_lambdas: None,
initial_beta: Some(Array1::zeros(p)),
};
for order in penalty_orders {
append_deviation_function_penalty(&mut block, &runtime, order)?;
}
if cfg.double_penalty {
append_deviation_function_penalty(&mut block, &runtime, 0)?;
}
Ok(DeviationPrepared { block, runtime })
}
pub(super) fn resolve_deviation_operator_orders(
cfg: &DeviationBlockConfig,
) -> Result<Vec<usize>, String> {
let mut orders = Vec::new();
let requested = if cfg.penalty_orders.is_empty() {
std::slice::from_ref(&cfg.penalty_order)
} else {
cfg.penalty_orders.as_slice()
};
for &order in requested {
if order == 0 {
continue;
}
if order > cfg.degree {
return Err(format!(
"deviation function penalty derivative order {order} exceeds basis degree {}",
cfg.degree
));
}
if !orders.contains(&order) {
orders.push(order);
}
}
if orders.is_empty() {
return Err(
"deviation block requires at least one positive function-penalty derivative order"
.to_string(),
);
}
Ok(orders)
}
pub(super) fn append_deviation_function_penalty(
block: &mut ParameterBlockInput,
runtime: &DeviationRuntime,
derivative_order: usize,
) -> Result<(), String> {
let (penalty, nullity) =
runtime.integrated_derivative_penalty_with_nullity(derivative_order)?;
block
.penalties
.push(crate::solver::estimate::PenaltySpec::Dense(penalty));
block.nullspace_dims.push(nullity);
Ok(())
}