use ndarray::{Array1, ArrayView1};
use serde::{Deserialize, Serialize};
use std::ops::{Deref, DerefMut};
pub use gam_linalg::{RidgeDeterminantMode, RidgePolicy};
pub const MIN_WEIGHT: f64 = 1e-12;
pub use gam_spec::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeMatrixForm {
ScaledIdentity,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RidgePassport {
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub policy: RidgePolicy,
}
impl RidgePassport {
pub const fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
Self {
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
policy,
}
}
#[inline]
pub const fn penalty_logdet_ridge(self) -> f64 {
if self.policy.include_penalty_logdet {
self.delta
} else {
0.0
}
}
#[inline]
pub const fn laplacehessianridge(self) -> f64 {
if self.policy.include_laplacehessian {
self.delta
} else {
0.0
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Inertia {
pub positive: usize,
pub zero: usize,
pub negative: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationRule {
FixedConstant,
InertiaTarget { spd_floor: f64 },
Heuristic,
UserSpecified,
BackoffEscalation { attempts: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationKind {
None,
SolverDampingOnly,
NumericalPerturbation {
backward_error_bound: Option<f64>,
},
ExplicitPrior,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct StabilizationLedger {
pub kind: StabilizationKind,
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub chosen_by: StabilizationRule,
pub inertia_before: Option<Inertia>,
pub inertia_after: Option<Inertia>,
}
impl StabilizationLedger {
pub const fn none() -> Self {
Self {
kind: StabilizationKind::None,
delta: 0.0,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
}
}
pub const fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
Self {
kind: StabilizationKind::SolverDampingOnly,
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
}
}
pub const fn numerical_perturbation(
delta: f64,
chosen_by: StabilizationRule,
backward_error_bound: Option<f64>,
) -> Self {
Self {
kind: StabilizationKind::NumericalPerturbation {
backward_error_bound,
},
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
}
}
pub const fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
Self {
kind: StabilizationKind::ExplicitPrior,
delta,
matrix_form,
chosen_by: StabilizationRule::UserSpecified,
inertia_before: None,
inertia_after: None,
}
}
pub const fn from_passport(passport: RidgePassport) -> Self {
let any_included = passport.policy.include_quadratic_penalty
|| passport.policy.include_laplacehessian
|| passport.policy.include_penalty_logdet;
let kind = if any_included {
StabilizationKind::ExplicitPrior
} else {
StabilizationKind::NumericalPerturbation {
backward_error_bound: None,
}
};
Self {
kind,
delta: passport.delta,
matrix_form: passport.matrix_form,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
}
}
#[inline]
pub const fn quadratic_delta(&self) -> f64 {
match self.kind {
StabilizationKind::ExplicitPrior => self.delta,
StabilizationKind::None
| StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => 0.0,
}
}
#[inline]
pub const fn laplace_hessian_delta(&self) -> f64 {
match self.kind {
StabilizationKind::ExplicitPrior => self.delta,
StabilizationKind::None
| StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => 0.0,
}
}
#[inline]
pub const fn penalty_logdet_delta(&self) -> f64 {
match self.kind {
StabilizationKind::ExplicitPrior => self.delta,
StabilizationKind::None
| StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => 0.0,
}
}
}
macro_rules! array1_f64_newtype {
($name:ident $(, $extra:ident)*) => {
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct $name(pub Array1<f64>);
impl $name {
#[inline]
pub fn new(values: Array1<f64>) -> Self {
Self(values)
}
#[inline]
pub fn zeros(len: usize) -> Self {
Self(Array1::zeros(len))
}
}
impl Deref for $name {
type Target = Array1<f64>;
#[inline]
fn deref(&self) -> &Self::Target { &self.0 }
}
impl DerefMut for $name {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}
impl AsRef<Array1<f64>> for $name {
#[inline]
fn as_ref(&self) -> &Array1<f64> { &self.0 }
}
impl From<Array1<f64>> for $name {
#[inline]
fn from(values: Array1<f64>) -> Self { Self(values) }
}
impl From<$name> for Array1<f64> {
#[inline]
fn from(values: $name) -> Self { values.0 }
}
$( array1_f64_newtype!(@extra $name $extra); )*
};
(@extra $name:ident exp) => {
impl $name {
#[inline]
pub fn exp(&self) -> Array1<f64> { self.0.mapv(f64::exp) }
}
};
}
array1_f64_newtype!(Coefficients);
array1_f64_newtype!(LinearPredictor);
array1_f64_newtype!(LogSmoothingParams, exp);
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct SmoothTermIdx(usize);
impl SmoothTermIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn placeholder() -> Self {
Self(usize::MAX)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
#[inline]
pub const fn is_placeholder(self) -> bool {
self.0 == usize::MAX
}
}
impl std::fmt::Display for SmoothTermIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct PenaltyIdx(usize);
impl PenaltyIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for PenaltyIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct BasisIdx(usize);
impl BasisIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for BasisIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ColIdx(usize);
impl ColIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for ColIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct RowIdx(usize);
impl RowIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for RowIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug)]
pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
impl<'a> LogSmoothingParamsView<'a> {
pub fn new(values: ArrayView1<'a, f64>) -> Self {
Self(values)
}
pub fn exp(&self) -> Array1<f64> {
self.0.mapv(f64::exp)
}
}
impl<'a> Deref for LogSmoothingParamsView<'a> {
type Target = ArrayView1<'a, f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
mod ridge_policy_tests {
use super::{RidgePassport, RidgePolicy, StabilizationKind, StabilizationLedger};
#[test]
fn solver_only_ridge_policy_stays_off_objective_accounting() {
let passport = RidgePassport::scaled_identity(1.0e-4, RidgePolicy::solver_only());
assert!(
!passport.policy.include_quadratic_penalty,
"solver-only ridge must not add a quadratic prior"
);
assert_eq!(
passport.penalty_logdet_ridge(),
0.0,
"solver-only ridge must not shift the penalty logdet"
);
assert_eq!(
passport.laplacehessianridge(),
0.0,
"solver-only ridge must not shift the Laplace Hessian"
);
let ledger = StabilizationLedger::from_passport(passport);
assert!(
matches!(
ledger.kind,
StabilizationKind::NumericalPerturbation {
backward_error_bound: None
}
),
"solver-only ridge is a numerical perturbation, not an explicit prior"
);
assert_eq!(
ledger.quadratic_delta(),
0.0,
"solver-only ridge must not contribute to the optimized objective"
);
assert_eq!(
ledger.laplace_hessian_delta(),
0.0,
"solver-only ridge must not contribute to REML curvature accounting"
);
assert_eq!(
ledger.penalty_logdet_delta(),
0.0,
"solver-only ridge must not contribute to determinant accounting"
);
}
}