Skip to main content

gam_identifiability/families/
bernoulli.rs

1//! Bernoulli marginal-slope concrete impls for the family-agnostic
2//! identifiability compiler (`crate::families::compiler`).
3//!
4//! Bernoulli's row primary state is the scalar linear predictor `η_i`, so
5//! `K = 1` throughout. Every block's row Jacobian is the row of its dense
6//! design matrix; the row Hessian is the standard probit IRLS weight
7//!     `W_i = w_i · φ(η_i)² / (Φ(η_i) · Φ(−η_i))`
8//! evaluated at the pilot η.
9//!
10//! These concrete impls (`BernoulliRowHessian`, `BernoulliDenseDesignOperator`)
11//! feed the live BMS fit driver via `bms::install_flex`, whose
12//! `install_compiled_flex_block_into_runtime` is the entry point that
13//! residualises each flex block against the compiled parametric anchors.
14
15use ndarray::{Array1, Array2, Array3};
16
17use crate::families::compiler::{RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with};
18use gam_problem::FamilyChannelHessian;
19
20/// Standard normal pdf.
21#[inline]
22fn phi(x: f64) -> f64 {
23    (-0.5 * x * x).exp() / (std::f64::consts::TAU).sqrt()
24}
25
26/// Standard normal cdf. Wrapper for the codebase's `normal_cdf`.
27#[inline]
28fn cdf(x: f64) -> f64 {
29    gam_math::probability::normal_cdf(x)
30}
31
32/// Probit IRLS row weight `w_i · φ(η)² / (Φ(η) · Φ(−η))`. Clamped strictly
33/// positive: the residualised Gram must remain PSD even at extreme η pilots
34/// (probit saturation collapses both `Φ(η)` and `Φ(−η)` toward zero).
35fn probit_irls_weight(eta: f64, sample_weight: f64) -> f64 {
36    let p = cdf(eta).clamp(f64::MIN_POSITIVE, 1.0 - f64::MIN_POSITIVE);
37    let one_m = (1.0 - p).max(f64::MIN_POSITIVE);
38    let phi_eta = phi(eta);
39    let denom = (p * one_m).max(f64::MIN_POSITIVE);
40    sample_weight * phi_eta * phi_eta / denom
41}
42
43/// Row Hessian for Bernoulli's K=1 row primary state. The "Hessian" is the
44/// scalar IRLS weight per row at the pilot η.
45pub struct BernoulliRowHessian {
46    w: Array1<f64>,
47}
48
49impl BernoulliRowHessian {
50    pub fn from_eta_pilot(eta_pilot: &Array1<f64>, sample_weights: &Array1<f64>) -> Self {
51        assert_eq!(
52            eta_pilot.len(),
53            sample_weights.len(),
54            "BernoulliRowHessian: eta_pilot length {} must match sample_weights length {}",
55            eta_pilot.len(),
56            sample_weights.len(),
57        );
58        let w = Array1::from_iter(
59            eta_pilot
60                .iter()
61                .zip(sample_weights.iter())
62                .map(|(&eta, &w)| probit_irls_weight(eta, w)),
63        );
64        Self { w }
65    }
66
67    /// Construct directly from a pre-computed row-weight vector (e.g. the
68    /// existing `pilot_irls_hessian_row_metric_at_eta` output).
69    pub fn from_row_weights(w: Array1<f64>) -> Self {
70        Self { w }
71    }
72
73    /// Borrow the underlying per-row weight vector.
74    pub fn row_weights(&self) -> &Array1<f64> {
75        &self.w
76    }
77}
78
79impl RowHessian for BernoulliRowHessian {
80    fn k(&self) -> usize {
81        1
82    }
83    fn nrows(&self) -> usize {
84        self.w.len()
85    }
86    fn fill_row(&self, row: usize, out: &mut [f64]) {
87        assert_eq!(out.len(), 1, "BernoulliRowHessian::fill_row expects K=1");
88        out[0] = self.w[row];
89    }
90    fn evaluate_full(&self) -> Array3<f64> {
91        let n = self.w.len();
92        let mut out = Array3::<f64>::zeros((n, 1, 1));
93        for i in 0..n {
94            out[[i, 0, 0]] = self.w[i];
95        }
96        out
97    }
98}
99
100/// `FamilyChannelHessian` for Bernoulli marginal-slope.
101///
102/// BMS has a single output channel (K=1). The per-subject channel Hessian
103/// W_i is the scalar probit IRLS weight:
104///
105/// ```text
106/// W_i = w_i · φ(η_i)² / (Φ(η_i) · (1 − Φ(η_i)))
107/// ```
108///
109/// This is exactly the 1×1 scalar stored in `BernoulliRowHessian::w`.
110/// Since K=1, the scalar fast path is used and cross-channel curvature
111/// is vacuous. Families that genuinely have a single output channel
112/// (Gaussian, Binomial, Poisson, etc.) all use this 1×1 identity path.
113impl FamilyChannelHessian for BernoulliRowHessian {
114    fn n_outputs(&self) -> usize {
115        1
116    }
117
118    fn n_subjects(&self) -> usize {
119        self.w.len()
120    }
121
122    fn fill_subject(&self, i: usize, out: &mut [f64]) {
123        assert_eq!(
124            out.len(),
125            1,
126            "BernoulliRowHessian::fill_subject expects K=1"
127        );
128        out[0] = self.w[i];
129    }
130
131    fn evaluate_full(&self) -> ndarray::Array3<f64> {
132        let n = self.w.len();
133        let mut out = ndarray::Array3::<f64>::zeros((n, 1, 1));
134        for i in 0..n {
135            out[[i, 0, 0]] = self.w[i];
136        }
137        out
138    }
139}
140
141/// Row Jacobian operator backed by a dense design matrix. K=1 — the only
142/// channel is `δη = design.row(i) · δβ`. Covers BMS's marginal, logslope,
143/// score-warp, and link-deviation blocks uniformly.
144pub struct BernoulliDenseDesignOperator {
145    design: Array2<f64>,
146}
147
148impl BernoulliDenseDesignOperator {
149    pub fn new(design: Array2<f64>) -> Self {
150        Self { design }
151    }
152}
153
154impl RowJacobianOperator for BernoulliDenseDesignOperator {
155    fn k(&self) -> usize {
156        1
157    }
158    fn ncols(&self) -> usize {
159        self.design.ncols()
160    }
161    fn nrows(&self) -> usize {
162        self.design.nrows()
163    }
164    fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
165        assert_eq!(out.len(), 1);
166        assert_eq!(delta_beta.len(), self.design.ncols());
167        let mut acc = 0.0;
168        for (j, &b) in delta_beta.iter().enumerate() {
169            acc += self.design[[row, j]] * b;
170        }
171        out[0] = acc;
172    }
173    fn evaluate_full(&self) -> Array3<f64> {
174        let n = self.design.nrows();
175        let p = self.design.ncols();
176        let mut out = Array3::<f64>::zeros((n, p, 1));
177        for i in 0..n {
178            for j in 0..p {
179                out[[i, j, 0]] = self.design[[i, j]];
180            }
181        }
182        out
183    }
184    fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
185        // K=1: the only channel is `δη = design.row(i)·δβ`. Scale straight from
186        // the stored `(n, p)` design rather than reshaping it into a `(n, p, 1)`
187        // tensor first. (#738: a capability is not a representation.)
188        let n = self.design.nrows();
189        let p = self.design.ncols();
190        scale_jacobian_by_sqrt_h_with(n, p, 1, h_full, |i, a, c| {
191            assert_eq!(c, 0, "K=1 operator has only channel 0");
192            self.design[[i, a]]
193        })
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn bernoulli_row_hessian_matches_probit_irls_weight() {
203        let eta = Array1::from(vec![-1.5_f64, 0.0, 0.75, 2.0]);
204        let w = Array1::from(vec![1.0_f64; 4]);
205        let hess = BernoulliRowHessian::from_eta_pilot(&eta, &w);
206        for i in 0..eta.len() {
207            let want = probit_irls_weight(eta[i], 1.0);
208            let got = hess.row_weights()[i];
209            assert!(
210                (got - want).abs() < 1e-14,
211                "row {i}: got {got}, want {want}"
212            );
213        }
214    }
215
216    #[test]
217    fn dense_design_operator_evaluate_full_shape() {
218        let design = Array2::from_shape_fn((5, 3), |(i, j)| (i as f64) * 0.1 + (j as f64));
219        let op = BernoulliDenseDesignOperator::new(design.clone());
220        let full = op.evaluate_full();
221        assert_eq!(full.shape(), &[5, 3, 1]);
222        for i in 0..5 {
223            for j in 0..3 {
224                assert_eq!(full[[i, j, 0]], design[[i, j]]);
225            }
226        }
227    }
228}