gam_models/gamlss/gaussian/
log_link.rs1use super::*;
6
7pub struct PoissonLogFamily {
8 pub y: Array1<f64>,
9 pub weights: Array1<f64>,
10}
11
12impl PoissonLogFamily {
13 pub const BLOCK_ETA: usize = 0;
14
15 pub fn parameternames() -> &'static [&'static str] {
16 &["eta"]
17 }
18
19 pub fn parameter_links() -> &'static [ParameterLink] {
20 &[ParameterLink::Log]
21 }
22
23 pub fn metadata() -> FamilyMetadata {
24 FamilyMetadata {
25 name: "poisson_log",
26 parameternames: Self::parameternames(),
27 parameter_links: Self::parameter_links(),
28 }
29 }
30}
31
32pub(crate) struct DiagonalIrlsRow {
39 pub(crate) log_lik_increment: f64,
41 pub(crate) observed_weight: f64,
43 pub(crate) working_step: f64,
47}
48
49trait LogLinkDiagonalIrlsFamily {
54 fn family_label(&self) -> &'static str;
56
57 fn y(&self) -> &Array1<f64>;
59 fn prior_weights(&self) -> &Array1<f64>;
60
61 fn validate_self(&self) -> Result<(), String> {
64 Ok(())
65 }
66
67 fn validate_yi(&self, yi: f64, idx: usize) -> Result<(), String>;
71
72 fn row_kernel(&self, yi: f64, e_clamped: f64, m: f64, prior_w: f64) -> DiagonalIrlsRow;
75}
76
77fn evaluate_log_link_diagonal_irls<F: LogLinkDiagonalIrlsFamily + ?Sized>(
82 family: &F,
83 block_states: &[ParameterBlockState],
84) -> Result<FamilyEvaluation, String> {
85 let label = family.family_label();
86 let eta = &expect_single_block(block_states, label)?.eta;
87 let y = family.y();
88 let prior_weights = family.prior_weights();
89 let n = y.len();
90 if eta.len() != n || prior_weights.len() != n {
91 return Err(GamlssError::DimensionMismatch {
92 reason: format!("{label} input size mismatch"),
93 }
94 .into());
95 }
96 family.validate_self()?;
97
98 let mut ll = 0.0;
99 let mut z = Array1::<f64>::zeros(n);
100 let mut w = Array1::<f64>::zeros(n);
101
102 for i in 0..n {
103 let yi = y[i];
104 family.validate_yi(yi, i)?;
105 let e_raw = eta[i];
106 let e = e_raw.clamp(-ETA_HARD_CLAMP, ETA_HARD_CLAMP);
107 let active_clamp = e != e_raw;
108 let m = saturated_exp_eta(e_raw);
109 let prior_w = prior_weights[i];
110 let row = family.row_kernel(yi, e, m, prior_w);
111 ll += row.log_lik_increment;
112 if prior_w == 0.0 || active_clamp {
113 w[i] = 0.0;
114 z[i] = e_raw;
115 } else {
116 w[i] = floor_positiveweight(row.observed_weight, MIN_WEIGHT);
117 z[i] = e + row.working_step;
118 }
119 }
120
121 Ok(FamilyEvaluation {
122 log_likelihood: ll,
123 blockworking_sets: vec![BlockWorkingSet::diagonal_checked(z, w)?],
124 })
125}
126
127impl LogLinkDiagonalIrlsFamily for PoissonLogFamily {
128 fn family_label(&self) -> &'static str {
129 "PoissonLogFamily"
130 }
131 fn y(&self) -> &Array1<f64> {
132 &self.y
133 }
134 fn prior_weights(&self) -> &Array1<f64> {
135 &self.weights
136 }
137 fn validate_yi(&self, yi: f64, idx: usize) -> Result<(), String> {
138 if !yi.is_finite() || yi < 0.0 {
139 return Err(GamlssError::InvalidInput {
140 reason: format!(
141 "PoissonLogFamily requires non-negative finite y; found y[{idx}]={yi}"
142 ),
143 }
144 .into());
145 }
146 Ok::<(), _>(())
147 }
148 #[inline]
149 fn row_kernel(&self, yi: f64, e_clamped: f64, m: f64, prior_w: f64) -> DiagonalIrlsRow {
150 let log_lik_increment = prior_w * (yi * e_clamped - m);
152 let dmu = m.max(MIN_DERIV);
153 let var = m.max(MIN_PROB);
154 DiagonalIrlsRow {
155 log_lik_increment,
156 observed_weight: prior_w * (dmu * dmu / var),
157 working_step: (yi - m) / signedwith_floor(dmu, MIN_DERIV),
159 }
160 }
161}
162
163impl CustomFamily for PoissonLogFamily {
164 fn joint_jeffreys_term_required(&self) -> bool {
168 true
169 }
170
171 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
172 evaluate_log_link_diagonal_irls(self, block_states)
173 }
174}
175
176impl CustomFamilyGenerative for PoissonLogFamily {
177 fn generativespec(
178 &self,
179 block_states: &[ParameterBlockState],
180 ) -> Result<GenerativeSpec, String> {
181 let eta = &expect_single_block(block_states, "PoissonLogFamily")?.eta;
182 let mean = gamlss_rowwise_map(eta.len(), |i| saturated_exp_eta(eta[i]));
183 Ok(GenerativeSpec {
184 mean,
185 noise: NoiseModel::Poisson,
186 })
187 }
188}
189
190#[derive(Clone)]
192pub struct GammaLogFamily {
193 pub y: Array1<f64>,
194 pub weights: Array1<f64>,
195 pub shape: f64,
196}
197
198impl GammaLogFamily {
199 pub const BLOCK_ETA: usize = 0;
200
201 pub fn parameternames() -> &'static [&'static str] {
202 &["eta"]
203 }
204
205 pub fn parameter_links() -> &'static [ParameterLink] {
206 &[ParameterLink::Log]
207 }
208
209 pub fn metadata() -> FamilyMetadata {
210 FamilyMetadata {
211 name: "gamma_log",
212 parameternames: Self::parameternames(),
213 parameter_links: Self::parameter_links(),
214 }
215 }
216}
217
218impl LogLinkDiagonalIrlsFamily for GammaLogFamily {
219 fn family_label(&self) -> &'static str {
220 "GammaLogFamily"
221 }
222 fn y(&self) -> &Array1<f64> {
223 &self.y
224 }
225 fn prior_weights(&self) -> &Array1<f64> {
226 &self.weights
227 }
228 fn validate_self(&self) -> Result<(), String> {
229 if !self.shape.is_finite() || self.shape <= 0.0 {
230 return Err(GamlssError::NonFinite {
231 reason: "GammaLogFamily shape must be finite and > 0".to_string(),
232 }
233 .into());
234 }
235 Ok(())
236 }
237 fn validate_yi(&self, yi: f64, idx: usize) -> Result<(), String> {
238 if !yi.is_finite() || yi <= 0.0 {
239 return Err(GamlssError::InvalidInput {
240 reason: format!("GammaLogFamily requires positive finite y; found y[{idx}]={yi}"),
241 }
242 .into());
243 }
244 Ok::<(), _>(())
245 }
246 #[inline]
247 fn row_kernel(&self, yi: f64, e_clamped: f64, m: f64, prior_w: f64) -> DiagonalIrlsRow {
248 assert!(e_clamped.is_finite());
249 assert!((e_clamped.exp() - m).abs() <= 1.0e-8 * m.abs().max(1.0));
250 let log_lik_increment = prior_w * (-self.shape * (yi / m + m.ln()));
252 let observed_weight = prior_w * self.shape * yi / m;
257 let score = prior_w * self.shape * (yi / m - 1.0);
258 let w_floored = observed_weight.max(MIN_WEIGHT);
265 DiagonalIrlsRow {
266 log_lik_increment,
267 observed_weight,
268 working_step: score / w_floored,
269 }
270 }
271}
272
273impl CustomFamily for GammaLogFamily {
274 fn joint_jeffreys_term_required(&self) -> bool {
278 true
279 }
280
281 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
282 evaluate_log_link_diagonal_irls(self, block_states)
283 }
284
285 fn diagonalworking_weights_directional_derivative(
286 &self,
287 block_states: &[ParameterBlockState],
288 block_idx: usize,
289 d_eta: &Array1<f64>,
290 ) -> Result<Option<Array1<f64>>, String> {
291 if block_idx != Self::BLOCK_ETA {
292 return Ok(None);
293 }
294 let eta = &expect_single_block(block_states, "GammaLogFamily")?.eta;
295 let n = self.y.len();
296 if eta.len() != n || self.weights.len() != n || d_eta.len() != n {
297 return Err(GamlssError::DimensionMismatch {
298 reason: "GammaLogFamily input size mismatch".to_string(),
299 }
300 .into());
301 }
302 if !self.shape.is_finite() || self.shape <= 0.0 {
303 return Err(GamlssError::NonFinite {
304 reason: "GammaLogFamily shape must be finite and > 0".to_string(),
305 }
306 .into());
307 }
308
309 let mut dw = Array1::<f64>::zeros(n);
310 for i in 0..n {
311 let yi = self.y[i];
312 if !yi.is_finite() || yi <= 0.0 {
313 return Err(GamlssError::InvalidInput {
314 reason: format!("GammaLogFamily requires positive finite y; found y[{i}]={yi}"),
315 }
316 .into());
317 }
318 let e_raw = eta[i];
319 let e = e_raw.clamp(-ETA_HARD_CLAMP, ETA_HARD_CLAMP);
320 if self.weights[i] == 0.0 || e != e_raw {
321 dw[i] = 0.0;
322 continue;
323 }
324 let m = safe_exp(e).max(MIN_WEIGHT);
325 let observed_weight = self.weights[i] * self.shape * yi / m;
326 if observed_weight <= MIN_WEIGHT {
329 dw[i] = 0.0;
330 } else {
331 dw[i] = -observed_weight * d_eta[i];
332 }
333 }
334 Ok(Some(dw))
335 }
336}
337
338impl CustomFamilyGenerative for GammaLogFamily {
339 fn generativespec(
340 &self,
341 block_states: &[ParameterBlockState],
342 ) -> Result<GenerativeSpec, String> {
343 let eta = &expect_single_block(block_states, "GammaLogFamily")?.eta;
344 let mean = gamlss_rowwise_map(eta.len(), |i| saturated_exp_eta(eta[i]));
345 let shape = ndarray::Array1::from_elem(mean.len(), self.shape);
346 Ok(GenerativeSpec {
347 mean,
348 noise: NoiseModel::Gamma { shape },
349 })
350 }
351}