Skip to main content

pounce_algorithm/mu/oracle/
quality_function.rs

1//! Quality-function mu oracle — port of
2//! `IpQualityFunctionMuOracle.{hpp,cpp}`. Phase 10.
3//!
4//! The oracle picks `μ_new = σ * avrg_compl` by minimizing a 1-D
5//! quality function `q(σ)` over `σ ∈ [σ_lo, σ_up]` via golden section.
6//! The full vector-valued evaluator (which builds the trial slack /
7//! multiplier vectors at a candidate σ and reduces them to a scalar
8//! norm) is split into two pieces:
9//!
10//! * `evaluate_quality_function` — a *pure-scalar* reducer that takes
11//!   already-computed `‖·‖` aggregates and combines them per the
12//!   `(norm, centrality, balancing)` triple per
13//!   `IpQualityFunctionMuOracle.cpp:566-658`. The vector→aggregate
14//!   step is the caller's responsibility.
15//! * `pick_sigma` — orchestrator that mirrors
16//!   `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 329-385: picks
17//!   the σ-bracket, evaluates `q(1)` and `q(1−ε)` to decide whether
18//!   to search above or below 1, then drives `golden_section`.
19//!
20//! Wiring `pick_sigma` to a fully populated `q(σ)` evaluator —
21//! including the centering predictor solve — is the remaining scope.
22
23use crate::ipopt_cq::IpoptCqHandle;
24use crate::ipopt_data::IpoptDataHandle;
25use crate::ipopt_nlp::IpoptNlp;
26use crate::iterates_vector::IteratesVector;
27use crate::kkt::pd_search_dir_calc::PdSearchDirCalc;
28use crate::mu::oracle::r#trait::MuOracle;
29use pounce_common::types::Number;
30use pounce_linalg::Vector;
31use std::cell::RefCell;
32use std::rc::Rc;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum NormType {
36    OneNorm,
37    /// Squared 2-norm — upstream `NM_NORM_2_SQUARED` (default).
38    /// Aggregates are `||·||²` (no sqrt) and `(1−α)²` weighting.
39    TwoNormSquared,
40    TwoNorm,
41    MaxNorm,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum CentralityType {
46    None,
47    LogCenter,
48    ReciprocalCenter,
49    CubedReciprocalCenter,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum BalancingTermType {
54    None,
55    CubicTerm,
56}
57
58pub struct QualityFunctionMuOracle {
59    pub norm_type: NormType,
60    pub centrality_type: CentralityType,
61    pub balancing_term: BalancingTermType,
62    pub max_section_steps: i32,
63    pub section_sigma_tol: Number,
64    pub section_qf_tol: Number,
65    pub sigma_max: Number,
66    pub sigma_min: Number,
67    pub mu_min: Number,
68    pub mu_max: Number,
69}
70
71impl Default for QualityFunctionMuOracle {
72    fn default() -> Self {
73        // Defaults from `IpQualityFunctionMuOracle.cpp:RegisterOptions`.
74        Self {
75            norm_type: NormType::TwoNormSquared,
76            centrality_type: CentralityType::None,
77            balancing_term: BalancingTermType::None,
78            max_section_steps: 8,
79            section_sigma_tol: 1e-2,
80            section_qf_tol: 0.0,
81            sigma_max: 100.0,
82            // Upstream `IpQualityFunctionMuOracle.cpp:62-69`
83            // `RegisterOptions` default is 1e-6, not 1e-9. Setting it
84            // too low lets golden-section collapse σ all the way to
85            // the floor on outer iterations where q(σ) is nearly
86            // flat over the bracket — which then drives μ to ~1e-11
87            // in a single step and triggers a kappa_sigma blow-up
88            // that pushes the algorithm into restoration. (HS1NE
89            // and ~50 other CUTEst problems exhibited this.)
90            sigma_min: 1e-6,
91            mu_min: 1e-11,
92            mu_max: 1e5,
93        }
94    }
95}
96
97impl QualityFunctionMuOracle {
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Drive the predictor + centring solves through `pd_search_dir`,
103    /// project the results onto the four bound-mask subspaces, then
104    /// run [`pick_sigma`] over a `q(σ)` closure that builds the σ-step
105    /// and reduces it to [`QualityFunctionAggregates`] before invoking
106    /// [`evaluate_quality_function`]. Mirrors upstream
107    /// `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 188-485.
108    ///
109    /// Returns `None` if either linear solve fails (caller falls back
110    /// to LOQO, matching upstream's
111    /// `IpAdaptiveMuUpdate.cpp::CalculateMuFromOracle:330-340`).
112    #[allow(clippy::too_many_lines)]
113    pub fn calculate_mu_with_predictor_centering(
114        &mut self,
115        data: &IpoptDataHandle,
116        cq: &IpoptCqHandle,
117        nlp: &Rc<RefCell<dyn IpoptNlp>>,
118        pd_search_dir: &mut PdSearchDirCalc,
119    ) -> Option<Number> {
120        if !pd_search_dir.compute_affine_step(data, cq, nlp) {
121            return None;
122        }
123        if !pd_search_dir.compute_centering_step(data, cq, nlp) {
124            return None;
125        }
126
127        let delta_aff: IteratesVector = data.borrow().delta_aff.clone()?;
128        let delta_cen: IteratesVector = data.borrow().delta_cen.clone()?;
129
130        // Project step.x onto the bound subspaces. Each block matches
131        // the `step_aff_x_L = P_L^T·δ_aff_x` setup in
132        // `IpQualityFunctionMuOracle.cpp:308-323`.
133        let nlp_ref = nlp.borrow();
134        let cq_ref = cq.borrow();
135        let curr_iv = cq_ref.curr_iv();
136        let curr_slack_x_l = cq_ref.curr_slack_x_l();
137        let curr_slack_x_u = cq_ref.curr_slack_x_u();
138        let curr_slack_s_l = cq_ref.curr_slack_s_l();
139        let curr_slack_s_u = cq_ref.curr_slack_s_u();
140        let avrg_compl = cq_ref.curr_avrg_compl();
141
142        let project = |sign_l_x: Number,
143                       sign_u_x: Number,
144                       step_x: &dyn Vector,
145                       step_s: &dyn Vector|
146         -> [Rc<dyn Vector>; 4] {
147            let mut x_l = curr_slack_x_l.make_new();
148            nlp_ref
149                .px_l()
150                .trans_mult_vector(sign_l_x, step_x, 0.0, &mut *x_l);
151            let mut x_u = curr_slack_x_u.make_new();
152            nlp_ref
153                .px_u()
154                .trans_mult_vector(sign_u_x, step_x, 0.0, &mut *x_u);
155            let mut s_l = curr_slack_s_l.make_new();
156            nlp_ref
157                .pd_l()
158                .trans_mult_vector(sign_l_x, step_s, 0.0, &mut *s_l);
159            let mut s_u = curr_slack_s_u.make_new();
160            nlp_ref
161                .pd_u()
162                .trans_mult_vector(sign_u_x, step_s, 0.0, &mut *s_u);
163            [Rc::from(x_l), Rc::from(x_u), Rc::from(s_l), Rc::from(s_u)]
164        };
165
166        let [step_aff_x_l, step_aff_x_u, step_aff_s_l, step_aff_s_u] =
167            project(1.0, -1.0, &*delta_aff.x, &*delta_aff.s);
168        let [step_cen_x_l, step_cen_x_u, step_cen_s_l, step_cen_s_u] =
169            project(1.0, -1.0, &*delta_cen.x, &*delta_cen.s);
170
171        // The z/v step blocks are stored directly on the iterate — no
172        // projection needed (upstream lines 318-323 use the raw blocks).
173        let step_aff_z_l = delta_aff.z_l.clone();
174        let step_aff_z_u = delta_aff.z_u.clone();
175        let step_aff_v_l = delta_aff.v_l.clone();
176        let step_aff_v_u = delta_aff.v_u.clone();
177        let step_cen_z_l = delta_cen.z_l.clone();
178        let step_cen_z_u = delta_cen.z_u.clone();
179        let step_cen_v_l = delta_cen.v_l.clone();
180        let step_cen_v_u = delta_cen.v_u.clone();
181
182        // Drop the immutable nlp borrow before invoking CQ accessors
183        // that may take a `nlp.borrow_mut()` (e.g. `curr_grad_lag_x` →
184        // `curr_grad_f` → `nlp.eval_grad_f`).
185        drop(nlp_ref);
186
187        // Constant-in-σ aggregates: `dual_aggr` from ‖∇L_x‖, ‖∇L_s‖;
188        // `primal_aggr` from ‖c‖, ‖d−s‖. Norm choice driven by
189        // `self.norm_type`. Upstream `cpp:283-303`.
190        let grad_lag_x = cq_ref.curr_grad_lag_x();
191        let grad_lag_s = cq_ref.curr_grad_lag_s();
192        let c = cq_ref.curr_c();
193        let d_minus_s = cq_ref.curr_d_minus_s();
194        let dual_aggr = match self.norm_type {
195            NormType::OneNorm => grad_lag_x.asum() + grad_lag_s.asum(),
196            NormType::TwoNormSquared => {
197                let nx = grad_lag_x.nrm2();
198                let ns = grad_lag_s.nrm2();
199                nx * nx + ns * ns
200            }
201            NormType::TwoNorm => {
202                let nx = grad_lag_x.nrm2();
203                let ns = grad_lag_s.nrm2();
204                (nx * nx + ns * ns).sqrt()
205            }
206            NormType::MaxNorm => grad_lag_x.amax().max(grad_lag_s.amax()),
207        };
208        let primal_aggr = match self.norm_type {
209            NormType::OneNorm => c.asum() + d_minus_s.asum(),
210            NormType::TwoNormSquared => {
211                let nc = c.nrm2();
212                let nd = d_minus_s.nrm2();
213                nc * nc + nd * nd
214            }
215            NormType::TwoNorm => {
216                let nc = c.nrm2();
217                let nd = d_minus_s.nrm2();
218                (nc * nc + nd * nd).sqrt()
219            }
220            NormType::MaxNorm => c.amax().max(d_minus_s.amax()),
221        };
222
223        let n_dual = curr_iv.x.dim() + curr_iv.s.dim();
224        let n_pri = curr_iv.y_c.dim() + curr_iv.y_d.dim();
225        let n_comp = curr_iv.z_l.dim() + curr_iv.z_u.dim() + curr_iv.v_l.dim() + curr_iv.v_u.dim();
226        let tau = data.borrow().curr_tau;
227
228        let curr_z_l = curr_iv.z_l.clone();
229        let curr_z_u = curr_iv.z_u.clone();
230        let curr_v_l = curr_iv.v_l.clone();
231        let curr_v_u = curr_iv.v_u.clone();
232
233        drop(cq_ref);
234
235        let norm_type = self.norm_type;
236        let centrality = self.centrality_type;
237        let balancing = self.balancing_term;
238
239        // q(σ) closure. Captures the eight aff/cen step projections,
240        // the four current slacks, the four current bound multipliers,
241        // and the constant aggregates; pure scalar work per call.
242        let mut eval_q = |sigma: Number| -> Number {
243            // step_σ = step_aff + σ · step_cen, projected blocks.
244            let combine = |aff: &Rc<dyn Vector>, cen: &Rc<dyn Vector>| -> Box<dyn Vector> {
245                let mut out = aff.make_new();
246                out.set(0.0);
247                out.add_two_vectors(1.0, &**aff, sigma, &**cen, 0.0);
248                out
249            };
250            let stp_x_l = combine(&step_aff_x_l, &step_cen_x_l);
251            let stp_x_u = combine(&step_aff_x_u, &step_cen_x_u);
252            let stp_s_l = combine(&step_aff_s_l, &step_cen_s_l);
253            let stp_s_u = combine(&step_aff_s_u, &step_cen_s_u);
254            let stp_z_l = combine(&step_aff_z_l, &step_cen_z_l);
255            let stp_z_u = combine(&step_aff_z_u, &step_cen_z_u);
256            let stp_v_l = combine(&step_aff_v_l, &step_cen_v_l);
257            let stp_v_u = combine(&step_aff_v_u, &step_cen_v_u);
258
259            // α_pri = min over slacks of frac_to_bound(curr_slack, step, τ).
260            let alpha_pri = curr_slack_x_l
261                .frac_to_bound(&*stp_x_l, tau)
262                .min(curr_slack_x_u.frac_to_bound(&*stp_x_u, tau))
263                .min(curr_slack_s_l.frac_to_bound(&*stp_s_l, tau))
264                .min(curr_slack_s_u.frac_to_bound(&*stp_s_u, tau));
265            let alpha_du = curr_z_l
266                .frac_to_bound(&*stp_z_l, tau)
267                .min(curr_z_u.frac_to_bound(&*stp_z_u, tau))
268                .min(curr_v_l.frac_to_bound(&*stp_v_l, tau))
269                .min(curr_v_u.frac_to_bound(&*stp_v_u, tau));
270
271            // Build σ-step trial slacks/duals: trial = curr + α·step.
272            let mut trial_s_x_l = curr_slack_x_l.make_new();
273            trial_s_x_l.set(0.0);
274            trial_s_x_l.add_two_vectors(1.0, &*curr_slack_x_l, alpha_pri, &*stp_x_l, 0.0);
275            let mut trial_s_x_u = curr_slack_x_u.make_new();
276            trial_s_x_u.set(0.0);
277            trial_s_x_u.add_two_vectors(1.0, &*curr_slack_x_u, alpha_pri, &*stp_x_u, 0.0);
278            let mut trial_s_s_l = curr_slack_s_l.make_new();
279            trial_s_s_l.set(0.0);
280            trial_s_s_l.add_two_vectors(1.0, &*curr_slack_s_l, alpha_pri, &*stp_s_l, 0.0);
281            let mut trial_s_s_u = curr_slack_s_u.make_new();
282            trial_s_s_u.set(0.0);
283            trial_s_s_u.add_two_vectors(1.0, &*curr_slack_s_u, alpha_pri, &*stp_s_u, 0.0);
284
285            let mut trial_z_l = curr_z_l.make_new();
286            trial_z_l.set(0.0);
287            trial_z_l.add_two_vectors(1.0, &*curr_z_l, alpha_du, &*stp_z_l, 0.0);
288            let mut trial_z_u = curr_z_u.make_new();
289            trial_z_u.set(0.0);
290            trial_z_u.add_two_vectors(1.0, &*curr_z_u, alpha_du, &*stp_z_u, 0.0);
291            let mut trial_v_l = curr_v_l.make_new();
292            trial_v_l.set(0.0);
293            trial_v_l.add_two_vectors(1.0, &*curr_v_l, alpha_du, &*stp_v_l, 0.0);
294            let mut trial_v_u = curr_v_u.make_new();
295            trial_v_u.set(0.0);
296            trial_v_u.add_two_vectors(1.0, &*curr_v_u, alpha_du, &*stp_v_u, 0.0);
297
298            // Complementarity products at the σ-trial point.
299            trial_s_x_l.element_wise_multiply(&*trial_z_l);
300            trial_s_x_u.element_wise_multiply(&*trial_z_u);
301            trial_s_s_l.element_wise_multiply(&*trial_v_l);
302            trial_s_s_u.element_wise_multiply(&*trial_v_u);
303
304            let compl_aggr = match norm_type {
305                NormType::OneNorm => {
306                    trial_s_x_l.asum()
307                        + trial_s_x_u.asum()
308                        + trial_s_s_l.asum()
309                        + trial_s_s_u.asum()
310                }
311                NormType::TwoNormSquared => {
312                    let a = trial_s_x_l.nrm2();
313                    let b = trial_s_x_u.nrm2();
314                    let c = trial_s_s_l.nrm2();
315                    let d = trial_s_s_u.nrm2();
316                    a * a + b * b + c * c + d * d
317                }
318                NormType::TwoNorm => {
319                    let a = trial_s_x_l.nrm2();
320                    let b = trial_s_x_u.nrm2();
321                    let c = trial_s_s_l.nrm2();
322                    let d = trial_s_s_u.nrm2();
323                    (a * a + b * b + c * c + d * d).sqrt()
324                }
325                NormType::MaxNorm => trial_s_x_l
326                    .amax()
327                    .max(trial_s_x_u.amax())
328                    .max(trial_s_s_l.amax())
329                    .max(trial_s_s_u.amax()),
330            };
331
332            let xi = if matches!(centrality, CentralityType::None) {
333                1.0
334            } else {
335                // Centrality: min(s_i z_i) / avg(s_i z_i). Cheap proxy
336                // when centrality != None — upstream computes the same
337                // ratio at line 612 onward.
338                let total = trial_s_x_l.asum()
339                    + trial_s_x_u.asum()
340                    + trial_s_s_l.asum()
341                    + trial_s_s_u.asum();
342                let avg = if n_comp > 0 {
343                    total / n_comp as Number
344                } else {
345                    1.0
346                };
347                let mn = trial_s_x_l
348                    .min()
349                    .min(trial_s_x_u.min())
350                    .min(trial_s_s_l.min())
351                    .min(trial_s_s_u.min());
352                if avg > 0.0 {
353                    mn / avg
354                } else {
355                    1.0
356                }
357            };
358
359            let aggr = QualityFunctionAggregates {
360                dual_aggr,
361                primal_aggr,
362                compl_aggr,
363                n_dual,
364                n_pri,
365                n_comp,
366            };
367
368            if std::env::var("POUNCE_DBG_QF_AGGR").is_ok() {
369                tracing::debug!(target: "pounce::mu",
370                    "[QF_AGGR] σ={:.6e} α_pri={:.6e} α_du={:.6e} xi={:.6e} dual_aggr={:.6e} primal_aggr={:.6e} compl_aggr={:.6e} n_dual={} n_pri={} n_comp={}",
371                    sigma, alpha_pri, alpha_du, xi,
372                    dual_aggr, primal_aggr, compl_aggr,
373                    n_dual, n_pri, n_comp
374                );
375            }
376
377            evaluate_quality_function(
378                norm_type, centrality, balancing, alpha_pri, alpha_du, xi, aggr,
379            )
380        };
381
382        // One-shot σ-sweep dump for iter==N: emits q(σ) at 21 σ values
383        // spanning [σ_min, σ_max] log-uniform. Enable with
384        // `POUNCE_DBG_QF_SWEEP=<iter>` (matches `data.iter_count`).
385        if let Ok(s) = std::env::var("POUNCE_DBG_QF_SWEEP") {
386            if let Ok(target_iter) = s.parse::<i32>() {
387                if data.borrow().iter_count == target_iter {
388                    let lo = self.sigma_min.max(self.mu_min / avrg_compl);
389                    let hi = self.sigma_max.min(self.mu_max / avrg_compl).max(lo * 10.0);
390                    let log_lo = lo.ln();
391                    let log_hi = hi.ln();
392                    tracing::debug!(target: "pounce::mu", "[QF_SWEEP] iter={} avrg_compl={:.6e} σ_range=[{:.3e},{:.3e}] sigma_min={:.3e} sigma_max={:.3e} mu_min={:.3e} mu_max={:.3e}",
393                        target_iter, avrg_compl, lo, hi,
394                        self.sigma_min, self.sigma_max, self.mu_min, self.mu_max);
395                    let n = 21;
396                    for i in 0..n {
397                        let frac = i as f64 / (n - 1) as f64;
398                        let sig = (log_lo + frac * (log_hi - log_lo)).exp();
399                        let q = eval_q(sig);
400                        tracing::debug!(target: "pounce::mu", "[QF_SWEEP] σ={:.6e} q={:.10e}", sig, q);
401                    }
402                    let q1 = eval_q(1.0);
403                    let s1m = 1.0 - self.section_sigma_tol.max(1e-4);
404                    let q1m = eval_q(s1m);
405                    tracing::debug!(target: "pounce::mu",
406                        "[QF_SWEEP] σ=1.0 q={:.10e}  σ={:.6e} q={:.10e}  (q_1minus>q_1: {})",
407                        q1,
408                        s1m,
409                        q1m,
410                        q1m > q1
411                    );
412                }
413            }
414        }
415
416        let sigma = pick_sigma(
417            self.sigma_min,
418            self.sigma_max,
419            self.mu_min,
420            self.mu_max,
421            avrg_compl,
422            self.section_sigma_tol,
423            self.section_qf_tol,
424            self.max_section_steps,
425            &mut eval_q,
426        );
427
428        let mu_new = sigma * avrg_compl;
429        let mu_clamped = mu_new.clamp(self.mu_min, self.mu_max);
430        if std::env::var("POUNCE_DBG_QF").is_ok() {
431            let iter_count = data.borrow().iter_count;
432            let curr_mu = data.borrow().curr_mu;
433            let sigma_floor = self.sigma_min.max(self.mu_min / avrg_compl);
434            let sigma_up_dn = sigma_floor
435                .max(1.0 - self.section_sigma_tol.max(1e-4))
436                .min(self.mu_max / avrg_compl);
437            tracing::debug!(target: "pounce::mu",
438                "[QF] iter={} curr_mu={:.3e} avrg_compl={:.3e} sigma={:.3e} mu_new={:.3e} mu_clamped={:.3e} | sigma_min={:.3e} mu_min={:.3e} sigma_lo_dn={:.3e} sigma_up_dn={:.3e} mu_min/avrg={:.3e}",
439                iter_count, curr_mu, avrg_compl, sigma, mu_new, mu_clamped,
440                self.sigma_min, self.mu_min, sigma_floor, sigma_up_dn,
441                self.mu_min / avrg_compl,
442            );
443        }
444        Some(mu_clamped)
445    }
446}
447
448impl MuOracle for QualityFunctionMuOracle {
449    fn calculate_mu(&mut self) -> Option<Number> {
450        // The full oracle needs the affine and centering steps; until
451        // the iterate plumbing is finalized, return None so the
452        // adaptive μ update falls through to the LOQO fallback as
453        // upstream does at `IpAdaptiveMuUpdate.cpp:CheckSufficientProgress`.
454        None
455    }
456}
457
458/// Pure-scalar golden-section minimizer used by
459/// `QualityFunctionMuOracle::PerformGoldenSection`
460/// (`IpQualityFunctionMuOracle.cpp:668-790`).
461///
462/// Searches for `argmin_{σ ∈ [σ_lo, σ_up]} q(σ)` via golden-section.
463/// Stops when *either*:
464/// * `(σ_up − σ_lo) < σ_tol · σ_up` (relative width), or
465/// * `1 − min(q_corners) / max(q_corners) < qf_tol` (function flat),
466/// * `nsections ≥ max_steps`.
467///
468/// `q_lo` / `q_up` are the function values at the bracket endpoints,
469/// as in upstream where they're often pre-evaluated and a sentinel
470/// `-100.0` is passed when the value isn't yet known.
471pub fn golden_section(
472    sigma_lo_in: Number,
473    sigma_up_in: Number,
474    q_lo_in: Number,
475    q_up_in: Number,
476    sigma_tol: Number,
477    qf_tol: Number,
478    max_steps: i32,
479    mut q: impl FnMut(Number) -> Number,
480) -> Number {
481    let mut sigma_lo = sigma_lo_in;
482    let mut sigma_up = sigma_up_in;
483    let mut q_lo = q_lo_in;
484    let mut q_up = q_up_in;
485
486    let gfac = (3.0 - 5.0_f64.sqrt()) / 2.0;
487    let mut sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
488    let mut sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
489    let mut qmid1 = q(sigma_mid1);
490    let mut qmid2 = q(sigma_mid2);
491
492    let mut nsections = 0;
493    let mut width_ok;
494    let mut qf_ok;
495    loop {
496        width_ok = (sigma_up - sigma_lo) >= sigma_tol * sigma_up;
497        let qmin = q_lo.min(q_up).min(qmid1).min(qmid2);
498        let qmax = q_lo.max(q_up).max(qmid1).max(qmid2);
499        qf_ok = qmax > 0.0 && (1.0 - qmin / qmax) >= qf_tol;
500        if !(width_ok && qf_ok && nsections < max_steps) {
501            break;
502        }
503        nsections += 1;
504        if qmid1 > qmid2 {
505            sigma_lo = sigma_mid1;
506            q_lo = qmid1;
507            sigma_mid1 = sigma_mid2;
508            qmid1 = qmid2;
509            sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
510            qmid2 = q(sigma_mid2);
511        } else {
512            sigma_up = sigma_mid2;
513            q_up = qmid2;
514            sigma_mid2 = sigma_mid1;
515            qmid2 = qmid1;
516            sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
517            qmid1 = q(sigma_mid1);
518        }
519    }
520
521    // Post-loop selection — mirrors `IpQualityFunctionMuOracle.cpp:749-826`.
522    //
523    // Two distinct cases:
524    //  * **qf_tol stop** (`width_ok && !qf_ok`): the four sampled values
525    //    have converged to within `qf_tol`. Pick whichever of the four
526    //    has the smallest q. Upstream reaches this branch only with real
527    //    values — its loop condition `(1 - qmin/qmax) >= qf_tol` keeps a
528    //    sentinel state alive (sentinel `-100.0` yields a large positive
529    //    ratio) until the slot is overwritten, so `DBG_ASSERT(qf_min > -100.)`
530    //    holds. pounce, however, adds a `qmax > 0.0` guard to `qf_ok`
531    //    (line 499) to avoid a divide-by-zero when every sample is ≤ 0; that
532    //    guard can force `qf_ok = false` while an endpoint still holds the
533    //    sentinel, routing it here. So this branch must re-evaluate an unmoved
534    //    sentinel endpoint first (below), exactly like the else-branch (L4).
535    //  * **Else** (`!width_ok || nsections == max_steps`): pick min of
536    //    the two midpoints, then check whether either endpoint *never
537    //    moved during the loop*. If an unmoved endpoint was passed in
538    //    with the `-100.0` sentinel, it has not been evaluated yet —
539    //    compute its q now and compare. Without this, callers that
540    //    pass a sentinel endpoint (every `pick_sigma` call does — one
541    //    of `q_lo`/`q_up` is always `-100.0`) can have the routine
542    //    return that *unevaluated* endpoint as the minimum, which is
543    //    how DECONVBNE used to land on `sigma = sigma_min`.
544    if width_ok && !qf_ok {
545        // Re-evaluate any endpoint that *never moved during the loop* and is
546        // still carrying the `-100.0` sentinel, before selecting the minimum.
547        // Upstream only reaches this branch with real values (its loop keeps a
548        // sentinel state alive because it lacks the `qmax > 0.0` guard); the
549        // guard pounce adds at line 499 can route a sentinel-containing state
550        // here, so we must mirror the else-branch / upstream re-evaluation or
551        // we would return an unevaluated endpoint as the spurious minimum (L4).
552        if sigma_lo == sigma_lo_in && q_lo < 0.0 {
553            q_lo = q(sigma_lo);
554        }
555        if sigma_up == sigma_up_in && q_up < 0.0 {
556            q_up = q(sigma_up);
557        }
558        let mut best_s = sigma_lo;
559        let mut best_q = q_lo;
560        if q_up < best_q {
561            best_s = sigma_up;
562            best_q = q_up;
563        }
564        if qmid1 < best_q {
565            best_s = sigma_mid1;
566            best_q = qmid1;
567        }
568        if qmid2 < best_q {
569            best_s = sigma_mid2;
570        }
571        return best_s;
572    }
573    let (mut sigma, mut qval) = if qmid1 < qmid2 {
574        (sigma_mid1, qmid1)
575    } else {
576        (sigma_mid2, qmid2)
577    };
578    if sigma_up == sigma_up_in {
579        let qtmp = if q_up < 0.0 { q(sigma_up) } else { q_up };
580        if qtmp < qval {
581            sigma = sigma_up;
582            qval = qtmp;
583        }
584    } else if sigma_lo == sigma_lo_in {
585        let qtmp = if q_lo < 0.0 { q(sigma_lo) } else { q_lo };
586        if qtmp < qval {
587            sigma = sigma_lo;
588        }
589    }
590    let _ = qval;
591    sigma
592}
593
594/// Per-norm aggregates feeding [`evaluate_quality_function`].
595///
596/// All four arrays of pre-reduced complementarity infeasibilities are
597/// caller-provided so the evaluator stays pure-scalar:
598///
599/// * `dual_aggr` — norm of `(grad_lag_x, grad_lag_s)` *before* scaling
600///   by `(1 − α_du)`.
601/// * `primal_aggr` — norm of `(c, d − s)` before `(1 − α_pri)` scaling.
602/// * `compl_aggr` — norm of the four trial-complementarity products
603///   `(s_L · z_L, s_U · z_U, σ_L · v_L, σ_U · v_U)` after the σ-step
604///   has been applied.
605/// * `n_dual`, `n_pri`, `n_comp` — block dimensions used by the
606///   `1`-norm and `2`-norm averaging (the `2_squared` and `max`
607///   variants do not divide).
608#[derive(Debug, Clone, Copy)]
609pub struct QualityFunctionAggregates {
610    pub dual_aggr: Number,
611    pub primal_aggr: Number,
612    pub compl_aggr: Number,
613    pub n_dual: i32,
614    pub n_pri: i32,
615    pub n_comp: i32,
616}
617
618/// Pure-scalar reducer corresponding to
619/// `IpQualityFunctionMuOracle.cpp::CalculateQualityFunction`
620/// lines 566-658 minus the vector→aggregate reduction. Combines the
621/// caller-provided norm aggregates per the configured `(norm,
622/// centrality, balancing)` triple.
623///
624/// `xi` is the centrality measure of the trial complementarity
625/// products; ignored when `centrality == None`.
626pub fn evaluate_quality_function(
627    norm: NormType,
628    centrality: CentralityType,
629    balancing: BalancingTermType,
630    alpha_primal: Number,
631    alpha_dual: Number,
632    xi: Number,
633    aggr: QualityFunctionAggregates,
634) -> Number {
635    let (mut dual_inf, mut primal_inf, mut compl_inf) = match norm {
636        NormType::OneNorm => {
637            let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
638            let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
639            let mut c = aggr.compl_aggr;
640            d /= aggr.n_dual as Number;
641            if aggr.n_pri > 0 {
642                p /= aggr.n_pri as Number;
643            }
644            debug_assert!(aggr.n_comp > 0);
645            c /= aggr.n_comp as Number;
646            (d, p, c)
647        }
648        NormType::TwoNormSquared => {
649            // Upstream `IpQualityFunctionMuOracle.cpp:584-595`. The
650            // (1−α)² weight and per-n averaging differ from the plain
651            // 2-norm branch — and this is the upstream default.
652            let mut d = (1.0 - alpha_dual).powi(2) * aggr.dual_aggr;
653            let mut p = (1.0 - alpha_primal).powi(2) * aggr.primal_aggr;
654            let mut c = aggr.compl_aggr;
655            d /= aggr.n_dual as Number;
656            if aggr.n_pri > 0 {
657                p /= aggr.n_pri as Number;
658            }
659            debug_assert!(aggr.n_comp > 0);
660            c /= aggr.n_comp as Number;
661            (d, p, c)
662        }
663        NormType::MaxNorm => (
664            (1.0 - alpha_dual) * aggr.dual_aggr,
665            (1.0 - alpha_primal) * aggr.primal_aggr,
666            aggr.compl_aggr,
667        ),
668        NormType::TwoNorm => {
669            let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
670            let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
671            let mut c = aggr.compl_aggr;
672            d /= (aggr.n_dual as Number).sqrt();
673            if aggr.n_pri > 0 {
674                p /= (aggr.n_pri as Number).sqrt();
675            }
676            debug_assert!(aggr.n_comp > 0);
677            c /= (aggr.n_comp as Number).sqrt();
678            (d, p, c)
679        }
680    };
681
682    // Repair fp damage from the divisions when the input was already 0.
683    if dual_inf.is_nan() {
684        dual_inf = 0.0;
685    }
686    if primal_inf.is_nan() {
687        primal_inf = 0.0;
688    }
689    if compl_inf.is_nan() {
690        compl_inf = 0.0;
691    }
692
693    let mut q = dual_inf + primal_inf + compl_inf;
694
695    match centrality {
696        CentralityType::None => {}
697        CentralityType::LogCenter => q -= compl_inf * xi.ln(),
698        CentralityType::ReciprocalCenter => q += compl_inf / xi,
699        CentralityType::CubedReciprocalCenter => q += compl_inf / xi.powi(3),
700    }
701
702    match balancing {
703        BalancingTermType::None => {}
704        BalancingTermType::CubicTerm => {
705            let dom = dual_inf.max(primal_inf) - compl_inf;
706            q += dom.max(0.0).powi(3);
707        }
708    }
709
710    q
711}
712
713/// Sigma-bracket selection + golden-section orchestrator. Mirrors
714/// `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 329-385.
715///
716/// `q` is a black-box `q(σ)` evaluator (typically constructed by
717/// composing the affine + σ·centering step into a trial point and
718/// calling [`evaluate_quality_function`]).
719///
720/// Returns the σ that approximately minimizes `q` on the picked
721/// bracket; the caller then sets `μ_new = σ · avrg_compl` and clamps
722/// to `[mu_min, mu_max]`.
723#[allow(clippy::too_many_arguments)]
724pub fn pick_sigma(
725    sigma_min: Number,
726    sigma_max: Number,
727    mu_min: Number,
728    mu_max: Number,
729    avrg_compl: Number,
730    sigma_tol: Number,
731    qf_tol: Number,
732    max_steps: i32,
733    mut q: impl FnMut(Number) -> Number,
734) -> Number {
735    let qf_1 = q(1.0);
736    let sigma_1minus = 1.0 - sigma_tol.max(1e-4);
737    let qf_1minus = q(sigma_1minus);
738
739    if qf_1minus > qf_1 {
740        // q decreases for σ > 1 — search up.
741        let sigma_up = sigma_max.min(mu_max / avrg_compl);
742        let sigma_lo = 1.0;
743        if sigma_lo >= sigma_up {
744            sigma_up
745        } else {
746            golden_section(
747                sigma_lo, sigma_up, qf_1, -100.0, sigma_tol, qf_tol, max_steps, q,
748            )
749        }
750    } else {
751        // q decreases for σ < 1 — search down.
752        let sigma_lo = sigma_min.max(mu_min / avrg_compl);
753        let sigma_up = sigma_lo.max(sigma_1minus).min(mu_max / avrg_compl);
754        if sigma_lo >= sigma_up {
755            sigma_lo
756        } else {
757            golden_section(
758                sigma_lo, sigma_up, -100.0, qf_1minus, sigma_tol, qf_tol, max_steps, q,
759            )
760        }
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767
768    #[test]
769    fn golden_section_minimizes_parabola() {
770        // q(σ) = (σ − 0.3)²; minimum at σ = 0.3.
771        let f = |s: f64| (s - 0.3).powi(2);
772        let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-6, 0.0, 50, f);
773        assert!((s - 0.3).abs() < 1e-3);
774    }
775
776    #[test]
777    fn golden_section_respects_max_steps() {
778        // Heavy max-step cap should still produce a reasonable σ.
779        let f = |s: f64| (s - 0.5).powi(2);
780        let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-12, 0.0, 5, f);
781        assert!((s - 0.5).abs() < 0.2);
782    }
783
784    #[test]
785    fn golden_section_handles_monotone() {
786        // q monotone increasing → minimum at lo end.
787        let f = |s: f64| s;
788        let s = golden_section(0.1, 2.0, 0.1, 2.0, 1e-6, 0.0, 50, f);
789        assert!(s < 0.2, "got s = {}", s);
790    }
791
792    #[test]
793    fn golden_section_never_returns_unevaluated_sentinel() {
794        // Regression for L4. `pick_sigma` always passes one endpoint with the
795        // `-100.0` sentinel as its q-value (search-up → q_up = -100,
796        // search-down → q_lo = -100). When every *evaluated* sample is ≤ 0,
797        // pounce's added `qmax > 0.0` guard forces `qf_ok = false` on the
798        // first pass and drops into the `width_ok && !qf_ok` branch. Before
799        // the fix that branch compared the raw q values — including the
800        // unevaluated `-100.0` — and returned the sentinel endpoint as the
801        // spurious minimum, even though its true quality value is the *worst*
802        // of the bracket. The fix re-evaluates any unmoved sentinel endpoint
803        // first, mirroring the else-branch and upstream's `if( q_up < 0. )`.
804        let sigma_lo = 1.0_f64;
805        let sigma_up = 3.0_f64;
806        // Negative on the interior/lo points (so qmax ≤ 0) but large and
807        // positive exactly at the upper endpoint — the worst place to land.
808        let q = move |s: f64| if s == sigma_up { 50.0 } else { -s };
809        // search-up style: the upper endpoint carries the -100 sentinel.
810        let s = golden_section(sigma_lo, sigma_up, q(sigma_lo), -100.0, 1e-3, 0.0, 50, q);
811        assert!(
812            s < sigma_up,
813            "golden_section returned the unevaluated sentinel endpoint σ = {} \
814             (true q there = {}, the bracket maximum); it must re-evaluate the \
815             sentinel before selecting a minimum",
816            s,
817            q(s)
818        );
819    }
820
821    #[test]
822    fn calculate_mu_returns_none_until_plumbed() {
823        let mut o = QualityFunctionMuOracle::new();
824        assert!(o.calculate_mu().is_none());
825    }
826
827    fn aggr(
828        d: Number,
829        p: Number,
830        c: Number,
831        nd: i32,
832        np: i32,
833        nc: i32,
834    ) -> QualityFunctionAggregates {
835        QualityFunctionAggregates {
836            dual_aggr: d,
837            primal_aggr: p,
838            compl_aggr: c,
839            n_dual: nd,
840            n_pri: np,
841            n_comp: nc,
842        }
843    }
844
845    #[test]
846    fn evaluate_one_norm_averages_by_n() {
847        // (1−α_du)*d/n_d + (1−α_pri)*p/n_p + c/n_c.
848        let q = evaluate_quality_function(
849            NormType::OneNorm,
850            CentralityType::None,
851            BalancingTermType::None,
852            0.5,  // α_pri
853            0.25, // α_du
854            1.0,
855            aggr(8.0, 4.0, 6.0, 4, 2, 3),
856        );
857        // d = 0.75 * 8 / 4 = 1.5; p = 0.5 * 4 / 2 = 1.0; c = 6/3 = 2.0; total = 4.5
858        assert!((q - 4.5).abs() < 1e-12, "got {}", q);
859    }
860
861    #[test]
862    fn evaluate_max_norm_does_not_divide() {
863        let q = evaluate_quality_function(
864            NormType::MaxNorm,
865            CentralityType::None,
866            BalancingTermType::None,
867            0.0,
868            0.0,
869            1.0,
870            aggr(2.0, 3.0, 5.0, 10, 10, 10),
871        );
872        assert!((q - 10.0).abs() < 1e-12);
873    }
874
875    #[test]
876    fn evaluate_two_norm_divides_by_sqrt_n() {
877        let q = evaluate_quality_function(
878            NormType::TwoNorm,
879            CentralityType::None,
880            BalancingTermType::None,
881            0.0,
882            0.0,
883            1.0,
884            aggr(2.0, 0.0, 4.0, 4, 0, 16),
885        );
886        // d = 2/2 = 1.0; p stays 0 (n_pri = 0 → no divide); c = 4/4 = 1.0
887        assert!((q - 2.0).abs() < 1e-12, "got {}", q);
888    }
889
890    #[test]
891    fn evaluate_one_norm_handles_zero_pri_dim() {
892        // n_pri = 0 ⇒ primal must not be divided.
893        let q = evaluate_quality_function(
894            NormType::OneNorm,
895            CentralityType::None,
896            BalancingTermType::None,
897            0.0,
898            0.0,
899            1.0,
900            aggr(0.0, 0.0, 1.0, 1, 0, 1),
901        );
902        assert!(q.is_finite() && (q - 1.0).abs() < 1e-12);
903    }
904
905    #[test]
906    fn evaluate_log_centrality_subtracts_compl_log_xi() {
907        let base = evaluate_quality_function(
908            NormType::MaxNorm,
909            CentralityType::None,
910            BalancingTermType::None,
911            0.0,
912            0.0,
913            std::f64::consts::E,
914            aggr(0.0, 0.0, 4.0, 1, 1, 1),
915        );
916        let logc = evaluate_quality_function(
917            NormType::MaxNorm,
918            CentralityType::LogCenter,
919            BalancingTermType::None,
920            0.0,
921            0.0,
922            std::f64::consts::E,
923            aggr(0.0, 0.0, 4.0, 1, 1, 1),
924        );
925        // Difference is −compl_inf · ln(xi) = −4 · 1 = −4.
926        assert!((base - logc - 4.0).abs() < 1e-12, "base={base} logc={logc}");
927    }
928
929    #[test]
930    fn evaluate_reciprocal_centrality_adds_c_over_xi() {
931        let q = evaluate_quality_function(
932            NormType::MaxNorm,
933            CentralityType::ReciprocalCenter,
934            BalancingTermType::None,
935            0.0,
936            0.0,
937            0.5,
938            aggr(0.0, 0.0, 1.0, 1, 1, 1),
939        );
940        // 1.0 + 1.0/0.5 = 3.0.
941        assert!((q - 3.0).abs() < 1e-12);
942    }
943
944    #[test]
945    fn evaluate_cubed_reciprocal_centrality_adds_c_over_xi3() {
946        let q = evaluate_quality_function(
947            NormType::MaxNorm,
948            CentralityType::CubedReciprocalCenter,
949            BalancingTermType::None,
950            0.0,
951            0.0,
952            0.5,
953            aggr(0.0, 0.0, 1.0, 1, 1, 1),
954        );
955        // 1.0 + 1.0/0.125 = 9.0.
956        assert!((q - 9.0).abs() < 1e-12);
957    }
958
959    #[test]
960    fn evaluate_cubic_balancing_adds_when_dual_dominates() {
961        let q = evaluate_quality_function(
962            NormType::MaxNorm,
963            CentralityType::None,
964            BalancingTermType::CubicTerm,
965            0.0,
966            0.0,
967            1.0,
968            aggr(5.0, 1.0, 2.0, 1, 1, 1),
969        );
970        // base = 5+1+2 = 8; dom = max(5,1) − 2 = 3; +27 → 35.
971        assert!((q - 35.0).abs() < 1e-12, "got {}", q);
972    }
973
974    #[test]
975    fn evaluate_cubic_balancing_zero_when_compl_dominates() {
976        let q = evaluate_quality_function(
977            NormType::MaxNorm,
978            CentralityType::None,
979            BalancingTermType::CubicTerm,
980            0.0,
981            0.0,
982            1.0,
983            aggr(1.0, 1.0, 5.0, 1, 1, 1),
984        );
985        // dom = max(1,1) − 5 = −4 → clamped to 0; total = 7.
986        assert!((q - 7.0).abs() < 1e-12);
987    }
988
989    #[test]
990    fn pick_sigma_searches_below_one_for_decreasing_q() {
991        // Parabola minimum at σ = 0.4 (well below 1).
992        let f = |s: f64| (s - 0.4).powi(2);
993        let s = pick_sigma(1e-9, 100.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
994        assert!((s - 0.4).abs() < 1e-2, "got s = {}", s);
995    }
996
997    #[test]
998    fn pick_sigma_searches_above_one_for_q_decreasing_in_sigma() {
999        // q decreases as σ grows ⇒ minimum at top of bracket.
1000        let f = |s: f64| -s;
1001        let s = pick_sigma(1e-9, 10.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
1002        // bracket up-end is min(sigma_max=10, mu_max/avrg=1e5) = 10.
1003        assert!(s > 5.0, "got s = {}", s);
1004    }
1005
1006    #[test]
1007    fn pick_sigma_clamps_to_mu_max_over_avrg_in_up_search() {
1008        // mu_max/avrg = 2.0 should cap σ_up below sigma_max = 100.
1009        let f = |s: f64| -s;
1010        let s = pick_sigma(1e-9, 100.0, 1e-11, 2.0, 1.0, 1e-6, 0.0, 50, f);
1011        assert!(s <= 2.0 + 1e-9 && s >= 1.0, "got s = {}", s);
1012    }
1013
1014    #[test]
1015    fn pick_sigma_clamps_to_mu_min_over_avrg_in_down_search() {
1016        // mu_min/avrg = 0.5 must dominate σ_min = 1e-9.
1017        // q monotone-decreasing toward 0 → search picks low end of bracket.
1018        let f = |s: f64| s;
1019        let s = pick_sigma(1e-9, 100.0, 0.5, 1e5, 1.0, 1e-6, 0.0, 50, f);
1020        assert!(s >= 0.5 - 1e-9 && s <= 1.0, "got s = {}", s);
1021    }
1022}