gam 0.3.59

Generalized penalized likelihood engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
//! Analytic diagnostic helpers for LAML/REML optimization.
//!
//! Production diagnostics inspect analytic invariants only. Runtime fitting,
//! prediction, and diagnostic APIs must consume quantities the optimizer
//! already computes. This module implements diagnostic strategies that identify
//! root causes of gradient pathologies from those analytic quantities:
//!
//! 1. KKT Audit (Envelope Theorem Check): Detects violations of the stationarity
//!    assumption used in implicit differentiation.
//!
//! 2. Spectral Bleed Trace: Detects when truncated eigenspace corrections are
//!    inconsistent with the penalty's energy in that subspace.
//!
//! 3. Dual-Ridge Consistency Check: Verifies that the ridge used by the inner
//!    solver (PIRLS) matches what the outer gradient calculation assumes.

use ndarray::Array1;
use std::fmt;
use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};

// =============================================================================
// Rate-Limited Diagnostic Output
// =============================================================================
// These helpers prevent diagnostic spam while ensuring important messages are seen.
// Pattern: show first occurrence, then every Nth occurrence, with count indicator.

/// Rate-limited diagnostic counters for gradient calculations
pub static GRAD_DIAG_BETA_COLLAPSE_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_DELTA_ZERO_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_LOGH_CLAMPED_COUNT: AtomicUsize = AtomicUsize::new(0);
pub static GRAD_DIAG_KKT_SKIP_COUNT: AtomicUsize = AtomicUsize::new(0);

/// Rate-limited diagnostic for Hessian minimum eigenvalue warnings
pub static H_MIN_EIG_LOG_BUCKET: AtomicI32 = AtomicI32::new(i32::MIN);
pub static H_MIN_EIG_LOG_COUNT: AtomicUsize = AtomicUsize::new(0);
pub const MIN_EIG_DIAG_EVERY: usize = 200;
pub const MIN_EIG_DIAG_THRESHOLD: f64 = 1e-4;

/// Diagnostic formatter shared across the outer optimizer and the custom-family
/// fitter: shows the `max_items` entries of `values` with largest absolute
/// value, formatted as `label=[i:value, ...]`.
pub fn format_top_abs(values: &Array1<f64>, label: &str, max_items: usize) -> String {
    if values.is_empty() {
        return format!("{label}=<empty>");
    }
    let mut ranked: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
    ranked.sort_by(|(_, left), (_, right)| {
        right
            .abs()
            .partial_cmp(&left.abs())
            .unwrap_or(std::cmp::Ordering::Equal)
    });
    let parts: Vec<String> = ranked
        .into_iter()
        .take(max_items)
        .map(|(idx, value)| format!("{idx}:{value:.3e}"))
        .collect();
    format!("{label}=[{}]", parts.join(", "))
}

/// Rate-limited check for Hessian minimum eigenvalue diagnostics.
/// Returns true if this eigenvalue warrants a diagnostic message.
pub fn should_emit_h_min_eig_diag(min_eig: f64) -> bool {
    if !min_eig.is_finite() || min_eig <= 0.0 {
        return true;
    }
    if min_eig >= MIN_EIG_DIAG_THRESHOLD {
        return false;
    }
    let bucket = if min_eig.is_finite() && min_eig > 0.0 {
        min_eig.log10().floor() as i32
    } else {
        i32::MIN
    };
    let last = H_MIN_EIG_LOG_BUCKET.load(Ordering::Relaxed);
    let count = H_MIN_EIG_LOG_COUNT.fetch_add(1, Ordering::Relaxed);
    if bucket != last || count.is_multiple_of(MIN_EIG_DIAG_EVERY) {
        H_MIN_EIG_LOG_BUCKET.store(bucket, Ordering::Relaxed);
        true
    } else {
        false
    }
}

// =============================================================================
// Formatting Utilities for Diagnostic Output
// =============================================================================

/// Configuration for gradient diagnostics
#[derive(Clone, Debug)]
pub struct DiagnosticConfig {
    /// Tolerance for KKT residual norm (envelope theorem violation)
    pub kkt_tolerance: f64,
    /// Relative error threshold for flagging issues
    pub rel_error_threshold: f64,
    /// Whether to emit warnings to stderr
    pub emitwarnings: bool,
}

impl Default for DiagnosticConfig {
    fn default() -> Self {
        Self {
            kkt_tolerance: 1e-4,
            rel_error_threshold: 0.1,
            emitwarnings: true,
        }
    }
}

/// Result of envelope theorem (KKT) audit
#[derive(Clone, Debug)]
pub struct EnvelopeAudit {
    /// Norm of the inner KKT residual ∇_β L(β*, ρ)
    pub kkt_residual_norm: f64,
    /// Ridge used by the inner solver
    pub innerridge: f64,
    /// Ridge assumed by the outer gradient calculation
    pub outerridge: f64,
    /// Whether the envelope theorem is violated
    pub isviolated: bool,
    /// Human-readable diagnostic message
    pub message: String,
}

impl fmt::Display for EnvelopeAudit {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.message)
    }
}

