1use ndarray::Array1;
18use std::fmt;
19use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
20
21pub static H_MIN_EIG_LOG_BUCKET: AtomicI32 = AtomicI32::new(i32::MIN);
30pub static H_MIN_EIG_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
35pub const MIN_EIG_DIAG_EVERY: usize = 200;
39pub const MIN_EIG_DIAG_THRESHOLD: f64 = 1e-4;
42
43pub fn format_top_abs(values: &Array1<f64>, label: &str, max_items: usize) -> String {
47 if values.is_empty() {
48 return format!("{label}=<empty>");
49 }
50 let mut ranked: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
51 ranked.sort_by(|(_, left), (_, right)| {
52 right
53 .abs()
54 .partial_cmp(&left.abs())
55 .unwrap_or(std::cmp::Ordering::Equal)
56 });
57 let parts: Vec<String> = ranked
58 .into_iter()
59 .take(max_items)
60 .map(|(idx, value)| format!("{idx}:{value:.3e}"))
61 .collect();
62 format!("{label}=[{}]", parts.join(", "))
63}
64
65pub fn should_emit_h_min_eig_diag(min_eig: f64) -> bool {
68 if !min_eig.is_finite() || min_eig <= 0.0 {
69 return true;
70 }
71 if min_eig >= MIN_EIG_DIAG_THRESHOLD {
72 return false;
73 }
74 let bucket = if min_eig.is_finite() && min_eig > 0.0 {
75 min_eig.log10().floor() as i32
76 } else {
77 i32::MIN
78 };
79 let last = H_MIN_EIG_LOG_BUCKET.load(Ordering::Relaxed);
80 let count = H_MIN_EIG_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
81 if bucket != last || count.is_multiple_of(MIN_EIG_DIAG_EVERY) {
82 H_MIN_EIG_LOG_BUCKET.store(bucket, Ordering::Relaxed);
83 true
84 } else {
85 false
86 }
87}
88
89#[derive(Clone, Debug)]
95pub struct DiagnosticConfig {
96 pub kkt_tolerance: f64,
98 pub rel_error_threshold: f64,
100 pub emitwarnings: bool,
102}
103
104impl Default for DiagnosticConfig {
105 fn default() -> Self {
106 Self {
107 kkt_tolerance: 1e-4,
108 rel_error_threshold: 0.1,
109 emitwarnings: true,
110 }
111 }
112}
113
114#[derive(Clone, Debug)]
116pub struct EnvelopeAudit {
117 pub kkt_residual_norm: f64,
119 pub innerridge: f64,
121 pub outerridge: f64,
123 pub isviolated: bool,
125 pub message: String,
127}
128
129impl fmt::Display for EnvelopeAudit {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 write!(f, "{}", self.message)
132 }
133}
134
135#[derive(Clone, Debug)]
137pub struct SpectralBleedResult {
138 pub penalty_k: usize,
139 pub truncated_energy: f64,
141 pub applied_correction: f64,
143 pub has_bleed: bool,
145 pub message: String,
147}
148
149impl fmt::Display for SpectralBleedResult {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 write!(f, "{}", self.message)
152 }
153}
154
155#[derive(Clone, Debug)]
157pub struct DualRidgeResult {
158 pub pirlsridge: f64,
160 pub costridge: f64,
162 pub gradientridge: f64,
164 pub ridge_impact: f64,
166 pub phantom_penalty: f64,
168 pub has_mismatch: bool,
170 pub message: String,
172}
173
174impl fmt::Display for DualRidgeResult {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 write!(f, "{}", self.message)
177 }
178}
179
180#[derive(Clone, Debug, PartialEq)]
182pub struct PredictionDiagnostics {
183 pub n_obs: usize,
184 pub mae: f64,
185 pub rmse: f64,
186 pub bias: f64,
187 pub r_squared: Option<f64>,
188 pub residuals: Vec<f64>,
189}
190
191pub fn diagnostics_from_predictions(
193 observed: &[f64],
194 predicted_mean: &[f64],
195) -> Result<PredictionDiagnostics, String> {
196 if observed.is_empty() {
197 return Err("diagnostics_from_predictions requires at least one observation".to_string());
198 }
199 if observed.len() != predicted_mean.len() {
200 return Err(format!(
201 "diagnostics_from_predictions length mismatch: observed has {} values but predicted mean has {}",
202 observed.len(),
203 predicted_mean.len()
204 ));
205 }
206 if observed.iter().any(|value| !value.is_finite()) {
207 return Err("observed values must contain only finite numbers".to_string());
208 }
209 if predicted_mean.iter().any(|value| !value.is_finite()) {
210 return Err("predicted mean values must contain only finite numbers".to_string());
211 }
212
213 let n_obs = observed.len();
214 let n_obs_f = n_obs as f64;
215 let mut residuals = Vec::with_capacity(n_obs);
216 let mut abs_sum = 0.0_f64;
217 let mut residual_sum = 0.0_f64;
218 let mut residual_sum_squares = 0.0_f64;
219 let mut observed_sum = 0.0_f64;
220 for (obs, pred) in observed.iter().zip(predicted_mean.iter()) {
221 let residual = obs - pred;
222 residuals.push(residual);
223 abs_sum += residual.abs();
224 residual_sum += residual;
225 residual_sum_squares += residual * residual;
226 observed_sum += obs;
227 }
228
229 let observed_mean = observed_sum / n_obs_f;
230 let total_sum_squares = observed
231 .iter()
232 .map(|value| {
233 let centered = value - observed_mean;
234 centered * centered
235 })
236 .sum::<f64>();
237 let r_squared = if total_sum_squares > 0.0 {
238 Some(1.0 - residual_sum_squares / total_sum_squares)
239 } else {
240 None
241 };
242
243 Ok(PredictionDiagnostics {
244 n_obs,
245 mae: abs_sum / n_obs_f,
246 rmse: (residual_sum_squares / n_obs_f).sqrt(),
247 bias: residual_sum / n_obs_f,
248 r_squared,
249 residuals,
250 })
251}
252
253#[derive(Clone, Debug, Default)]
255pub struct GradientDiagnosticReport {
256 pub envelopeaudit: Option<EnvelopeAudit>,
258 pub spectral_bleed: Vec<SpectralBleedResult>,
260 pub dualridge: Option<DualRidgeResult>,
262}
263
264impl GradientDiagnosticReport {
265 pub fn new() -> Self {
267 Self::default()
268 }
269
270 pub fn summary(&self) -> String {
272 let mut lines = Vec::new();
273
274 if let Some(ref audit) = self.envelopeaudit
275 && audit.isviolated
276 {
277 lines.push(format!("[DIAG] {}", audit));
278 }
279
280 for bleed in &self.spectral_bleed {
281 if bleed.has_bleed {
282 lines.push(format!("[DIAG] {}", bleed));
283 }
284 }
285
286 if let Some(ref ridge) = self.dualridge
287 && ridge.has_mismatch
288 {
289 lines.push(format!("[DIAG] {}", ridge));
290 }
291
292 if lines.is_empty() {
293 "No gradient diagnostic issues detected.".to_string()
294 } else {
295 lines.join("\n")
296 }
297 }
298}
299
300pub fn compute_envelopeaudit(
318 kkt_residual_norm: f64,
319 referencegradient: &Array1<f64>,
320 ridge_used: f64,
321 ridge_assumed: f64,
322 beta: &Array1<f64>,
323 abs_tolerance: f64,
324 rel_tolerance: f64,
325) -> EnvelopeAudit {
326 let kkt_norm = kkt_residual_norm;
327 let penalty_norm = referencegradient.dot(referencegradient).sqrt();
328 let beta_norm = beta.dot(beta).sqrt();
329 let scale = penalty_norm.max((ridge_assumed.abs() * beta_norm).max(1e-12));
330 let rel_kkt = if scale > 0.0 { kkt_norm / scale } else { 0.0 };
331 let ridge_mismatch = (ridge_used - ridge_assumed).abs() > 1e-12;
332 let kktviolation = kkt_norm > abs_tolerance && rel_kkt > rel_tolerance;
333 let isviolated = kktviolation || ridge_mismatch;
334
335 let message = if ridge_mismatch && kktviolation {
336 format!(
337 "Envelope Violation: Inner solver ridge = {:.2e}, Outer gradient assumes ridge = {:.2e}. \
338 KKT residual norm = {:.2e} (abs tol = {:.2e}, rel tol = {:.2e}). Unaccounted gradient energy: {:.2e}",
339 ridge_used, ridge_assumed, kkt_norm, abs_tolerance, rel_tolerance, kkt_norm
340 )
341 } else if ridge_mismatch {
342 format!(
343 "Ridge Mismatch: PIRLS optimized for H + {:.2e}*I, but Gradient calculated for H + {:.2e}*I",
344 ridge_used, ridge_assumed
345 )
346 } else if kktviolation {
347 format!(
348 "Envelope Violation: KKT residual ||∇_β L|| = {:.2e} (rel {:.2e}) exceeds tolerances (abs {:.2e}, rel {:.2e}). \
349 Inner solver may not have converged to true stationary point.",
350 kkt_norm, rel_kkt, abs_tolerance, rel_tolerance
351 )
352 } else {
353 format!(
354 "Envelope OK: KKT residual = {:.2e} (rel {:.2e}), ridge match = {:.2e}",
355 kkt_norm, rel_kkt, ridge_used
356 )
357 };
358
359 EnvelopeAudit {
360 kkt_residual_norm: kkt_norm,
361 innerridge: ridge_used,
362 outerridge: ridge_assumed,
363 isviolated,
364 message,
365 }
366}
367
368pub fn compute_dualridge_check(
385 pirlsridge: f64,
386 costridge: f64,
387 gradientridge: f64,
388 beta: &Array1<f64>,
389) -> DualRidgeResult {
390 let beta_norm_sq = beta.dot(beta);
391 let beta_norm = beta_norm_sq.sqrt();
392
393 let ridge_impact = pirlsridge * beta_norm;
394 let phantom_penalty = 0.5 * pirlsridge * beta_norm_sq;
395
396 let pirlscost_mismatch = (pirlsridge - costridge).abs() > 1e-12;
397 let pirlsgrad_mismatch = (pirlsridge - gradientridge).abs() > 1e-12;
398 let costgrad_mismatch = (costridge - gradientridge).abs() > 1e-12;
399 let has_mismatch = pirlscost_mismatch || pirlsgrad_mismatch || costgrad_mismatch;
400
401 let message = if has_mismatch {
402 let mut mismatches = Vec::new();
403 if pirlscost_mismatch {
404 mismatches.push(format!(
405 "PIRLS({:.2e}) vs Cost({:.2e})",
406 pirlsridge, costridge
407 ));
408 }
409 if pirlsgrad_mismatch {
410 mismatches.push(format!(
411 "PIRLS({:.2e}) vs Gradient({:.2e})",
412 pirlsridge, gradientridge
413 ));
414 }
415 if costgrad_mismatch {
416 mismatches.push(format!(
417 "Cost({:.2e}) vs Gradient({:.2e})",
418 costridge, gradientridge
419 ));
420 }
421 format!(
422 "Ridge Mismatch detected: {}. Effective ridge impact on ||β|| = {:.2e}. \
423 Phantom penalty = {:.2e}. The surface being differentiated differs from \
424 the surface being optimized.",
425 mismatches.join(", "),
426 ridge_impact,
427 phantom_penalty
428 )
429 } else if pirlsridge > 0.0 {
430 format!(
431 "Ridge Consistency OK: All stages use ridge = {:.2e}. ||β|| = {:.2e}, phantom penalty = {:.2e}",
432 pirlsridge, beta_norm, phantom_penalty
433 )
434 } else {
435 "Ridge Consistency OK: No stabilization ridge required.".to_string()
436 };
437
438 DualRidgeResult {
439 pirlsridge,
440 costridge,
441 gradientridge,
442 ridge_impact,
443 phantom_penalty,
444 has_mismatch,
445 message,
446 }
447}
448
449#[derive(Clone, Copy, Debug, PartialEq, Eq)]
462pub enum KktRefusalDiagnosis {
463 RankDeficientHPen,
464 PhantomMultiplierWithWellConditionedH,
465 ActiveSetIncomplete,
466 AliasingDetectedAtFit,
472}
473
474impl KktRefusalDiagnosis {
475 pub fn as_str(&self) -> &'static str {
476 match self {
477 KktRefusalDiagnosis::RankDeficientHPen => "rank_deficient_H_pen",
478 KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => {
479 "phantom_multiplier_with_well_conditioned_H"
480 }
481 KktRefusalDiagnosis::ActiveSetIncomplete => "active_set_incomplete",
482 KktRefusalDiagnosis::AliasingDetectedAtFit => "aliasing_detected_at_fit",
483 }
484 }
485
486 pub fn parse_from_error(message: &str) -> Option<Self> {
490 let marker = "diagnosis: ";
491 let start = message.rfind(marker)? + marker.len();
492 let tail = &message[start..];
493 let end = tail
494 .find(|c: char| c == ';' || c == '\n' || c == ' ')
495 .unwrap_or(tail.len());
496 match &tail[..end] {
497 "rank_deficient_H_pen" => Some(KktRefusalDiagnosis::RankDeficientHPen),
498 "phantom_multiplier_with_well_conditioned_H" => {
499 Some(KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH)
500 }
501 "active_set_incomplete" => Some(KktRefusalDiagnosis::ActiveSetIncomplete),
502 "aliasing_detected_at_fit" => Some(KktRefusalDiagnosis::AliasingDetectedAtFit),
503 _ => None,
504 }
505 }
506
507 pub fn guidance(self) -> &'static str {
508 match self {
509 KktRefusalDiagnosis::RankDeficientHPen => {
510 "check whether the named block has a structural or numerical null direction \
511 not identified by the likelihood/penalty combination; for Duchon-style \
512 smooths this may be a polynomial null space, while marginal-slope fits can \
513 also expose callback-owned weak directions"
514 }
515 KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => {
516 "check whether the named block has a near-separated or weakly identified \
517 direction despite a well-conditioned penalized Hessian; in marginal-slope \
518 fits this often indicates marginal/logslope coupling rather than a \
519 Matérn/Duchon polynomial-nullspace failure"
520 }
521 KktRefusalDiagnosis::ActiveSetIncomplete => {
522 "check whether the named block's linear constraints need an additional \
523 active row or a tighter constrained re-solve; this is an active-set \
524 certification failure, not a polynomial-nullspace diagnosis"
525 }
526 KktRefusalDiagnosis::AliasingDetectedAtFit => {
527 "check whether the named block aliases another block after runtime \
528 constraints or callbacks materialize; drop or reparameterize the aliased \
529 direction before fitting"
530 }
531 }
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use ndarray::arr1;
539
540 #[test]
541 fn test_envelopeaudit_noviolation() {
542 let reference = arr1(&[0.0, 0.0, 0.0]);
543 let beta = arr1(&[0.1, 0.2, 0.3]);
544 let result = compute_envelopeaudit(0.0, &reference, 0.0, 0.0, &beta, 1e-8, 1e-6);
545
546 assert!(!result.isviolated);
547 }
548
549 #[test]
550 fn test_envelopeaudit_detects_ridge_mismatch() {
551 let reference = arr1(&[1.0, 0.0, 0.0]);
552 let beta = arr1(&[0.1, 0.2, 0.3]);
553 let result = compute_envelopeaudit(1e-10, &reference, 0.1, 0.0, &beta, 1e-8, 1e-6);
554
555 assert!(result.isviolated);
556 assert!(result.message.contains("Ridge Mismatch"));
557 }
558
559 #[test]
560 fn test_dualridge_check_no_mismatch() {
561 let beta = arr1(&[0.1, 0.2, 0.3]);
562 let result = compute_dualridge_check(0.0, 0.0, 0.0, &beta);
563
564 assert!(!result.has_mismatch);
565 }
566
567 #[test]
568 fn test_dualridge_check_detects_mismatch() {
569 let beta = arr1(&[0.1, 0.2, 0.3]);
570 let result = compute_dualridge_check(1e-4, 0.0, 0.0, &beta);
571
572 assert!(result.has_mismatch);
573 assert!(result.message.contains("Ridge Mismatch detected"));
574 }
575
576 #[test]
577 fn diagnostics_from_predictions_computes_residual_metrics() {
578 let observed = [1.0, 2.0, 4.0];
579 let predicted = [1.5, 1.5, 3.0];
580
581 let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
582
583 assert_eq!(result.residuals, vec![-0.5, 0.5, 1.0]);
584 assert_eq!(result.n_obs, 3);
585 assert_eq!(result.mae, 2.0 / 3.0);
586 assert_eq!(result.bias, 1.0 / 3.0);
587 assert_eq!(result.rmse, (1.5_f64 / 3.0).sqrt());
588 assert_eq!(result.r_squared, Some(1.0 - 1.5 / (14.0 / 3.0)));
589 }
590
591 #[test]
592 fn diagnostics_from_predictions_omits_r_squared_for_constant_observed() {
593 let observed = [2.0, 2.0];
594 let predicted = [1.0, 3.0];
595
596 let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
597
598 assert_eq!(result.r_squared, None);
599 }
600
601 #[test]
602 fn diagnostics_from_predictions_rejects_invalid_inputs() {
603 assert_eq!(
604 diagnostics_from_predictions(&[], &[]),
605 Err("diagnostics_from_predictions requires at least one observation".to_string())
606 );
607 assert_eq!(
608 diagnostics_from_predictions(&[1.0], &[1.0, 2.0]),
609 Err(
610 "diagnostics_from_predictions length mismatch: observed has 1 values but predicted mean has 2"
611 .to_string()
612 )
613 );
614 assert_eq!(
615 diagnostics_from_predictions(&[f64::NAN], &[1.0]),
616 Err("observed values must contain only finite numbers".to_string())
617 );
618 assert_eq!(
619 diagnostics_from_predictions(&[1.0], &[f64::INFINITY]),
620 Err("predicted mean values must contain only finite numbers".to_string())
621 );
622 }
623}