gam_identifiability/families/
bernoulli.rs1use ndarray::{Array1, Array2, Array3};
16
17use crate::families::compiler::{RowHessian, RowJacobianOperator, scale_jacobian_by_sqrt_h_with};
18use gam_problem::FamilyChannelHessian;
19
20#[inline]
22fn phi(x: f64) -> f64 {
23 (-0.5 * x * x).exp() / (std::f64::consts::TAU).sqrt()
24}
25
26#[inline]
28fn cdf(x: f64) -> f64 {
29 gam_math::probability::normal_cdf(x)
30}
31
32fn probit_irls_weight(eta: f64, sample_weight: f64) -> f64 {
36 let p = cdf(eta).clamp(f64::MIN_POSITIVE, 1.0 - f64::MIN_POSITIVE);
37 let one_m = (1.0 - p).max(f64::MIN_POSITIVE);
38 let phi_eta = phi(eta);
39 let denom = (p * one_m).max(f64::MIN_POSITIVE);
40 sample_weight * phi_eta * phi_eta / denom
41}
42
43pub struct BernoulliRowHessian {
46 w: Array1<f64>,
47}
48
49impl BernoulliRowHessian {
50 pub fn from_eta_pilot(eta_pilot: &Array1<f64>, sample_weights: &Array1<f64>) -> Self {
51 assert_eq!(
52 eta_pilot.len(),
53 sample_weights.len(),
54 "BernoulliRowHessian: eta_pilot length {} must match sample_weights length {}",
55 eta_pilot.len(),
56 sample_weights.len(),
57 );
58 let w = Array1::from_iter(
59 eta_pilot
60 .iter()
61 .zip(sample_weights.iter())
62 .map(|(&eta, &w)| probit_irls_weight(eta, w)),
63 );
64 Self { w }
65 }
66
67 pub fn from_row_weights(w: Array1<f64>) -> Self {
70 Self { w }
71 }
72
73 pub fn row_weights(&self) -> &Array1<f64> {
75 &self.w
76 }
77}
78
79impl RowHessian for BernoulliRowHessian {
80 fn k(&self) -> usize {
81 1
82 }
83 fn nrows(&self) -> usize {
84 self.w.len()
85 }
86 fn fill_row(&self, row: usize, out: &mut [f64]) {
87 assert_eq!(out.len(), 1, "BernoulliRowHessian::fill_row expects K=1");
88 out[0] = self.w[row];
89 }
90 fn evaluate_full(&self) -> Array3<f64> {
91 let n = self.w.len();
92 let mut out = Array3::<f64>::zeros((n, 1, 1));
93 for i in 0..n {
94 out[[i, 0, 0]] = self.w[i];
95 }
96 out
97 }
98}
99
100impl FamilyChannelHessian for BernoulliRowHessian {
114 fn n_outputs(&self) -> usize {
115 1
116 }
117
118 fn n_subjects(&self) -> usize {
119 self.w.len()
120 }
121
122 fn fill_subject(&self, i: usize, out: &mut [f64]) {
123 assert_eq!(
124 out.len(),
125 1,
126 "BernoulliRowHessian::fill_subject expects K=1"
127 );
128 out[0] = self.w[i];
129 }
130
131 fn evaluate_full(&self) -> ndarray::Array3<f64> {
132 let n = self.w.len();
133 let mut out = ndarray::Array3::<f64>::zeros((n, 1, 1));
134 for i in 0..n {
135 out[[i, 0, 0]] = self.w[i];
136 }
137 out
138 }
139}
140
141pub struct BernoulliDenseDesignOperator {
145 design: Array2<f64>,
146}
147
148impl BernoulliDenseDesignOperator {
149 pub fn new(design: Array2<f64>) -> Self {
150 Self { design }
151 }
152}
153
154impl RowJacobianOperator for BernoulliDenseDesignOperator {
155 fn k(&self) -> usize {
156 1
157 }
158 fn ncols(&self) -> usize {
159 self.design.ncols()
160 }
161 fn nrows(&self) -> usize {
162 self.design.nrows()
163 }
164 fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
165 assert_eq!(out.len(), 1);
166 assert_eq!(delta_beta.len(), self.design.ncols());
167 let mut acc = 0.0;
168 for (j, &b) in delta_beta.iter().enumerate() {
169 acc += self.design[[row, j]] * b;
170 }
171 out[0] = acc;
172 }
173 fn evaluate_full(&self) -> Array3<f64> {
174 let n = self.design.nrows();
175 let p = self.design.ncols();
176 let mut out = Array3::<f64>::zeros((n, p, 1));
177 for i in 0..n {
178 for j in 0..p {
179 out[[i, j, 0]] = self.design[[i, j]];
180 }
181 }
182 out
183 }
184 fn scaled_design_by_sqrt_h(&self, h_full: &Array3<f64>) -> Array2<f64> {
185 let n = self.design.nrows();
189 let p = self.design.ncols();
190 scale_jacobian_by_sqrt_h_with(n, p, 1, h_full, |i, a, c| {
191 assert_eq!(c, 0, "K=1 operator has only channel 0");
192 self.design[[i, a]]
193 })
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn bernoulli_row_hessian_matches_probit_irls_weight() {
203 let eta = Array1::from(vec![-1.5_f64, 0.0, 0.75, 2.0]);
204 let w = Array1::from(vec![1.0_f64; 4]);
205 let hess = BernoulliRowHessian::from_eta_pilot(&eta, &w);
206 for i in 0..eta.len() {
207 let want = probit_irls_weight(eta[i], 1.0);
208 let got = hess.row_weights()[i];
209 assert!(
210 (got - want).abs() < 1e-14,
211 "row {i}: got {got}, want {want}"
212 );
213 }
214 }
215
216 #[test]
217 fn dense_design_operator_evaluate_full_shape() {
218 let design = Array2::from_shape_fn((5, 3), |(i, j)| (i as f64) * 0.1 + (j as f64));
219 let op = BernoulliDenseDesignOperator::new(design.clone());
220 let full = op.evaluate_full();
221 assert_eq!(full.shape(), &[5, 3, 1]);
222 for i in 0..5 {
223 for j in 0..3 {
224 assert_eq!(full[[i, j, 0]], design[[i, j]]);
225 }
226 }
227 }
228}