1use super::weighted_design_products::{mirror_upper_to_lower, xt_diag_x_design, xt_diag_y_design};
7use super::{
10 BlockwiseTermFitResult, GamlssLambdaLayout, LOCATION_SCALE_N_OUTPUTS,
11 LocationScaleFamilyBuilder, build_location_scale_block, fit_location_scale_terms,
12 identity_penalty, solve_penalizedweighted_projection,
13};
14use crate::block_layout::block_count::validate_block_count;
15use crate::custom_family::{
16 BlockWorkingSet, BlockwiseFitOptions, CustomFamily, CustomFamilyBlockPsiDerivative,
17 FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
18};
19use crate::gamlss::GamlssError;
20use crate::model_types::UnifiedFitResult;
21use gam_linalg::matrix::LinearOperator;
22use gam_math::jet_scalar::JetScalar;
23use gam_terms::smooth::{
24 SpatialLengthScaleOptimizationOptions, TermCollectionDesign, TermCollectionSpec,
25};
26use ndarray::{Array1, Array2, s};
27use statrs::function::gamma::ln_gamma;
28
29#[derive(Clone, Copy, Debug, PartialEq)]
79pub enum DispersionFamilyKind {
80 NegativeBinomial,
82 Gamma,
84 Beta,
87 Tweedie { p: f64 },
94}
95
96impl DispersionFamilyKind {
97 pub const fn family_tag(self) -> &'static str {
98 match self {
99 DispersionFamilyKind::NegativeBinomial => FAMILY_NEGBIN_LOCATION_SCALE,
100 DispersionFamilyKind::Gamma => FAMILY_GAMMA_LOCATION_SCALE,
101 DispersionFamilyKind::Beta => FAMILY_BETA_LOCATION_SCALE,
102 DispersionFamilyKind::Tweedie { .. } => FAMILY_TWEEDIE_LOCATION_SCALE,
103 }
104 }
105
106 pub(crate) const fn mean_is_logit(self) -> bool {
108 matches!(self, DispersionFamilyKind::Beta)
109 }
110
111 pub fn base_link(self) -> gam_problem::InverseLink {
116 use gam_problem::{InverseLink, StandardLink};
117 if self.mean_is_logit() {
118 InverseLink::Standard(StandardLink::Logit)
119 } else {
120 InverseLink::Standard(StandardLink::Log)
121 }
122 }
123
124 pub fn likelihood_spec(self) -> gam_problem::LikelihoodSpec {
132 use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
133 let response = match self {
134 DispersionFamilyKind::NegativeBinomial => ResponseFamily::NegativeBinomial {
135 theta: 1.0,
136 theta_fixed: false,
137 },
138 DispersionFamilyKind::Gamma => ResponseFamily::Gamma,
139 DispersionFamilyKind::Beta => ResponseFamily::Beta { phi: 1.0 },
140 DispersionFamilyKind::Tweedie { p } => ResponseFamily::Tweedie { p },
141 };
142 let link = if self.mean_is_logit() {
143 InverseLink::Standard(StandardLink::Logit)
144 } else {
145 InverseLink::Standard(StandardLink::Log)
146 };
147 LikelihoodSpec::new(response, link)
148 }
149}
150
151pub const FAMILY_NEGBIN_LOCATION_SCALE: &str = "negbin-location-scale";
152pub const FAMILY_GAMMA_LOCATION_SCALE: &str = "gamma-location-scale";
153pub const FAMILY_BETA_LOCATION_SCALE: &str = "beta-location-scale";
154pub const FAMILY_TWEEDIE_LOCATION_SCALE: &str = "tweedie-location-scale";
155
156pub(super) const DISPERSION_ETA_CLAMP: f64 = 30.0;
160pub(super) const DISPERSION_MIN_CURVATURE: f64 = 1e-12;
165
166const DISPERSION_PARALLEL_ROW_THRESHOLD: usize = 1024;
172
173pub(super) struct DispersionRowKernel {
175 pub(super) loglik: f64,
176 pub(super) mean_weight: f64,
177 pub(super) mean_response: f64,
178 pub(super) disp_weight: f64,
179 pub(super) disp_response: f64,
180}
181
182#[cfg(test)]
183mod test_support {
184 use super::*;
185
186 #[inline]
189 pub(super) fn dispersion_nb_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
190 yi: f64,
191 mu_value: f64,
192 theta_value: f64,
193 wi: f64,
194 ) -> S {
195 let mu = S::variable(mu_value, 0);
196 let theta = S::variable(theta_value, 1);
197 let tpm = theta.add(&mu);
198 let loglik = theta
201 .add(&S::constant(yi))
202 .ln_gamma()
203 .sub(&theta.ln_gamma())
204 .sub(&S::constant(ln_gamma(yi + 1.0)))
205 .add(&theta.mul(&theta.ln()))
206 .sub(&theta.mul(&tpm.ln()))
207 .add(&mu.ln().scale(yi))
208 .sub(&tpm.ln().scale(yi));
209 loglik.scale(-wi)
210 }
211
212 #[inline]
215 pub(super) fn dispersion_gamma_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
216 yi: f64,
217 y_pos: f64,
218 mu_value: f64,
219 nu_value: f64,
220 wi: f64,
221 ) -> S {
222 let mu = S::variable(mu_value, 0);
223 let nu = S::variable(nu_value, 1);
224 let loglik = nu
226 .mul(&nu.ln())
227 .sub(&nu.mul(&mu.ln()))
228 .sub(&nu.ln_gamma())
229 .add(&nu.sub(&S::constant(1.0)).scale(y_pos.ln()))
230 .sub(&nu.mul(&mu.recip().scale(yi)));
231 loglik.scale(-wi)
232 }
233
234 #[inline]
237 pub(super) fn dispersion_beta_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
238 yi: f64,
239 mu_value: f64,
240 phi_value: f64,
241 wi: f64,
242 ) -> S {
243 let mu = S::variable(mu_value, 0);
244 let phi = S::variable(phi_value, 1);
245 let one_minus_mu = S::constant(1.0).sub(&mu);
246 let yc = yi.clamp(1e-12, 1.0 - 1e-12);
247 let a = mu.mul(&phi);
248 let b = one_minus_mu.mul(&phi);
249 let loglik = phi
252 .ln_gamma()
253 .sub(&a.ln_gamma())
254 .sub(&b.ln_gamma())
255 .add(&a.sub(&S::constant(1.0)).scale(yc.ln()))
256 .add(&b.sub(&S::constant(1.0)).scale((1.0 - yc).ln()));
257 loglik.scale(-wi)
258 }
259
260 #[inline]
269 pub(super) fn dispersion_nb_nll_order2(
270 yi: f64,
271 mu_value: f64,
272 theta_value: f64,
273 wi: f64,
274 ) -> gam_math::jet_scalar::Order2<2> {
275 type O2 = gam_math::jet_scalar::Order2<2>;
276
277 let mu = O2::variable(mu_value, 0);
278 let theta = O2::variable(theta_value, 1);
279 let tpm = theta.add(&mu);
280 let theta_plus_y = theta.add(&O2::constant(yi));
281 let loglik = order2_ln_gamma(&theta_plus_y)
282 .sub(&order2_ln_gamma(&theta))
283 .sub(&O2::constant(ln_gamma(yi + 1.0)))
284 .add(&theta.mul(&theta.ln()))
285 .sub(&theta.mul(&tpm.ln()))
286 .add(&mu.ln().scale(yi))
287 .sub(&tpm.ln().scale(yi));
288 loglik.scale(-wi)
289 }
290
291 #[inline]
296 pub(super) fn dispersion_gamma_nll_order2(
297 yi: f64,
298 y_pos: f64,
299 mu_value: f64,
300 nu_value: f64,
301 wi: f64,
302 ) -> gam_math::jet_scalar::Order2<2> {
303 type O2 = gam_math::jet_scalar::Order2<2>;
304
305 let mu = O2::variable(mu_value, 0);
306 let nu = O2::variable(nu_value, 1);
307 let loglik = nu
308 .mul(&nu.ln())
309 .sub(&nu.mul(&mu.ln()))
310 .sub(&order2_ln_gamma(&nu))
311 .add(&nu.sub(&O2::constant(1.0)).scale(y_pos.ln()))
312 .sub(&nu.mul(&mu.recip().scale(yi)));
313 loglik.scale(-wi)
314 }
315}
316
317#[inline]
320pub(crate) fn dispersion_beta_nll_order2(
321 yi: f64,
322 mu_value: f64,
323 phi_value: f64,
324 wi: f64,
325) -> gam_math::jet_scalar::Order2<2> {
326 type O2 = gam_math::jet_scalar::Order2<2>;
327
328 let mu = O2::variable(mu_value, 0);
329 let phi = O2::variable(phi_value, 1);
330 let one_minus_mu = O2::constant(1.0).sub(&mu);
331 let yc = yi.clamp(1e-12, 1.0 - 1e-12);
332 let a = mu.mul(&phi);
333 let b = one_minus_mu.mul(&phi);
334 let loglik = order2_ln_gamma(&phi)
335 .sub(&order2_ln_gamma(&a))
336 .sub(&order2_ln_gamma(&b))
337 .add(&a.sub(&O2::constant(1.0)).scale(yc.ln()))
338 .add(&b.sub(&O2::constant(1.0)).scale((1.0 - yc).ln()));
339 loglik.scale(-wi)
340}
341
342#[inline]
343fn order2_ln_gamma<const K: usize>(
344 x: &gam_math::jet_scalar::Order2<K>,
345) -> gam_math::jet_scalar::Order2<K> {
346 gam_math::jet_scalar::Order2(
347 x.0.compose_unary(gam_math::jet_tower::ln_gamma_derivative_stack_order2(x.0.v)),
348 )
349}
350
351#[inline]
376fn order1_ln_gamma<const K: usize>(
377 x: &gam_math::jet_scalar::Order1<K>,
378) -> gam_math::jet_scalar::Order1<K> {
379 x.compose_unary([ln_gamma(x.v), gam_math::jet_tower::digamma(x.v), 0.0, 0.0, 0.0])
380}
381
382#[inline]
391pub(crate) fn dispersion_nb_disp_order1(
392 yi: f64,
393 mu_value: f64,
394 theta_value: f64,
395 wi: f64,
396) -> gam_math::jet_scalar::Order1<1> {
397 type O1 = gam_math::jet_scalar::Order1<1>;
398
399 let mu = O1::constant(mu_value);
400 let theta = O1::variable(theta_value, 0);
401 let tpm = theta.add(&mu);
402 let theta_plus_y = theta.add(&O1::constant(yi));
403 let loglik = order1_ln_gamma(&theta_plus_y)
404 .sub(&order1_ln_gamma(&theta))
405 .sub(&O1::constant(ln_gamma(yi + 1.0)))
406 .add(&theta.mul(&theta.ln()))
407 .sub(&theta.mul(&tpm.ln()))
408 .add(&mu.ln().scale(yi))
409 .sub(&tpm.ln().scale(yi));
410 loglik.scale(-wi)
411}
412
413#[inline]
422pub(crate) fn dispersion_gamma_disp_order2(
423 yi: f64,
424 y_pos: f64,
425 mu_value: f64,
426 nu_value: f64,
427 wi: f64,
428) -> gam_math::jet_scalar::Order2<1> {
429 type O1 = gam_math::jet_scalar::Order2<1>;
430
431 let mu = O1::constant(mu_value);
432 let nu = O1::variable(nu_value, 0);
433 let loglik = nu
434 .mul(&nu.ln())
435 .sub(&nu.mul(&mu.ln()))
436 .sub(&order2_ln_gamma(&nu))
437 .add(&nu.sub(&O1::constant(1.0)).scale(y_pos.ln()))
438 .sub(&nu.mul(&mu.recip().scale(yi)));
439 loglik.scale(-wi)
440}
441
442#[inline]
448pub(crate) fn dispersion_tweedie_disp_order2(
449 yi: f64,
450 eta_mu: f64,
451 eta_d: f64,
452 p: f64,
453 wi: f64,
454) -> gam_math::jet_scalar::Order2<1> {
455 type O1 = gam_math::jet_scalar::Order2<1>;
456
457 let one_minus_p = 1.0 - p;
458 let two_minus_p = 2.0 - p;
459 let mu = O1::constant(eta_mu).exp();
460 let phi = O1::variable(eta_d, 0).scale(-1.0).exp();
461 if yi > 0.0 {
462 let dev = mu
463 .powf(two_minus_p)
464 .scale(1.0 / two_minus_p)
465 .sub(&mu.powf(one_minus_p).scale(yi / one_minus_p))
466 .add(&O1::constant(
467 yi.powf(two_minus_p) / (one_minus_p * two_minus_p),
468 ))
469 .scale(2.0);
470 let loglik = dev
471 .mul(&phi.recip().scale(-0.5))
472 .sub(&phi.scale(2.0 * std::f64::consts::PI).ln().scale(0.5))
473 .sub(&O1::constant(0.5 * p * yi.ln()));
474 loglik.scale(-wi)
475 } else {
476 let c = mu.powf(two_minus_p).scale(1.0 / two_minus_p);
477 let loglik = c.mul(&phi.recip()).scale(-1.0);
478 loglik.scale(-wi)
479 }
480}
481
482#[inline]
498fn dispersion_nb_neg_loglik(yi: f64, mu: f64, theta: f64, wi: f64) -> f64 {
499 let tpm = theta + mu;
500 let s = ln_gamma(theta + yi) - ln_gamma(theta) - ln_gamma(yi + 1.0) + theta * theta.ln()
501 - theta * tpm.ln()
502 + mu.ln() * yi
503 - tpm.ln() * yi;
504 -(s * -wi)
505}
506
507#[inline]
510fn dispersion_gamma_neg_loglik(yi: f64, y_pos: f64, mu: f64, nu: f64, wi: f64) -> f64 {
511 let s = nu * nu.ln() - nu * mu.ln() - ln_gamma(nu) + (nu - 1.0) * y_pos.ln()
515 - nu * ((1.0 / mu) * yi);
516 -(s * -wi)
517}
518
519#[inline]
522fn dispersion_beta_neg_loglik(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
523 let one_minus_mu = 1.0 - mu;
524 let yc = yi.clamp(1e-12, 1.0 - 1e-12);
525 let a = mu * phi;
526 let b = one_minus_mu * phi;
527 let s = ln_gamma(phi) - ln_gamma(a) - ln_gamma(b)
528 + (a - 1.0) * yc.ln()
529 + (b - 1.0) * (1.0 - yc).ln();
530 -(s * -wi)
531}
532
533#[inline]
536fn dispersion_tweedie_neg_loglik(yi: f64, eta_mu: f64, eta_d: f64, p: f64, wi: f64) -> f64 {
537 let one_minus_p = 1.0 - p;
538 let two_minus_p = 2.0 - p;
539 let mu = eta_mu.exp();
540 let phi = (-eta_d).exp();
541 let s = if yi > 0.0 {
542 let dev = (mu.powf(two_minus_p) * (1.0 / two_minus_p)
543 - mu.powf(one_minus_p) * (yi / one_minus_p)
544 + yi.powf(two_minus_p) / (one_minus_p * two_minus_p))
545 * 2.0;
546 dev * ((1.0 / phi) * -0.5)
547 - (phi * (2.0 * std::f64::consts::PI)).ln() * 0.5
548 - 0.5 * p * yi.ln()
549 } else {
550 let c = mu.powf(two_minus_p) * (1.0 / two_minus_p);
551 (c * (1.0 / phi)) * -1.0
552 };
553 -(s * -wi)
554}
555
556#[inline]
562pub(crate) fn dispersion_row_loglik(
563 kind: DispersionFamilyKind,
564 yi: f64,
565 eta_mu: f64,
566 eta_d: f64,
567 prior_weight: f64,
568) -> f64 {
569 let wi = prior_weight.max(0.0);
570 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
571 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
572 match kind {
573 DispersionFamilyKind::NegativeBinomial => {
574 let mu = em.exp().max(1e-300);
575 let theta = ed.exp().max(1e-12);
576 dispersion_nb_neg_loglik(yi, mu, theta, wi)
577 }
578 DispersionFamilyKind::Gamma => {
579 let mu = em.exp().max(1e-300);
580 let nu = ed.exp().max(1e-12);
581 let y_pos = yi.max(1e-300);
582 dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)
583 }
584 DispersionFamilyKind::Beta => {
585 let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
586 let phi = ed.exp().max(1e-12);
587 dispersion_beta_neg_loglik(yi, mu, phi, wi)
588 }
589 DispersionFamilyKind::Tweedie { p } => dispersion_tweedie_neg_loglik(yi, em, ed, p, wi),
590 }
591}
592
593#[inline]
594pub(crate) fn beta_observed_cross_weight_eta(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
595 let q = (mu * (1.0 - mu)).max(1e-12);
596 let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
597 q * phi * tower.h()[0][1]
598}
599
600#[inline]
601pub(crate) fn dispersion_row_cross_weight(
602 kind: DispersionFamilyKind,
603 yi: f64,
604 eta_mu: f64,
605 eta_d: f64,
606 prior_weight: f64,
607) -> f64 {
608 let wi = prior_weight.max(0.0);
609 if wi == 0.0 {
610 return 0.0;
611 }
612 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
613 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
614 match kind {
615 DispersionFamilyKind::Beta => {
616 let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
617 let phi = ed.exp().max(1e-12);
618 beta_observed_cross_weight_eta(yi, mu, phi, wi)
619 }
620 DispersionFamilyKind::NegativeBinomial
621 | DispersionFamilyKind::Gamma
622 | DispersionFamilyKind::Tweedie { .. } => 0.0,
623 }
624}
625
626#[inline]
627pub(crate) fn tower_score_info<const K: usize>(
628 tower: &gam_math::jet_scalar::Order2<K>,
629 idx: usize,
630 wi: f64,
631) -> (f64, f64) {
632 if wi == 0.0 {
633 (0.0, 0.0)
634 } else {
635 (-tower.g()[idx] / wi, tower.h()[idx][idx] / wi)
636 }
637}
638
639pub(super) fn dispersion_row_kernel(
644 kind: DispersionFamilyKind,
645 yi: f64,
646 eta_mu: f64,
647 eta_d: f64,
648 prior_weight: f64,
649) -> DispersionRowKernel {
650 let wi = prior_weight.max(0.0);
651 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
652 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
653 match kind {
654 DispersionFamilyKind::NegativeBinomial => {
655 let mu = em.exp().max(1e-300);
656 let theta = ed.exp().max(1e-12); let tpm = theta + mu;
658 let tower = dispersion_nb_disp_order1(yi, mu, theta, wi);
665 let s_theta = if wi == 0.0 { 0.0 } else { -tower.g()[0] / wi };
666 let loglik = -tower.value();
667 let info_mu = if wi == 0.0 {
668 DISPERSION_MIN_CURVATURE
669 } else {
670 (theta / (mu * tpm)).max(DISPERSION_MIN_CURVATURE)
671 };
672 let score_mu = theta * (yi - mu) / (mu * tpm);
673 let mean_weight = wi * mu * mu * info_mu;
674 let mean_response = em + score_mu / (mu * info_mu);
675 let trigamma_theta = gam_math::jet_tower::trigamma(theta);
712 let trigamma_tpm = gam_math::jet_tower::trigamma(tpm);
713 let info_theta_fisher = trigamma_theta - trigamma_tpm - 1.0 / theta + 1.0 / tpm;
714 let info_pos = info_theta_fisher.max(DISPERSION_MIN_CURVATURE);
715 let disp_weight = wi * theta * theta * info_pos;
716 let disp_response = ed + s_theta / (theta * info_pos);
717 DispersionRowKernel {
718 loglik,
719 mean_weight,
720 mean_response,
721 disp_weight,
722 disp_response,
723 }
724 }
725 DispersionFamilyKind::Gamma => {
726 let mu = em.exp().max(1e-300);
727 let nu = ed.exp().max(1e-12); let y_pos = yi.max(1e-300);
729 let tower = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
730 let (s_nu, info_nu_raw) = tower_score_info(&tower, 0, wi);
731 let loglik = -tower.value();
732 let info_mu = if wi == 0.0 {
733 DISPERSION_MIN_CURVATURE
734 } else {
735 (nu / (mu * mu)).max(DISPERSION_MIN_CURVATURE)
736 };
737 let score_mu = nu * (yi - mu) / (mu * mu);
738 let mean_weight = wi * mu * mu * info_mu;
739 let mean_response = em + score_mu / (mu * info_mu);
740 let info_nu = info_nu_raw.max(DISPERSION_MIN_CURVATURE);
741 let disp_weight = wi * nu * nu * info_nu;
742 let disp_response = ed + s_nu / (nu * info_nu);
743 DispersionRowKernel {
744 loglik,
745 mean_weight,
746 mean_response,
747 disp_weight,
748 disp_response,
749 }
750 }
751 DispersionFamilyKind::Beta => {
752 let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
754 let phi = ed.exp().max(1e-12); let q = (mu * (1.0 - mu)).max(1e-12); let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
757 let (score_mu, info_mu_raw) = tower_score_info(&tower, 0, wi);
758 let (s_phi, info_phi_raw) = tower_score_info(&tower, 1, wi);
759 let loglik = -tower.value();
760 let info_mu = info_mu_raw.max(DISPERSION_MIN_CURVATURE);
761 let mean_weight = wi * q * q * info_mu;
762 let mean_response = em + score_mu / (q * info_mu);
763 let info_phi = info_phi_raw.max(DISPERSION_MIN_CURVATURE);
764 let disp_weight = wi * phi * phi * info_phi;
765 let disp_response = ed + s_phi / (phi * info_phi);
766 DispersionRowKernel {
767 loglik,
768 mean_weight,
769 mean_response,
770 disp_weight,
771 disp_response,
772 }
773 }
774 DispersionFamilyKind::Tweedie { p } => {
775 let mu = em.exp().max(1e-300);
776 let phi = (-ed).exp().max(1e-12);
778 let two_minus_p = 2.0 - p;
779 let mean_weight = wi * mu.powf(two_minus_p) / phi;
785 let mean_response = em + (yi - mu) / mu;
786 let tower = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
795 let loglik = -tower.value();
796 let (s_eta, info_eta_raw) = tower_score_info(&tower, 0, wi);
800 let curvature_eta = if wi == 0.0 {
801 DISPERSION_MIN_CURVATURE
802 } else {
803 info_eta_raw.max(DISPERSION_MIN_CURVATURE)
804 };
805 let disp_weight = wi * curvature_eta;
809 let disp_response = ed + s_eta / curvature_eta;
810 DispersionRowKernel {
811 loglik,
812 mean_weight,
813 mean_response,
814 disp_weight,
815 disp_response,
816 }
817 }
818 }
819}
820
821#[derive(Clone)]
823pub(crate) struct DispersionGlmLocationScaleFamily {
824 pub(crate) kind: DispersionFamilyKind,
825 pub(crate) y: Array1<f64>,
826 pub(crate) weights: Array1<f64>,
827}
828
829impl DispersionGlmLocationScaleFamily {
830 pub(crate) const BLOCK_MEAN: usize = 0;
831 pub(crate) const BLOCK_DISP: usize = 1;
832}
833
834impl CustomFamily for DispersionGlmLocationScaleFamily {
835 fn joint_jeffreys_term_required(&self) -> bool {
839 true
840 }
841
842 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
843 validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
844 let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
845 let eta_d = &block_states[Self::BLOCK_DISP].eta;
846 let n = self.y.len();
847 if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
848 return Err(format!(
849 "{} row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
850 self.kind.family_tag(),
851 eta_mu.len(),
852 eta_d.len(),
853 self.weights.len()
854 ));
855 }
856 let mut mean_weights = Array1::<f64>::zeros(n);
857 let mut mean_response = Array1::<f64>::zeros(n);
858 let mut disp_weights = Array1::<f64>::zeros(n);
859 let mut disp_response = Array1::<f64>::zeros(n);
860
861 let kernels: Vec<DispersionRowKernel> = if rayon::current_thread_index().is_none()
872 && n > DISPERSION_PARALLEL_ROW_THRESHOLD
873 {
874 use rayon::iter::{IntoParallelIterator, ParallelIterator};
875 (0..n)
876 .into_par_iter()
877 .map(|i| {
878 dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
879 })
880 .collect()
881 } else {
882 (0..n)
883 .map(|i| {
884 dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
885 })
886 .collect()
887 };
888
889 let mut log_likelihood = 0.0;
890 for (i, row) in kernels.into_iter().enumerate() {
891 if row.loglik.is_finite() {
892 log_likelihood += row.loglik;
893 }
894 mean_weights[i] = row.mean_weight.max(0.0);
895 mean_response[i] = row.mean_response;
896 disp_weights[i] = row.disp_weight.max(0.0);
897 disp_response[i] = row.disp_response;
898 }
899 Ok(FamilyEvaluation {
900 log_likelihood,
901 blockworking_sets: vec![
902 BlockWorkingSet::diagonal_checked(mean_response, mean_weights)?,
903 BlockWorkingSet::diagonal_checked(disp_response, disp_weights)?,
904 ],
905 })
906 }
907
908 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
909 validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
910 let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
911 let eta_d = &block_states[Self::BLOCK_DISP].eta;
912 let n = self.y.len();
913 let per_row: Vec<f64> = if rayon::current_thread_index().is_none()
922 && n > DISPERSION_PARALLEL_ROW_THRESHOLD
923 {
924 use rayon::iter::{IntoParallelIterator, ParallelIterator};
925 (0..n)
926 .into_par_iter()
927 .map(|i| {
928 dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
929 })
930 .collect()
931 } else {
932 (0..n)
933 .map(|i| {
934 dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
935 })
936 .collect()
937 };
938 let mut ll = 0.0;
939 for loglik in per_row {
940 if loglik.is_finite() {
941 ll += loglik;
942 }
943 }
944 Ok(ll)
945 }
946
947 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
948 crate::location_scale_engine::location_scale_coefficient_hessian_cost(
949 self.y.len() as u64,
950 specs,
951 )
952 }
953
954 fn exact_newton_joint_hessian_with_specs(
973 &self,
974 block_states: &[ParameterBlockState],
975 specs: &[ParameterBlockSpec],
976 ) -> Result<Option<Array2<f64>>, String> {
977 validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
978 if specs.len() != 2 {
979 return Err(format!(
980 "{} exact joint Hessian expects 2 specs, got {}",
981 self.kind.family_tag(),
982 specs.len()
983 ));
984 }
985 let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
986 let eta_d = &block_states[Self::BLOCK_DISP].eta;
987 let n = self.y.len();
988 if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
989 return Err(format!(
990 "{} exact joint Hessian row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
991 self.kind.family_tag(),
992 eta_mu.len(),
993 eta_d.len(),
994 self.weights.len()
995 ));
996 }
997
998 let eval = self.evaluate(block_states)?;
999 let BlockWorkingSet::Diagonal {
1000 working_weights: mean_weights,
1001 ..
1002 } = &eval.blockworking_sets[Self::BLOCK_MEAN]
1003 else {
1004 return Err(format!(
1005 "{} dispersion mean block did not return diagonal weights",
1006 self.kind.family_tag()
1007 ));
1008 };
1009 let BlockWorkingSet::Diagonal {
1010 working_weights: disp_weights,
1011 ..
1012 } = &eval.blockworking_sets[Self::BLOCK_DISP]
1013 else {
1014 return Err(format!(
1015 "{} dispersion precision block did not return diagonal weights",
1016 self.kind.family_tag()
1017 ));
1018 };
1019
1020 let cross_weights = if rayon::current_thread_index().is_none()
1026 && n > DISPERSION_PARALLEL_ROW_THRESHOLD
1027 {
1028 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1029 Array1::from_vec(
1030 (0..n)
1031 .into_par_iter()
1032 .map(|i| {
1033 dispersion_row_cross_weight(
1034 self.kind,
1035 self.y[i],
1036 eta_mu[i],
1037 eta_d[i],
1038 self.weights[i],
1039 )
1040 })
1041 .collect::<Vec<f64>>(),
1042 )
1043 } else {
1044 Array1::from_shape_fn(n, |i| {
1045 dispersion_row_cross_weight(
1046 self.kind,
1047 self.y[i],
1048 eta_mu[i],
1049 eta_d[i],
1050 self.weights[i],
1051 )
1052 })
1053 };
1054 let mean_spec = &specs[Self::BLOCK_MEAN];
1055 let disp_spec = &specs[Self::BLOCK_DISP];
1056 if mean_spec.design.nrows() != n || disp_spec.design.nrows() != n {
1057 return Err(format!(
1058 "{} exact joint Hessian design row mismatch: y={n}, mean rows={}, precision rows={}",
1059 self.kind.family_tag(),
1060 mean_spec.design.nrows(),
1061 disp_spec.design.nrows()
1062 ));
1063 }
1064 let p_mean = mean_spec.design.ncols();
1065 let p_disp = disp_spec.design.ncols();
1066 if block_states[Self::BLOCK_MEAN].beta.len() != p_mean
1067 || block_states[Self::BLOCK_DISP].beta.len() != p_disp
1068 {
1069 return Err(format!(
1070 "{} exact joint Hessian beta/design mismatch: mean beta {} vs cols {}, precision beta {} vs cols {}",
1071 self.kind.family_tag(),
1072 block_states[Self::BLOCK_MEAN].beta.len(),
1073 p_mean,
1074 block_states[Self::BLOCK_DISP].beta.len(),
1075 p_disp
1076 ));
1077 }
1078
1079 let h_mean = xt_diag_x_design(&mean_spec.design, mean_weights)?;
1080 let h_cross = xt_diag_y_design(&mean_spec.design, &cross_weights, &disp_spec.design)?;
1081 let h_disp = xt_diag_x_design(&disp_spec.design, disp_weights)?;
1082 let total = p_mean + p_disp;
1083 let mut h = Array2::<f64>::zeros((total, total));
1084 h.slice_mut(s![0..p_mean, 0..p_mean]).assign(&h_mean);
1085 h.slice_mut(s![0..p_mean, p_mean..total]).assign(&h_cross);
1086 h.slice_mut(s![p_mean..total, p_mean..total])
1087 .assign(&h_disp);
1088 mirror_upper_to_lower(&mut h);
1089 Ok(Some(h))
1090 }
1091
1092 fn likelihood_blocks_uncoupled(&self) -> bool {
1106 !matches!(self.kind, DispersionFamilyKind::Beta)
1107 }
1108
1109 fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
1120 assert!(
1121 crate::custom_family::validate_blockspec_consistency(specs).is_ok(),
1122 "DispersionGlmLocationScale outer hyper-Hessian dense availability: \
1123 inconsistent parameter block specs"
1124 );
1125 specs.len() < 2
1126 }
1127}
1128
1129pub struct DispersionGlmLocationScaleTermSpec {
1133 pub kind: DispersionFamilyKind,
1134 pub y: Array1<f64>,
1135 pub weights: Array1<f64>,
1136 pub meanspec: TermCollectionSpec,
1137 pub log_dispspec: TermCollectionSpec,
1138 pub mean_offset: Array1<f64>,
1139 pub log_disp_offset: Array1<f64>,
1140}
1141
1142pub(crate) struct DispersionGlmLocationScaleTermBuilder {
1143 pub(crate) kind: DispersionFamilyKind,
1144 pub(crate) y: Array1<f64>,
1145 pub(crate) weights: Array1<f64>,
1146 pub(crate) meanspec: TermCollectionSpec,
1147 pub(crate) noisespec: TermCollectionSpec,
1148 pub(crate) mean_offset: Array1<f64>,
1149 pub(crate) noise_offset: Array1<f64>,
1150}
1151
1152pub(crate) fn dispersion_location_scale_warm_start(
1156 kind: DispersionFamilyKind,
1157 y: &Array1<f64>,
1158 weights: &Array1<f64>,
1159 mean_block: &ParameterBlockSpec,
1160 disp_block: &ParameterBlockSpec,
1161 mean_beta_hint: Option<&Array1<f64>>,
1162 disp_beta_hint: Option<&Array1<f64>>,
1163) -> Result<(Array1<f64>, Array1<f64>), String> {
1164 let ridge_floor = 1e-10;
1165 let mean_beta = if let Some(beta) = mean_beta_hint {
1166 beta.clone()
1167 } else {
1168 let target = Array1::from_shape_fn(y.len(), |i| {
1169 if kind.mean_is_logit() {
1170 let yi = y[i].clamp(1e-3, 1.0 - 1e-3);
1171 (yi / (1.0 - yi)).ln()
1172 } else {
1173 (y[i].max(0.0) + 0.1).ln()
1175 }
1176 });
1177 solve_penalizedweighted_projection(
1178 &mean_block.design,
1179 &mean_block.offset,
1180 &target,
1181 weights,
1182 &mean_block.penalties,
1183 &mean_block.initial_log_lambdas,
1184 ridge_floor,
1185 )?
1186 };
1187 let disp_beta = if let Some(beta) = disp_beta_hint {
1188 beta.clone()
1189 } else {
1190 let mean_eta = mean_block.design.apply(&mean_beta) + &mean_block.offset;
1205 let target = Array1::from_shape_fn(y.len(), |i| {
1206 dispersion_moment_log_precision_seed(kind, y[i], mean_eta[i])
1207 });
1208 solve_penalizedweighted_projection(
1209 &disp_block.design,
1210 &disp_block.offset,
1211 &target,
1212 weights,
1213 &disp_block.penalties,
1214 &disp_block.initial_log_lambdas,
1215 ridge_floor,
1216 )?
1217 };
1218 Ok((mean_beta, disp_beta))
1219}
1220
1221#[inline]
1222fn dispersion_moment_log_precision_seed(kind: DispersionFamilyKind, yi: f64, eta_mu: f64) -> f64 {
1223 const LOG_PRECISION_FLOOR: f64 = -10.0;
1224 const LOG_PRECISION_CEILING: f64 = 10.0;
1225 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1226 let raw = match kind {
1227 DispersionFamilyKind::Beta => {
1228 0.0
1237 }
1238 DispersionFamilyKind::Gamma => {
1239 let mu = em.exp().max(1e-12);
1240 let e2 = (yi - mu).powi(2).max(1e-8 * mu * mu);
1241 (mu * mu / e2).max(1e-6).ln()
1242 }
1243 DispersionFamilyKind::NegativeBinomial => {
1244 let mu = em.exp().max(1e-12);
1245 let e2 = (yi - mu).powi(2);
1246 let excess = (e2 - mu).max(1e-6 * (mu + mu * mu));
1247 (mu * mu / excess).max(1e-6).ln()
1248 }
1249 DispersionFamilyKind::Tweedie { p } => {
1250 let mu = em.exp().max(1e-12);
1251 let e2 = (yi - mu).powi(2).max(1e-8 * mu.powf(p));
1252 (mu.powf(p) / e2).max(1e-6).ln()
1253 }
1254 };
1255 raw.clamp(LOG_PRECISION_FLOOR, LOG_PRECISION_CEILING)
1256}
1257
1258impl LocationScaleFamilyBuilder for DispersionGlmLocationScaleTermBuilder {
1259 type Family = DispersionGlmLocationScaleFamily;
1260
1261 fn meanspec(&self) -> &TermCollectionSpec {
1262 &self.meanspec
1263 }
1264
1265 fn noisespec(&self) -> &TermCollectionSpec {
1266 &self.noisespec
1267 }
1268
1269 fn noise_penalty_count(&self, noise_design: &TermCollectionDesign) -> usize {
1270 noise_design.penalties.len() + 1
1274 }
1275
1276 fn build_blocks(
1277 &self,
1278 theta: &Array1<f64>,
1279 mean_design: &TermCollectionDesign,
1280 noise_design: &TermCollectionDesign,
1281 mean_beta_hint: Option<Array1<f64>>,
1282 noise_beta_hint: Option<Array1<f64>>,
1283 ) -> Result<Vec<ParameterBlockSpec>, String> {
1284 let layout = GamlssLambdaLayout::two_block(
1285 mean_design.penalties.len(),
1286 self.noise_penalty_count(noise_design),
1287 );
1288 layout.validate_theta_len(theta.len(), "dispersion location-scale")?;
1289
1290 let mut meanspec = build_location_scale_block(
1291 "mu",
1292 mean_design.design.clone(),
1293 self.mean_offset.clone(),
1294 mean_design.penalties_as_penalty_matrix(),
1295 mean_design.nullspace_dims.clone(),
1296 layout.mean_from(theta),
1297 mean_beta_hint,
1298 0,
1299 LOCATION_SCALE_N_OUTPUTS,
1300 "DispersionLocationScale::build_blocks: mu",
1301 )?;
1302
1303 let p_disp = noise_design.design.ncols();
1304 let mut disp_penalties = noise_design.penalties_as_penalty_matrix();
1305 disp_penalties.push(PenaltyMatrix::Dense(identity_penalty(p_disp)));
1306 let mut disp_nullspace = noise_design.nullspace_dims.clone();
1307 disp_nullspace.push(0);
1308 let mut dispspec = build_location_scale_block(
1309 "log_precision",
1310 noise_design.design.clone(),
1311 self.noise_offset.clone(),
1312 disp_penalties,
1313 disp_nullspace,
1314 layout.noise_from(theta),
1315 noise_beta_hint,
1316 1,
1317 LOCATION_SCALE_N_OUTPUTS,
1318 "DispersionLocationScale::build_blocks: log_precision",
1319 )?;
1320
1321 if meanspec.initial_beta.is_none() || dispspec.initial_beta.is_none() {
1322 let (mean_beta0, disp_beta0) = dispersion_location_scale_warm_start(
1323 self.kind,
1324 &self.y,
1325 &self.weights,
1326 &meanspec,
1327 &dispspec,
1328 meanspec.initial_beta.as_ref(),
1329 dispspec.initial_beta.as_ref(),
1330 )?;
1331 if meanspec.initial_beta.is_none() {
1332 meanspec.initial_beta = Some(mean_beta0);
1333 }
1334 if dispspec.initial_beta.is_none() {
1335 dispspec.initial_beta = Some(disp_beta0);
1336 }
1337 }
1338
1339 Ok(vec![meanspec, dispspec])
1340 }
1341
1342 fn build_family(
1343 &self,
1344 mean_design: &TermCollectionDesign,
1345 noise_design: &TermCollectionDesign,
1346 ) -> Self::Family {
1347 assert_eq!(
1354 mean_design.design.nrows(),
1355 self.y.len(),
1356 "DispersionGlmLocationScale::build_family: mean design row count must match y"
1357 );
1358 assert_eq!(
1359 noise_design.design.nrows(),
1360 self.y.len(),
1361 "DispersionGlmLocationScale::build_family: noise design row count must match y"
1362 );
1363 DispersionGlmLocationScaleFamily {
1364 kind: self.kind,
1365 y: self.y.clone(),
1366 weights: self.weights.clone(),
1367 }
1368 }
1369
1370 fn extract_primary_betas(
1371 &self,
1372 fit: &UnifiedFitResult,
1373 ) -> Result<(Array1<f64>, Array1<f64>), String> {
1374 let mean_beta = fit
1375 .block_states
1376 .get(DispersionGlmLocationScaleFamily::BLOCK_MEAN)
1377 .ok_or_else(|| "missing dispersion mean block state".to_string())?
1378 .beta
1379 .clone();
1380 let disp_beta = fit
1381 .block_states
1382 .get(DispersionGlmLocationScaleFamily::BLOCK_DISP)
1383 .ok_or_else(|| "missing dispersion log-precision block state".to_string())?
1384 .beta
1385 .clone();
1386 Ok((mean_beta, disp_beta))
1387 }
1388
1389 fn build_psiderivative_blocks(
1390 &self,
1391 data: ndarray::ArrayView2<'_, f64>,
1392 meanspec: &TermCollectionSpec,
1393 noisespec: &TermCollectionSpec,
1394 mean_design: &TermCollectionDesign,
1395 noise_design: &TermCollectionDesign,
1396 ) -> Result<Vec<Vec<CustomFamilyBlockPsiDerivative>>, String> {
1397 Err(format!(
1405 "dispersion location-scale ({:?}) does not implement analytic spatial \
1406 psi derivatives; the κ/ψ joint optimizer must be disabled before \
1407 this builder is consulted. Called with data {n_rows}×{n_cols}, mean \
1408 spec (linear={mean_lin}, random={mean_re}, smooth={mean_sm}), noise \
1409 spec (linear={noise_lin}, random={noise_re}, smooth={noise_sm}), \
1410 mean design cols={mean_p}, noise design cols={noise_p}",
1411 self.kind,
1412 n_rows = data.nrows(),
1413 n_cols = data.ncols(),
1414 mean_lin = meanspec.linear_terms.len(),
1415 mean_re = meanspec.random_effect_terms.len(),
1416 mean_sm = meanspec.smooth_terms.len(),
1417 noise_lin = noisespec.linear_terms.len(),
1418 noise_re = noisespec.random_effect_terms.len(),
1419 noise_sm = noisespec.smooth_terms.len(),
1420 mean_p = mean_design.design.ncols(),
1421 noise_p = noise_design.design.ncols(),
1422 ))
1423 }
1424}
1425
1426pub fn fit_dispersion_glm_location_scale_terms(
1430 data: ndarray::ArrayView2<'_, f64>,
1431 spec: DispersionGlmLocationScaleTermSpec,
1432 options: &BlockwiseFitOptions,
1433 kappa_options: &SpatialLengthScaleOptimizationOptions,
1434) -> Result<BlockwiseTermFitResult, String> {
1435 if let DispersionFamilyKind::Tweedie { p } = spec.kind {
1436 if !(p.is_finite() && p > 1.0 && p < 2.0) {
1437 return Err(format!(
1438 "Tweedie location-scale requires a variance power strictly in (1, 2); got p={p}"
1439 ));
1440 }
1441 }
1442 let mut kappa = kappa_options.clone();
1447 kappa.enabled = false;
1448 let mut options = options.clone();
1466 options.compute_covariance = true;
1467 fit_location_scale_terms(
1468 data,
1469 DispersionGlmLocationScaleTermBuilder {
1470 kind: spec.kind,
1471 y: spec.y,
1472 weights: spec.weights,
1473 meanspec: spec.meanspec,
1474 noisespec: spec.log_dispspec,
1475 mean_offset: spec.mean_offset,
1476 noise_offset: spec.log_disp_offset,
1477 },
1478 &options,
1479 &kappa,
1480 )
1481}
1482
1483#[cfg(test)]
1484mod tests {
1485 use super::*;
1486 use super::test_support::{dispersion_gamma_nll_order2, dispersion_nb_nll_order2};
1487 use crate::gamlss::test_support::dispersion_tweedie_nll_generic;
1488
1489 #[inline]
1495 fn dispersion_nb_disp_order2(
1496 yi: f64,
1497 mu_value: f64,
1498 theta_value: f64,
1499 wi: f64,
1500 ) -> gam_math::jet_scalar::Order2<1> {
1501 use gam_math::jet_scalar::JetScalar;
1502 use statrs::function::gamma::ln_gamma;
1503 type O1 = gam_math::jet_scalar::Order2<1>;
1504
1505 let mu = O1::constant(mu_value);
1506 let theta = O1::variable(theta_value, 0);
1507 let tpm = theta.add(&mu);
1508 let theta_plus_y = theta.add(&O1::constant(yi));
1509 let loglik = order2_ln_gamma(&theta_plus_y)
1510 .sub(&order2_ln_gamma(&theta))
1511 .sub(&O1::constant(ln_gamma(yi + 1.0)))
1512 .add(&theta.mul(&theta.ln()))
1513 .sub(&theta.mul(&tpm.ln()))
1514 .add(&mu.ln().scale(yi))
1515 .sub(&tpm.ln().scale(yi));
1516 loglik.scale(-wi)
1517 }
1518
1519 pub(crate) fn beta_fisher_cross_info_mu_phi(mu: f64, phi: f64) -> f64 {
1520 let a = mu * phi;
1521 let b = (1.0 - mu) * phi;
1522 phi * (mu * gam_math::jet_tower::trigamma_derivative_stack(a)[0]
1523 - (1.0 - mu) * gam_math::jet_tower::trigamma_derivative_stack(b)[0])
1524 }
1525
1526 pub(crate) fn assert_close(label: &str, got: f64, want: f64, tol: f64) {
1527 assert!(
1528 (got - want).abs() <= tol,
1529 "{label}: got {got:.12e}, want {want:.12e}, |diff|={:.3e}",
1530 (got - want).abs()
1531 );
1532 }
1533
1534 #[test]
1535 pub(crate) fn beta_tower_mixed_channel_matches_cross_information_formula() {
1536 let mu = 0.1;
1537 let phi = 10.0;
1538 let a = mu * phi;
1539 let b = (1.0 - mu) * phi;
1540 let digamma_a = gam_math::jet_tower::digamma_derivative_stack(a)[0];
1541 let digamma_b = gam_math::jet_tower::digamma_derivative_stack(b)[0];
1542 let score_neutral_y = 1.0 / (1.0 + (-(digamma_a - digamma_b)).exp());
1543
1544 let tower = dispersion_beta_nll_order2(score_neutral_y, mu, phi, 1.0);
1545 let trigamma_a = std::f64::consts::PI * std::f64::consts::PI / 6.0;
1546 let trigamma_b = gam_math::jet_tower::trigamma_derivative_stack(b)[0];
1547 let analytic = phi * (mu * trigamma_a - (1.0 - mu) * trigamma_b);
1548 let helper = beta_fisher_cross_info_mu_phi(mu, phi);
1549
1550 assert!(
1551 analytic > 0.58,
1552 "audit example should have visibly nonzero cross information, got {analytic}"
1553 );
1554 assert_close("helper cross information", helper, analytic, 1e-12);
1555 assert_close("tower mixed channel", tower.h()[0][1], analytic, 1e-8);
1556
1557 let q = mu * (1.0 - mu);
1558 let eta_cross = beta_observed_cross_weight_eta(score_neutral_y, mu, phi, 1.0);
1559 assert_close(
1560 "eta-scale cross weight",
1561 eta_cross,
1562 q * phi * analytic,
1563 1e-8,
1564 );
1565 }
1566
1567 #[test]
1571 pub(crate) fn order2_matches_dense_tower_all_channels() {
1572 use gam_math::jet_scalar::{JetScalar, Order2};
1573 use gam_math::jet_tower::Tower4;
1574
1575 fn check_o2_vs_tower4(label: &str, o2: Order2<2>, t4: Tower4<2>) {
1576 let band = |a: f64, b: f64| 1e-9 + 1e-9 * a.abs().max(b.abs());
1577 assert!(
1578 (o2.value() - t4.v).abs() <= band(o2.value(), t4.v),
1579 "{label} value: {} vs {}",
1580 o2.value(),
1581 t4.v
1582 );
1583 for a in 0..2 {
1584 assert!(
1585 (o2.g()[a] - t4.g[a]).abs() <= band(o2.g()[a], t4.g[a]),
1586 "{label} grad[{a}]: {} vs {}",
1587 o2.g()[a],
1588 t4.g[a]
1589 );
1590 for b in 0..2 {
1591 assert!(
1592 (o2.h()[a][b] - t4.h[a][b]).abs() <= band(o2.h()[a][b], t4.h[a][b]),
1593 "{label} hess[{a}][{b}]: {} vs {}",
1594 o2.h()[a][b],
1595 t4.h[a][b]
1596 );
1597 }
1598 }
1599 }
1600
1601 let wi = 1.7_f64;
1602 for &(yi, mu, theta) in &[(0.0, 1.2, 3.0), (4.0, 2.5, 0.7), (10.0, 0.6, 5.0)] {
1604 check_o2_vs_tower4(
1605 "nb",
1606 dispersion_nb_nll_order2(yi, mu, theta, wi),
1607 test_support::dispersion_nb_nll_generic::<Tower4<2>>(yi, mu, theta, wi),
1608 );
1609 }
1610 for &(yi, mu, nu) in &[
1612 (0.5_f64, 1.1_f64, 2.0_f64),
1613 (3.0, 4.0, 0.9),
1614 (1.0, 0.3, 6.0),
1615 ] {
1616 let y_pos = yi.max(1e-300);
1617 check_o2_vs_tower4(
1618 "gamma",
1619 dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi),
1620 test_support::dispersion_gamma_nll_generic::<Tower4<2>>(yi, y_pos, mu, nu, wi),
1621 );
1622 }
1623 for &(yi, mu, phi) in &[(0.3, 0.4, 5.0), (0.9, 0.6, 12.0), (0.01, 0.2, 3.0)] {
1625 check_o2_vs_tower4(
1626 "beta",
1627 dispersion_beta_nll_order2(yi, mu, phi, wi),
1628 test_support::dispersion_beta_nll_generic::<Tower4<2>>(yi, mu, phi, wi),
1629 );
1630 }
1631 for &(yi, eta_mu, eta_d, p) in &[
1633 (0.0, 0.4, -0.3, 1.5),
1634 (2.5, -0.2, 0.5, 1.3),
1635 (0.0, 1.0, 0.1, 1.7),
1636 (5.0, 0.7, -0.6, 1.6),
1637 ] {
1638 check_o2_vs_tower4(
1639 "tweedie",
1640 dispersion_tweedie_nll_generic::<Order2<2>>(yi, eta_mu, eta_d, p, wi),
1641 dispersion_tweedie_nll_generic::<Tower4<2>>(yi, eta_mu, eta_d, p, wi),
1642 );
1643 }
1644 }
1645
1646 #[test]
1653 pub(crate) fn pruned_disp_towers_bit_identical_to_full_order2() {
1654 use gam_math::jet_scalar::{JetScalar, Order2};
1655
1656 let mut state: u64 = 0x9E3779B97F4A7C15;
1658 let mut next = || {
1659 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1660 ((state >> 11) as f64) / ((1u64 << 53) as f64)
1661 };
1662 let bits = |x: f64| x.to_bits();
1663
1664 let n_per = 600; for _ in 0..n_per {
1666 let wi = 0.25 + 3.0 * next();
1667 let yi_count = (next() * 12.0).floor();
1668
1669 {
1671 let mu = (0.05 + 4.0 * next()).max(1e-300);
1672 let theta = (0.05 + 6.0 * next()).max(1e-12);
1673 let full = dispersion_nb_nll_order2(yi_count, mu, theta, wi);
1674 let prn = dispersion_nb_disp_order2(yi_count, mu, theta, wi);
1675 assert_eq!(bits(full.value()), bits(prn.value()), "nb value");
1676 assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "nb grad");
1677 assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "nb hess");
1678 let prn1 = dispersion_nb_disp_order1(yi_count, mu, theta, wi);
1683 assert_eq!(bits(prn.value()), bits(prn1.value()), "nb order1 value");
1684 assert_eq!(bits(prn.g()[0]), bits(prn1.g()[0]), "nb order1 grad");
1685 assert_eq!(
1687 bits(dispersion_nb_neg_loglik(yi_count, mu, theta, wi)),
1688 bits(-prn.value()),
1689 "nb value-only"
1690 );
1691 }
1692 {
1694 let mu = (0.05 + 4.0 * next()).max(1e-300);
1695 let nu = (0.05 + 6.0 * next()).max(1e-12);
1696 let yi = 0.01 + 8.0 * next();
1697 let y_pos = yi.max(1e-300);
1698 let full = dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi);
1699 let prn = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
1700 assert_eq!(bits(full.value()), bits(prn.value()), "gamma value");
1701 assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "gamma grad");
1702 assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "gamma hess");
1703 assert_eq!(
1704 bits(dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)),
1705 bits(-prn.value()),
1706 "gamma value-only"
1707 );
1708 }
1709 {
1711 let mu = (1e-6 + (1.0 - 2e-6) * next()).clamp(1e-12, 1.0 - 1e-12);
1712 let phi = (0.05 + 20.0 * next()).max(1e-12);
1713 let yi = next();
1714 let full = dispersion_beta_nll_order2(yi, mu, phi, wi);
1715 assert_eq!(
1716 bits(dispersion_beta_neg_loglik(yi, mu, phi, wi)),
1717 bits(-full.value()),
1718 "beta value-only"
1719 );
1720 }
1721 for &(yi, eta_mu, eta_d, p) in &[
1723 (0.0_f64, -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1724 (0.01 + 9.0 * next(), -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1725 (3.0, -DISPERSION_ETA_CLAMP - 5.0, DISPERSION_ETA_CLAMP + 5.0, 1.5),
1726 ] {
1727 let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1728 let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1729 let full = dispersion_tweedie_nll_generic::<Order2<2>>(yi, em, ed, p, wi);
1730 let prn = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
1731 assert_eq!(bits(full.value()), bits(prn.value()), "tweedie value");
1732 assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "tweedie grad");
1733 assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "tweedie hess");
1734 assert_eq!(
1735 bits(dispersion_tweedie_neg_loglik(yi, em, ed, p, wi)),
1736 bits(-prn.value()),
1737 "tweedie value-only"
1738 );
1739 }
1740 }
1741 }
1742
1743 #[test]
1744 pub(crate) fn orthogonal_dispersion_families_report_zero_cross_weight() {
1745 let cases = [
1746 DispersionFamilyKind::NegativeBinomial,
1747 DispersionFamilyKind::Gamma,
1748 DispersionFamilyKind::Tweedie { p: 1.5 },
1749 ];
1750 for kind in cases {
1751 let got = dispersion_row_cross_weight(kind, 1.25, 0.2, -0.3, 2.0);
1752 assert_close(kind.family_tag(), got, 0.0, 1e-12);
1753 }
1754 }
1755
1756 #[test]
1766 pub(crate) fn parallel_evaluate_matches_serial_reference() {
1767 let n = DISPERSION_PARALLEL_ROW_THRESHOLD * 3 + 7;
1768 let mut state: u64 = 0xD1B5_4A32_D192_ED03;
1770 let mut next = || {
1771 state = state
1772 .wrapping_mul(6364136223846793005)
1773 .wrapping_add(1442695040888963407);
1774 ((state >> 11) as f64) / ((1u64 << 53) as f64)
1775 };
1776
1777 for kind in [
1778 DispersionFamilyKind::NegativeBinomial,
1779 DispersionFamilyKind::Gamma,
1780 DispersionFamilyKind::Beta,
1781 DispersionFamilyKind::Tweedie { p: 1.5 },
1782 ] {
1783 let y = Array1::from_shape_fn(n, |_| match kind {
1784 DispersionFamilyKind::Beta => 1e-3 + (1.0 - 2e-3) * next(),
1785 DispersionFamilyKind::NegativeBinomial => (next() * 12.0).floor(),
1786 _ => 0.05 + 8.0 * next(),
1787 });
1788 let weights = Array1::from_shape_fn(n, |_| 0.25 + 2.0 * next());
1789 let eta_mu = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1790 let eta_d = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1791
1792 let family = DispersionGlmLocationScaleFamily {
1793 kind,
1794 y: y.clone(),
1795 weights: weights.clone(),
1796 };
1797 let states = vec![
1798 ParameterBlockState {
1799 beta: Array1::zeros(0),
1800 eta: eta_mu.clone(),
1801 },
1802 ParameterBlockState {
1803 beta: Array1::zeros(0),
1804 eta: eta_d.clone(),
1805 },
1806 ];
1807
1808 let mut ll_ref = 0.0;
1810 let mut mw_ref = Array1::<f64>::zeros(n);
1811 let mut mr_ref = Array1::<f64>::zeros(n);
1812 let mut dw_ref = Array1::<f64>::zeros(n);
1813 let mut dr_ref = Array1::<f64>::zeros(n);
1814 for i in 0..n {
1815 let row = dispersion_row_kernel(kind, y[i], eta_mu[i], eta_d[i], weights[i]);
1816 if row.loglik.is_finite() {
1817 ll_ref += row.loglik;
1818 }
1819 mw_ref[i] = row.mean_weight.max(0.0);
1820 mr_ref[i] = row.mean_response;
1821 dw_ref[i] = row.disp_weight.max(0.0);
1822 dr_ref[i] = row.disp_response;
1823 }
1824
1825 let eval = family.evaluate(&states).expect("parallel evaluate");
1826 assert_close(
1827 &format!("{kind:?} evaluate log-likelihood"),
1828 eval.log_likelihood,
1829 ll_ref,
1830 1e-9,
1831 );
1832
1833 let BlockWorkingSet::Diagonal {
1834 working_response: mr,
1835 working_weights: mw,
1836 } = &eval.blockworking_sets[0]
1837 else {
1838 panic!("mean block not diagonal");
1839 };
1840 let BlockWorkingSet::Diagonal {
1841 working_response: dr,
1842 working_weights: dw,
1843 } = &eval.blockworking_sets[1]
1844 else {
1845 panic!("dispersion block not diagonal");
1846 };
1847 for i in 0..n {
1848 assert_close("mean weight", mw[i], mw_ref[i], 1e-9);
1849 assert_close("mean response", mr[i], mr_ref[i], 1e-9);
1850 assert_close("disp weight", dw[i], dw_ref[i], 1e-9);
1851 assert_close("disp response", dr[i], dr_ref[i], 1e-9);
1852 }
1853
1854 let ll_only = family
1857 .log_likelihood_only(&states)
1858 .expect("parallel log_likelihood_only");
1859 assert_close(
1860 &format!("{kind:?} log_likelihood_only"),
1861 ll_only,
1862 ll_ref,
1863 1e-9,
1864 );
1865 }
1866 }
1867}