Skip to main content

pounce_studio_core/
analysis.rs

1//! Derived series and diagnostics over a [`SolveReport`].
2//!
3//! Mirrors the Python analysis helpers in `studio/mcp/pounce_studio_mcp/
4//! reports.py` so the desktop / VS Code shells and the MCP server can
5//! agree on the same notion of "stall window", "restoration window",
6//! and "common failure modes". Heuristics are tunable via the
7//! parameters on each function; the defaults are the ones the Python
8//! `diagnose` tool ships with.
9
10use serde::{Deserialize, Serialize};
11
12use crate::report::{Error, IterRecord, SolveReport};
13
14/// Compact view-model derived from a [`SolveReport`]. Suitable as the
15/// "headline summary" that an LLM or dashboard reads first.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Summary {
18    pub schema: String,
19    pub result_id: String,
20    pub solver: String,
21    pub solver_version: String,
22    pub elapsed_seconds: f64,
23    pub n_variables: i32,
24    pub n_constraints: i32,
25    pub status: String,
26    pub final_objective: f64,
27    pub iteration_count: i32,
28    pub final_kkt_error: f64,
29    pub final_dual_inf: f64,
30    pub final_constr_viol: f64,
31    pub final_compl: f64,
32    pub restoration_calls: i32,
33    pub restoration_outer_iters: i32,
34    pub restoration_wall_secs: f64,
35    pub iterations_captured: usize,
36}
37
38pub fn summarize(report: &SolveReport) -> Summary {
39    Summary {
40        schema: report.schema.clone(),
41        result_id: report.fair_metadata.result_id.clone(),
42        solver: report.fair_metadata.solver.name.clone(),
43        solver_version: report.fair_metadata.solver.version.clone(),
44        elapsed_seconds: report.fair_metadata.elapsed_seconds,
45        n_variables: report.problem.n_variables,
46        n_constraints: report.problem.n_constraints,
47        status: report.solution.status.clone(),
48        final_objective: report.statistics.final_objective,
49        iteration_count: report.statistics.iteration_count,
50        final_kkt_error: report.statistics.final_kkt_error,
51        final_dual_inf: report.statistics.final_dual_inf,
52        final_constr_viol: report.statistics.final_constr_viol,
53        final_compl: report.statistics.final_compl,
54        restoration_calls: report.statistics.restoration_calls,
55        restoration_outer_iters: report.statistics.restoration_outer_iters,
56        restoration_wall_secs: report.statistics.restoration_wall_secs,
57        iterations_captured: report.iterations.len(),
58    }
59}
60
61/// Per-iteration trajectory in column-oriented form. More compact than
62/// a `Vec<IterRecord>` when serialised, since the column names are
63/// emitted once.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ConvergenceTrace {
66    pub iter: Vec<i32>,
67    pub objective: Vec<f64>,
68    pub inf_pr: Vec<f64>,
69    pub inf_du: Vec<f64>,
70    pub mu: Vec<f64>,
71    pub d_norm: Vec<f64>,
72    pub regularization: Vec<f64>,
73    pub alpha_dual: Vec<f64>,
74    pub alpha_primal: Vec<f64>,
75    pub alpha_primal_char: Vec<char>,
76    pub ls_trials: Vec<i32>,
77}
78
79pub fn convergence_trace(report: &SolveReport) -> ConvergenceTrace {
80    let n = report.iterations.len();
81    let mut t = ConvergenceTrace {
82        iter: Vec::with_capacity(n),
83        objective: Vec::with_capacity(n),
84        inf_pr: Vec::with_capacity(n),
85        inf_du: Vec::with_capacity(n),
86        mu: Vec::with_capacity(n),
87        d_norm: Vec::with_capacity(n),
88        regularization: Vec::with_capacity(n),
89        alpha_dual: Vec::with_capacity(n),
90        alpha_primal: Vec::with_capacity(n),
91        alpha_primal_char: Vec::with_capacity(n),
92        ls_trials: Vec::with_capacity(n),
93    };
94    for r in &report.iterations {
95        t.iter.push(r.iter);
96        t.objective.push(r.objective);
97        t.inf_pr.push(r.inf_pr);
98        t.inf_du.push(r.inf_du);
99        t.mu.push(r.mu);
100        t.d_norm.push(r.d_norm);
101        t.regularization.push(r.regularization);
102        t.alpha_dual.push(r.alpha_dual);
103        t.alpha_primal.push(r.alpha_primal);
104        t.alpha_primal_char.push(r.alpha_primal_char);
105        t.ls_trials.push(r.ls_trials);
106    }
107    t
108}
109
110/// One stalled-progress window: consecutive iterations whose
111/// log10-residual moved by less than the configured threshold.
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct Stall {
114    pub start_iter: i32,
115    pub end_iter: i32,
116    pub metric: &'static str,
117    pub delta_log10: f64,
118}
119
120/// Default stall detection: 5+ consecutive iters with <0.3 orders of
121/// magnitude movement in either `inf_pr` or `inf_du`.
122pub fn find_stalls(report: &SolveReport) -> Vec<Stall> {
123    find_stalls_with(report, 5, 0.3)
124}
125
126pub fn find_stalls_with(
127    report: &SolveReport,
128    min_window: usize,
129    max_log10_progress: f64,
130) -> Vec<Stall> {
131    let mut out = Vec::new();
132    for (metric, series) in [
133        ("inf_pr", series_log10(&report.iterations, |r| r.inf_pr)),
134        ("inf_du", series_log10(&report.iterations, |r| r.inf_du)),
135    ] {
136        scan_stalls(
137            &series,
138            &report.iterations,
139            metric,
140            min_window,
141            max_log10_progress,
142            &mut out,
143        );
144    }
145    out
146}
147
148fn series_log10<F: Fn(&IterRecord) -> f64>(iters: &[IterRecord], f: F) -> Vec<Option<f64>> {
149    iters
150        .iter()
151        .map(|r| {
152            let v = f(r);
153            if v > 0.0 && v.is_finite() {
154                Some(v.log10())
155            } else {
156                None
157            }
158        })
159        .collect()
160}
161
162fn scan_stalls(
163    series: &[Option<f64>],
164    iters: &[IterRecord],
165    metric: &'static str,
166    min_window: usize,
167    max_log10_progress: f64,
168    out: &mut Vec<Stall>,
169) {
170    let mut i = 0;
171    let n = series.len();
172    while i < n {
173        if series[i].is_none() {
174            i += 1;
175            continue;
176        }
177        // Greedy: extend j while [i..=j] remains a stall.
178        let mut j = i;
179        let mut win_min = series[i].unwrap_or(0.0);
180        let mut win_max = win_min;
181        while j + 1 < n {
182            let Some(next) = series[j + 1] else {
183                break;
184            };
185            let new_min = win_min.min(next);
186            let new_max = win_max.max(next);
187            if new_max - new_min > max_log10_progress {
188                break;
189            }
190            win_min = new_min;
191            win_max = new_max;
192            j += 1;
193        }
194        if j - i + 1 >= min_window {
195            out.push(Stall {
196                start_iter: iters[i].iter,
197                end_iter: iters[j].iter,
198                metric,
199                delta_log10: win_max - win_min,
200            });
201            i = j + 1;
202        } else {
203            i += 1;
204        }
205    }
206}
207
208/// Contiguous runs of iters tagged `'r'` in the alpha-primal char
209/// column — one entry per restoration entry → exit cycle.
210#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211pub struct RestorationWindow {
212    pub start_iter: i32,
213    pub end_iter: i32,
214}
215
216pub fn restoration_windows(report: &SolveReport) -> Vec<RestorationWindow> {
217    let mut out: Vec<RestorationWindow> = Vec::new();
218    let mut current: Option<RestorationWindow> = None;
219    for r in &report.iterations {
220        if r.alpha_primal_char.to_ascii_lowercase() == 'r' {
221            match &mut current {
222                Some(w) => w.end_iter = r.iter,
223                None => {
224                    current = Some(RestorationWindow {
225                        start_iter: r.iter,
226                        end_iter: r.iter,
227                    })
228                }
229            }
230        } else if let Some(w) = current.take() {
231            out.push(w);
232        }
233    }
234    if let Some(w) = current {
235        out.push(w);
236    }
237    out
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
241#[serde(rename_all = "lowercase")]
242pub enum Severity {
243    Info,
244    Warning,
245    Error,
246}
247
248/// One finding from [`diagnose`].
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct Finding {
251    pub severity: Severity,
252    /// Stable machine-readable identifier (e.g. `"max_iter_exceeded"`).
253    pub code: &'static str,
254    pub message: String,
255}
256
257/// Run common Ipopt-failure heuristics and return all findings.
258///
259/// Heuristics:
260/// - `converged` (info): solver succeeded
261/// - `max_iter_exceeded` (error): hit max_iter without converging
262/// - `restoration_used` (warning): restoration phase entered ≥1 times
263/// - `restoration_loop` (warning): multiple restoration entries
264/// - `mu_stuck` (warning): barrier parameter barely decreased
265/// - `heavy_line_search` (warning): backtracking ≥10 trials
266/// - `hessian_regularized` (info): δ_w applied on any iter
267/// - `convergence_stall` (warning): suppressed on clean convergence
268///   unless the stall window is long (≥8 iters)
269pub fn diagnose(report: &SolveReport) -> Vec<Finding> {
270    let mut findings = Vec::new();
271    let stats = &report.statistics;
272    let solution = &report.solution;
273    let iters = &report.iterations;
274    let status = solution.status.as_str();
275
276    if status == "SolveSucceeded" {
277        findings.push(Finding {
278            severity: Severity::Info,
279            code: "converged",
280            message: format!(
281                "Solver converged in {} iterations to objective {:.6e}; KKT error {:.2e}.",
282                stats.iteration_count, stats.final_objective, stats.final_kkt_error,
283            ),
284        });
285    } else if status == "MaximumIterationsExceeded" {
286        findings.push(Finding {
287            severity: Severity::Error,
288            code: "max_iter_exceeded",
289            message: format!(
290                "Hit max_iter without converging. KKT error at termination: {:.2e}. \
291                 Consider raising max_iter, tightening initial guess, or relaxing tol.",
292                stats.final_kkt_error,
293            ),
294        });
295    }
296
297    if stats.restoration_calls > 0 {
298        findings.push(Finding {
299            severity: Severity::Warning,
300            code: "restoration_used",
301            message: format!(
302                "Restoration phase entered {} time(s); {} outer iters spent in \
303                 restoration ({:.3}s). Indicates the line search couldn't make \
304                 progress on the original problem.",
305                stats.restoration_calls, stats.restoration_outer_iters, stats.restoration_wall_secs,
306            ),
307        });
308    }
309
310    if iters.len() >= 10 {
311        let mu_first = iters[..3].iter().map(|r| r.mu).fold(0.0_f64, f64::max);
312        let mu_last = iters[iters.len() - 3..]
313            .iter()
314            .map(|r| r.mu)
315            .fold(f64::INFINITY, f64::min);
316        if mu_first > 0.0 && mu_last > 0.0 {
317            let log_drop = mu_first.log10() - mu_last.log10();
318            if log_drop < 1.0 {
319                findings.push(Finding {
320                    severity: Severity::Warning,
321                    code: "mu_stuck",
322                    message: format!(
323                        "Barrier parameter μ dropped only {log_drop:.2} orders of magnitude across \
324                         {} iterations (from {mu_first:.2e} to {mu_last:.2e}). Try \
325                         mu_strategy=adaptive or a smaller mu_init.",
326                        iters.len(),
327                    ),
328                });
329            }
330        }
331    }
332
333    let heavy_ls: Vec<&IterRecord> = iters.iter().filter(|r| r.ls_trials >= 10).collect();
334    if let Some(worst) = heavy_ls.iter().max_by_key(|r| r.ls_trials) {
335        findings.push(Finding {
336            severity: Severity::Warning,
337            code: "heavy_line_search",
338            message: format!(
339                "{} iteration(s) needed >=10 backtracking trials (worst: iter {} with {} \
340                 trials). Search direction quality may be poor — check Hessian accuracy.",
341                heavy_ls.len(),
342                worst.iter,
343                worst.ls_trials,
344            ),
345        });
346    }
347
348    let big_reg: Vec<f64> = iters
349        .iter()
350        .map(|r| r.regularization)
351        .filter(|&r| r > 1e-4)
352        .collect();
353    if !big_reg.is_empty() {
354        let max_reg = big_reg.iter().copied().fold(0.0_f64, f64::max);
355        findings.push(Finding {
356            severity: Severity::Info,
357            code: "hessian_regularized",
358            message: format!(
359                "Hessian regularization applied on {} iteration(s) (max δ_w = {max_reg:.2e}). \
360                 The KKT system was indefinite; this is normal near saddle points but \
361                 persistent regularization suggests a problematic Hessian.",
362                big_reg.len(),
363            ),
364        });
365    }
366
367    let rwins = restoration_windows(report);
368    if rwins.len() > 1 {
369        findings.push(Finding {
370            severity: Severity::Warning,
371            code: "restoration_loop",
372            message: format!(
373                "Restoration was entered {} separate times. Repeated re-entry often means \
374                 the problem is infeasible at the working point. Verify constraints.",
375                rwins.len(),
376            ),
377        });
378    }
379
380    let stalls = find_stalls(report);
381    if !stalls.is_empty() {
382        let longest = stalls
383            .iter()
384            .map(|s| (s.end_iter - s.start_iter + 1) as usize)
385            .max()
386            .unwrap_or(0);
387        if status != "SolveSucceeded" || longest >= 8 {
388            findings.push(Finding {
389                severity: Severity::Warning,
390                code: "convergence_stall",
391                message: format!(
392                    "Detected {} stall window(s) where log-residual barely moved (longest: {} \
393                     iters). Either the problem is ill-conditioned, scaling is off, or \
394                     termination tolerance is too tight.",
395                    stalls.len(),
396                    longest,
397                ),
398            });
399        }
400    }
401
402    findings
403}
404
405/// Augmented [`IterRecord`] returned by [`get_iterate`]: the raw row
406/// plus derived log10 values handy for tooltip / LLM rendering.
407#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct AugmentedIterate {
409    #[serde(flatten)]
410    pub raw: IterRecord,
411    pub log10_inf_pr: Option<f64>,
412    pub log10_inf_du: Option<f64>,
413    pub log10_mu: Option<f64>,
414}
415
416pub fn get_iterate(report: &SolveReport, k: usize) -> Result<AugmentedIterate, Error> {
417    let n = report.iterations.len();
418    if n == 0 {
419        return Err(Error::NoIterations);
420    }
421    if k >= n {
422        return Err(Error::IterOutOfRange { k, n });
423    }
424    let raw = report.iterations[k].clone();
425    Ok(AugmentedIterate {
426        log10_inf_pr: safe_log10(raw.inf_pr),
427        log10_inf_du: safe_log10(raw.inf_du),
428        log10_mu: safe_log10(raw.mu),
429        raw,
430    })
431}
432
433fn safe_log10(x: f64) -> Option<f64> {
434    if x > 0.0 && x.is_finite() {
435        Some(x.log10())
436    } else {
437        None
438    }
439}
440
441/// One row in a side-by-side comparison.
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct CompareRow {
444    pub label: String,
445    pub status: String,
446    pub iter_count: i32,
447    pub final_objective: f64,
448    pub final_kkt_error: f64,
449    pub restoration_calls: i32,
450    pub elapsed_seconds: f64,
451}
452
453pub fn compare_runs<'a, I>(runs: I) -> Vec<CompareRow>
454where
455    I: IntoIterator<Item = (&'a str, &'a SolveReport)>,
456{
457    runs.into_iter()
458        .map(|(label, r)| CompareRow {
459            label: label.to_string(),
460            status: r.solution.status.clone(),
461            iter_count: r.statistics.iteration_count,
462            final_objective: r.statistics.final_objective,
463            final_kkt_error: r.statistics.final_kkt_error,
464            restoration_calls: r.statistics.restoration_calls,
465            elapsed_seconds: r.fair_metadata.elapsed_seconds,
466        })
467        .collect()
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::report::IterRecord;
474
475    fn iter(idx: i32, mu: f64, inf_du: f64) -> IterRecord {
476        IterRecord {
477            iter: idx,
478            inf_du,
479            mu,
480            alpha_primal_char: 'f',
481            ..IterRecord::default()
482        }
483    }
484
485    fn report_with(iters: Vec<IterRecord>) -> SolveReport {
486        use crate::report::*;
487        SolveReport {
488            schema: SOLVE_REPORT_SCHEMA.into(),
489            fair_metadata: FairMetadata {
490                result_id: "t".into(),
491                created_at_iso: "2026-05-24T00:00:00.000Z".into(),
492                created_at_unix_nanos: 0,
493                elapsed_seconds: 0.0,
494                solver: SolverIdentity {
495                    name: "pounce".into(),
496                    version: "0.0.0".into(),
497                    git_commit: None,
498                    target_triple: "test".into(),
499                },
500                license: "EPL-2.0".into(),
501                input: InputDescriptor::TnlpDirect,
502            },
503            problem: ProblemInfo {
504                n_variables: 1,
505                n_constraints: 0,
506                n_objectives: 1,
507                minimize: true,
508                nnz_jac_g: None,
509                nnz_h_lag: None,
510            },
511            solution: SolutionInfo {
512                status: "SolveSucceeded".into(),
513                solve_result_num: 0,
514                objective: 0.0,
515                x: vec![],
516                lambda: vec![],
517                suffixes: vec![],
518            },
519            statistics: StatisticsInfo {
520                iteration_count: iters.len() as i32,
521                final_objective: 0.0,
522                final_scaled_objective: 0.0,
523                final_dual_inf: 0.0,
524                final_constr_viol: 0.0,
525                final_compl: 0.0,
526                final_kkt_error: 0.0,
527                num_obj_evals: 0,
528                num_constr_evals: 0,
529                num_obj_grad_evals: 0,
530                num_constr_jac_evals: 0,
531                num_hess_evals: 0,
532                total_wallclock_time_secs: 0.0,
533                restoration_calls: 0,
534                restoration_inner_iters: 0,
535                restoration_outer_iters: 0,
536                restoration_wall_secs: 0.0,
537            },
538            iterations: iters,
539            linear_solver: None,
540        }
541    }
542
543    #[test]
544    fn stall_detection_flat_residual() {
545        // 5 iters where inf_du barely moves -> one stall.
546        let iters = (0..5)
547            .map(|i| iter(i, 0.1, 1e-3 + (i as f64) * 1e-6))
548            .collect();
549        let stalls = find_stalls(&report_with(iters));
550        assert_eq!(stalls.len(), 1);
551        assert_eq!(stalls[0].start_iter, 0);
552        assert_eq!(stalls[0].end_iter, 4);
553    }
554
555    #[test]
556    fn stall_detection_progress_not_flagged() {
557        // 5 iters where inf_du drops by orders of magnitude each step.
558        let iters = (0..5).map(|i| iter(i, 0.1, 10f64.powi(-i))).collect();
559        let stalls = find_stalls(&report_with(iters));
560        assert!(stalls.is_empty(), "got {stalls:?}");
561    }
562
563    #[test]
564    fn restoration_windows_grouped() {
565        let mut iters = vec![iter(0, 0.1, 1e-2), iter(1, 0.1, 1e-3)];
566        for i in 2..5 {
567            let mut r = iter(i, 0.1, 1e-3);
568            r.alpha_primal_char = 'r';
569            iters.push(r);
570        }
571        iters.push(iter(5, 0.1, 1e-4));
572        let windows = restoration_windows(&report_with(iters));
573        assert_eq!(windows.len(), 1);
574        assert_eq!(windows[0].start_iter, 2);
575        assert_eq!(windows[0].end_iter, 4);
576    }
577
578    #[test]
579    fn get_iterate_out_of_range() {
580        let report = report_with(vec![iter(0, 0.1, 1e-3)]);
581        assert!(matches!(
582            get_iterate(&report, 5),
583            Err(Error::IterOutOfRange { k: 5, n: 1 }),
584        ));
585    }
586
587    #[test]
588    fn diagnose_clean_convergence_no_stall_warning() {
589        // Quick converging run: just the `converged` finding, no stall noise.
590        let iters: Vec<IterRecord> = (0..5)
591            .map(|i| iter(i, 10f64.powi(-(i + 1)), 10f64.powi(-i)))
592            .collect();
593        let findings = diagnose(&report_with(iters));
594        let codes: Vec<&str> = findings.iter().map(|f| f.code).collect();
595        assert!(codes.contains(&"converged"), "got {codes:?}");
596        assert!(
597            !codes.contains(&"convergence_stall"),
598            "stall shouldn't trip on healthy convergence: {codes:?}",
599        );
600    }
601}