Skip to main content

gam_models/gamlss/binomial/
wiggle_workspace.rs

1// Real concern-organized submodule of the gamlss family stack.
2// Cross-module items are re-exported flat through the parent (`gamlss.rs`),
3// so `use super::*;` makes the sibling-concern symbols this module references
4// resolve through the parent namespace.
5use super::*;
6
7/// Matrix-free joint-Hessian operator for the 3-block binomial
8/// location-scale wiggle family. See `BinomialLocationScaleWiggleHessianRowPieces`
9/// for the per-row weight structure.
10pub(crate) struct BinomialLocationScaleWiggleHessianWorkspace {
11    pub(crate) family: BinomialLocationScaleWiggleFamily,
12    pub(crate) block_states: Vec<ParameterBlockState>,
13    pub(crate) x_t: Arc<Array2<f64>>,
14    pub(crate) x_ls: Arc<Array2<f64>>,
15    pub(crate) pieces: BinomialLocationScaleWiggleHessianRowPieces,
16}
17
18impl BinomialLocationScaleWiggleHessianWorkspace {
19    pub(crate) fn new(
20        family: BinomialLocationScaleWiggleFamily,
21        block_states: Vec<ParameterBlockState>,
22        x_t: Array2<f64>,
23        x_ls: Array2<f64>,
24    ) -> Result<Self, String> {
25        let pieces = family.wiggle_hessian_row_pieces(&block_states)?;
26        Ok(Self {
27            family,
28            block_states,
29            x_t: Arc::new(x_t),
30            x_ls: Arc::new(x_ls),
31            pieces,
32        })
33    }
34
35    /// Apply a Horvitz–Thompson outer-row subsample mask to the precomputed
36    /// per-row coefficient arrays in place.
37    ///
38    /// Each sampled row's `coeff_*[i]` is multiplied by its
39    /// `WeightedOuterRow.weight` (the HT inverse-inclusion factor 1/π_i —
40    /// uniform or stratified sampling both supported). All non-sampled rows
41    /// are zeroed. Because every downstream assembly (`hessian_dense`,
42    /// `hessian_matvec`, `hessian_diagonal`) is row-linear in these arrays
43    /// via `Xᵀ diag(W) Y`, the resulting joint-Hessian is an unbiased
44    /// estimator of the full-data joint Hessian. The `b0`/`d0` basis matrices
45    /// are independent of the per-row weights and remain unchanged.
46    pub(crate) fn apply_outer_subsample(
47        &mut self,
48        rows: &[crate::outer_subsample::WeightedOuterRow],
49    ) {
50        let n = self.pieces.coeff_tt.len();
51        let mut mask_tt = Array1::<f64>::zeros(n);
52        let mut mask_tl = Array1::<f64>::zeros(n);
53        let mut mask_ll = Array1::<f64>::zeros(n);
54        let mut mask_tw_b = Array1::<f64>::zeros(n);
55        let mut mask_tw_d = Array1::<f64>::zeros(n);
56        let mut mask_lw_b = Array1::<f64>::zeros(n);
57        let mut mask_lw_d = Array1::<f64>::zeros(n);
58        let mut maskww = Array1::<f64>::zeros(n);
59        for r in rows {
60            let i = r.index;
61            let w = r.weight;
62            mask_tt[i] = self.pieces.coeff_tt[i] * w;
63            mask_tl[i] = self.pieces.coeff_tl[i] * w;
64            mask_ll[i] = self.pieces.coeff_ll[i] * w;
65            mask_tw_b[i] = self.pieces.coeff_tw_b[i] * w;
66            mask_tw_d[i] = self.pieces.coeff_tw_d[i] * w;
67            mask_lw_b[i] = self.pieces.coeff_lw_b[i] * w;
68            mask_lw_d[i] = self.pieces.coeff_lw_d[i] * w;
69            maskww[i] = self.pieces.coeffww[i] * w;
70        }
71        self.pieces.coeff_tt = mask_tt;
72        self.pieces.coeff_tl = mask_tl;
73        self.pieces.coeff_ll = mask_ll;
74        self.pieces.coeff_tw_b = mask_tw_b;
75        self.pieces.coeff_tw_d = mask_tw_d;
76        self.pieces.coeff_lw_b = mask_lw_b;
77        self.pieces.coeff_lw_d = mask_lw_d;
78        self.pieces.coeffww = maskww;
79    }
80}
81
82impl ExactNewtonJointHessianWorkspace for BinomialLocationScaleWiggleHessianWorkspace {
83    fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
84        // Same Hv structure as `hessian_matvec`, but routed through the
85        // already-existing `assemble_dense` row-pieces helper (eight GEMMs
86        // covering h_tt, h_tl, h_ll, h_tw_b, h_tw_d, h_lw_b, h_lw_d, h_ww).
87        // Avoids `total` canonical-basis HVPs in
88        // `MatrixFreeSpdOperator::materialize_dense_operator`, which at
89        // large scale (n≈320k, p_total≈82) costs ~568s per κ-iter versus
90        // ~1s for the dense build.
91        let dense = self
92            .pieces
93            .assemble_dense(self.x_t.as_ref(), self.x_ls.as_ref())?;
94        Ok(Some(dense))
95    }
96
97    fn hessian_matvec_available(&self) -> bool {
98        true
99    }
100
101    fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
102        let pt = self.x_t.ncols();
103        let pls = self.x_ls.ncols();
104        let pw = self.pieces.b0.ncols();
105        let total = pt + pls + pw;
106        if v.len() != total {
107            return Err(GamlssError::DimensionMismatch {
108                reason: format!(
109                    "BinomialLocationScaleWiggle matvec dimension mismatch: got {}, expected {}",
110                    v.len(),
111                    total
112                ),
113            }
114            .into());
115        }
116        let v_t = v.slice(s![0..pt]);
117        let v_ls = v.slice(s![pt..pt + pls]);
118        let v_w = v.slice(s![pt + pls..total]);
119
120        let u_t = self.x_t.dot(&v_t);
121        let u_ls = self.x_ls.dot(&v_ls);
122        let u_b = self.pieces.b0.dot(&v_w);
123        let u_d = self.pieces.d0.dot(&v_w);
124
125        let r_t = &self.pieces.coeff_tt * &u_t
126            + &self.pieces.coeff_tl * &u_ls
127            + &self.pieces.coeff_tw_b * &u_b
128            + &self.pieces.coeff_tw_d * &u_d;
129        let r_ls = &self.pieces.coeff_tl * &u_t
130            + &self.pieces.coeff_ll * &u_ls
131            + &self.pieces.coeff_lw_b * &u_b
132            + &self.pieces.coeff_lw_d * &u_d;
133        let r_b = &self.pieces.coeff_tw_b * &u_t
134            + &self.pieces.coeff_lw_b * &u_ls
135            + &self.pieces.coeffww * &u_b;
136        let r_d = &self.pieces.coeff_tw_d * &u_t + &self.pieces.coeff_lw_d * &u_ls;
137
138        let out_t = fast_atv(self.x_t.as_ref(), &r_t);
139        let out_ls = fast_atv(self.x_ls.as_ref(), &r_ls);
140        let out_w = fast_atv(&self.pieces.b0, &r_b) + &fast_atv(&self.pieces.d0, &r_d);
141
142        let mut out = Array1::<f64>::zeros(total);
143        out.slice_mut(s![0..pt]).assign(&out_t);
144        out.slice_mut(s![pt..pt + pls]).assign(&out_ls);
145        out.slice_mut(s![pt + pls..total]).assign(&out_w);
146        Ok(Some(out))
147    }
148
149    fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
150        let pt = self.x_t.ncols();
151        let pls = self.x_ls.ncols();
152        let pw = self.pieces.b0.ncols();
153        let total = pt + pls + pw;
154        let mut diag = Array1::<f64>::zeros(total);
155        let n = self.pieces.coeff_tt.len();
156        for j in 0..pt {
157            let col = self.x_t.column(j);
158            let mut acc = 0.0;
159            for i in 0..n {
160                let v = col[i];
161                acc += self.pieces.coeff_tt[i] * v * v;
162            }
163            diag[j] = acc;
164        }
165        for j in 0..pls {
166            let col = self.x_ls.column(j);
167            let mut acc = 0.0;
168            for i in 0..n {
169                let v = col[i];
170                acc += self.pieces.coeff_ll[i] * v * v;
171            }
172            diag[pt + j] = acc;
173        }
174        for j in 0..pw {
175            let col = self.pieces.b0.column(j);
176            let mut acc = 0.0;
177            for i in 0..n {
178                let v = col[i];
179                acc += self.pieces.coeffww[i] * v * v;
180            }
181            diag[pt + pls + j] = acc;
182        }
183        Ok(Some(diag))
184    }
185
186    fn directional_derivative(
187        &self,
188        d_beta_flat: &Array1<f64>,
189    ) -> Result<Option<Array2<f64>>, String> {
190        self.family
191            .exact_newton_joint_hessian_directional_derivative(&self.block_states, d_beta_flat)
192    }
193
194    fn directional_derivative_operator(
195        &self,
196        d_beta_flat: &Array1<f64>,
197    ) -> Result<Option<Arc<dyn gam_problem::HyperOperator>>, String> {
198        self.family.bls_wiggle_directional_operator(
199            &self.block_states,
200            self.x_t.clone(),
201            self.x_ls.clone(),
202            d_beta_flat,
203        )
204    }
205
206    fn second_directional_derivative(
207        &self,
208        d_beta_u_flat: &Array1<f64>,
209        d_beta_v_flat: &Array1<f64>,
210    ) -> Result<Option<Array2<f64>>, String> {
211        self.family
212            .exact_newton_joint_hessiansecond_directional_derivative(
213                &self.block_states,
214                d_beta_u_flat,
215                d_beta_v_flat,
216            )
217    }
218
219    fn second_directional_derivative_operator(
220        &self,
221        d_beta_u: &Array1<f64>,
222        d_beta_v: &Array1<f64>,
223    ) -> Result<Option<Arc<dyn gam_problem::HyperOperator>>, String> {
224        self.family.bls_wiggle_second_directional_operator(
225            &self.block_states,
226            self.x_t.clone(),
227            self.x_ls.clone(),
228            d_beta_u,
229            d_beta_v,
230        )
231    }
232}
233
234impl CustomFamilyGenerative for BinomialLocationScaleWiggleFamily {
235    fn generativespec(
236        &self,
237        block_states: &[ParameterBlockState],
238    ) -> Result<GenerativeSpec, String> {
239        validate_block_count::<GamlssError>(
240            "BinomialLocationScaleWiggleFamily",
241            3,
242            block_states.len(),
243        )?;
244        let eta_t = &block_states[Self::BLOCK_T].eta;
245        let eta_ls = &block_states[Self::BLOCK_LOG_SIGMA].eta;
246        let etaw = &block_states[Self::BLOCK_WIGGLE].eta;
247        if eta_t.len() != self.y.len() || eta_ls.len() != self.y.len() || etaw.len() != self.y.len()
248        {
249            return Err(GamlssError::DimensionMismatch {
250                reason: "BinomialLocationScaleWiggleFamily generative size mismatch".to_string(),
251            }
252            .into());
253        }
254        let mean = gamlss_rowwise_map_result(self.y.len(), |i| {
255            let sigma = exp_sigma_from_eta_scalar(eta_ls[i]);
256            let q0 = binomial_location_scale_q0(eta_t[i], sigma);
257            let jet = inverse_link_jet_for_inverse_link(&self.link_kind, q0 + etaw[i])
258                .map_err(|e| format!("location-scale inverse-link evaluation failed: {e}"))?;
259            Ok(jet.mu)
260        })?;
261        Ok(GenerativeSpec {
262            mean,
263            noise: NoiseModel::Bernoulli,
264        })
265    }
266}