gam 0.3.80

Generalized penalized likelihood engine
Documentation
use ndarray::{Array1, Array2, ArrayView1};

use crate::linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_matrix};
use crate::solver::arrow_schur::{ArrowSchurError, ArrowSchurSystem};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArrowConvergenceIssueSeverity {
    Warning,
    Failure,
}

#[derive(Debug, Clone)]
pub struct ArrowConvergenceIssue {
    pub severity: ArrowConvergenceIssueSeverity,
    pub message: String,
}

#[derive(Debug, Clone)]
pub struct LatentCompactBox {
    pub lower: Array1<f64>,
    pub upper: Array1<f64>,
}

#[derive(Debug, Clone)]
pub struct ArrowSchurConvergenceCheckOptions {
    pub proximal_ridge_floor: f64,
    pub max_condition_number: f64,
    pub penalty_gradient_lipschitz_bound: Option<f64>,
    pub adaptive_proximal_correction_enabled: bool,
    pub require_compact_box: bool,
    pub fail_on_warnings: bool,
}

impl Default for ArrowSchurConvergenceCheckOptions {
    fn default() -> Self {
        Self {
            proximal_ridge_floor: 1e-8,
            max_condition_number: 1e12,
            penalty_gradient_lipschitz_bound: None,
            adaptive_proximal_correction_enabled: true,
            require_compact_box: true,
            fail_on_warnings: false,
        }
    }
}

#[derive(Debug, Clone)]
pub struct ArrowSchurConvergenceReport {
    pub issues: Vec<ArrowConvergenceIssue>,
    pub min_row_eigenvalue: f64,
    pub schur_condition_number: f64,
}

impl ArrowSchurConvergenceReport {
    pub fn has_failures(&self) -> bool {
        self.issues
            .iter()
            .any(|issue| issue.severity == ArrowConvergenceIssueSeverity::Failure)
    }

    pub fn warn_or_fail(&self) -> Result<(), ArrowSchurError> {
        for issue in self.issues.iter() {
            match issue.severity {
                ArrowConvergenceIssueSeverity::Warning => {
                    log::warn!("[arrow-Schur convergence] {}", issue.message);
                }
                ArrowConvergenceIssueSeverity::Failure => {
                    return Err(ArrowSchurError::AdaptiveCorrectionFailed {
                        reason: issue.message.clone(),
                    });
                }
            }
        }
        Ok(())
    }
}

pub fn check_arrow_schur_fit_start(
    sys: &ArrowSchurSystem,
    latent_flat: Option<ArrayView1<'_, f64>>,
    compact_box: Option<&LatentCompactBox>,
    options: &ArrowSchurConvergenceCheckOptions,
) -> ArrowSchurConvergenceReport {
    let mut issues = Vec::new();
    if !options.adaptive_proximal_correction_enabled {
        issues.push(issue(
            ArrowConvergenceIssueSeverity::Failure,
            "undamped full-step arrow-Schur Newton has a two-cycle counterexample; enable adaptive proximal correction",
        ));
    }
    check_latent_compactness(latent_flat, compact_box, options, &mut issues);

    if options.penalty_gradient_lipschitz_bound.is_none() {
        let severity = if options.fail_on_warnings {
            ArrowConvergenceIssueSeverity::Failure
        } else {
            ArrowConvergenceIssueSeverity::Warning
        };
        issues.push(issue(
            severity,
            "no penalty-gradient Lipschitz certificate was supplied; fit-start can only check finite local Hessian blocks",
        ));
    }

    let ridge = options.proximal_ridge_floor.max(0.0);
    let min_row_eigenvalue = min_shifted_row_eigenvalue(sys, ridge, &mut issues);
    let schur_condition_number = schur_condition_number(sys, ridge, options, &mut issues);

    ArrowSchurConvergenceReport {
        issues,
        min_row_eigenvalue,
        schur_condition_number,
    }
}

fn issue(severity: ArrowConvergenceIssueSeverity, message: &str) -> ArrowConvergenceIssue {
    ArrowConvergenceIssue {
        severity,
        message: message.to_string(),
    }
}

