Skip to main content

gam_problem/
diagnostics.rs

1//! Analytic diagnostic helpers for LAML/REML optimization.
2//!
3//! Production diagnostics inspect analytic invariants only. Runtime fitting,
4//! prediction, and diagnostic APIs must consume quantities the optimizer
5//! already computes. This module implements diagnostic strategies that identify
6//! root causes of gradient pathologies from those analytic quantities:
7//!
8//! 1. KKT Audit (Envelope Theorem Check): Detects violations of the stationarity
9//!    assumption used in implicit differentiation.
10//!
11//! 2. Spectral Bleed Trace: Detects when truncated eigenspace corrections are
12//!    inconsistent with the penalty's energy in that subspace.
13//!
14//! 3. Dual-Ridge Consistency Check: Verifies that the ridge used by the inner
15//!    solver (PIRLS) matches what the outer gradient calculation assumes.
16
17use ndarray::Array1;
18use std::fmt;
19use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
20
21// =============================================================================
22// Rate-Limited Diagnostic Output
23// =============================================================================
24// These helpers prevent diagnostic spam while ensuring important messages are seen.
25// Pattern: show first occurrence, then every Nth occurrence, with count indicator.
26
27
28/// Rate-limited diagnostic for Hessian minimum eigenvalue warnings
29pub static H_MIN_EIG_LOG_BUCKET: AtomicI32 = AtomicI32::new(i32::MIN);
30/// Count of `should_emit_h_min_eig_diag` invocations that have ever been
31/// considered for emission; used together with `H_MIN_EIG_LOG_BUCKET` to
32/// rate-limit one diagnostic per decade-magnitude bucket and per
33/// `MIN_EIG_DIAG_EVERY` repeats within the same bucket.
34pub static H_MIN_EIG_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
35/// Repeat period within a magnitude bucket for the Hessian-minimum-eigenvalue
36/// diagnostic: after the first emission for a bucket, every Nth subsequent
37/// invocation also emits.
38pub const MIN_EIG_DIAG_EVERY: usize = 200;
39/// Threshold below which a positive Hessian minimum eigenvalue is treated as
40/// nearly-singular and routed through the rate-limited diagnostic.
41pub const MIN_EIG_DIAG_THRESHOLD: f64 = 1e-4;
42
43/// Diagnostic formatter shared across the outer optimizer and the custom-family
44/// fitter: shows the `max_items` entries of `values` with largest absolute
45/// value, formatted as `label=[i:value, ...]`.
46pub fn format_top_abs(values: &Array1<f64>, label: &str, max_items: usize) -> String {
47    if values.is_empty() {
48        return format!("{label}=<empty>");
49    }
50    let mut ranked: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
51    ranked.sort_by(|(_, left), (_, right)| {
52        right
53            .abs()
54            .partial_cmp(&left.abs())
55            .unwrap_or(std::cmp::Ordering::Equal)
56    });
57    let parts: Vec<String> = ranked
58        .into_iter()
59        .take(max_items)
60        .map(|(idx, value)| format!("{idx}:{value:.3e}"))
61        .collect();
62    format!("{label}=[{}]", parts.join(", "))
63}
64
65/// Rate-limited check for Hessian minimum eigenvalue diagnostics.
66/// Returns true if this eigenvalue warrants a diagnostic message.
67pub fn should_emit_h_min_eig_diag(min_eig: f64) -> bool {
68    if !min_eig.is_finite() || min_eig <= 0.0 {
69        return true;
70    }
71    if min_eig >= MIN_EIG_DIAG_THRESHOLD {
72        return false;
73    }
74    let bucket = if min_eig.is_finite() && min_eig > 0.0 {
75        min_eig.log10().floor() as i32
76    } else {
77        i32::MIN
78    };
79    let last = H_MIN_EIG_LOG_BUCKET.load(Ordering::Relaxed);
80    let count = H_MIN_EIG_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
81    if bucket != last || count.is_multiple_of(MIN_EIG_DIAG_EVERY) {
82        H_MIN_EIG_LOG_BUCKET.store(bucket, Ordering::Relaxed);
83        true
84    } else {
85        false
86    }
87}
88
89// =============================================================================
90// Formatting Utilities for Diagnostic Output
91// =============================================================================
92
93/// Configuration for gradient diagnostics
94#[derive(Clone, Debug)]
95pub struct DiagnosticConfig {
96    /// Tolerance for KKT residual norm (envelope theorem violation)
97    pub kkt_tolerance: f64,
98    /// Relative error threshold for flagging issues
99    pub rel_error_threshold: f64,
100    /// Whether to emit warnings to stderr
101    pub emitwarnings: bool,
102}
103
104impl Default for DiagnosticConfig {
105    fn default() -> Self {
106        Self {
107            kkt_tolerance: 1e-4,
108            rel_error_threshold: 0.1,
109            emitwarnings: true,
110        }
111    }
112}
113
114/// Result of envelope theorem (KKT) audit
115#[derive(Clone, Debug)]
116pub struct EnvelopeAudit {
117    /// Norm of the inner KKT residual ∇_β L(β*, ρ)
118    pub kkt_residual_norm: f64,
119    /// Ridge used by the inner solver
120    pub innerridge: f64,
121    /// Ridge assumed by the outer gradient calculation
122    pub outerridge: f64,
123    /// Whether the envelope theorem is violated
124    pub isviolated: bool,
125    /// Human-readable diagnostic message
126    pub message: String,
127}
128
129impl fmt::Display for EnvelopeAudit {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        write!(f, "{}", self.message)
132    }
133}
134
135/// Result of spectral bleed trace diagnostic
136#[derive(Clone, Debug)]
137pub struct SpectralBleedResult {
138    pub penalty_k: usize,
139    /// Energy of penalty S_k in the truncated subspace: trace(U_⊥' S_k U_⊥)
140    pub truncated_energy: f64,
141    /// Correction term actually applied in the gradient
142    pub applied_correction: f64,
143    /// Whether there's a spectral bleed issue
144    pub has_bleed: bool,
145    /// Human-readable diagnostic message
146    pub message: String,
147}
148
149impl fmt::Display for SpectralBleedResult {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        write!(f, "{}", self.message)
152    }
153}
154
155/// Result of dual-ridge consistency check
156#[derive(Clone, Debug)]
157pub struct DualRidgeResult {
158    /// Ridge used during P-IRLS optimization
159    pub pirlsridge: f64,
160    /// Ridge used in LAML cost function
161    pub costridge: f64,
162    /// Ridge used in gradient calculation
163    pub gradientridge: f64,
164    /// Effective ridge impact: ||ridge * β||
165    pub ridge_impact: f64,
166    /// Phantom penalty contribution: 0.5 * ridge * ||β||²
167    pub phantom_penalty: f64,
168    /// Whether there's a ridge mismatch
169    pub has_mismatch: bool,
170    /// Human-readable diagnostic message
171    pub message: String,
172}
173
174impl fmt::Display for DualRidgeResult {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        write!(f, "{}", self.message)
177    }
178}
179
180/// Residual diagnostics for observed values and predicted means.
181#[derive(Clone, Debug, PartialEq)]
182pub struct PredictionDiagnostics {
183    pub n_obs: usize,
184    pub mae: f64,
185    pub rmse: f64,
186    pub bias: f64,
187    pub r_squared: Option<f64>,
188    pub residuals: Vec<f64>,
189}
190
191/// Compute prediction residual diagnostics from observed values and predicted means.
192pub fn diagnostics_from_predictions(
193    observed: &[f64],
194    predicted_mean: &[f64],
195) -> Result<PredictionDiagnostics, String> {
196    if observed.is_empty() {
197        return Err("diagnostics_from_predictions requires at least one observation".to_string());
198    }
199    if observed.len() != predicted_mean.len() {
200        return Err(format!(
201            "diagnostics_from_predictions length mismatch: observed has {} values but predicted mean has {}",
202            observed.len(),
203            predicted_mean.len()
204        ));
205    }
206    if observed.iter().any(|value| !value.is_finite()) {
207        return Err("observed values must contain only finite numbers".to_string());
208    }
209    if predicted_mean.iter().any(|value| !value.is_finite()) {
210        return Err("predicted mean values must contain only finite numbers".to_string());
211    }
212
213    let n_obs = observed.len();
214    let n_obs_f = n_obs as f64;
215    let mut residuals = Vec::with_capacity(n_obs);
216    let mut abs_sum = 0.0_f64;
217    let mut residual_sum = 0.0_f64;
218    let mut residual_sum_squares = 0.0_f64;
219    let mut observed_sum = 0.0_f64;
220    for (obs, pred) in observed.iter().zip(predicted_mean.iter()) {
221        let residual = obs - pred;
222        residuals.push(residual);
223        abs_sum += residual.abs();
224        residual_sum += residual;
225        residual_sum_squares += residual * residual;
226        observed_sum += obs;
227    }
228
229    let observed_mean = observed_sum / n_obs_f;
230    let total_sum_squares = observed
231        .iter()
232        .map(|value| {
233            let centered = value - observed_mean;
234            centered * centered
235        })
236        .sum::<f64>();
237    let r_squared = if total_sum_squares > 0.0 {
238        Some(1.0 - residual_sum_squares / total_sum_squares)
239    } else {
240        None
241    };
242
243    Ok(PredictionDiagnostics {
244        n_obs,
245        mae: abs_sum / n_obs_f,
246        rmse: (residual_sum_squares / n_obs_f).sqrt(),
247        bias: residual_sum / n_obs_f,
248        r_squared,
249        residuals,
250    })
251}
252
253/// Complete diagnostic report for a gradient evaluation
254#[derive(Clone, Debug, Default)]
255pub struct GradientDiagnosticReport {
256    /// Envelope theorem audit results
257    pub envelopeaudit: Option<EnvelopeAudit>,
258    /// Spectral bleed results for each penalty
259    pub spectral_bleed: Vec<SpectralBleedResult>,
260    /// Dual-ridge consistency result
261    pub dualridge: Option<DualRidgeResult>,
262}
263
264impl GradientDiagnosticReport {
265    /// Create an empty report
266    pub fn new() -> Self {
267        Self::default()
268    }
269
270    /// Generate a summary string of all issues found
271    pub fn summary(&self) -> String {
272        let mut lines = Vec::new();
273
274        if let Some(ref audit) = self.envelopeaudit
275            && audit.isviolated
276        {
277            lines.push(format!("[DIAG] {}", audit));
278        }
279
280        for bleed in &self.spectral_bleed {
281            if bleed.has_bleed {
282                lines.push(format!("[DIAG] {}", bleed));
283            }
284        }
285
286        if let Some(ref ridge) = self.dualridge
287            && ridge.has_mismatch
288        {
289            lines.push(format!("[DIAG] {}", ridge));
290        }
291
292        if lines.is_empty() {
293            "No gradient diagnostic issues detected.".to_string()
294        } else {
295            lines.join("\n")
296        }
297    }
298}
299
300// =============================================================================
301// Strategy 1: Envelope Theorem (KKT) Audit
302// =============================================================================
303
304/// Compute the inner KKT residual to detect envelope theorem violations.
305///
306/// The analytic gradient calculation assumes that P-IRLS found an exact stationary
307/// point where ∇_β L = 0. If this is not true (due to stabilization ridge, Firth
308/// adjustments, or early termination), the "indirect term" of the chain rule becomes
309/// significant and the gradient will be wrong.
310///
311/// # Arguments
312/// * `kkt_residual_norm` - Norm of the full inner gradient ||∇_β L|| at the PIRLS solution
313/// * `referencegradient` - Reference gradient scale (typically S_λ β) for relative normalization
314/// * `ridge_used` - Ridge added by PIRLS for stabilization
315/// * `beta` - Current coefficient estimate
316/// * `tolerance` - Threshold for flagging violations
317pub fn compute_envelopeaudit(
318    kkt_residual_norm: f64,
319    referencegradient: &Array1<f64>,
320    ridge_used: f64,
321    ridge_assumed: f64,
322    beta: &Array1<f64>,
323    abs_tolerance: f64,
324    rel_tolerance: f64,
325) -> EnvelopeAudit {
326    let kkt_norm = kkt_residual_norm;
327    let penalty_norm = referencegradient.dot(referencegradient).sqrt();
328    let beta_norm = beta.dot(beta).sqrt();
329    let scale = penalty_norm.max((ridge_assumed.abs() * beta_norm).max(1e-12));
330    let rel_kkt = if scale > 0.0 { kkt_norm / scale } else { 0.0 };
331    let ridge_mismatch = (ridge_used - ridge_assumed).abs() > 1e-12;
332    let kktviolation = kkt_norm > abs_tolerance && rel_kkt > rel_tolerance;
333    let isviolated = kktviolation || ridge_mismatch;
334
335    let message = if ridge_mismatch && kktviolation {
336        format!(
337            "Envelope Violation: Inner solver ridge = {:.2e}, Outer gradient assumes ridge = {:.2e}. \
338             KKT residual norm = {:.2e} (abs tol = {:.2e}, rel tol = {:.2e}). Unaccounted gradient energy: {:.2e}",
339            ridge_used, ridge_assumed, kkt_norm, abs_tolerance, rel_tolerance, kkt_norm
340        )
341    } else if ridge_mismatch {
342        format!(
343            "Ridge Mismatch: PIRLS optimized for H + {:.2e}*I, but Gradient calculated for H + {:.2e}*I",
344            ridge_used, ridge_assumed
345        )
346    } else if kktviolation {
347        format!(
348            "Envelope Violation: KKT residual ||∇_β L|| = {:.2e} (rel {:.2e}) exceeds tolerances (abs {:.2e}, rel {:.2e}). \
349             Inner solver may not have converged to true stationary point.",
350            kkt_norm, rel_kkt, abs_tolerance, rel_tolerance
351        )
352    } else {
353        format!(
354            "Envelope OK: KKT residual = {:.2e} (rel {:.2e}), ridge match = {:.2e}",
355            kkt_norm, rel_kkt, ridge_used
356        )
357    };
358
359    EnvelopeAudit {
360        kkt_residual_norm: kkt_norm,
361        innerridge: ridge_used,
362        outerridge: ridge_assumed,
363        isviolated,
364        message,
365    }
366}
367
368// =============================================================================
369// Strategy 4: Dual-Ridge Consistency Check
370// =============================================================================
371
372/// Check consistency between the ridge used in different stages of computation.
373///
374/// When the Hessian is non-positive-definite, ensure_positive_definitewithridge
375/// adds a stabilization ridge during P-IRLS. This ridge changes the objective
376/// surface being optimized. If the gradient calculation uses a different ridge
377/// value, it will point in the wrong direction.
378///
379/// # Arguments
380/// * `pirlsridge` - Ridge actually used during P-IRLS iteration
381/// * `costridge` - Ridge used when computing LAML cost
382/// * `gradientridge` - Ridge assumed when computing analytic gradient
383/// * `beta` - Current coefficient estimate
384pub fn compute_dualridge_check(
385    pirlsridge: f64,
386    costridge: f64,
387    gradientridge: f64,
388    beta: &Array1<f64>,
389) -> DualRidgeResult {
390    let beta_norm_sq = beta.dot(beta);
391    let beta_norm = beta_norm_sq.sqrt();
392
393    let ridge_impact = pirlsridge * beta_norm;
394    let phantom_penalty = 0.5 * pirlsridge * beta_norm_sq;
395
396    let pirlscost_mismatch = (pirlsridge - costridge).abs() > 1e-12;
397    let pirlsgrad_mismatch = (pirlsridge - gradientridge).abs() > 1e-12;
398    let costgrad_mismatch = (costridge - gradientridge).abs() > 1e-12;
399    let has_mismatch = pirlscost_mismatch || pirlsgrad_mismatch || costgrad_mismatch;
400
401    let message = if has_mismatch {
402        let mut mismatches = Vec::new();
403        if pirlscost_mismatch {
404            mismatches.push(format!(
405                "PIRLS({:.2e}) vs Cost({:.2e})",
406                pirlsridge, costridge
407            ));
408        }
409        if pirlsgrad_mismatch {
410            mismatches.push(format!(
411                "PIRLS({:.2e}) vs Gradient({:.2e})",
412                pirlsridge, gradientridge
413            ));
414        }
415        if costgrad_mismatch {
416            mismatches.push(format!(
417                "Cost({:.2e}) vs Gradient({:.2e})",
418                costridge, gradientridge
419            ));
420        }
421        format!(
422            "Ridge Mismatch detected: {}. Effective ridge impact on ||β|| = {:.2e}. \
423             Phantom penalty = {:.2e}. The surface being differentiated differs from \
424             the surface being optimized.",
425            mismatches.join(", "),
426            ridge_impact,
427            phantom_penalty
428        )
429    } else if pirlsridge > 0.0 {
430        format!(
431            "Ridge Consistency OK: All stages use ridge = {:.2e}. ||β|| = {:.2e}, phantom penalty = {:.2e}",
432            pirlsridge, beta_norm, phantom_penalty
433        )
434    } else {
435        "Ridge Consistency OK: No stabilization ridge required.".to_string()
436    };
437
438    DualRidgeResult {
439        pirlsridge,
440        costridge,
441        gradientridge,
442        ridge_impact,
443        phantom_penalty,
444        has_mismatch,
445        message,
446    }
447}
448
449/// Three-way classification of why the cert refused, computed from the
450/// H_pen spectrum and the projected residual at the refusing iterate.
451/// `RankDeficientHPen` is the regression canary the nullspace lead's
452/// smooth-construction rework is intended to eliminate; keep this variant
453/// intact when extending — it doubles as the user-facing signal for
454/// "an unconstrained polynomial null space slipped past absorption."
455///
456/// Relocated from `gam-solve`'s `custom_family/joint_newton.rs` (issue #1521
457/// crate carve): this is the neutral diagnostic carrier that `gam-solve`'s
458/// REML/PIRLS core consumes when classifying a custom-family cert refusal,
459/// so it must live BELOW both the core and the (extracted) custom-family
460/// subsystem.
461#[derive(Clone, Copy, Debug, PartialEq, Eq)]
462pub enum KktRefusalDiagnosis {
463    RankDeficientHPen,
464    PhantomMultiplierWithWellConditionedH,
465    ActiveSetIncomplete,
466    /// Cross-block identifiability aliasing surfaced mid-inner-solve
467    /// (e.g., a binding active set materialised a 2-way alias that
468    /// the pre-fit audit could not see at the cold design). The fix
469    /// is structural — drop or reparameterise the aliased block;
470    /// rho-anneal will not recover.
471    AliasingDetectedAtFit,
472}
473
474impl KktRefusalDiagnosis {
475    pub fn as_str(&self) -> &'static str {
476        match self {
477            KktRefusalDiagnosis::RankDeficientHPen => "rank_deficient_H_pen",
478            KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => {
479                "phantom_multiplier_with_well_conditioned_H"
480            }
481            KktRefusalDiagnosis::ActiveSetIncomplete => "active_set_incomplete",
482            KktRefusalDiagnosis::AliasingDetectedAtFit => "aliasing_detected_at_fit",
483        }
484    }
485
486    /// Parse the textual `diagnosis:` field embedded in the structured
487    /// bubbled error string. Returns `None` when no recognised label is
488    /// present (legacy / non-cert-refusal error strings).
489    pub fn parse_from_error(message: &str) -> Option<Self> {
490        let marker = "diagnosis: ";
491        let start = message.rfind(marker)? + marker.len();
492        let tail = &message[start..];
493        let end = tail
494            .find(|c: char| c == ';' || c == '\n' || c == ' ')
495            .unwrap_or(tail.len());
496        match &tail[..end] {
497            "rank_deficient_H_pen" => Some(KktRefusalDiagnosis::RankDeficientHPen),
498            "phantom_multiplier_with_well_conditioned_H" => {
499                Some(KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH)
500            }
501            "active_set_incomplete" => Some(KktRefusalDiagnosis::ActiveSetIncomplete),
502            "aliasing_detected_at_fit" => Some(KktRefusalDiagnosis::AliasingDetectedAtFit),
503            _ => None,
504        }
505    }
506
507    pub fn guidance(self) -> &'static str {
508        match self {
509            KktRefusalDiagnosis::RankDeficientHPen => {
510                "check whether the named block has a structural or numerical null direction \
511                 not identified by the likelihood/penalty combination; for Duchon-style \
512                 smooths this may be a polynomial null space, while marginal-slope fits can \
513                 also expose callback-owned weak directions"
514            }
515            KktRefusalDiagnosis::PhantomMultiplierWithWellConditionedH => {
516                "check whether the named block has a near-separated or weakly identified \
517                 direction despite a well-conditioned penalized Hessian; in marginal-slope \
518                 fits this often indicates marginal/logslope coupling rather than a \
519                 Matérn/Duchon polynomial-nullspace failure"
520            }
521            KktRefusalDiagnosis::ActiveSetIncomplete => {
522                "check whether the named block's linear constraints need an additional \
523                 active row or a tighter constrained re-solve; this is an active-set \
524                 certification failure, not a polynomial-nullspace diagnosis"
525            }
526            KktRefusalDiagnosis::AliasingDetectedAtFit => {
527                "check whether the named block aliases another block after runtime \
528                 constraints or callbacks materialize; drop or reparameterize the aliased \
529                 direction before fitting"
530            }
531        }
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use ndarray::arr1;
539
540    #[test]
541    fn test_envelopeaudit_noviolation() {
542        let reference = arr1(&[0.0, 0.0, 0.0]);
543        let beta = arr1(&[0.1, 0.2, 0.3]);
544        let result = compute_envelopeaudit(0.0, &reference, 0.0, 0.0, &beta, 1e-8, 1e-6);
545
546        assert!(!result.isviolated);
547    }
548
549    #[test]
550    fn test_envelopeaudit_detects_ridge_mismatch() {
551        let reference = arr1(&[1.0, 0.0, 0.0]);
552        let beta = arr1(&[0.1, 0.2, 0.3]);
553        let result = compute_envelopeaudit(1e-10, &reference, 0.1, 0.0, &beta, 1e-8, 1e-6);
554
555        assert!(result.isviolated);
556        assert!(result.message.contains("Ridge Mismatch"));
557    }
558
559    #[test]
560    fn test_dualridge_check_no_mismatch() {
561        let beta = arr1(&[0.1, 0.2, 0.3]);
562        let result = compute_dualridge_check(0.0, 0.0, 0.0, &beta);
563
564        assert!(!result.has_mismatch);
565    }
566
567    #[test]
568    fn test_dualridge_check_detects_mismatch() {
569        let beta = arr1(&[0.1, 0.2, 0.3]);
570        let result = compute_dualridge_check(1e-4, 0.0, 0.0, &beta);
571
572        assert!(result.has_mismatch);
573        assert!(result.message.contains("Ridge Mismatch detected"));
574    }
575
576    #[test]
577    fn diagnostics_from_predictions_computes_residual_metrics() {
578        let observed = [1.0, 2.0, 4.0];
579        let predicted = [1.5, 1.5, 3.0];
580
581        let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
582
583        assert_eq!(result.residuals, vec![-0.5, 0.5, 1.0]);
584        assert_eq!(result.n_obs, 3);
585        assert_eq!(result.mae, 2.0 / 3.0);
586        assert_eq!(result.bias, 1.0 / 3.0);
587        assert_eq!(result.rmse, (1.5_f64 / 3.0).sqrt());
588        assert_eq!(result.r_squared, Some(1.0 - 1.5 / (14.0 / 3.0)));
589    }
590
591    #[test]
592    fn diagnostics_from_predictions_omits_r_squared_for_constant_observed() {
593        let observed = [2.0, 2.0];
594        let predicted = [1.0, 3.0];
595
596        let result = diagnostics_from_predictions(&observed, &predicted).unwrap();
597
598        assert_eq!(result.r_squared, None);
599    }
600
601    #[test]
602    fn diagnostics_from_predictions_rejects_invalid_inputs() {
603        assert_eq!(
604            diagnostics_from_predictions(&[], &[]),
605            Err("diagnostics_from_predictions requires at least one observation".to_string())
606        );
607        assert_eq!(
608            diagnostics_from_predictions(&[1.0], &[1.0, 2.0]),
609            Err(
610                "diagnostics_from_predictions length mismatch: observed has 1 values but predicted mean has 2"
611                    .to_string()
612            )
613        );
614        assert_eq!(
615            diagnostics_from_predictions(&[f64::NAN], &[1.0]),
616            Err("observed values must contain only finite numbers".to_string())
617        );
618        assert_eq!(
619            diagnostics_from_predictions(&[1.0], &[f64::INFINITY]),
620            Err("predicted mean values must contain only finite numbers".to_string())
621        );
622    }
623}