1use crate::ipopt_cq::IpoptCqHandle;
24use crate::ipopt_data::IpoptDataHandle;
25use crate::ipopt_nlp::IpoptNlp;
26use crate::iterates_vector::IteratesVector;
27use crate::kkt::pd_search_dir_calc::PdSearchDirCalc;
28use crate::mu::oracle::r#trait::MuOracle;
29use pounce_common::types::Number;
30use pounce_linalg::Vector;
31use std::cell::RefCell;
32use std::rc::Rc;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum NormType {
36 OneNorm,
37 TwoNormSquared,
40 TwoNorm,
41 MaxNorm,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum CentralityType {
46 None,
47 LogCenter,
48 ReciprocalCenter,
49 CubedReciprocalCenter,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum BalancingTermType {
54 None,
55 CubicTerm,
56}
57
58pub struct QualityFunctionMuOracle {
59 pub norm_type: NormType,
60 pub centrality_type: CentralityType,
61 pub balancing_term: BalancingTermType,
62 pub max_section_steps: i32,
63 pub section_sigma_tol: Number,
64 pub section_qf_tol: Number,
65 pub sigma_max: Number,
66 pub sigma_min: Number,
67 pub mu_min: Number,
68 pub mu_max: Number,
69}
70
71impl Default for QualityFunctionMuOracle {
72 fn default() -> Self {
73 Self {
75 norm_type: NormType::TwoNormSquared,
76 centrality_type: CentralityType::None,
77 balancing_term: BalancingTermType::None,
78 max_section_steps: 8,
79 section_sigma_tol: 1e-2,
80 section_qf_tol: 0.0,
81 sigma_max: 100.0,
82 sigma_min: 1e-6,
91 mu_min: 1e-11,
92 mu_max: 1e5,
93 }
94 }
95}
96
97impl QualityFunctionMuOracle {
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 #[allow(clippy::too_many_lines)]
113 pub fn calculate_mu_with_predictor_centering(
114 &mut self,
115 data: &IpoptDataHandle,
116 cq: &IpoptCqHandle,
117 nlp: &Rc<RefCell<dyn IpoptNlp>>,
118 pd_search_dir: &mut PdSearchDirCalc,
119 ) -> Option<Number> {
120 if !pd_search_dir.compute_affine_step(data, cq, nlp) {
121 return None;
122 }
123 if !pd_search_dir.compute_centering_step(data, cq, nlp) {
124 return None;
125 }
126
127 let delta_aff: IteratesVector = data.borrow().delta_aff.clone()?;
128 let delta_cen: IteratesVector = data.borrow().delta_cen.clone()?;
129
130 let nlp_ref = nlp.borrow();
134 let cq_ref = cq.borrow();
135 let curr_iv = cq_ref.curr_iv();
136 let curr_slack_x_l = cq_ref.curr_slack_x_l();
137 let curr_slack_x_u = cq_ref.curr_slack_x_u();
138 let curr_slack_s_l = cq_ref.curr_slack_s_l();
139 let curr_slack_s_u = cq_ref.curr_slack_s_u();
140 let avrg_compl = cq_ref.curr_avrg_compl();
141
142 let project = |sign_l_x: Number,
143 sign_u_x: Number,
144 step_x: &dyn Vector,
145 step_s: &dyn Vector|
146 -> [Rc<dyn Vector>; 4] {
147 let mut x_l = curr_slack_x_l.make_new();
148 nlp_ref
149 .px_l()
150 .trans_mult_vector(sign_l_x, step_x, 0.0, &mut *x_l);
151 let mut x_u = curr_slack_x_u.make_new();
152 nlp_ref
153 .px_u()
154 .trans_mult_vector(sign_u_x, step_x, 0.0, &mut *x_u);
155 let mut s_l = curr_slack_s_l.make_new();
156 nlp_ref
157 .pd_l()
158 .trans_mult_vector(sign_l_x, step_s, 0.0, &mut *s_l);
159 let mut s_u = curr_slack_s_u.make_new();
160 nlp_ref
161 .pd_u()
162 .trans_mult_vector(sign_u_x, step_s, 0.0, &mut *s_u);
163 [Rc::from(x_l), Rc::from(x_u), Rc::from(s_l), Rc::from(s_u)]
164 };
165
166 let [step_aff_x_l, step_aff_x_u, step_aff_s_l, step_aff_s_u] =
167 project(1.0, -1.0, &*delta_aff.x, &*delta_aff.s);
168 let [step_cen_x_l, step_cen_x_u, step_cen_s_l, step_cen_s_u] =
169 project(1.0, -1.0, &*delta_cen.x, &*delta_cen.s);
170
171 let step_aff_z_l = delta_aff.z_l.clone();
174 let step_aff_z_u = delta_aff.z_u.clone();
175 let step_aff_v_l = delta_aff.v_l.clone();
176 let step_aff_v_u = delta_aff.v_u.clone();
177 let step_cen_z_l = delta_cen.z_l.clone();
178 let step_cen_z_u = delta_cen.z_u.clone();
179 let step_cen_v_l = delta_cen.v_l.clone();
180 let step_cen_v_u = delta_cen.v_u.clone();
181
182 drop(nlp_ref);
186
187 let grad_lag_x = cq_ref.curr_grad_lag_x();
191 let grad_lag_s = cq_ref.curr_grad_lag_s();
192 let c = cq_ref.curr_c();
193 let d_minus_s = cq_ref.curr_d_minus_s();
194 let dual_aggr = match self.norm_type {
195 NormType::OneNorm => grad_lag_x.asum() + grad_lag_s.asum(),
196 NormType::TwoNormSquared => {
197 let nx = grad_lag_x.nrm2();
198 let ns = grad_lag_s.nrm2();
199 nx * nx + ns * ns
200 }
201 NormType::TwoNorm => {
202 let nx = grad_lag_x.nrm2();
203 let ns = grad_lag_s.nrm2();
204 (nx * nx + ns * ns).sqrt()
205 }
206 NormType::MaxNorm => grad_lag_x.amax().max(grad_lag_s.amax()),
207 };
208 let primal_aggr = match self.norm_type {
209 NormType::OneNorm => c.asum() + d_minus_s.asum(),
210 NormType::TwoNormSquared => {
211 let nc = c.nrm2();
212 let nd = d_minus_s.nrm2();
213 nc * nc + nd * nd
214 }
215 NormType::TwoNorm => {
216 let nc = c.nrm2();
217 let nd = d_minus_s.nrm2();
218 (nc * nc + nd * nd).sqrt()
219 }
220 NormType::MaxNorm => c.amax().max(d_minus_s.amax()),
221 };
222
223 let n_dual = curr_iv.x.dim() + curr_iv.s.dim();
224 let n_pri = curr_iv.y_c.dim() + curr_iv.y_d.dim();
225 let n_comp = curr_iv.z_l.dim() + curr_iv.z_u.dim() + curr_iv.v_l.dim() + curr_iv.v_u.dim();
226 let tau = data.borrow().curr_tau;
227
228 let curr_z_l = curr_iv.z_l.clone();
229 let curr_z_u = curr_iv.z_u.clone();
230 let curr_v_l = curr_iv.v_l.clone();
231 let curr_v_u = curr_iv.v_u.clone();
232
233 drop(cq_ref);
234
235 let norm_type = self.norm_type;
236 let centrality = self.centrality_type;
237 let balancing = self.balancing_term;
238
239 let mut eval_q = |sigma: Number| -> Number {
243 let combine = |aff: &Rc<dyn Vector>, cen: &Rc<dyn Vector>| -> Box<dyn Vector> {
245 let mut out = aff.make_new();
246 out.set(0.0);
247 out.add_two_vectors(1.0, &**aff, sigma, &**cen, 0.0);
248 out
249 };
250 let stp_x_l = combine(&step_aff_x_l, &step_cen_x_l);
251 let stp_x_u = combine(&step_aff_x_u, &step_cen_x_u);
252 let stp_s_l = combine(&step_aff_s_l, &step_cen_s_l);
253 let stp_s_u = combine(&step_aff_s_u, &step_cen_s_u);
254 let stp_z_l = combine(&step_aff_z_l, &step_cen_z_l);
255 let stp_z_u = combine(&step_aff_z_u, &step_cen_z_u);
256 let stp_v_l = combine(&step_aff_v_l, &step_cen_v_l);
257 let stp_v_u = combine(&step_aff_v_u, &step_cen_v_u);
258
259 let alpha_pri = curr_slack_x_l
261 .frac_to_bound(&*stp_x_l, tau)
262 .min(curr_slack_x_u.frac_to_bound(&*stp_x_u, tau))
263 .min(curr_slack_s_l.frac_to_bound(&*stp_s_l, tau))
264 .min(curr_slack_s_u.frac_to_bound(&*stp_s_u, tau));
265 let alpha_du = curr_z_l
266 .frac_to_bound(&*stp_z_l, tau)
267 .min(curr_z_u.frac_to_bound(&*stp_z_u, tau))
268 .min(curr_v_l.frac_to_bound(&*stp_v_l, tau))
269 .min(curr_v_u.frac_to_bound(&*stp_v_u, tau));
270
271 let mut trial_s_x_l = curr_slack_x_l.make_new();
273 trial_s_x_l.set(0.0);
274 trial_s_x_l.add_two_vectors(1.0, &*curr_slack_x_l, alpha_pri, &*stp_x_l, 0.0);
275 let mut trial_s_x_u = curr_slack_x_u.make_new();
276 trial_s_x_u.set(0.0);
277 trial_s_x_u.add_two_vectors(1.0, &*curr_slack_x_u, alpha_pri, &*stp_x_u, 0.0);
278 let mut trial_s_s_l = curr_slack_s_l.make_new();
279 trial_s_s_l.set(0.0);
280 trial_s_s_l.add_two_vectors(1.0, &*curr_slack_s_l, alpha_pri, &*stp_s_l, 0.0);
281 let mut trial_s_s_u = curr_slack_s_u.make_new();
282 trial_s_s_u.set(0.0);
283 trial_s_s_u.add_two_vectors(1.0, &*curr_slack_s_u, alpha_pri, &*stp_s_u, 0.0);
284
285 let mut trial_z_l = curr_z_l.make_new();
286 trial_z_l.set(0.0);
287 trial_z_l.add_two_vectors(1.0, &*curr_z_l, alpha_du, &*stp_z_l, 0.0);
288 let mut trial_z_u = curr_z_u.make_new();
289 trial_z_u.set(0.0);
290 trial_z_u.add_two_vectors(1.0, &*curr_z_u, alpha_du, &*stp_z_u, 0.0);
291 let mut trial_v_l = curr_v_l.make_new();
292 trial_v_l.set(0.0);
293 trial_v_l.add_two_vectors(1.0, &*curr_v_l, alpha_du, &*stp_v_l, 0.0);
294 let mut trial_v_u = curr_v_u.make_new();
295 trial_v_u.set(0.0);
296 trial_v_u.add_two_vectors(1.0, &*curr_v_u, alpha_du, &*stp_v_u, 0.0);
297
298 trial_s_x_l.element_wise_multiply(&*trial_z_l);
300 trial_s_x_u.element_wise_multiply(&*trial_z_u);
301 trial_s_s_l.element_wise_multiply(&*trial_v_l);
302 trial_s_s_u.element_wise_multiply(&*trial_v_u);
303
304 let compl_aggr = match norm_type {
305 NormType::OneNorm => {
306 trial_s_x_l.asum()
307 + trial_s_x_u.asum()
308 + trial_s_s_l.asum()
309 + trial_s_s_u.asum()
310 }
311 NormType::TwoNormSquared => {
312 let a = trial_s_x_l.nrm2();
313 let b = trial_s_x_u.nrm2();
314 let c = trial_s_s_l.nrm2();
315 let d = trial_s_s_u.nrm2();
316 a * a + b * b + c * c + d * d
317 }
318 NormType::TwoNorm => {
319 let a = trial_s_x_l.nrm2();
320 let b = trial_s_x_u.nrm2();
321 let c = trial_s_s_l.nrm2();
322 let d = trial_s_s_u.nrm2();
323 (a * a + b * b + c * c + d * d).sqrt()
324 }
325 NormType::MaxNorm => trial_s_x_l
326 .amax()
327 .max(trial_s_x_u.amax())
328 .max(trial_s_s_l.amax())
329 .max(trial_s_s_u.amax()),
330 };
331
332 let xi = if matches!(centrality, CentralityType::None) {
333 1.0
334 } else {
335 let total = trial_s_x_l.asum()
339 + trial_s_x_u.asum()
340 + trial_s_s_l.asum()
341 + trial_s_s_u.asum();
342 let avg = if n_comp > 0 {
343 total / n_comp as Number
344 } else {
345 1.0
346 };
347 let mn = trial_s_x_l
348 .min()
349 .min(trial_s_x_u.min())
350 .min(trial_s_s_l.min())
351 .min(trial_s_s_u.min());
352 if avg > 0.0 {
353 mn / avg
354 } else {
355 1.0
356 }
357 };
358
359 let aggr = QualityFunctionAggregates {
360 dual_aggr,
361 primal_aggr,
362 compl_aggr,
363 n_dual,
364 n_pri,
365 n_comp,
366 };
367
368 if std::env::var("POUNCE_DBG_QF_AGGR").is_ok() {
369 tracing::debug!(target: "pounce::mu",
370 "[QF_AGGR] σ={:.6e} α_pri={:.6e} α_du={:.6e} xi={:.6e} dual_aggr={:.6e} primal_aggr={:.6e} compl_aggr={:.6e} n_dual={} n_pri={} n_comp={}",
371 sigma, alpha_pri, alpha_du, xi,
372 dual_aggr, primal_aggr, compl_aggr,
373 n_dual, n_pri, n_comp
374 );
375 }
376
377 evaluate_quality_function(
378 norm_type, centrality, balancing, alpha_pri, alpha_du, xi, aggr,
379 )
380 };
381
382 if let Ok(s) = std::env::var("POUNCE_DBG_QF_SWEEP") {
386 if let Ok(target_iter) = s.parse::<i32>() {
387 if data.borrow().iter_count == target_iter {
388 let lo = self.sigma_min.max(self.mu_min / avrg_compl);
389 let hi = self.sigma_max.min(self.mu_max / avrg_compl).max(lo * 10.0);
390 let log_lo = lo.ln();
391 let log_hi = hi.ln();
392 tracing::debug!(target: "pounce::mu", "[QF_SWEEP] iter={} avrg_compl={:.6e} σ_range=[{:.3e},{:.3e}] sigma_min={:.3e} sigma_max={:.3e} mu_min={:.3e} mu_max={:.3e}",
393 target_iter, avrg_compl, lo, hi,
394 self.sigma_min, self.sigma_max, self.mu_min, self.mu_max);
395 let n = 21;
396 for i in 0..n {
397 let frac = i as f64 / (n - 1) as f64;
398 let sig = (log_lo + frac * (log_hi - log_lo)).exp();
399 let q = eval_q(sig);
400 tracing::debug!(target: "pounce::mu", "[QF_SWEEP] σ={:.6e} q={:.10e}", sig, q);
401 }
402 let q1 = eval_q(1.0);
403 let s1m = 1.0 - self.section_sigma_tol.max(1e-4);
404 let q1m = eval_q(s1m);
405 tracing::debug!(target: "pounce::mu",
406 "[QF_SWEEP] σ=1.0 q={:.10e} σ={:.6e} q={:.10e} (q_1minus>q_1: {})",
407 q1,
408 s1m,
409 q1m,
410 q1m > q1
411 );
412 }
413 }
414 }
415
416 let sigma = pick_sigma(
417 self.sigma_min,
418 self.sigma_max,
419 self.mu_min,
420 self.mu_max,
421 avrg_compl,
422 self.section_sigma_tol,
423 self.section_qf_tol,
424 self.max_section_steps,
425 &mut eval_q,
426 );
427
428 let mu_new = sigma * avrg_compl;
429 let mu_clamped = mu_new.clamp(self.mu_min, self.mu_max);
430 if std::env::var("POUNCE_DBG_QF").is_ok() {
431 let iter_count = data.borrow().iter_count;
432 let curr_mu = data.borrow().curr_mu;
433 let sigma_floor = self.sigma_min.max(self.mu_min / avrg_compl);
434 let sigma_up_dn = sigma_floor
435 .max(1.0 - self.section_sigma_tol.max(1e-4))
436 .min(self.mu_max / avrg_compl);
437 tracing::debug!(target: "pounce::mu",
438 "[QF] iter={} curr_mu={:.3e} avrg_compl={:.3e} sigma={:.3e} mu_new={:.3e} mu_clamped={:.3e} | sigma_min={:.3e} mu_min={:.3e} sigma_lo_dn={:.3e} sigma_up_dn={:.3e} mu_min/avrg={:.3e}",
439 iter_count, curr_mu, avrg_compl, sigma, mu_new, mu_clamped,
440 self.sigma_min, self.mu_min, sigma_floor, sigma_up_dn,
441 self.mu_min / avrg_compl,
442 );
443 }
444 Some(mu_clamped)
445 }
446}
447
448impl MuOracle for QualityFunctionMuOracle {
449 fn calculate_mu(&mut self) -> Option<Number> {
450 None
455 }
456}
457
458pub fn golden_section(
472 sigma_lo_in: Number,
473 sigma_up_in: Number,
474 q_lo_in: Number,
475 q_up_in: Number,
476 sigma_tol: Number,
477 qf_tol: Number,
478 max_steps: i32,
479 mut q: impl FnMut(Number) -> Number,
480) -> Number {
481 let mut sigma_lo = sigma_lo_in;
482 let mut sigma_up = sigma_up_in;
483 let mut q_lo = q_lo_in;
484 let mut q_up = q_up_in;
485
486 let gfac = (3.0 - 5.0_f64.sqrt()) / 2.0;
487 let mut sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
488 let mut sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
489 let mut qmid1 = q(sigma_mid1);
490 let mut qmid2 = q(sigma_mid2);
491
492 let mut nsections = 0;
493 let mut width_ok;
494 let mut qf_ok;
495 loop {
496 width_ok = (sigma_up - sigma_lo) >= sigma_tol * sigma_up;
497 let qmin = q_lo.min(q_up).min(qmid1).min(qmid2);
498 let qmax = q_lo.max(q_up).max(qmid1).max(qmid2);
499 qf_ok = qmax > 0.0 && (1.0 - qmin / qmax) >= qf_tol;
500 if !(width_ok && qf_ok && nsections < max_steps) {
501 break;
502 }
503 nsections += 1;
504 if qmid1 > qmid2 {
505 sigma_lo = sigma_mid1;
506 q_lo = qmid1;
507 sigma_mid1 = sigma_mid2;
508 qmid1 = qmid2;
509 sigma_mid2 = sigma_lo + (1.0 - gfac) * (sigma_up - sigma_lo);
510 qmid2 = q(sigma_mid2);
511 } else {
512 sigma_up = sigma_mid2;
513 q_up = qmid2;
514 sigma_mid2 = sigma_mid1;
515 qmid2 = qmid1;
516 sigma_mid1 = sigma_lo + gfac * (sigma_up - sigma_lo);
517 qmid1 = q(sigma_mid1);
518 }
519 }
520
521 if width_ok && !qf_ok {
541 let mut best_s = sigma_lo;
542 let mut best_q = q_lo;
543 if q_up < best_q {
544 best_s = sigma_up;
545 best_q = q_up;
546 }
547 if qmid1 < best_q {
548 best_s = sigma_mid1;
549 best_q = qmid1;
550 }
551 if qmid2 < best_q {
552 best_s = sigma_mid2;
553 }
554 return best_s;
555 }
556 let (mut sigma, mut qval) = if qmid1 < qmid2 {
557 (sigma_mid1, qmid1)
558 } else {
559 (sigma_mid2, qmid2)
560 };
561 if sigma_up == sigma_up_in {
562 let qtmp = if q_up < 0.0 { q(sigma_up) } else { q_up };
563 if qtmp < qval {
564 sigma = sigma_up;
565 qval = qtmp;
566 }
567 } else if sigma_lo == sigma_lo_in {
568 let qtmp = if q_lo < 0.0 { q(sigma_lo) } else { q_lo };
569 if qtmp < qval {
570 sigma = sigma_lo;
571 }
572 }
573 let _ = qval;
574 sigma
575}
576
577#[derive(Debug, Clone, Copy)]
592pub struct QualityFunctionAggregates {
593 pub dual_aggr: Number,
594 pub primal_aggr: Number,
595 pub compl_aggr: Number,
596 pub n_dual: i32,
597 pub n_pri: i32,
598 pub n_comp: i32,
599}
600
601pub fn evaluate_quality_function(
610 norm: NormType,
611 centrality: CentralityType,
612 balancing: BalancingTermType,
613 alpha_primal: Number,
614 alpha_dual: Number,
615 xi: Number,
616 aggr: QualityFunctionAggregates,
617) -> Number {
618 let (mut dual_inf, mut primal_inf, mut compl_inf) = match norm {
619 NormType::OneNorm => {
620 let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
621 let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
622 let mut c = aggr.compl_aggr;
623 d /= aggr.n_dual as Number;
624 if aggr.n_pri > 0 {
625 p /= aggr.n_pri as Number;
626 }
627 debug_assert!(aggr.n_comp > 0);
628 c /= aggr.n_comp as Number;
629 (d, p, c)
630 }
631 NormType::TwoNormSquared => {
632 let mut d = (1.0 - alpha_dual).powi(2) * aggr.dual_aggr;
636 let mut p = (1.0 - alpha_primal).powi(2) * aggr.primal_aggr;
637 let mut c = aggr.compl_aggr;
638 d /= aggr.n_dual as Number;
639 if aggr.n_pri > 0 {
640 p /= aggr.n_pri as Number;
641 }
642 debug_assert!(aggr.n_comp > 0);
643 c /= aggr.n_comp as Number;
644 (d, p, c)
645 }
646 NormType::MaxNorm => (
647 (1.0 - alpha_dual) * aggr.dual_aggr,
648 (1.0 - alpha_primal) * aggr.primal_aggr,
649 aggr.compl_aggr,
650 ),
651 NormType::TwoNorm => {
652 let mut d = (1.0 - alpha_dual) * aggr.dual_aggr;
653 let mut p = (1.0 - alpha_primal) * aggr.primal_aggr;
654 let mut c = aggr.compl_aggr;
655 d /= (aggr.n_dual as Number).sqrt();
656 if aggr.n_pri > 0 {
657 p /= (aggr.n_pri as Number).sqrt();
658 }
659 debug_assert!(aggr.n_comp > 0);
660 c /= (aggr.n_comp as Number).sqrt();
661 (d, p, c)
662 }
663 };
664
665 if dual_inf.is_nan() {
667 dual_inf = 0.0;
668 }
669 if primal_inf.is_nan() {
670 primal_inf = 0.0;
671 }
672 if compl_inf.is_nan() {
673 compl_inf = 0.0;
674 }
675
676 let mut q = dual_inf + primal_inf + compl_inf;
677
678 match centrality {
679 CentralityType::None => {}
680 CentralityType::LogCenter => q -= compl_inf * xi.ln(),
681 CentralityType::ReciprocalCenter => q += compl_inf / xi,
682 CentralityType::CubedReciprocalCenter => q += compl_inf / xi.powi(3),
683 }
684
685 match balancing {
686 BalancingTermType::None => {}
687 BalancingTermType::CubicTerm => {
688 let dom = dual_inf.max(primal_inf) - compl_inf;
689 q += dom.max(0.0).powi(3);
690 }
691 }
692
693 q
694}
695
696#[allow(clippy::too_many_arguments)]
707pub fn pick_sigma(
708 sigma_min: Number,
709 sigma_max: Number,
710 mu_min: Number,
711 mu_max: Number,
712 avrg_compl: Number,
713 sigma_tol: Number,
714 qf_tol: Number,
715 max_steps: i32,
716 mut q: impl FnMut(Number) -> Number,
717) -> Number {
718 let qf_1 = q(1.0);
719 let sigma_1minus = 1.0 - sigma_tol.max(1e-4);
720 let qf_1minus = q(sigma_1minus);
721
722 if qf_1minus > qf_1 {
723 let sigma_up = sigma_max.min(mu_max / avrg_compl);
725 let sigma_lo = 1.0;
726 if sigma_lo >= sigma_up {
727 sigma_up
728 } else {
729 golden_section(
730 sigma_lo, sigma_up, qf_1, -100.0, sigma_tol, qf_tol, max_steps, q,
731 )
732 }
733 } else {
734 let sigma_lo = sigma_min.max(mu_min / avrg_compl);
736 let sigma_up = sigma_lo.max(sigma_1minus).min(mu_max / avrg_compl);
737 if sigma_lo >= sigma_up {
738 sigma_lo
739 } else {
740 golden_section(
741 sigma_lo, sigma_up, -100.0, qf_1minus, sigma_tol, qf_tol, max_steps, q,
742 )
743 }
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[test]
752 fn golden_section_minimizes_parabola() {
753 let f = |s: f64| (s - 0.3).powi(2);
755 let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-6, 0.0, 50, f);
756 assert!((s - 0.3).abs() < 1e-3);
757 }
758
759 #[test]
760 fn golden_section_respects_max_steps() {
761 let f = |s: f64| (s - 0.5).powi(2);
763 let s = golden_section(0.0, 1.0, f(0.0), f(1.0), 1e-12, 0.0, 5, f);
764 assert!((s - 0.5).abs() < 0.2);
765 }
766
767 #[test]
768 fn golden_section_handles_monotone() {
769 let f = |s: f64| s;
771 let s = golden_section(0.1, 2.0, 0.1, 2.0, 1e-6, 0.0, 50, f);
772 assert!(s < 0.2, "got s = {}", s);
773 }
774
775 #[test]
776 fn calculate_mu_returns_none_until_plumbed() {
777 let mut o = QualityFunctionMuOracle::new();
778 assert!(o.calculate_mu().is_none());
779 }
780
781 fn aggr(
782 d: Number,
783 p: Number,
784 c: Number,
785 nd: i32,
786 np: i32,
787 nc: i32,
788 ) -> QualityFunctionAggregates {
789 QualityFunctionAggregates {
790 dual_aggr: d,
791 primal_aggr: p,
792 compl_aggr: c,
793 n_dual: nd,
794 n_pri: np,
795 n_comp: nc,
796 }
797 }
798
799 #[test]
800 fn evaluate_one_norm_averages_by_n() {
801 let q = evaluate_quality_function(
803 NormType::OneNorm,
804 CentralityType::None,
805 BalancingTermType::None,
806 0.5, 0.25, 1.0,
809 aggr(8.0, 4.0, 6.0, 4, 2, 3),
810 );
811 assert!((q - 4.5).abs() < 1e-12, "got {}", q);
813 }
814
815 #[test]
816 fn evaluate_max_norm_does_not_divide() {
817 let q = evaluate_quality_function(
818 NormType::MaxNorm,
819 CentralityType::None,
820 BalancingTermType::None,
821 0.0,
822 0.0,
823 1.0,
824 aggr(2.0, 3.0, 5.0, 10, 10, 10),
825 );
826 assert!((q - 10.0).abs() < 1e-12);
827 }
828
829 #[test]
830 fn evaluate_two_norm_divides_by_sqrt_n() {
831 let q = evaluate_quality_function(
832 NormType::TwoNorm,
833 CentralityType::None,
834 BalancingTermType::None,
835 0.0,
836 0.0,
837 1.0,
838 aggr(2.0, 0.0, 4.0, 4, 0, 16),
839 );
840 assert!((q - 2.0).abs() < 1e-12, "got {}", q);
842 }
843
844 #[test]
845 fn evaluate_one_norm_handles_zero_pri_dim() {
846 let q = evaluate_quality_function(
848 NormType::OneNorm,
849 CentralityType::None,
850 BalancingTermType::None,
851 0.0,
852 0.0,
853 1.0,
854 aggr(0.0, 0.0, 1.0, 1, 0, 1),
855 );
856 assert!(q.is_finite() && (q - 1.0).abs() < 1e-12);
857 }
858
859 #[test]
860 fn evaluate_log_centrality_subtracts_compl_log_xi() {
861 let base = evaluate_quality_function(
862 NormType::MaxNorm,
863 CentralityType::None,
864 BalancingTermType::None,
865 0.0,
866 0.0,
867 std::f64::consts::E,
868 aggr(0.0, 0.0, 4.0, 1, 1, 1),
869 );
870 let logc = evaluate_quality_function(
871 NormType::MaxNorm,
872 CentralityType::LogCenter,
873 BalancingTermType::None,
874 0.0,
875 0.0,
876 std::f64::consts::E,
877 aggr(0.0, 0.0, 4.0, 1, 1, 1),
878 );
879 assert!((base - logc - 4.0).abs() < 1e-12, "base={base} logc={logc}");
881 }
882
883 #[test]
884 fn evaluate_reciprocal_centrality_adds_c_over_xi() {
885 let q = evaluate_quality_function(
886 NormType::MaxNorm,
887 CentralityType::ReciprocalCenter,
888 BalancingTermType::None,
889 0.0,
890 0.0,
891 0.5,
892 aggr(0.0, 0.0, 1.0, 1, 1, 1),
893 );
894 assert!((q - 3.0).abs() < 1e-12);
896 }
897
898 #[test]
899 fn evaluate_cubed_reciprocal_centrality_adds_c_over_xi3() {
900 let q = evaluate_quality_function(
901 NormType::MaxNorm,
902 CentralityType::CubedReciprocalCenter,
903 BalancingTermType::None,
904 0.0,
905 0.0,
906 0.5,
907 aggr(0.0, 0.0, 1.0, 1, 1, 1),
908 );
909 assert!((q - 9.0).abs() < 1e-12);
911 }
912
913 #[test]
914 fn evaluate_cubic_balancing_adds_when_dual_dominates() {
915 let q = evaluate_quality_function(
916 NormType::MaxNorm,
917 CentralityType::None,
918 BalancingTermType::CubicTerm,
919 0.0,
920 0.0,
921 1.0,
922 aggr(5.0, 1.0, 2.0, 1, 1, 1),
923 );
924 assert!((q - 35.0).abs() < 1e-12, "got {}", q);
926 }
927
928 #[test]
929 fn evaluate_cubic_balancing_zero_when_compl_dominates() {
930 let q = evaluate_quality_function(
931 NormType::MaxNorm,
932 CentralityType::None,
933 BalancingTermType::CubicTerm,
934 0.0,
935 0.0,
936 1.0,
937 aggr(1.0, 1.0, 5.0, 1, 1, 1),
938 );
939 assert!((q - 7.0).abs() < 1e-12);
941 }
942
943 #[test]
944 fn pick_sigma_searches_below_one_for_decreasing_q() {
945 let f = |s: f64| (s - 0.4).powi(2);
947 let s = pick_sigma(1e-9, 100.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
948 assert!((s - 0.4).abs() < 1e-2, "got s = {}", s);
949 }
950
951 #[test]
952 fn pick_sigma_searches_above_one_for_q_decreasing_in_sigma() {
953 let f = |s: f64| -s;
955 let s = pick_sigma(1e-9, 10.0, 1e-11, 1e5, 1.0, 1e-6, 0.0, 50, f);
956 assert!(s > 5.0, "got s = {}", s);
958 }
959
960 #[test]
961 fn pick_sigma_clamps_to_mu_max_over_avrg_in_up_search() {
962 let f = |s: f64| -s;
964 let s = pick_sigma(1e-9, 100.0, 1e-11, 2.0, 1.0, 1e-6, 0.0, 50, f);
965 assert!(s <= 2.0 + 1e-9 && s >= 1.0, "got s = {}", s);
966 }
967
968 #[test]
969 fn pick_sigma_clamps_to_mu_min_over_avrg_in_down_search() {
970 let f = |s: f64| s;
973 let s = pick_sigma(1e-9, 100.0, 0.5, 1e5, 1.0, 1e-6, 0.0, 50, f);
974 assert!(s >= 0.5 - 1e-9 && s <= 1.0, "got s = {}", s);
975 }
976}