Skip to main content

gam_models/gamlss/gaussian/
log_link.rs

1// Real concern-organized submodule of the gamlss family stack.
2// Cross-module items are re-exported flat through the parent (`gamlss.rs`),
3// so `use super::*;` makes the sibling-concern symbols this module references
4// resolve through the parent namespace.
5use super::*;
6
7pub struct PoissonLogFamily {
8    pub y: Array1<f64>,
9    pub weights: Array1<f64>,
10}
11
12impl PoissonLogFamily {
13    pub const BLOCK_ETA: usize = 0;
14
15    pub fn parameternames() -> &'static [&'static str] {
16        &["eta"]
17    }
18
19    pub fn parameter_links() -> &'static [ParameterLink] {
20        &[ParameterLink::Log]
21    }
22
23    pub fn metadata() -> FamilyMetadata {
24        FamilyMetadata {
25            name: "poisson_log",
26            parameternames: Self::parameternames(),
27            parameter_links: Self::parameter_links(),
28        }
29    }
30}
31
32/// Per-row IRLS contribution that a single-parameter log-link family must
33/// produce. The shared driver `evaluate_log_link_diagonal_irls` consumes
34/// these and assembles the full `FamilyEvaluation` so the three pieces of
35/// code that previously lived inside each family — size validation, per-row
36/// y validation + η clamping + saturated `exp`, the active-clamp w/z guard,
37/// and the final return — exist in exactly one place.
38pub(crate) struct DiagonalIrlsRow {
39    /// Weighted contribution to ℓ at this row.
40    pub(crate) log_lik_increment: f64,
41    /// Unfloored observed Hessian weight (the driver applies `MIN_WEIGHT`).
42    pub(crate) observed_weight: f64,
43    /// Per-row Newton step on the working response: `z = e + working_step`.
44    /// Each family computes this with its own (score, denominator); the
45    /// driver only handles the active-clamp / zero-weight guard.
46    pub(crate) working_step: f64,
47}
48
49/// Trait implemented by single-block log-link families that share the
50/// diagonal IRLS structure (Poisson, Gamma). Each impl is responsible only
51/// for the family-specific math: validating `y[i]` and producing the
52/// per-row triple `(ℓ_increment, observed_weight, working_step)`.
53trait LogLinkDiagonalIrlsFamily {
54    /// Short, human-readable name used in size-mismatch errors.
55    fn family_label(&self) -> &'static str;
56
57    /// Read access to the shared (y, prior weights) buffers.
58    fn y(&self) -> &Array1<f64>;
59    fn prior_weights(&self) -> &Array1<f64>;
60
61    /// Optional pre-loop validation hook for parameters outside the
62    /// (y, weights, eta) triple (e.g. Gamma shape > 0).
63    fn validate_self(&self) -> Result<(), String> {
64        Ok(())
65    }
66
67    /// Validate `y[i]` and return an error message if rejected. Default
68    /// implementation enforces only finiteness; concrete families override
69    /// to add domain constraints.
70    fn validate_yi(&self, yi: f64, idx: usize) -> Result<(), String>;
71
72    /// Family-specific per-row math; `m = saturated_exp_eta(eta_clamped)`
73    /// is computed by the driver and handed in.
74    fn row_kernel(&self, yi: f64, e_clamped: f64, m: f64, prior_w: f64) -> DiagonalIrlsRow;
75}
76
77/// Shared IRLS driver for [`LogLinkDiagonalIrlsFamily`]. Centralises the
78/// size-check, η-clamp, saturated-exp, active-clamp guard, ll accumulation,
79/// and `FamilyEvaluation` assembly so all log-link families with the diagonal
80/// structure (Poisson, Gamma) cannot drift apart numerically.
81fn evaluate_log_link_diagonal_irls<F: LogLinkDiagonalIrlsFamily + ?Sized>(
82    family: &F,
83    block_states: &[ParameterBlockState],
84) -> Result<FamilyEvaluation, String> {
85    let label = family.family_label();
86    let eta = &expect_single_block(block_states, label)?.eta;
87    let y = family.y();
88    let prior_weights = family.prior_weights();
89    let n = y.len();
90    if eta.len() != n || prior_weights.len() != n {
91        return Err(GamlssError::DimensionMismatch {
92            reason: format!("{label} input size mismatch"),
93        }
94        .into());
95    }
96    family.validate_self()?;
97
98    let mut ll = 0.0;
99    let mut z = Array1::<f64>::zeros(n);
100    let mut w = Array1::<f64>::zeros(n);
101
102    for i in 0..n {
103        let yi = y[i];
104        family.validate_yi(yi, i)?;
105        let e_raw = eta[i];
106        let e = e_raw.clamp(-ETA_HARD_CLAMP, ETA_HARD_CLAMP);
107        let active_clamp = e != e_raw;
108        let m = saturated_exp_eta(e_raw);
109        let prior_w = prior_weights[i];
110        let row = family.row_kernel(yi, e, m, prior_w);
111        ll += row.log_lik_increment;
112        if prior_w == 0.0 || active_clamp {
113            w[i] = 0.0;
114            z[i] = e_raw;
115        } else {
116            w[i] = floor_positiveweight(row.observed_weight, MIN_WEIGHT);
117            z[i] = e + row.working_step;
118        }
119    }
120
121    Ok(FamilyEvaluation {
122        log_likelihood: ll,
123        blockworking_sets: vec![BlockWorkingSet::diagonal_checked(z, w)?],
124    })
125}
126
127impl LogLinkDiagonalIrlsFamily for PoissonLogFamily {
128    fn family_label(&self) -> &'static str {
129        "PoissonLogFamily"
130    }
131    fn y(&self) -> &Array1<f64> {
132        &self.y
133    }
134    fn prior_weights(&self) -> &Array1<f64> {
135        &self.weights
136    }
137    fn validate_yi(&self, yi: f64, idx: usize) -> Result<(), String> {
138        if !yi.is_finite() || yi < 0.0 {
139            return Err(GamlssError::InvalidInput {
140                reason: format!(
141                    "PoissonLogFamily requires non-negative finite y; found y[{idx}]={yi}"
142                ),
143            }
144            .into());
145        }
146        Ok::<(), _>(())
147    }
148    #[inline]
149    fn row_kernel(&self, yi: f64, e_clamped: f64, m: f64, prior_w: f64) -> DiagonalIrlsRow {
150        // Drop log(y!) constant in objective.
151        let log_lik_increment = prior_w * (yi * e_clamped - m);
152        let dmu = m.max(MIN_DERIV);
153        let var = m.max(MIN_PROB);
154        DiagonalIrlsRow {
155            log_lik_increment,
156            observed_weight: prior_w * (dmu * dmu / var),
157            // (yi - m)/dmu, identical to the previous direct expression.
158            working_step: (yi - m) / signedwith_floor(dmu, MIN_DERIV),
159        }
160    }
161}
162
163impl CustomFamily for PoissonLogFamily {
164    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
165    // flat-prior exact-Newton objective carries no Jeffreys term), so families
166    // that historically armed the term by default opt back in explicitly.
167    fn joint_jeffreys_term_required(&self) -> bool {
168        true
169    }
170
171    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
172        evaluate_log_link_diagonal_irls(self, block_states)
173    }
174}
175
176impl CustomFamilyGenerative for PoissonLogFamily {
177    fn generativespec(
178        &self,
179        block_states: &[ParameterBlockState],
180    ) -> Result<GenerativeSpec, String> {
181        let eta = &expect_single_block(block_states, "PoissonLogFamily")?.eta;
182        let mean = gamlss_rowwise_map(eta.len(), |i| saturated_exp_eta(eta[i]));
183        Ok(GenerativeSpec {
184            mean,
185            noise: NoiseModel::Poisson,
186        })
187    }
188}
189
190/// Built-in Gamma log-link family (single parameter block, fixed shape).
191#[derive(Clone)]
192pub struct GammaLogFamily {
193    pub y: Array1<f64>,
194    pub weights: Array1<f64>,
195    pub shape: f64,
196}
197
198impl GammaLogFamily {
199    pub const BLOCK_ETA: usize = 0;
200
201    pub fn parameternames() -> &'static [&'static str] {
202        &["eta"]
203    }
204
205    pub fn parameter_links() -> &'static [ParameterLink] {
206        &[ParameterLink::Log]
207    }
208
209    pub fn metadata() -> FamilyMetadata {
210        FamilyMetadata {
211            name: "gamma_log",
212            parameternames: Self::parameternames(),
213            parameter_links: Self::parameter_links(),
214        }
215    }
216}
217
218impl LogLinkDiagonalIrlsFamily for GammaLogFamily {
219    fn family_label(&self) -> &'static str {
220        "GammaLogFamily"
221    }
222    fn y(&self) -> &Array1<f64> {
223        &self.y
224    }
225    fn prior_weights(&self) -> &Array1<f64> {
226        &self.weights
227    }
228    fn validate_self(&self) -> Result<(), String> {
229        if !self.shape.is_finite() || self.shape <= 0.0 {
230            return Err(GamlssError::NonFinite {
231                reason: "GammaLogFamily shape must be finite and > 0".to_string(),
232            }
233            .into());
234        }
235        Ok(())
236    }
237    fn validate_yi(&self, yi: f64, idx: usize) -> Result<(), String> {
238        if !yi.is_finite() || yi <= 0.0 {
239            return Err(GamlssError::InvalidInput {
240                reason: format!("GammaLogFamily requires positive finite y; found y[{idx}]={yi}"),
241            }
242            .into());
243        }
244        Ok::<(), _>(())
245    }
246    #[inline]
247    fn row_kernel(&self, yi: f64, e_clamped: f64, m: f64, prior_w: f64) -> DiagonalIrlsRow {
248        assert!(e_clamped.is_finite());
249        assert!((e_clamped.exp() - m).abs() <= 1.0e-8 * m.abs().max(1.0));
250        // Gamma(shape=k, scale=mu/k), dropping eta-independent constants.
251        let log_lik_increment = prior_w * (-self.shape * (yi / m + m.ln()));
252        // Gamma with log mean is non-canonical. Use the exact observed
253        // η-space curvature -d²ℓ/dη² = prior_w * shape * y / μ, not the
254        // Fisher weight prior_w * shape, so diagonal REML/LAML Hessians
255        // use the true Laplace curvature instead of a PQL/Fisher surrogate.
256        let observed_weight = prior_w * self.shape * yi / m;
257        let score = prior_w * self.shape * (yi / m - 1.0);
258        // Mirror the pre-extraction formula z = e + score / w_floored exactly;
259        // the driver applies MIN_WEIGHT *before* writing w[i], but the old
260        // code divided by the already-floored w[i] for non-degenerate rows,
261        // and the floor only activates on the degenerate `observed_weight <=
262        // MIN_WEIGHT` tail. Reproduce that branch here to preserve bitwise
263        // step shape on every row that used to hit the floor.
264        let w_floored = observed_weight.max(MIN_WEIGHT);
265        DiagonalIrlsRow {
266            log_lik_increment,
267            observed_weight,
268            working_step: score / w_floored,
269        }
270    }
271}
272
273impl CustomFamily for GammaLogFamily {
274    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
275    // flat-prior exact-Newton objective carries no Jeffreys term), so families
276    // that historically armed the term by default opt back in explicitly.
277    fn joint_jeffreys_term_required(&self) -> bool {
278        true
279    }
280
281    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
282        evaluate_log_link_diagonal_irls(self, block_states)
283    }
284
285    fn diagonalworking_weights_directional_derivative(
286        &self,
287        block_states: &[ParameterBlockState],
288        block_idx: usize,
289        d_eta: &Array1<f64>,
290    ) -> Result<Option<Array1<f64>>, String> {
291        if block_idx != Self::BLOCK_ETA {
292            return Ok(None);
293        }
294        let eta = &expect_single_block(block_states, "GammaLogFamily")?.eta;
295        let n = self.y.len();
296        if eta.len() != n || self.weights.len() != n || d_eta.len() != n {
297            return Err(GamlssError::DimensionMismatch {
298                reason: "GammaLogFamily input size mismatch".to_string(),
299            }
300            .into());
301        }
302        if !self.shape.is_finite() || self.shape <= 0.0 {
303            return Err(GamlssError::NonFinite {
304                reason: "GammaLogFamily shape must be finite and > 0".to_string(),
305            }
306            .into());
307        }
308
309        let mut dw = Array1::<f64>::zeros(n);
310        for i in 0..n {
311            let yi = self.y[i];
312            if !yi.is_finite() || yi <= 0.0 {
313                return Err(GamlssError::InvalidInput {
314                    reason: format!("GammaLogFamily requires positive finite y; found y[{i}]={yi}"),
315                }
316                .into());
317            }
318            let e_raw = eta[i];
319            let e = e_raw.clamp(-ETA_HARD_CLAMP, ETA_HARD_CLAMP);
320            if self.weights[i] == 0.0 || e != e_raw {
321                dw[i] = 0.0;
322                continue;
323            }
324            let m = safe_exp(e).max(MIN_WEIGHT);
325            let observed_weight = self.weights[i] * self.shape * yi / m;
326            // d/dη [prior_weight * shape * y / exp(η)] = -W_obs.
327            // If the positive floor is active, match the evaluated local piece.
328            if observed_weight <= MIN_WEIGHT {
329                dw[i] = 0.0;
330            } else {
331                dw[i] = -observed_weight * d_eta[i];
332            }
333        }
334        Ok(Some(dw))
335    }
336}
337
338impl CustomFamilyGenerative for GammaLogFamily {
339    fn generativespec(
340        &self,
341        block_states: &[ParameterBlockState],
342    ) -> Result<GenerativeSpec, String> {
343        let eta = &expect_single_block(block_states, "GammaLogFamily")?.eta;
344        let mean = gamlss_rowwise_map(eta.len(), |i| saturated_exp_eta(eta[i]));
345        let shape = ndarray::Array1::from_elem(mean.len(), self.shape);
346        Ok(GenerativeSpec {
347            mean,
348            noise: NoiseModel::Gamma { shape },
349        })
350    }
351}