Skip to main content

gam_problem/
block_spec.rs

1//! Data model for the blockwise carrier (subset moved down to `gam-problem`):
2//! parameter-block specs, the effective-Jacobian and channel-Hessian
3//! abstractions, per-block working sets and states, and the block geometry
4//! directional derivative.
5//!
6//! The coefficient-group/label/prior types, `custom_family_block_role`, and the
7//! blockspec validators remain in the family crate because they depend on
8//! `CoefficientGroupPrior`/`RhoPrior`/`BlockRole`/`CustomFamilyError`.
9
10use std::ops::Range;
11use std::sync::Arc;
12
13use ndarray::{Array1, Array2};
14
15use crate::PenaltyMatrix;
16use gam_linalg::matrix::{DesignMatrix, SymmetricMatrix};
17
18/// Per-subject channel Hessian provider for multi-output families.
19///
20/// The Fisher information decomposition for multi-output families is
21///
22/// ```text
23/// I(β) = Σ_i  J_iᵀ W_i J_i
24/// ```
25///
26/// where `J_i` is the channel-stacked Jacobian (shape `n_outputs × p` for
27/// subject `i`) and `W_i` is the `n_outputs × n_outputs` per-subject channel
28/// Hessian of the row negative log-likelihood (the second-derivative block of
29/// `−log L_i(u_i)` at a pilot β, PSD-clamped).
30///
31/// For single-output families this is the scalar IRLS weight; for multi-output
32/// families (survival marginal-slope: `n_outputs = 4`; location-scale:
33/// `n_outputs = 2`) it carries full cross-channel curvature.
34///
35/// The identifiability canonicalisation step uses the `n_outputs`-channel
36/// weighted joint design `W_joint = Σ_i sqrt(W_i) ⊗ J_i` to detect
37/// block-against-block aliasing.  When this trait is present on
38/// `ParameterBlockSpec::channel_hessian`, `canonicalize_for_identifiability`
39/// routes through `audit_identifiability_channel_aware`; when absent it falls
40/// back to the scalar-weight flat audit.
41///
42/// # W-metric rank theorem
43///
44/// The canonicalisation computes `rank(J^T W J)` where `W_blkdiag =
45/// block-diagonal of per-subject W_i`.  This rank equals
46///
47/// ```text
48/// rank(J) − dim(range(J) ∩ ker(W_blkdiag))
49/// ```
50///
51/// i.e. columns of `J` that lie in the kernel of `W_blkdiag` (flat directions
52/// in the curvature landscape at the pilot β) are correctly identified as
53/// curvature-redundant and may be dropped.
54pub trait FamilyChannelHessian: Send + Sync {
55    /// Number of output channels `n_outputs` (= K in the row Jacobian).
56    fn n_outputs(&self) -> usize;
57
58    /// Number of subjects (rows).
59    fn n_subjects(&self) -> usize;
60
61    /// Fill the `n_outputs × n_outputs` per-subject channel Hessian `W_i`
62    /// into `out` (row-major, length `n_outputs * n_outputs`) for subject `i`.
63    /// Negative eigenvalues must be clamped to zero (PSD projection) before
64    /// or inside this call.
65    fn fill_subject(&self, i: usize, out: &mut [f64]);
66
67    /// Materialise the full `(n_subjects × n_outputs × n_outputs)` tensor.
68    /// Default implementation calls `fill_subject` for each row.
69    fn evaluate_full(&self) -> ndarray::Array3<f64> {
70        let n = self.n_subjects();
71        let k = self.n_outputs();
72        let mut out = ndarray::Array3::<f64>::zeros((n, k, k));
73        let mut buf = vec![0.0_f64; k * k];
74        for i in 0..n {
75            self.fill_subject(i, &mut buf);
76            for a in 0..k {
77                for b in 0..k {
78                    out[[i, a, b]] = buf[a * k + b];
79                }
80            }
81        }
82        out
83    }
84
85    /// Return a refreshed W evaluated at `beta` using `family_scalars` when
86    /// those scalars carry the per-row primary state at the current β.
87    ///
88    /// # Fisher information identity
89    ///
90    /// `I(β) = J(β)^T W(β) J(β)`. T8 originally froze W at β=0; T34 refreshes
91    /// both J and W at the current β so the audit's rank verdict reflects the
92    /// actual local identifiability.
93    ///
94    /// # Default implementation (β-independent W)
95    ///
96    /// Families whose W is β-independent (e.g. Gaussian-identity where
97    /// `W = prior_w`) return a clone of their frozen W by delegating to
98    /// `evaluate_full()`. No recomputation is performed. `beta` and
99    /// `family_scalars` are ignored.
100    ///
101    /// # Override (β-dependent W)
102    ///
103    /// Families with β-dependent W (e.g. survival marginal-slope where
104    /// `W_i(β)` depends on `(q0_i, q1_i, qd1_i, g_i)`) must override this
105    /// method and recompute W from the current primary state.
106    ///
107    /// When `beta` is non-zero in a way that affects W (i.e. `g_i != 0`),
108    /// `family_scalars` MUST be `Some(..)`. Return `Err` if scalars are
109    /// missing in that case (same error-message style as T26's contract).
110    fn channel_hessian_at(
111        &self,
112        beta: &[f64],
113        family_scalars: Option<&std::sync::Arc<dyn std::any::Any + Send + Sync>>,
114    ) -> Result<Arc<dyn FamilyChannelHessian>, String> {
115        // Default: W is β-independent — return a snapshot of the frozen W
116        // wrapped in a simple tensor-backed implementation. β and
117        // family_scalars are validated (NaN-guard, presence flag) so callers
118        // that pass garbage state still see an Err rather than a silently-stale
119        // W. The default impl does not require family_scalars; family-specific
120        // overrides may.
121        if beta.iter().any(|v| v.is_nan()) {
122            return Err("channel_hessian_at: beta contains NaN".to_string());
123        }
124        // Acknowledge family_scalars without binding it to a discarded name.
125        if family_scalars.is_some() && beta.is_empty() {
126            return Err(
127                "channel_hessian_at: family_scalars supplied but beta is empty".to_string(),
128            );
129        }
130        let tensor = self.evaluate_full();
131        Ok(Arc::new(TensorChannelHessian { h: tensor }))
132    }
133}
134
135/// A [`FamilyChannelHessian`] backed directly by a pre-computed
136/// `(n × K × K)` tensor. Used by the default `channel_hessian_at`
137/// implementation and by tests.
138///
139/// This is the β-independent path: `fill_subject` reads from the frozen
140/// tensor without any recomputation.
141pub struct TensorChannelHessian {
142    pub h: ndarray::Array3<f64>,
143}
144
145impl FamilyChannelHessian for TensorChannelHessian {
146    fn n_outputs(&self) -> usize {
147        self.h.shape()[1]
148    }
149
150    fn n_subjects(&self) -> usize {
151        self.h.shape()[0]
152    }
153
154    fn fill_subject(&self, i: usize, out: &mut [f64]) {
155        let k = self.h.shape()[1];
156        assert_eq!(out.len(), k * k);
157        for a in 0..k {
158            for b in 0..k {
159                out[a * k + b] = self.h[[i, a, b]];
160            }
161        }
162    }
163
164    fn evaluate_full(&self) -> ndarray::Array3<f64> {
165        self.h.clone()
166    }
167}
168
169/// β-linearization state passed to [`BlockEffectiveJacobian::effective_jacobian_at`].
170///
171/// At pre-fit initialization, pass `beta = &[]` / zeros and `family_scalars = None`.
172/// Families that need β-dependent scalars (e.g. survival marginal-slope's q0, q1,
173/// g, c, z) store them in `family_scalars` as a concrete type behind
174/// `Arc<dyn Any + Send + Sync>` and downcast inside their impl.
175pub struct FamilyLinearizationState<'a> {
176    pub beta: &'a [f64],
177    /// Optional family-shared scalars at this β linearization.
178    /// Downcast via `state.family_scalars.as_ref().and_then(|a| a.downcast_ref::<T>())`.
179    pub family_scalars: Option<Arc<dyn std::any::Any + Send + Sync>>,
180    /// Optional per-subject channel Hessian for multi-output families.
181    /// When `Some`, the identifiability canonicalisation step and the Gram
182    /// builder use the channel-stacked Fisher information instead of the
183    /// scalar-weight approximation.  Single-output families leave this `None`.
184    pub channel_hessian: Option<Arc<dyn FamilyChannelHessian>>,
185    /// Probit frailty scale factor `s_f = 1/√(1+σ²)`.
186    ///
187    /// For survival marginal-slope families the logslope η contribution is
188    /// `s_f · g · z`, so any Jacobian callback that depends on g or z must
189    /// read `s_f` from here rather than from a captured-at-construction value.
190    /// When σ = 0 (no frailty) or for non-frailty families, set this to 1.0.
191    ///
192    /// Since σ is always **fixed** (not jointly optimised with β) in the
193    /// survival family, `s_f` is a static scalar for the entire inner fit;
194    /// `∂s_f/∂σ` never appears in the β-Jacobian.  The field is nonetheless
195    /// carried through state so that Jacobian callbacks are not required to
196    /// capture `s_f` at spec-construction time — they can read it at
197    /// evaluation time and thus stay correct across outer-loop σ updates.
198    pub probit_frailty_scale: f64,
199}
200
201/// β-dependent Jacobian callback for a parameter block.
202///
203/// Principled long-term contract for expressing how a block contributes to
204/// the stacked linear predictor at a given β:
205///
206/// ```text
207/// J(β) ∈ ℝ^{n_rows · n_outputs × p_block}
208/// ```
209///
210/// - Single-output linear block: returns `design.clone()`.
211/// - Row-scaled block (`RowScaledJacobian`): returns `diag(eta_scaling) · design` (still linear in β).
212/// - Multi-output block (e.g. survival marginal-slope with η0, η1, ad1):
213///   stacks `∂eta_r/∂β_k` for `r ∈ 0..n_outputs`, row-major ordering.
214///
215/// The default impl on [`ParameterBlockSpec::effective_jacobian_at`] is:
216/// - `jacobian_callback = None` → `design.clone()`.
217/// - `jacobian_callback = Some(cb)` → delegates to `cb.effective_jacobian_at`.
218pub trait BlockEffectiveJacobian: Send + Sync {
219    /// Stacked multi-output Jacobian for a contiguous observation row range.
220    ///
221    /// Shape: `(rows.len() * n_outputs, p_block)`, with the same channel-major
222    /// layout as [`Self::effective_jacobian_at`]: row
223    /// `channel * rows.len() + local_row` is `rows.start + local_row` in that
224    /// output channel. Implementations should keep this as the single source of
225    /// row math so large construction-time audits can stream chunks instead of
226    /// materialising all `n * p * K` entries at once.
227    fn effective_jacobian_rows(
228        &self,
229        state: &FamilyLinearizationState<'_>,
230        rows: Range<usize>,
231    ) -> Result<Array2<f64>, String>;
232
233    /// Stacked multi-output Jacobian at the current β.
234    ///
235    /// Shape: `(n_rows * n_outputs, p_block)`, **channel-major**: rows
236    /// `r * n_rows .. (r + 1) * n_rows` carry output channel `r`'s row
237    /// Jacobian, so `stacked[r * n_rows + i, j]` is observation `i`'s row at
238    /// output `r` and coefficient column `j`.  Every consumer that destacks
239    /// this matrix (audit, canonicaliser, fit) relies on this layout — see
240    /// `BlockJacobianAsRowOp::from_callback` for the destacking transpose.
241    /// For `n_outputs = 1` this is identical to the `(n_rows, p_block)` effective
242    /// design used by the flat identifiability audit.
243    fn effective_jacobian_at(
244        &self,
245        state: &FamilyLinearizationState<'_>,
246    ) -> Result<Array2<f64>, String> {
247        let full = self.effective_jacobian_rows(state, 0..usize::MAX)?;
248        Ok(full)
249    }
250
251    /// Number of stacked output channels. 1 for most blocks.
252    fn n_outputs(&self) -> usize {
253        1
254    }
255
256    /// Returns the per-row scaling vector when this callback is a simple
257    /// diagonal-scaling block (`RowScaledJacobian`).  Used by the
258    /// identifiability audit's skewness-aware bias correction (T25).
259    ///
260    /// Returns `None` for all blocks except `RowScaledJacobian`.
261    fn eta_row_scaling_for_skewness(&self) -> Option<Arc<[f64]>> {
262        None
263    }
264}
265
266/// A [`BlockEffectiveJacobian`] for any block that contributes linearly to
267/// exactly one output of a multi-output family.
268///
269/// `own_output` is the zero-based output index that this block drives.
270/// `n_family_outputs` is the total number of outputs (e.g. 2 for location-scale).
271/// `design` is the block's effective design matrix (n × p_block).
272///
273/// The returned Jacobian has shape `(n_family_outputs * n, p_block)`:
274/// rows `own_output * n .. (own_output + 1) * n` contain `design`,
275/// all other rows are zero.
276pub struct AdditiveBlockJacobian {
277    pub design: Array2<f64>,
278    pub own_output: usize,
279    pub n_family_outputs: usize,
280}
281
282impl BlockEffectiveJacobian for AdditiveBlockJacobian {
283    fn effective_jacobian_rows(
284        &self,
285        state: &FamilyLinearizationState<'_>,
286        rows: Range<usize>,
287    ) -> Result<Array2<f64>, String> {
288        let n = self.design.nrows();
289        let p = self.design.ncols();
290        let rows = clamp_jacobian_rows(rows, n);
291        // Additive (linear) block: Jacobian is β-independent — design does
292        // not depend on state.beta. Verify beta contains no NaN when provided.
293        if !state.beta.is_empty() && state.beta.iter().any(|v| v.is_nan()) {
294            return Err(
295                "AdditiveBlockJacobian::effective_jacobian_at: beta contains NaN".to_string(),
296            );
297        }
298        let chunk = rows.end - rows.start;
299        let total_rows = self.n_family_outputs * chunk;
300        let mut jac = Array2::<f64>::zeros((total_rows, p));
301        let row_start = self.own_output * chunk;
302        jac.slice_mut(ndarray::s![row_start..row_start + chunk, ..])
303            .assign(&self.design.slice(ndarray::s![rows.start..rows.end, ..]));
304        Ok(jac)
305    }
306
307    fn n_outputs(&self) -> usize {
308        self.n_family_outputs
309    }
310}
311
312/// A [`BlockEffectiveJacobian`] for a single-output block whose contribution
313/// to the linear predictor is `diag(eta_scaling) · design` (row-wise scaling).
314///
315/// This is the canonical replacement for the former `eta_row_scaling` field on
316/// [`ParameterBlockSpec`].  The identifiability audit's skewness-aware bias
317/// correction can recover the scaling vector via
318/// [`BlockEffectiveJacobian::eta_row_scaling_for_skewness`].
319pub struct RowScaledJacobian {
320    pub design: Arc<Array2<f64>>,
321    pub eta_scaling: Arc<[f64]>,
322}
323
324impl BlockEffectiveJacobian for RowScaledJacobian {
325    fn effective_jacobian_rows(
326        &self,
327        state: &FamilyLinearizationState<'_>,
328        rows: Range<usize>,
329    ) -> Result<Array2<f64>, String> {
330        let n = self.design.nrows();
331        let rows = clamp_jacobian_rows(rows, n);
332        if self.eta_scaling.len() != n {
333            return Err(format!(
334                "RowScaledJacobian: eta_scaling length {} != design nrows {}",
335                self.eta_scaling.len(),
336                n,
337            ));
338        }
339        // Row-scaled blocks are β-linear; verify the linearization point
340        // contains no NaN when β is provided (sanity check on caller state).
341        if !state.beta.is_empty() && state.beta.iter().any(|v| v.is_nan()) {
342            return Err(
343                "RowScaledJacobian::effective_jacobian_at: state.beta contains NaN".to_string(),
344            );
345        }
346        let mut scaled = self
347            .design
348            .slice(ndarray::s![rows.start..rows.end, ..])
349            .to_owned();
350        for local_i in 0..scaled.nrows() {
351            let s = self.eta_scaling[rows.start + local_i];
352            for j in 0..scaled.ncols() {
353                scaled[[local_i, j]] *= s;
354            }
355        }
356        Ok(scaled)
357    }
358
359    fn eta_row_scaling_for_skewness(&self) -> Option<Arc<[f64]>> {
360        Some(Arc::clone(&self.eta_scaling))
361    }
362}
363
364pub(crate) fn clamp_jacobian_rows(rows: Range<usize>, n: usize) -> Range<usize> {
365    let start = rows.start.min(n);
366    let end = rows.end.min(n);
367    start..end.max(start)
368}
369
370/// A [`BlockEffectiveJacobian`] that composes an inner callback's raw-width
371/// effective Jacobian with a fixed reduced→raw block transform `T_b`
372/// (`p_raw × r_reduced`), so the family sees the **reduced** coordinates by
373/// construction (#933).
374///
375/// The inner callback emits its row Jacobian in the raw coordinate system
376/// (`(rows · k) × p_raw`), the layout every `BlockEffectiveJacobian` impl
377/// produces — channel-major rows, raw columns. Post-multiplying each row by
378/// `T_b` rotates those raw columns into the reduced section: the effective
379/// reduced Jacobian is `J_raw · T_b`, with `r_reduced` columns. On the model
380/// `η = J_raw · β_raw = (J_raw · T_b) · θ` this is the exact reduced operator
381/// for the reduced coefficient θ, and the family lifts θ back to β_raw through
382/// the SAME `T_b` via the one [`Gauge`](crate::Gauge).
383///
384/// This is the inversion #933 calls for: instead of forwarding a raw-width
385/// callback alongside a column-selection `T_i` (which leaves the family
386/// asserting raw column counts on a reduced spec and panicking), the callback
387/// is wrapped so its output already has the reduced width — the family captures
388/// the reduced design and its row-Hessian column-count assertions hold by
389/// construction. A column-selection `T_b` (zero/one entries) makes this exactly
390/// the audit's drop; a general orthonormal `T_b` makes it any gauge section.
391pub struct GaugeComposedJacobian {
392    inner: Arc<dyn BlockEffectiveJacobian>,
393    /// Reduced→raw block transform `T_b`, shape `(p_raw × r_reduced)`.
394    t_block: Arc<Array2<f64>>,
395}
396
397impl GaugeComposedJacobian {
398    /// Wrap `inner` so its effective Jacobian is post-multiplied by `t_block`
399    /// (`p_raw × r_reduced`). `t_block.nrows()` must equal the inner callback's
400    /// raw column count.
401    pub fn new(inner: Arc<dyn BlockEffectiveJacobian>, t_block: Arc<Array2<f64>>) -> Self {
402        Self { inner, t_block }
403    }
404}
405
406impl BlockEffectiveJacobian for GaugeComposedJacobian {
407    fn effective_jacobian_rows(
408        &self,
409        state: &FamilyLinearizationState<'_>,
410        rows: Range<usize>,
411    ) -> Result<Array2<f64>, String> {
412        let raw_width = self.t_block.nrows();
413        let reduced_width = self.t_block.ncols();
414        let lifted_beta;
415        let lifted_state;
416        let zero_raw_beta;
417        let delegate_state = if state.beta.len() == raw_width {
418            state
419        } else if state.beta.len() == reduced_width {
420            lifted_beta = self.t_block.dot(&ndarray::ArrayView1::from(state.beta));
421            lifted_state = FamilyLinearizationState {
422                beta: lifted_beta
423                    .as_slice()
424                    .expect("GaugeComposedJacobian lifted beta is contiguous"),
425                family_scalars: state.family_scalars.clone(),
426                channel_hessian: state.channel_hessian.clone(),
427                probit_frailty_scale: state.probit_frailty_scale,
428            };
429            &lifted_state
430        } else if state.beta.is_empty() {
431            zero_raw_beta = ndarray::Array1::<f64>::zeros(raw_width);
432            lifted_state = FamilyLinearizationState {
433                beta: zero_raw_beta
434                    .as_slice()
435                    .expect("GaugeComposedJacobian zero raw beta is contiguous"),
436                family_scalars: state.family_scalars.clone(),
437                channel_hessian: state.channel_hessian.clone(),
438                probit_frailty_scale: state.probit_frailty_scale,
439            };
440            &lifted_state
441        } else {
442            return Err(format!(
443                "GaugeComposedJacobian: beta has length {}, expected raw width {} \
444                 or reduced width {}; this wrapper cannot infer a block slice from a joint \
445                 coefficient vector",
446                state.beta.len(),
447                raw_width,
448                reduced_width,
449            ));
450        };
451        let j_raw = self.inner.effective_jacobian_rows(delegate_state, rows)?;
452        if j_raw.ncols() != self.t_block.nrows() {
453            return Err(format!(
454                "GaugeComposedJacobian: inner Jacobian has {} columns but T_b has {} rows",
455                j_raw.ncols(),
456                self.t_block.nrows(),
457            ));
458        }
459        // (rows·k × p_raw) · (p_raw × r_reduced) = (rows·k × r_reduced).
460        Ok(j_raw.dot(self.t_block.as_ref()))
461    }
462
463    fn n_outputs(&self) -> usize {
464        self.inner.n_outputs()
465    }
466
467    // Skewness scaling is a raw-row property; reducing the column space does not
468    // change the per-row scaling, so it is forwarded unchanged when present.
469    fn eta_row_scaling_for_skewness(&self) -> Option<Arc<[f64]>> {
470        self.inner.eta_row_scaling_for_skewness()
471    }
472}
473
474#[cfg(test)]
475mod gauge_composed_jacobian_tests {
476    use super::*;
477    use ndarray::array;
478
479    struct BetaScaledJacobian {
480        design: Array2<f64>,
481    }
482
483    impl BlockEffectiveJacobian for BetaScaledJacobian {
484        fn effective_jacobian_rows(
485            &self,
486            state: &FamilyLinearizationState<'_>,
487            rows: Range<usize>,
488        ) -> Result<Array2<f64>, String> {
489            let n = self.design.nrows();
490            let rows = rows.start.min(n)..rows.end.min(n);
491            let mut out = self.design.slice(ndarray::s![rows, ..]).to_owned();
492            for col in 0..out.ncols() {
493                let scale = 1.0 + state.beta.get(col).copied().unwrap_or(0.0);
494                out.column_mut(col).mapv_inplace(|v| v * scale);
495            }
496            Ok(out)
497        }
498
499        fn n_outputs(&self) -> usize {
500            1
501        }
502    }
503
504    #[test]
505    fn gauge_composed_jacobian_lifts_reduced_block_beta_before_delegating() {
506        let inner: Arc<dyn BlockEffectiveJacobian> = Arc::new(BetaScaledJacobian {
507            design: array![[2.0, 3.0], [5.0, 7.0]],
508        });
509        let t_block = Arc::new(array![[0.0], [1.0]]);
510        let wrapped = GaugeComposedJacobian::new(inner, Arc::clone(&t_block));
511
512        let theta = [4.0];
513        let reduced_state = FamilyLinearizationState {
514            beta: &theta,
515            family_scalars: None,
516            channel_hessian: None,
517            probit_frailty_scale: 1.0,
518        };
519        let reduced = wrapped
520            .effective_jacobian_rows(&reduced_state, 0..2)
521            .expect("reduced beta should be lifted through T before inner callback");
522
523        let raw_beta = [0.0, 4.0];
524        let raw_state = FamilyLinearizationState {
525            beta: &raw_beta,
526            family_scalars: None,
527            channel_hessian: None,
528            probit_frailty_scale: 1.0,
529        };
530        let raw = wrapped
531            .effective_jacobian_rows(&raw_state, 0..2)
532            .expect("raw beta state remains valid");
533
534        assert_eq!(reduced, raw);
535        assert_eq!(reduced, array![[15.0], [35.0]]);
536    }
537
538    struct StrictRawWidthJacobian {
539        design: Array2<f64>,
540    }
541
542    impl BlockEffectiveJacobian for StrictRawWidthJacobian {
543        fn effective_jacobian_rows(
544            &self,
545            state: &FamilyLinearizationState<'_>,
546            rows: Range<usize>,
547        ) -> Result<Array2<f64>, String> {
548            if state.beta.len() != self.design.ncols() {
549                return Err(format!(
550                    "StrictRawWidthJacobian expected raw beta len {}, got {}",
551                    self.design.ncols(),
552                    state.beta.len(),
553                ));
554            }
555            Ok(self.design.slice(ndarray::s![rows, ..]).to_owned())
556        }
557    }
558
559    #[test]
560    fn gauge_composed_jacobian_lifts_zero_reduced_beta_before_delegating() {
561        let inner: Arc<dyn BlockEffectiveJacobian> = Arc::new(StrictRawWidthJacobian {
562            design: array![[2.0, 3.0], [5.0, 7.0]],
563        });
564        let wrapped = GaugeComposedJacobian::new(inner, Arc::new(array![[0.0], [1.0]]));
565
566        let theta = [0.0];
567        let reduced_state = FamilyLinearizationState {
568            beta: &theta,
569            family_scalars: None,
570            channel_hessian: None,
571            probit_frailty_scale: 1.0,
572        };
573
574        let reduced = wrapped
575            .effective_jacobian_rows(&reduced_state, 0..2)
576            .expect("zero reduced beta must still be lifted to raw width");
577
578        assert_eq!(reduced, array![[3.0], [7.0]]);
579    }
580
581    #[test]
582    fn gauge_composed_jacobian_rejects_nonzero_unknown_beta_layout() {
583        let inner: Arc<dyn BlockEffectiveJacobian> = Arc::new(BetaScaledJacobian {
584            design: array![[2.0, 3.0]],
585        });
586        let wrapped = GaugeComposedJacobian::new(inner, Arc::new(array![[0.0], [1.0]]));
587        let joint_like_beta = [1.0, 0.0, 0.0];
588        let state = FamilyLinearizationState {
589            beta: &joint_like_beta,
590            family_scalars: None,
591            channel_hessian: None,
592            probit_frailty_scale: 1.0,
593        };
594
595        let err = wrapped
596            .effective_jacobian_rows(&state, 0..1)
597            .expect_err("nonzero joint-layout beta cannot be inferred from one block T");
598        assert!(
599            err.contains("cannot infer a block slice"),
600            "unexpected error: {err}"
601        );
602    }
603}
604
605/// Static specification for one parameter block in a custom family.
606///
607/// `design` and `stacked_design` are two structurally distinct operators:
608///
609/// * `design` is the **canonical, single-channel, n-observation operator**.
610///   `design.nrows()` ALWAYS equals `n_obs` (one row per training
611///   observation).  This is the matrix the identifiability audit, the
612///   shape policy, and every "what shape is this block?" reader inspect.
613///   For most blocks `design` is also the eta-producing operator used by
614///   the solver — see [`Self::solver_design`].
615/// * `stacked_design`, when `Some`, is the **multi-channel eta-producing
616///   operator** used by the solver.  Survival time-varying blocks stack
617///   `[exit; entry; deriv]` into a `(3·n × p)` operator here so the
618///   solver can produce a `3·n`-long `eta` in one mat-vec; the audit
619///   never sees this matrix.  When `None`, the solver uses `design` (the
620///   single-channel default).
621///
622/// The single contract that downstream code can rely on:
623/// `design.nrows() == n_obs`.  No more dual semantics on `design`.
624///
625/// Read access:
626/// * Audit / canonicalize / "n_obs is the row count" code → `&spec.design`.
627/// * Eta-producing solver code → [`Self::solver_design`].
628#[derive(Clone)]
629pub struct ParameterBlockSpec {
630    pub name: String,
631    pub design: DesignMatrix,
632    pub offset: Array1<f64>,
633    /// Block-local penalty matrices (all p_block x p_block).
634    pub penalties: Vec<PenaltyMatrix>,
635    /// Structural nullspace dimension of each penalty matrix (same length as `penalties`).
636    /// Used by the penalty pseudo-logdet to determine rank without numerical thresholds.
637    /// If empty, falls back to eigenvalue-based rank detection.
638    pub nullspace_dims: Vec<usize>,
639    /// Initial log-smoothing parameters for this block (same length as `penalties`).
640    pub initial_log_lambdas: Array1<f64>,
641    /// Optional initial coefficients (defaults to zeros if omitted).
642    pub initial_beta: Option<Array1<f64>>,
643    /// Gauge ownership priority. Higher = more likely to retain a
644    /// redundant direction during canonical-gauge reparameterisation.
645    /// Defaults to 100. Set higher for blocks that should "own" shared
646    /// affine/null-space directions (e.g. baseline time in survival).
647    pub gauge_priority: u8,
648    /// Full β-dependent Jacobian callback.  When `Some`, this is the
649    /// authoritative source for `effective_jacobian_at`.  For simple
650    /// single-output row-scaled blocks use [`RowScaledJacobian`].
651    pub jacobian_callback: Option<Arc<dyn BlockEffectiveJacobian>>,
652    /// Optional multi-channel eta-producing operator used by the solver.
653    ///
654    /// When `Some`, the solver consumes this matrix (typically
655    /// `(k·n × p)` for `k` stacked channels — e.g. survival
656    /// `[exit; entry; deriv]` with `k = 3`) to evaluate `eta = stacked · β + stacked_offset`.
657    /// The audit and shape policy NEVER read this field; they only ever
658    /// inspect `design` (which always has `n_obs` rows).
659    ///
660    /// When `None`, the solver falls back to `design` — the correct
661    /// behavior for every single-channel block (i.e. all non-survival
662    /// time-varying blocks).
663    ///
664    /// Read this field via [`Self::solver_design`], never directly.
665    ///
666    /// Invariant: when `stacked_design = Some(_)`, `stacked_offset` MUST
667    /// also be `Some(_)` and its length MUST equal `stacked_design.nrows()`.
668    pub stacked_design: Option<DesignMatrix>,
669    /// Optional offset paired with [`Self::stacked_design`]. Same Option
670    /// state as `stacked_design` (both `Some` or both `None`).
671    /// Read via [`Self::solver_offset`].
672    pub stacked_offset: Option<Array1<f64>>,
673}
674
675impl std::fmt::Debug for ParameterBlockSpec {
676    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677        f.debug_struct("ParameterBlockSpec")
678            .field("name", &self.name)
679            .field("design", &self.design)
680            .field("offset", &self.offset)
681            .field("penalties", &self.penalties)
682            .field("nullspace_dims", &self.nullspace_dims)
683            .field("initial_log_lambdas", &self.initial_log_lambdas)
684            .field("initial_beta", &self.initial_beta)
685            .field("gauge_priority", &self.gauge_priority)
686            .field(
687                "jacobian_callback",
688                &self
689                    .jacobian_callback
690                    .as_ref()
691                    .map(|_| "<BlockEffectiveJacobian>"),
692            )
693            .finish()
694    }
695}
696
697impl ParameterBlockSpec {
698    /// Returns a ParameterBlockSpec with sensible defaults for all optional
699    /// fields. Callers using struct literal syntax can use
700    /// `..ParameterBlockSpec::defaults()` to fill in any fields added after
701    /// the literal was written.
702    pub fn defaults() -> Self {
703        Self {
704            name: String::new(),
705            design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
706                ndarray::Array2::<f64>::zeros((0, 0)),
707            )),
708            offset: ndarray::Array1::<f64>::zeros(0),
709            penalties: Vec::new(),
710            nullspace_dims: Vec::new(),
711            initial_log_lambdas: ndarray::Array1::<f64>::zeros(0),
712            initial_beta: None,
713            gauge_priority: 100,
714            jacobian_callback: None,
715            stacked_design: None,
716            stacked_offset: None,
717        }
718    }
719
720    /// Returns the eta-producing operator used by the solver.
721    ///
722    /// Resolution order:
723    ///   1. `stacked_design = Some(d)` → return `d` (multi-channel
724    ///      operator, e.g. `(3n × p)` for survival time-varying blocks).
725    ///   2. otherwise → return `&self.design` (the single-channel default).
726    ///
727    /// Solver code that needs `eta = D · β` MUST call this accessor;
728    /// reading `&self.design` directly silently breaks multi-channel
729    /// (survival LS time-varying) blocks because `self.design.nrows()`
730    /// always equals `n_obs`, never `3·n_obs`.
731    pub fn solver_design(&self) -> &DesignMatrix {
732        self.stacked_design.as_ref().unwrap_or(&self.design)
733    }
734
735    /// Returns the offset paired with [`Self::solver_design`]. When
736    /// `stacked_offset = Some(o)` this returns `&o`; otherwise it falls
737    /// back to `&self.offset`.
738    pub fn solver_offset(&self) -> &Array1<f64> {
739        self.stacked_offset.as_ref().unwrap_or(&self.offset)
740    }
741
742    /// Returns the effective design `D_eff` for this block at β = 0 with no
743    /// family scalars — a convenience wrapper around [`Self::effective_jacobian_at`]
744    /// for the single-output (n_outputs = 1) case.
745    ///
746    /// Callers that need multi-output Jacobians or β-dependent scalars should
747    /// call `effective_jacobian_at` directly with the appropriate state.
748    ///
749    /// Returns `Err` if the design cannot be densified.
750    pub fn effective_design(&self, caller: &str) -> Result<ndarray::Array2<f64>, String> {
751        let p = self.design.ncols();
752        let zeros = vec![0.0f64; p];
753        let state = FamilyLinearizationState {
754            beta: &zeros,
755            family_scalars: None,
756            channel_hessian: None,
757            probit_frailty_scale: 1.0,
758        };
759        self.effective_jacobian_at(caller, &state)
760    }
761
762    /// Returns the β-dependent stacked Jacobian `J(β)` for this block.
763    ///
764    /// Shape: `(n_rows * n_outputs, p_block)`.  For most blocks `n_outputs = 1`
765    /// and the result is the familiar `(n_rows, p_block)` effective design.
766    ///
767    /// Dispatch order:
768    ///   1. `jacobian_callback = Some(cb)` → `cb.effective_jacobian_at(state)`.
769    ///   2. `jacobian_callback = None` → `design.clone()` (ignores `beta` and `family_scalars`).
770    ///
771    /// Returns `Err` if the design cannot be densified.
772    pub fn effective_jacobian_at(
773        &self,
774        caller: &str,
775        state: &FamilyLinearizationState<'_>,
776    ) -> Result<ndarray::Array2<f64>, String> {
777        if let Some(cb) = self.jacobian_callback.as_ref() {
778            return cb.effective_jacobian_at(state);
779        }
780        self.design
781            .try_to_dense_arc(&format!(
782                "{caller}::effective_jacobian_at block '{}'",
783                self.name
784            ))
785            .map(|arc| arc.as_ref().clone())
786    }
787}
788
789/// Current state for a parameter block.
790#[derive(Clone, Debug)]
791pub struct ParameterBlockState {
792    pub beta: Array1<f64>,
793    pub eta: Array1<f64>,
794}
795
796#[derive(Clone)]
797pub struct BlockGeometryDirectionalDerivative {
798    /// Directional derivative of the block design matrix along a coefficient-space direction.
799    pub d_design: Option<Array2<f64>>,
800    /// Directional derivative of the block offset along the same direction.
801    pub d_offset: Array1<f64>,
802}
803
804/// Working quantities supplied by a custom family for one block.
805///
806/// # Observed vs expected information (see response.md Section 3)
807///
808/// For the outer REML/LAML criterion, the Hessian used in log|H| and trace terms
809/// must be the **observed** (actual) Hessian at the mode, not the expected Fisher.
810///
811/// - `ExactNewton`: provides -nabla^2 log L directly, which is the observed Hessian
812///   by construction. This is always correct.
813///
814/// - `Diagonal`: provides IRLS working weights W such that the per-block Hessian
815///   is X'WX. For canonical links (logit-Binomial, log-Poisson), W_obs = W_Fisher.
816///   For supported non-canonical diagonal links, W must be the observed weight
817///   W_obs = W_Fisher - (y-mu)*B so the outer REML uses the exact Laplace
818///   Hessian. The matching `CustomFamily::diagonalworking_weights_directional_derivative`
819///   callback must differentiate the same observed W surface; silently using Fisher
820///   weights or zero `dW` would change the criterion into a PQL-type surrogate.
821#[derive(Clone, Debug)]
822pub enum BlockWorkingSet {
823    /// Standard IRLS/GLM-style diagonal working set for eta-space updates.
824    Diagonal {
825        /// IRLS pseudo-response for this block's linear predictor.
826        working_response: Array1<f64>,
827        /// IRLS working weights for this block (non-negative, length n).
828        ///
829        /// For the inner solver, Fisher or observed weights both find the same mode.
830        /// For the outer REML/LAML log|H| term, observed weights are the correct
831        /// Laplace choice (see response.md Section 3). Canonical-link families need
832        /// no correction since observed = Fisher.
833        working_weights: Array1<f64>,
834    },
835    /// Exact Newton block update in coefficient space.
836    ///
837    /// `gradient` is nabla log L wrt block coefficients.
838    /// `hessian` is -nabla^2 log L wrt block coefficients (positive semidefinite near optimum).
839    ///
840    /// This is the observed Hessian by construction (actual second derivative of the
841    /// log-likelihood), which is the correct quantity for the outer REML Laplace
842    /// approximation.
843    ExactNewton {
844        gradient: Array1<f64>,
845        hessian: SymmetricMatrix,
846    },
847}
848
849impl BlockWorkingSet {
850    /// Construct a `Diagonal` working set with the length invariant
851    /// (`working_response.len() == working_weights.len()`) enforced at the
852    /// type boundary. Use this from any new code path that produces a
853    /// diagonal IRLS block; the legacy struct-literal form is preserved for
854    /// existing call sites pending a full migration.
855    #[inline]
856    pub fn diagonal_checked(
857        working_response: Array1<f64>,
858        working_weights: Array1<f64>,
859    ) -> Result<Self, String> {
860        if working_response.len() != working_weights.len() {
861            return Err(format!(
862                "BlockWorkingSet::Diagonal length mismatch: working_response={}, working_weights={}",
863                working_response.len(),
864                working_weights.len(),
865            ));
866        }
867        Ok(Self::Diagonal {
868            working_response,
869            working_weights,
870        })
871    }
872}