1use super::weighted_design_products::{mirror_upper_to_lower, xt_diag_x_design, xt_diag_y_design};
7use super::{
10 BlockwiseTermFitResult, GamlssLambdaLayout, LOCATION_SCALE_N_OUTPUTS,
11 LocationScaleFamilyBuilder, build_location_scale_block, fit_location_scale_terms,
12 identity_penalty, solve_penalizedweighted_projection,
13};
14use crate::block_layout::block_count::validate_block_count;
15use crate::custom_family::{
16 BlockWorkingSet, BlockwiseFitOptions, CustomFamily, CustomFamilyBlockPsiDerivative,
17 FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
18};
19use crate::gamlss::GamlssError;
20use crate::model_types::UnifiedFitResult;
21use gam_linalg::matrix::LinearOperator;
22use gam_math::jet_scalar::JetScalar;
23use gam_terms::smooth::{
24 SpatialLengthScaleOptimizationOptions, TermCollectionDesign, TermCollectionSpec,
25};
26use ndarray::{Array1, Array2, s};
27use statrs::function::gamma::ln_gamma;
28
29#[derive(Clone, Copy, Debug, PartialEq)]
79pub enum DispersionFamilyKind {
80 NegativeBinomial,
82 Gamma,
84 Beta,
87 Tweedie { p: f64 },
94}
95
96impl DispersionFamilyKind {
97 pub const fn family_tag(self) -> &'static str {
98 match self {
99 DispersionFamilyKind::NegativeBinomial => FAMILY_NEGBIN_LOCATION_SCALE,
100 DispersionFamilyKind::Gamma => FAMILY_GAMMA_LOCATION_SCALE,
101 DispersionFamilyKind::Beta => FAMILY_BETA_LOCATION_SCALE,
102 DispersionFamilyKind::Tweedie { .. } => FAMILY_TWEEDIE_LOCATION_SCALE,
103 }
104 }
105
106 pub(crate) const fn mean_is_logit(self) -> bool {
108 matches!(self, DispersionFamilyKind::Beta)
109 }
110
111 pub fn base_link(self) -> gam_problem::InverseLink {
116 use gam_problem::{InverseLink, StandardLink};
117 if self.mean_is_logit() {
118 InverseLink::Standard(StandardLink::Logit)
119 } else {
120 InverseLink::Standard(StandardLink::Log)
121 }
122 }
123
124 pub fn likelihood_spec(self) -> gam_problem::LikelihoodSpec {
132 use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
133 let response = match self {
134 DispersionFamilyKind::NegativeBinomial => ResponseFamily::NegativeBinomial {
135 theta: 1.0,
136 theta_fixed: false,
137 },
138 DispersionFamilyKind::Gamma => ResponseFamily::Gamma,
139 DispersionFamilyKind::Beta => ResponseFamily::Beta { phi: 1.0 },
140 DispersionFamilyKind::Tweedie { p } => ResponseFamily::Tweedie { p },
141 };
142 let link = if self.mean_is_logit() {
143 InverseLink::Standard(StandardLink::Logit)
144 } else {
145 InverseLink::Standard(StandardLink::Log)
146 };
147 LikelihoodSpec::new(response, link)
148 }
149}
150
151pub const FAMILY_NEGBIN_LOCATION_SCALE: &str = "negbin-location-scale";
152pub const FAMILY_GAMMA_LOCATION_SCALE: &str = "gamma-location-scale";
153pub const FAMILY_BETA_LOCATION_SCALE: &str = "beta-location-scale";
154pub const FAMILY_TWEEDIE_LOCATION_SCALE: &str = "tweedie-location-scale";
155
156pub(super) const DISPERSION_ETA_CLAMP: f64 = 30.0;
160pub(super) const DISPERSION_MIN_CURVATURE: f64 = 1e-12;
165
166const DISPERSION_PARALLEL_ROW_THRESHOLD: usize = 1024;
172
173pub(super) struct DispersionRowKernel {
175 pub(super) loglik: f64,
176 pub(super) mean_weight: f64,
177 pub(super) mean_response: f64,
178 pub(super) disp_weight: f64,
179 pub(super) disp_response: f64,
180}
181
182#[cfg(test)]
183mod test_support {
184 use super::*;
185
186 #[inline]
189 pub(super) fn dispersion_nb_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
190 yi: f64,
191 mu_value: f64,
192 theta_value: f64,
193 wi: f64,
194 ) -> S {
195 let mu = S::variable(mu_value, 0);
196 let theta = S::variable(theta_value, 1);
197 let tpm = theta.add(&mu);
198 let loglik = theta
201 .add(&S::constant(yi))
202 .ln_gamma()
203 .sub(&theta.ln_gamma())
204 .sub(&S::constant(ln_gamma(yi + 1.0)))
205 .add(&theta.mul(&theta.ln()))
206 .sub(&theta.mul(&tpm.ln()))
207 .add(&mu.ln().scale(yi))
208 .sub(&tpm.ln().scale(yi));
209 loglik.scale(-wi)
210 }
211
212 #[inline]
215 pub(super) fn dispersion_gamma_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
216 yi: f64,
217 y_pos: f64,
218 mu_value: f64,
219 nu_value: f64,
220 wi: f64,
221 ) -> S {
222 let mu = S::variable(mu_value, 0);
223 let nu = S::variable(nu_value, 1);
224 let loglik = nu
226 .mul(&nu.ln())
227 .sub(&nu.mul(&mu.ln()))
228 .sub(&nu.ln_gamma())
229 .add(&nu.sub(&S::constant(1.0)).scale(y_pos.ln()))
230 .sub(&nu.mul(&mu.recip().scale(yi)));
231 loglik.scale(-wi)
232 }
233
234 #[inline]
237 pub(super) fn dispersion_beta_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
238 yi: f64,
239 mu_value: f64,
240 phi_value: f64,
241 wi: f64,
242 ) -> S {
243 let mu = S::variable(mu_value, 0);
244 let phi = S::variable(phi_value, 1);
245 let one_minus_mu = S::constant(1.0).sub(&mu);
246 let yc = yi.clamp(1e-12, 1.0 - 1e-12);
247 let a = mu.mul(&phi);
248 let b = one_minus_mu.mul(&phi);
249 let loglik = phi
252 .ln_gamma()
253 .sub(&a.ln_gamma())
254 .sub(&b.ln_gamma())
255 .add(&a.sub(&S::constant(1.0)).scale(yc.ln()))
256 .add(&b.sub(&S::constant(1.0)).scale((1.0 - yc).ln()));
257 loglik.scale(-wi)
258 }
259
260 #[inline]
269 pub(super) fn dispersion_nb_nll_order2(
270 yi: f64,
271 mu_value: f64,
272 theta_value: f64,
273 wi: f64,
274 ) -> gam_math::jet_scalar::Order2<2> {
275 type O2 = gam_math::jet_scalar::Order2<2>;
276
277 let mu = O2::variable(mu_value, 0);
278 let theta = O2::variable(theta_value, 1);
279 let tpm = theta.add(&mu);
280 let theta_plus_y = theta.add(&O2::constant(yi));
281 let loglik = order2_ln_gamma(&theta_plus_y)
282 .sub(&order2_ln_gamma(&theta))
283 .sub(&O2::constant(ln_gamma(yi + 1.0)))
284 .add(&theta.mul(&theta.ln()))
285 .sub(&theta.mul(&tpm.ln()))
286 .add(&mu.ln().scale(yi))
287 .sub(&tpm.ln().scale(yi));
288 loglik.scale(-wi)
289 }
290
291 #[inline]
296 pub(super) fn dispersion_gamma_nll_order2(
297 yi: f64,
298 y_pos: f64,
299 mu_value: f64,
300 nu_value: f64,
301 wi: f64,
302 ) -> gam_math::jet_scalar::Order2<2> {
303 type O2 = gam_math::jet_scalar::Order2<2>;
304
305 let mu = O2::variable(mu_value, 0);
306 let nu = O2::variable(nu_value, 1);
307 let loglik = nu
308 .mul(&nu.ln())
309 .sub(&nu.mul(&mu.ln()))
310 .sub(&order2_ln_gamma(&nu))
311 .add(&nu.sub(&O2::constant(1.0)).scale(y_pos.ln()))
312 .sub(&nu.mul(&mu.recip().scale(yi)));
313 loglik.scale(-wi)
314 }
315}
316
317#[inline]
320pub(crate) fn dispersion_beta_nll_order2(
321 yi: f64,
322 mu_value: f64,
323 phi_value: f64,
324 wi: f64,
325) -> gam_math::jet_scalar::Order2<2> {
326 type O2 = gam_math::jet_scalar::Order2<2>;
327
328 let mu = O2::variable(mu_value, 0);
329 let phi = O2::variable(phi_value, 1);
330 let one_minus_mu = O2::constant(1.0).sub(&mu);
331 let yc = yi.clamp(1e-12, 1.0 - 1e-12);
332 let a = mu.mul(&phi);
333 let b = one_minus_mu.mul(&phi);
334 let loglik = order2_ln_gamma(&phi)
335 .sub(&order2_ln_gamma(&a))
336 .sub(&order2_ln_gamma(&b))
337 .add(&a.sub(&O2::constant(1.0)).scale(yc.ln()))
338 .add(&b.sub(&O2::constant(1.0)).scale((1.0 - yc).ln()));
339 loglik.scale(-wi)
340}
341
342#[inline]
343fn order2_ln_gamma<const K: usize>(
344 x: &gam_math::jet_scalar::Order2<K>,
345) -> gam_math::jet_scalar::Order2<K> {
346 gam_math::jet_scalar::Order2(
347 x.0.compose_unary(gam_math::jet_tower::ln_gamma_derivative_stack_order2(x.0.v)),
348 )
349}
350
351#[inline]
373pub(crate) fn dispersion_nb_disp_order2(
374 yi: f64,
375 mu_value: f64,
376 theta_value: f64,
377 wi: f64,
378) -> gam_math::jet_scalar::Order2<1> {
379 type O1 = gam_math::jet_scalar::Order2<1>;
380
381 let mu = O1::constant(mu_value);
382 let theta = O1::variable(theta_value, 0);
383 let tpm = theta.add(&mu);
384 let theta_plus_y = theta.add(&O1::constant(yi));
385 let loglik = order2_ln_gamma(&theta_plus_y)
386 .sub(&order2_ln_gamma(&theta))
387 .sub(&O1::constant(ln_gamma(yi + 1.0)))
388 .add(&theta.mul(&theta.ln()))
389 .sub(&theta.mul(&tpm.ln()))
390 .add(&mu.ln().scale(yi))
391 .sub(&tpm.ln().scale(yi));
392 loglik.scale(-wi)
393}
394
395#[inline]
399pub(crate) fn dispersion_gamma_disp_order2(
400 yi: f64,
401 y_pos: f64,
402 mu_value: f64,
403 nu_value: f64,
404 wi: f64,
405) -> gam_math::jet_scalar::Order2<1> {
406 type O1 = gam_math::jet_scalar::Order2<1>;
407
408 let mu = O1::constant(mu_value);
409 let nu = O1::variable(nu_value, 0);
410 let loglik = nu
411 .mul(&nu.ln())
412 .sub(&nu.mul(&mu.ln()))
413 .sub(&order2_ln_gamma(&nu))
414 .add(&nu.sub(&O1::constant(1.0)).scale(y_pos.ln()))
415 .sub(&nu.mul(&mu.recip().scale(yi)));
416 loglik.scale(-wi)
417}
418
419#[inline]
425pub(crate) fn dispersion_tweedie_disp_order2(
426 yi: f64,
427 eta_mu: f64,
428 eta_d: f64,
429 p: f64,
430 wi: f64,
431) -> gam_math::jet_scalar::Order2<1> {
432 type O1 = gam_math::jet_scalar::Order2<1>;
433
434 let one_minus_p = 1.0 - p;
435 let two_minus_p = 2.0 - p;
436 let mu = O1::constant(eta_mu).exp();
437 let phi = O1::variable(eta_d, 0).scale(-1.0).exp();
438 if yi > 0.0 {
439 let dev = mu
440 .powf(two_minus_p)
441 .scale(1.0 / two_minus_p)
442 .sub(&mu.powf(one_minus_p).scale(yi / one_minus_p))
443 .add(&O1::constant(
444 yi.powf(two_minus_p) / (one_minus_p * two_minus_p),
445 ))
446 .scale(2.0);
447 let loglik = dev
448 .mul(&phi.recip().scale(-0.5))
449 .sub(&phi.scale(2.0 * std::f64::consts::PI).ln().scale(0.5))
450 .sub(&O1::constant(0.5 * p * yi.ln()));
451 loglik.scale(-wi)
452 } else {
453 let c = mu.powf(two_minus_p).scale(1.0 / two_minus_p);
454 let loglik = c.mul(&phi.recip()).scale(-1.0);
455 loglik.scale(-wi)
456 }
457}
458
459#[inline]
475fn dispersion_nb_neg_loglik(yi: f64, mu: f64, theta: f64, wi: f64) -> f64 {
476 let tpm = theta + mu;
477 let s = ln_gamma(theta + yi) - ln_gamma(theta) - ln_gamma(yi + 1.0) + theta * theta.ln()
478 - theta * tpm.ln()
479 + mu.ln() * yi
480 - tpm.ln() * yi;
481 -(s * -wi)
482}
483
484#[inline]
487fn dispersion_gamma_neg_loglik(yi: f64, y_pos: f64, mu: f64, nu: f64, wi: f64) -> f64 {
488 let s = nu * nu.ln() - nu * mu.ln() - ln_gamma(nu) + (nu - 1.0) * y_pos.ln()
492 - nu * ((1.0 / mu) * yi);
493 -(s * -wi)
494}
495
496#[inline]
499fn dispersion_beta_neg_loglik(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
500 let one_minus_mu = 1.0 - mu;
501 let yc = yi.clamp(1e-12, 1.0 - 1e-12);
502 let a = mu * phi;
503 let b = one_minus_mu * phi;
504 let s = ln_gamma(phi) - ln_gamma(a) - ln_gamma(b)
505 + (a - 1.0) * yc.ln()
506 + (b - 1.0) * (1.0 - yc).ln();
507 -(s * -wi)
508}
509
510#[inline]
513fn dispersion_tweedie_neg_loglik(yi: f64, eta_mu: f64, eta_d: f64, p: f64, wi: f64) -> f64 {
514 let one_minus_p = 1.0 - p;
515 let two_minus_p = 2.0 - p;
516 let mu = eta_mu.exp();
517 let phi = (-eta_d).exp();
518 let s = if yi > 0.0 {
519 let dev = (mu.powf(two_minus_p) * (1.0 / two_minus_p)
520 - mu.powf(one_minus_p) * (yi / one_minus_p)
521 + yi.powf(two_minus_p) / (one_minus_p * two_minus_p))
522 * 2.0;
523 dev * ((1.0 / phi) * -0.5)
524 - (phi * (2.0 * std::f64::consts::PI)).ln() * 0.5
525 - 0.5 * p * yi.ln()
526 } else {
527 let c = mu.powf(two_minus_p) * (1.0 / two_minus_p);
528 (c * (1.0 / phi)) * -1.0
529 };
530 -(s * -wi)
531}
532
533#[inline]
539pub(crate) fn dispersion_row_loglik(
540 kind: DispersionFamilyKind,
541 yi: f64,
542 eta_mu: f64,
543 eta_d: f64,
544 prior_weight: f64,
545) -> f64 {
546 let wi = prior_weight.max(0.0);
547 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
548 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
549 match kind {
550 DispersionFamilyKind::NegativeBinomial => {
551 let mu = em.exp().max(1e-300);
552 let theta = ed.exp().max(1e-12);
553 dispersion_nb_neg_loglik(yi, mu, theta, wi)
554 }
555 DispersionFamilyKind::Gamma => {
556 let mu = em.exp().max(1e-300);
557 let nu = ed.exp().max(1e-12);
558 let y_pos = yi.max(1e-300);
559 dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)
560 }
561 DispersionFamilyKind::Beta => {
562 let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
563 let phi = ed.exp().max(1e-12);
564 dispersion_beta_neg_loglik(yi, mu, phi, wi)
565 }
566 DispersionFamilyKind::Tweedie { p } => dispersion_tweedie_neg_loglik(yi, em, ed, p, wi),
567 }
568}
569
570#[inline]
571pub(crate) fn beta_observed_cross_weight_eta(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
572 let q = (mu * (1.0 - mu)).max(1e-12);
573 let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
574 q * phi * tower.h()[0][1]
575}
576
577#[inline]
578pub(crate) fn dispersion_row_cross_weight(
579 kind: DispersionFamilyKind,
580 yi: f64,
581 eta_mu: f64,
582 eta_d: f64,
583 prior_weight: f64,
584) -> f64 {
585 let wi = prior_weight.max(0.0);
586 if wi == 0.0 {
587 return 0.0;
588 }
589 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
590 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
591 match kind {
592 DispersionFamilyKind::Beta => {
593 let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
594 let phi = ed.exp().max(1e-12);
595 beta_observed_cross_weight_eta(yi, mu, phi, wi)
596 }
597 DispersionFamilyKind::NegativeBinomial
598 | DispersionFamilyKind::Gamma
599 | DispersionFamilyKind::Tweedie { .. } => 0.0,
600 }
601}
602
603#[inline]
604pub(crate) fn tower_score_info<const K: usize>(
605 tower: &gam_math::jet_scalar::Order2<K>,
606 idx: usize,
607 wi: f64,
608) -> (f64, f64) {
609 if wi == 0.0 {
610 (0.0, 0.0)
611 } else {
612 (-tower.g()[idx] / wi, tower.h()[idx][idx] / wi)
613 }
614}
615
616pub(super) fn dispersion_row_kernel(
621 kind: DispersionFamilyKind,
622 yi: f64,
623 eta_mu: f64,
624 eta_d: f64,
625 prior_weight: f64,
626) -> DispersionRowKernel {
627 let wi = prior_weight.max(0.0);
628 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
629 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
630 match kind {
631 DispersionFamilyKind::NegativeBinomial => {
632 let mu = em.exp().max(1e-300);
633 let theta = ed.exp().max(1e-12); let tpm = theta + mu;
635 let tower = dispersion_nb_disp_order2(yi, mu, theta, wi);
636 let (s_theta, _info_theta_observed) = tower_score_info(&tower, 0, wi);
641 let loglik = -tower.value();
642 let info_mu = if wi == 0.0 {
643 DISPERSION_MIN_CURVATURE
644 } else {
645 (theta / (mu * tpm)).max(DISPERSION_MIN_CURVATURE)
646 };
647 let score_mu = theta * (yi - mu) / (mu * tpm);
648 let mean_weight = wi * mu * mu * info_mu;
649 let mean_response = em + score_mu / (mu * info_mu);
650 let trigamma_theta = gam_math::jet_tower::trigamma_derivative_stack(theta)[0];
683 let trigamma_tpm = gam_math::jet_tower::trigamma_derivative_stack(tpm)[0];
684 let info_theta_fisher = trigamma_theta - trigamma_tpm - 1.0 / theta + 1.0 / tpm;
685 let info_pos = info_theta_fisher.max(DISPERSION_MIN_CURVATURE);
686 let disp_weight = wi * theta * theta * info_pos;
687 let disp_response = ed + s_theta / (theta * info_pos);
688 DispersionRowKernel {
689 loglik,
690 mean_weight,
691 mean_response,
692 disp_weight,
693 disp_response,
694 }
695 }
696 DispersionFamilyKind::Gamma => {
697 let mu = em.exp().max(1e-300);
698 let nu = ed.exp().max(1e-12); let y_pos = yi.max(1e-300);
700 let tower = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
701 let (s_nu, info_nu_raw) = tower_score_info(&tower, 0, wi);
702 let loglik = -tower.value();
703 let info_mu = if wi == 0.0 {
704 DISPERSION_MIN_CURVATURE
705 } else {
706 (nu / (mu * mu)).max(DISPERSION_MIN_CURVATURE)
707 };
708 let score_mu = nu * (yi - mu) / (mu * mu);
709 let mean_weight = wi * mu * mu * info_mu;
710 let mean_response = em + score_mu / (mu * info_mu);
711 let info_nu = info_nu_raw.max(DISPERSION_MIN_CURVATURE);
712 let disp_weight = wi * nu * nu * info_nu;
713 let disp_response = ed + s_nu / (nu * info_nu);
714 DispersionRowKernel {
715 loglik,
716 mean_weight,
717 mean_response,
718 disp_weight,
719 disp_response,
720 }
721 }
722 DispersionFamilyKind::Beta => {
723 let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
725 let phi = ed.exp().max(1e-12); let q = (mu * (1.0 - mu)).max(1e-12); let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
728 let (score_mu, info_mu_raw) = tower_score_info(&tower, 0, wi);
729 let (s_phi, info_phi_raw) = tower_score_info(&tower, 1, wi);
730 let loglik = -tower.value();
731 let info_mu = info_mu_raw.max(DISPERSION_MIN_CURVATURE);
732 let mean_weight = wi * q * q * info_mu;
733 let mean_response = em + score_mu / (q * info_mu);
734 let info_phi = info_phi_raw.max(DISPERSION_MIN_CURVATURE);
735 let disp_weight = wi * phi * phi * info_phi;
736 let disp_response = ed + s_phi / (phi * info_phi);
737 DispersionRowKernel {
738 loglik,
739 mean_weight,
740 mean_response,
741 disp_weight,
742 disp_response,
743 }
744 }
745 DispersionFamilyKind::Tweedie { p } => {
746 let mu = em.exp().max(1e-300);
747 let phi = (-ed).exp().max(1e-12);
749 let two_minus_p = 2.0 - p;
750 let mean_weight = wi * mu.powf(two_minus_p) / phi;
756 let mean_response = em + (yi - mu) / mu;
757 let tower = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
766 let loglik = -tower.value();
767 let (s_eta, info_eta_raw) = tower_score_info(&tower, 0, wi);
771 let curvature_eta = if wi == 0.0 {
772 DISPERSION_MIN_CURVATURE
773 } else {
774 info_eta_raw.max(DISPERSION_MIN_CURVATURE)
775 };
776 let disp_weight = wi * curvature_eta;
780 let disp_response = ed + s_eta / curvature_eta;
781 DispersionRowKernel {
782 loglik,
783 mean_weight,
784 mean_response,
785 disp_weight,
786 disp_response,
787 }
788 }
789 }
790}
791
792#[derive(Clone)]
794pub(crate) struct DispersionGlmLocationScaleFamily {
795 pub(crate) kind: DispersionFamilyKind,
796 pub(crate) y: Array1<f64>,
797 pub(crate) weights: Array1<f64>,
798}
799
800impl DispersionGlmLocationScaleFamily {
801 pub(crate) const BLOCK_MEAN: usize = 0;
802 pub(crate) const BLOCK_DISP: usize = 1;
803}
804
805impl CustomFamily for DispersionGlmLocationScaleFamily {
806 fn joint_jeffreys_term_required(&self) -> bool {
810 true
811 }
812
813 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
814 validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
815 let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
816 let eta_d = &block_states[Self::BLOCK_DISP].eta;
817 let n = self.y.len();
818 if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
819 return Err(format!(
820 "{} row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
821 self.kind.family_tag(),
822 eta_mu.len(),
823 eta_d.len(),
824 self.weights.len()
825 ));
826 }
827 let mut mean_weights = Array1::<f64>::zeros(n);
828 let mut mean_response = Array1::<f64>::zeros(n);
829 let mut disp_weights = Array1::<f64>::zeros(n);
830 let mut disp_response = Array1::<f64>::zeros(n);
831
832 let kernels: Vec<DispersionRowKernel> = if rayon::current_thread_index().is_none()
843 && n > DISPERSION_PARALLEL_ROW_THRESHOLD
844 {
845 use rayon::iter::{IntoParallelIterator, ParallelIterator};
846 (0..n)
847 .into_par_iter()
848 .map(|i| {
849 dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
850 })
851 .collect()
852 } else {
853 (0..n)
854 .map(|i| {
855 dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
856 })
857 .collect()
858 };
859
860 let mut log_likelihood = 0.0;
861 for (i, row) in kernels.into_iter().enumerate() {
862 if row.loglik.is_finite() {
863 log_likelihood += row.loglik;
864 }
865 mean_weights[i] = row.mean_weight.max(0.0);
866 mean_response[i] = row.mean_response;
867 disp_weights[i] = row.disp_weight.max(0.0);
868 disp_response[i] = row.disp_response;
869 }
870 Ok(FamilyEvaluation {
871 log_likelihood,
872 blockworking_sets: vec![
873 BlockWorkingSet::diagonal_checked(mean_response, mean_weights)?,
874 BlockWorkingSet::diagonal_checked(disp_response, disp_weights)?,
875 ],
876 })
877 }
878
879 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
880 validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
881 let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
882 let eta_d = &block_states[Self::BLOCK_DISP].eta;
883 let n = self.y.len();
884 let per_row: Vec<f64> = if rayon::current_thread_index().is_none()
893 && n > DISPERSION_PARALLEL_ROW_THRESHOLD
894 {
895 use rayon::iter::{IntoParallelIterator, ParallelIterator};
896 (0..n)
897 .into_par_iter()
898 .map(|i| {
899 dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
900 })
901 .collect()
902 } else {
903 (0..n)
904 .map(|i| {
905 dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
906 })
907 .collect()
908 };
909 let mut ll = 0.0;
910 for loglik in per_row {
911 if loglik.is_finite() {
912 ll += loglik;
913 }
914 }
915 Ok(ll)
916 }
917
918 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
919 crate::location_scale_engine::location_scale_coefficient_hessian_cost(
920 self.y.len() as u64,
921 specs,
922 )
923 }
924
925 fn exact_newton_joint_hessian_with_specs(
944 &self,
945 block_states: &[ParameterBlockState],
946 specs: &[ParameterBlockSpec],
947 ) -> Result<Option<Array2<f64>>, String> {
948 validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
949 if specs.len() != 2 {
950 return Err(format!(
951 "{} exact joint Hessian expects 2 specs, got {}",
952 self.kind.family_tag(),
953 specs.len()
954 ));
955 }
956 let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
957 let eta_d = &block_states[Self::BLOCK_DISP].eta;
958 let n = self.y.len();
959 if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
960 return Err(format!(
961 "{} exact joint Hessian row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
962 self.kind.family_tag(),
963 eta_mu.len(),
964 eta_d.len(),
965 self.weights.len()
966 ));
967 }
968
969 let eval = self.evaluate(block_states)?;
970 let BlockWorkingSet::Diagonal {
971 working_weights: mean_weights,
972 ..
973 } = &eval.blockworking_sets[Self::BLOCK_MEAN]
974 else {
975 return Err(format!(
976 "{} dispersion mean block did not return diagonal weights",
977 self.kind.family_tag()
978 ));
979 };
980 let BlockWorkingSet::Diagonal {
981 working_weights: disp_weights,
982 ..
983 } = &eval.blockworking_sets[Self::BLOCK_DISP]
984 else {
985 return Err(format!(
986 "{} dispersion precision block did not return diagonal weights",
987 self.kind.family_tag()
988 ));
989 };
990
991 let cross_weights = if rayon::current_thread_index().is_none()
997 && n > DISPERSION_PARALLEL_ROW_THRESHOLD
998 {
999 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1000 Array1::from_vec(
1001 (0..n)
1002 .into_par_iter()
1003 .map(|i| {
1004 dispersion_row_cross_weight(
1005 self.kind,
1006 self.y[i],
1007 eta_mu[i],
1008 eta_d[i],
1009 self.weights[i],
1010 )
1011 })
1012 .collect::<Vec<f64>>(),
1013 )
1014 } else {
1015 Array1::from_shape_fn(n, |i| {
1016 dispersion_row_cross_weight(
1017 self.kind,
1018 self.y[i],
1019 eta_mu[i],
1020 eta_d[i],
1021 self.weights[i],
1022 )
1023 })
1024 };
1025 let mean_spec = &specs[Self::BLOCK_MEAN];
1026 let disp_spec = &specs[Self::BLOCK_DISP];
1027 if mean_spec.design.nrows() != n || disp_spec.design.nrows() != n {
1028 return Err(format!(
1029 "{} exact joint Hessian design row mismatch: y={n}, mean rows={}, precision rows={}",
1030 self.kind.family_tag(),
1031 mean_spec.design.nrows(),
1032 disp_spec.design.nrows()
1033 ));
1034 }
1035 let p_mean = mean_spec.design.ncols();
1036 let p_disp = disp_spec.design.ncols();
1037 if block_states[Self::BLOCK_MEAN].beta.len() != p_mean
1038 || block_states[Self::BLOCK_DISP].beta.len() != p_disp
1039 {
1040 return Err(format!(
1041 "{} exact joint Hessian beta/design mismatch: mean beta {} vs cols {}, precision beta {} vs cols {}",
1042 self.kind.family_tag(),
1043 block_states[Self::BLOCK_MEAN].beta.len(),
1044 p_mean,
1045 block_states[Self::BLOCK_DISP].beta.len(),
1046 p_disp
1047 ));
1048 }
1049
1050 let h_mean = xt_diag_x_design(&mean_spec.design, mean_weights)?;
1051 let h_cross = xt_diag_y_design(&mean_spec.design, &cross_weights, &disp_spec.design)?;
1052 let h_disp = xt_diag_x_design(&disp_spec.design, disp_weights)?;
1053 let total = p_mean + p_disp;
1054 let mut h = Array2::<f64>::zeros((total, total));
1055 h.slice_mut(s![0..p_mean, 0..p_mean]).assign(&h_mean);
1056 h.slice_mut(s![0..p_mean, p_mean..total]).assign(&h_cross);
1057 h.slice_mut(s![p_mean..total, p_mean..total])
1058 .assign(&h_disp);
1059 mirror_upper_to_lower(&mut h);
1060 Ok(Some(h))
1061 }
1062
1063 fn likelihood_blocks_uncoupled(&self) -> bool {
1077 !matches!(self.kind, DispersionFamilyKind::Beta)
1078 }
1079
1080 fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
1091 assert!(
1092 crate::custom_family::validate_blockspec_consistency(specs).is_ok(),
1093 "DispersionGlmLocationScale outer hyper-Hessian dense availability: \
1094 inconsistent parameter block specs"
1095 );
1096 specs.len() < 2
1097 }
1098}
1099
1100pub struct DispersionGlmLocationScaleTermSpec {
1104 pub kind: DispersionFamilyKind,
1105 pub y: Array1<f64>,
1106 pub weights: Array1<f64>,
1107 pub meanspec: TermCollectionSpec,
1108 pub log_dispspec: TermCollectionSpec,
1109 pub mean_offset: Array1<f64>,
1110 pub log_disp_offset: Array1<f64>,
1111}
1112
1113pub(crate) struct DispersionGlmLocationScaleTermBuilder {
1114 pub(crate) kind: DispersionFamilyKind,
1115 pub(crate) y: Array1<f64>,
1116 pub(crate) weights: Array1<f64>,
1117 pub(crate) meanspec: TermCollectionSpec,
1118 pub(crate) noisespec: TermCollectionSpec,
1119 pub(crate) mean_offset: Array1<f64>,
1120 pub(crate) noise_offset: Array1<f64>,
1121}
1122
1123pub(crate) fn dispersion_location_scale_warm_start(
1127 kind: DispersionFamilyKind,
1128 y: &Array1<f64>,
1129 weights: &Array1<f64>,
1130 mean_block: &ParameterBlockSpec,
1131 disp_block: &ParameterBlockSpec,
1132 mean_beta_hint: Option<&Array1<f64>>,
1133 disp_beta_hint: Option<&Array1<f64>>,
1134) -> Result<(Array1<f64>, Array1<f64>), String> {
1135 let ridge_floor = 1e-10;
1136 let mean_beta = if let Some(beta) = mean_beta_hint {
1137 beta.clone()
1138 } else {
1139 let target = Array1::from_shape_fn(y.len(), |i| {
1140 if kind.mean_is_logit() {
1141 let yi = y[i].clamp(1e-3, 1.0 - 1e-3);
1142 (yi / (1.0 - yi)).ln()
1143 } else {
1144 (y[i].max(0.0) + 0.1).ln()
1146 }
1147 });
1148 solve_penalizedweighted_projection(
1149 &mean_block.design,
1150 &mean_block.offset,
1151 &target,
1152 weights,
1153 &mean_block.penalties,
1154 &mean_block.initial_log_lambdas,
1155 ridge_floor,
1156 )?
1157 };
1158 let disp_beta = if let Some(beta) = disp_beta_hint {
1159 beta.clone()
1160 } else {
1161 let mean_eta = mean_block.design.apply(&mean_beta) + &mean_block.offset;
1176 let target = Array1::from_shape_fn(y.len(), |i| {
1177 dispersion_moment_log_precision_seed(kind, y[i], mean_eta[i])
1178 });
1179 solve_penalizedweighted_projection(
1180 &disp_block.design,
1181 &disp_block.offset,
1182 &target,
1183 weights,
1184 &disp_block.penalties,
1185 &disp_block.initial_log_lambdas,
1186 ridge_floor,
1187 )?
1188 };
1189 Ok((mean_beta, disp_beta))
1190}
1191
1192#[inline]
1193fn dispersion_moment_log_precision_seed(kind: DispersionFamilyKind, yi: f64, eta_mu: f64) -> f64 {
1194 const LOG_PRECISION_FLOOR: f64 = -10.0;
1195 const LOG_PRECISION_CEILING: f64 = 10.0;
1196 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1197 let raw = match kind {
1198 DispersionFamilyKind::Beta => {
1199 0.0
1208 }
1209 DispersionFamilyKind::Gamma => {
1210 let mu = em.exp().max(1e-12);
1211 let e2 = (yi - mu).powi(2).max(1e-8 * mu * mu);
1212 (mu * mu / e2).max(1e-6).ln()
1213 }
1214 DispersionFamilyKind::NegativeBinomial => {
1215 let mu = em.exp().max(1e-12);
1216 let e2 = (yi - mu).powi(2);
1217 let excess = (e2 - mu).max(1e-6 * (mu + mu * mu));
1218 (mu * mu / excess).max(1e-6).ln()
1219 }
1220 DispersionFamilyKind::Tweedie { p } => {
1221 let mu = em.exp().max(1e-12);
1222 let e2 = (yi - mu).powi(2).max(1e-8 * mu.powf(p));
1223 (mu.powf(p) / e2).max(1e-6).ln()
1224 }
1225 };
1226 raw.clamp(LOG_PRECISION_FLOOR, LOG_PRECISION_CEILING)
1227}
1228
1229impl LocationScaleFamilyBuilder for DispersionGlmLocationScaleTermBuilder {
1230 type Family = DispersionGlmLocationScaleFamily;
1231
1232 fn meanspec(&self) -> &TermCollectionSpec {
1233 &self.meanspec
1234 }
1235
1236 fn noisespec(&self) -> &TermCollectionSpec {
1237 &self.noisespec
1238 }
1239
1240 fn noise_penalty_count(&self, noise_design: &TermCollectionDesign) -> usize {
1241 noise_design.penalties.len() + 1
1245 }
1246
1247 fn build_blocks(
1248 &self,
1249 theta: &Array1<f64>,
1250 mean_design: &TermCollectionDesign,
1251 noise_design: &TermCollectionDesign,
1252 mean_beta_hint: Option<Array1<f64>>,
1253 noise_beta_hint: Option<Array1<f64>>,
1254 ) -> Result<Vec<ParameterBlockSpec>, String> {
1255 let layout = GamlssLambdaLayout::two_block(
1256 mean_design.penalties.len(),
1257 self.noise_penalty_count(noise_design),
1258 );
1259 layout.validate_theta_len(theta.len(), "dispersion location-scale")?;
1260
1261 let mut meanspec = build_location_scale_block(
1262 "mu",
1263 mean_design.design.clone(),
1264 self.mean_offset.clone(),
1265 mean_design.penalties_as_penalty_matrix(),
1266 mean_design.nullspace_dims.clone(),
1267 layout.mean_from(theta),
1268 mean_beta_hint,
1269 0,
1270 LOCATION_SCALE_N_OUTPUTS,
1271 "DispersionLocationScale::build_blocks: mu",
1272 )?;
1273
1274 let p_disp = noise_design.design.ncols();
1275 let mut disp_penalties = noise_design.penalties_as_penalty_matrix();
1276 disp_penalties.push(PenaltyMatrix::Dense(identity_penalty(p_disp)));
1277 let mut disp_nullspace = noise_design.nullspace_dims.clone();
1278 disp_nullspace.push(0);
1279 let mut dispspec = build_location_scale_block(
1280 "log_precision",
1281 noise_design.design.clone(),
1282 self.noise_offset.clone(),
1283 disp_penalties,
1284 disp_nullspace,
1285 layout.noise_from(theta),
1286 noise_beta_hint,
1287 1,
1288 LOCATION_SCALE_N_OUTPUTS,
1289 "DispersionLocationScale::build_blocks: log_precision",
1290 )?;
1291
1292 if meanspec.initial_beta.is_none() || dispspec.initial_beta.is_none() {
1293 let (mean_beta0, disp_beta0) = dispersion_location_scale_warm_start(
1294 self.kind,
1295 &self.y,
1296 &self.weights,
1297 &meanspec,
1298 &dispspec,
1299 meanspec.initial_beta.as_ref(),
1300 dispspec.initial_beta.as_ref(),
1301 )?;
1302 if meanspec.initial_beta.is_none() {
1303 meanspec.initial_beta = Some(mean_beta0);
1304 }
1305 if dispspec.initial_beta.is_none() {
1306 dispspec.initial_beta = Some(disp_beta0);
1307 }
1308 }
1309
1310 Ok(vec![meanspec, dispspec])
1311 }
1312
1313 fn build_family(
1314 &self,
1315 mean_design: &TermCollectionDesign,
1316 noise_design: &TermCollectionDesign,
1317 ) -> Self::Family {
1318 assert_eq!(
1325 mean_design.design.nrows(),
1326 self.y.len(),
1327 "DispersionGlmLocationScale::build_family: mean design row count must match y"
1328 );
1329 assert_eq!(
1330 noise_design.design.nrows(),
1331 self.y.len(),
1332 "DispersionGlmLocationScale::build_family: noise design row count must match y"
1333 );
1334 DispersionGlmLocationScaleFamily {
1335 kind: self.kind,
1336 y: self.y.clone(),
1337 weights: self.weights.clone(),
1338 }
1339 }
1340
1341 fn extract_primary_betas(
1342 &self,
1343 fit: &UnifiedFitResult,
1344 ) -> Result<(Array1<f64>, Array1<f64>), String> {
1345 let mean_beta = fit
1346 .block_states
1347 .get(DispersionGlmLocationScaleFamily::BLOCK_MEAN)
1348 .ok_or_else(|| "missing dispersion mean block state".to_string())?
1349 .beta
1350 .clone();
1351 let disp_beta = fit
1352 .block_states
1353 .get(DispersionGlmLocationScaleFamily::BLOCK_DISP)
1354 .ok_or_else(|| "missing dispersion log-precision block state".to_string())?
1355 .beta
1356 .clone();
1357 Ok((mean_beta, disp_beta))
1358 }
1359
1360 fn build_psiderivative_blocks(
1361 &self,
1362 data: ndarray::ArrayView2<'_, f64>,
1363 meanspec: &TermCollectionSpec,
1364 noisespec: &TermCollectionSpec,
1365 mean_design: &TermCollectionDesign,
1366 noise_design: &TermCollectionDesign,
1367 ) -> Result<Vec<Vec<CustomFamilyBlockPsiDerivative>>, String> {
1368 Err(format!(
1376 "dispersion location-scale ({:?}) does not implement analytic spatial \
1377 psi derivatives; the κ/ψ joint optimizer must be disabled before \
1378 this builder is consulted. Called with data {n_rows}×{n_cols}, mean \
1379 spec (linear={mean_lin}, random={mean_re}, smooth={mean_sm}), noise \
1380 spec (linear={noise_lin}, random={noise_re}, smooth={noise_sm}), \
1381 mean design cols={mean_p}, noise design cols={noise_p}",
1382 self.kind,
1383 n_rows = data.nrows(),
1384 n_cols = data.ncols(),
1385 mean_lin = meanspec.linear_terms.len(),
1386 mean_re = meanspec.random_effect_terms.len(),
1387 mean_sm = meanspec.smooth_terms.len(),
1388 noise_lin = noisespec.linear_terms.len(),
1389 noise_re = noisespec.random_effect_terms.len(),
1390 noise_sm = noisespec.smooth_terms.len(),
1391 mean_p = mean_design.design.ncols(),
1392 noise_p = noise_design.design.ncols(),
1393 ))
1394 }
1395}
1396
1397pub fn fit_dispersion_glm_location_scale_terms(
1401 data: ndarray::ArrayView2<'_, f64>,
1402 spec: DispersionGlmLocationScaleTermSpec,
1403 options: &BlockwiseFitOptions,
1404 kappa_options: &SpatialLengthScaleOptimizationOptions,
1405) -> Result<BlockwiseTermFitResult, String> {
1406 if let DispersionFamilyKind::Tweedie { p } = spec.kind {
1407 if !(p.is_finite() && p > 1.0 && p < 2.0) {
1408 return Err(format!(
1409 "Tweedie location-scale requires a variance power strictly in (1, 2); got p={p}"
1410 ));
1411 }
1412 }
1413 let mut kappa = kappa_options.clone();
1418 kappa.enabled = false;
1419 let mut options = options.clone();
1437 options.compute_covariance = true;
1438 fit_location_scale_terms(
1439 data,
1440 DispersionGlmLocationScaleTermBuilder {
1441 kind: spec.kind,
1442 y: spec.y,
1443 weights: spec.weights,
1444 meanspec: spec.meanspec,
1445 noisespec: spec.log_dispspec,
1446 mean_offset: spec.mean_offset,
1447 noise_offset: spec.log_disp_offset,
1448 },
1449 &options,
1450 &kappa,
1451 )
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456 use super::*;
1457 use super::test_support::{dispersion_gamma_nll_order2, dispersion_nb_nll_order2};
1458 use crate::gamlss::test_support::dispersion_tweedie_nll_generic;
1459
1460 pub(crate) fn beta_fisher_cross_info_mu_phi(mu: f64, phi: f64) -> f64 {
1461 let a = mu * phi;
1462 let b = (1.0 - mu) * phi;
1463 phi * (mu * gam_math::jet_tower::trigamma_derivative_stack(a)[0]
1464 - (1.0 - mu) * gam_math::jet_tower::trigamma_derivative_stack(b)[0])
1465 }
1466
1467 pub(crate) fn assert_close(label: &str, got: f64, want: f64, tol: f64) {
1468 assert!(
1469 (got - want).abs() <= tol,
1470 "{label}: got {got:.12e}, want {want:.12e}, |diff|={:.3e}",
1471 (got - want).abs()
1472 );
1473 }
1474
1475 #[test]
1476 pub(crate) fn beta_tower_mixed_channel_matches_cross_information_formula() {
1477 let mu = 0.1;
1478 let phi = 10.0;
1479 let a = mu * phi;
1480 let b = (1.0 - mu) * phi;
1481 let digamma_a = gam_math::jet_tower::digamma_derivative_stack(a)[0];
1482 let digamma_b = gam_math::jet_tower::digamma_derivative_stack(b)[0];
1483 let score_neutral_y = 1.0 / (1.0 + (-(digamma_a - digamma_b)).exp());
1484
1485 let tower = dispersion_beta_nll_order2(score_neutral_y, mu, phi, 1.0);
1486 let trigamma_a = std::f64::consts::PI * std::f64::consts::PI / 6.0;
1487 let trigamma_b = gam_math::jet_tower::trigamma_derivative_stack(b)[0];
1488 let analytic = phi * (mu * trigamma_a - (1.0 - mu) * trigamma_b);
1489 let helper = beta_fisher_cross_info_mu_phi(mu, phi);
1490
1491 assert!(
1492 analytic > 0.58,
1493 "audit example should have visibly nonzero cross information, got {analytic}"
1494 );
1495 assert_close("helper cross information", helper, analytic, 1e-12);
1496 assert_close("tower mixed channel", tower.h()[0][1], analytic, 1e-8);
1497
1498 let q = mu * (1.0 - mu);
1499 let eta_cross = beta_observed_cross_weight_eta(score_neutral_y, mu, phi, 1.0);
1500 assert_close(
1501 "eta-scale cross weight",
1502 eta_cross,
1503 q * phi * analytic,
1504 1e-8,
1505 );
1506 }
1507
1508 #[test]
1512 pub(crate) fn order2_matches_dense_tower_all_channels() {
1513 use gam_math::jet_scalar::{JetScalar, Order2};
1514 use gam_math::jet_tower::Tower4;
1515
1516 fn check_o2_vs_tower4(label: &str, o2: Order2<2>, t4: Tower4<2>) {
1517 let band = |a: f64, b: f64| 1e-9 + 1e-9 * a.abs().max(b.abs());
1518 assert!(
1519 (o2.value() - t4.v).abs() <= band(o2.value(), t4.v),
1520 "{label} value: {} vs {}",
1521 o2.value(),
1522 t4.v
1523 );
1524 for a in 0..2 {
1525 assert!(
1526 (o2.g()[a] - t4.g[a]).abs() <= band(o2.g()[a], t4.g[a]),
1527 "{label} grad[{a}]: {} vs {}",
1528 o2.g()[a],
1529 t4.g[a]
1530 );
1531 for b in 0..2 {
1532 assert!(
1533 (o2.h()[a][b] - t4.h[a][b]).abs() <= band(o2.h()[a][b], t4.h[a][b]),
1534 "{label} hess[{a}][{b}]: {} vs {}",
1535 o2.h()[a][b],
1536 t4.h[a][b]
1537 );
1538 }
1539 }
1540 }
1541
1542 let wi = 1.7_f64;
1543 for &(yi, mu, theta) in &[(0.0, 1.2, 3.0), (4.0, 2.5, 0.7), (10.0, 0.6, 5.0)] {
1545 check_o2_vs_tower4(
1546 "nb",
1547 dispersion_nb_nll_order2(yi, mu, theta, wi),
1548 test_support::dispersion_nb_nll_generic::<Tower4<2>>(yi, mu, theta, wi),
1549 );
1550 }
1551 for &(yi, mu, nu) in &[
1553 (0.5_f64, 1.1_f64, 2.0_f64),
1554 (3.0, 4.0, 0.9),
1555 (1.0, 0.3, 6.0),
1556 ] {
1557 let y_pos = yi.max(1e-300);
1558 check_o2_vs_tower4(
1559 "gamma",
1560 dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi),
1561 test_support::dispersion_gamma_nll_generic::<Tower4<2>>(yi, y_pos, mu, nu, wi),
1562 );
1563 }
1564 for &(yi, mu, phi) in &[(0.3, 0.4, 5.0), (0.9, 0.6, 12.0), (0.01, 0.2, 3.0)] {
1566 check_o2_vs_tower4(
1567 "beta",
1568 dispersion_beta_nll_order2(yi, mu, phi, wi),
1569 test_support::dispersion_beta_nll_generic::<Tower4<2>>(yi, mu, phi, wi),
1570 );
1571 }
1572 for &(yi, eta_mu, eta_d, p) in &[
1574 (0.0, 0.4, -0.3, 1.5),
1575 (2.5, -0.2, 0.5, 1.3),
1576 (0.0, 1.0, 0.1, 1.7),
1577 (5.0, 0.7, -0.6, 1.6),
1578 ] {
1579 check_o2_vs_tower4(
1580 "tweedie",
1581 dispersion_tweedie_nll_generic::<Order2<2>>(yi, eta_mu, eta_d, p, wi),
1582 dispersion_tweedie_nll_generic::<Tower4<2>>(yi, eta_mu, eta_d, p, wi),
1583 );
1584 }
1585 }
1586
1587 #[test]
1594 pub(crate) fn pruned_disp_towers_bit_identical_to_full_order2() {
1595 use gam_math::jet_scalar::{JetScalar, Order2};
1596
1597 let mut state: u64 = 0x9E3779B97F4A7C15;
1599 let mut next = || {
1600 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1601 ((state >> 11) as f64) / ((1u64 << 53) as f64)
1602 };
1603 let bits = |x: f64| x.to_bits();
1604
1605 let n_per = 600; for _ in 0..n_per {
1607 let wi = 0.25 + 3.0 * next();
1608 let yi_count = (next() * 12.0).floor();
1609
1610 {
1612 let mu = (0.05 + 4.0 * next()).max(1e-300);
1613 let theta = (0.05 + 6.0 * next()).max(1e-12);
1614 let full = dispersion_nb_nll_order2(yi_count, mu, theta, wi);
1615 let prn = dispersion_nb_disp_order2(yi_count, mu, theta, wi);
1616 assert_eq!(bits(full.value()), bits(prn.value()), "nb value");
1617 assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "nb grad");
1618 assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "nb hess");
1619 assert_eq!(
1621 bits(dispersion_nb_neg_loglik(yi_count, mu, theta, wi)),
1622 bits(-prn.value()),
1623 "nb value-only"
1624 );
1625 }
1626 {
1628 let mu = (0.05 + 4.0 * next()).max(1e-300);
1629 let nu = (0.05 + 6.0 * next()).max(1e-12);
1630 let yi = 0.01 + 8.0 * next();
1631 let y_pos = yi.max(1e-300);
1632 let full = dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi);
1633 let prn = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
1634 assert_eq!(bits(full.value()), bits(prn.value()), "gamma value");
1635 assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "gamma grad");
1636 assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "gamma hess");
1637 assert_eq!(
1638 bits(dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)),
1639 bits(-prn.value()),
1640 "gamma value-only"
1641 );
1642 }
1643 {
1645 let mu = (1e-6 + (1.0 - 2e-6) * next()).clamp(1e-12, 1.0 - 1e-12);
1646 let phi = (0.05 + 20.0 * next()).max(1e-12);
1647 let yi = next();
1648 let full = dispersion_beta_nll_order2(yi, mu, phi, wi);
1649 assert_eq!(
1650 bits(dispersion_beta_neg_loglik(yi, mu, phi, wi)),
1651 bits(-full.value()),
1652 "beta value-only"
1653 );
1654 }
1655 for &(yi, eta_mu, eta_d, p) in &[
1657 (0.0_f64, -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1658 (0.01 + 9.0 * next(), -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1659 (3.0, -DISPERSION_ETA_CLAMP - 5.0, DISPERSION_ETA_CLAMP + 5.0, 1.5),
1660 ] {
1661 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1662 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1663 let full = dispersion_tweedie_nll_generic::<Order2<2>>(yi, em, ed, p, wi);
1664 let prn = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
1665 assert_eq!(bits(full.value()), bits(prn.value()), "tweedie value");
1666 assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "tweedie grad");
1667 assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "tweedie hess");
1668 assert_eq!(
1669 bits(dispersion_tweedie_neg_loglik(yi, em, ed, p, wi)),
1670 bits(-prn.value()),
1671 "tweedie value-only"
1672 );
1673 }
1674 }
1675 }
1676
1677 #[test]
1678 pub(crate) fn orthogonal_dispersion_families_report_zero_cross_weight() {
1679 let cases = [
1680 DispersionFamilyKind::NegativeBinomial,
1681 DispersionFamilyKind::Gamma,
1682 DispersionFamilyKind::Tweedie { p: 1.5 },
1683 ];
1684 for kind in cases {
1685 let got = dispersion_row_cross_weight(kind, 1.25, 0.2, -0.3, 2.0);
1686 assert_close(kind.family_tag(), got, 0.0, 1e-12);
1687 }
1688 }
1689
1690 #[test]
1700 pub(crate) fn parallel_evaluate_matches_serial_reference() {
1701 let n = DISPERSION_PARALLEL_ROW_THRESHOLD * 3 + 7;
1702 let mut state: u64 = 0xD1B5_4A32_D192_ED03;
1704 let mut next = || {
1705 state = state
1706 .wrapping_mul(6364136223846793005)
1707 .wrapping_add(1442695040888963407);
1708 ((state >> 11) as f64) / ((1u64 << 53) as f64)
1709 };
1710
1711 for kind in [
1712 DispersionFamilyKind::NegativeBinomial,
1713 DispersionFamilyKind::Gamma,
1714 DispersionFamilyKind::Beta,
1715 DispersionFamilyKind::Tweedie { p: 1.5 },
1716 ] {
1717 let y = Array1::from_shape_fn(n, |_| match kind {
1718 DispersionFamilyKind::Beta => 1e-3 + (1.0 - 2e-3) * next(),
1719 DispersionFamilyKind::NegativeBinomial => (next() * 12.0).floor(),
1720 _ => 0.05 + 8.0 * next(),
1721 });
1722 let weights = Array1::from_shape_fn(n, |_| 0.25 + 2.0 * next());
1723 let eta_mu = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1724 let eta_d = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1725
1726 let family = DispersionGlmLocationScaleFamily {
1727 kind,
1728 y: y.clone(),
1729 weights: weights.clone(),
1730 };
1731 let states = vec![
1732 ParameterBlockState {
1733 beta: Array1::zeros(0),
1734 eta: eta_mu.clone(),
1735 },
1736 ParameterBlockState {
1737 beta: Array1::zeros(0),
1738 eta: eta_d.clone(),
1739 },
1740 ];
1741
1742 let mut ll_ref = 0.0;
1744 let mut mw_ref = Array1::<f64>::zeros(n);
1745 let mut mr_ref = Array1::<f64>::zeros(n);
1746 let mut dw_ref = Array1::<f64>::zeros(n);
1747 let mut dr_ref = Array1::<f64>::zeros(n);
1748 for i in 0..n {
1749 let row = dispersion_row_kernel(kind, y[i], eta_mu[i], eta_d[i], weights[i]);
1750 if row.loglik.is_finite() {
1751 ll_ref += row.loglik;
1752 }
1753 mw_ref[i] = row.mean_weight.max(0.0);
1754 mr_ref[i] = row.mean_response;
1755 dw_ref[i] = row.disp_weight.max(0.0);
1756 dr_ref[i] = row.disp_response;
1757 }
1758
1759 let eval = family.evaluate(&states).expect("parallel evaluate");
1760 assert_close(
1761 &format!("{kind:?} evaluate log-likelihood"),
1762 eval.log_likelihood,
1763 ll_ref,
1764 1e-9,
1765 );
1766
1767 let BlockWorkingSet::Diagonal {
1768 working_response: mr,
1769 working_weights: mw,
1770 } = &eval.blockworking_sets[0]
1771 else {
1772 panic!("mean block not diagonal");
1773 };
1774 let BlockWorkingSet::Diagonal {
1775 working_response: dr,
1776 working_weights: dw,
1777 } = &eval.blockworking_sets[1]
1778 else {
1779 panic!("dispersion block not diagonal");
1780 };
1781 for i in 0..n {
1782 assert_close("mean weight", mw[i], mw_ref[i], 1e-9);
1783 assert_close("mean response", mr[i], mr_ref[i], 1e-9);
1784 assert_close("disp weight", dw[i], dw_ref[i], 1e-9);
1785 assert_close("disp response", dr[i], dr_ref[i], 1e-9);
1786 }
1787
1788 let ll_only = family
1791 .log_likelihood_only(&states)
1792 .expect("parallel log_likelihood_only");
1793 assert_close(
1794 &format!("{kind:?} log_likelihood_only"),
1795 ll_only,
1796 ll_ref,
1797 1e-9,
1798 );
1799 }
1800 }
1801}