Skip to main content

gam_problem/
block_spec.rs

1//! Data model for the blockwise carrier (subset moved down to `gam-problem`):
2//! parameter-block specs, the effective-Jacobian and channel-Hessian
3//! abstractions, per-block working sets and states, and the block geometry
4//! directional derivative.
5//!
6//! The coefficient-group/label/prior types, `custom_family_block_role`, and the
7//! blockspec validators remain in the family crate because they depend on
8//! `CoefficientGroupPrior`/`RhoPrior`/`BlockRole`/`CustomFamilyError`.
9
10use std::ops::Range;
11use std::sync::Arc;
12
13use ndarray::{Array1, Array2};
14
15use crate::PenaltyMatrix;
16use gam_linalg::matrix::{DesignMatrix, SymmetricMatrix};
17
18/// Per-subject channel Hessian provider for multi-output families.
19///
20/// The Fisher information decomposition for multi-output families is
21///
22/// ```text
23/// I(β) = Σ_i  J_iᵀ W_i J_i
24/// ```
25///
26/// where `J_i` is the channel-stacked Jacobian (shape `n_outputs × p` for
27/// subject `i`) and `W_i` is the `n_outputs × n_outputs` per-subject channel
28/// Hessian of the row negative log-likelihood (the second-derivative block of
29/// `−log L_i(u_i)` at a pilot β, PSD-clamped).
30///
31/// For single-output families this is the scalar IRLS weight; for multi-output
32/// families (survival marginal-slope: `n_outputs = 4`; location-scale:
33/// `n_outputs = 2`) it carries full cross-channel curvature.
34///
35/// The identifiability canonicalisation step uses the `n_outputs`-channel
36/// weighted joint design `W_joint = Σ_i sqrt(W_i) ⊗ J_i` to detect
37/// block-against-block aliasing.  When this trait is present on
38/// `ParameterBlockSpec::channel_hessian`, `canonicalize_for_identifiability`
39/// routes through `audit_identifiability_channel_aware`; when absent it falls
40/// back to the scalar-weight flat audit.
41///
42/// # W-metric rank theorem
43///
44/// The canonicalisation computes `rank(J^T W J)` where `W_blkdiag =
45/// block-diagonal of per-subject W_i`.  This rank equals
46///
47/// ```text
48/// rank(J) − dim(range(J) ∩ ker(W_blkdiag))
49/// ```
50///
51/// i.e. columns of `J` that lie in the kernel of `W_blkdiag` (flat directions
52/// in the curvature landscape at the pilot β) are correctly identified as
53/// curvature-redundant and may be dropped.
54pub trait FamilyChannelHessian: Send + Sync {
55    /// Number of output channels `n_outputs` (= K in the row Jacobian).
56    fn n_outputs(&self) -> usize;
57
58    /// Number of subjects (rows).
59    fn n_subjects(&self) -> usize;
60
61    /// Fill the `n_outputs × n_outputs` per-subject channel Hessian `W_i`
62    /// into `out` (row-major, length `n_outputs * n_outputs`) for subject `i`.
63    /// Negative eigenvalues must be clamped to zero (PSD projection) before
64    /// or inside this call.
65    fn fill_subject(&self, i: usize, out: &mut [f64]);
66
67    /// Materialise the full `(n_subjects × n_outputs × n_outputs)` tensor.
68    /// Default implementation calls `fill_subject` for each row.
69    fn evaluate_full(&self) -> ndarray::Array3<f64> {
70        let n = self.n_subjects();
71        let k = self.n_outputs();
72        let mut out = ndarray::Array3::<f64>::zeros((n, k, k));
73        let mut buf = vec![0.0_f64; k * k];
74        for i in 0..n {
75            self.fill_subject(i, &mut buf);
76            for a in 0..k {
77                for b in 0..k {
78                    out[[i, a, b]] = buf[a * k + b];
79                }
80            }
81        }
82        out
83    }
84
85    /// Return a refreshed W evaluated at `beta` using `family_scalars` when
86    /// those scalars carry the per-row primary state at the current β.
87    ///
88    /// # Fisher information identity
89    ///
90    /// `I(β) = J(β)^T W(β) J(β)`. T8 originally froze W at β=0; T34 refreshes
91    /// both J and W at the current β so the audit's rank verdict reflects the
92    /// actual local identifiability.
93    ///
94    /// # Default implementation (β-independent W)
95    ///
96    /// Families whose W is β-independent (e.g. Gaussian-identity where
97    /// `W = prior_w`) return a clone of their frozen W by delegating to
98    /// `evaluate_full()`. No recomputation is performed. `beta` and
99    /// `family_scalars` are ignored.
100    ///
101    /// # Override (β-dependent W)
102    ///
103    /// Families with β-dependent W (e.g. survival marginal-slope where
104    /// `W_i(β)` depends on `(q0_i, q1_i, qd1_i, g_i)`) must override this
105    /// method and recompute W from the current primary state.
106    ///
107    /// When `beta` is non-zero in a way that affects W (i.e. `g_i != 0`),
108    /// `family_scalars` MUST be `Some(..)`. Return `Err` if scalars are
109    /// missing in that case (same error-message style as T26's contract).
110    fn channel_hessian_at(
111        &self,
112        beta: &[f64],
113        family_scalars: Option<&std::sync::Arc<dyn std::any::Any + Send + Sync>>,
114    ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
115        // Default: W is β-independent — return a snapshot of the frozen W
116        // wrapped in a simple tensor-backed implementation. β and
117        // family_scalars are validated (NaN-guard, presence flag) so callers
118        // that pass garbage state still see an Err rather than a silently-stale
119        // W. The default impl does not require family_scalars; family-specific
120        // overrides may.
121        if beta.iter().any(|v| v.is_nan()) {
122            return Err("channel_hessian_at: beta contains NaN".to_string());
123        }
124        // Acknowledge family_scalars without binding it to a discarded name.
125        if family_scalars.is_some() && beta.is_empty() {
126            return Err(
127                "channel_hessian_at: family_scalars supplied but beta is empty".to_string(),
128            );
129        }
130        let tensor = self.evaluate_full();
131        Ok(Arc::new(TensorChannelHessian { h: tensor }))
132    }
133}
134
135/// A [`FamilyChannelHessian`] backed directly by a pre-computed
136/// `(n × K × K)` tensor. Used by the default `channel_hessian_at`
137/// implementation and by tests.
138///
139/// This is the β-independent path: `fill_subject` reads from the frozen
140/// tensor without any recomputation.
141pub struct TensorChannelHessian {
142    pub h: ndarray::Array3<f64>,
143}
144
145impl FamilyChannelHessian for TensorChannelHessian {
146    fn n_outputs(&self) -> usize {
147        self.h.shape()[1]
148    }
149
150    fn n_subjects(&self) -> usize {
151        self.h.shape()[0]
152    }
153
154    fn fill_subject(&self, i: usize, out: &mut [f64]) {
155        let k = self.h.shape()[1];
156        assert_eq!(out.len(), k * k);
157        for a in 0..k {
158            for b in 0..k {
159                out[a * k + b] = self.h[[i, a, b]];
160            }
161        }
162    }
163
164    fn evaluate_full(&self) -> ndarray::Array3<f64> {
165        self.h.clone()
166    }
167}
168
169/// β-linearization state passed to [`BlockEffectiveJacobian::effective_jacobian_at`].
170///
171/// At pre-fit initialization, pass `beta = &[]` / zeros and `family_scalars = None`.
172/// Families that need β-dependent scalars (e.g. survival marginal-slope's q0, q1,
173/// g, c, z) store them in `family_scalars` as a concrete type behind
174/// `Arc<dyn Any + Send + Sync>` and downcast inside their impl.
175pub struct FamilyLinearizationState<'a> {
176    pub beta: &'a [f64],
177    /// Optional family-shared scalars at this β linearization.
178    /// Downcast via `state.family_scalars.as_ref().and_then(|a| a.downcast_ref::<T>())`.
179    pub family_scalars: Option<Arc<dyn std::any::Any + Send + Sync>>,
180    /// Optional per-subject channel Hessian for multi-output families.
181    /// When `Some`, the identifiability canonicalisation step and the Gram
182    /// builder use the channel-stacked Fisher information instead of the
183    /// scalar-weight approximation.  Single-output families leave this `None`.
184    pub channel_hessian: Option<Arc<dyn FamilyChannelHessian>>,
185    /// Probit frailty scale factor `s_f = 1/√(1+σ²)`.
186    ///
187    /// For survival marginal-slope families the logslope η contribution is
188    /// `s_f · g · z`, so any Jacobian callback that depends on g or z must
189    /// read `s_f` from here rather than from a captured-at-construction value.
190    /// When σ = 0 (no frailty) or for non-frailty families, set this to 1.0.
191    ///
192    /// Since σ is always **fixed** (not jointly optimised with β) in the
193    /// survival family, `s_f` is a static scalar for the entire inner fit;
194    /// `∂s_f/∂σ` never appears in the β-Jacobian.  The field is nonetheless
195    /// carried through state so that Jacobian callbacks are not required to
196    /// capture `s_f` at spec-construction time — they can read it at
197    /// evaluation time and thus stay correct across outer-loop σ updates.
198    pub probit_frailty_scale: f64,
199}
200
201/// β-dependent Jacobian callback for a parameter block.
202///
203/// Principled long-term contract for expressing how a block contributes to
204/// the stacked linear predictor at a given β:
205///
206/// ```text
207/// J(β) ∈ ℝ^{n_rows · n_outputs × p_block}
208/// ```
209///
210/// - Single-output linear block: returns `design.clone()`.
211/// - Row-scaled block (`RowScaledJacobian`): returns `diag(eta_scaling) · design` (still linear in β).
212/// - Multi-output block (e.g. survival marginal-slope with η0, η1, ad1):
213///   stacks `∂eta_r/∂β_k` for `r ∈ 0..n_outputs`, row-major ordering.
214///
215/// The default impl on [`ParameterBlockSpec::effective_jacobian_at`] is:
216/// - `jacobian_callback = None` → `design.clone()`.
217/// - `jacobian_callback = Some(cb)` → delegates to `cb.effective_jacobian_at`.
218pub trait BlockEffectiveJacobian: Send + Sync {
219    /// Stacked multi-output Jacobian for a contiguous observation row range.
220    ///
221    /// Shape: `(rows.len() * n_outputs, p_block)`, with the same channel-major
222    /// layout as [`Self::effective_jacobian_at`]: row
223    /// `channel * rows.len() + local_row` is `rows.start + local_row` in that
224    /// output channel. Implementations should keep this as the single source of
225    /// row math so large construction-time audits can stream chunks instead of
226    /// materialising all `n * p * K` entries at once.
227    fn effective_jacobian_rows(
228        &self,
229        state: &FamilyLinearizationState<'_>,
230        rows: Range<usize>,
231    ) -> Result<Array2<f64>, String>;
232
233    /// Stacked multi-output Jacobian at the current β.
234    ///
235    /// Shape: `(n_rows * n_outputs, p_block)`, **channel-major**: rows
236    /// `r * n_rows .. (r + 1) * n_rows` carry output channel `r`'s row
237    /// Jacobian, so `stacked[r * n_rows + i, j]` is observation `i`'s row at
238    /// output `r` and coefficient column `j`.  Every consumer that destacks
239    /// this matrix (audit, canonicaliser, fit) relies on this layout — see
240    /// `BlockJacobianAsRowOp::from_callback` for the destacking transpose.
241    /// For `n_outputs = 1` this is identical to the `(n_rows, p_block)` effective
242    /// design used by the flat identifiability audit.
243    fn effective_jacobian_at(
244        &self,
245        state: &FamilyLinearizationState<'_>,
246    ) -> Result<Array2<f64>, String> {
247        let full = self.effective_jacobian_rows(state, 0..usize::MAX)?;
248        Ok(full)
249    }
250
251    /// Number of stacked output channels. 1 for most blocks.
252    fn n_outputs(&self) -> usize {
253        1
254    }
255
256    /// Returns the per-row scaling vector when this callback is a simple
257    /// diagonal-scaling block (`RowScaledJacobian`).  Used by the
258    /// identifiability audit's skewness-aware bias correction (T25).
259    ///
260    /// Returns `None` for all blocks except `RowScaledJacobian`.
261    fn eta_row_scaling_for_skewness(&self) -> Option<Arc<[f64]>> {
262        None
263    }
264
265    /// Whether the identifiability canonicaliser must keep this block at its
266    /// full raw column width instead of column-reducing it.
267    ///
268    /// The `#933` reduction path wraps a `jacobian_callback` block in a
269    /// gauge-composed Jacobian so the family fits in a reduced section — sound
270    /// only when the family's effective geometry is DERIVED from the callback
271    /// (multinomial softmax, marginal-slope logslope). It is NOT sound for a
272    /// callback whose effective Jacobian is a **fixed nonlinear functional
273    /// basis** recomputed at the raw coefficient width on every evaluation
274    /// (the survival marginal-slope monotone time-wiggle time block): its
275    /// downstream likelihood reads raw-width internal designs and asserts
276    /// `beta.len() == p_raw`, so a reduced β desynchronises that layout — the
277    /// same failure mode the competing-risks dead-column veto already guards.
278    /// Such a block returns `true` so the canonicaliser keeps it at raw width
279    /// (its own penalty nullspace regularises the weak directions instead).
280    ///
281    /// Defaults to `false`: every existing callback reduces safely.
282    fn locks_raw_width_reduction(&self) -> bool {
283        false
284    }
285}
286
287/// A [`BlockEffectiveJacobian`] for any block that contributes linearly to
288/// exactly one output of a multi-output family.
289///
290/// `own_output` is the zero-based output index that this block drives.
291/// `n_family_outputs` is the total number of outputs (e.g. 2 for location-scale).
292/// `design` is the block's effective design matrix (n × p_block).
293///
294/// The returned Jacobian has shape `(n_family_outputs * n, p_block)`:
295/// rows `own_output * n .. (own_output + 1) * n` contain `design`,
296/// all other rows are zero.
297pub struct AdditiveBlockJacobian {
298    pub design: Array2<f64>,
299    pub own_output: usize,
300    pub n_family_outputs: usize,
301}
302
303impl BlockEffectiveJacobian for AdditiveBlockJacobian {
304    fn effective_jacobian_rows(
305        &self,
306        state: &FamilyLinearizationState<'_>,
307        rows: Range<usize>,
308    ) -> Result<Array2<f64>, String> {
309        let n = self.design.nrows();
310        let p = self.design.ncols();
311        let rows = clamp_jacobian_rows(rows, n);
312        // Additive (linear) block: Jacobian is β-independent — design does
313        // not depend on state.beta. Verify beta contains no NaN when provided.
314        if !state.beta.is_empty() && state.beta.iter().any(|v| v.is_nan()) {
315            return Err(
316                "AdditiveBlockJacobian::effective_jacobian_at: beta contains NaN".to_string(),
317            );
318        }
319        let chunk = rows.end - rows.start;
320        let total_rows = self.n_family_outputs * chunk;
321        let mut jac = Array2::<f64>::zeros((total_rows, p));
322        let row_start = self.own_output * chunk;
323        jac.slice_mut(ndarray::s![row_start..row_start + chunk, ..])
324            .assign(&self.design.slice(ndarray::s![rows.start..rows.end, ..]));
325        Ok(jac)
326    }
327
328    fn n_outputs(&self) -> usize {
329        self.n_family_outputs
330    }
331}
332
333/// A [`BlockEffectiveJacobian`] for a single-output block whose contribution
334/// to the linear predictor is `diag(eta_scaling) · design` (row-wise scaling).
335///
336/// This is the canonical replacement for the former `eta_row_scaling` field on
337/// [`ParameterBlockSpec`].  The identifiability audit's skewness-aware bias
338/// correction can recover the scaling vector via
339/// [`BlockEffectiveJacobian::eta_row_scaling_for_skewness`].
340pub struct RowScaledJacobian {
341    pub design: Arc<Array2<f64>>,
342    pub eta_scaling: Arc<[f64]>,
343}
344
345impl BlockEffectiveJacobian for RowScaledJacobian {
346    fn effective_jacobian_rows(
347        &self,
348        state: &FamilyLinearizationState<'_>,
349        rows: Range<usize>,
350    ) -> Result<Array2<f64>, String> {
351        let n = self.design.nrows();
352        let rows = clamp_jacobian_rows(rows, n);
353        if self.eta_scaling.len() != n {
354            return Err(format!(
355                "RowScaledJacobian: eta_scaling length {} != design nrows {}",
356                self.eta_scaling.len(),
357                n,
358            ));
359        }
360        // Row-scaled blocks are β-linear; verify the linearization point
361        // contains no NaN when β is provided (sanity check on caller state).
362        if !state.beta.is_empty() && state.beta.iter().any(|v| v.is_nan()) {
363            return Err(
364                "RowScaledJacobian::effective_jacobian_at: state.beta contains NaN".to_string(),
365            );
366        }
367        let mut scaled = self
368            .design
369            .slice(ndarray::s![rows.start..rows.end, ..])
370            .to_owned();
371        for local_i in 0..scaled.nrows() {
372            let s = self.eta_scaling[rows.start + local_i];
373            for j in 0..scaled.ncols() {
374                scaled[[local_i, j]] *= s;
375            }
376        }
377        Ok(scaled)
378    }
379
380    fn eta_row_scaling_for_skewness(&self) -> Option<Arc<[f64]>> {
381        Some(Arc::clone(&self.eta_scaling))
382    }
383}
384
385pub(crate) fn clamp_jacobian_rows(rows: Range<usize>, n: usize) -> Range<usize> {
386    let start = rows.start.min(n);
387    let end = rows.end.min(n);
388    start..end.max(start)
389}
390
391/// A [`BlockEffectiveJacobian`] that composes an inner callback's raw-width
392/// effective Jacobian with a fixed reduced→raw block transform `T_b`
393/// (`p_raw × r_reduced`), so the family sees the **reduced** coordinates by
394/// construction (#933).
395///
396/// The inner callback emits its row Jacobian in the raw coordinate system
397/// (`(rows · k) × p_raw`), the layout every `BlockEffectiveJacobian` impl
398/// produces — channel-major rows, raw columns. Post-multiplying each row by
399/// `T_b` rotates those raw columns into the reduced section: the effective
400/// reduced Jacobian is `J_raw · T_b`, with `r_reduced` columns. On the model
401/// `η = J_raw · β_raw = (J_raw · T_b) · θ` this is the exact reduced operator
402/// for the reduced coefficient θ, and the family lifts θ back to β_raw through
403/// the SAME `T_b` via the one [`Gauge`](crate::Gauge).
404///
405/// This is the inversion #933 calls for: instead of forwarding a raw-width
406/// callback alongside a column-selection `T_i` (which leaves the family
407/// asserting raw column counts on a reduced spec and panicking), the callback
408/// is wrapped so its output already has the reduced width — the family captures
409/// the reduced design and its row-Hessian column-count assertions hold by
410/// construction. A column-selection `T_b` (zero/one entries) makes this exactly
411/// the audit's drop; a general orthonormal `T_b` makes it any gauge section.
412pub struct GaugeComposedJacobian {
413    inner: Arc<dyn BlockEffectiveJacobian>,
414    /// Reduced→raw block transform `T_b`, shape `(p_raw × r_reduced)`.
415    t_block: Arc<Array2<f64>>,
416}
417
418impl GaugeComposedJacobian {
419    /// Wrap `inner` so its effective Jacobian is post-multiplied by `t_block`
420    /// (`p_raw × r_reduced`). `t_block.nrows()` must equal the inner callback's
421    /// raw column count.
422    pub fn new(inner: Arc<dyn BlockEffectiveJacobian>, t_block: Arc<Array2<f64>>) -> Self {
423        Self { inner, t_block }
424    }
425}
426
427impl BlockEffectiveJacobian for GaugeComposedJacobian {
428    fn effective_jacobian_rows(
429        &self,
430        state: &FamilyLinearizationState<'_>,
431        rows: Range<usize>,
432    ) -> Result<Array2<f64>, String> {
433        let raw_width = self.t_block.nrows();
434        let reduced_width = self.t_block.ncols();
435        let lifted_beta;
436        let lifted_state;
437        let zero_raw_beta;
438        let delegate_state = if state.beta.len() == raw_width {
439            state
440        } else if state.beta.len() == reduced_width {
441            lifted_beta = self.t_block.dot(&ndarray::ArrayView1::from(state.beta));
442            lifted_state = FamilyLinearizationState {
443                beta: lifted_beta
444                    .as_slice()
445                    .expect("GaugeComposedJacobian lifted beta is contiguous"),
446                family_scalars: state.family_scalars.clone(),
447                channel_hessian: state.channel_hessian.clone(),
448                probit_frailty_scale: state.probit_frailty_scale,
449            };
450            &lifted_state
451        } else if state.beta.is_empty() {
452            zero_raw_beta = ndarray::Array1::<f64>::zeros(raw_width);
453            lifted_state = FamilyLinearizationState {
454                beta: zero_raw_beta
455                    .as_slice()
456                    .expect("GaugeComposedJacobian zero raw beta is contiguous"),
457                family_scalars: state.family_scalars.clone(),
458                channel_hessian: state.channel_hessian.clone(),
459                probit_frailty_scale: state.probit_frailty_scale,
460            };
461            &lifted_state
462        } else {
463            return Err(format!(
464                "GaugeComposedJacobian: beta has length {}, expected raw width {} \
465                 or reduced width {}; this wrapper cannot infer a block slice from a joint \
466                 coefficient vector",
467                state.beta.len(),
468                raw_width,
469                reduced_width,
470            ));
471        };
472        let j_raw = self.inner.effective_jacobian_rows(delegate_state, rows)?;
473        if j_raw.ncols() != self.t_block.nrows() {
474            return Err(format!(
475                "GaugeComposedJacobian: inner Jacobian has {} columns but T_b has {} rows",
476                j_raw.ncols(),
477                self.t_block.nrows(),
478            ));
479        }
480        // (rows·k × p_raw) · (p_raw × r_reduced) = (rows·k × r_reduced).
481        Ok(j_raw.dot(self.t_block.as_ref()))
482    }
483
484    fn n_outputs(&self) -> usize {
485        self.inner.n_outputs()
486    }
487
488    // Skewness scaling is a raw-row property; reducing the column space does not
489    // change the per-row scaling, so it is forwarded unchanged when present.
490    fn eta_row_scaling_for_skewness(&self) -> Option<Arc<[f64]>> {
491        self.inner.eta_row_scaling_for_skewness()
492    }
493}
494
495#[cfg(test)]
496mod gauge_composed_jacobian_tests {
497    use super::*;
498    use ndarray::array;
499
500    struct BetaScaledJacobian {
501        design: Array2<f64>,
502    }
503
504    impl BlockEffectiveJacobian for BetaScaledJacobian {
505        fn effective_jacobian_rows(
506            &self,
507            state: &FamilyLinearizationState<'_>,
508            rows: Range<usize>,
509        ) -> Result<Array2<f64>, String> {
510            let n = self.design.nrows();
511            let rows = rows.start.min(n)..rows.end.min(n);
512            let mut out = self.design.slice(ndarray::s![rows, ..]).to_owned();
513            for col in 0..out.ncols() {
514                let scale = 1.0 + state.beta.get(col).copied().unwrap_or(0.0);
515                out.column_mut(col).mapv_inplace(|v| v * scale);
516            }
517            Ok(out)
518        }
519
520        fn n_outputs(&self) -> usize {
521            1
522        }
523    }
524
525    #[test]
526    fn gauge_composed_jacobian_lifts_reduced_block_beta_before_delegating() {
527        let inner: Arc<dyn BlockEffectiveJacobian> = Arc::new(BetaScaledJacobian {
528            design: array![[2.0, 3.0], [5.0, 7.0]],
529        });
530        let t_block = Arc::new(array![[0.0], [1.0]]);
531        let wrapped = GaugeComposedJacobian::new(inner, Arc::clone(&t_block));
532
533        let theta = [4.0];
534        let reduced_state = FamilyLinearizationState {
535            beta: &theta,
536            family_scalars: None,
537            channel_hessian: None,
538            probit_frailty_scale: 1.0,
539        };
540        let reduced = wrapped
541            .effective_jacobian_rows(&reduced_state, 0..2)
542            .expect("reduced beta should be lifted through T before inner callback");
543
544        let raw_beta = [0.0, 4.0];
545        let raw_state = FamilyLinearizationState {
546            beta: &raw_beta,
547            family_scalars: None,
548            channel_hessian: None,
549            probit_frailty_scale: 1.0,
550        };
551        let raw = wrapped
552            .effective_jacobian_rows(&raw_state, 0..2)
553            .expect("raw beta state remains valid");
554
555        assert_eq!(reduced, raw);
556        assert_eq!(reduced, array![[15.0], [35.0]]);
557    }
558
559    struct StrictRawWidthJacobian {
560        design: Array2<f64>,
561    }
562
563    impl BlockEffectiveJacobian for StrictRawWidthJacobian {
564        fn effective_jacobian_rows(
565            &self,
566            state: &FamilyLinearizationState<'_>,
567            rows: Range<usize>,
568        ) -> Result<Array2<f64>, String> {
569            if state.beta.len() != self.design.ncols() {
570                return Err(format!(
571                    "StrictRawWidthJacobian expected raw beta len {}, got {}",
572                    self.design.ncols(),
573                    state.beta.len(),
574                ));
575            }
576            Ok(self.design.slice(ndarray::s![rows, ..]).to_owned())
577        }
578    }
579
580    #[test]
581    fn gauge_composed_jacobian_lifts_zero_reduced_beta_before_delegating() {
582        let inner: Arc<dyn BlockEffectiveJacobian> = Arc::new(StrictRawWidthJacobian {
583            design: array![[2.0, 3.0], [5.0, 7.0]],
584        });
585        let wrapped = GaugeComposedJacobian::new(inner, Arc::new(array![[0.0], [1.0]]));
586
587        let theta = [0.0];
588        let reduced_state = FamilyLinearizationState {
589            beta: &theta,
590            family_scalars: None,
591            channel_hessian: None,
592            probit_frailty_scale: 1.0,
593        };
594
595        let reduced = wrapped
596            .effective_jacobian_rows(&reduced_state, 0..2)
597            .expect("zero reduced beta must still be lifted to raw width");
598
599        assert_eq!(reduced, array![[3.0], [7.0]]);
600    }
601
602    #[test]
603    fn gauge_composed_jacobian_rejects_nonzero_unknown_beta_layout() {
604        let inner: Arc<dyn BlockEffectiveJacobian> = Arc::new(BetaScaledJacobian {
605            design: array![[2.0, 3.0]],
606        });
607        let wrapped = GaugeComposedJacobian::new(inner, Arc::new(array![[0.0], [1.0]]));
608        let joint_like_beta = [1.0, 0.0, 0.0];
609        let state = FamilyLinearizationState {
610            beta: &joint_like_beta,
611            family_scalars: None,
612            channel_hessian: None,
613            probit_frailty_scale: 1.0,
614        };
615
616        let err = wrapped
617            .effective_jacobian_rows(&state, 0..1)
618            .expect_err("nonzero joint-layout beta cannot be inferred from one block T");
619        assert!(
620            err.contains("cannot infer a block slice"),
621            "unexpected error: {err}"
622        );
623    }
624}
625
626/// Static specification for one parameter block in a custom family.
627///
628/// `design` and `stacked_design` are two structurally distinct operators:
629///
630/// * `design` is the **canonical, single-channel, n-observation operator**.
631///   `design.nrows()` ALWAYS equals `n_obs` (one row per training
632///   observation).  This is the matrix the identifiability audit, the
633///   shape policy, and every "what shape is this block?" reader inspect.
634///   For most blocks `design` is also the eta-producing operator used by
635///   the solver — see [`Self::solver_design`].
636/// * `stacked_design`, when `Some`, is the **multi-channel eta-producing
637///   operator** used by the solver.  Survival time-varying blocks stack
638///   `[exit; entry; deriv]` into a `(3·n × p)` operator here so the
639///   solver can produce a `3·n`-long `eta` in one mat-vec; the audit
640///   never sees this matrix.  When `None`, the solver uses `design` (the
641///   single-channel default).
642///
643/// The single contract that downstream code can rely on:
644/// `design.nrows() == n_obs`.  No more dual semantics on `design`.
645///
646/// Read access:
647/// * Audit / canonicalize / "n_obs is the row count" code → `&spec.design`.
648/// * Eta-producing solver code → [`Self::solver_design`].
649#[derive(Clone)]
650pub struct ParameterBlockSpec {
651    pub name: String,
652    pub design: DesignMatrix,
653    pub offset: Array1<f64>,
654    /// Block-local penalty matrices (all p_block x p_block).
655    pub penalties: Vec<PenaltyMatrix>,
656    /// Structural nullspace dimension of each penalty matrix (same length as `penalties`).
657    /// Used by the penalty pseudo-logdet to determine rank without numerical thresholds.
658    /// If empty, falls back to eigenvalue-based rank detection.
659    pub nullspace_dims: Vec<usize>,
660    /// Initial log-smoothing parameters for this block (same length as `penalties`).
661    pub initial_log_lambdas: Array1<f64>,
662    /// Optional initial coefficients (defaults to zeros if omitted).
663    pub initial_beta: Option<Array1<f64>>,
664    /// Gauge ownership priority. Higher = more likely to retain a
665    /// redundant direction during canonical-gauge reparameterisation.
666    /// Defaults to 100. Set higher for blocks that should "own" shared
667    /// affine/null-space directions (e.g. baseline time in survival).
668    pub gauge_priority: u8,
669    /// Full β-dependent Jacobian callback.  When `Some`, this is the
670    /// authoritative source for `effective_jacobian_at`.  For simple
671    /// single-output row-scaled blocks use [`RowScaledJacobian`].
672    pub jacobian_callback: Option<Arc<dyn BlockEffectiveJacobian>>,
673    /// Optional multi-channel eta-producing operator used by the solver.
674    ///
675    /// When `Some`, the solver consumes this matrix (typically
676    /// `(k·n × p)` for `k` stacked channels — e.g. survival
677    /// `[exit; entry; deriv]` with `k = 3`) to evaluate `eta = stacked · β + stacked_offset`.
678    /// The audit and shape policy NEVER read this field; they only ever
679    /// inspect `design` (which always has `n_obs` rows).
680    ///
681    /// When `None`, the solver falls back to `design` — the correct
682    /// behavior for every single-channel block (i.e. all non-survival
683    /// time-varying blocks).
684    ///
685    /// Read this field via [`Self::solver_design`], never directly.
686    ///
687    /// Invariant: when `stacked_design = Some(_)`, `stacked_offset` MUST
688    /// also be `Some(_)` and its length MUST equal `stacked_design.nrows()`.
689    pub stacked_design: Option<DesignMatrix>,
690    /// Optional offset paired with [`Self::stacked_design`]. Same Option
691    /// state as `stacked_design` (both `Some` or both `None`).
692    /// Read via [`Self::solver_offset`].
693    pub stacked_offset: Option<Array1<f64>>,
694}
695
696impl std::fmt::Debug for ParameterBlockSpec {
697    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
698        f.debug_struct("ParameterBlockSpec")
699            .field("name", &self.name)
700            .field("design", &self.design)
701            .field("offset", &self.offset)
702            .field("penalties", &self.penalties)
703            .field("nullspace_dims", &self.nullspace_dims)
704            .field("initial_log_lambdas", &self.initial_log_lambdas)
705            .field("initial_beta", &self.initial_beta)
706            .field("gauge_priority", &self.gauge_priority)
707            .field(
708                "jacobian_callback",
709                &self
710                    .jacobian_callback
711                    .as_ref()
712                    .map(|_| "<BlockEffectiveJacobian>"),
713            )
714            .finish()
715    }
716}
717
718impl ParameterBlockSpec {
719    /// Returns a ParameterBlockSpec with sensible defaults for all optional
720    /// fields. Callers using struct literal syntax can use
721    /// `..ParameterBlockSpec::defaults()` to fill in any fields added after
722    /// the literal was written.
723    pub fn defaults() -> Self {
724        Self {
725            name: String::new(),
726            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
727                ndarray::Array2::<f64>::zeros((0, 0)),
728            )),
729            offset: ndarray::Array1::<f64>::zeros(0),
730            penalties: Vec::new(),
731            nullspace_dims: Vec::new(),
732            initial_log_lambdas: ndarray::Array1::<f64>::zeros(0),
733            initial_beta: None,
734            gauge_priority: 100,
735            jacobian_callback: None,
736            stacked_design: None,
737            stacked_offset: None,
738        }
739    }
740
741    /// Returns the eta-producing operator used by the solver.
742    ///
743    /// Resolution order:
744    ///   1. `stacked_design = Some(d)` → return `d` (multi-channel
745    ///      operator, e.g. `(3n × p)` for survival time-varying blocks).
746    ///   2. otherwise → return `&self.design` (the single-channel default).
747    ///
748    /// Solver code that needs `eta = D · β` MUST call this accessor;
749    /// reading `&self.design` directly silently breaks multi-channel
750    /// (survival LS time-varying) blocks because `self.design.nrows()`
751    /// always equals `n_obs`, never `3·n_obs`.
752    pub fn solver_design(&self) -> &DesignMatrix {
753        self.stacked_design.as_ref().unwrap_or(&self.design)
754    }
755
756    /// Returns the offset paired with [`Self::solver_design`]. When
757    /// `stacked_offset = Some(o)` this returns `&o`; otherwise it falls
758    /// back to `&self.offset`.
759    pub fn solver_offset(&self) -> &Array1<f64> {
760        self.stacked_offset.as_ref().unwrap_or(&self.offset)
761    }
762
763    /// Returns the effective design `D_eff` for this block at β = 0 with no
764    /// family scalars — a convenience wrapper around [`Self::effective_jacobian_at`]
765    /// for the single-output (n_outputs = 1) case.
766    ///
767    /// Callers that need multi-output Jacobians or β-dependent scalars should
768    /// call `effective_jacobian_at` directly with the appropriate state.
769    ///
770    /// Returns `Err` if the design cannot be densified.
771    pub fn effective_design(&self, caller: &str) -> Result<ndarray::Array2<f64>, String> {
772        let p = self.design.ncols();
773        let zeros = vec![0.0f64; p];
774        let state = FamilyLinearizationState {
775            beta: &zeros,
776            family_scalars: None,
777            channel_hessian: None,
778            probit_frailty_scale: 1.0,
779        };
780        self.effective_jacobian_at(caller, &state)
781    }
782
783    /// Returns the β-dependent stacked Jacobian `J(β)` for this block.
784    ///
785    /// Shape: `(n_rows * n_outputs, p_block)`.  For most blocks `n_outputs = 1`
786    /// and the result is the familiar `(n_rows, p_block)` effective design.
787    ///
788    /// Dispatch order:
789    ///   1. `jacobian_callback = Some(cb)` → `cb.effective_jacobian_at(state)`.
790    ///   2. `jacobian_callback = None` → `design.clone()` (ignores `beta` and `family_scalars`).
791    ///
792    /// Returns `Err` if the design cannot be densified.
793    pub fn effective_jacobian_at(
794        &self,
795        caller: &str,
796        state: &FamilyLinearizationState<'_>,
797    ) -> Result<ndarray::Array2<f64>, String> {
798        if let Some(cb) = self.jacobian_callback.as_ref() {
799            return cb.effective_jacobian_at(state);
800        }
801        self.design
802            .try_to_dense_arc(&format!(
803                "{caller}::effective_jacobian_at block '{}'",
804                self.name
805            ))
806            .map(|arc| arc.as_ref().clone())
807    }
808}
809
810/// Current state for a parameter block.
811#[derive(Clone, Debug)]
812pub struct ParameterBlockState {
813    pub beta: Array1<f64>,
814    pub eta: Array1<f64>,
815}
816
817#[derive(Clone)]
818pub struct BlockGeometryDirectionalDerivative {
819    /// Directional derivative of the block design matrix along a coefficient-space direction.
820    pub d_design: Option<Array2<f64>>,
821    /// Directional derivative of the block offset along the same direction.
822    pub d_offset: Array1<f64>,
823}
824
825/// Working quantities supplied by a custom family for one block.
826///
827/// # Observed vs expected information (see response.md Section 3)
828///
829/// For the outer REML/LAML criterion, the Hessian used in log|H| and trace terms
830/// must be the **observed** (actual) Hessian at the mode, not the expected Fisher.
831///
832/// - `ExactNewton`: provides -nabla^2 log L directly, which is the observed Hessian
833///   by construction. This is always correct.
834///
835/// - `Diagonal`: provides IRLS working weights W such that the per-block Hessian
836///   is X'WX. For canonical links (logit-Binomial, log-Poisson), W_obs = W_Fisher.
837///   For supported non-canonical diagonal links, W must be the observed weight
838///   W_obs = W_Fisher - (y-mu)*B so the outer REML uses the exact Laplace
839///   Hessian. The matching `CustomFamily::diagonalworking_weights_directional_derivative`
840///   callback must differentiate the same observed W surface; silently using Fisher
841///   weights or zero `dW` would change the criterion into a PQL-type surrogate.
842#[derive(Clone, Debug)]
843pub enum BlockWorkingSet {
844    /// Standard IRLS/GLM-style diagonal working set for eta-space updates.
845    Diagonal {
846        /// IRLS pseudo-response for this block's linear predictor.
847        working_response: Array1<f64>,
848        /// IRLS working weights for this block (non-negative, length n).
849        ///
850        /// For the inner solver, Fisher or observed weights both find the same mode.
851        /// For the outer REML/LAML log|H| term, observed weights are the correct
852        /// Laplace choice (see response.md Section 3). Canonical-link families need
853        /// no correction since observed = Fisher.
854        working_weights: Array1<f64>,
855    },
856    /// Exact Newton block update in coefficient space.
857    ///
858    /// `gradient` is nabla log L wrt block coefficients.
859    /// `hessian` is -nabla^2 log L wrt block coefficients (positive semidefinite near optimum).
860    ///
861    /// This is the observed Hessian by construction (actual second derivative of the
862    /// log-likelihood), which is the correct quantity for the outer REML Laplace
863    /// approximation.
864    ExactNewton {
865        gradient: Array1<f64>,
866        hessian: SymmetricMatrix,
867    },
868}
869
870impl BlockWorkingSet {
871    /// Construct a `Diagonal` working set with the length invariant
872    /// (`working_response.len() == working_weights.len()`) enforced at the
873    /// type boundary. Use this from any new code path that produces a
874    /// diagonal IRLS block; the legacy struct-literal form is preserved for
875    /// existing call sites pending a full migration.
876    #[inline]
877    pub fn diagonal_checked(
878        working_response: Array1<f64>,
879        working_weights: Array1<f64>,
880    ) -> Result<Self, String> {
881        if working_response.len() != working_weights.len() {
882            return Err(format!(
883                "BlockWorkingSet::Diagonal length mismatch: working_response={}, working_weights={}",
884                working_response.len(),
885                working_weights.len(),
886            ));
887        }
888        Ok(Self::Diagonal {
889            working_response,
890            working_weights,
891        })
892    }
893}