use std::sync::Arc;
use ndarray::{Array1, Array2, Array3};
use crate::families::custom_family::FamilyChannelHessian;
use crate::families::identifiability_compiler::{
AnchorRowEvaluator, BlockOrder, RowHessian, RowJacobianOperator,
};
use crate::linalg::faer_ndarray::fast_ab;
#[inline]
fn phi(x: f64) -> f64 {
(-0.5 * x * x).exp() / (std::f64::consts::TAU).sqrt()
}
#[inline]
fn cdf(x: f64) -> f64 {
crate::inference::probability::normal_cdf(x)
}
fn probit_irls_weight(eta: f64, sample_weight: f64) -> f64 {
let p = cdf(eta).clamp(f64::MIN_POSITIVE, 1.0 - f64::MIN_POSITIVE);
let one_m = (1.0 - p).max(f64::MIN_POSITIVE);
let phi_eta = phi(eta);
let denom = (p * one_m).max(f64::MIN_POSITIVE);
sample_weight * phi_eta * phi_eta / denom
}
pub struct BernoulliRowHessian {
w: Array1<f64>,
}
impl BernoulliRowHessian {
pub fn from_eta_pilot(eta_pilot: &Array1<f64>, sample_weights: &Array1<f64>) -> Self {
assert_eq!(
eta_pilot.len(),
sample_weights.len(),
"BernoulliRowHessian: eta_pilot length {} must match sample_weights length {}",
eta_pilot.len(),
sample_weights.len(),
);
let w = Array1::from_iter(
eta_pilot
.iter()
.zip(sample_weights.iter())
.map(|(&eta, &w)| probit_irls_weight(eta, w)),
);
Self { w }
}
pub fn from_row_weights(w: Array1<f64>) -> Self {
Self { w }
}
pub fn row_weights(&self) -> &Array1<f64> {
&self.w
}
}
impl RowHessian for BernoulliRowHessian {
fn k(&self) -> usize {
1
}
fn nrows(&self) -> usize {
self.w.len()
}
fn fill_row(&self, row: usize, out: &mut [f64]) {
assert_eq!(out.len(), 1, "BernoulliRowHessian::fill_row expects K=1");
out[0] = self.w[row];
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.w.len();
let mut out = Array3::<f64>::zeros((n, 1, 1));
for i in 0..n {
out[[i, 0, 0]] = self.w[i];
}
out
}
}
impl FamilyChannelHessian for BernoulliRowHessian {
fn n_outputs(&self) -> usize {
1
}
fn n_subjects(&self) -> usize {
self.w.len()
}
fn fill_subject(&self, i: usize, out: &mut [f64]) {
assert_eq!(
out.len(),
1,
"BernoulliRowHessian::fill_subject expects K=1"
);
out[0] = self.w[i];
}
fn evaluate_full(&self) -> ndarray::Array3<f64> {
let n = self.w.len();
let mut out = ndarray::Array3::<f64>::zeros((n, 1, 1));
for i in 0..n {
out[[i, 0, 0]] = self.w[i];
}
out
}
}
pub struct BernoulliDenseDesignOperator {
design: Array2<f64>,
}
impl BernoulliDenseDesignOperator {
pub fn new(design: Array2<f64>) -> Self {
Self { design }
}
}
impl RowJacobianOperator for BernoulliDenseDesignOperator {
fn k(&self) -> usize {
1
}
fn ncols(&self) -> usize {
self.design.ncols()
}
fn nrows(&self) -> usize {
self.design.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), 1);
assert_eq!(delta_beta.len(), self.design.ncols());
let mut acc = 0.0;
for (j, &b) in delta_beta.iter().enumerate() {
acc += self.design[[row, j]] * b;
}
out[0] = acc;
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.design.nrows();
let p = self.design.ncols();
let mut out = Array3::<f64>::zeros((n, p, 1));
for i in 0..n {
for j in 0..p {
out[[i, j, 0]] = self.design[[i, j]];
}
}
out
}
}
pub struct ParametricAnchorEvaluator {
design: Array2<f64>,
}
impl ParametricAnchorEvaluator {
pub fn new(design: Array2<f64>) -> Self {
Self { design }
}
}
impl AnchorRowEvaluator for ParametricAnchorEvaluator {
fn anchor_rows(&self, predict_arg: &Array1<f64>) -> Result<Array2<f64>, String> {
if predict_arg.len() != self.design.nrows() {
return Err(format!(
"ParametricAnchorEvaluator: predict_arg length {} must match \
materialised design rows {}",
predict_arg.len(),
self.design.nrows()
));
}
Ok(self.design.clone())
}
fn ncols(&self) -> usize {
self.design.ncols()
}
}
pub struct CompiledFlexAnchorEvaluator {
raw_basis: Arc<dyn Fn(&Array1<f64>) -> Result<Array2<f64>, String> + Send + Sync>,
t_lw: Array2<f64>,
anchor_correction: Option<Array2<f64>>,
parent: Option<Arc<dyn AnchorRowEvaluator>>,
}
impl CompiledFlexAnchorEvaluator {
pub fn new(
raw_basis: Arc<dyn Fn(&Array1<f64>) -> Result<Array2<f64>, String> + Send + Sync>,
t_lw: Array2<f64>,
anchor_correction: Option<Array2<f64>>,
parent: Option<Arc<dyn AnchorRowEvaluator>>,
) -> Self {
Self {
raw_basis,
t_lw,
anchor_correction,
parent,
}
}
}
impl AnchorRowEvaluator for CompiledFlexAnchorEvaluator {
fn anchor_rows(&self, predict_arg: &Array1<f64>) -> Result<Array2<f64>, String> {
let raw = (self.raw_basis)(predict_arg)?;
let rotated = fast_ab(&raw, &self.t_lw);
match (&self.anchor_correction, &self.parent) {
(Some(m), Some(parent)) => {
let anchor = parent.anchor_rows(predict_arg)?;
let correction = fast_ab(&anchor, m);
Ok(&rotated - &correction)
}
(None, _) | (_, None) => Ok(rotated),
}
}
fn ncols(&self) -> usize {
self.t_lw.ncols()
}
}
pub fn build_bernoulli_compiler_inputs(
marginal_design: Array2<f64>,
logslope_design: Array2<f64>,
score_warp_design: Option<Array2<f64>>,
link_dev_design: Option<Array2<f64>>,
) -> (Vec<Arc<dyn RowJacobianOperator>>, Vec<BlockOrder>) {
let mut ops: Vec<Arc<dyn RowJacobianOperator>> = Vec::with_capacity(4);
let mut order: Vec<BlockOrder> = Vec::with_capacity(4);
ops.push(Arc::new(BernoulliDenseDesignOperator::new(marginal_design)));
order.push(BlockOrder::Marginal);
ops.push(Arc::new(BernoulliDenseDesignOperator::new(logslope_design)));
order.push(BlockOrder::Logslope);
if let Some(sw) = score_warp_design {
ops.push(Arc::new(BernoulliDenseDesignOperator::new(sw)));
order.push(BlockOrder::ScoreWarp);
}
if let Some(ld) = link_dev_design {
ops.push(Arc::new(BernoulliDenseDesignOperator::new(ld)));
order.push(BlockOrder::LinkDev);
}
(ops, order)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bernoulli_row_hessian_matches_probit_irls_weight() {
let eta = Array1::from(vec![-1.5_f64, 0.0, 0.75, 2.0]);
let w = Array1::from(vec![1.0_f64; 4]);
let hess = BernoulliRowHessian::from_eta_pilot(&eta, &w);
for i in 0..eta.len() {
let want = probit_irls_weight(eta[i], 1.0);
let got = hess.row_weights()[i];
assert!(
(got - want).abs() < 1e-14,
"row {i}: got {got}, want {want}"
);
}
}
#[test]
fn dense_design_operator_evaluate_full_shape() {
let design = Array2::from_shape_fn((5, 3), |(i, j)| (i as f64) * 0.1 + (j as f64));
let op = BernoulliDenseDesignOperator::new(design.clone());
let full = op.evaluate_full();
assert_eq!(full.shape(), &[5, 3, 1]);
for i in 0..5 {
for j in 0..3 {
assert_eq!(full[[i, j, 0]], design[[i, j]]);
}
}
}
#[test]
fn parametric_anchor_evaluator_returns_design_verbatim() {
let design = Array2::from_shape_fn((4, 2), |(i, j)| (i + j) as f64);
let ev = ParametricAnchorEvaluator::new(design.clone());
let predict_arg = Array1::from(vec![0.0_f64; 4]);
let rows = ev.anchor_rows(&predict_arg).expect("anchor_rows ok");
assert_eq!(rows, design);
}
}