fn check_latent_compactness(
    latent_flat: Option<ArrayView1<'_, f64>>,
    compact_box: Option<&LatentCompactBox>,
    options: &ArrowSchurConvergenceCheckOptions,
    issues: &mut Vec<ArrowConvergenceIssue>,
) {
    let missing_severity = if options.require_compact_box || options.fail_on_warnings {
        ArrowConvergenceIssueSeverity::Failure
    } else {
        ArrowConvergenceIssueSeverity::Warning
    };
    let Some(latent) = latent_flat else {
        issues.push(issue(
            missing_severity,
            "latent compactness was not checked because no fit-start latent vector was supplied",
        ));
        return;
    };
    let Some(bounds) = compact_box else {
        issues.push(issue(
            missing_severity,
            "latent compactness was not checked because no compact box was supplied",
        ));
        return;
    };
    if bounds.lower.len() != latent.len() || bounds.upper.len() != latent.len() {
        issues.push(issue(
            ArrowConvergenceIssueSeverity::Failure,
            "latent compact box dimension does not match latent vector length",
        ));
        return;
    }
    for i in 0..latent.len() {
        let value = latent[i];
        if !(value.is_finite() && bounds.lower[i].is_finite() && bounds.upper[i].is_finite()) {
            issues.push(issue(
                ArrowConvergenceIssueSeverity::Failure,
                "latent compact box check found a non-finite value or bound",
            ));
            return;
        }
        if bounds.lower[i] > bounds.upper[i] {
            issues.push(issue(
                ArrowConvergenceIssueSeverity::Failure,
                "latent compact box has lower bound greater than upper bound",
            ));
            return;
        }
        if value < bounds.lower[i] || value > bounds.upper[i] {
            issues.push(issue(
                ArrowConvergenceIssueSeverity::Failure,
                "fit-start latent value lies outside the supplied compact box",
            ));
            return;
        }
    }
}

fn min_shifted_row_eigenvalue(
    sys: &ArrowSchurSystem,
    ridge: f64,
    issues: &mut Vec<ArrowConvergenceIssue>,
) -> f64 {
    let mut min_eval = f64::INFINITY;
    for (row_idx, row) in sys.rows.iter().enumerate() {
        let di = row.htt.nrows();
        if row.htt.ncols() != di
            || row.htbeta.nrows() != di
            || row.htbeta.ncols() != sys.k
            || row.gt.len() != di
        {
            issues.push(issue(
                ArrowConvergenceIssueSeverity::Failure,
                "arrow-Schur row dimensions are inconsistent",
            ));
            return f64::NAN;
        }
        if row.htt.iter().any(|v| !v.is_finite())
            || row.htbeta.iter().any(|v| !v.is_finite())
            || row.gt.iter().any(|v| !v.is_finite())
        {
            issues.push(issue(
                ArrowConvergenceIssueSeverity::Failure,
                "arrow-Schur row contains non-finite Hessian or gradient entries",
            ));
            return f64::NAN;
        }
        let mut shifted = row.htt.clone();
        for j in 0..di {
            shifted[[j, j]] += ridge;
        }
        let (lo, _) = symmetric_eigenvalue_bounds(&shifted);
        min_eval = min_eval.min(lo);
        if !(lo.is_finite() && lo > 0.0) {
            issues.push(issue(
                ArrowConvergenceIssueSeverity::Failure,
                &format!(
                    "proximal row block {row_idx} is not positive definite at ridge {ridge}; min eigenvalue {lo}"
                ),
            ));
        }
    }
    min_eval
}

