1use gam_custom_family::{CustomFamily, ParameterBlockState};
2use gam_solve::estimate::EstimationError;
3use crate::inference::predict_io::PredictResult;
4use gam_problem::types::{
5 LikelihoodScaleMetadata, LikelihoodSpec, ResponseFamily, is_valid_tweedie_power,
6};
7use ndarray::{Array1, Array2};
8
9pub fn family_noise_parameter(
40 scale: LikelihoodScaleMetadata,
41 standard_deviation: f64,
42 likelihood: &LikelihoodSpec,
43) -> Option<f64> {
44 match likelihood.response {
45 ResponseFamily::Tweedie { .. } => scale.fixed_phi().or(Some(1.0)),
50 ResponseFamily::NegativeBinomial { theta, .. } => scale.negbin_theta().or(Some(theta)),
53 ResponseFamily::Beta { phi } => scale.fixed_phi().or(Some(phi)),
56 ResponseFamily::Gamma => scale.gamma_shape().or(Some(standard_deviation)),
59 _ => Some(standard_deviation),
62 }
63}
64
65#[derive(Clone, Debug)]
67pub enum NoiseModel {
68 Gaussian {
69 sigma: Array1<f64>,
71 },
72 Poisson,
73 Tweedie {
74 p: f64,
75 phi: Array1<f64>,
79 },
80 NegativeBinomial {
81 theta: Array1<f64>,
83 },
84 Beta {
85 phi: Array1<f64>,
87 },
88 Gamma {
89 shape: Array1<f64>,
92 },
93 Bernoulli,
94 TransformationNormalQuantile {
107 grid_y: Array1<f64>,
109 h_grid: Array2<f64>,
112 },
113}
114
115fn invert_monotone_grid(grid_y: &Array1<f64>, h_row: ndarray::ArrayView1<'_, f64>, target: f64) -> f64 {
122 let g = grid_y.len();
123 if target <= h_row[0] {
124 return grid_y[0];
125 }
126 if target >= h_row[g - 1] {
127 return grid_y[g - 1];
128 }
129 let mut lo = 0usize;
130 let mut hi = g - 1;
131 while hi - lo > 1 {
132 let mid = (lo + hi) / 2;
133 if h_row[mid] <= target {
134 lo = mid;
135 } else {
136 hi = mid;
137 }
138 }
139 let t = (target - h_row[lo]) / (h_row[hi] - h_row[lo]);
140 grid_y[lo] + t * (grid_y[hi] - grid_y[lo])
141}
142
143#[derive(Clone, Debug)]
145pub struct GenerativeSpec {
146 pub mean: Array1<f64>,
147 pub noise: NoiseModel,
148}
149
150impl GenerativeSpec {
151 pub fn nobs(&self) -> usize {
154 self.mean.len()
155 }
156}
157
158pub fn generativespec_from_predict(
160 prediction: PredictResult,
161 likelihood: LikelihoodSpec,
162 gaussian_scale: Option<f64>,
163) -> Result<GenerativeSpec, EstimationError> {
164 let noise = NoiseModel::from_likelihood(&likelihood, prediction.mean.len(), gaussian_scale)?;
165 Ok(GenerativeSpec {
166 mean: prediction.mean,
167 noise,
168 })
169}
170
171impl NoiseModel {
172 pub fn from_likelihood(
184 likelihood: &LikelihoodSpec,
185 nobs: usize,
186 gaussian_scale: Option<f64>,
187 ) -> Result<NoiseModel, EstimationError> {
188 match &likelihood.response {
189 ResponseFamily::Gaussian => {
190 let sigma =
191 Self::require_noise_parameter(likelihood, "Gaussian sigma", gaussian_scale)?;
192 if sigma < 0.0 {
193 crate::bail_invalid_estim!(
194 "{} generative sampling requires Gaussian sigma >= 0; got {sigma}",
195 likelihood.pretty_name()
196 );
197 }
198 Ok(NoiseModel::Gaussian {
199 sigma: Array1::from_elem(nobs, sigma),
200 })
201 }
202 ResponseFamily::Binomial => Ok(NoiseModel::Bernoulli),
203 ResponseFamily::Poisson => Ok(NoiseModel::Poisson),
204 ResponseFamily::Tweedie { p } => {
205 let p = *p;
206 if !is_valid_tweedie_power(p) {
207 crate::bail_invalid_estim!(
208 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
209 );
210 }
211 let phi = Self::require_positive_noise_parameter(
212 likelihood,
213 "Tweedie dispersion phi",
214 gaussian_scale,
215 )?;
216 Ok(NoiseModel::Tweedie {
217 p,
218 phi: Array1::from_elem(nobs, phi),
222 })
223 }
224 ResponseFamily::NegativeBinomial { theta, .. } => {
225 let theta = gaussian_scale.unwrap_or(*theta);
233 if !(theta.is_finite() && theta > 0.0) {
234 crate::bail_invalid_estim!(
235 "negative-binomial theta must be finite and > 0; got {theta}"
236 );
237 }
238 Ok(NoiseModel::NegativeBinomial {
239 theta: Array1::from_elem(nobs, theta),
240 })
241 }
242 ResponseFamily::Beta { phi } => {
243 let phi = gaussian_scale.unwrap_or(*phi);
257 if !(phi.is_finite() && phi > 0.0) {
258 crate::bail_invalid_estim!(
259 "beta-regression phi must be finite and > 0; got {phi}"
260 );
261 }
262 Ok(NoiseModel::Beta {
263 phi: Array1::from_elem(nobs, phi),
264 })
265 }
266 ResponseFamily::Gamma => {
267 let shape = Self::require_positive_noise_parameter(
268 likelihood,
269 "Gamma shape",
270 gaussian_scale,
271 )?;
272 Ok(NoiseModel::Gamma {
273 shape: Array1::from_elem(nobs, shape),
274 })
275 }
276 ResponseFamily::RoystonParmar => Err(EstimationError::InvalidInput(
277 "RoystonParmar generative sampling is not exposed via generic generation"
278 .to_string(),
279 )),
280 }
281 }
282
283 pub fn from_likelihood_with_per_row_dispersion(
292 likelihood: &LikelihoodSpec,
293 dispersion: Array1<f64>,
294 ) -> Result<NoiseModel, EstimationError> {
295 match &likelihood.response {
296 ResponseFamily::Tweedie { p } => {
297 let p = *p;
298 if !is_valid_tweedie_power(p) {
299 crate::bail_invalid_estim!(
300 "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
301 );
302 }
303 Ok(NoiseModel::Tweedie { p, phi: dispersion })
304 }
305 ResponseFamily::NegativeBinomial { .. } => {
306 Ok(NoiseModel::NegativeBinomial { theta: dispersion })
307 }
308 ResponseFamily::Beta { .. } => Ok(NoiseModel::Beta { phi: dispersion }),
309 ResponseFamily::Gamma => Ok(NoiseModel::Gamma { shape: dispersion }),
310 other => Err(EstimationError::InvalidInput(format!(
311 "per-row dispersion generative sampling is only defined for the dispersion \
312 location-scale families (Gamma/NegativeBinomial/Beta/Tweedie); got {other:?}"
313 ))),
314 }
315 }
316
317 fn require_noise_parameter(
318 likelihood: &LikelihoodSpec,
319 parameter_name: &str,
320 value: Option<f64>,
321 ) -> Result<f64, EstimationError> {
322 let value = value.ok_or_else(|| {
323 EstimationError::InvalidInput(format!(
324 "{} generative sampling requires fitted {parameter_name}",
325 likelihood.pretty_name()
326 ))
327 })?;
328 if value.is_finite() {
329 Ok(value)
330 } else {
331 Err(EstimationError::InvalidInput(format!(
332 "{} generative sampling requires finite {parameter_name}; got {value}",
333 likelihood.pretty_name()
334 )))
335 }
336 }
337
338 fn require_positive_noise_parameter(
339 likelihood: &LikelihoodSpec,
340 parameter_name: &str,
341 value: Option<f64>,
342 ) -> Result<f64, EstimationError> {
343 let value = Self::require_noise_parameter(likelihood, parameter_name, value)?;
344 if value > 0.0 {
345 Ok(value)
346 } else {
347 Err(EstimationError::InvalidInput(format!(
348 "{} generative sampling requires {parameter_name} > 0; got {value}",
349 likelihood.pretty_name()
350 )))
351 }
352 }
353}
354
355fn check_dispersion_len(
359 dispersion: &Array1<f64>,
360 nobs: usize,
361 name: &str,
362) -> Result<(), EstimationError> {
363 if dispersion.len() != nobs {
364 crate::bail_invalid_estim!(
365 "{name} length {} does not match mean length {nobs}",
366 dispersion.len()
367 );
368 }
369 Ok(())
370}
371
372pub fn sampleobservations<R: rand::Rng + ?Sized>(
374 spec: &GenerativeSpec,
375 rng: &mut R,
376) -> Result<Array1<f64>, EstimationError> {
377 if spec.mean.iter().any(|m| !m.is_finite()) {
378 crate::bail_invalid_estim!("generative mean contains non-finite values");
379 }
380 match &spec.noise {
381 NoiseModel::Gaussian { sigma } => {
382 if sigma.len() != spec.mean.len() {
383 crate::bail_invalid_estim!(
384 "Gaussian sigma length {} does not match mean length {}",
385 sigma.len(),
386 spec.mean.len()
387 );
388 }
389 let mut y = spec.mean.clone();
390 for i in 0..y.len() {
391 let sd = sigma[i].max(0.0);
392 if sd == 0.0 {
393 continue;
394 }
395 let dist = rand_distr::Normal::new(0.0, sd).map_err(|e| {
396 EstimationError::InvalidInput(format!("invalid Gaussian noise scale {sd}: {e}"))
397 })?;
398 y[i] += rand_distr::Distribution::sample(&dist, rng);
399 }
400 Ok(y)
401 }
402 NoiseModel::Poisson => {
403 let mut y = Array1::<f64>::zeros(spec.mean.len());
404 for i in 0..y.len() {
405 let lam = spec.mean[i].max(1e-12);
406 let dist = rand_distr::Poisson::new(lam).map_err(|e| {
407 EstimationError::InvalidInput(format!("invalid Poisson rate {lam}: {e}"))
408 })?;
409 let draw = rand_distr::Distribution::sample(&dist, rng);
410 y[i] = draw;
411 }
412 Ok(y)
413 }
414 NoiseModel::Tweedie { p, phi } => {
415 if !(p.is_finite() && *p >= 1.0 && *p <= 2.0) {
416 crate::bail_invalid_estim!("invalid Tweedie power p: {p}");
417 }
418 check_dispersion_len(phi, spec.mean.len(), "Tweedie dispersion phi")?;
419 for (i, &phi_i) in phi.iter().enumerate() {
420 if !(phi_i.is_finite() && phi_i > 0.0) {
421 crate::bail_invalid_estim!(
422 "invalid Tweedie dispersion phi at row {i}: {phi_i}"
423 );
424 }
425 }
426 let mut y = Array1::<f64>::zeros(spec.mean.len());
427 if (*p - 1.0).abs() <= 1.0e-12 {
428 for i in 0..y.len() {
429 let phi_i = phi[i];
430 let lam = (spec.mean[i] / phi_i).max(1e-12);
431 let dist = rand_distr::Poisson::new(lam).map_err(|e| {
432 EstimationError::InvalidInput(format!(
433 "invalid Tweedie-Poisson rate {lam}: {e}"
434 ))
435 })?;
436 y[i] = phi_i * rand_distr::Distribution::sample(&dist, rng);
437 }
438 return Ok(y);
439 }
440 if (*p - 2.0).abs() <= 1.0e-12 {
441 for i in 0..y.len() {
442 let phi_i = phi[i];
443 let shape = (1.0 / phi_i).max(1e-12);
444 let mu = spec.mean[i].max(1e-12);
445 let scale = (mu * phi_i).max(1e-12);
446 let dist = rand_distr::Gamma::new(shape, scale).map_err(|e| {
447 EstimationError::InvalidInput(format!(
448 "invalid Tweedie-Gamma params shape={shape} scale={scale}: {e}"
449 ))
450 })?;
451 y[i] = rand_distr::Distribution::sample(&dist, rng);
452 }
453 return Ok(y);
454 }
455 let alpha = (2.0 - *p) / (*p - 1.0);
456 for i in 0..y.len() {
457 let phi_i = phi[i];
458 let mu = spec.mean[i].max(1e-12);
459 let lambda = (mu.powf(2.0 - *p) / (phi_i * (2.0 - *p))).max(1e-12);
460 let scale = (phi_i * (*p - 1.0) * mu.powf(*p - 1.0)).max(1e-12);
461 let count_dist = rand_distr::Poisson::new(lambda).map_err(|e| {
462 EstimationError::InvalidInput(format!(
463 "invalid Tweedie compound-Poisson rate {lambda}: {e}"
464 ))
465 })?;
466 let count = rand_distr::Distribution::sample(&count_dist, rng) as usize;
467 if count == 0 {
468 continue;
469 }
470 let jump_dist = rand_distr::Gamma::new(alpha, scale).map_err(|e| {
471 EstimationError::InvalidInput(format!(
472 "invalid Tweedie jump params shape={alpha} scale={scale}: {e}"
473 ))
474 })?;
475 y[i] = (0..count)
476 .map(|_| rand_distr::Distribution::sample(&jump_dist, rng))
477 .sum();
478 }
479 Ok(y)
480 }
481 NoiseModel::NegativeBinomial { theta } => {
482 check_dispersion_len(theta, spec.mean.len(), "NegativeBinomial theta")?;
483 let mut y = Array1::<f64>::zeros(spec.mean.len());
484 for i in 0..y.len() {
485 let theta_i = theta[i];
486 if !(theta_i.is_finite() && theta_i > 0.0) {
487 crate::bail_invalid_estim!(
488 "invalid negative-binomial theta at row {i}: {theta_i}"
489 );
490 }
491 let mu = spec.mean[i].max(1e-12);
492 let scale = (mu / theta_i).max(1e-12);
493 let gamma = rand_distr::Gamma::new(theta_i, scale).map_err(|e| {
494 EstimationError::InvalidInput(format!(
495 "invalid NegativeBinomial gamma mixture params theta={theta_i} scale={scale}: {e}"
496 ))
497 })?;
498 let lambda = rand_distr::Distribution::sample(&gamma, rng).max(1e-12);
499 let poisson = rand_distr::Poisson::new(lambda).map_err(|e| {
500 EstimationError::InvalidInput(format!(
501 "invalid NegativeBinomial Poisson rate {lambda}: {e}"
502 ))
503 })?;
504 y[i] = rand_distr::Distribution::sample(&poisson, rng);
505 }
506 Ok(y)
507 }
508 NoiseModel::Beta { phi } => {
509 check_dispersion_len(phi, spec.mean.len(), "Beta phi")?;
510 let mut y = Array1::<f64>::zeros(spec.mean.len());
511 for i in 0..y.len() {
512 let phi_i = phi[i];
513 if !(phi_i.is_finite() && phi_i > 0.0) {
514 crate::bail_invalid_estim!("invalid beta-regression phi at row {i}: {phi_i}");
515 }
516 let mu = spec.mean[i].clamp(1e-12, 1.0 - 1e-12);
517 let alpha = (mu * phi_i).max(1e-12);
518 let beta = ((1.0 - mu) * phi_i).max(1e-12);
519 let dist = rand_distr::Beta::new(alpha, beta).map_err(|e| {
520 EstimationError::InvalidInput(format!(
521 "invalid Beta params alpha={alpha} beta={beta}: {e}"
522 ))
523 })?;
524 y[i] = rand_distr::Distribution::sample(&dist, rng);
525 }
526 Ok(y)
527 }
528 NoiseModel::Gamma { shape } => {
529 check_dispersion_len(shape, spec.mean.len(), "Gamma shape")?;
530 let mut y = Array1::<f64>::zeros(spec.mean.len());
531 for i in 0..y.len() {
532 let shape_i = shape[i];
533 if !shape_i.is_finite() || shape_i <= 0.0 {
534 crate::bail_invalid_estim!("invalid Gamma shape at row {i}: {shape_i}");
535 }
536 let mu = spec.mean[i].max(1e-12);
537 let scale = (mu / shape_i).max(1e-12);
538 let dist = rand_distr::Gamma::new(shape_i, scale).map_err(|e| {
539 EstimationError::InvalidInput(format!(
540 "invalid Gamma params shape={shape_i} scale={scale}: {e}"
541 ))
542 })?;
543 y[i] = rand_distr::Distribution::sample(&dist, rng);
544 }
545 Ok(y)
546 }
547 NoiseModel::Bernoulli => {
548 let mut y = Array1::<f64>::zeros(spec.mean.len());
549 for i in 0..y.len() {
550 let p = spec.mean[i];
551 let dist = rand_distr::Bernoulli::new(p).map_err(|e| {
552 EstimationError::InvalidInput(format!("invalid Bernoulli probability {p}: {e}"))
553 })?;
554 y[i] = if rand_distr::Distribution::sample(&dist, rng) {
555 1.0
556 } else {
557 0.0
558 };
559 }
560 Ok(y)
561 }
562 NoiseModel::TransformationNormalQuantile { grid_y, h_grid } => {
563 let n = spec.mean.len();
564 if h_grid.nrows() != n {
565 crate::bail_invalid_estim!(
566 "transformation-normal h_grid has {} rows but mean length is {n}",
567 h_grid.nrows()
568 );
569 }
570 let g = grid_y.len();
571 if g < 2 || h_grid.ncols() != g {
572 crate::bail_invalid_estim!(
573 "transformation-normal grid is degenerate: grid_y len {g}, h_grid cols {}",
574 h_grid.ncols()
575 );
576 }
577 let dist = rand_distr::Normal::new(0.0, 1.0).map_err(|e| {
581 EstimationError::InvalidInput(format!(
582 "invalid standard-normal latent sampler: {e}"
583 ))
584 })?;
585 let mut y = Array1::<f64>::zeros(n);
586 for i in 0..n {
587 let z: f64 = rand_distr::Distribution::sample(&dist, rng);
588 y[i] = invert_monotone_grid(grid_y, h_grid.row(i), z);
589 }
590 Ok(y)
591 }
592 }
593}
594
595pub fn sampleobservation_replicates<R: rand::Rng + ?Sized>(
597 spec: &GenerativeSpec,
598 n_draws: usize,
599 rng: &mut R,
600) -> Result<Array2<f64>, EstimationError> {
601 let n = spec.nobs();
602 let mut out = Array2::<f64>::zeros((n_draws, n));
603 for d in 0..n_draws {
604 let draw = sampleobservations(spec, rng)?;
605 out.row_mut(d).assign(&draw);
606 }
607 Ok(out)
608}
609
610pub trait CustomFamilyGenerative: CustomFamily {
613 fn generativespec(
614 &self,
615 block_states: &[ParameterBlockState],
616 ) -> Result<GenerativeSpec, String>;
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622 use crate::family_runtime::{FamilyStrategy, strategy_for_spec};
623
624 #[test]
632 fn transformation_normal_quantile_sampler_is_inverse_transform() {
633 use rand::SeedableRng;
634
635 let g = 801usize;
636 let (y_lo, y_hi) = (-12.0_f64, 12.0_f64);
637 let grid_y = Array1::from_shape_fn(g, |k| {
638 y_lo + (y_hi - y_lo) * (k as f64) / ((g - 1) as f64)
639 });
640 let centers = [-1.0_f64, 2.0_f64];
642 let slopes = [2.0_f64, 4.0_f64];
643 let mut h_grid = Array2::<f64>::zeros((2, g));
644 for i in 0..2 {
645 for k in 0..g {
646 h_grid[[i, k]] = slopes[i] * (grid_y[k] - centers[i]);
647 }
648 }
649 let spec = GenerativeSpec {
650 mean: Array1::from_vec(vec![centers[0], centers[1]]),
651 noise: NoiseModel::TransformationNormalQuantile {
652 grid_y: grid_y.clone(),
653 h_grid,
654 },
655 };
656
657 let mut rng = rand::rngs::StdRng::seed_from_u64(20240613);
658 let n_draws = 40_000usize;
659 let draws = sampleobservation_replicates(&spec, n_draws, &mut rng).unwrap();
660 assert_eq!(draws.shape(), &[n_draws, 2]);
661
662 let mut row_means = [0.0_f64; 2];
663 for i in 0..2 {
664 let col = draws.column(i);
665 let mean = col.sum() / (n_draws as f64);
666 let var =
667 col.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / (n_draws as f64);
668 let sd = var.sqrt();
669 row_means[i] = mean;
670 assert!(
671 (mean - centers[i]).abs() < 0.02,
672 "row {i} draw mean {mean:.4} should be the response-scale center {:.4}",
673 centers[i]
674 );
675 let expected_sd = 1.0 / slopes[i];
676 assert!(
677 (sd - expected_sd).abs() < 0.02,
678 "row {i} draw sd {sd:.4} should be the response-scale 1/slope {expected_sd:.4}, \
679 not the latent ≈1 of the old Gaussian-noise path"
680 );
681 }
682 assert!(
685 row_means[1] > row_means[0],
686 "draw means must increase with center: row0={:.4} row1={:.4}",
687 row_means[0],
688 row_means[1]
689 );
690 }
691
692 #[test]
699 fn family_noise_parameter_reads_fitted_dispersion_not_seed() {
700 let nb = LikelihoodSpec::negative_binomial_log(1.0);
702 assert_eq!(
703 family_noise_parameter(
704 LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 },
705 0.0,
706 &nb,
707 ),
708 Some(2.97),
709 "NB picker must read theta_hat (#1124), not the seed theta=1"
710 );
711
712 let tw = LikelihoodSpec::tweedie_log(1.5);
715 assert_eq!(
716 family_noise_parameter(
717 LikelihoodScaleMetadata::EstimatedTweediePhi { phi: 7.25 },
718 0.0,
719 &tw,
720 ),
721 Some(7.25),
722 "Tweedie picker must read phi_hat (#771), not the variance power p"
723 );
724
725 let beta = LikelihoodSpec::beta_logit(1.0);
727 assert_eq!(
728 family_noise_parameter(
729 LikelihoodScaleMetadata::EstimatedBetaPhi { phi: 12.0 },
730 0.0,
731 &beta,
732 ),
733 Some(12.0),
734 "Beta picker must read phi_hat (#770), not the seed phi=1"
735 );
736
737 let gamma = LikelihoodSpec::gamma_log();
739 assert_eq!(
740 family_noise_parameter(
741 LikelihoodScaleMetadata::EstimatedGammaShape { shape: 4.5 },
742 0.123,
743 &gamma,
744 ),
745 Some(4.5),
746 "Gamma picker must read shape_hat (#678), not the residual-scale fallback"
747 );
748 }
749
750 #[test]
755 fn family_noise_parameter_falls_back_to_seed_when_unfitted() {
756 let none = LikelihoodScaleMetadata::ProfiledGaussian;
759 assert_eq!(
760 family_noise_parameter(none, 0.0, &LikelihoodSpec::negative_binomial_log(3.5)),
761 Some(3.5),
762 "NB picker must fall back to the spec seed theta"
763 );
764 assert_eq!(
765 family_noise_parameter(none, 0.0, &LikelihoodSpec::beta_logit(8.0)),
766 Some(8.0),
767 "Beta picker must fall back to the spec seed phi"
768 );
769 assert_eq!(
770 family_noise_parameter(none, 0.0, &LikelihoodSpec::tweedie_log(1.5)),
771 Some(1.0),
772 "Tweedie picker must fall back to unit dispersion"
773 );
774 assert_eq!(
775 family_noise_parameter(none, 2.0, &LikelihoodSpec::gamma_log()),
776 Some(2.0),
777 "Gamma picker must fall back to the residual scale"
778 );
779 }
780
781 #[test]
787 fn picker_then_from_likelihood_threads_fitted_nb_theta() {
788 let nobs = 6usize;
789 let seed_spec = LikelihoodSpec::negative_binomial_log(1.0);
790 let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.751 };
791 let picked = family_noise_parameter(scale, 0.0, &seed_spec);
792 let noise =
793 NoiseModel::from_likelihood(&seed_spec, nobs, picked).expect("NB noise model builds");
794 let NoiseModel::NegativeBinomial { theta } = noise else {
795 panic!("expected an NB observation noise model");
796 };
797 assert!(
798 theta.len() == nobs && theta.iter().all(|&t| (t - 2.751).abs() < 1e-12),
799 "NB generate composes the seed theta=1 instead of theta_hat (#1124): {theta:?}"
800 );
801 }
802
803 fn noise_models_match(a: &NoiseModel, b: &NoiseModel) -> bool {
807 match (a, b) {
808 (NoiseModel::Gaussian { sigma: sa }, NoiseModel::Gaussian { sigma: sb }) => sa == sb,
809 (NoiseModel::Poisson, NoiseModel::Poisson) => true,
810 (NoiseModel::Bernoulli, NoiseModel::Bernoulli) => true,
811 (NoiseModel::Tweedie { p: pa, phi: pha }, NoiseModel::Tweedie { p: pb, phi: phb }) => {
812 pa == pb && pha == phb
813 }
814 (
815 NoiseModel::NegativeBinomial { theta: ta },
816 NoiseModel::NegativeBinomial { theta: tb },
817 ) => ta == tb,
818 (NoiseModel::Beta { phi: pa }, NoiseModel::Beta { phi: pb }) => pa == pb,
819 (NoiseModel::Gamma { shape: sa }, NoiseModel::Gamma { shape: sb }) => sa == sb,
820 _ => false,
821 }
822 }
823
824 #[test]
830 fn from_likelihood_matches_simulate_noise_for_each_family() {
831 let nobs = 5usize;
832 let mean = Array1::from_elem(nobs, 0.5_f64);
833
834 let cases: [(LikelihoodSpec, Option<f64>, NoiseModel); 7] = [
836 (
837 LikelihoodSpec::gaussian_identity(),
838 Some(0.7),
839 NoiseModel::Gaussian {
840 sigma: Array1::from_elem(nobs, 0.7),
841 },
842 ),
843 (
844 LikelihoodSpec::binomial_logit(),
845 None,
846 NoiseModel::Bernoulli,
847 ),
848 (LikelihoodSpec::poisson_log(), None, NoiseModel::Poisson),
849 (
850 LikelihoodSpec::tweedie_log(1.4),
851 Some(0.9),
852 NoiseModel::Tweedie {
853 p: 1.4,
854 phi: Array1::from_elem(nobs, 0.9),
855 },
856 ),
857 (
858 LikelihoodSpec::negative_binomial_log(2.5),
859 None,
860 NoiseModel::NegativeBinomial {
861 theta: Array1::from_elem(nobs, 2.5),
862 },
863 ),
864 (
865 LikelihoodSpec::beta_logit(3.0),
866 None,
867 NoiseModel::Beta {
868 phi: Array1::from_elem(nobs, 3.0),
869 },
870 ),
871 (
872 LikelihoodSpec::gamma_log(),
873 Some(1.5),
874 NoiseModel::Gamma {
875 shape: Array1::from_elem(nobs, 1.5),
876 },
877 ),
878 ];
879
880 for (spec, scale, expected) in cases {
881 let from_helper = NoiseModel::from_likelihood(&spec, nobs, scale)
882 .expect("canonical mapping must accept a supported family");
883 let from_strategy = strategy_for_spec(&spec)
884 .simulate_noise(&mean, scale)
885 .expect("simulation adapter must accept a supported family");
886
887 assert!(
888 noise_models_match(&from_helper, &expected),
889 "{} canonical mapping produced an unexpected NoiseModel",
890 spec.pretty_name()
891 );
892 assert!(
893 noise_models_match(&from_helper, &from_strategy),
894 "{} simulation and inference disagree on the NoiseModel",
895 spec.pretty_name()
896 );
897 }
898 }
899
900 #[test]
904 fn royston_parmar_rejected_on_both_paths() {
905 let spec = LikelihoodSpec::royston_parmar();
906 let mean = Array1::from_elem(3, 0.0_f64);
907 assert!(NoiseModel::from_likelihood(&spec, 3, None).is_err());
908 assert!(
909 strategy_for_spec(&spec)
910 .simulate_noise(&mean, None)
911 .is_err()
912 );
913 }
914
915 #[test]
918 fn invalid_dispersion_rejected_on_both_paths() {
919 let mean = Array1::from_elem(4, 0.0_f64);
920
921 let gauss = LikelihoodSpec::gaussian_identity();
923 assert!(NoiseModel::from_likelihood(&gauss, 4, None).is_err());
924 assert!(
925 strategy_for_spec(&gauss)
926 .simulate_noise(&mean, None)
927 .is_err()
928 );
929
930 let bad_tweedie = LikelihoodSpec::tweedie_log(2.5);
932 assert!(NoiseModel::from_likelihood(&bad_tweedie, 4, Some(0.5)).is_err());
933 assert!(
934 strategy_for_spec(&bad_tweedie)
935 .simulate_noise(&mean, Some(0.5))
936 .is_err()
937 );
938
939 let gamma = LikelihoodSpec::gamma_log();
941 assert!(NoiseModel::from_likelihood(&gamma, 4, Some(-1.0)).is_err());
942 assert!(
943 strategy_for_spec(&gamma)
944 .simulate_noise(&mean, Some(-1.0))
945 .is_err()
946 );
947 }
948}