gam 0.1.17

Generalized penalized likelihood engine
Documentation
use gam::construction::CanonicalPenalty;
use gam::estimate::{ExternalOptimOptions, PenaltySpec, evaluate_externalcost_andridge};
use gam::pirls::{PenaltyConfig, PirlsConfig, PirlsProblem, fit_model_for_fixed_rho};
use gam::smooth::BlockwisePenalty;
use gam::types::{
    GlmLikelihoodFamily, GlmLikelihoodSpec, InverseLink, LikelihoodFamily, LinkFunction,
    LogSmoothingParamsView,
};
use ndarray::{Array1, Array2, array};
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};

fn canonicalize_test_penalties(s_list: &[Array2<f64>]) -> Vec<CanonicalPenalty> {
    let p = s_list[0].nrows();
    s_list
        .iter()
        .enumerate()
        .filter_map(|(idx, s)| {
            gam::construction::canonicalize_penalty_spec(
                &PenaltySpec::Dense(s.clone()),
                p,
                idx,
                "test",
            )
            .expect("canonicalize test penalty")
        })
        .collect()
}

fn make_problem(
    seed: u64,
) -> (
    Array2<f64>,
    Array1<f64>,
    Array1<f64>,
    Array2<f64>,
    Vec<BlockwisePenalty>,
) {
    let n = 100;
    let p = 10;
    let mut rng = StdRng::seed_from_u64(seed);
    let mut x = Array2::<f64>::zeros((n, p));
    for i in 0..n {
        x[[i, 0]] = 1.0;
        for j in 1..p {
            x[[i, j]] = rng.random_range(-1.0..1.0);
        }
    }
    let beta = Array1::from_shape_fn(p, |j| if j == 0 { -0.1 } else { 0.2 / j as f64 });
    let eta = x.dot(&beta);
    let y = eta.mapv(|e| {
        let prob = 1.0 / (1.0 + (-e).exp());
        if rng.random::<f64>() < prob { 1.0 } else { 0.0 }
    });
    let w = Array1::<f64>::ones(n);
    let mut s = Array2::<f64>::zeros((p, p));
    for j in 1..p {
        s[[j, j]] = 1.0;
    }
    (x, y, w, s.clone(), vec![BlockwisePenalty::new(0..p, s)])
}

fn fit_beta_norm(
    x: &Array2<f64>,
    y: &Array1<f64>,
    w: &Array1<f64>,
    penalties: &[CanonicalPenalty],
    rho: f64,
    firth: bool,
) -> f64 {
    let p = x.ncols();
    let cfg = PirlsConfig {
        likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialLogit),
        link_kind: InverseLink::Standard(LinkFunction::Logit),
        max_iterations: 500,
        convergence_tolerance: 1e-10,
        firth_bias_reduction: firth,
        initial_lm_lambda: None,
    };
    let offset = Array1::<f64>::zeros(y.len());
    let (fit, _) = fit_model_for_fixed_rho(
        LogSmoothingParamsView::new(array![rho].view()),
        PirlsProblem {
            x: x.view(),
            offset: offset.view(),
            y: y.view(),
            priorweights: w.view(),
            covariate_se: None,
        },
        PenaltyConfig {
            canonical_penalties: penalties,
            balanced_penalty_root: None,
            reparam_invariant: None,
            p,
            coefficient_lower_bounds: None,
            linear_constraints_original: None,
            penalty_shrinkage_floor: None,
            kronecker_factored: None,
        },
        &cfg,
        None,
    )
    .expect("fit");
    fit.beta_transformed
        .dot(fit.beta_transformed.as_ref())
        .sqrt()
}

fn proxycostwith_pirls(
    x: &Array2<f64>,
    y: &Array1<f64>,
    w: &Array1<f64>,
    penalties: &[CanonicalPenalty],
    s: &Array2<f64>,
    rho: f64,
    firth: bool,
) -> f64 {
    let cfg = PirlsConfig {
        likelihood: GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialLogit),
        link_kind: InverseLink::Standard(LinkFunction::Logit),
        max_iterations: 500,
        convergence_tolerance: 1e-10,
        firth_bias_reduction: firth,
        initial_lm_lambda: None,
    };
    let offset = Array1::<f64>::zeros(y.len());
    let p = x.ncols();
    let (fit, _) = fit_model_for_fixed_rho(
        LogSmoothingParamsView::new(array![rho].view()),
        PirlsProblem {
            x: x.view(),
            offset: offset.view(),
            y: y.view(),
            priorweights: w.view(),
            covariate_se: None,
        },
        PenaltyConfig {
            canonical_penalties: penalties,
            balanced_penalty_root: None,
            reparam_invariant: None,
            p,
            coefficient_lower_bounds: None,
            linear_constraints_original: None,
            penalty_shrinkage_floor: None,
            kronecker_factored: None,
        },
        &cfg,
        None,
    )
    .expect("fit");
    let lambda = rho.exp();
    let b = fit.beta_transformed.as_ref().to_owned();
    let penalty = 0.5 * lambda * b.dot(&s.dot(&b));
    fit.deviance + penalty
}

