use crate::construction::ReparamResult;
use crate::estimate::EstimationError;
use crate::matrix::{
DesignMatrix, PsdWeightsView, ReparamOperator, SignedWeightsView, SymmetricMatrix,
};
use crate::solver::active_set::{ConstraintKktDiagnostics, LinearInequalityConstraints};
use crate::types::{Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, RidgePassport};
use ndarray::{Array1, Array2, ArrayView1};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use super::{compute_observed_hessian_curvature_arrays, computeworkingweight_derivatives_from_eta};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PirlsLinearSolvePath {
DenseTransformed,
SparseNative,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PirlsCoordinateFrame {
TransformedQs,
OriginalSparseNative,
}
#[derive(Debug, Clone, Default)]
pub enum FirthDiagnostics {
#[default]
Inactive,
Active {
jeffreys_logdet: f64,
hat_diag: Array1<f64>,
},
}
impl FirthDiagnostics {
#[inline]
pub fn jeffreys_logdet(&self) -> Option<f64> {
match self {
Self::Inactive => None,
Self::Active {
jeffreys_logdet, ..
} => Some(*jeffreys_logdet),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HessianCurvatureKind {
Fisher,
Observed,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ExportedLaplaceCurvature {
ObservedExact,
ExpectedInformationSurrogate,
InvalidObservedCurvature {
min_eigenvalue: f64,
pd_tolerance: f64,
gradient_norm: f64,
},
}
#[derive(Debug, Clone)]
pub struct WorkingState {
pub eta: LinearPredictor,
pub gradient: Array1<f64>,
pub hessian: crate::linalg::matrix::SymmetricMatrix,
pub log_likelihood: f64,
pub deviance: f64,
pub penalty_term: f64,
pub firth: FirthDiagnostics,
pub ridge_used: f64,
pub hessian_curvature: HessianCurvatureKind,
pub gradient_natural_scale: f64,
}
impl WorkingState {
#[inline]
pub fn jeffreys_logdet(&self) -> Option<f64> {
self.firth.jeffreys_logdet()
}
#[inline]
pub fn relative_gradient_norm(&self, g_norm: f64) -> f64 {
g_norm / (1.0 + self.gradient_natural_scale)
}
#[inline]
fn kkt_dimension_scale(&self) -> f64 {
let n = self.eta.len().max(1) as f64;
let p = (self.gradient.len() as f64).max(1.0);
n.sqrt() * p.sqrt()
}
#[inline]
pub fn certifies_kkt(&self, g_norm: f64, tol: f64) -> bool {
g_norm < tol * self.kkt_dimension_scale() || self.relative_gradient_norm(g_norm) < tol
}
#[inline]
pub fn near_stationary_kkt(&self, g_norm: f64, tol: f64) -> bool {
let near_tol = tol * 10.0;
g_norm <= near_tol * self.kkt_dimension_scale()
|| self.relative_gradient_norm(g_norm) <= near_tol
}
}
#[inline]
pub(crate) fn array1_l2_norm(v: &Array1<f64>) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[derive(Clone, Copy, Debug)]
pub struct AdaptiveKktTolerance {
pub eta: f64,
pub floor: f64,
pub ceiling: f64,
pub outer_grad_norm: f64,
}
#[derive(Clone, Debug)]
pub struct WorkingModelIterationInfo {
pub iteration: usize,
pub deviance: f64,
pub gradient_norm: f64,
pub step_size: f64,
pub step_halving: usize,
}
#[derive(Clone)]
pub struct WorkingModelPirlsResult {
pub beta: Coefficients,
pub state: WorkingState,
pub status: PirlsStatus,
pub iterations: usize,
pub lastgradient_norm: f64,
pub last_deviance_change: f64,
pub last_step_size: f64,
pub last_step_halving: usize,
pub max_abs_eta: f64,
pub constraint_kkt: Option<ConstraintKktDiagnostics>,
pub final_lm_lambda: f64,
pub final_accept_rho: Option<f64>,
pub min_penalized_deviance: f64,
pub exported_laplace_curvature: ExportedLaplaceCurvature,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum PirlsStatus {
Converged,
StalledAtValidMinimum,
MaxIterationsReached,
LmStepSearchExhausted,
Unstable,
}
impl PirlsStatus {
#[inline]
pub const fn is_failed_max_iterations(self) -> bool {
matches!(
self,
PirlsStatus::MaxIterationsReached | PirlsStatus::LmStepSearchExhausted
)
}
}
#[derive(Clone)]
pub struct PirlsResult {
pub likelihood: GlmLikelihoodSpec,
pub beta_transformed: Coefficients,
pub penalized_hessian_transformed: SymmetricMatrix,
pub stabilizedhessian_transformed: SymmetricMatrix,
pub ridge_passport: RidgePassport,
pub ridge_used: f64,
pub deviance: f64,
pub edf: f64,
pub stable_penalty_term: f64,
pub firth: FirthDiagnostics,
pub finalweights: Array1<f64>,
pub final_offset: Array1<f64>,
pub final_eta: Array1<f64>,
pub finalmu: Array1<f64>,
pub solveweights: Array1<f64>,
pub solveworking_response: Array1<f64>,
pub solvemu: Array1<f64>,
pub solve_dmu_deta: Array1<f64>,
pub solve_d2mu_deta2: Array1<f64>,
pub solve_d3mu_deta3: Array1<f64>,
pub solve_c_array: Array1<f64>,
pub solve_d_array: Array1<f64>,
pub derivatives_unsupported: bool,
pub status: PirlsStatus,
pub iteration: usize,
pub max_abs_eta: f64,
pub lastgradient_norm: f64,
pub gradient_natural_scale: f64,
pub last_deviance_change: f64,
pub last_step_halving: usize,
pub hessian_curvature: HessianCurvatureKind,
pub exported_laplace_curvature: ExportedLaplaceCurvature,
pub final_lm_lambda: f64,
pub final_accept_rho: Option<f64>,
pub constraint_kkt: Option<ConstraintKktDiagnostics>,
pub linear_constraints_transformed: Option<LinearInequalityConstraints>,
pub reparam_result: ReparamResult,
pub x_transformed: DesignMatrix,
pub coordinate_frame: PirlsCoordinateFrame,
pub cache_compacted: bool,
pub min_penalized_deviance: f64,
}
impl PirlsResult {
pub fn dense_stabilizedhessian_transformed(
&self,
context: &str,
) -> Result<Array2<f64>, EstimationError> {
self.stabilizedhessian_transformed
.try_to_dense_exact(context)
.map_err(EstimationError::InvalidInput)
}
#[inline]
pub fn jeffreys_logdet(&self) -> Option<f64> {
self.firth.jeffreys_logdet()
}
#[inline]
pub fn final_weights_signed(&self) -> SignedWeightsView<'_> {
SignedWeightsView::from_array(&self.finalweights)
}
#[inline]
pub fn solve_weights_psd(&self) -> PsdWeightsView<'_> {
PsdWeightsView::from_view_unchecked(self.solveweights.view())
}
#[inline]
pub fn relative_gradient_norm(&self) -> f64 {
self.lastgradient_norm / (1.0 + self.gradient_natural_scale)
}
pub(crate) fn compact_for_reml_cache(&self) -> Self {
Self {
likelihood: self.likelihood.clone(),
beta_transformed: self.beta_transformed.clone(),
penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
ridge_passport: self.ridge_passport,
ridge_used: self.ridge_used,
deviance: self.deviance,
edf: self.edf,
stable_penalty_term: self.stable_penalty_term,
firth: self.firth.clone(),
finalweights: Array1::zeros(0),
final_offset: Array1::zeros(0),
final_eta: self.final_eta.clone(),
finalmu: Array1::zeros(0),
solveweights: self.solveweights.clone(),
solveworking_response: self.solveworking_response.clone(),
solvemu: self.solvemu.clone(),
solve_dmu_deta: Array1::zeros(0),
solve_d2mu_deta2: Array1::zeros(0),
solve_d3mu_deta3: Array1::zeros(0),
solve_c_array: self.solve_c_array.clone(),
solve_d_array: self.solve_d_array.clone(),
derivatives_unsupported: self.derivatives_unsupported,
status: self.status,
iteration: self.iteration,
max_abs_eta: self.max_abs_eta,
lastgradient_norm: self.lastgradient_norm,
gradient_natural_scale: self.gradient_natural_scale,
last_deviance_change: self.last_deviance_change,
last_step_halving: self.last_step_halving,
hessian_curvature: self.hessian_curvature,
exported_laplace_curvature: self.exported_laplace_curvature.clone(),
final_lm_lambda: self.final_lm_lambda,
final_accept_rho: self.final_accept_rho,
constraint_kkt: self.constraint_kkt.clone(),
linear_constraints_transformed: self.linear_constraints_transformed.clone(),
reparam_result: self.reparam_result.clone(),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
Array2::zeros((0, 0)),
)),
coordinate_frame: self.coordinate_frame,
cache_compacted: true,
min_penalized_deviance: self.min_penalized_deviance,
}
}
pub(crate) fn rehydrate_after_reml_cache(
&self,
x_original: &DesignMatrix,
y: ArrayView1<'_, f64>,
priorweights: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
inverse_link: &InverseLink,
) -> Result<Self, EstimationError> {
if !self.cache_compacted {
return Ok(self.clone());
}
let (score_c_array, score_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
computeworkingweight_derivatives_from_eta(
&self.likelihood,
inverse_link,
&self.final_eta,
priorweights,
)?;
let (finalweights, solve_c_array, solve_d_array) =
if self.hessian_curvature == HessianCurvatureKind::Observed {
compute_observed_hessian_curvature_arrays(
&self.likelihood,
inverse_link,
&self.final_eta,
y,
&self.solveweights,
priorweights,
)?
} else {
(
self.solveweights.clone(),
score_c_array.clone(),
score_d_array.clone(),
)
};
let qs_arc = Arc::new(self.reparam_result.qs.clone());
Ok(Self {
likelihood: self.likelihood.clone(),
beta_transformed: self.beta_transformed.clone(),
penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
ridge_passport: self.ridge_passport,
ridge_used: self.ridge_used,
deviance: self.deviance,
edf: self.edf,
stable_penalty_term: self.stable_penalty_term,
firth: self.firth.clone(),
finalweights,
final_offset: offset.to_owned(),
final_eta: self.final_eta.clone(),
finalmu: self.solvemu.clone(),
solveweights: self.solveweights.clone(),
solveworking_response: self.solveworking_response.clone(),
solvemu: self.solvemu.clone(),
solve_dmu_deta,
solve_d2mu_deta2,
solve_d3mu_deta3,
solve_c_array,
solve_d_array,
derivatives_unsupported: self.derivatives_unsupported,
status: self.status,
iteration: self.iteration,
max_abs_eta: self.max_abs_eta,
lastgradient_norm: self.lastgradient_norm,
gradient_natural_scale: self.gradient_natural_scale,
last_deviance_change: self.last_deviance_change,
last_step_halving: self.last_step_halving,
hessian_curvature: self.hessian_curvature,
exported_laplace_curvature: self.exported_laplace_curvature.clone(),
final_lm_lambda: self.final_lm_lambda,
final_accept_rho: self.final_accept_rho,
constraint_kkt: self.constraint_kkt.clone(),
linear_constraints_transformed: self.linear_constraints_transformed.clone(),
reparam_result: self.reparam_result.clone(),
x_transformed: DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(
ReparamOperator::new(x_original.clone(), qs_arc),
))),
coordinate_frame: self.coordinate_frame,
cache_compacted: false,
min_penalized_deviance: self.min_penalized_deviance,
})
}
}