1use ndarray::Array1;
18use std::fmt;
19use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
20
21pub static H_MIN_EIG_LOG_BUCKET: AtomicI32 = AtomicI32::new(i32::MIN);
29pub static H_MIN_EIG_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
34pub const MIN_EIG_DIAG_EVERY: usize = 200;
38pub const MIN_EIG_DIAG_THRESHOLD: f64 = 1e-4;
41
42pub fn format_top_abs(values: &Array1<f64>, label: &str, max_items: usize) -> String {
46 if values.is_empty() {
47 return format!("{label}=<empty>");
48 }
49 let mut ranked: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
50 ranked.sort_by(|(_, left), (_, right)| {
51 right
52 .abs()
53 .partial_cmp(&left.abs())
54 .unwrap_or(std::cmp::Ordering::Equal)
55 });
56 let parts: Vec<String> = ranked
57 .into_iter()
58 .take(max_items)
59 .map(|(idx, value)| format!("{idx}:{value:.3e}"))
60 .collect();
61 format!("{label}=[{}]", parts.join(", "))
62}
63
64pub fn should_emit_h_min_eig_diag(min_eig: f64) -> bool {
67 if !min_eig.is_finite() || min_eig <= 0.0 {
68 return true;
69 }
70 if min_eig >= MIN_EIG_DIAG_THRESHOLD {
71 return false;
72 }
73 let bucket = if min_eig.is_finite() && min_eig > 0.0 {
74 min_eig.log10().floor() as i32
75 } else {
76 i32::MIN
77 };
78 let last = H_MIN_EIG_LOG_BUCKET.load(Ordering::Relaxed);
79 let count = H_MIN_EIG_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
80 if bucket != last || count.is_multiple_of(MIN_EIG_DIAG_EVERY) {
81 H_MIN_EIG_LOG_BUCKET.store(bucket, Ordering::Relaxed);
82 true
83 } else {
84 false
85 }
86}
87
88#[derive(Clone, Debug)]
94pub struct DiagnosticConfig {
95 pub kkt_tolerance: f64,
97 pub rel_error_threshold: f64,
99 pub emitwarnings: bool,
101}
102
103impl Default for DiagnosticConfig {
104 fn default() -> Self {
105 Self {
106 kkt_tolerance: 1e-4,
107 rel_error_threshold: 0.1,
108 emitwarnings: true,
109 }
110 }
111}
112
113#[derive(Clone, Debug)]
115pub struct EnvelopeAudit {
116 pub kkt_residual_norm: f64,
118 pub innerridge: f64,
120 pub outerridge: f64,
122 pub isviolated: bool,
124 pub message: String,
126}
127
128impl fmt::Display for EnvelopeAudit {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 write!(f, "{}", self.message)
131 }
132}
133
134#[derive(Clone, Debug)]
136pub struct SpectralBleedResult {
137 pub penalty_k: usize,
138 pub truncated_energy: f64,
140 pub applied_correction: f64,
142 pub has_bleed: bool,
144 pub message: String,
146}
147
148impl fmt::Display for SpectralBleedResult {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 write!(f, "{}", self.message)
151 }
152}
153
154#[derive(Clone, Debug)]
156pub struct DualRidgeResult {
157 pub pirlsridge: f64,
159 pub costridge: f64,
161 pub gradientridge: f64,
163 pub ridge_impact: f64,
165 pub phantom_penalty: f64,
167 pub has_mismatch: bool,
169 pub message: String,
171}
172
173impl fmt::Display for DualRidgeResult {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 write!(f, "{}", self.message)
176 }
177}
178
179#[derive(Clone, Debug, PartialEq)]
181pub struct PredictionDiagnostics {
182 pub n_obs: usize,
183 pub mae: f64,
184 pub rmse: f64,
185 pub bias: f64,
186 pub r_squared: Option<f64>,
187 pub residuals: Vec<f64>,
188}
189
190pub fn diagnostics_from_predictions(
192 observed: &[f64],
193 predicted_mean: &[f64],
194) -> Result<PredictionDiagnostics, String> {
195 if observed.is_empty() {
196 return Err("diagnostics_from_predictions requires at least one observation".to_string());
197 }
198 if observed.len() != predicted_mean.len() {
199 return Err(format!(
200 "diagnostics_from_predictions length mismatch: observed has {} values but predicted mean has {}",
201 observed.len(),
202 predicted_mean.len()
203 ));
204 }
205 if observed.iter().any(|value| !value.is_finite()) {
206 return Err("observed values must contain only finite numbers".to_string());
207 }
208 if predicted_mean.iter().any(|value| !value.is_finite()) {
209 return Err("predicted mean values must contain only finite numbers".to_string());
210 }
211
212 let n_obs = observed.len();
213 let n_obs_f = n_obs as f64;
214 let mut residuals = Vec::with_capacity(n_obs);
215 let mut abs_sum = 0.0_f64;
216 let mut residual_sum = 0.0_f64;
217 let mut residual_sum_squares = 0.0_f64;
218 let mut observed_sum = 0.0_f64;
219 for (obs, pred) in observed.iter().zip(predicted_mean.iter()) {
220 let residual = obs - pred;
221 residuals.push(residual);
222 abs_sum += residual.abs();
223 residual_sum += residual;
224 residual_sum_squares += residual * residual;
225 observed_sum += obs;
226 }
227
228 let observed_mean = observed_sum / n_obs_f;
229 let total_sum_squares = observed
230 .iter()
231 .map(|value| {
232 let centered = value - observed_mean;
233 centered * centered
234 })
235 .sum::<f64>();
236 let r_squared = if total_sum_squares > 0.0 {
237 Some(1.0 - residual_sum_squares / total_sum_squares)
238 } else {
239 None
240 };
241
242 Ok(PredictionDiagnostics {
243 n_obs,
244 mae: abs_sum / n_obs_f,
245 rmse: (residual_sum_squares / n_obs_f).sqrt(),
246 bias: residual_sum / n_obs_f,
247 r_squared,
248 residuals,
249 })
250}
251
252#[derive(Clone, Debug, Default)]
254pub struct GradientDiagnosticReport {
255 pub envelopeaudit: Option<EnvelopeAudit>,
257 pub spectral_bleed: Vec<SpectralBleedResult>,
259 pub dualridge: Option<DualRidgeResult>,
261}
262
263impl GradientDiagnosticReport {
264 pub fn new() -> Self {
266 Self::default()
267 }
268
269 pub fn summary(&self) -> String {
271 let mut lines = Vec::new();
272
273 if let Some(ref audit) = self.envelopeaudit
274 && audit.isviolated
275 {
276 lines.push(format!("[DIAG] {}", audit));
277 }
278
279 for bleed in &self.spectral_bleed {
280 if bleed.has_bleed {
281 lines.push(format!("[DIAG] {}", bleed));
282 }
283 }
284
285 if let Some(ref ridge) = self.dualridge
286 && ridge.has_mismatch
287 {
288 lines.push(format!("[DIAG] {}", ridge));
289 }
290
291 if lines.is_empty() {
292 "No gradient diagnostic issues detected.".to_string()
293 } else {
294 lines.join("\n")
295 }
296 }
297}
298
299pub fn compute_envelopeaudit(
317 kkt_residual_norm: f64,
318 referencegradient: &Array1<f64>,
319 ridge_used: f64,
320 ridge_assumed: f64,
321 beta: &Array1<f64>,
322 abs_tolerance: f64,
323 rel_tolerance: f64,
324) -> EnvelopeAudit {
325 let kkt_norm = kkt_residual_norm;
326 let penalty_norm = referencegradient.dot(referencegradient).sqrt();
327 let beta_norm = beta.dot(beta).sqrt();
328 let scale = penalty_norm.max((ridge_assumed.abs() * beta_norm).max(1e-12));
329 let rel_kkt = if scale > 0.0 { kkt_norm / scale } else { 0.0 };
330 let ridge_mismatch = (ridge_used - ridge_assumed).abs() > 1e-12;
331 let kktviolation = kkt_norm > abs_tolerance && rel_kkt > rel_tolerance;
332 let isviolated = kktviolation || ridge_mismatch;
333
334 let message = if ridge_mismatch && kktviolation {
335 format!(
336 "Envelope Violation: Inner solver ridge = {:.2e}, Outer gradient assumes ridge = {:.2e}. \
337 KKT residual norm = {:.2e} (abs tol = {:.2e}, rel tol = {:.2e}). Unaccounted gradient energy: {:.2e}",
338 ridge_used, ridge_assumed, kkt_norm, abs_tolerance, rel_tolerance, kkt_norm
339 )
340 } else if ridge_mismatch {
341 format!(
342 "Ridge Mismatch: PIRLS optimized for H + {:.2e}*I, but Gradient calculated for H + {:.2e}*I",
343 ridge_used, ridge_assumed
344 )
345 } else if kktviolation {
346 format!(
347 "Envelope Violation: KKT residual ||∇_β L|| = {:.2e} (rel {:.2e}) exceeds tolerances (abs {:.2e}, rel {:.2e}). \
348 Inner solver may not have converged to true stationary point.",
349 kkt_norm, rel_kkt, abs_tolerance, rel_tolerance
350 )
351 } else {
352 format!(
353 "Envelope OK: KKT residual = {:.2e} (rel {:.2e}), ridge match = {:.2e}",
354 kkt_norm, rel_kkt, ridge_used
355 )
356 };
357
358 EnvelopeAudit {
359 kkt_residual_norm: kkt_norm,
360 innerridge: ridge_used,
361 outerridge: ridge_assumed,
362 isviolated,
363 message,
364 }
365}
366
367pub fn compute_dualridge_check(
384 pirlsridge: f64,
385 costridge: f64,
386 gradientridge: f64,
387 beta: &Array1<f64>,
388) -> DualRidgeResult {
389 let beta_norm_sq = beta.dot(beta);
390 let beta_norm = beta_norm_sq.sqrt();
391
392 let ridge_impact = pirlsridge * beta_norm;
393 let phantom_penalty = 0.5 * pirlsridge * beta_norm_sq;
394
395 let pirlscost_mismatch = (pirlsridge - costridge).abs() > 1e-12;
396 let pirlsgrad_mismatch = (pirlsridge - gradientridge).abs() > 1e-12;
397 let costgrad_mismatch = (costridge - gradientridge).abs() > 1e-12;
398 let has_mismatch = pirlscost_mismatch || pirlsgrad_mismatch || costgrad_mismatch;
399
400 let message = if has_mismatch {
401 let mut mismatches = Vec::new();
402 if pirlscost_mismatch {
403 mismatches.push(format!(
404 "PIRLS({:.2e}) vs Cost({:.2e})",
405 pirlsridge, costridge
406 ));
407 }
408 if pirlsgrad_mismatch {
409 mismatches.push(format!(
410 "PIRLS({:.2e}) vs Gradient({:.2e})",
411 pirlsridge, gradientridge
412 ));
413 }
414 if costgrad_mismatch {
415 mismatches.push(format!(
416 "Cost({:.2e}) vs Gradient({:.2e})",
417 costridge, gradientridge
418 ));
419 }
420 format!(
421 "Ridge Mismatch detected: {}. Effective ridge impact on ||β|| = {:.2e}. \
422 Phantom penalty = {:.2e}. The surface being differentiated differs from \
423 the surface being optimized.",
424 mismatches.join(", "),
425 ridge_impact,
426 phantom_penalty
427 )
428 } else if pirlsridge > 0.0 {
429 format!(
430 "Ridge Consistency OK: All stages use ridge = {:.2e}. ||β|| = {:.2e}, phantom penalty = {:.2e}",
431 pirlsridge, beta_norm, phantom_penalty
432 )
433 } else {
434 "Ridge Consistency OK: No stabilization ridge required.".to_string()
435 };
436
437 DualRidgeResult {
438 pirlsridge,
439 costridge,
440 gradientridge,
441 ridge_impact,
442 phantom_penalty,
443 has_mismatch,
444 message,
445 }
446}
447
448#[derive(Clone, Copy, Debug, PartialEq, Eq)]
461pub enum KktRefusalDiagnosis {
462 RankDeficientHPen,
463 PhantomMultiplierWithWellConditionedH,
464 ActiveSetIncomplete,
465 AliasingDetectedAtFit,
471}
472
473impl KktRefusalDiagnosis {
474 pub fn as_str(&self) -> &'static str {
475 match self {
476 KktRefusalDiagnosis::RankDeficientHPen => "rank_deficient_H_pen",
477 KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => {
478 "phantom_multiplier_with_well_conditioned_H"
479 }
480 KktRefusalDiagnosis::ActiveSetIncomplete => "active_set_incomplete",
481 KktRefusalDiagnosis::AliasingDetectedAtFit => "aliasing_detected_at_fit",
482 }
483 }
484
485 pub fn parse_from_error(message: &str) -> Option<Self> {
489 let marker = "diagnosis: ";
490 let start = message.rfind(marker)? + marker.len();
491 let tail = &message[start..];
492 let end = tail
493 .find(|c: char| c == ';' || c == '\n' || c == ' ')
494 .unwrap_or(tail.len());
495 match &tail[..end] {
496 "rank_deficient_H_pen" => Some(KktRefusalDiagnosis::RankDeficientHPen),
497 "phantom_multiplier_with_well_conditioned_H" => {
498 Some(KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH)
499 }
500 "active_set_incomplete" => Some(KktRefusalDiagnosis::ActiveSetIncomplete),
501 "aliasing_detected_at_fit" => Some(KktRefusalDiagnosis::AliasingDetectedAtFit),
502 _ => None,
503 }
504 }
505
506 pub fn guidance(self) -> &'static str {
507 match self {
508 KktRefusalDiagnosis::RankDeficientHPen => {
509 "check whether the named block has a structural or numerical null direction \
510 not identified by the likelihood/penalty combination; for Duchon-style \
511 smooths this may be a polynomial null space, while marginal-slope fits can \
512 also expose callback-owned weak directions"
513 }
514 KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => {
515 "check whether the named block has a near-separated or weakly identified \
516 direction despite a well-conditioned penalized Hessian; in marginal-slope \
517 fits this often indicates marginal/logslope coupling rather than a \
518 Matérn/Duchon polynomial-nullspace failure"
519 }
520 KktRefusalDiagnosis::ActiveSetIncomplete => {
521 "check whether the named block's linear constraints need an additional \
522 active row or a tighter constrained re-solve; this is an active-set \
523 certification failure, not a polynomial-nullspace diagnosis"
524 }
525 KktRefusalDiagnosis::AliasingDetectedAtFit => {
526 "check whether the named block aliases another block after runtime \
527 constraints or callbacks materialize; drop or reparameterize the aliased \
528 direction before fitting"
529 }
530 }
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use ndarray::arr1;
538
539 #[test]
540 fn test_envelopeaudit_noviolation() {
541 let reference = arr1(&[0.0, 0.0, 0.0]);
542 let beta = arr1(&[0.1, 0.2, 0.3]);
543 let result = compute_envelopeaudit(0.0, &reference, 0.0, 0.0, &beta, 1e-8, 1e-6);
544
545 assert!(!result.isviolated);
546 }
547
548 #[test]
549 fn test_envelopeaudit_detects_ridge_mismatch() {
550 let reference = arr1(&[1.0, 0.0, 0.0]);
551 let beta = arr1(&[0.1, 0.2, 0.3]);
552 let result = compute_envelopeaudit(1e-10, &reference, 0.1, 0.0, &beta, 1e-8, 1e-6);
553
554 assert!(result.isviolated);
555 assert!(result.message.contains("Ridge Mismatch"));
556 }
557
558 #[test]
559 fn test_dualridge_check_no_mismatch() {
560 let beta = arr1(&[0.1, 0.2, 0.3]);
561 let result = compute_dualridge_check(0.0, 0.0, 0.0, &beta);
562
563 assert!(!result.has_mismatch);
564 }
565
566 #[test]
567 fn test_dualridge_check_detects_mismatch() {
568 let beta = arr1(&[0.1, 0.2, 0.3]);
569 let result = compute_dualridge_check(1e-4, 0.0, 0.0, &beta);
570
571 assert!(result.has_mismatch);
572 assert!(result.message.contains("Ridge Mismatch detected"));
573 }
574
575 #[test]
576 fn diagnostics_from_predictions_computes_residual_metrics() {
577 let observed = [1.0, 2.0, 4.0];
578 let predicted = [1.5, 1.5, 3.0];
579
580 let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
581
582 assert_eq!(result.residuals, vec![-0.5, 0.5, 1.0]);
583 assert_eq!(result.n_obs, 3);
584 assert_eq!(result.mae, 2.0 / 3.0);
585 assert_eq!(result.bias, 1.0 / 3.0);
586 assert_eq!(result.rmse, (1.5_f64 / 3.0).sqrt());
587 assert_eq!(result.r_squared, Some(1.0 - 1.5 / (14.0 / 3.0)));
588 }
589
590 #[test]
591 fn diagnostics_from_predictions_omits_r_squared_for_constant_observed() {
592 let observed = [2.0, 2.0];
593 let predicted = [1.0, 3.0];
594
595 let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
596
597 assert_eq!(result.r_squared, None);
598 }
599
600 #[test]
601 fn diagnostics_from_predictions_rejects_invalid_inputs() {
602 assert_eq!(
603 diagnostics_from_predictions(&[], &[]),
604 Err("diagnostics_from_predictions requires at least one observation".to_string())
605 );
606 assert_eq!(
607 diagnostics_from_predictions(&[1.0], &[1.0, 2.0]),
608 Err(
609 "diagnostics_from_predictions length mismatch: observed has 1 values but predicted mean has 2"
610 .to_string()
611 )
612 );
613 assert_eq!(
614 diagnostics_from_predictions(&[f64::NAN], &[1.0]),
615 Err("observed values must contain only finite numbers".to_string())
616 );
617 assert_eq!(
618 diagnostics_from_predictions(&[1.0], &[f64::INFINITY]),
619 Err("predicted mean values must contain only finite numbers".to_string())
620 );
621 }
622}