Skip to main content

gam_models/inference/
generative.rs

1use gam_custom_family::{CustomFamily, ParameterBlockState};
2use gam_solve::estimate::EstimationError;
3use crate::inference::predict_io::PredictResult;
4use gam_problem::types::{
5    LikelihoodScaleMetadata, LikelihoodSpec, ResponseFamily, is_valid_tweedie_power,
6};
7use ndarray::{Array1, Array2};
8
9/// THE single source of truth for the scalar dispersion the generative
10/// observation model uses for a fitted family — the value handed to
11/// [`NoiseModel::from_likelihood`] / [`generativespec_from_predict`] as
12/// `gaussian_scale`.
13///
14/// For every exponential-dispersion / overdispersed family the dispersion is
15/// **estimated jointly with the mean** and recorded in the fit's
16/// [`LikelihoodScaleMetadata`] (`scale`); the value embedded in the response
17/// spec (`likelihood.response`) is only the construction-time *seed* (e.g.
18/// `theta = 1.0`, `phi = 1.0`), left un-updated after the fit refreshes the
19/// estimate. Generation must therefore read the *fitted* dispersion off `scale`,
20/// falling back to the seed only for fit-free construction. Reading the seed was
21/// the shared root cause of a whole family of bugs — Gamma #678, Beta #769/#770,
22/// Tweedie #771, and the NB sibling #1124 (`Var = mu + mu^2` instead of
23/// `mu + mu^2/theta_hat`).
24///
25/// This helper exists in exactly one place precisely because that bug class
26/// recurred: the dispersion-picking logic had been duplicated across the CLI
27/// `gam generate` path and the Python `sample_replicates` path, and fixing one
28/// copy left the other drawing at the seed. Both paths now call this function,
29/// so the set of supported families and the interpretation of each dispersion
30/// parameter can never diverge again. (The per-row dispersion location-scale
31/// path, #913/#1125, is the one exception that bypasses this scalar picker — it
32/// threads a full `exp(eta_d(x))` vector via
33/// [`NoiseModel::from_likelihood_with_per_row_dispersion`] instead.)
34///
35/// `standard_deviation` is the fit's residual scale, used as the Gamma-shape and
36/// Gaussian-`sigma` fallback. Returns `None` only for families that carry no
37/// dispersion at all in the fallback arm (never, in practice, for the families
38/// above).
39pub fn family_noise_parameter(
40    scale: LikelihoodScaleMetadata,
41    standard_deviation: f64,
42    likelihood: &LikelihoodSpec,
43) -> Option<f64> {
44    match likelihood.response {
45        // Tweedie: `gaussian_scale` carries the *dispersion* phi; the variance
46        // power `p` is read straight off the family spec by `from_likelihood`.
47        // phi is estimated jointly with the mean (#771), so consult the fit's
48        // scale metadata; unit dispersion is the fit-free fallback.
49        ResponseFamily::Tweedie { .. } => scale.fixed_phi().or(Some(1.0)),
50        // NB overdispersion theta is estimated jointly with the mean and stored
51        // as `EstimatedNegBinTheta`; the spec theta is only the seed (#1124).
52        ResponseFamily::NegativeBinomial { theta, .. } => scale.negbin_theta().or(Some(theta)),
53        // Beta precision phi is estimated jointly with the mean (#567/#770); the
54        // spec phi is only the seed.
55        ResponseFamily::Beta { phi } => scale.fixed_phi().or(Some(phi)),
56        // Gamma shape k is estimated jointly with the mean (#678); fall back to
57        // the residual scale only when the fit recorded no shape.
58        ResponseFamily::Gamma => scale.gamma_shape().or(Some(standard_deviation)),
59        // Gaussian / Poisson / Binomial: the residual scale is the generative
60        // sigma (Poisson/Binomial ignore it downstream).
61        _ => Some(standard_deviation),
62    }
63}
64
65/// Observation-noise model used for generative sampling.
66#[derive(Clone, Debug)]
67pub enum NoiseModel {
68    Gaussian {
69        /// Per-observation standard deviation.
70        sigma: Array1<f64>,
71    },
72    Poisson,
73    Tweedie {
74        p: f64,
75        /// Per-observation dispersion φ (> 0). A scalar-dispersion fit broadcasts
76        /// one value to every row; a dispersion location-scale fit (#913/#1125)
77        /// supplies the fitted per-row φ = 1/exp(eta_d(x)).
78        phi: Array1<f64>,
79    },
80    NegativeBinomial {
81        /// Per-observation overdispersion θ (> 0); see `Tweedie::phi`.
82        theta: Array1<f64>,
83    },
84    Beta {
85        /// Per-observation precision φ (> 0); see `Tweedie::phi`.
86        phi: Array1<f64>,
87    },
88    Gamma {
89        /// Per-observation Gamma shape k (> 0), with mean-driven scale; see
90        /// `Tweedie::phi`.
91        shape: Array1<f64>,
92    },
93    Bernoulli,
94}
95
96/// First-class generative specification: mean process + observation noise.
97#[derive(Clone, Debug)]
98pub struct GenerativeSpec {
99    pub mean: Array1<f64>,
100    pub noise: NoiseModel,
101}
102
103impl GenerativeSpec {
104    /// Number of observations `n` in the mean vector, matching the row
105    /// count of the design used to produce this generative specification.
106    pub fn nobs(&self) -> usize {
107        self.mean.len()
108    }
109}
110
111/// Build a generative specification for built-in GAM families from eta/mean.
112pub fn generativespec_from_predict(
113    prediction: PredictResult,
114    likelihood: LikelihoodSpec,
115    gaussian_scale: Option<f64>,
116) -> Result<GenerativeSpec, EstimationError> {
117    let noise = NoiseModel::from_likelihood(&likelihood, prediction.mean.len(), gaussian_scale)?;
118    Ok(GenerativeSpec {
119        mean: prediction.mean,
120        noise,
121    })
122}
123
124impl NoiseModel {
125    /// Single canonical mapping from a fitted `LikelihoodSpec` (response
126    /// distribution + dispersion `gaussian_scale`) to the observation
127    /// `NoiseModel` used for generative sampling. Both simulation
128    /// (`FamilyStrategy::simulate_noise`) and generative inference
129    /// (`generativespec_from_predict`) route through this one helper so the
130    /// set of supported likelihoods and the interpretation of dispersion
131    /// parameters can never diverge between the two paths.
132    ///
133    /// `nobs` is the number of observations the resulting per-observation
134    /// Gaussian `sigma` vector should span; it is ignored for families whose
135    /// noise carries no per-observation state.
136    pub fn from_likelihood(
137        likelihood: &LikelihoodSpec,
138        nobs: usize,
139        gaussian_scale: Option<f64>,
140    ) -> Result<NoiseModel, EstimationError> {
141        match &likelihood.response {
142            ResponseFamily::Gaussian => {
143                let sigma =
144                    Self::require_noise_parameter(likelihood, "Gaussian sigma", gaussian_scale)?;
145                if sigma < 0.0 {
146                    crate::bail_invalid_estim!(
147                        "{} generative sampling requires Gaussian sigma >= 0; got {sigma}",
148                        likelihood.pretty_name()
149                    );
150                }
151                Ok(NoiseModel::Gaussian {
152                    sigma: Array1::from_elem(nobs, sigma),
153                })
154            }
155            ResponseFamily::Binomial => Ok(NoiseModel::Bernoulli),
156            ResponseFamily::Poisson => Ok(NoiseModel::Poisson),
157            ResponseFamily::Tweedie { p } => {
158                let p = *p;
159                if !is_valid_tweedie_power(p) {
160                    crate::bail_invalid_estim!(
161                        "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
162                    );
163                }
164                let phi = Self::require_positive_noise_parameter(
165                    likelihood,
166                    "Tweedie dispersion phi",
167                    gaussian_scale,
168                )?;
169                Ok(NoiseModel::Tweedie {
170                    p,
171                    // Scalar-dispersion fit: broadcast one φ to every row. The
172                    // dispersion location-scale path (#1125) builds the per-row
173                    // vector directly in `run_generate_unified` instead.
174                    phi: Array1::from_elem(nobs, phi),
175                })
176            }
177            ResponseFamily::NegativeBinomial { theta, .. } => {
178                // The NB overdispersion θ is estimated jointly with the mean and
179                // the authoritative post-fit value is handed in as
180                // `gaussian_scale` (from `likelihood_scale.negbin_theta()`);
181                // the θ embedded in the response spec is only the seed (1.0).
182                // Reading the seed was the NB sibling of the Beta #770 bug:
183                // generate drew Var = μ + μ² (θ = 1) regardless of the fitted
184                // overdispersion (#1124). Mirror the Beta arm below.
185                let theta = gaussian_scale.unwrap_or(*theta);
186                if !(theta.is_finite() && theta > 0.0) {
187                    crate::bail_invalid_estim!(
188                        "negative-binomial theta must be finite and > 0; got {theta}"
189                    );
190                }
191                Ok(NoiseModel::NegativeBinomial {
192                    theta: Array1::from_elem(nobs, theta),
193                })
194            }
195            ResponseFamily::Beta { phi } => {
196                // The Beta precision φ is estimated jointly with the mean
197                // (issue #567), so the authoritative value after fitting is the
198                // dispersion handed in as `gaussian_scale` — exactly as Gamma's
199                // shape and Tweedie's φ already take theirs. The `phi` embedded
200                // in the response spec is only the construction-time *seed* (left
201                // at its original value, e.g. 1.0, after the fit refreshes the
202                // estimate in `likelihood_scale`), so it serves solely as a
203                // fallback for fit-free construction where no fitted dispersion
204                // is supplied. Reading the seed instead of `gaussian_scale` was
205                // issue #770: the generative/observation path drew Beta responses
206                // with φ = 1.0 regardless of the data — nearly uniform on (0,1),
207                // ~20× too much variance — even though the fit estimated φ and
208                // the caller forwarded it here.
209                let phi = gaussian_scale.unwrap_or(*phi);
210                if !(phi.is_finite() && phi > 0.0) {
211                    crate::bail_invalid_estim!(
212                        "beta-regression phi must be finite and > 0; got {phi}"
213                    );
214                }
215                Ok(NoiseModel::Beta {
216                    phi: Array1::from_elem(nobs, phi),
217                })
218            }
219            ResponseFamily::Gamma => {
220                let shape = Self::require_positive_noise_parameter(
221                    likelihood,
222                    "Gamma shape",
223                    gaussian_scale,
224                )?;
225                Ok(NoiseModel::Gamma {
226                    shape: Array1::from_elem(nobs, shape),
227                })
228            }
229            ResponseFamily::RoystonParmar => Err(EstimationError::InvalidInput(
230                "RoystonParmar generative sampling is not exposed via generic generation"
231                    .to_string(),
232            )),
233        }
234    }
235
236    /// Build the observation `NoiseModel` for a dispersion location-scale fit
237    /// (#1125) from a fitted PER-ROW dispersion surface `dispersion[i]` (the
238    /// predictor's `exp(eta_d(x_i))` mapped into NoiseModel units — NB θ, Gamma
239    /// shape, Beta φ directly, Tweedie φ as the reciprocal). Unlike
240    /// `from_likelihood`, which broadcasts a single scalar dispersion to every
241    /// row, this threads the genuine per-observation precision channel so
242    /// generated data reproduces the fitted non-constant dispersion instead of
243    /// coming out homoscedastic at the seed.
244    pub fn from_likelihood_with_per_row_dispersion(
245        likelihood: &LikelihoodSpec,
246        dispersion: Array1<f64>,
247    ) -> Result<NoiseModel, EstimationError> {
248        match &likelihood.response {
249            ResponseFamily::Tweedie { p } => {
250                let p = *p;
251                if !is_valid_tweedie_power(p) {
252                    crate::bail_invalid_estim!(
253                        "Tweedie variance power must be finite and strictly between 1 and 2; got {p}"
254                    );
255                }
256                Ok(NoiseModel::Tweedie { p, phi: dispersion })
257            }
258            ResponseFamily::NegativeBinomial { .. } => {
259                Ok(NoiseModel::NegativeBinomial { theta: dispersion })
260            }
261            ResponseFamily::Beta { .. } => Ok(NoiseModel::Beta { phi: dispersion }),
262            ResponseFamily::Gamma => Ok(NoiseModel::Gamma { shape: dispersion }),
263            other => Err(EstimationError::InvalidInput(format!(
264                "per-row dispersion generative sampling is only defined for the dispersion \
265                 location-scale families (Gamma/NegativeBinomial/Beta/Tweedie); got {other:?}"
266            ))),
267        }
268    }
269
270    fn require_noise_parameter(
271        likelihood: &LikelihoodSpec,
272        parameter_name: &str,
273        value: Option<f64>,
274    ) -> Result<f64, EstimationError> {
275        let value = value.ok_or_else(|| {
276            EstimationError::InvalidInput(format!(
277                "{} generative sampling requires fitted {parameter_name}",
278                likelihood.pretty_name()
279            ))
280        })?;
281        if value.is_finite() {
282            Ok(value)
283        } else {
284            Err(EstimationError::InvalidInput(format!(
285                "{} generative sampling requires finite {parameter_name}; got {value}",
286                likelihood.pretty_name()
287            )))
288        }
289    }
290
291    fn require_positive_noise_parameter(
292        likelihood: &LikelihoodSpec,
293        parameter_name: &str,
294        value: Option<f64>,
295    ) -> Result<f64, EstimationError> {
296        let value = Self::require_noise_parameter(likelihood, parameter_name, value)?;
297        if value > 0.0 {
298            Ok(value)
299        } else {
300            Err(EstimationError::InvalidInput(format!(
301                "{} generative sampling requires {parameter_name} > 0; got {value}",
302                likelihood.pretty_name()
303            )))
304        }
305    }
306}
307
308/// Validate that a per-observation dispersion vector matches the mean length.
309/// Scalar-dispersion fits broadcast one value across all rows (length `n`);
310/// dispersion location-scale fits (#1125) carry the genuine per-row vector.
311fn check_dispersion_len(
312    dispersion: &Array1<f64>,
313    nobs: usize,
314    name: &str,
315) -> Result<(), EstimationError> {
316    if dispersion.len() != nobs {
317        crate::bail_invalid_estim!(
318            "{name} length {} does not match mean length {nobs}",
319            dispersion.len()
320        );
321    }
322    Ok(())
323}
324
325/// Draw one synthetic observation vector from a generative spec.
326pub fn sampleobservations<R: rand::Rng + ?Sized>(
327    spec: &GenerativeSpec,
328    rng: &mut R,
329) -> Result<Array1<f64>, EstimationError> {
330    if spec.mean.iter().any(|m| !m.is_finite()) {
331        crate::bail_invalid_estim!("generative mean contains non-finite values");
332    }
333    match &spec.noise {
334        NoiseModel::Gaussian { sigma } => {
335            if sigma.len() != spec.mean.len() {
336                crate::bail_invalid_estim!(
337                    "Gaussian sigma length {} does not match mean length {}",
338                    sigma.len(),
339                    spec.mean.len()
340                );
341            }
342            let mut y = spec.mean.clone();
343            for i in 0..y.len() {
344                let sd = sigma[i].max(0.0);
345                if sd == 0.0 {
346                    continue;
347                }
348                let dist = rand_distr::Normal::new(0.0, sd).map_err(|e| {
349                    EstimationError::InvalidInput(format!("invalid Gaussian noise scale {sd}: {e}"))
350                })?;
351                y[i] += rand_distr::Distribution::sample(&dist, rng);
352            }
353            Ok(y)
354        }
355        NoiseModel::Poisson => {
356            let mut y = Array1::<f64>::zeros(spec.mean.len());
357            for i in 0..y.len() {
358                let lam = spec.mean[i].max(1e-12);
359                let dist = rand_distr::Poisson::new(lam).map_err(|e| {
360                    EstimationError::InvalidInput(format!("invalid Poisson rate {lam}: {e}"))
361                })?;
362                let draw = rand_distr::Distribution::sample(&dist, rng);
363                y[i] = draw;
364            }
365            Ok(y)
366        }
367        NoiseModel::Tweedie { p, phi } => {
368            if !(p.is_finite() && *p >= 1.0 && *p <= 2.0) {
369                crate::bail_invalid_estim!("invalid Tweedie power p: {p}");
370            }
371            check_dispersion_len(phi, spec.mean.len(), "Tweedie dispersion phi")?;
372            for (i, &phi_i) in phi.iter().enumerate() {
373                if !(phi_i.is_finite() && phi_i > 0.0) {
374                    crate::bail_invalid_estim!(
375                        "invalid Tweedie dispersion phi at row {i}: {phi_i}"
376                    );
377                }
378            }
379            let mut y = Array1::<f64>::zeros(spec.mean.len());
380            if (*p - 1.0).abs() <= 1.0e-12 {
381                for i in 0..y.len() {
382                    let phi_i = phi[i];
383                    let lam = (spec.mean[i] / phi_i).max(1e-12);
384                    let dist = rand_distr::Poisson::new(lam).map_err(|e| {
385                        EstimationError::InvalidInput(format!(
386                            "invalid Tweedie-Poisson rate {lam}: {e}"
387                        ))
388                    })?;
389                    y[i] = phi_i * rand_distr::Distribution::sample(&dist, rng);
390                }
391                return Ok(y);
392            }
393            if (*p - 2.0).abs() <= 1.0e-12 {
394                for i in 0..y.len() {
395                    let phi_i = phi[i];
396                    let shape = (1.0 / phi_i).max(1e-12);
397                    let mu = spec.mean[i].max(1e-12);
398                    let scale = (mu * phi_i).max(1e-12);
399                    let dist = rand_distr::Gamma::new(shape, scale).map_err(|e| {
400                        EstimationError::InvalidInput(format!(
401                            "invalid Tweedie-Gamma params shape={shape} scale={scale}: {e}"
402                        ))
403                    })?;
404                    y[i] = rand_distr::Distribution::sample(&dist, rng);
405                }
406                return Ok(y);
407            }
408            let alpha = (2.0 - *p) / (*p - 1.0);
409            for i in 0..y.len() {
410                let phi_i = phi[i];
411                let mu = spec.mean[i].max(1e-12);
412                let lambda = (mu.powf(2.0 - *p) / (phi_i * (2.0 - *p))).max(1e-12);
413                let scale = (phi_i * (*p - 1.0) * mu.powf(*p - 1.0)).max(1e-12);
414                let count_dist = rand_distr::Poisson::new(lambda).map_err(|e| {
415                    EstimationError::InvalidInput(format!(
416                        "invalid Tweedie compound-Poisson rate {lambda}: {e}"
417                    ))
418                })?;
419                let count = rand_distr::Distribution::sample(&count_dist, rng) as usize;
420                if count == 0 {
421                    continue;
422                }
423                let jump_dist = rand_distr::Gamma::new(alpha, scale).map_err(|e| {
424                    EstimationError::InvalidInput(format!(
425                        "invalid Tweedie jump params shape={alpha} scale={scale}: {e}"
426                    ))
427                })?;
428                y[i] = (0..count)
429                    .map(|_| rand_distr::Distribution::sample(&jump_dist, rng))
430                    .sum();
431            }
432            Ok(y)
433        }
434        NoiseModel::NegativeBinomial { theta } => {
435            check_dispersion_len(theta, spec.mean.len(), "NegativeBinomial theta")?;
436            let mut y = Array1::<f64>::zeros(spec.mean.len());
437            for i in 0..y.len() {
438                let theta_i = theta[i];
439                if !(theta_i.is_finite() && theta_i > 0.0) {
440                    crate::bail_invalid_estim!(
441                        "invalid negative-binomial theta at row {i}: {theta_i}"
442                    );
443                }
444                let mu = spec.mean[i].max(1e-12);
445                let scale = (mu / theta_i).max(1e-12);
446                let gamma = rand_distr::Gamma::new(theta_i, scale).map_err(|e| {
447                    EstimationError::InvalidInput(format!(
448                        "invalid NegativeBinomial gamma mixture params theta={theta_i} scale={scale}: {e}"
449                    ))
450                })?;
451                let lambda = rand_distr::Distribution::sample(&gamma, rng).max(1e-12);
452                let poisson = rand_distr::Poisson::new(lambda).map_err(|e| {
453                    EstimationError::InvalidInput(format!(
454                        "invalid NegativeBinomial Poisson rate {lambda}: {e}"
455                    ))
456                })?;
457                y[i] = rand_distr::Distribution::sample(&poisson, rng);
458            }
459            Ok(y)
460        }
461        NoiseModel::Beta { phi } => {
462            check_dispersion_len(phi, spec.mean.len(), "Beta phi")?;
463            let mut y = Array1::<f64>::zeros(spec.mean.len());
464            for i in 0..y.len() {
465                let phi_i = phi[i];
466                if !(phi_i.is_finite() && phi_i > 0.0) {
467                    crate::bail_invalid_estim!("invalid beta-regression phi at row {i}: {phi_i}");
468                }
469                let mu = spec.mean[i].clamp(1e-12, 1.0 - 1e-12);
470                let alpha = (mu * phi_i).max(1e-12);
471                let beta = ((1.0 - mu) * phi_i).max(1e-12);
472                let dist = rand_distr::Beta::new(alpha, beta).map_err(|e| {
473                    EstimationError::InvalidInput(format!(
474                        "invalid Beta params alpha={alpha} beta={beta}: {e}"
475                    ))
476                })?;
477                y[i] = rand_distr::Distribution::sample(&dist, rng);
478            }
479            Ok(y)
480        }
481        NoiseModel::Gamma { shape } => {
482            check_dispersion_len(shape, spec.mean.len(), "Gamma shape")?;
483            let mut y = Array1::<f64>::zeros(spec.mean.len());
484            for i in 0..y.len() {
485                let shape_i = shape[i];
486                if !shape_i.is_finite() || shape_i <= 0.0 {
487                    crate::bail_invalid_estim!("invalid Gamma shape at row {i}: {shape_i}");
488                }
489                let mu = spec.mean[i].max(1e-12);
490                let scale = (mu / shape_i).max(1e-12);
491                let dist = rand_distr::Gamma::new(shape_i, scale).map_err(|e| {
492                    EstimationError::InvalidInput(format!(
493                        "invalid Gamma params shape={shape_i} scale={scale}: {e}"
494                    ))
495                })?;
496                y[i] = rand_distr::Distribution::sample(&dist, rng);
497            }
498            Ok(y)
499        }
500        NoiseModel::Bernoulli => {
501            let mut y = Array1::<f64>::zeros(spec.mean.len());
502            for i in 0..y.len() {
503                let p = spec.mean[i];
504                let dist = rand_distr::Bernoulli::new(p).map_err(|e| {
505                    EstimationError::InvalidInput(format!("invalid Bernoulli probability {p}: {e}"))
506                })?;
507                y[i] = if rand_distr::Distribution::sample(&dist, rng) {
508                    1.0
509                } else {
510                    0.0
511                };
512            }
513            Ok(y)
514        }
515    }
516}
517
518/// Draw multiple synthetic replicates (n_draws x nobs).
519pub fn sampleobservation_replicates<R: rand::Rng + ?Sized>(
520    spec: &GenerativeSpec,
521    n_draws: usize,
522    rng: &mut R,
523) -> Result<Array2<f64>, EstimationError> {
524    let n = spec.nobs();
525    let mut out = Array2::<f64>::zeros((n_draws, n));
526    for d in 0..n_draws {
527        let draw = sampleobservations(spec, rng)?;
528        out.row_mut(d).assign(&draw);
529    }
530    Ok(out)
531}
532
533/// Extension trait for custom multi-block families that provide explicit
534/// generative semantics (mean + observation noise) at a fitted state.
535pub trait CustomFamilyGenerative: CustomFamily {
536    fn generativespec(
537        &self,
538        block_states: &[ParameterBlockState],
539    ) -> Result<GenerativeSpec, String>;
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use crate::family_runtime::{FamilyStrategy, strategy_for_spec};
546
547    /// The canonical dispersion picker must read the *fitted* dispersion off the
548    /// scale metadata, never the construction seed embedded in the response
549    /// spec. This is the single guard for the whole "generate draws at the seed
550    /// dispersion" bug family — Gamma #678, Beta #769/#770, Tweedie #771, and
551    /// the NB sibling #1124 — now that the picker lives in exactly one place
552    /// (previously three divergent copies let a fix in one miss the others).
553    #[test]
554    fn family_noise_parameter_reads_fitted_dispersion_not_seed() {
555        // NB: spec carries the seed theta = 1; the fit estimated theta_hat.
556        let nb = LikelihoodSpec::negative_binomial_log(1.0);
557        assert_eq!(
558            family_noise_parameter(
559                LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.97 },
560                0.0,
561                &nb,
562            ),
563            Some(2.97),
564            "NB picker must read theta_hat (#1124), not the seed theta=1"
565        );
566
567        // Tweedie: the picker must return the dispersion phi, never the variance
568        // power p that lives on the spec.
569        let tw = LikelihoodSpec::tweedie_log(1.5);
570        assert_eq!(
571            family_noise_parameter(
572                LikelihoodScaleMetadata::EstimatedTweediePhi { phi: 7.25 },
573                0.0,
574                &tw,
575            ),
576            Some(7.25),
577            "Tweedie picker must read phi_hat (#771), not the variance power p"
578        );
579
580        // Beta: spec carries the seed phi = 1; the fit estimated phi_hat.
581        let beta = LikelihoodSpec::beta_logit(1.0);
582        assert_eq!(
583            family_noise_parameter(
584                LikelihoodScaleMetadata::EstimatedBetaPhi { phi: 12.0 },
585                0.0,
586                &beta,
587            ),
588            Some(12.0),
589            "Beta picker must read phi_hat (#770), not the seed phi=1"
590        );
591
592        // Gamma: the estimated shape must win over the residual-scale fallback.
593        let gamma = LikelihoodSpec::gamma_log();
594        assert_eq!(
595            family_noise_parameter(
596                LikelihoodScaleMetadata::EstimatedGammaShape { shape: 4.5 },
597                0.123,
598                &gamma,
599            ),
600            Some(4.5),
601            "Gamma picker must read shape_hat (#678), not the residual-scale fallback"
602        );
603    }
604
605    /// With no fitted dispersion recorded (fit-free construction), the picker
606    /// falls back to the seed on the spec / the residual scale. It must never
607    /// return `None` for a dispersion family, or generation would have nothing
608    /// to draw with.
609    #[test]
610    fn family_noise_parameter_falls_back_to_seed_when_unfitted() {
611        // `ProfiledGaussian` carries no fixed_phi / negbin_theta / gamma_shape,
612        // so every accessor returns `None` and the picker must use the fallback.
613        let none = LikelihoodScaleMetadata::ProfiledGaussian;
614        assert_eq!(
615            family_noise_parameter(none, 0.0, &LikelihoodSpec::negative_binomial_log(3.5)),
616            Some(3.5),
617            "NB picker must fall back to the spec seed theta"
618        );
619        assert_eq!(
620            family_noise_parameter(none, 0.0, &LikelihoodSpec::beta_logit(8.0)),
621            Some(8.0),
622            "Beta picker must fall back to the spec seed phi"
623        );
624        assert_eq!(
625            family_noise_parameter(none, 0.0, &LikelihoodSpec::tweedie_log(1.5)),
626            Some(1.0),
627            "Tweedie picker must fall back to unit dispersion"
628        );
629        assert_eq!(
630            family_noise_parameter(none, 2.0, &LikelihoodSpec::gamma_log()),
631            Some(2.0),
632            "Gamma picker must fall back to the residual scale"
633        );
634    }
635
636    /// End-to-end through the exact composition `gam generate` and
637    /// `sample_replicates` use — picker → `from_likelihood`. The seed-spec
638    /// theta = 1 plus an estimated theta_hat must yield a per-row NB noise model
639    /// at theta_hat, not at the seed. This is the #1124 repro at the unit level,
640    /// from the angle of the *composed* path rather than `from_likelihood` alone.
641    #[test]
642    fn picker_then_from_likelihood_threads_fitted_nb_theta() {
643        let nobs = 6usize;
644        let seed_spec = LikelihoodSpec::negative_binomial_log(1.0);
645        let scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: 2.751 };
646        let picked = family_noise_parameter(scale, 0.0, &seed_spec);
647        let noise =
648            NoiseModel::from_likelihood(&seed_spec, nobs, picked).expect("NB noise model builds");
649        let NoiseModel::NegativeBinomial { theta } = noise else {
650            panic!("expected an NB observation noise model");
651        };
652        assert!(
653            theta.len() == nobs && theta.iter().all(|&t| (t - 2.751).abs() < 1e-12),
654            "NB generate composes the seed theta=1 instead of theta_hat (#1124): {theta:?}"
655        );
656    }
657
658    /// Structural equality for `NoiseModel` (no derived `PartialEq` so that
659    /// the live enum can carry per-observation arrays). Two models are equal
660    /// when they are the same variant with bitwise-identical parameters.
661    fn noise_models_match(a: &NoiseModel, b: &NoiseModel) -> bool {
662        match (a, b) {
663            (NoiseModel::Gaussian { sigma: sa }, NoiseModel::Gaussian { sigma: sb }) => sa == sb,
664            (NoiseModel::Poisson, NoiseModel::Poisson) => true,
665            (NoiseModel::Bernoulli, NoiseModel::Bernoulli) => true,
666            (NoiseModel::Tweedie { p: pa, phi: pha }, NoiseModel::Tweedie { p: pb, phi: phb }) => {
667                pa == pb && pha == phb
668            }
669            (
670                NoiseModel::NegativeBinomial { theta: ta },
671                NoiseModel::NegativeBinomial { theta: tb },
672            ) => ta == tb,
673            (NoiseModel::Beta { phi: pa }, NoiseModel::Beta { phi: pb }) => pa == pb,
674            (NoiseModel::Gamma { shape: sa }, NoiseModel::Gamma { shape: sb }) => sa == sb,
675            _ => false,
676        }
677    }
678
679    /// For every supported built-in family, the canonical
680    /// `NoiseModel::from_likelihood` mapping and the simulation adapter
681    /// `FamilyStrategy::simulate_noise` must produce the same `NoiseModel`
682    /// from the same fitted dispersion — this is the single-mapping guarantee
683    /// the unification provides.
684    #[test]
685    fn from_likelihood_matches_simulate_noise_for_each_family() {
686        let nobs = 5usize;
687        let mean = Array1::from_elem(nobs, 0.5_f64);
688
689        // (spec, dispersion/gaussian_scale, expected noise variant).
690        let cases: [(LikelihoodSpec, Option<f64>, NoiseModel); 7] = [
691            (
692                LikelihoodSpec::gaussian_identity(),
693                Some(0.7),
694                NoiseModel::Gaussian {
695                    sigma: Array1::from_elem(nobs, 0.7),
696                },
697            ),
698            (
699                LikelihoodSpec::binomial_logit(),
700                None,
701                NoiseModel::Bernoulli,
702            ),
703            (LikelihoodSpec::poisson_log(), None, NoiseModel::Poisson),
704            (
705                LikelihoodSpec::tweedie_log(1.4),
706                Some(0.9),
707                NoiseModel::Tweedie {
708                    p: 1.4,
709                    phi: Array1::from_elem(nobs, 0.9),
710                },
711            ),
712            (
713                LikelihoodSpec::negative_binomial_log(2.5),
714                None,
715                NoiseModel::NegativeBinomial {
716                    theta: Array1::from_elem(nobs, 2.5),
717                },
718            ),
719            (
720                LikelihoodSpec::beta_logit(3.0),
721                None,
722                NoiseModel::Beta {
723                    phi: Array1::from_elem(nobs, 3.0),
724                },
725            ),
726            (
727                LikelihoodSpec::gamma_log(),
728                Some(1.5),
729                NoiseModel::Gamma {
730                    shape: Array1::from_elem(nobs, 1.5),
731                },
732            ),
733        ];
734
735        for (spec, scale, expected) in cases {
736            let from_helper = NoiseModel::from_likelihood(&spec, nobs, scale)
737                .expect("canonical mapping must accept a supported family");
738            let from_strategy = strategy_for_spec(&spec)
739                .simulate_noise(&mean, scale)
740                .expect("simulation adapter must accept a supported family");
741
742            assert!(
743                noise_models_match(&from_helper, &expected),
744                "{} canonical mapping produced an unexpected NoiseModel",
745                spec.pretty_name()
746            );
747            assert!(
748                noise_models_match(&from_helper, &from_strategy),
749                "{} simulation and inference disagree on the NoiseModel",
750                spec.pretty_name()
751            );
752        }
753    }
754
755    /// RoystonParmar is not exposed through the generic generative path, and
756    /// both the canonical mapping and the simulation adapter must reject it
757    /// identically so the two paths stay in lockstep.
758    #[test]
759    fn royston_parmar_rejected_on_both_paths() {
760        let spec = LikelihoodSpec::royston_parmar();
761        let mean = Array1::from_elem(3, 0.0_f64);
762        assert!(NoiseModel::from_likelihood(&spec, 3, None).is_err());
763        assert!(
764            strategy_for_spec(&spec)
765                .simulate_noise(&mean, None)
766                .is_err()
767        );
768    }
769
770    /// Invalid / missing dispersion is rejected the same way regardless of
771    /// which entry point is used.
772    #[test]
773    fn invalid_dispersion_rejected_on_both_paths() {
774        let mean = Array1::from_elem(4, 0.0_f64);
775
776        // Gaussian sigma missing.
777        let gauss = LikelihoodSpec::gaussian_identity();
778        assert!(NoiseModel::from_likelihood(&gauss, 4, None).is_err());
779        assert!(
780            strategy_for_spec(&gauss)
781                .simulate_noise(&mean, None)
782                .is_err()
783        );
784
785        // Tweedie power outside (1, 2).
786        let bad_tweedie = LikelihoodSpec::tweedie_log(2.5);
787        assert!(NoiseModel::from_likelihood(&bad_tweedie, 4, Some(0.5)).is_err());
788        assert!(
789            strategy_for_spec(&bad_tweedie)
790                .simulate_noise(&mean, Some(0.5))
791                .is_err()
792        );
793
794        // Gamma shape non-positive.
795        let gamma = LikelihoodSpec::gamma_log();
796        assert!(NoiseModel::from_likelihood(&gamma, 4, Some(-1.0)).is_err());
797        assert!(
798            strategy_for_spec(&gamma)
799                .simulate_noise(&mean, Some(-1.0))
800                .is_err()
801        );
802    }
803}