1use crate::ipopt_cq::IpoptCqHandle;
24use crate::ipopt_data::IpoptDataHandle;
25use crate::ipopt_nlp::IpoptNlp;
26use crate::iterates_vector::IteratesVector;
27use crate::kkt::pd_search_dir_calc::PdSearchDirCalc;
28use crate::mu::oracle::r#trait::MuOracle;
29use pounce_common::types::Number;
30use pounce_linalg::Vector;
31use std::cell::RefCell;
32use std::rc::Rc;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum NormType {
36 OneNorm,
37 TwoNormSquared,
40 TwoNorm,
41 MaxNorm,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum CentralityType {
46 None,
47 LogCenter,
48 ReciprocalCenter,
49 CubedReciprocalCenter,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum BalancingTermType {
54 None,
55 CubicTerm,
56}
57
58pub struct QualityFunctionMuOracle {
59 pub norm_type: NormType,
60 pub centrality_type: CentralityType,
61 pub balancing_term: BalancingTermType,
62 pub max_section_steps: i32,
63 pub section_sigma_tol: Number,
64 pub section_qf_tol: Number,
65 pub sigma_max: Number,
66 pub sigma_min: Number,
67 pub mu_min: Number,
68 pub mu_max: Number,
69}
70
71impl Default for QualityFunctionMuOracle {
72 fn default() -> Self {
73 Self {
75 norm_type: NormType::TwoNormSquared,
76 centrality_type: CentralityType::None,
77 balancing_term: BalancingTermType::None,
78 max_section_steps: 8,
79 section_sigma_tol: 1e-2,
80 section_qf_tol: 0.0,
81 sigma_max: 100.0,
82 sigma_min: 1e-6,
91 mu_min: 1e-11,
92 mu_max: 1e5,
93 }
94 }
95}
96
97impl QualityFunctionMuOracle {
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 #[allow(clippy::too_many_lines)]
113 pub fn calculate_mu_with_predictor_centering(
114 &mut self,
115 data: &IpoptDataHandle,
116 cq: &IpoptCqHandle,
117 nlp: &Rc<RefCell<dyn IpoptNlp>>,
118 pd_search_dir: &mut PdSearchDirCalc,
119 ) -> Option<Number> {
120 if !pd_search_dir.compute_affine_step(data, cq, nlp) {
121 return None;
122 }
123 if !pd_search_dir.compute_centering_step(data, cq, nlp) {
124 return None;
125 }
126
127 let delta_aff: IteratesVector = data.borrow().delta_aff.clone()?;
128 let delta_cen: IteratesVector = data.borrow().delta_cen.clone()?;
129
130 let nlp_ref = nlp.borrow();
134 let cq_ref = cq.borrow();
135 let curr_iv = cq_ref.curr_iv();
136 let curr_slack_x_l = cq_ref.curr_slack_x_l();
137 let curr_slack_x_u = cq_ref.curr_slack_x_u();
138 let curr_slack_s_l = cq_ref.curr_slack_s_l();
139 let curr_slack_s_u = cq_ref.curr_slack_s_u();
140 let avrg_compl = cq_ref.curr_avrg_compl();
141
142 let project = |sign_l_x: Number,
143 sign_u_x: Number,
144 step_x: &dyn Vector,
145 step_s: &dyn Vector|
146 -> [Rc<dyn Vector>; 4] {
147 let mut x_l = curr_slack_x_l.make_new();
148 nlp_ref
149 .px_l()
150 .trans_mult_vector(sign_l_x, step_x, 0.0, &mut *x_l);
151 let mut x_u = curr_slack_x_u.make_new();
152 nlp_ref
153 .px_u()
154 .trans_mult_vector(sign_u_x, step_x, 0.0, &mut *x_u);
155 let mut s_l = curr_slack_s_l.make_new();
156 nlp_ref
157 .pd_l()
158 .trans_mult_vector(sign_l_x, step_s, 0.0, &mut *s_l);
159 let mut s_u = curr_slack_s_u.make_new();
160 nlp_ref
161 .pd_u()
162 .trans_mult_vector(sign_u_x, step_s, 0.0, &mut *s_u);
163 [Rc::from(x_l), Rc::from(x_u), Rc::from(s_l), Rc::from(s_u)]
164 };
165
166 let [step_aff_x_l, step_aff_x_u, step_aff_s_l, step_aff_s_u] =
167 project(1.0, -1.0, &*delta_aff.x, &*delta_aff.s);
168 let [step_cen_x_l, step_cen_x_u, step_cen_s_l, step_cen_s_u] =
169 project(1.0, -1.0, &*delta_cen.x, &*delta_cen.s);
170
171 let step_aff_z_l = delta_aff.z_l.clone();
174 let step_aff_z_u = delta_aff.z_u.clone();
175 let step_aff_v_l = delta_aff.v_l.clone();
176 let step_aff_v_u = delta_aff.v_u.clone();
177 let step_cen_z_l = delta_cen.z_l.clone();
178 let step_cen_z_u = delta_cen.z_u.clone();
179 let step_cen_v_l = delta_cen.v_l.clone();
180 let step_cen_v_u = delta_cen.v_u.clone();
181
182 drop(nlp_ref);
186
187 let grad_lag_x = cq_ref.curr_grad_lag_x();
191 let grad_lag_s = cq_ref.curr_grad_lag_s();
192 let c = cq_ref.curr_c();
193 let d_minus_s = cq_ref.curr_d_minus_s();
194 let dual_aggr = match self.norm_type {
195 NormType::OneNorm => grad_lag_x.asum() + grad_lag_s.asum(),
196 NormType::TwoNormSquared => {
197 let nx = grad_lag_x.nrm2();
198 let ns = grad_lag_s.nrm2();
199 nx * nx + ns * ns
200 }
201 NormType::TwoNorm => {
202 let nx = grad_lag_x.nrm2();
203 let ns = grad_lag_s.nrm2();
204 (nx * nx + ns * ns).sqrt()
205 }
206 NormType::MaxNorm => grad_lag_x.amax().max(grad_lag_s.amax()),
207 };
208 let primal_aggr = match self.norm_type {
209 NormType::OneNorm => c.asum() + d_minus_s.asum(),
210 NormType::TwoNormSquared => {
211 let nc = c.nrm2();
212 let nd = d_minus_s.nrm2();
213 nc * nc + nd * nd
214 }
215 NormType::TwoNorm => {
216 let nc = c.nrm2();
217 let nd = d_minus_s.nrm2();
218 (nc * nc + nd * nd).sqrt()
219 }
220 NormType::MaxNorm => c.amax().max(d_minus_s.amax()),
221 };
222
223 let n_dual = curr_iv.x.dim() + curr_iv.s.dim();
224 let n_pri = curr_iv.y_c.dim() + curr_iv.y_d.dim();
225 let n_comp = curr_iv.z_l.dim() + curr_iv.z_u.dim() + curr_iv.v_l.dim() + curr_iv.v_u.dim();
226 let tau = data.borrow().curr_tau;
227
228 let curr_z_l = curr_iv.z_l.clone();
229 let curr_z_u = curr_iv.z_u.clone();
230 let curr_v_l = curr_iv.v_l.clone();
231 let curr_v_u = curr_iv.v_u.clone();
232
233 drop(cq_ref);
234
235 let norm_type = self.norm_type;
236 let centrality = self.centrality_type;
237 let balancing = self.balancing_term;
238
239 let mut eval_q = |sigma: Number| -> Number {
243 let combine = |aff: &Rc<dyn Vector>, cen: &Rc<dyn Vector>| -> Box<dyn Vector> {
245 let mut out = aff.make_new();
246 out.set(0.0);
247 out.add_two_vectors(1.0, &**aff, sigma, &**cen, 0.0);
248 out
249 };
250 let stp_x_l = combine(&step_aff_x_l, &step_cen_x_l);
251 let stp_x_u = combine(&step_aff_x_u, &step_cen_x_u);
252 let stp_s_l = combine(&step_aff_s_l, &step_cen_s_l);
253 let stp_s_u = combine(&step_aff_s_u, &step_cen_s_u);
254 let stp_z_l = combine(&step_aff_z_l, &step_cen_z_l);
255 let stp_z_u = combine(&step_aff_z_u, &step_cen_z_u);
256 let stp_v_l = combine(&step_aff_v_l, &step_cen_v_l);
257 let stp_v_u = combine(&step_aff_v_u, &step_cen_v_u);
258
259 let alpha_pri = curr_slack_x_l
261 .frac_to_bound(&*stp_x_l, tau)
262 .min(curr_slack_x_u.frac_to_bound(&*stp_x_u, tau))
263 .min(curr_slack_s_l.frac_to_bound(&*stp_s_l, tau))
264 .min(curr_slack_s_u.frac_to_bound(&*stp_s_u, tau));
265 let alpha_du = curr_z_l
266 .frac_to_bound(&*stp_z_l, tau)
267 .min(curr_z_u.frac_to_bound(&*stp_z_u, tau))
268 .min(curr_v_l.frac_to_bound(&*stp_v_l, tau))
269 .min(curr_v_u.frac_to_bound(&*stp_v_u, tau));
270
271 let mut trial_s_x_l = curr_slack_x_l.make_new();
273 trial_s_x_l.set(0.0);
274 trial_s_x_l.add_two_vectors(1.0, &*curr_slack_x_l, alpha_pri, &*stp_x_l, 0.0);
275 let mut trial_s_x_u = curr_slack_x_u.make_new();
276 trial_s_x_u.set(0.0);
277 trial_s_x_u.add_two_vectors(1.0, &*curr_slack_x_u, alpha_pri, &*stp_x_u, 0.0);
278 let mut trial_s_s_l = curr_slack_s_l.make_new();
279 trial_s_s_l.set(0.0);
280 trial_s_s_l.add_two_vectors(1.0, &*curr_slack_s_l, alpha_pri, &*stp_s_l, 0.0);
281 let mut trial_s_s_u = curr_slack_s_u.make_new();
282 trial_s_s_u.set(0.0);
283 trial_s_s_u.add_two_vectors(1.0, &*curr_slack_s_u, alpha_pri, &*stp_s_u, 0.0);
284
285 let mut trial_z_l = curr_z_l.make_new();
286 trial_z_l.set(0.0);
287 trial_z_l.add_two_vectors(1.0, &*curr_z_l, alpha_du, &*stp_z_l, 0.0);
288 let mut trial_z_u = curr_z_u.make_new();
289 trial_z_u.set(0.0);
290 trial_z_u.add_two_vectors(1.0, &*curr_z_u, alpha_du, &*stp_z_u, 0.0);
291 let mut trial_v_l = curr_v_l.make_new();
292 trial_v_l.set(0.0);
293 trial_v_l.add_two_vectors(1.0, &*curr_v_l, alpha_du, &*stp_v_l, 0.0);
294 let mut trial_v_u = curr_v_u.make_new();
295 trial_v_u.set(0.0);
296 trial_v_u.add_two_vectors(1.0, &*curr_v_u, alpha_du, &*stp_v_u, 0.0);
297
298 trial_s_x_l.element_wise_multiply(&*trial_z_l);
300 trial_s_x_u.element_wise_multiply(&*trial_z_u);
301 trial_s_s_l.element_wise_multiply(&*trial_v_l);
302 trial_s_s_u.element_wise_multiply(&*trial_v_u);
303
304 let compl_aggr = match norm_type {
305 NormType::OneNorm => {
306 trial_s_x_l.asum()
307 + trial_s_x_u.asum()
308 + trial_s_s_l.asum()
309 + trial_s_s_u.asum()
310 }
311 NormType::TwoNormSquared => {
312 let a = trial_s_x_l.nrm2();
313 let b = trial_s_x_u.nrm2();
314 let c = trial_s_s_l.nrm2();
315 let d = trial_s_s_u.nrm2();
316 a * a + b * b + c * c + d * d
317 }
318 NormType::TwoNorm => {
319 let a = trial_s_x_l.nrm2();
320 let b = trial_s_x_u.nrm2();
321 let c = trial_s_s_l.nrm2();
322 let d = trial_s_s_u.nrm2();
323 (a * a + b * b + c * c + d * d).sqrt()
324 }
325 NormType::MaxNorm => trial_s_x_l
326 .amax()
327 .max(trial_s_x_u.amax())
328 .max(trial_s_s_l.amax())
329 .max(trial_s_s_u.amax()),
330 };
331
332 let xi = if matches!(centrality, CentralityType::None) {
333 1.0
334 } else {
335 let total = trial_s_x_l.asum()
339 + trial_s_x_u.asum()
340 + trial_s_s_l.asum()
341 + trial_s_s_u.asum();
342 let avg = if n_comp > 0 {
343 total / n_comp as Number
344 } else {
345 1.0
346 };
347 let mn = trial_s_x_l
348 .min()
349 .min(trial_s_x_u.min())
350 .min(trial_s_s_l.min())
351 .min(trial_s_s_u.min());
352 if avg > 0.0 {
353 mn / avg
354 } else {
355 1.0
356 }
357 };
358
359 let aggr = QualityFunctionAggregates {
360 dual_aggr,
361 primal_aggr,
362 compl_aggr,
363 n_dual,
364 n_pri,
365 n_comp,
366 };
367
368 if std::env::var("POUNCE_DBG_QF_AGGR").is_ok() {
369 tracing::debug!(target: "pounce::mu",
370 "[QF_AGGR] σ={:.6e} α_pri={:.6e} α_du={:.6e} xi={:.6e} dual_aggr={:.6e} primal_aggr={:.6e} compl_aggr={:.6e} n_dual={} n_pri={} n_comp={}",
371 sigma, alpha_pri, alpha_du, xi,
372 dual_aggr, primal_aggr, compl_aggr,
373 n_dual, n_pri, n_comp
374 );
375 }
376
377 evaluate_quality_function(
378 norm_type, centrality, balancing, alpha_pri, alpha_du, xi, aggr,
379 )
380 };
381
382 if let Ok(s) = std::env::var("POUNCE_DBG_QF_SWEEP") {
386 if let Ok(target_iter) = s.parse::<i32>() {
387 if data.borrow().iter_count == target_iter {
388 let lo = self.sigma_min.max(self.mu_min / avrg_compl);
389 let hi = self.sigma_max.min(self.mu_max / avrg_compl).max(lo * 10.0);
390 let log_lo = lo.ln();
391 let log_hi = hi.ln();
392 tracing::debug!(target: "pounce::mu", "[QF_SWEEP] iter={} avrg_compl={:.6e} σ_range=[{:.3e},{:.3e}] sigma_min={:.3e} sigma_max={:.3e} mu_min={:.3e} mu_max={:.3e}",
393 target_iter, avrg_compl, lo, hi,
394 self.sigma_min, self.sigma_max, self.mu_min, self.mu_max);
395 let n = 21;
396 for i in 0..n {
397 let frac = i as f64 / (n - 1) as f64;
398 let sig = (log_lo + frac * (log_hi - log_lo)).exp();
399 let q = eval_q(sig);
400 tracing::debug!(target: "pounce::mu", "[QF_SWEEP] σ={:.6e} q={:.10e}", sig, q);
401 }
402 let q1 = eval_q(1.0);
403 let s1m = 1.0 - self.section_sigma_tol.max(1e-4);
404 let q1m = eval_q(s1m);
405 tracing::debug!(target: "pounce::mu",
406 "[QF_SWEEP] σ=1.0 q={:.10e} σ={:.6e} q={:.10e} (q_1minus>q_1: {})",
407 q1,
408 s1m,
409 q1m,
410 q1m > q1
411 );
412 }
413 }
414 }
415
416 let sigma = pick_sigma(
417 self.sigma_min,
418 self.sigma_max,
419 self.mu_min,
420 self.mu_max,
421 avrg_compl,
422 self.section_sigma_tol,
423 self.section_qf_tol,
424 self.max_section_steps,
425 &mut eval_q,
426 );
427
428 let mu_new = sigma * avrg_compl;
429 let mu_clamped = mu_new.clamp(self.mu_min, self.mu_max);
430 if std::env::var("POUNCE_DBG_QF").is_ok() {
431 let iter_count = data.borrow().iter_count;
432 let curr_mu = data.borrow().curr_mu;
433 let sigma_floor = self.sigma_min.max(self.mu_min / avrg_compl);
434 let sigma_up_dn = sigma_floor
435 .max(1.0 - self.section_sigma_tol.max(1e-4))
436 .min(self.mu_max / avrg_compl);
437 tracing::debug!(target: "pounce::mu",
438 "[QF] iter={} curr_mu={:.3e} avrg_compl={:.3e} sigma={:.3e} mu_new={:.3e} mu_clamped={:.3e} | sigma_min={:.3e} mu_min={:.3e} sigma_lo_dn={:.3e} sigma_up_dn={:.3e} mu_min/avrg={:.3e}",
439 iter_count, curr_mu, avrg_compl, sigma, mu_new, mu_clamped,
440 self.sigma_min, self.mu_min, sigma_floor, sigma_up_dn,
441 self.mu_min / avrg_compl,
442 );
443 }
444 Some(mu_clamped)
445 }
446}
447
448impl MuOracle for QualityFunctionMuOracle {
449 fn calculate_mu(&mut self) -> Option<Number> {
450 None
455 }
456}
457
458pub fn golden_section(
472 sigma_lo_in: Number,
473 sigma_up_in: Number,
474 q_lo_in: Number,
475 q_up_in: Number,
476 sigma_tol: Number,
477 qf_tol: Number,
478 max_steps: i32,
479 mut q: impl FnMut(Number) -> Number,
480) -> Number {
481 let mut sigma_lo = sigma_lo_in;
482 let mut sigma_up = sigma_up_in;
483 let mut q_lo = q_lo_in;
484 let mut q_up = q_up_in;
485
486 let gfac = (3.0 - 5.0_f64.sqrt()) / 2.0;
487 let mut sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
488 let mut sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
489 let mut qmid1 = q(sigma_mid1);
490 let mut qmid2 = q(sigma_mid2);
491
492 let mut nsections = 0;
493 let mut width_ok;
494 let mut qf_ok;
495 loop {
496 width_ok = (sigma_up - sigma_lo) >= sigma_tol * sigma_up;
497 let qmin = q_lo.min(q_up).min(qmid1).min(qmid2);
498 let qmax = q_lo.max(q_up).max(qmid1).max(qmid2);
499 qf_ok = qmax > 0.0 && (1.0 - qmin / qmax) >= qf_tol;
500 if !(width_ok && qf_ok && nsections < max_steps) {
501 break;
502 }
503 nsections += 1;
504 if qmid1 > qmid2 {
505 sigma_lo = sigma_mid1;
506 q_lo = qmid1;
507 sigma_mid1 = sigma_mid2;
508 qmid1 = qmid2;
509 sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
510 qmid2 = q(sigma_mid2);
511 } else {
512 sigma_up = sigma_mid2;
513 q_up = qmid2;
514 sigma_mid2 = sigma_mid1;
515 qmid2 = qmid1;
516 sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
517 qmid1 = q(sigma_mid1);
518 }
519 }
520
521 if width_ok && !qf_ok {
545 if sigma_lo == sigma_lo_in && q_lo < 0.0 {
553 q_lo = q(sigma_lo);
554 }
555 if sigma_up == sigma_up_in && q_up < 0.0 {
556 q_up = q(sigma_up);
557 }
558 let mut best_s = sigma_lo;
559 let mut best_q = q_lo;
560 if q_up < best_q {
561 best_s = sigma_up;
562 best_q = q_up;
563 }
564 if qmid1 < best_q {
565 best_s = sigma_mid1;
566 best_q = qmid1;
567 }
568 if qmid2 < best_q {
569 best_s = sigma_mid2;
570 }
571 return best_s;
572 }
573 let (mut sigma, mut qval) = if qmid1 < qmid2 {
574 (sigma_mid1, qmid1)
575 } else {
576 (sigma_mid2, qmid2)
577 };
578 if sigma_up == sigma_up_in {
579 let qtmp = if q_up < 0.0 { q(sigma_up) } else { q_up };
580 if qtmp < qval {
581 sigma = sigma_up;
582 qval = qtmp;
583 }
584 } else if sigma_lo == sigma_lo_in {
585 let qtmp = if q_lo < 0.0 { q(sigma_lo) } else { q_lo };
586 if qtmp < qval {
587 sigma = sigma_lo;
588 }
589 }
590 let _ = qval;
591 sigma
592}
593
594#[derive(Debug, Clone, Copy)]
609pub struct QualityFunctionAggregates {
610 pub dual_aggr: Number,
611 pub primal_aggr: Number,
612 pub compl_aggr: Number,
613 pub n_dual: i32,
614 pub n_pri: i32,
615 pub n_comp: i32,
616}
617
618pub fn evaluate_quality_function(
627 norm: NormType,
628 centrality: CentralityType,
629 balancing: BalancingTermType,
630 alpha_primal: Number,
631 alpha_dual: Number,
632 xi: Number,
633 aggr: QualityFunctionAggregates,
634) -> Number {
635 let (mut dual_inf, mut primal_inf, mut compl_inf) = match norm {
636 NormType::OneNorm => {
637 let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
638 let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
639 let mut c = aggr.compl_aggr;
640 d /= aggr.n_dual as Number;
641 if aggr.n_pri > 0 {
642 p /= aggr.n_pri as Number;
643 }
644 debug_assert!(aggr.n_comp > 0);
645 c /= aggr.n_comp as Number;
646 (d, p, c)
647 }
648 NormType::TwoNormSquared => {
649 let mut d = (1.0 - alpha_dual).powi(2) * aggr.dual_aggr;
653 let mut p = (1.0 - alpha_primal).powi(2) * aggr.primal_aggr;
654 let mut c = aggr.compl_aggr;
655 d /= aggr.n_dual as Number;
656 if aggr.n_pri > 0 {
657 p /= aggr.n_pri as Number;
658 }
659 debug_assert!(aggr.n_comp > 0);
660 c /= aggr.n_comp as Number;
661 (d, p, c)
662 }
663 NormType::MaxNorm => (
664 (1.0 - alpha_dual) * aggr.dual_aggr,
665 (1.0 - alpha_primal) * aggr.primal_aggr,
666 aggr.compl_aggr,
667 ),
668 NormType::TwoNorm => {
669 let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
670 let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
671 let mut c = aggr.compl_aggr;
672 d /= (aggr.n_dual as Number).sqrt();
673 if aggr.n_pri > 0 {
674 p /= (aggr.n_pri as Number).sqrt();
675 }
676 debug_assert!(aggr.n_comp > 0);
677 c /= (aggr.n_comp as Number).sqrt();
678 (d, p, c)
679 }
680 };
681
682 if dual_inf.is_nan() {
684 dual_inf = 0.0;
685 }
686 if primal_inf.is_nan() {
687 primal_inf = 0.0;
688 }
689 if compl_inf.is_nan() {
690 compl_inf = 0.0;
691 }
692
693 let mut q = dual_inf + primal_inf + compl_inf;
694
695 match centrality {
696 CentralityType::None => {}
697 CentralityType::LogCenter => q -= compl_inf * xi.ln(),
698 CentralityType::ReciprocalCenter => q += compl_inf / xi,
699 CentralityType::CubedReciprocalCenter => q += compl_inf / xi.powi(3),
700 }
701
702 match balancing {
703 BalancingTermType::None => {}
704 BalancingTermType::CubicTerm => {
705 let dom = dual_inf.max(primal_inf) - compl_inf;
706 q += dom.max(0.0).powi(3);
707 }
708 }
709
710 q
711}
712
713#[allow(clippy::too_many_arguments)]
724pub fn pick_sigma(
725 sigma_min: Number,
726 sigma_max: Number,
727 mu_min: Number,
728 mu_max: Number,
729 avrg_compl: Number,
730 sigma_tol: Number,
731 qf_tol: Number,
732 max_steps: i32,
733 mut q: impl FnMut(Number) -> Number,
734) -> Number {
735 let qf_1 = q(1.0);
736 let sigma_1minus = 1.0 - sigma_tol.max(1e-4);
737 let qf_1minus = q(sigma_1minus);
738
739 if qf_1minus > qf_1 {
740 let sigma_up = sigma_max.min(mu_max / avrg_compl);
742 let sigma_lo = 1.0;
743 if sigma_lo >= sigma_up {
744 sigma_up
745 } else {
746 golden_section(
747 sigma_lo, sigma_up, qf_1, -100.0, sigma_tol, qf_tol, max_steps, q,
748 )
749 }
750 } else {
751 let sigma_lo = sigma_min.max(mu_min / avrg_compl);
753 let sigma_up = sigma_lo.max(sigma_1minus).min(mu_max / avrg_compl);
754 if sigma_lo >= sigma_up {
755 sigma_lo
756 } else {
757 golden_section(
758 sigma_lo, sigma_up, -100.0, qf_1minus, sigma_tol, qf_tol, max_steps, q,
759 )
760 }
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn golden_section_minimizes_parabola() {
770 let f = |s: f64| (s - 0.3).powi(2);
772 let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-6, 0.0, 50, f);
773 assert!((s - 0.3).abs() < 1e-3);
774 }
775
776 #[test]
777 fn golden_section_respects_max_steps() {
778 let f = |s: f64| (s - 0.5).powi(2);
780 let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-12, 0.0, 5, f);
781 assert!((s - 0.5).abs() < 0.2);
782 }
783
784 #[test]
785 fn golden_section_handles_monotone() {
786 let f = |s: f64| s;
788 let s = golden_section(0.1, 2.0, 0.1, 2.0, 1e-6, 0.0, 50, f);
789 assert!(s < 0.2, "got s = {}", s);
790 }
791
792 #[test]
793 fn golden_section_never_returns_unevaluated_sentinel() {
794 let sigma_lo = 1.0_f64;
805 let sigma_up = 3.0_f64;
806 let q = move |s: f64| if s == sigma_up { 50.0 } else { -s };
809 let s = golden_section(sigma_lo, sigma_up, q(sigma_lo), -100.0, 1e-3, 0.0, 50, q);
811 assert!(
812 s < sigma_up,
813 "golden_section returned the unevaluated sentinel endpoint σ = {} \
814 (true q there = {}, the bracket maximum); it must re-evaluate the \
815 sentinel before selecting a minimum",
816 s,
817 q(s)
818 );
819 }
820
821 #[test]
822 fn calculate_mu_returns_none_until_plumbed() {
823 let mut o = QualityFunctionMuOracle::new();
824 assert!(o.calculate_mu().is_none());
825 }
826
827 fn aggr(
828 d: Number,
829 p: Number,
830 c: Number,
831 nd: i32,
832 np: i32,
833 nc: i32,
834 ) -> QualityFunctionAggregates {
835 QualityFunctionAggregates {
836 dual_aggr: d,
837 primal_aggr: p,
838 compl_aggr: c,
839 n_dual: nd,
840 n_pri: np,
841 n_comp: nc,
842 }
843 }
844
845 #[test]
846 fn evaluate_one_norm_averages_by_n() {
847 let q = evaluate_quality_function(
849 NormType::OneNorm,
850 CentralityType::None,
851 BalancingTermType::None,
852 0.5, 0.25, 1.0,
855 aggr(8.0, 4.0, 6.0, 4, 2, 3),
856 );
857 assert!((q - 4.5).abs() < 1e-12, "got {}", q);
859 }
860
861 #[test]
862 fn evaluate_max_norm_does_not_divide() {
863 let q = evaluate_quality_function(
864 NormType::MaxNorm,
865 CentralityType::None,
866 BalancingTermType::None,
867 0.0,
868 0.0,
869 1.0,
870 aggr(2.0, 3.0, 5.0, 10, 10, 10),
871 );
872 assert!((q - 10.0).abs() < 1e-12);
873 }
874
875 #[test]
876 fn evaluate_two_norm_divides_by_sqrt_n() {
877 let q = evaluate_quality_function(
878 NormType::TwoNorm,
879 CentralityType::None,
880 BalancingTermType::None,
881 0.0,
882 0.0,
883 1.0,
884 aggr(2.0, 0.0, 4.0, 4, 0, 16),
885 );
886 assert!((q - 2.0).abs() < 1e-12, "got {}", q);
888 }
889
890 #[test]
891 fn evaluate_one_norm_handles_zero_pri_dim() {
892 let q = evaluate_quality_function(
894 NormType::OneNorm,
895 CentralityType::None,
896 BalancingTermType::None,
897 0.0,
898 0.0,
899 1.0,
900 aggr(0.0, 0.0, 1.0, 1, 0, 1),
901 );
902 assert!(q.is_finite() && (q - 1.0).abs() < 1e-12);
903 }
904
905 #[test]
906 fn evaluate_log_centrality_subtracts_compl_log_xi() {
907 let base = evaluate_quality_function(
908 NormType::MaxNorm,
909 CentralityType::None,
910 BalancingTermType::None,
911 0.0,
912 0.0,
913 std::f64::consts::E,
914 aggr(0.0, 0.0, 4.0, 1, 1, 1),
915 );
916 let logc = evaluate_quality_function(
917 NormType::MaxNorm,
918 CentralityType::LogCenter,
919 BalancingTermType::None,
920 0.0,
921 0.0,
922 std::f64::consts::E,
923 aggr(0.0, 0.0, 4.0, 1, 1, 1),
924 );
925 assert!((base - logc - 4.0).abs() < 1e-12, "base={base} logc={logc}");
927 }
928
929 #[test]
930 fn evaluate_reciprocal_centrality_adds_c_over_xi() {
931 let q = evaluate_quality_function(
932 NormType::MaxNorm,
933 CentralityType::ReciprocalCenter,
934 BalancingTermType::None,
935 0.0,
936 0.0,
937 0.5,
938 aggr(0.0, 0.0, 1.0, 1, 1, 1),
939 );
940 assert!((q - 3.0).abs() < 1e-12);
942 }
943
944 #[test]
945 fn evaluate_cubed_reciprocal_centrality_adds_c_over_xi3() {
946 let q = evaluate_quality_function(
947 NormType::MaxNorm,
948 CentralityType::CubedReciprocalCenter,
949 BalancingTermType::None,
950 0.0,
951 0.0,
952 0.5,
953 aggr(0.0, 0.0, 1.0, 1, 1, 1),
954 );
955 assert!((q - 9.0).abs() < 1e-12);
957 }
958
959 #[test]
960 fn evaluate_cubic_balancing_adds_when_dual_dominates() {
961 let q = evaluate_quality_function(
962 NormType::MaxNorm,
963 CentralityType::None,
964 BalancingTermType::CubicTerm,
965 0.0,
966 0.0,
967 1.0,
968 aggr(5.0, 1.0, 2.0, 1, 1, 1),
969 );
970 assert!((q - 35.0).abs() < 1e-12, "got {}", q);
972 }
973
974 #[test]
975 fn evaluate_cubic_balancing_zero_when_compl_dominates() {
976 let q = evaluate_quality_function(
977 NormType::MaxNorm,
978 CentralityType::None,
979 BalancingTermType::CubicTerm,
980 0.0,
981 0.0,
982 1.0,
983 aggr(1.0, 1.0, 5.0, 1, 1, 1),
984 );
985 assert!((q - 7.0).abs() < 1e-12);
987 }
988
989 #[test]
990 fn pick_sigma_searches_below_one_for_decreasing_q() {
991 let f = |s: f64| (s - 0.4).powi(2);
993 let s = pick_sigma(1e-9, 100.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
994 assert!((s - 0.4).abs() < 1e-2, "got s = {}", s);
995 }
996
997 #[test]
998 fn pick_sigma_searches_above_one_for_q_decreasing_in_sigma() {
999 let f = |s: f64| -s;
1001 let s = pick_sigma(1e-9, 10.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
1002 assert!(s > 5.0, "got s = {}", s);
1004 }
1005
1006 #[test]
1007 fn pick_sigma_clamps_to_mu_max_over_avrg_in_up_search() {
1008 let f = |s: f64| -s;
1010 let s = pick_sigma(1e-9, 100.0, 1e-11, 2.0, 1.0, 1e-6, 0.0, 50, f);
1011 assert!(s <= 2.0 + 1e-9 && s >= 1.0, "got s = {}", s);
1012 }
1013
1014 #[test]
1015 fn pick_sigma_clamps_to_mu_min_over_avrg_in_down_search() {
1016 let f = |s: f64| s;
1019 let s = pick_sigma(1e-9, 100.0, 0.5, 1e5, 1.0, 1e-6, 0.0, 50, f);
1020 assert!(s >= 0.5 - 1e-9 && s <= 1.0, "got s = {}", s);
1021 }
1022}