1use gam_custom_family::{CustomFamily, ParameterBlockState};
2use gam_solve::estimate::EstimationError;
3use crate::inference::predict_io::PredictResult;
4use gam_problem::types::{
5 LikelihoodScaleMetadata, LikelihoodSpec, ResponseFamily, is_valid_tweedie_power,
6};
7use ndarray::{Array1, Array2};
8
9pub fn family_noise_parameter(
40 scale: LikelihoodScaleMetadata,
41 standard_deviation: f64,
42 likelihood: &LikelihoodSpec,
43) -> Option<f64> {
44 match likelihood.response {
45 ResponseFamily::Tweedie { .. } => scale.fixed_phi().or(Some(1.0)),
50 ResponseFamily::NegativeBinomial { theta, .. } => scale.negbin_theta().or(Some(theta)),
53 ResponseFamily::Beta { phi } => scale.fixed_phi().or(Some(phi)),
56 ResponseFamily::Gamma => scale.gamma_shape().or(Some(standard_deviation)),
59 _ => Some(standard_deviation),
62 }
63}
64
65#[derive(Clone, Debug)]
67pub enum NoiseModel {
68 Gaussian {
69 sigma: Array1<f64>,
71 },
72 Poisson,
73 Tweedie {
74 p: f64,
75 phi: Array1<f64>,
79 },
80 NegativeBinomial {
81 theta: Array1<f64>,
83 },
84 Beta {
85 phi: Array1<f64>,
87 },
88 Gamma {
89 shape: Array1<f64>,
92 },
93 Bernoulli,
94 TransformationNormalQuantile {
107 grid_y: Array1<f64>,
109 h_grid: Array2<f64>,
112 },
113}
114
115fn invert_monotone_grid(grid_y: &Array1<f64>, h_row: ndarray::ArrayView1<'_, f64>, target: f64) -> f64 {
122 let g = grid_y.len();
123 if target <= h_row[0] {
124 return grid_y[0];
125 }
126 if target >= h_row[g - 1] {
127 return grid_y[g - 1];
128 }
129 let mut lo = 0usize;
130 let mut hi = g - 1;
131 while hi - lo > 1 {
132 let mid = (lo + hi) / 2;
133 if h_row[mid] <= target {
134 lo = mid;
135 } else {
136 hi = mid;
137 }
138 }
139 let t = (target - h_row[lo]) / (h_row[hi] - h_row[lo]);
140 grid_y[lo] + t * (grid_y[hi] - grid_y[lo])
141}
142
143#[derive(Clone, Debug)]
145pub struct GenerativeSpec {
146 pub mean: Array1<f64>,
147 pub noise: NoiseModel,
148}
149
150impl GenerativeSpec {
151 pub fn nobs(&self) -> usize {
154 self.mean.len()
155 }
156}
157
158pub fn generativespec_from_predict(
160 prediction: PredictResult,
161 likelihood: LikelihoodSpec,
162 gaussian_scale: Option<f64>,
163 prior_weights: Option<&Array1<f64>>,
164) -> Result<GenerativeSpec, EstimationError> {
165 let mut noise =
166 NoiseModel::from_likelihood(&likelihood, prediction.mean.len(), gaussian_scale)?;
167 if let (NoiseModel::Gaussian { sigma }, Some(weights)) = (&mut noise, prior_weights) {
176 scale_gaussian_sigma_by_prior_weights(sigma, weights)?;
177 }
178 Ok(GenerativeSpec {
179 mean: prediction.mean,
180 noise,
181 })
182}
183
184fn scale_gaussian_sigma_by_prior_weights(
190 sigma: &mut Array1<f64>,
191 weights: &Array1<f64>,
192) -> Result<(), EstimationError> {
193 if weights.len() != sigma.len() {
194 crate::bail_invalid_estim!(
195 "prior weights length {} does not match observation count {}",
196 weights.len(),
197 sigma.len()
198 );
199 }
200 for (s, &w) in sigma.iter_mut().zip(weights.iter()) {
201 if !(w.is_finite() && w > 0.0) {
202 crate::bail_invalid_estim!(
203 "Gaussian replicate prior weights must be finite and > 0; got {w}"
204 );
205 }
206 *s /= w.sqrt();
207 }
208 Ok(())
209}
210
211impl NoiseModel {
212 pub fn from_likelihood(
224 likelihood: &LikelihoodSpec,
225 nobs: usize,
226 gaussian_scale: Option<f64>,
227 ) -> Result<NoiseModel, EstimationError> {
228 match &likelihood.response {
229 ResponseFamily::Gaussian => {
230 let sigma =
231 Self::require_noise_parameter(likelihood, "Gaussian sigma", gaussian_scale)?;
232 if sigma < 0.0 {
233 crate::bail_invalid_estim!(
234 "{} generative sampling requires Gaussian sigma >= 0; got {sigma}",
235 likelihood.pretty_name()
236 );
237 }
238 Ok(NoiseModel::Gaussian {
239 sigma: Array1::from_elem(nobs, sigma),
240 })
241 }
242 ResponseFamily::Binomial => Ok(NoiseModel::Bernoulli),
243 ResponseFamily::Poisson => Ok(NoiseModel::Poisson),
244 ResponseFamily::Tweedie { p } => {
245 let p = *p;
246 if !is_valid_tweedie_power(p) {
247 crate::bail_invalid_estim!(
248 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
249 );
250 }
251 let phi = Self::require_positive_noise_parameter(
252 likelihood,
253 "Tweedie dispersion phi",
254 gaussian_scale,
255 )?;
256 Ok(NoiseModel::Tweedie {
257 p,
258 phi: Array1::from_elem(nobs, phi),
262 })
263 }
264 ResponseFamily::NegativeBinomial { theta, .. } => {
265 let theta = gaussian_scale.unwrap_or(*theta);
273 if !(theta.is_finite() && theta > 0.0) {
274 crate::bail_invalid_estim!(
275 "negative-binomial theta must be finite and > 0; got {theta}"
276 );
277 }
278 Ok(NoiseModel::NegativeBinomial {
279 theta: Array1::from_elem(nobs, theta),
280 })
281 }
282 ResponseFamily::Beta { phi } => {
283 let phi = gaussian_scale.unwrap_or(*phi);
297 if !(phi.is_finite() && phi > 0.0) {
298 crate::bail_invalid_estim!(
299 "beta-regression phi must be finite and > 0; got {phi}"
300 );
301 }
302 Ok(NoiseModel::Beta {
303 phi: Array1::from_elem(nobs, phi),
304 })
305 }
306 ResponseFamily::Gamma => {
307 let shape = Self::require_positive_noise_parameter(
308 likelihood,
309 "Gamma shape",
310 gaussian_scale,
311 )?;
312 Ok(NoiseModel::Gamma {
313 shape: Array1::from_elem(nobs, shape),
314 })
315 }
316 ResponseFamily::RoystonParmar => Err(EstimationError::InvalidInput(
317 "RoystonParmar generative sampling is not exposed via generic generation"
318 .to_string(),
319 )),
320 }
321 }
322
323 pub fn from_likelihood_with_per_row_dispersion(
332 likelihood: &LikelihoodSpec,
333 dispersion: Array1<f64>,
334 ) -> Result<NoiseModel, EstimationError> {
335 match &likelihood.response {
336 ResponseFamily::Tweedie { p } => {
337 let p = *p;
338 if !is_valid_tweedie_power(p) {
339 crate::bail_invalid_estim!(
340 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
341 );
342 }
343 Ok(NoiseModel::Tweedie { p, phi: dispersion })
344 }
345 ResponseFamily::NegativeBinomial { .. } => {
346 Ok(NoiseModel::NegativeBinomial { theta: dispersion })
347 }
348 ResponseFamily::Beta { .. } => Ok(NoiseModel::Beta { phi: dispersion }),
349 ResponseFamily::Gamma => Ok(NoiseModel::Gamma { shape: dispersion }),
350 other => Err(EstimationError::InvalidInput(format!(
351 "per-row dispersion generative sampling is only defined for the dispersion \
352 location-scale families (Gamma/NegativeBinomial/Beta/Tweedie); got {other:?}"
353 ))),
354 }
355 }
356
357 fn require_noise_parameter(
358 likelihood: &LikelihoodSpec,
359 parameter_name: &str,
360 value: Option<f64>,
361 ) -> Result<f64, EstimationError> {
362 let value = value.ok_or_else(|| {
363 EstimationError::InvalidInput(format!(
364 "{} generative sampling requires fitted {parameter_name}",
365 likelihood.pretty_name()
366 ))
367 })?;
368 if value.is_finite() {
369 Ok(value)
370 } else {
371 Err(EstimationError::InvalidInput(format!(
372 "{} generative sampling requires finite {parameter_name}; got {value}",
373 likelihood.pretty_name()
374 )))
375 }
376 }
377
378 fn require_positive_noise_parameter(
379 likelihood: &LikelihoodSpec,
380 parameter_name: &str,
381 value: Option<f64>,
382 ) -> Result<f64, EstimationError> {
383 let value = Self::require_noise_parameter(likelihood, parameter_name, value)?;
384 if value > 0.0 {
385 Ok(value)
386 } else {
387 Err(EstimationError::InvalidInput(format!(
388 "{} generative sampling requires {parameter_name} > 0; got {value}",
389 likelihood.pretty_name()
390 )))
391 }
392 }
393}
394
395fn check_dispersion_len(
399 dispersion: &Array1<f64>,
400 nobs: usize,
401 name: &str,
402) -> Result<(), EstimationError> {
403 if dispersion.len() != nobs {
404 crate::bail_invalid_estim!(
405 "{name} length {} does not match mean length {nobs}",
406 dispersion.len()
407 );
408 }
409 Ok(())
410}
411
412pub fn sampleobservations<R: rand::Rng + ?Sized>(
414 spec: &GenerativeSpec,
415 rng: &mut R,
416) -> Result<Array1<f64>, EstimationError> {
417 if spec.mean.iter().any(|m| !m.is_finite()) {
418 crate::bail_invalid_estim!("generative mean contains non-finite values");
419 }
420 match &spec.noise {
421 NoiseModel::Gaussian { sigma } => {
422 if sigma.len() != spec.mean.len() {
423 crate::bail_invalid_estim!(
424 "Gaussian sigma length {} does not match mean length {}",
425 sigma.len(),
426 spec.mean.len()
427 );
428 }
429 let mut y = spec.mean.clone();
430 for i in 0..y.len() {
431 let sd = sigma[i].max(0.0);
432 if sd == 0.0 {
433 continue;
434 }
435 let dist = rand_distr::Normal::new(0.0, sd).map_err(|e| {
436 EstimationError::InvalidInput(format!("invalid Gaussian noise scale {sd}: {e}"))
437 })?;
438 y[i] += rand_distr::Distribution::sample(&dist, rng);
439 }
440 Ok(y)
441 }
442 NoiseModel::Poisson => {
443 let mut y = Array1::<f64>::zeros(spec.mean.len());
444 for i in 0..y.len() {
445 let lam = spec.mean[i].max(1e-12);
446 let dist = rand_distr::Poisson::new(lam).map_err(|e| {
447 EstimationError::InvalidInput(format!("invalid Poisson rate {lam}: {e}"))
448 })?;
449 let draw = rand_distr::Distribution::sample(&dist, rng);
450 y[i] = draw;
451 }
452 Ok(y)
453 }
454 NoiseModel::Tweedie { p, phi } => {
455 if !(p.is_finite() && *p >= 1.0 && *p <= 2.0) {
456 crate::bail_invalid_estim!("invalid Tweedie power p: {p}");
457 }
458 check_dispersion_len(phi, spec.mean.len(), "Tweedie dispersion phi")?;
459 for (i, &phi_i) in phi.iter().enumerate() {
460 if !(phi_i.is_finite() && phi_i > 0.0) {
461 crate::bail_invalid_estim!(
462 "invalid Tweedie dispersion phi at row {i}: {phi_i}"
463 );
464 }
465 }
466 let mut y = Array1::<f64>::zeros(spec.mean.len());
467 if (*p - 1.0).abs() <= 1.0e-12 {
468 for i in 0..y.len() {
469 let phi_i = phi[i];
470 let lam = (spec.mean[i] / phi_i).max(1e-12);
471 let dist = rand_distr::Poisson::new(lam).map_err(|e| {
472 EstimationError::InvalidInput(format!(
473 "invalid Tweedie-Poisson rate {lam}: {e}"
474 ))
475 })?;
476 y[i] = phi_i * rand_distr::Distribution::sample(&dist, rng);
477 }
478 return Ok(y);
479 }
480 if (*p - 2.0).abs() <= 1.0e-12 {
481 for i in 0..y.len() {
482 let phi_i = phi[i];
483 let shape = (1.0 / phi_i).max(1e-12);
484 let mu = spec.mean[i].max(1e-12);
485 let scale = (mu * phi_i).max(1e-12);
486 let dist = rand_distr::Gamma::new(shape, scale).map_err(|e| {
487 EstimationError::InvalidInput(format!(
488 "invalid Tweedie-Gamma params shape={shape} scale={scale}: {e}"
489 ))
490 })?;
491 y[i] = rand_distr::Distribution::sample(&dist, rng);
492 }
493 return Ok(y);
494 }
495 let alpha = (2.0 - *p) / (*p - 1.0);
496 for i in 0..y.len() {
497 let phi_i = phi[i];
498 let mu = spec.mean[i].max(1e-12);
499 let lambda = (mu.powf(2.0 - *p) / (phi_i * (2.0 - *p))).max(1e-12);
500 let scale = (phi_i * (*p - 1.0) * mu.powf(*p - 1.0)).max(1e-12);
501 let count_dist = rand_distr::Poisson::new(lambda).map_err(|e| {
502 EstimationError::InvalidInput(format!(
503 "invalid Tweedie compound-Poisson rate {lambda}: {e}"
504 ))
505 })?;
506 let count = rand_distr::Distribution::sample(&count_dist, rng) as usize;
507 if count == 0 {
508 continue;
509 }
510 let jump_dist = rand_distr::Gamma::new(alpha, scale).map_err(|e| {
511 EstimationError::InvalidInput(format!(
512 "invalid Tweedie jump params shape={alpha} scale={scale}: {e}"
513 ))
514 })?;
515 y[i] = (0..count)
516 .map(|_| rand_distr::Distribution::sample(&jump_dist, rng))
517 .sum();
518 }
519 Ok(y)
520 }
521 NoiseModel::NegativeBinomial { theta } => {
522 check_dispersion_len(theta, spec.mean.len(), "NegativeBinomial theta")?;
523 let mut y = Array1::<f64>::zeros(spec.mean.len());
524 for i in 0..y.len() {
525 let theta_i = theta[i];
526 if !(theta_i.is_finite() && theta_i > 0.0) {
527 crate::bail_invalid_estim!(
528 "invalid negative-binomial theta at row {i}: {theta_i}"
529 );
530 }
531 let mu = spec.mean[i].max(1e-12);
532 let scale = (mu / theta_i).max(1e-12);
533 let gamma = rand_distr::Gamma::new(theta_i, scale).map_err(|e| {
534 EstimationError::InvalidInput(format!(
535 "invalid NegativeBinomial gamma mixture params theta={theta_i} scale={scale}: {e}"
536 ))
537 })?;
538 let lambda = rand_distr::Distribution::sample(&gamma, rng).max(1e-12);
539 let poisson = rand_distr::Poisson::new(lambda).map_err(|e| {
540 EstimationError::InvalidInput(format!(
541 "invalid NegativeBinomial Poisson rate {lambda}: {e}"
542 ))
543 })?;
544 y[i] = rand_distr::Distribution::sample(&poisson, rng);
545 }
546 Ok(y)
547 }
548 NoiseModel::Beta { phi } => {
549 check_dispersion_len(phi, spec.mean.len(), "Beta phi")?;
550 let mut y = Array1::<f64>::zeros(spec.mean.len());
551 for i in 0..y.len() {
552 let phi_i = phi[i];
553 if !(phi_i.is_finite() && phi_i > 0.0) {
554 crate::bail_invalid_estim!("invalid beta-regression phi at row {i}: {phi_i}");
555 }
556 let mu = spec.mean[i].clamp(1e-12, 1.0 - 1e-12);
557 let alpha = (mu * phi_i).max(1e-12);
558 let beta = ((1.0 - mu) * phi_i).max(1e-12);
559 let dist = rand_distr::Beta::new(alpha, beta).map_err(|e| {
560 EstimationError::InvalidInput(format!(
561 "invalid Beta params alpha={alpha} beta={beta}: {e}"
562 ))
563 })?;
564 y[i] = rand_distr::Distribution::sample(&dist, rng);
565 }
566 Ok(y)
567 }
568 NoiseModel::Gamma { shape } => {
569 check_dispersion_len(shape, spec.mean.len(), "Gamma shape")?;
570 let mut y = Array1::<f64>::zeros(spec.mean.len());
571 for i in 0..y.len() {
572 let shape_i = shape[i];
573 if !shape_i.is_finite() || shape_i <= 0.0 {
574 crate::bail_invalid_estim!("invalid Gamma shape at row {i}: {shape_i}");
575 }
576 let mu = spec.mean[i].max(1e-12);
577 let scale = (mu / shape_i).max(1e-12);
578 let dist = rand_distr::Gamma::new(shape_i, scale).map_err(|e| {
579 EstimationError::InvalidInput(format!(
580 "invalid Gamma params shape={shape_i} scale={scale}: {e}"
581 ))
582 })?;
583 y[i] = rand_distr::Distribution::sample(&dist, rng);
584 }
585 Ok(y)
586 }
587 NoiseModel::Bernoulli => {
588 let mut y = Array1::<f64>::zeros(spec.mean.len());
589 for i in 0..y.len() {
590 let p = spec.mean[i];
591 let dist = rand_distr::Bernoulli::new(p).map_err(|e| {
592 EstimationError::InvalidInput(format!("invalid Bernoulli probability {p}: {e}"))
593 })?;
594 y[i] = if rand_distr::Distribution::sample(&dist, rng) {
595 1.0
596 } else {
597 0.0
598 };
599 }
600 Ok(y)
601 }
602 NoiseModel::TransformationNormalQuantile { grid_y, h_grid } => {
603 let n = spec.mean.len();
604 if h_grid.nrows() != n {
605 crate::bail_invalid_estim!(
606 "transformation-normal h_grid has {} rows but mean length is {n}",
607 h_grid.nrows()
608 );
609 }
610 let g = grid_y.len();
611 if g < 2 || h_grid.ncols() != g {
612 crate::bail_invalid_estim!(
613 "transformation-normal grid is degenerate: grid_y len {g}, h_grid cols {}",
614 h_grid.ncols()
615 );
616 }
617 let dist = rand_distr::Normal::new(0.0, 1.0).map_err(|e| {
621 EstimationError::InvalidInput(format!(
622 "invalid standard-normal latent sampler: {e}"
623 ))
624 })?;
625 let mut y = Array1::<f64>::zeros(n);
626 for i in 0..n {
627 let z: f64 = rand_distr::Distribution::sample(&dist, rng);
628 y[i] = invert_monotone_grid(grid_y, h_grid.row(i), z);
629 }
630 Ok(y)
631 }
632 }
633}
634
635pub fn sampleobservation_replicates<R: rand::Rng + ?Sized>(
637 spec: &GenerativeSpec,
638 n_draws: usize,
639 rng: &mut R,
640) -> Result<Array2<f64>, EstimationError> {
641 let n = spec.nobs();
642 let mut out = Array2::<f64>::zeros((n_draws, n));
643 for d in 0..n_draws {
644 let draw = sampleobservations(spec, rng)?;
645 out.row_mut(d).assign(&draw);
646 }
647 Ok(out)
648}
649
650pub trait CustomFamilyGenerative: CustomFamily {
653 fn generativespec(
654 &self,
655 block_states: &[ParameterBlockState],
656 ) -> Result<GenerativeSpec, String>;
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use crate::family_runtime::{FamilyStrategy, strategy_for_spec};
663
664 #[test]
672 fn transformation_normal_quantile_sampler_is_inverse_transform() {
673 use rand::SeedableRng;
674
675 let g = 801usize;
676 let (y_lo, y_hi) = (-12.0_f64, 12.0_f64);
677 let grid_y = Array1::from_shape_fn(g, |k| {
678 y_lo + (y_hi - y_lo) * (k as f64) / ((g - 1) as f64)
679 });
680 let centers = [-1.0_f64, 2.0_f64];
682 let slopes = [2.0_f64, 4.0_f64];
683 let mut h_grid = Array2::<f64>::zeros((2, g));
684 for i in 0..2 {
685 for k in 0..g {
686 h_grid[[i, k]] = slopes[i] * (grid_y[k] - centers[i]);
687 }
688 }
689 let spec = GenerativeSpec {
690 mean: Array1::from_vec(vec![centers[0], centers[1]]),
691 noise: NoiseModel::TransformationNormalQuantile {
692 grid_y: grid_y.clone(),
693 h_grid,
694 },
695 };
696
697 let mut rng = rand::rngs::StdRng::seed_from_u64(20240613);
698 let n_draws = 40_000usize;
699 let draws = sampleobservation_replicates(&spec, n_draws, &mut rng).unwrap();
700 assert_eq!(draws.shape(), &[n_draws, 2]);
701
702 let mut row_means = [0.0_f64; 2];
703 for i in 0..2 {
704 let col = draws.column(i);
705 let mean = col.sum() / (n_draws as f64);
706 let var =
707 col.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / (n_draws as f64);
708 let sd = var.sqrt();
709 row_means[i] = mean;
710 assert!(
711 (mean - centers[i]).abs() < 0.02,
712 "row {i} draw mean {mean:.4} should be the response-scale center {:.4}",
713 centers[i]
714 );
715 let expected_sd = 1.0 / slopes[i];
716 assert!(
717 (sd - expected_sd).abs() < 0.02,
718 "row {i} draw sd {sd:.4} should be the response-scale 1/slope {expected_sd:.4}, \
719 not the latent ≈1 of the old Gaussian-noise path"
720 );
721 }
722 assert!(
725 row_means[1] > row_means[0],
726 "draw means must increase with center: row0={:.4} row1={:.4}",
727 row_means[0],
728 row_means[1]
729 );
730 }
731
732 #[test]
739 fn family_noise_parameter_reads_fitted_dispersion_not_seed() {
740 let nb = LikelihoodSpec::negative_binomial_log(1.0);
742 assert_eq!(
743 family_noise_parameter(
744 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 },
745 0.0,
746 &nb,
747 ),
748 Some(2.97),
749 "NB picker must read theta_hat (#1124), not the seed theta=1"
750 );
751
752 let tw = LikelihoodSpec::tweedie_log(1.5);
755 assert_eq!(
756 family_noise_parameter(
757 LikelihoodScaleMetadata::EstimatedTweediePhi { phi: 7.25 },
758 0.0,
759 &tw,
760 ),
761 Some(7.25),
762 "Tweedie picker must read phi_hat (#771), not the variance power p"
763 );
764
765 let beta = LikelihoodSpec::beta_logit(1.0);
767 assert_eq!(
768 family_noise_parameter(
769 LikelihoodScaleMetadata::EstimatedBetaPhi { phi: 12.0 },
770 0.0,
771 &beta,
772 ),
773 Some(12.0),
774 "Beta picker must read phi_hat (#770), not the seed phi=1"
775 );
776
777 let gamma = LikelihoodSpec::gamma_log();
779 assert_eq!(
780 family_noise_parameter(
781 LikelihoodScaleMetadata::EstimatedGammaShape { shape: 4.5 },
782 0.123,
783 &gamma,
784 ),
785 Some(4.5),
786 "Gamma picker must read shape_hat (#678), not the residual-scale fallback"
787 );
788 }
789
790 #[test]
795 fn family_noise_parameter_falls_back_to_seed_when_unfitted() {
796 let none = LikelihoodScaleMetadata::ProfiledGaussian;
799 assert_eq!(
800 family_noise_parameter(none, 0.0, &LikelihoodSpec::negative_binomial_log(3.5)),
801 Some(3.5),
802 "NB picker must fall back to the spec seed theta"
803 );
804 assert_eq!(
805 family_noise_parameter(none, 0.0, &LikelihoodSpec::beta_logit(8.0)),
806 Some(8.0),
807 "Beta picker must fall back to the spec seed phi"
808 );
809 assert_eq!(
810 family_noise_parameter(none, 0.0, &LikelihoodSpec::tweedie_log(1.5)),
811 Some(1.0),
812 "Tweedie picker must fall back to unit dispersion"
813 );
814 assert_eq!(
815 family_noise_parameter(none, 2.0, &LikelihoodSpec::gamma_log()),
816 Some(2.0),
817 "Gamma picker must fall back to the residual scale"
818 );
819 }
820
821 #[test]
827 fn picker_then_from_likelihood_threads_fitted_nb_theta() {
828 let nobs = 6usize;
829 let seed_spec = LikelihoodSpec::negative_binomial_log(1.0);
830 let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.751 };
831 let picked = family_noise_parameter(scale, 0.0, &seed_spec);
832 let noise =
833 NoiseModel::from_likelihood(&seed_spec, nobs, picked).expect("NB noise model builds");
834 let NoiseModel::NegativeBinomial { theta } = noise else {
835 panic!("expected an NB observation noise model");
836 };
837 assert!(
838 theta.len() == nobs && theta.iter().all(|&t| (t - 2.751).abs() < 1e-12),
839 "NB generate composes the seed theta=1 instead of theta_hat (#1124): {theta:?}"
840 );
841 }
842
843 #[test]
850 fn gaussian_generativespec_scales_sigma_by_prior_weights() {
851 let sigma_hat = 2.0_f64;
852 let weights = Array1::from(vec![1.0, 4.0, 0.25]);
853 let mean = Array1::from(vec![0.0, 1.0, -1.0]);
854 let prediction = PredictResult {
855 eta: mean.clone(),
856 mean: mean.clone(),
857 };
858 let spec = generativespec_from_predict(
859 prediction,
860 LikelihoodSpec::gaussian_identity(),
861 Some(sigma_hat),
862 Some(&weights),
863 )
864 .expect("weighted Gaussian generative spec builds");
865 let NoiseModel::Gaussian { sigma } = spec.noise else {
866 panic!("expected Gaussian observation noise");
867 };
868 let expected = [2.0_f64, 1.0, 4.0];
870 for (i, (&got, &want)) in sigma.iter().zip(expected.iter()).enumerate() {
871 assert!(
872 (got - want).abs() < 1e-12,
873 "row {i}: sigma must be sigma_hat/sqrt(w_i)={want}, got {got} \
874 (flat sigma_hat={sigma_hat} drops the prior weights, #2025)"
875 );
876 }
877 assert!(
878 sigma.iter().any(|&s| (s - sigma_hat).abs() > 1e-9),
879 "sigma is flat at the pooled scalar; prior weights were dropped (#2025)"
880 );
881
882 let unit = Array1::from_elem(3, 1.0_f64);
884 let unweighted = generativespec_from_predict(
885 PredictResult {
886 eta: mean.clone(),
887 mean,
888 },
889 LikelihoodSpec::gaussian_identity(),
890 Some(sigma_hat),
891 Some(&unit),
892 )
893 .expect("unit-weight Gaussian generative spec builds");
894 let NoiseModel::Gaussian { sigma: flat } = unweighted.noise else {
895 panic!("expected Gaussian observation noise");
896 };
897 assert!(
898 flat.iter().all(|&s| (s - sigma_hat).abs() < 1e-12),
899 "unit prior weights must leave sigma at the pooled scalar sigma_hat"
900 );
901 }
902
903 fn noise_models_match(a: &NoiseModel, b: &NoiseModel) -> bool {
907 match (a, b) {
908 (NoiseModel::Gaussian { sigma: sa }, NoiseModel::Gaussian { sigma: sb }) => sa == sb,
909 (NoiseModel::Poisson, NoiseModel::Poisson) => true,
910 (NoiseModel::Bernoulli, NoiseModel::Bernoulli) => true,
911 (NoiseModel::Tweedie { p: pa, phi: pha }, NoiseModel::Tweedie { p: pb, phi: phb }) => {
912 pa == pb && pha == phb
913 }
914 (
915 NoiseModel::NegativeBinomial { theta: ta },
916 NoiseModel::NegativeBinomial { theta: tb },
917 ) => ta == tb,
918 (NoiseModel::Beta { phi: pa }, NoiseModel::Beta { phi: pb }) => pa == pb,
919 (NoiseModel::Gamma { shape: sa }, NoiseModel::Gamma { shape: sb }) => sa == sb,
920 _ => false,
921 }
922 }
923
924 #[test]
930 fn from_likelihood_matches_simulate_noise_for_each_family() {
931 let nobs = 5usize;
932 let mean = Array1::from_elem(nobs, 0.5_f64);
933
934 let cases: [(LikelihoodSpec, Option<f64>, NoiseModel); 7] = [
936 (
937 LikelihoodSpec::gaussian_identity(),
938 Some(0.7),
939 NoiseModel::Gaussian {
940 sigma: Array1::from_elem(nobs, 0.7),
941 },
942 ),
943 (
944 LikelihoodSpec::binomial_logit(),
945 None,
946 NoiseModel::Bernoulli,
947 ),
948 (LikelihoodSpec::poisson_log(), None, NoiseModel::Poisson),
949 (
950 LikelihoodSpec::tweedie_log(1.4),
951 Some(0.9),
952 NoiseModel::Tweedie {
953 p: 1.4,
954 phi: Array1::from_elem(nobs, 0.9),
955 },
956 ),
957 (
958 LikelihoodSpec::negative_binomial_log(2.5),
959 None,
960 NoiseModel::NegativeBinomial {
961 theta: Array1::from_elem(nobs, 2.5),
962 },
963 ),
964 (
965 LikelihoodSpec::beta_logit(3.0),
966 None,
967 NoiseModel::Beta {
968 phi: Array1::from_elem(nobs, 3.0),
969 },
970 ),
971 (
972 LikelihoodSpec::gamma_log(),
973 Some(1.5),
974 NoiseModel::Gamma {
975 shape: Array1::from_elem(nobs, 1.5),
976 },
977 ),
978 ];
979
980 for (spec, scale, expected) in cases {
981 let from_helper = NoiseModel::from_likelihood(&spec, nobs, scale)
982 .expect("canonical mapping must accept a supported family");
983 let from_strategy = strategy_for_spec(&spec)
984 .simulate_noise(&mean, scale)
985 .expect("simulation adapter must accept a supported family");
986
987 assert!(
988 noise_models_match(&from_helper, &expected),
989 "{} canonical mapping produced an unexpected NoiseModel",
990 spec.pretty_name()
991 );
992 assert!(
993 noise_models_match(&from_helper, &from_strategy),
994 "{} simulation and inference disagree on the NoiseModel",
995 spec.pretty_name()
996 );
997 }
998 }
999
1000 #[test]
1004 fn royston_parmar_rejected_on_both_paths() {
1005 let spec = LikelihoodSpec::royston_parmar();
1006 let mean = Array1::from_elem(3, 0.0_f64);
1007 assert!(NoiseModel::from_likelihood(&spec, 3, None).is_err());
1008 assert!(
1009 strategy_for_spec(&spec)
1010 .simulate_noise(&mean, None)
1011 .is_err()
1012 );
1013 }
1014
1015 #[test]
1018 fn invalid_dispersion_rejected_on_both_paths() {
1019 let mean = Array1::from_elem(4, 0.0_f64);
1020
1021 let gauss = LikelihoodSpec::gaussian_identity();
1023 assert!(NoiseModel::from_likelihood(&gauss, 4, None).is_err());
1024 assert!(
1025 strategy_for_spec(&gauss)
1026 .simulate_noise(&mean, None)
1027 .is_err()
1028 );
1029
1030 let bad_tweedie = LikelihoodSpec::tweedie_log(2.5);
1032 assert!(NoiseModel::from_likelihood(&bad_tweedie, 4, Some(0.5)).is_err());
1033 assert!(
1034 strategy_for_spec(&bad_tweedie)
1035 .simulate_noise(&mean, Some(0.5))
1036 .is_err()
1037 );
1038
1039 let gamma = LikelihoodSpec::gamma_log();
1041 assert!(NoiseModel::from_likelihood(&gamma, 4, Some(-1.0)).is_err());
1042 assert!(
1043 strategy_for_spec(&gamma)
1044 .simulate_noise(&mean, Some(-1.0))
1045 .is_err()
1046 );
1047 }
1048}