use ndarray::Array1;
use std::fmt;
use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
pub static GRAD_DIAG_BETA_COLLAPSE_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_DELTA_ZERO_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_LOGH_CLAMPED_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_KKT_SKIP_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static H_MIN_EIG_LOG_BUCKET: AtomicI32 = AtomicI32::new(i32::MIN);
pub static H_MIN_EIG_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
pub const MIN_EIG_DIAG_EVERY: usize = 200;
pub const MIN_EIG_DIAG_THRESHOLD: f64 = 1e-4;
pub fn format_top_abs(values: &Array1<f64>, label: &str, max_items: usize) -> String {
if values.is_empty() {
return format!("{label}=<empty>");
}
let mut ranked: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
ranked.sort_by(|(_, left), (_, right)| {
right
.abs()
.partial_cmp(&left.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let parts: Vec<String> = ranked
.into_iter()
.take(max_items)
.map(|(idx, value)| format!("{idx}:{value:.3e}"))
.collect();
format!("{label}=[{}]", parts.join(", "))
}
pub fn should_emit_h_min_eig_diag(min_eig: f64) -> bool {
if !min_eig.is_finite() || min_eig <= 0.0 {
return true;
}
if min_eig >= MIN_EIG_DIAG_THRESHOLD {
return false;
}
let bucket = if min_eig.is_finite() && min_eig > 0.0 {
min_eig.log10().floor() as i32
} else {
i32::MIN
};
let last = H_MIN_EIG_LOG_BUCKET.load(Ordering::Relaxed);
let count = H_MIN_EIG_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
if bucket != last || count.is_multiple_of(MIN_EIG_DIAG_EVERY) {
H_MIN_EIG_LOG_BUCKET.store(bucket, Ordering::Relaxed);
true
} else {
false
}
}
#[derive(Clone, Debug)]
pub struct DiagnosticConfig {
pub kkt_tolerance: f64,
pub rel_error_threshold: f64,
pub emitwarnings: bool,
}
impl Default for DiagnosticConfig {
fn default() -> Self {
Self {
kkt_tolerance: 1e-4,
rel_error_threshold: 0.1,
emitwarnings: true,
}
}
}
#[derive(Clone, Debug)]
pub struct EnvelopeAudit {
pub kkt_residual_norm: f64,
pub innerridge: f64,
pub outerridge: f64,
pub isviolated: bool,
pub message: String,
}
impl fmt::Display for EnvelopeAudit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Clone, Debug)]
pub struct SpectralBleedResult {
pub penalty_k: usize,
pub truncated_energy: f64,
pub applied_correction: f64,
pub has_bleed: bool,
pub message: String,
}
impl fmt::Display for SpectralBleedResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Clone, Debug)]
pub struct DualRidgeResult {
pub pirlsridge: f64,
pub costridge: f64,
pub gradientridge: f64,
pub ridge_impact: f64,
pub phantom_penalty: f64,
pub has_mismatch: bool,
pub message: String,
}
impl fmt::Display for DualRidgeResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct PredictionDiagnostics {
pub n_obs: usize,
pub mae: f64,
pub rmse: f64,
pub bias: f64,
pub r_squared: Option<f64>,
pub residuals: Vec<f64>,
}
pub fn diagnostics_from_predictions(
observed: &[f64],
predicted_mean: &[f64],
) -> Result<PredictionDiagnostics, String> {
if observed.is_empty() {
return Err("diagnostics_from_predictions requires at least one observation".to_string());
}
if observed.len() != predicted_mean.len() {
return Err(format!(
"diagnostics_from_predictions length mismatch: observed has {} values but predicted mean has {}",
observed.len(),
predicted_mean.len()
));
}
if observed.iter().any(|value| !value.is_finite()) {
return Err("observed values must contain only finite numbers".to_string());
}
if predicted_mean.iter().any(|value| !value.is_finite()) {
return Err("predicted mean values must contain only finite numbers".to_string());
}
let n_obs = observed.len();
let n_obs_f = n_obs as f64;
let mut residuals = Vec::with_capacity(n_obs);
let mut abs_sum = 0.0_f64;
let mut residual_sum = 0.0_f64;
let mut residual_sum_squares = 0.0_f64;
let mut observed_sum = 0.0_f64;
for (obs, pred) in observed.iter().zip(predicted_mean.iter()) {
let residual = obs - pred;
residuals.push(residual);
abs_sum += residual.abs();
residual_sum += residual;
residual_sum_squares += residual * residual;
observed_sum += obs;
}
let observed_mean = observed_sum / n_obs_f;
let total_sum_squares = observed
.iter()
.map(|value| {
let centered = value - observed_mean;
centered * centered
})
.sum::<f64>();
let r_squared = if total_sum_squares > 0.0 {
Some(1.0 - residual_sum_squares / total_sum_squares)
} else {
None
};
Ok(PredictionDiagnostics {
n_obs,
mae: abs_sum / n_obs_f,
rmse: (residual_sum_squares / n_obs_f).sqrt(),
bias: residual_sum / n_obs_f,
r_squared,
residuals,
})
}
#[derive(Clone, Debug, Default)]
pub struct GradientDiagnosticReport {
pub envelopeaudit: Option<EnvelopeAudit>,
pub spectral_bleed: Vec<SpectralBleedResult>,
pub dualridge: Option<DualRidgeResult>,
}
impl GradientDiagnosticReport {
pub fn new() -> Self {
Self::default()
}
pub fn summary(&self) -> String {
let mut lines = Vec::new();
if let Some(ref audit) = self.envelopeaudit
&& audit.isviolated
{
lines.push(format!("[DIAG] {}", audit));
}
for bleed in &self.spectral_bleed {
if bleed.has_bleed {
lines.push(format!("[DIAG] {}", bleed));
}
}
if let Some(ref ridge) = self.dualridge
&& ridge.has_mismatch
{
lines.push(format!("[DIAG] {}", ridge));
}
if lines.is_empty() {
"No gradient diagnostic issues detected.".to_string()
} else {
lines.join("\n")
}
}
}
pub fn compute_envelopeaudit(
kkt_residual_norm: f64,
referencegradient: &Array1<f64>,
ridge_used: f64,
ridge_assumed: f64,
beta: &Array1<f64>,
abs_tolerance: f64,
rel_tolerance: f64,
) -> EnvelopeAudit {
let kkt_norm = kkt_residual_norm;
let penalty_norm = referencegradient.dot(referencegradient).sqrt();
let beta_norm = beta.dot(beta).sqrt();
let scale = penalty_norm.max((ridge_assumed.abs() * beta_norm).max(1e-12));
let rel_kkt = if scale > 0.0 { kkt_norm / scale } else { 0.0 };
let ridge_mismatch = (ridge_used - ridge_assumed).abs() > 1e-12;
let kktviolation = kkt_norm > abs_tolerance && rel_kkt > rel_tolerance;
let isviolated = kktviolation || ridge_mismatch;
let message = if ridge_mismatch && kktviolation {
format!(
"Envelope Violation: Inner solver ridge = {:.2e}, Outer gradient assumes ridge = {:.2e}. \
KKT residual norm = {:.2e} (abs tol = {:.2e}, rel tol = {:.2e}). Unaccounted gradient energy: {:.2e}",
ridge_used, ridge_assumed, kkt_norm, abs_tolerance, rel_tolerance, kkt_norm
)
} else if ridge_mismatch {
format!(
"Ridge Mismatch: PIRLS optimized for H + {:.2e}*I, but Gradient calculated for H + {:.2e}*I",
ridge_used, ridge_assumed
)
} else if kktviolation {
format!(
"Envelope Violation: KKT residual ||∇_β L|| = {:.2e} (rel {:.2e}) exceeds tolerances (abs {:.2e}, rel {:.2e}). \
Inner solver may not have converged to true stationary point.",
kkt_norm, rel_kkt, abs_tolerance, rel_tolerance
)
} else {
format!(
"Envelope OK: KKT residual = {:.2e} (rel {:.2e}), ridge match = {:.2e}",
kkt_norm, rel_kkt, ridge_used
)
};
EnvelopeAudit {
kkt_residual_norm: kkt_norm,
innerridge: ridge_used,
outerridge: ridge_assumed,
isviolated,
message,
}
}
pub fn compute_dualridge_check(
pirlsridge: f64,
costridge: f64,
gradientridge: f64,
beta: &Array1<f64>,
) -> DualRidgeResult {
let beta_norm_sq = beta.dot(beta);
let beta_norm = beta_norm_sq.sqrt();
let ridge_impact = pirlsridge * beta_norm;
let phantom_penalty = 0.5 * pirlsridge * beta_norm_sq;
let pirlscost_mismatch = (pirlsridge - costridge).abs() > 1e-12;
let pirlsgrad_mismatch = (pirlsridge - gradientridge).abs() > 1e-12;
let costgrad_mismatch = (costridge - gradientridge).abs() > 1e-12;
let has_mismatch = pirlscost_mismatch || pirlsgrad_mismatch || costgrad_mismatch;
let message = if has_mismatch {
let mut mismatches = Vec::new();
if pirlscost_mismatch {
mismatches.push(format!(
"PIRLS({:.2e}) vs Cost({:.2e})",
pirlsridge, costridge
));
}
if pirlsgrad_mismatch {
mismatches.push(format!(
"PIRLS({:.2e}) vs Gradient({:.2e})",
pirlsridge, gradientridge
));
}
if costgrad_mismatch {
mismatches.push(format!(
"Cost({:.2e}) vs Gradient({:.2e})",
costridge, gradientridge
));
}
format!(
"Ridge Mismatch detected: {}. Effective ridge impact on ||β|| = {:.2e}. \
Phantom penalty = {:.2e}. The surface being differentiated differs from \
the surface being optimized.",
mismatches.join(", "),
ridge_impact,
phantom_penalty
)
} else if pirlsridge > 0.0 {
format!(
"Ridge Consistency OK: All stages use ridge = {:.2e}. ||β|| = {:.2e}, phantom penalty = {:.2e}",
pirlsridge, beta_norm, phantom_penalty
)
} else {
"Ridge Consistency OK: No stabilization ridge required.".to_string()
};
DualRidgeResult {
pirlsridge,
costridge,
gradientridge,
ridge_impact,
phantom_penalty,
has_mismatch,
message,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr1;
#[test]
fn test_envelopeaudit_noviolation() {
let reference = arr1(&[0.0, 0.0, 0.0]);
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_envelopeaudit(0.0, &reference, 0.0, 0.0, &beta, 1e-8, 1e-6);
assert!(!result.isviolated);
}
#[test]
fn test_envelopeaudit_detects_ridge_mismatch() {
let reference = arr1(&[1.0, 0.0, 0.0]);
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_envelopeaudit(1e-10, &reference, 0.1, 0.0, &beta, 1e-8, 1e-6);
assert!(result.isviolated);
assert!(result.message.contains("Ridge Mismatch"));
}
#[test]
fn test_dualridge_check_no_mismatch() {
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_dualridge_check(0.0, 0.0, 0.0, &beta);
assert!(!result.has_mismatch);
}
#[test]
fn test_dualridge_check_detects_mismatch() {
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_dualridge_check(1e-4, 0.0, 0.0, &beta);
assert!(result.has_mismatch);
assert!(result.message.contains("Ridge Mismatch detected"));
}
#[test]
fn diagnostics_from_predictions_computes_residual_metrics() {
let observed = [1.0, 2.0, 4.0];
let predicted = [1.5, 1.5, 3.0];
let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
assert_eq!(result.residuals, vec![-0.5, 0.5, 1.0]);
assert_eq!(result.n_obs, 3);
assert_eq!(result.mae, 2.0 / 3.0);
assert_eq!(result.bias, 1.0 / 3.0);
assert_eq!(result.rmse, (1.5_f64 / 3.0).sqrt());
assert_eq!(result.r_squared, Some(1.0 - 1.5 / (14.0 / 3.0)));
}
#[test]
fn diagnostics_from_predictions_omits_r_squared_for_constant_observed() {
let observed = [2.0, 2.0];
let predicted = [1.0, 3.0];
let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
assert_eq!(result.r_squared, None);
}
#[test]
fn diagnostics_from_predictions_rejects_invalid_inputs() {
assert_eq!(
diagnostics_from_predictions(&[], &[]),
Err("diagnostics_from_predictions requires at least one observation".to_string())
);
assert_eq!(
diagnostics_from_predictions(&[1.0], &[1.0, 2.0]),
Err(
"diagnostics_from_predictions length mismatch: observed has 1 values but predicted mean has 2"
.to_string()
)
);
assert_eq!(
diagnostics_from_predictions(&[f64::NAN], &[1.0]),
Err("observed values must contain only finite numbers".to_string())
);
assert_eq!(
diagnostics_from_predictions(&[1.0], &[f64::INFINITY]),
Err("predicted mean values must contain only finite numbers".to_string())
);
}
}