/// Result of spectral bleed trace diagnostic
#[derive(Clone, Debug)]
pub struct SpectralBleedResult {
    pub penalty_k: usize,
    /// Energy of penalty S_k in the truncated subspace: trace(U_⊥' S_k U_⊥)
    pub truncated_energy: f64,
    /// Correction term actually applied in the gradient
    pub applied_correction: f64,
    /// Whether there's a spectral bleed issue
    pub has_bleed: bool,
    /// Human-readable diagnostic message
    pub message: String,
}

impl fmt::Display for SpectralBleedResult {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.message)
    }
}

/// Result of dual-ridge consistency check
#[derive(Clone, Debug)]
pub struct DualRidgeResult {
    /// Ridge used during P-IRLS optimization
    pub pirlsridge: f64,
    /// Ridge used in LAML cost function
    pub costridge: f64,
    /// Ridge used in gradient calculation
    pub gradientridge: f64,
    /// Effective ridge impact: ||ridge * β||
    pub ridge_impact: f64,
    /// Phantom penalty contribution: 0.5 * ridge * ||β||²
    pub phantom_penalty: f64,
    /// Whether there's a ridge mismatch
    pub has_mismatch: bool,
    /// Human-readable diagnostic message
    pub message: String,
}

impl fmt::Display for DualRidgeResult {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.message)
    }
}

/// Complete diagnostic report for a gradient evaluation
#[derive(Clone, Debug, Default)]
pub struct GradientDiagnosticReport {
    /// Envelope theorem audit results
    pub envelopeaudit: Option<EnvelopeAudit>,
    /// Spectral bleed results for each penalty
    pub spectral_bleed: Vec<SpectralBleedResult>,
    /// Dual-ridge consistency result
    pub dualridge: Option<DualRidgeResult>,
}

impl GradientDiagnosticReport {
    /// Create an empty report
    pub fn new() -> Self {
        Self::default()
    }

    /// Generate a summary string of all issues found
    pub fn summary(&self) -> String {
        let mut lines = Vec::new();

        if let Some(ref audit) = self.envelopeaudit
            && audit.isviolated
        {
            lines.push(format!("[DIAG] {}", audit));
        }

        for bleed in &self.spectral_bleed {
            if bleed.has_bleed {
                lines.push(format!("[DIAG] {}", bleed));
            }
        }

        if let Some(ref ridge) = self.dualridge
            && ridge.has_mismatch
        {
            lines.push(format!("[DIAG] {}", ridge));
        }

        if lines.is_empty() {
            "No gradient diagnostic issues detected.".to_string()
        } else {
            lines.join("\n")
        }
    }
}

// =============================================================================
// Strategy 1: Envelope Theorem (KKT) Audit
// =============================================================================

/// Compute the inner KKT residual to detect envelope theorem violations.
///
/// The analytic gradient calculation assumes that P-IRLS found an exact stationary
/// point where ∇_β L = 0. If this is not true (due to stabilization ridge, Firth
/// adjustments, or early termination), the "indirect term" of the chain rule becomes
/// significant and the gradient will be wrong.
///
/// # Arguments
/// * `kkt_residual_norm` - Norm of the full inner gradient ||∇_β L|| at the PIRLS solution
/// * `referencegradient` - Reference gradient scale (typically S_λ β) for relative normalization
/// * `ridge_used` - Ridge added by PIRLS for stabilization
/// * `beta` - Current coefficient estimate
/// * `tolerance` - Threshold for flagging violations
pub fn compute_envelopeaudit(
    kkt_residual_norm: f64,
    referencegradient: &Array1<f64>,
    ridge_used: f64,
    ridge_assumed: f64,
    beta: &Array1<f64>,
    abs_tolerance: f64,
    rel_tolerance: f64,
) -> EnvelopeAudit {
    let kkt_norm = kkt_residual_norm;
    let penalty_norm = referencegradient.dot(referencegradient).sqrt();
    let beta_norm = beta.dot(beta).sqrt();
    let scale = penalty_norm.max((ridge_assumed.abs() * beta_norm).max(1e-12));
    let rel_kkt = if scale > 0.0 { kkt_norm / scale } else { 0.0 };
    let ridge_mismatch = (ridge_used - ridge_assumed).abs() > 1e-12;
    let kktviolation = kkt_norm > abs_tolerance && rel_kkt > rel_tolerance;
    let isviolated = kktviolation || ridge_mismatch;

    let message = if ridge_mismatch && kktviolation {
        format!(
            "Envelope Violation: Inner solver ridge = {:.2e}, Outer gradient assumes ridge = {:.2e}. \
             KKT residual norm = {:.2e} (abs tol = {:.2e}, rel tol = {:.2e}). Unaccounted gradient energy: {:.2e}",
            ridge_used, ridge_assumed, kkt_norm, abs_tolerance, rel_tolerance, kkt_norm
        )
    } else if ridge_mismatch {
        format!(
            "Ridge Mismatch: PIRLS optimized for H + {:.2e}*I, but Gradient calculated for H + {:.2e}*I",
            ridge_used, ridge_assumed
        )
    } else if kktviolation {
        format!(
            "Envelope Violation: KKT residual ||∇_β L|| = {:.2e} (rel {:.2e}) exceeds tolerances (abs {:.2e}, rel {:.2e}). \
             Inner solver may not have converged to true stationary point.",
            kkt_norm, rel_kkt, abs_tolerance, rel_tolerance
        )
    } else {
        format!(
            "Envelope OK: KKT residual = {:.2e} (rel {:.2e}), ridge match = {:.2e}",
            kkt_norm, rel_kkt, ridge_used
        )
    };

    EnvelopeAudit {
        kkt_residual_norm: kkt_norm,
        innerridge: ridge_used,
        outerridge: ridge_assumed,
        isviolated,
        message,
    }
}

