Skip to main content

gam_models/block_layout/
block_jacobian.rs

1//! Shared additive-plus-wiggle block-Jacobian dispatcher.
2//!
3//! Several multi-output custom families (survival location-scale, the Binomial
4//! location-scale-wiggle family, …) build the [`BlockEffectiveJacobian`] for one
5//! parameter block with the identical structure:
6//!
7//! * a set of **additive blocks**, each of which drives exactly one family
8//!   output (`own_output == block_idx`) via its effective design matrix; and
9//! * an optional **wiggle / link block**, which modulates the inverse link
10//!   nonlinearly and therefore contributes an all-zero effective linear
11//!   Jacobian of shape `(n × p_wiggle)` (anchored at `own_output = 0`).
12//!
13//! Only the family name, the number of family outputs, the additive block ids
14//! and the wiggle block id vary between callers. This module holds the single
15//! canonical implementation so the dispatch is not re-typed by hand.
16
17use crate::custom_family::{AdditiveBlockJacobian, BlockEffectiveJacobian, ParameterBlockSpec};
18
19/// Static description of an additive-plus-wiggle family's block layout, used to
20/// build the per-block [`BlockEffectiveJacobian`] in one place.
21///
22/// `additive_blocks` lists the block ids that contribute a linear additive term
23/// to a single family output; each such block drives output `block_idx`
24/// (i.e. `own_output == block_idx`). `wiggle_block` is the optional nonlinear
25/// link-modulation block whose effective linear Jacobian is all zeros; its row
26/// count is taken from the first additive block's design.
27pub struct AdditiveWiggleBlockLayout<'a> {
28    /// Family name used as the message prefix and `effective_design` context,
29    /// e.g. `"SurvivalLocationScaleFamily"`.
30    pub family: &'a str,
31    /// Total number of stacked family outputs (e.g. 3 for survival
32    /// location-scale, 2 for Binomial location-scale).
33    pub n_outputs: usize,
34    /// Block ids that contribute a linear additive term to their own output.
35    pub additive_blocks: &'a [usize],
36    /// Optional nonlinear link-modulation block id (all-zero linear Jacobian).
37    pub wiggle_block: Option<usize>,
38}
39
40impl AdditiveWiggleBlockLayout<'_> {
41    /// Build the [`BlockEffectiveJacobian`] for `block_idx` under this layout.
42    ///
43    /// * An additive block returns an [`AdditiveBlockJacobian`] over its
44    ///   effective design with `own_output = block_idx` and
45    ///   `n_family_outputs = n_outputs`.
46    /// * The wiggle block returns an [`AdditiveBlockJacobian`] over an
47    ///   `(n × p_wiggle)` zero design anchored at `own_output = 0`, with `n`
48    ///   read from the first additive block's design.
49    pub fn block_effective_jacobian(
50        &self,
51        specs: &[ParameterBlockSpec],
52        block_idx: usize,
53    ) -> Result<Box<dyn BlockEffectiveJacobian>, String> {
54        if block_idx >= specs.len() {
55            return Err(format!(
56                "{}::block_effective_jacobian: block_idx {} out of range ({})",
57                self.family,
58                block_idx,
59                specs.len()
60            ));
61        }
62        if self.additive_blocks.contains(&block_idx) {
63            let context = format!("{}::block_effective_jacobian", self.family);
64            let design = specs[block_idx].effective_design(&context)?;
65            return Ok(Box::new(AdditiveBlockJacobian {
66                design,
67                own_output: block_idx,
68                n_family_outputs: self.n_outputs,
69            }));
70        }
71        if self.wiggle_block == Some(block_idx) {
72            let first_additive = self.additive_blocks[0];
73            let n = specs[first_additive].design.nrows();
74            let p = specs[block_idx].design.ncols();
75            return Ok(Box::new(AdditiveBlockJacobian {
76                design: ndarray::Array2::<f64>::zeros((n, p)),
77                own_output: 0,
78                n_family_outputs: self.n_outputs,
79            }));
80        }
81        Err(format!(
82            "{}::block_effective_jacobian: unknown block_idx {}",
83            self.family, block_idx
84        ))
85    }
86}