#[test]
fn firthfd_step_size_sensitivity() {
    let (x, y, w, s_dense, s_list) = make_problem(31);
    let offset = Array1::<f64>::zeros(y.len());
    let opts = ExternalOptimOptions {
        latent_cloglog: None,
        mixture_link: None,
        optimize_mixture: false,
        sas_link: None,
        optimize_sas: false,
        family: LikelihoodFamily::BinomialLogit,
        compute_inference: true,
        tol: 1e-10,
        max_iter: 500,
        nullspace_dims: vec![1],
        linear_constraints: None,
        firth_bias_reduction: None,
        penalty_shrinkage_floor: None,
        rho_prior: Default::default(),
        kronecker_penalty_system: None,
        kronecker_factored: None,
    };
    let base_rho = 12.0;
    let cost_at = |rho: f64| -> f64 {
        evaluate_externalcost_andridge(
            y.view(),
            w.view(),
            x.view(),
            offset.view(),
            &s_list,
            &opts,
            &array![rho],
        )
        .map(|(c, _)| c)
        .expect("cost")
    };
    assert!(s_dense.iter().all(|v| v.is_finite()));
    let wide_trend = cost_at(base_rho + 1.0) - cost_at(base_rho - 1.0);
    let trend_sign = wide_trend > 0.0;
    let step_sizes = [0.02, 0.01, 0.005, 0.002, 0.001, 0.0005];
    let mut consistent_count = 0;
    for &h in &step_sizes {
        let fd = (cost_at(base_rho + h) - cost_at(base_rho - h)) / (2.0 * h);
        if (fd > 0.0) == trend_sign {
            consistent_count += 1;
        }
    }
    assert!(consistent_count >= step_sizes.len() / 2);
}

#[test]
fn firth_beta_monotonicity_comparison() {
    let (x, y, w, s_dense, _) = make_problem(31);
    let penalties = canonicalize_test_penalties(&[s_dense.clone()]);
    let deltas = [
        -0.010_f64, -0.005, -0.002, -0.001, 0.0, 0.001, 0.002, 0.005, 0.010,
    ];
    let betas_firth: Vec<f64> = deltas
        .iter()
        .map(|&d| fit_beta_norm(&x, &y, &w, &penalties, 12.0 + d, true))
        .collect();
    let betas_no_firth: Vec<f64> = deltas
        .iter()
        .map(|&d| fit_beta_norm(&x, &y, &w, &penalties, 12.0 + d, false))
        .collect();
    let count_sign_changes = |values: &[f64]| -> usize {
        values
            .windows(2)
            .filter(|w| (w[1] - w[0]).signum() != 0.0)
            .zip(values.windows(2).skip(1))
            .filter(|(a, b)| (a[1] - a[0]).signum() * (b[1] - b[0]).signum() < 0.0)
            .count()
    };
    let changes_firth = count_sign_changes(&betas_firth);
    let changes_no_firth = count_sign_changes(&betas_no_firth);
    assert!(changes_no_firth <= changes_firth || changes_no_firth <= 2);
}

#[test]
fn firthcost_oscillationvs_no_firth() {
    let (x, y, w, s_dense, s_list) = make_problem(31);
    let penalties = canonicalize_test_penalties(&[s_dense.clone()]);
    let s = &s_dense;
    assert_eq!(s_list.len(), 1);
    let steps: Vec<f64> = (-20..=20).map(|i| i as f64 * 0.001).collect();
    let cost_firth: Vec<f64> = steps
        .iter()
        .map(|&d| proxycostwith_pirls(&x, &y, &w, &penalties, s, 12.0 + d, true))
        .collect();
    let cost_no_firth: Vec<f64> = steps
        .iter()
        .map(|&d| proxycostwith_pirls(&x, &y, &w, &penalties, s, 12.0 + d, false))
        .collect();
    let count_direction_changes = |costs: &[f64]| -> usize {
        let mut changes = 0;
        for i in 1..costs.len() - 1 {
            let left = costs[i] - costs[i - 1];
            let right = costs[i + 1] - costs[i];
            if left * right < 0.0 {
                changes += 1;
            }
        }
        changes
    };
    let firth_changes = count_direction_changes(&cost_firth);
    let no_firth_changes = count_direction_changes(&cost_no_firth);
    assert!(no_firth_changes <= firth_changes || no_firth_changes <= 5);
}