gam 0.3.125

Generalized penalized likelihood engine
use crate::util::quantile::quantile_from_sorted;
use std::cmp::Ordering;

pub fn credible_interval(
    samples_flat: &[f64],
    n_draws: usize,
    n_coeffs: usize,
    level: f64,
) -> Result<Vec<f64>, String> {
    if !(level > 0.0 && level < 1.0) {
        return Err(format!("interval level must lie in (0, 1); got {level}"));
    }
    if samples_flat.len() != n_draws * n_coeffs {
        return Err(format!(
            "posterior_credible_interval samples shape mismatch: got {} floats, expected {} * {}",
            samples_flat.len(),
            n_draws,
            n_coeffs
        ));
    }
    if n_draws == 0 {
        return Err("posterior_credible_interval requires at least one posterior draw".to_string());
    }
    let alpha = (1.0 - level) / 2.0;
    let mut out = Vec::with_capacity(2 * n_coeffs);
    let mut column = vec![0.0_f64; n_draws];
    for j in 0..n_coeffs {
        for k in 0..n_draws {
            column[k] = samples_flat[k * n_coeffs + j];
        }
        column.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
        out.push(quantile_from_sorted(&column, alpha));
        out.push(quantile_from_sorted(&column, 1.0 - alpha));
    }
    Ok(out)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn zero_draws_rejects_credible_interval() {
        let err = credible_interval(&[], 0, 2, 0.95).expect_err("zero draws must fail");
        assert!(err.contains("requires at least one posterior draw"));
    }

    #[test]
    fn credible_interval_brackets_sample_mean_with_draws() {
        let n_draws = 5;
        let n_coeffs = 2;
        let samples = vec![
            -2.0, 1.0, //
            -1.0, 2.0, //
            0.0, 3.0, //
            1.0, 4.0, //
            2.0, 5.0, //
        ];
        let ci = credible_interval(&samples, n_draws, n_coeffs, 0.80).expect("interval");
        for j in 0..n_coeffs {
            let mean =
                (0..n_draws).map(|k| samples[k * n_coeffs + j]).sum::<f64>() / n_draws as f64;
            assert!(
                ci[j * 2] <= mean && mean <= ci[j * 2 + 1],
                "coefficient {j} mean {mean} must sit inside [{}, {}]",
                ci[j * 2],
                ci[j * 2 + 1]
            );
        }
    }
}