fn schur_condition_number(
    sys: &ArrowSchurSystem,
    ridge: f64,
    options: &ArrowSchurConvergenceCheckOptions,
    issues: &mut Vec<ArrowConvergenceIssue>,
) -> f64 {
    if sys.k == 0 {
        return 1.0;
    }
    if sys.hbb.dim() != (sys.k, sys.k) {
        let severity = if options.fail_on_warnings {
            ArrowConvergenceIssueSeverity::Failure
        } else {
            ArrowConvergenceIssueSeverity::Warning
        };
        issues.push(issue(
            severity,
            "matrix-free beta block: fit-start checker cannot materialize the Schur condition number",
        ));
        return f64::INFINITY;
    }
    let Some(mut schur) = build_shifted_schur(sys, ridge) else {
        issues.push(issue(
            ArrowConvergenceIssueSeverity::Failure,
            "could not build shifted Schur complement because a row block was not positive definite",
        ));
        return f64::INFINITY;
    };
    for j in 0..sys.k {
        schur[[j, j]] += ridge;
    }
    let (lo, hi) = symmetric_eigenvalue_bounds(&schur);
    if !(lo.is_finite() && lo > 0.0 && hi.is_finite()) {
        issues.push(issue(
            ArrowConvergenceIssueSeverity::Failure,
            &format!(
                "shifted Schur complement is not positive definite; eigen bounds [{lo}, {hi}]"
            ),
        ));
        return f64::INFINITY;
    }
    let condition = hi / lo;
    if condition > options.max_condition_number {
        issues.push(issue(
            ArrowConvergenceIssueSeverity::Failure,
            &format!(
                "shifted Schur condition number {condition:.3e} exceeds limit {:.3e}",
                options.max_condition_number
            ),
        ));
    }
    condition
}

fn build_shifted_schur(sys: &ArrowSchurSystem, ridge: f64) -> Option<Array2<f64>> {
    // Start from the effective β-block operator so the Schur complement is
    // correct when β contributions live in a structured `BetaPenaltyOp`
    // (e.g. the SAE data-fit `G ⊗ I_p` block) rather than the dense `hbb`
    // accumulator. Reduces to `hbb.clone()` when no `penalty_op` is installed.
    let mut schur = sys.effective_penalty_op().to_dense();
    for row in sys.rows.iter() {
        let di = row.htt.nrows();
        let mut htt = row.htt.clone();
        for j in 0..di {
            htt[[j, j]] += ridge;
        }
        let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::FiniteStrict)?;
        let solved = cholesky_solve_matrix(&factor, &row.htbeta);
        for c in 0..di {
            for a in 0..sys.k {
                let left = row.htbeta[[c, a]];
                if left == 0.0 {
                    continue;
                }
                for b in 0..sys.k {
                    schur[[a, b]] -= left * solved[[c, b]];
                }
            }
        }
    }
    symmetrize(&mut schur);
    Some(schur)
}

fn symmetric_eigenvalue_bounds(a: &Array2<f64>) -> (f64, f64) {
    let n = a.nrows();
    if n == 0 {
        return (1.0, 1.0);
    }
    if a.ncols() != n || a.iter().any(|v| !v.is_finite()) {
        return (f64::NAN, f64::NAN);
    }
    let mut m = a.clone();
    for _ in 0..(80 * n.max(1)) {
        let mut p = 0;
        let mut q = 0;
        let mut max_off = 0.0;
        for i in 0..n {
            for j in (i + 1)..n {
                let off = m[[i, j]].abs();
                if off > max_off {
                    max_off = off;
                    p = i;
                    q = j;
                }
            }
        }
        if max_off <= 1e-12 {
            break;
        }
        let app = m[[p, p]];
        let aqq = m[[q, q]];
        let apq = m[[p, q]];
        let tau = (aqq - app) / (2.0 * apq);
        let t = tau.signum() / (tau.abs() + (1.0 + tau * tau).sqrt());
        let c = 1.0 / (1.0 + t * t).sqrt();
        let s = t * c;
        for r in 0..n {
            if r != p && r != q {
                let mrp = m[[r, p]];
                let mrq = m[[r, q]];
                m[[r, p]] = c * mrp - s * mrq;
                m[[p, r]] = m[[r, p]];
                m[[r, q]] = s * mrp + c * mrq;
                m[[q, r]] = m[[r, q]];
            }
        }
        m[[p, p]] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
        m[[q, q]] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
        m[[p, q]] = 0.0;
        m[[q, p]] = 0.0;
    }
    let mut lo = f64::INFINITY;
    let mut hi = f64::NEG_INFINITY;
    for i in 0..n {
        lo = lo.min(m[[i, i]]);
        hi = hi.max(m[[i, i]]);
    }
    (lo, hi)
}

#[inline]
fn symmetrize(a: &mut Array2<f64>) {
    // Callers in this module always pass square matrices (Schur complements);
    // delegate to the canonical helper in `linalg::utils`.
    crate::linalg::utils::enforce_symmetry(a)
}