Skip to main content

pounce_algorithm/mu/oracle/
quality_function.rs

1//! Quality-function mu oracle — port of
2//! `IpQualityFunctionMuOracle.{hpp,cpp}`. Phase 10.
3//!
4//! The oracle picks `μ_new = σ * avrg_compl` by minimizing a 1-D
5//! quality function `q(σ)` over `σ ∈ [σ_lo, σ_up]` via golden section.
6//! The full vector-valued evaluator (which builds the trial slack /
7//! multiplier vectors at a candidate σ and reduces them to a scalar
8//! norm) is split into two pieces:
9//!
10//! * `evaluate_quality_function` — a *pure-scalar* reducer that takes
11//!   already-computed `‖·‖` aggregates and combines them per the
12//!   `(norm, centrality, balancing)` triple per
13//!   `IpQualityFunctionMuOracle.cpp:566-658`. The vector→aggregate
14//!   step is the caller's responsibility.
15//! * `pick_sigma` — orchestrator that mirrors
16//!   `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 329-385: picks
17//!   the σ-bracket, evaluates `q(1)` and `q(1−ε)` to decide whether
18//!   to search above or below 1, then drives `golden_section`.
19//!
20//! Wiring `pick_sigma` to a fully populated `q(σ)` evaluator —
21//! including the centering predictor solve — is the remaining scope.
22
23use crate::ipopt_cq::IpoptCqHandle;
24use crate::ipopt_data::IpoptDataHandle;
25use crate::ipopt_nlp::IpoptNlp;
26use crate::iterates_vector::IteratesVector;
27use crate::kkt::pd_search_dir_calc::PdSearchDirCalc;
28use crate::mu::oracle::r#trait::MuOracle;
29use pounce_common::types::Number;
30use pounce_linalg::Vector;
31use std::cell::RefCell;
32use std::rc::Rc;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum NormType {
36    OneNorm,
37    /// Squared 2-norm — upstream `NM_NORM_2_SQUARED` (default).
38    /// Aggregates are `||·||²` (no sqrt) and `(1−α)²` weighting.
39    TwoNormSquared,
40    TwoNorm,
41    MaxNorm,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum CentralityType {
46    None,
47    LogCenter,
48    ReciprocalCenter,
49    CubedReciprocalCenter,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum BalancingTermType {
54    None,
55    CubicTerm,
56}
57
58pub struct QualityFunctionMuOracle {
59    pub norm_type: NormType,
60    pub centrality_type: CentralityType,
61    pub balancing_term: BalancingTermType,
62    pub max_section_steps: i32,
63    pub section_sigma_tol: Number,
64    pub section_qf_tol: Number,
65    pub sigma_max: Number,
66    pub sigma_min: Number,
67    pub mu_min: Number,
68    pub mu_max: Number,
69}
70
71impl Default for QualityFunctionMuOracle {
72    fn default() -> Self {
73        // Defaults from `IpQualityFunctionMuOracle.cpp:RegisterOptions`.
74        Self {
75            norm_type: NormType::TwoNormSquared,
76            centrality_type: CentralityType::None,
77            balancing_term: BalancingTermType::None,
78            max_section_steps: 8,
79            section_sigma_tol: 1e-2,
80            section_qf_tol: 0.0,
81            sigma_max: 100.0,
82            // Upstream `IpQualityFunctionMuOracle.cpp:62-69`
83            // `RegisterOptions` default is 1e-6, not 1e-9. Setting it
84            // too low lets golden-section collapse σ all the way to
85            // the floor on outer iterations where q(σ) is nearly
86            // flat over the bracket — which then drives μ to ~1e-11
87            // in a single step and triggers a kappa_sigma blow-up
88            // that pushes the algorithm into restoration. (HS1NE
89            // and ~50 other CUTEst problems exhibited this.)
90            sigma_min: 1e-6,
91            mu_min: 1e-11,
92            mu_max: 1e5,
93        }
94    }
95}
96
97impl QualityFunctionMuOracle {
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Drive the predictor + centring solves through `pd_search_dir`,
103    /// project the results onto the four bound-mask subspaces, then
104    /// run [`pick_sigma`] over a `q(σ)` closure that builds the σ-step
105    /// and reduces it to [`QualityFunctionAggregates`] before invoking
106    /// [`evaluate_quality_function`]. Mirrors upstream
107    /// `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 188-485.
108    ///
109    /// Returns `None` if either linear solve fails (caller falls back
110    /// to LOQO, matching upstream's
111    /// `IpAdaptiveMuUpdate.cpp::CalculateMuFromOracle:330-340`).
112    #[allow(clippy::too_many_lines)]
113    pub fn calculate_mu_with_predictor_centering(
114        &mut self,
115        data: &IpoptDataHandle,
116        cq: &IpoptCqHandle,
117        nlp: &Rc<RefCell<dyn IpoptNlp>>,
118        pd_search_dir: &mut PdSearchDirCalc,
119    ) -> Option<Number> {
120        if !pd_search_dir.compute_affine_step(data, cq, nlp) {
121            return None;
122        }
123        if !pd_search_dir.compute_centering_step(data, cq, nlp) {
124            return None;
125        }
126
127        let delta_aff: IteratesVector = data.borrow().delta_aff.clone()?;
128        let delta_cen: IteratesVector = data.borrow().delta_cen.clone()?;
129
130        // Project step.x onto the bound subspaces. Each block matches
131        // the `step_aff_x_L = P_L^T·δ_aff_x` setup in
132        // `IpQualityFunctionMuOracle.cpp:308-323`.
133        let nlp_ref = nlp.borrow();
134        let cq_ref = cq.borrow();
135        let curr_iv = cq_ref.curr_iv();
136        let curr_slack_x_l = cq_ref.curr_slack_x_l();
137        let curr_slack_x_u = cq_ref.curr_slack_x_u();
138        let curr_slack_s_l = cq_ref.curr_slack_s_l();
139        let curr_slack_s_u = cq_ref.curr_slack_s_u();
140        let avrg_compl = cq_ref.curr_avrg_compl();
141
142        let project = |sign_l_x: Number,
143                       sign_u_x: Number,
144                       step_x: &dyn Vector,
145                       step_s: &dyn Vector|
146         -> [Rc<dyn Vector>; 4] {
147            let mut x_l = curr_slack_x_l.make_new();
148            nlp_ref
149                .px_l()
150                .trans_mult_vector(sign_l_x, step_x, 0.0, &mut *x_l);
151            let mut x_u = curr_slack_x_u.make_new();
152            nlp_ref
153                .px_u()
154                .trans_mult_vector(sign_u_x, step_x, 0.0, &mut *x_u);
155            let mut s_l = curr_slack_s_l.make_new();
156            nlp_ref
157                .pd_l()
158                .trans_mult_vector(sign_l_x, step_s, 0.0, &mut *s_l);
159            let mut s_u = curr_slack_s_u.make_new();
160            nlp_ref
161                .pd_u()
162                .trans_mult_vector(sign_u_x, step_s, 0.0, &mut *s_u);
163            [Rc::from(x_l), Rc::from(x_u), Rc::from(s_l), Rc::from(s_u)]
164        };
165
166        let [step_aff_x_l, step_aff_x_u, step_aff_s_l, step_aff_s_u] =
167            project(1.0, -1.0, &*delta_aff.x, &*delta_aff.s);
168        let [step_cen_x_l, step_cen_x_u, step_cen_s_l, step_cen_s_u] =
169            project(1.0, -1.0, &*delta_cen.x, &*delta_cen.s);
170
171        // The z/v step blocks are stored directly on the iterate — no
172        // projection needed (upstream lines 318-323 use the raw blocks).
173        let step_aff_z_l = delta_aff.z_l.clone();
174        let step_aff_z_u = delta_aff.z_u.clone();
175        let step_aff_v_l = delta_aff.v_l.clone();
176        let step_aff_v_u = delta_aff.v_u.clone();
177        let step_cen_z_l = delta_cen.z_l.clone();
178        let step_cen_z_u = delta_cen.z_u.clone();
179        let step_cen_v_l = delta_cen.v_l.clone();
180        let step_cen_v_u = delta_cen.v_u.clone();
181
182        // Drop the immutable nlp borrow before invoking CQ accessors
183        // that may take a `nlp.borrow_mut()` (e.g. `curr_grad_lag_x` →
184        // `curr_grad_f` → `nlp.eval_grad_f`).
185        drop(nlp_ref);
186
187        // Constant-in-σ aggregates: `dual_aggr` from ‖∇L_x‖, ‖∇L_s‖;
188        // `primal_aggr` from ‖c‖, ‖d−s‖. Norm choice driven by
189        // `self.norm_type`. Upstream `cpp:283-303`.
190        let grad_lag_x = cq_ref.curr_grad_lag_x();
191        let grad_lag_s = cq_ref.curr_grad_lag_s();
192        let c = cq_ref.curr_c();
193        let d_minus_s = cq_ref.curr_d_minus_s();
194        let dual_aggr = match self.norm_type {
195            NormType::OneNorm => grad_lag_x.asum() + grad_lag_s.asum(),
196            NormType::TwoNormSquared => {
197                let nx = grad_lag_x.nrm2();
198                let ns = grad_lag_s.nrm2();
199                nx * nx + ns * ns
200            }
201            NormType::TwoNorm => {
202                let nx = grad_lag_x.nrm2();
203                let ns = grad_lag_s.nrm2();
204                (nx * nx + ns * ns).sqrt()
205            }
206            NormType::MaxNorm => grad_lag_x.amax().max(grad_lag_s.amax()),
207        };
208        let primal_aggr = match self.norm_type {
209            NormType::OneNorm => c.asum() + d_minus_s.asum(),
210            NormType::TwoNormSquared => {
211                let nc = c.nrm2();
212                let nd = d_minus_s.nrm2();
213                nc * nc + nd * nd
214            }
215            NormType::TwoNorm => {
216                let nc = c.nrm2();
217                let nd = d_minus_s.nrm2();
218                (nc * nc + nd * nd).sqrt()
219            }
220            NormType::MaxNorm => c.amax().max(d_minus_s.amax()),
221        };
222
223        let n_dual = curr_iv.x.dim() + curr_iv.s.dim();
224        let n_pri = curr_iv.y_c.dim() + curr_iv.y_d.dim();
225        let n_comp = curr_iv.z_l.dim() + curr_iv.z_u.dim() + curr_iv.v_l.dim() + curr_iv.v_u.dim();
226        let tau = data.borrow().curr_tau;
227
228        let curr_z_l = curr_iv.z_l.clone();
229        let curr_z_u = curr_iv.z_u.clone();
230        let curr_v_l = curr_iv.v_l.clone();
231        let curr_v_u = curr_iv.v_u.clone();
232
233        drop(cq_ref);
234
235        let norm_type = self.norm_type;
236        let centrality = self.centrality_type;
237        let balancing = self.balancing_term;
238
239        // q(σ) closure. Captures the eight aff/cen step projections,
240        // the four current slacks, the four current bound multipliers,
241        // and the constant aggregates; pure scalar work per call.
242        let mut eval_q = |sigma: Number| -> Number {
243            // step_σ = step_aff + σ · step_cen, projected blocks.
244            let combine = |aff: &Rc<dyn Vector>, cen: &Rc<dyn Vector>| -> Box<dyn Vector> {
245                let mut out = aff.make_new();
246                out.set(0.0);
247                out.add_two_vectors(1.0, &**aff, sigma, &**cen, 0.0);
248                out
249            };
250            let stp_x_l = combine(&step_aff_x_l, &step_cen_x_l);
251            let stp_x_u = combine(&step_aff_x_u, &step_cen_x_u);
252            let stp_s_l = combine(&step_aff_s_l, &step_cen_s_l);
253            let stp_s_u = combine(&step_aff_s_u, &step_cen_s_u);
254            let stp_z_l = combine(&step_aff_z_l, &step_cen_z_l);
255            let stp_z_u = combine(&step_aff_z_u, &step_cen_z_u);
256            let stp_v_l = combine(&step_aff_v_l, &step_cen_v_l);
257            let stp_v_u = combine(&step_aff_v_u, &step_cen_v_u);
258
259            // α_pri = min over slacks of frac_to_bound(curr_slack, step, τ).
260            let alpha_pri = curr_slack_x_l
261                .frac_to_bound(&*stp_x_l, tau)
262                .min(curr_slack_x_u.frac_to_bound(&*stp_x_u, tau))
263                .min(curr_slack_s_l.frac_to_bound(&*stp_s_l, tau))
264                .min(curr_slack_s_u.frac_to_bound(&*stp_s_u, tau));
265            let alpha_du = curr_z_l
266                .frac_to_bound(&*stp_z_l, tau)
267                .min(curr_z_u.frac_to_bound(&*stp_z_u, tau))
268                .min(curr_v_l.frac_to_bound(&*stp_v_l, tau))
269                .min(curr_v_u.frac_to_bound(&*stp_v_u, tau));
270
271            // Build σ-step trial slacks/duals: trial = curr + α·step.
272            let mut trial_s_x_l = curr_slack_x_l.make_new();
273            trial_s_x_l.set(0.0);
274            trial_s_x_l.add_two_vectors(1.0, &*curr_slack_x_l, alpha_pri, &*stp_x_l, 0.0);
275            let mut trial_s_x_u = curr_slack_x_u.make_new();
276            trial_s_x_u.set(0.0);
277            trial_s_x_u.add_two_vectors(1.0, &*curr_slack_x_u, alpha_pri, &*stp_x_u, 0.0);
278            let mut trial_s_s_l = curr_slack_s_l.make_new();
279            trial_s_s_l.set(0.0);
280            trial_s_s_l.add_two_vectors(1.0, &*curr_slack_s_l, alpha_pri, &*stp_s_l, 0.0);
281            let mut trial_s_s_u = curr_slack_s_u.make_new();
282            trial_s_s_u.set(0.0);
283            trial_s_s_u.add_two_vectors(1.0, &*curr_slack_s_u, alpha_pri, &*stp_s_u, 0.0);
284
285            let mut trial_z_l = curr_z_l.make_new();
286            trial_z_l.set(0.0);
287            trial_z_l.add_two_vectors(1.0, &*curr_z_l, alpha_du, &*stp_z_l, 0.0);
288            let mut trial_z_u = curr_z_u.make_new();
289            trial_z_u.set(0.0);
290            trial_z_u.add_two_vectors(1.0, &*curr_z_u, alpha_du, &*stp_z_u, 0.0);
291            let mut trial_v_l = curr_v_l.make_new();
292            trial_v_l.set(0.0);
293            trial_v_l.add_two_vectors(1.0, &*curr_v_l, alpha_du, &*stp_v_l, 0.0);
294            let mut trial_v_u = curr_v_u.make_new();
295            trial_v_u.set(0.0);
296            trial_v_u.add_two_vectors(1.0, &*curr_v_u, alpha_du, &*stp_v_u, 0.0);
297
298            // Complementarity products at the σ-trial point.
299            trial_s_x_l.element_wise_multiply(&*trial_z_l);
300            trial_s_x_u.element_wise_multiply(&*trial_z_u);
301            trial_s_s_l.element_wise_multiply(&*trial_v_l);
302            trial_s_s_u.element_wise_multiply(&*trial_v_u);
303
304            let compl_aggr = match norm_type {
305                NormType::OneNorm => {
306                    trial_s_x_l.asum()
307                        + trial_s_x_u.asum()
308                        + trial_s_s_l.asum()
309                        + trial_s_s_u.asum()
310                }
311                NormType::TwoNormSquared => {
312                    let a = trial_s_x_l.nrm2();
313                    let b = trial_s_x_u.nrm2();
314                    let c = trial_s_s_l.nrm2();
315                    let d = trial_s_s_u.nrm2();
316                    a * a + b * b + c * c + d * d
317                }
318                NormType::TwoNorm => {
319                    let a = trial_s_x_l.nrm2();
320                    let b = trial_s_x_u.nrm2();
321                    let c = trial_s_s_l.nrm2();
322                    let d = trial_s_s_u.nrm2();
323                    (a * a + b * b + c * c + d * d).sqrt()
324                }
325                NormType::MaxNorm => trial_s_x_l
326                    .amax()
327                    .max(trial_s_x_u.amax())
328                    .max(trial_s_s_l.amax())
329                    .max(trial_s_s_u.amax()),
330            };
331
332            let xi = if matches!(centrality, CentralityType::None) {
333                1.0
334            } else {
335                // Centrality: min(s_i z_i) / avg(s_i z_i). Cheap proxy
336                // when centrality != None — upstream computes the same
337                // ratio at line 612 onward.
338                let total = trial_s_x_l.asum()
339                    + trial_s_x_u.asum()
340                    + trial_s_s_l.asum()
341                    + trial_s_s_u.asum();
342                let avg = if n_comp > 0 {
343                    total / n_comp as Number
344                } else {
345                    1.0
346                };
347                let mn = trial_s_x_l
348                    .min()
349                    .min(trial_s_x_u.min())
350                    .min(trial_s_s_l.min())
351                    .min(trial_s_s_u.min());
352                if avg > 0.0 {
353                    mn / avg
354                } else {
355                    1.0
356                }
357            };
358
359            let aggr = QualityFunctionAggregates {
360                dual_aggr,
361                primal_aggr,
362                compl_aggr,
363                n_dual,
364                n_pri,
365                n_comp,
366            };
367
368            if std::env::var("POUNCE_DBG_QF_AGGR").is_ok() {
369                tracing::debug!(target: "pounce::mu",
370                    "[QF_AGGR] σ={:.6e} α_pri={:.6e} α_du={:.6e} xi={:.6e} dual_aggr={:.6e} primal_aggr={:.6e} compl_aggr={:.6e} n_dual={} n_pri={} n_comp={}",
371                    sigma, alpha_pri, alpha_du, xi,
372                    dual_aggr, primal_aggr, compl_aggr,
373                    n_dual, n_pri, n_comp
374                );
375            }
376
377            evaluate_quality_function(
378                norm_type, centrality, balancing, alpha_pri, alpha_du, xi, aggr,
379            )
380        };
381
382        // One-shot σ-sweep dump for iter==N: emits q(σ) at 21 σ values
383        // spanning [σ_min, σ_max] log-uniform. Enable with
384        // `POUNCE_DBG_QF_SWEEP=<iter>` (matches `data.iter_count`).
385        if let Ok(s) = std::env::var("POUNCE_DBG_QF_SWEEP") {
386            if let Ok(target_iter) = s.parse::<i32>() {
387                if data.borrow().iter_count == target_iter {
388                    let lo = self.sigma_min.max(self.mu_min / avrg_compl);
389                    let hi = self.sigma_max.min(self.mu_max / avrg_compl).max(lo * 10.0);
390                    let log_lo = lo.ln();
391                    let log_hi = hi.ln();
392                    tracing::debug!(target: "pounce::mu", "[QF_SWEEP] iter={} avrg_compl={:.6e} σ_range=[{:.3e},{:.3e}] sigma_min={:.3e} sigma_max={:.3e} mu_min={:.3e} mu_max={:.3e}",
393                        target_iter, avrg_compl, lo, hi,
394                        self.sigma_min, self.sigma_max, self.mu_min, self.mu_max);
395                    let n = 21;
396                    for i in 0..n {
397                        let frac = i as f64 / (n - 1) as f64;
398                        let sig = (log_lo + frac * (log_hi - log_lo)).exp();
399                        let q = eval_q(sig);
400                        tracing::debug!(target: "pounce::mu", "[QF_SWEEP] σ={:.6e} q={:.10e}", sig, q);
401                    }
402                    let q1 = eval_q(1.0);
403                    let s1m = 1.0 - self.section_sigma_tol.max(1e-4);
404                    let q1m = eval_q(s1m);
405                    tracing::debug!(target: "pounce::mu",
406                        "[QF_SWEEP] σ=1.0 q={:.10e}  σ={:.6e} q={:.10e}  (q_1minus>q_1: {})",
407                        q1,
408                        s1m,
409                        q1m,
410                        q1m > q1
411                    );
412                }
413            }
414        }
415
416        let sigma = pick_sigma(
417            self.sigma_min,
418            self.sigma_max,
419            self.mu_min,
420            self.mu_max,
421            avrg_compl,
422            self.section_sigma_tol,
423            self.section_qf_tol,
424            self.max_section_steps,
425            &mut eval_q,
426        );
427
428        let mu_new = sigma * avrg_compl;
429        let mu_clamped = mu_new.clamp(self.mu_min, self.mu_max);
430        if std::env::var("POUNCE_DBG_QF").is_ok() {
431            let iter_count = data.borrow().iter_count;
432            let curr_mu = data.borrow().curr_mu;
433            let sigma_floor = self.sigma_min.max(self.mu_min / avrg_compl);
434            let sigma_up_dn = sigma_floor
435                .max(1.0 - self.section_sigma_tol.max(1e-4))
436                .min(self.mu_max / avrg_compl);
437            tracing::debug!(target: "pounce::mu",
438                "[QF] iter={} curr_mu={:.3e} avrg_compl={:.3e} sigma={:.3e} mu_new={:.3e} mu_clamped={:.3e} | sigma_min={:.3e} mu_min={:.3e} sigma_lo_dn={:.3e} sigma_up_dn={:.3e} mu_min/avrg={:.3e}",
439                iter_count, curr_mu, avrg_compl, sigma, mu_new, mu_clamped,
440                self.sigma_min, self.mu_min, sigma_floor, sigma_up_dn,
441                self.mu_min / avrg_compl,
442            );
443        }
444        Some(mu_clamped)
445    }
446}
447
448impl MuOracle for QualityFunctionMuOracle {
449    fn calculate_mu(&mut self) -> Option<Number> {
450        // The full oracle needs the affine and centering steps; until
451        // the iterate plumbing is finalized, return None so the
452        // adaptive μ update falls through to the LOQO fallback as
453        // upstream does at `IpAdaptiveMuUpdate.cpp:CheckSufficientProgress`.
454        None
455    }
456}
457
458/// Pure-scalar golden-section minimizer used by
459/// `QualityFunctionMuOracle::PerformGoldenSection`
460/// (`IpQualityFunctionMuOracle.cpp:668-790`).
461///
462/// Searches for `argmin_{σ ∈ [σ_lo, σ_up]} q(σ)` via golden-section.
463/// Stops when *either*:
464/// * `(σ_up − σ_lo) < σ_tol · σ_up` (relative width), or
465/// * `1 − min(q_corners) / max(q_corners) < qf_tol` (function flat),
466/// * `nsections ≥ max_steps`.
467///
468/// `q_lo` / `q_up` are the function values at the bracket endpoints,
469/// as in upstream where they're often pre-evaluated and a sentinel
470/// `-100.0` is passed when the value isn't yet known.
471pub fn golden_section(
472    sigma_lo_in: Number,
473    sigma_up_in: Number,
474    q_lo_in: Number,
475    q_up_in: Number,
476    sigma_tol: Number,
477    qf_tol: Number,
478    max_steps: i32,
479    mut q: impl FnMut(Number) -> Number,
480) -> Number {
481    let mut sigma_lo = sigma_lo_in;
482    let mut sigma_up = sigma_up_in;
483    let mut q_lo = q_lo_in;
484    let mut q_up = q_up_in;
485
486    let gfac = (3.0 - 5.0_f64.sqrt()) / 2.0;
487    let mut sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
488    let mut sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
489    let mut qmid1 = q(sigma_mid1);
490    let mut qmid2 = q(sigma_mid2);
491
492    let mut nsections = 0;
493    let mut width_ok;
494    let mut qf_ok;
495    loop {
496        width_ok = (sigma_up - sigma_lo) >= sigma_tol * sigma_up;
497        let qmin = q_lo.min(q_up).min(qmid1).min(qmid2);
498        let qmax = q_lo.max(q_up).max(qmid1).max(qmid2);
499        qf_ok = qmax > 0.0 && (1.0 - qmin / qmax) >= qf_tol;
500        if !(width_ok && qf_ok && nsections < max_steps) {
501            break;
502        }
503        nsections += 1;
504        if qmid1 > qmid2 {
505            sigma_lo = sigma_mid1;
506            q_lo = qmid1;
507            sigma_mid1 = sigma_mid2;
508            qmid1 = qmid2;
509            sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
510            qmid2 = q(sigma_mid2);
511        } else {
512            sigma_up = sigma_mid2;
513            q_up = qmid2;
514            sigma_mid2 = sigma_mid1;
515            qmid2 = qmid1;
516            sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
517            qmid1 = q(sigma_mid1);
518        }
519    }
520
521    // Post-loop selection — mirrors `IpQualityFunctionMuOracle.cpp:749-826`.
522    //
523    // Two distinct cases:
524    //  * **qf_tol stop** (`width_ok && !qf_ok`): the four sampled values
525    //    have converged to within `qf_tol`. Pick whichever of the four
526    //    has the smallest q. Upstream `DBG_ASSERT(qf_min > -100.)`
527    //    holds because the qf_ok predicate `(1 - qmin/qmax) < qf_tol`
528    //    forces qmin to be a real positive value (sentinel `-100.0`
529    //    would yield `1 + 100/qmax > 1 > qf_tol`, keeping the loop
530    //    alive until the sentinel slot is overwritten).
531    //  * **Else** (`!width_ok || nsections == max_steps`): pick min of
532    //    the two midpoints, then check whether either endpoint *never
533    //    moved during the loop*. If an unmoved endpoint was passed in
534    //    with the `-100.0` sentinel, it has not been evaluated yet —
535    //    compute its q now and compare. Without this, callers that
536    //    pass a sentinel endpoint (every `pick_sigma` call does — one
537    //    of `q_lo`/`q_up` is always `-100.0`) can have the routine
538    //    return that *unevaluated* endpoint as the minimum, which is
539    //    how DECONVBNE used to land on `sigma = sigma_min`.
540    if width_ok && !qf_ok {
541        let mut best_s = sigma_lo;
542        let mut best_q = q_lo;
543        if q_up < best_q {
544            best_s = sigma_up;
545            best_q = q_up;
546        }
547        if qmid1 < best_q {
548            best_s = sigma_mid1;
549            best_q = qmid1;
550        }
551        if qmid2 < best_q {
552            best_s = sigma_mid2;
553        }
554        return best_s;
555    }
556    let (mut sigma, mut qval) = if qmid1 < qmid2 {
557        (sigma_mid1, qmid1)
558    } else {
559        (sigma_mid2, qmid2)
560    };
561    if sigma_up == sigma_up_in {
562        let qtmp = if q_up < 0.0 { q(sigma_up) } else { q_up };
563        if qtmp < qval {
564            sigma = sigma_up;
565            qval = qtmp;
566        }
567    } else if sigma_lo == sigma_lo_in {
568        let qtmp = if q_lo < 0.0 { q(sigma_lo) } else { q_lo };
569        if qtmp < qval {
570            sigma = sigma_lo;
571        }
572    }
573    let _ = qval;
574    sigma
575}
576
577/// Per-norm aggregates feeding [`evaluate_quality_function`].
578///
579/// All four arrays of pre-reduced complementarity infeasibilities are
580/// caller-provided so the evaluator stays pure-scalar:
581///
582/// * `dual_aggr` — norm of `(grad_lag_x, grad_lag_s)` *before* scaling
583///   by `(1 − α_du)`.
584/// * `primal_aggr` — norm of `(c, d − s)` before `(1 − α_pri)` scaling.
585/// * `compl_aggr` — norm of the four trial-complementarity products
586///   `(s_L · z_L, s_U · z_U, σ_L · v_L, σ_U · v_U)` after the σ-step
587///   has been applied.
588/// * `n_dual`, `n_pri`, `n_comp` — block dimensions used by the
589///   `1`-norm and `2`-norm averaging (the `2_squared` and `max`
590///   variants do not divide).
591#[derive(Debug, Clone, Copy)]
592pub struct QualityFunctionAggregates {
593    pub dual_aggr: Number,
594    pub primal_aggr: Number,
595    pub compl_aggr: Number,
596    pub n_dual: i32,
597    pub n_pri: i32,
598    pub n_comp: i32,
599}
600
601/// Pure-scalar reducer corresponding to
602/// `IpQualityFunctionMuOracle.cpp::CalculateQualityFunction`
603/// lines 566-658 minus the vector→aggregate reduction. Combines the
604/// caller-provided norm aggregates per the configured `(norm,
605/// centrality, balancing)` triple.
606///
607/// `xi` is the centrality measure of the trial complementarity
608/// products; ignored when `centrality == None`.
609pub fn evaluate_quality_function(
610    norm: NormType,
611    centrality: CentralityType,
612    balancing: BalancingTermType,
613    alpha_primal: Number,
614    alpha_dual: Number,
615    xi: Number,
616    aggr: QualityFunctionAggregates,
617) -> Number {
618    let (mut dual_inf, mut primal_inf, mut compl_inf) = match norm {
619        NormType::OneNorm => {
620            let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
621            let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
622            let mut c = aggr.compl_aggr;
623            d /= aggr.n_dual as Number;
624            if aggr.n_pri > 0 {
625                p /= aggr.n_pri as Number;
626            }
627            debug_assert!(aggr.n_comp > 0);
628            c /= aggr.n_comp as Number;
629            (d, p, c)
630        }
631        NormType::TwoNormSquared => {
632            // Upstream `IpQualityFunctionMuOracle.cpp:584-595`. The
633            // (1−α)² weight and per-n averaging differ from the plain
634            // 2-norm branch — and this is the upstream default.
635            let mut d = (1.0 - alpha_dual).powi(2) * aggr.dual_aggr;
636            let mut p = (1.0 - alpha_primal).powi(2) * aggr.primal_aggr;
637            let mut c = aggr.compl_aggr;
638            d /= aggr.n_dual as Number;
639            if aggr.n_pri > 0 {
640                p /= aggr.n_pri as Number;
641            }
642            debug_assert!(aggr.n_comp > 0);
643            c /= aggr.n_comp as Number;
644            (d, p, c)
645        }
646        NormType::MaxNorm => (
647            (1.0 - alpha_dual) * aggr.dual_aggr,
648            (1.0 - alpha_primal) * aggr.primal_aggr,
649            aggr.compl_aggr,
650        ),
651        NormType::TwoNorm => {
652            let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
653            let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
654            let mut c = aggr.compl_aggr;
655            d /= (aggr.n_dual as Number).sqrt();
656            if aggr.n_pri > 0 {
657                p /= (aggr.n_pri as Number).sqrt();
658            }
659            debug_assert!(aggr.n_comp > 0);
660            c /= (aggr.n_comp as Number).sqrt();
661            (d, p, c)
662        }
663    };
664
665    // Repair fp damage from the divisions when the input was already 0.
666    if dual_inf.is_nan() {
667        dual_inf = 0.0;
668    }
669    if primal_inf.is_nan() {
670        primal_inf = 0.0;
671    }
672    if compl_inf.is_nan() {
673        compl_inf = 0.0;
674    }
675
676    let mut q = dual_inf + primal_inf + compl_inf;
677
678    match centrality {
679        CentralityType::None => {}
680        CentralityType::LogCenter => q -= compl_inf * xi.ln(),
681        CentralityType::ReciprocalCenter => q += compl_inf / xi,
682        CentralityType::CubedReciprocalCenter => q += compl_inf / xi.powi(3),
683    }
684
685    match balancing {
686        BalancingTermType::None => {}
687        BalancingTermType::CubicTerm => {
688            let dom = dual_inf.max(primal_inf) - compl_inf;
689            q += dom.max(0.0).powi(3);
690        }
691    }
692
693    q
694}
695
696/// Sigma-bracket selection + golden-section orchestrator. Mirrors
697/// `IpQualityFunctionMuOracle.cpp::CalculateMu` lines 329-385.
698///
699/// `q` is a black-box `q(σ)` evaluator (typically constructed by
700/// composing the affine + σ·centering step into a trial point and
701/// calling [`evaluate_quality_function`]).
702///
703/// Returns the σ that approximately minimizes `q` on the picked
704/// bracket; the caller then sets `μ_new = σ · avrg_compl` and clamps
705/// to `[mu_min, mu_max]`.
706#[allow(clippy::too_many_arguments)]
707pub fn pick_sigma(
708    sigma_min: Number,
709    sigma_max: Number,
710    mu_min: Number,
711    mu_max: Number,
712    avrg_compl: Number,
713    sigma_tol: Number,
714    qf_tol: Number,
715    max_steps: i32,
716    mut q: impl FnMut(Number) -> Number,
717) -> Number {
718    let qf_1 = q(1.0);
719    let sigma_1minus = 1.0 - sigma_tol.max(1e-4);
720    let qf_1minus = q(sigma_1minus);
721
722    if qf_1minus > qf_1 {
723        // q decreases for σ > 1 — search up.
724        let sigma_up = sigma_max.min(mu_max / avrg_compl);
725        let sigma_lo = 1.0;
726        if sigma_lo >= sigma_up {
727            sigma_up
728        } else {
729            golden_section(
730                sigma_lo, sigma_up, qf_1, -100.0, sigma_tol, qf_tol, max_steps, q,
731            )
732        }
733    } else {
734        // q decreases for σ < 1 — search down.
735        let sigma_lo = sigma_min.max(mu_min / avrg_compl);
736        let sigma_up = sigma_lo.max(sigma_1minus).min(mu_max / avrg_compl);
737        if sigma_lo >= sigma_up {
738            sigma_lo
739        } else {
740            golden_section(
741                sigma_lo, sigma_up, -100.0, qf_1minus, sigma_tol, qf_tol, max_steps, q,
742            )
743        }
744    }
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750
751    #[test]
752    fn golden_section_minimizes_parabola() {
753        // q(σ) = (σ − 0.3)²; minimum at σ = 0.3.
754        let f = |s: f64| (s - 0.3).powi(2);
755        let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-6, 0.0, 50, f);
756        assert!((s - 0.3).abs() < 1e-3);
757    }
758
759    #[test]
760    fn golden_section_respects_max_steps() {
761        // Heavy max-step cap should still produce a reasonable σ.
762        let f = |s: f64| (s - 0.5).powi(2);
763        let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-12, 0.0, 5, f);
764        assert!((s - 0.5).abs() < 0.2);
765    }
766
767    #[test]
768    fn golden_section_handles_monotone() {
769        // q monotone increasing → minimum at lo end.
770        let f = |s: f64| s;
771        let s = golden_section(0.1, 2.0, 0.1, 2.0, 1e-6, 0.0, 50, f);
772        assert!(s < 0.2, "got s = {}", s);
773    }
774
775    #[test]
776    fn calculate_mu_returns_none_until_plumbed() {
777        let mut o = QualityFunctionMuOracle::new();
778        assert!(o.calculate_mu().is_none());
779    }
780
781    fn aggr(
782        d: Number,
783        p: Number,
784        c: Number,
785        nd: i32,
786        np: i32,
787        nc: i32,
788    ) -> QualityFunctionAggregates {
789        QualityFunctionAggregates {
790            dual_aggr: d,
791            primal_aggr: p,
792            compl_aggr: c,
793            n_dual: nd,
794            n_pri: np,
795            n_comp: nc,
796        }
797    }
798
799    #[test]
800    fn evaluate_one_norm_averages_by_n() {
801        // (1−α_du)*d/n_d + (1−α_pri)*p/n_p + c/n_c.
802        let q = evaluate_quality_function(
803            NormType::OneNorm,
804            CentralityType::None,
805            BalancingTermType::None,
806            0.5,  // α_pri
807            0.25, // α_du
808            1.0,
809            aggr(8.0, 4.0, 6.0, 4, 2, 3),
810        );
811        // d = 0.75 * 8 / 4 = 1.5; p = 0.5 * 4 / 2 = 1.0; c = 6/3 = 2.0; total = 4.5
812        assert!((q - 4.5).abs() < 1e-12, "got {}", q);
813    }
814
815    #[test]
816    fn evaluate_max_norm_does_not_divide() {
817        let q = evaluate_quality_function(
818            NormType::MaxNorm,
819            CentralityType::None,
820            BalancingTermType::None,
821            0.0,
822            0.0,
823            1.0,
824            aggr(2.0, 3.0, 5.0, 10, 10, 10),
825        );
826        assert!((q - 10.0).abs() < 1e-12);
827    }
828
829    #[test]
830    fn evaluate_two_norm_divides_by_sqrt_n() {
831        let q = evaluate_quality_function(
832            NormType::TwoNorm,
833            CentralityType::None,
834            BalancingTermType::None,
835            0.0,
836            0.0,
837            1.0,
838            aggr(2.0, 0.0, 4.0, 4, 0, 16),
839        );
840        // d = 2/2 = 1.0; p stays 0 (n_pri = 0 → no divide); c = 4/4 = 1.0
841        assert!((q - 2.0).abs() < 1e-12, "got {}", q);
842    }
843
844    #[test]
845    fn evaluate_one_norm_handles_zero_pri_dim() {
846        // n_pri = 0 ⇒ primal must not be divided.
847        let q = evaluate_quality_function(
848            NormType::OneNorm,
849            CentralityType::None,
850            BalancingTermType::None,
851            0.0,
852            0.0,
853            1.0,
854            aggr(0.0, 0.0, 1.0, 1, 0, 1),
855        );
856        assert!(q.is_finite() && (q - 1.0).abs() < 1e-12);
857    }
858
859    #[test]
860    fn evaluate_log_centrality_subtracts_compl_log_xi() {
861        let base = evaluate_quality_function(
862            NormType::MaxNorm,
863            CentralityType::None,
864            BalancingTermType::None,
865            0.0,
866            0.0,
867            std::f64::consts::E,
868            aggr(0.0, 0.0, 4.0, 1, 1, 1),
869        );
870        let logc = evaluate_quality_function(
871            NormType::MaxNorm,
872            CentralityType::LogCenter,
873            BalancingTermType::None,
874            0.0,
875            0.0,
876            std::f64::consts::E,
877            aggr(0.0, 0.0, 4.0, 1, 1, 1),
878        );
879        // Difference is −compl_inf · ln(xi) = −4 · 1 = −4.
880        assert!((base - logc - 4.0).abs() < 1e-12, "base={base} logc={logc}");
881    }
882
883    #[test]
884    fn evaluate_reciprocal_centrality_adds_c_over_xi() {
885        let q = evaluate_quality_function(
886            NormType::MaxNorm,
887            CentralityType::ReciprocalCenter,
888            BalancingTermType::None,
889            0.0,
890            0.0,
891            0.5,
892            aggr(0.0, 0.0, 1.0, 1, 1, 1),
893        );
894        // 1.0 + 1.0/0.5 = 3.0.
895        assert!((q - 3.0).abs() < 1e-12);
896    }
897
898    #[test]
899    fn evaluate_cubed_reciprocal_centrality_adds_c_over_xi3() {
900        let q = evaluate_quality_function(
901            NormType::MaxNorm,
902            CentralityType::CubedReciprocalCenter,
903            BalancingTermType::None,
904            0.0,
905            0.0,
906            0.5,
907            aggr(0.0, 0.0, 1.0, 1, 1, 1),
908        );
909        // 1.0 + 1.0/0.125 = 9.0.
910        assert!((q - 9.0).abs() < 1e-12);
911    }
912
913    #[test]
914    fn evaluate_cubic_balancing_adds_when_dual_dominates() {
915        let q = evaluate_quality_function(
916            NormType::MaxNorm,
917            CentralityType::None,
918            BalancingTermType::CubicTerm,
919            0.0,
920            0.0,
921            1.0,
922            aggr(5.0, 1.0, 2.0, 1, 1, 1),
923        );
924        // base = 5+1+2 = 8; dom = max(5,1) − 2 = 3; +27 → 35.
925        assert!((q - 35.0).abs() < 1e-12, "got {}", q);
926    }
927
928    #[test]
929    fn evaluate_cubic_balancing_zero_when_compl_dominates() {
930        let q = evaluate_quality_function(
931            NormType::MaxNorm,
932            CentralityType::None,
933            BalancingTermType::CubicTerm,
934            0.0,
935            0.0,
936            1.0,
937            aggr(1.0, 1.0, 5.0, 1, 1, 1),
938        );
939        // dom = max(1,1) − 5 = −4 → clamped to 0; total = 7.
940        assert!((q - 7.0).abs() < 1e-12);
941    }
942
943    #[test]
944    fn pick_sigma_searches_below_one_for_decreasing_q() {
945        // Parabola minimum at σ = 0.4 (well below 1).
946        let f = |s: f64| (s - 0.4).powi(2);
947        let s = pick_sigma(1e-9, 100.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
948        assert!((s - 0.4).abs() < 1e-2, "got s = {}", s);
949    }
950
951    #[test]
952    fn pick_sigma_searches_above_one_for_q_decreasing_in_sigma() {
953        // q decreases as σ grows ⇒ minimum at top of bracket.
954        let f = |s: f64| -s;
955        let s = pick_sigma(1e-9, 10.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
956        // bracket up-end is min(sigma_max=10, mu_max/avrg=1e5) = 10.
957        assert!(s > 5.0, "got s = {}", s);
958    }
959
960    #[test]
961    fn pick_sigma_clamps_to_mu_max_over_avrg_in_up_search() {
962        // mu_max/avrg = 2.0 should cap σ_up below sigma_max = 100.
963        let f = |s: f64| -s;
964        let s = pick_sigma(1e-9, 100.0, 1e-11, 2.0, 1.0, 1e-6, 0.0, 50, f);
965        assert!(s <= 2.0 + 1e-9 && s >= 1.0, "got s = {}", s);
966    }
967
968    #[test]
969    fn pick_sigma_clamps_to_mu_min_over_avrg_in_down_search() {
970        // mu_min/avrg = 0.5 must dominate σ_min = 1e-9.
971        // q monotone-decreasing toward 0 → search picks low end of bracket.
972        let f = |s: f64| s;
973        let s = pick_sigma(1e-9, 100.0, 0.5, 1e5, 1.0, 1e-6, 0.0, 50, f);
974        assert!(s >= 0.5 - 1e-9 && s <= 1.0, "got s = {}", s);
975    }
976}