use ndarray::Array1;
use std::fmt;
use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
pub static GRAD_DIAG_BETA_COLLAPSE_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_DELTA_ZERO_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_LOGH_CLAMPED_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_KKT_SKIP_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static H_MIN_EIG_LOG_BUCKET: AtomicI32 = AtomicI32::new(i32::MIN);
pub static H_MIN_EIG_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
pub const MIN_EIG_DIAG_EVERY: usize = 200;
pub const MIN_EIG_DIAG_THRESHOLD: f64 = 1e-4;
pub fn should_emit_h_min_eig_diag(min_eig: f64) -> bool {
if !min_eig.is_finite() || min_eig <= 0.0 {
return true;
}
if min_eig >= MIN_EIG_DIAG_THRESHOLD {
return false;
}
let bucket = if min_eig.is_finite() && min_eig > 0.0 {
min_eig.log10().floor() as i32
} else {
i32::MIN
};
let last = H_MIN_EIG_LOG_BUCKET.load(Ordering::Relaxed);
let count = H_MIN_EIG_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
if bucket != last || count.is_multiple_of(MIN_EIG_DIAG_EVERY) {
H_MIN_EIG_LOG_BUCKET.store(bucket, Ordering::Relaxed);
true
} else {
false
}
}
#[derive(Clone, Debug)]
pub struct DiagnosticConfig {
pub kkt_tolerance: f64,
pub rel_error_threshold: f64,
pub emitwarnings: bool,
}
impl Default for DiagnosticConfig {
fn default() -> Self {
Self {
kkt_tolerance: 1e-4,
rel_error_threshold: 0.1,
emitwarnings: true,
}
}
}
#[derive(Clone, Debug)]
pub struct EnvelopeAudit {
pub kkt_residual_norm: f64,
pub innerridge: f64,
pub outerridge: f64,
pub isviolated: bool,
pub message: String,
}
impl fmt::Display for EnvelopeAudit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Clone, Debug)]
pub struct SpectralBleedResult {
pub penalty_k: usize,
pub truncated_energy: f64,
pub applied_correction: f64,
pub has_bleed: bool,
pub message: String,
}
impl fmt::Display for SpectralBleedResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Clone, Debug)]
pub struct DualRidgeResult {
pub pirlsridge: f64,
pub costridge: f64,
pub gradientridge: f64,
pub ridge_impact: f64,
pub phantom_penalty: f64,
pub has_mismatch: bool,
pub message: String,
}
impl fmt::Display for DualRidgeResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Clone, Debug, Default)]
pub struct GradientDiagnosticReport {
pub envelopeaudit: Option<EnvelopeAudit>,
pub spectral_bleed: Vec<SpectralBleedResult>,
pub dualridge: Option<DualRidgeResult>,
}
impl GradientDiagnosticReport {
pub fn new() -> Self {
Self::default()
}
pub fn summary(&self) -> String {
let mut lines = Vec::new();
if let Some(ref audit) = self.envelopeaudit
&& audit.isviolated
{
lines.push(format!("[DIAG] {}", audit));
}
for bleed in &self.spectral_bleed {
if bleed.has_bleed {
lines.push(format!("[DIAG] {}", bleed));
}
}
if let Some(ref ridge) = self.dualridge
&& ridge.has_mismatch
{
lines.push(format!("[DIAG] {}", ridge));
}
if lines.is_empty() {
"No gradient diagnostic issues detected.".to_string()
} else {
lines.join("\n")
}
}
}
pub fn compute_envelopeaudit(
kkt_residual_norm: f64,
referencegradient: &Array1<f64>,
ridge_used: f64,
ridge_assumed: f64,
beta: &Array1<f64>,
abs_tolerance: f64,
rel_tolerance: f64,
) -> EnvelopeAudit {
let kkt_norm = kkt_residual_norm;
let penalty_norm = referencegradient.dot(referencegradient).sqrt();
let beta_norm = beta.dot(beta).sqrt();
let scale = penalty_norm.max((ridge_assumed.abs() * beta_norm).max(1e-12));
let rel_kkt = if scale > 0.0 { kkt_norm / scale } else { 0.0 };
let ridge_mismatch = (ridge_used - ridge_assumed).abs() > 1e-12;
let kktviolation = kkt_norm > abs_tolerance && rel_kkt > rel_tolerance;
let isviolated = kktviolation || ridge_mismatch;
let message = if ridge_mismatch && kktviolation {
format!(
"Envelope Violation: Inner solver ridge = {:.2e}, Outer gradient assumes ridge = {:.2e}. \
KKT residual norm = {:.2e} (abs tol = {:.2e}, rel tol = {:.2e}). Unaccounted gradient energy: {:.2e}",
ridge_used, ridge_assumed, kkt_norm, abs_tolerance, rel_tolerance, kkt_norm
)
} else if ridge_mismatch {
format!(
"Ridge Mismatch: PIRLS optimized for H + {:.2e}*I, but Gradient calculated for H + {:.2e}*I",
ridge_used, ridge_assumed
)
} else if kktviolation {
format!(
"Envelope Violation: KKT residual ||∇_β L|| = {:.2e} (rel {:.2e}) exceeds tolerances (abs {:.2e}, rel {:.2e}). \
Inner solver may not have converged to true stationary point.",
kkt_norm, rel_kkt, abs_tolerance, rel_tolerance
)
} else {
format!(
"Envelope OK: KKT residual = {:.2e} (rel {:.2e}), ridge match = {:.2e}",
kkt_norm, rel_kkt, ridge_used
)
};
EnvelopeAudit {
kkt_residual_norm: kkt_norm,
innerridge: ridge_used,
outerridge: ridge_assumed,
isviolated,
message,
}
}
pub fn compute_dualridge_check(
pirlsridge: f64,
costridge: f64,
gradientridge: f64,
beta: &Array1<f64>,
) -> DualRidgeResult {
let beta_norm_sq = beta.dot(beta);
let beta_norm = beta_norm_sq.sqrt();
let ridge_impact = pirlsridge * beta_norm;
let phantom_penalty = 0.5 * pirlsridge * beta_norm_sq;
let pirlscost_mismatch = (pirlsridge - costridge).abs() > 1e-12;
let pirlsgrad_mismatch = (pirlsridge - gradientridge).abs() > 1e-12;
let costgrad_mismatch = (costridge - gradientridge).abs() > 1e-12;
let has_mismatch = pirlscost_mismatch || pirlsgrad_mismatch || costgrad_mismatch;
let message = if has_mismatch {
let mut mismatches = Vec::new();
if pirlscost_mismatch {
mismatches.push(format!(
"PIRLS({:.2e}) vs Cost({:.2e})",
pirlsridge, costridge
));
}
if pirlsgrad_mismatch {
mismatches.push(format!(
"PIRLS({:.2e}) vs Gradient({:.2e})",
pirlsridge, gradientridge
));
}
if costgrad_mismatch {
mismatches.push(format!(
"Cost({:.2e}) vs Gradient({:.2e})",
costridge, gradientridge
));
}
format!(
"Ridge Mismatch detected: {}. Effective ridge impact on ||β|| = {:.2e}. \
Phantom penalty = {:.2e}. The surface being differentiated differs from \
the surface being optimized.",
mismatches.join(", "),
ridge_impact,
phantom_penalty
)
} else if pirlsridge > 0.0 {
format!(
"Ridge Consistency OK: All stages use ridge = {:.2e}. ||β|| = {:.2e}, phantom penalty = {:.2e}",
pirlsridge, beta_norm, phantom_penalty
)
} else {
"Ridge Consistency OK: No stabilization ridge required.".to_string()
};
DualRidgeResult {
pirlsridge,
costridge,
gradientridge,
ridge_impact,
phantom_penalty,
has_mismatch,
message,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr1;
#[test]
fn test_envelopeaudit_noviolation() {
let reference = arr1(&[0.0, 0.0, 0.0]);
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_envelopeaudit(0.0, &reference, 0.0, 0.0, &beta, 1e-8, 1e-6);
assert!(!result.isviolated);
}
#[test]
fn test_envelopeaudit_detects_ridge_mismatch() {
let reference = arr1(&[1.0, 0.0, 0.0]);
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_envelopeaudit(1e-10, &reference, 0.1, 0.0, &beta, 1e-8, 1e-6);
assert!(result.isviolated);
assert!(result.message.contains("Ridge Mismatch"));
}
#[test]
fn test_dualridge_check_no_mismatch() {
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_dualridge_check(0.0, 0.0, 0.0, &beta);
assert!(!result.has_mismatch);
}
#[test]
fn test_dualridge_check_detects_mismatch() {
let beta = arr1(&[0.1, 0.2, 0.3]);
let result = compute_dualridge_check(1e-4, 0.0, 0.0, &beta);
assert!(result.has_mismatch);
assert!(result.message.contains("Ridge Mismatch detected"));
}
}