// =============================================================================
// Strategy 4: Dual-Ridge Consistency Check
// =============================================================================

/// Check consistency between the ridge used in different stages of computation.
///
/// When the Hessian is non-positive-definite, ensure_positive_definitewithridge
/// adds a stabilization ridge during P-IRLS. This ridge changes the objective
/// surface being optimized. If the gradient calculation uses a different ridge
/// value, it will point in the wrong direction.
///
/// # Arguments
/// * `pirlsridge` - Ridge actually used during P-IRLS iteration
/// * `costridge` - Ridge used when computing LAML cost
/// * `gradientridge` - Ridge assumed when computing analytic gradient
/// * `beta` - Current coefficient estimate
pub fn compute_dualridge_check(
    pirlsridge: f64,
    costridge: f64,
    gradientridge: f64,
    beta: &Array1<f64>,
) -> DualRidgeResult {
    let beta_norm_sq = beta.dot(beta);
    let beta_norm = beta_norm_sq.sqrt();

    let ridge_impact = pirlsridge * beta_norm;
    let phantom_penalty = 0.5 * pirlsridge * beta_norm_sq;

    let pirlscost_mismatch = (pirlsridge - costridge).abs() > 1e-12;
    let pirlsgrad_mismatch = (pirlsridge - gradientridge).abs() > 1e-12;
    let costgrad_mismatch = (costridge - gradientridge).abs() > 1e-12;
    let has_mismatch = pirlscost_mismatch || pirlsgrad_mismatch || costgrad_mismatch;

    let message = if has_mismatch {
        let mut mismatches = Vec::new();
        if pirlscost_mismatch {
            mismatches.push(format!(
                "PIRLS({:.2e}) vs Cost({:.2e})",
                pirlsridge, costridge
            ));
        }
        if pirlsgrad_mismatch {
            mismatches.push(format!(
                "PIRLS({:.2e}) vs Gradient({:.2e})",
                pirlsridge, gradientridge
            ));
        }
        if costgrad_mismatch {
            mismatches.push(format!(
                "Cost({:.2e}) vs Gradient({:.2e})",
                costridge, gradientridge
            ));
        }
        format!(
            "Ridge Mismatch detected: {}. Effective ridge impact on ||β|| = {:.2e}. \
             Phantom penalty = {:.2e}. The surface being differentiated differs from \
             the surface being optimized.",
            mismatches.join(", "),
            ridge_impact,
            phantom_penalty
        )
    } else if pirlsridge > 0.0 {
        format!(
            "Ridge Consistency OK: All stages use ridge = {:.2e}. ||β|| = {:.2e}, phantom penalty = {:.2e}",
            pirlsridge, beta_norm, phantom_penalty
        )
    } else {
        "Ridge Consistency OK: No stabilization ridge required.".to_string()
    };

    DualRidgeResult {
        pirlsridge,
        costridge,
        gradientridge,
        ridge_impact,
        phantom_penalty,
        has_mismatch,
        message,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::arr1;

    #[test]
    fn test_envelopeaudit_noviolation() {
        let reference = arr1(&[0.0, 0.0, 0.0]);
        let beta = arr1(&[0.1, 0.2, 0.3]);
        let result = compute_envelopeaudit(0.0, &reference, 0.0, 0.0, &beta, 1e-8, 1e-6);

        assert!(!result.isviolated);
    }

    #[test]
    fn test_envelopeaudit_detects_ridge_mismatch() {
        let reference = arr1(&[1.0, 0.0, 0.0]);
        let beta = arr1(&[0.1, 0.2, 0.3]);
        let result = compute_envelopeaudit(1e-10, &reference, 0.1, 0.0, &beta, 1e-8, 1e-6);

        assert!(result.isviolated);
        assert!(result.message.contains("Ridge Mismatch"));
    }

    #[test]
    fn test_dualridge_check_no_mismatch() {
        let beta = arr1(&[0.1, 0.2, 0.3]);
        let result = compute_dualridge_check(0.0, 0.0, 0.0, &beta);

        assert!(!result.has_mismatch);
    }

    #[test]
    fn test_dualridge_check_detects_mismatch() {
        let beta = arr1(&[0.1, 0.2, 0.3]);
        let result = compute_dualridge_check(1e-4, 0.0, 0.0, &beta);

        assert!(result.has_mismatch);
        assert!(result.message.contains("Ridge Mismatch detected"));
    }
}