1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 ADMMSolver, AdamConfig, AdamOptimizer, BinaryIntegerProgram, FrankWolfeOptimizer,
9 GradientDescentConfig, GradientDescentOptimizer, LBFGSState, RegretTracker,
10 RobustOptimizationProblem, SGDConfig, TwoStageStochasticProgram,
11};
12
13pub fn app(f: Expr, a: Expr) -> Expr {
14 Expr::App(Box::new(f), Box::new(a))
15}
16pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
17 app(app(f, a), b)
18}
19pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
20 app(app2(f, a, b), c)
21}
22pub fn bvar(i: u32) -> Expr {
23 Expr::BVar(i)
24}
25pub fn cst(s: &str) -> Expr {
26 Expr::Const(Name::str(s), vec![])
27}
28pub fn prop() -> Expr {
29 Expr::Sort(Level::zero())
30}
31pub fn type0() -> Expr {
32 Expr::Sort(Level::succ(Level::zero()))
33}
34pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
35 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
36}
37pub fn arrow(a: Expr, b: Expr) -> Expr {
38 pi(BinderInfo::Default, "_", a, b)
39}
40pub fn nat_ty() -> Expr {
41 cst("Nat")
42}
43pub fn real_ty() -> Expr {
44 cst("Real")
45}
46pub fn bool_ty() -> Expr {
47 cst("Bool")
48}
49pub fn list_ty(elem: Expr) -> Expr {
50 app(cst("List"), elem)
51}
52pub fn fn_ty(dom: Expr, cod: Expr) -> Expr {
53 arrow(dom, cod)
54}
55pub fn rn_to_r() -> Expr {
56 fn_ty(list_ty(real_ty()), real_ty())
57}
58pub fn rn_to_rn() -> Expr {
59 fn_ty(list_ty(real_ty()), list_ty(real_ty()))
60}
61pub fn first_order_optimal_ty() -> Expr {
64 arrow(rn_to_r(), arrow(list_ty(real_ty()), prop()))
65}
66pub fn second_order_optimal_ty() -> Expr {
69 arrow(rn_to_r(), arrow(list_ty(real_ty()), prop()))
70}
71pub fn local_minimum_ty() -> Expr {
74 arrow(rn_to_r(), arrow(list_ty(real_ty()), prop()))
75}
76pub fn global_minimum_ty() -> Expr {
79 arrow(rn_to_r(), arrow(list_ty(real_ty()), prop()))
80}
81pub fn kkt_point_ty() -> Expr {
85 let list_rn_to_r = list_ty(rn_to_r());
86 let list_r = list_ty(real_ty());
87 arrow(
88 rn_to_r(),
89 arrow(
90 list_rn_to_r.clone(),
91 arrow(
92 list_rn_to_r,
93 arrow(list_r.clone(), arrow(list_r.clone(), arrow(list_r, prop()))),
94 ),
95 ),
96 )
97}
98pub fn complementary_slackness_ty() -> Expr {
101 let list_rn_to_r = list_ty(rn_to_r());
102 let list_r = list_ty(real_ty());
103 arrow(list_rn_to_r, arrow(list_r.clone(), arrow(list_r, prop())))
104}
105pub fn dual_feasible_ty() -> Expr {
108 arrow(list_ty(real_ty()), prop())
109}
110pub fn licq_ty() -> Expr {
113 arrow(list_ty(rn_to_r()), arrow(list_ty(real_ty()), prop()))
114}
115pub fn slater_condition_ty() -> Expr {
118 arrow(list_ty(rn_to_r()), prop())
119}
120pub fn mfcq_ty() -> Expr {
123 arrow(list_ty(rn_to_r()), arrow(list_ty(real_ty()), prop()))
124}
125pub fn weak_duality_ty() -> Expr {
128 arrow(rn_to_r(), arrow(list_ty(rn_to_r()), prop()))
129}
130pub fn strong_duality_ty() -> Expr {
133 arrow(rn_to_r(), arrow(list_ty(rn_to_r()), prop()))
134}
135pub fn lagrangian_ty() -> Expr {
138 let list_rn_to_r = list_ty(rn_to_r());
139 let list_r = list_ty(real_ty());
140 arrow(
141 rn_to_r(),
142 arrow(
143 list_rn_to_r,
144 arrow(list_r.clone(), arrow(list_r, real_ty())),
145 ),
146 )
147}
148pub fn dual_function_ty() -> Expr {
151 let list_rn_to_r = list_ty(rn_to_r());
152 let list_r = list_ty(real_ty());
153 arrow(rn_to_r(), arrow(list_rn_to_r, arrow(list_r, real_ty())))
154}
155pub fn duality_gap_ty() -> Expr {
158 let list_rn_to_r = list_ty(rn_to_r());
159 let list_r = list_ty(real_ty());
160 arrow(
161 rn_to_r(),
162 arrow(
163 list_rn_to_r,
164 arrow(list_r.clone(), arrow(list_r, arrow(real_ty(), prop()))),
165 ),
166 )
167}
168pub fn penalty_objective_ty() -> Expr {
171 let list_rn_to_r = list_ty(rn_to_r());
172 arrow(rn_to_r(), arrow(list_rn_to_r, arrow(real_ty(), rn_to_r())))
173}
174pub fn augmented_lagrangian_ty() -> Expr {
177 let list_rn_to_r = list_ty(rn_to_r());
178 let list_r = list_ty(real_ty());
179 arrow(
180 rn_to_r(),
181 arrow(list_rn_to_r, arrow(list_r, arrow(real_ty(), rn_to_r()))),
182 )
183}
184pub fn regret_bound_ty() -> Expr {
187 let seq_to_r = fn_ty(list_ty(list_ty(real_ty())), real_ty());
188 arrow(seq_to_r, arrow(nat_ty(), arrow(real_ty(), prop())))
189}
190pub fn no_regret_ty() -> Expr {
193 let seq_to_r = fn_ty(list_ty(list_ty(real_ty())), real_ty());
194 arrow(seq_to_r, prop())
195}
196pub fn stochastic_convergence_ty() -> Expr {
199 arrow(
200 rn_to_r(),
201 arrow(rn_to_r(), arrow(nat_ty(), arrow(real_ty(), prop()))),
202 )
203}
204pub fn grad_descent_convergence_ty() -> Expr {
208 arrow(
209 rn_to_r(),
210 arrow(
211 real_ty(),
212 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
213 ),
214 )
215}
216pub fn nesterov_acceleration_ty() -> Expr {
219 arrow(
220 rn_to_r(),
221 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
222 )
223}
224pub fn sgd_convergence_convex_ty() -> Expr {
227 arrow(
228 rn_to_r(),
229 arrow(
230 real_ty(),
231 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
232 ),
233 )
234}
235pub fn sgd_convergence_sc_ty() -> Expr {
238 arrow(
239 rn_to_r(),
240 arrow(
241 real_ty(),
242 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
243 ),
244 )
245}
246pub fn adagrad_convergence_ty() -> Expr {
249 arrow(
250 rn_to_r(),
251 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
252 )
253}
254pub fn rmsprop_convergence_ty() -> Expr {
257 arrow(
258 rn_to_r(),
259 arrow(
260 real_ty(),
261 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
262 ),
263 )
264}
265pub fn adam_convergence_ty() -> Expr {
268 arrow(
269 rn_to_r(),
270 arrow(
271 real_ty(),
272 arrow(
273 real_ty(),
274 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
275 ),
276 ),
277 )
278}
279pub fn frank_wolfe_convergence_ty() -> Expr {
282 arrow(
283 rn_to_r(),
284 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
285 )
286}
287pub fn frank_wolfe_feasible_ty() -> Expr {
290 arrow(
291 rn_to_r(),
292 arrow(
293 list_ty(rn_to_r()),
294 arrow(list_ty(real_ty()), arrow(nat_ty(), prop())),
295 ),
296 )
297}
298pub fn bregman_divergence_ty() -> Expr {
301 arrow(
302 rn_to_r(),
303 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
304 )
305}
306pub fn mirror_descent_convergence_ty() -> Expr {
309 arrow(
310 rn_to_r(),
311 arrow(
312 rn_to_r(),
313 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
314 ),
315 )
316}
317pub fn ucb_regret_ty() -> Expr {
320 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop())))
321}
322pub fn thompson_sampling_regret_ty() -> Expr {
325 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop())))
326}
327pub fn exp3_regret_ty() -> Expr {
330 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop())))
331}
332pub fn admm_convergence_ty() -> Expr {
335 arrow(
336 rn_to_r(),
337 arrow(
338 rn_to_r(),
339 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
340 ),
341 )
342}
343pub fn douglas_rachford_convergence_ty() -> Expr {
346 arrow(
347 rn_to_r(),
348 arrow(rn_to_r(), arrow(nat_ty(), arrow(real_ty(), prop()))),
349 )
350}
351pub fn chambolle_pock_convergence_ty() -> Expr {
354 arrow(
355 rn_to_r(),
356 arrow(
357 rn_to_r(),
358 arrow(
359 real_ty(),
360 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
361 ),
362 ),
363 )
364}
365pub fn dykstra_convergence_ty() -> Expr {
368 arrow(
369 list_ty(rn_to_r()),
370 arrow(
371 list_ty(real_ty()),
372 arrow(nat_ty(), arrow(real_ty(), prop())),
373 ),
374 )
375}
376pub fn coordinate_descent_convergence_ty() -> Expr {
379 arrow(rn_to_r(), arrow(nat_ty(), arrow(real_ty(), prop())))
380}
381pub fn block_cd_convergence_ty() -> Expr {
384 arrow(
385 rn_to_r(),
386 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
387 )
388}
389pub fn trust_region_convergence_ty() -> Expr {
392 arrow(
393 rn_to_r(),
394 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop()))),
395 )
396}
397pub fn levenberg_marquardt_convergence_ty() -> Expr {
400 arrow(
401 rn_to_r(),
402 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
403 )
404}
405pub fn lbfgs_convergence_ty() -> Expr {
408 arrow(
409 rn_to_r(),
410 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
411 )
412}
413pub fn conjugate_gradient_convergence_ty() -> Expr {
417 arrow(
418 rn_to_r(),
419 arrow(
420 real_ty(),
421 arrow(real_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
422 ),
423 )
424}
425pub fn successive_convex_approx_ty() -> Expr {
428 arrow(
429 rn_to_r(),
430 arrow(
431 list_ty(rn_to_r()),
432 arrow(nat_ty(), arrow(real_ty(), prop())),
433 ),
434 )
435}
436pub fn sdp_weak_duality_ty() -> Expr {
439 arrow(rn_to_r(), arrow(list_ty(rn_to_r()), prop()))
440}
441pub fn sdp_strong_duality_ty() -> Expr {
444 arrow(rn_to_r(), arrow(list_ty(rn_to_r()), prop()))
445}
446pub fn sdp_rank_bound_ty() -> Expr {
449 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop())))
450}
451pub fn build_optimization_theory_env() -> Environment {
453 let mut env = Environment::new();
454 let axioms: &[(&str, Expr)] = &[
455 ("FirstOrderOptimal", first_order_optimal_ty()),
456 ("SecondOrderOptimal", second_order_optimal_ty()),
457 ("LocalMinimum", local_minimum_ty()),
458 ("GlobalMinimum", global_minimum_ty()),
459 ("KKTPoint", kkt_point_ty()),
460 ("ComplementarySlackness", complementary_slackness_ty()),
461 ("DualFeasible", dual_feasible_ty()),
462 ("LICQ", licq_ty()),
463 ("SlaterCondition", slater_condition_ty()),
464 ("MangasarianFromovitz", mfcq_ty()),
465 ("WeakDuality", weak_duality_ty()),
466 ("StrongDuality", strong_duality_ty()),
467 ("Lagrangian", lagrangian_ty()),
468 ("DualFunction", dual_function_ty()),
469 ("DualityGap", duality_gap_ty()),
470 ("PenaltyObjective", penalty_objective_ty()),
471 ("AugmentedLagrangian", augmented_lagrangian_ty()),
472 ("RegretBound", regret_bound_ty()),
473 ("NoRegretAlgorithm", no_regret_ty()),
474 ("StochasticConvergence", stochastic_convergence_ty()),
475 ("GradDescentConvergence", grad_descent_convergence_ty()),
476 ("NesterovAcceleration", nesterov_acceleration_ty()),
477 ("SGDConvergenceConvex", sgd_convergence_convex_ty()),
478 ("SGDConvergenceStronglyConvex", sgd_convergence_sc_ty()),
479 ("AdaGradConvergence", adagrad_convergence_ty()),
480 ("RMSPropConvergence", rmsprop_convergence_ty()),
481 ("AdamConvergence", adam_convergence_ty()),
482 ("FrankWolfeConvergence", frank_wolfe_convergence_ty()),
483 ("FrankWolfeFeasible", frank_wolfe_feasible_ty()),
484 ("BregmanDivergence", bregman_divergence_ty()),
485 ("MirrorDescentConvergence", mirror_descent_convergence_ty()),
486 ("UCBRegretBound", ucb_regret_ty()),
487 ("ThompsonSamplingRegret", thompson_sampling_regret_ty()),
488 ("ExpThreeRegret", exp3_regret_ty()),
489 ("ADMMConvergence", admm_convergence_ty()),
490 (
491 "DouglasRachfordConvergence",
492 douglas_rachford_convergence_ty(),
493 ),
494 ("ChambollePockConvergence", chambolle_pock_convergence_ty()),
495 ("DykstraConvergence", dykstra_convergence_ty()),
496 (
497 "CoordinateDescentConvergence",
498 coordinate_descent_convergence_ty(),
499 ),
500 (
501 "BlockCoordinateDescentConvergence",
502 block_cd_convergence_ty(),
503 ),
504 ("TrustRegionConvergence", trust_region_convergence_ty()),
505 (
506 "LevenbergMarquardtConvergence",
507 levenberg_marquardt_convergence_ty(),
508 ),
509 ("LBFGSConvergence", lbfgs_convergence_ty()),
510 (
511 "ConjugateGradientConvergence",
512 conjugate_gradient_convergence_ty(),
513 ),
514 ("SuccessiveConvexApprox", successive_convex_approx_ty()),
515 ("SDPWeakDuality", sdp_weak_duality_ty()),
516 ("SDPStrongDuality", sdp_strong_duality_ty()),
517 ("SDPRankBound", sdp_rank_bound_ty()),
518 ("kkt_necessary_licq", prop()),
519 ("kkt_sufficient_convex", prop()),
520 ("weak_duality_theorem", prop()),
521 ("strong_duality_slater", prop()),
522 ("penalty_exact_kkt", prop()),
523 ("sqp_superlinear_convergence", prop()),
524 ("sgd_convergence_convex_smooth", prop()),
525 ("ogd_regret_sqrt_t", prop()),
526 ("mirror_descent_regret", prop()),
527 ("interior_point_barrier_convergence", prop()),
528 ("nesterov_optimal_rate", prop()),
529 ("admm_linear_convergence", prop()),
530 ("frank_wolfe_away_steps", prop()),
531 ("lbfgs_superlinear_convergence", prop()),
532 ("coordinate_descent_linear_sc", prop()),
533 ];
534 for (name, ty) in axioms {
535 env.add(Declaration::Axiom {
536 name: Name::str(*name),
537 univ_params: vec![],
538 ty: ty.clone(),
539 })
540 .ok();
541 }
542 env
543}
544pub fn finite_diff_gradient(f: &dyn Fn(&[f64]) -> f64, x: &[f64], h: f64) -> Vec<f64> {
546 let n = x.len();
547 let mut grad = vec![0.0; n];
548 let mut xp = x.to_vec();
549 let mut xm = x.to_vec();
550 for i in 0..n {
551 xp[i] += h;
552 xm[i] -= h;
553 grad[i] = (f(&xp) - f(&xm)) / (2.0 * h);
554 xp[i] = x[i];
555 xm[i] = x[i];
556 }
557 grad
558}
559pub fn finite_diff_hessian(f: &dyn Fn(&[f64]) -> f64, x: &[f64], h: f64) -> Vec<Vec<f64>> {
561 let n = x.len();
562 let mut hess = vec![vec![0.0; n]; n];
563 let f0 = f(x);
564 let mut xph = x.to_vec();
565 let mut xmh = x.to_vec();
566 let mut xphk = x.to_vec();
567 let mut xmhk = x.to_vec();
568 let mut xphk_ph = x.to_vec();
569 for i in 0..n {
570 xph[i] += h;
571 xmh[i] -= h;
572 hess[i][i] = (f(&xph) - 2.0 * f0 + f(&xmh)) / (h * h);
573 xph[i] = x[i];
574 xmh[i] = x[i];
575 for j in (i + 1)..n {
576 xphk[i] += h;
577 xmhk[i] -= h;
578 xphk_ph[i] += h;
579 xphk_ph[j] += h;
580 let mut xph_mh = x.to_vec();
581 xph_mh[i] += h;
582 xph_mh[j] -= h;
583 let mut xmh_ph = x.to_vec();
584 xmh_ph[i] -= h;
585 xmh_ph[j] += h;
586 let mut xmh_mh = x.to_vec();
587 xmh_mh[i] -= h;
588 xmh_mh[j] -= h;
589 hess[i][j] = (f(&xphk_ph) - f(&xph_mh) - f(&xmh_ph) + f(&xmh_mh)) / (4.0 * h * h);
590 hess[j][i] = hess[i][j];
591 xphk[i] = x[i];
592 xmhk[i] = x[i];
593 xphk_ph[i] = x[i];
594 xphk_ph[j] = x[i];
595 }
596 }
597 hess
598}
599pub fn sgd(
604 f: &dyn Fn(&[f64]) -> f64,
605 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
606 x0: &[f64],
607 cfg: &SGDConfig,
608) -> (Vec<f64>, f64, usize) {
609 let n = x0.len();
610 let mut x = x0.to_vec();
611 let mut iters = 0;
612 for t in 0..cfg.max_iter {
613 let g = grad_f(&x);
614 let gnorm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
615 if gnorm < cfg.tol {
616 iters = t;
617 break;
618 }
619 let lr_t = if cfg.decay {
620 cfg.lr / ((t as f64 + 1.0).sqrt())
621 } else {
622 cfg.lr
623 };
624 for i in 0..n {
625 x[i] -= lr_t * g[i];
626 }
627 iters = t + 1;
628 }
629 (x.clone(), f(&x), iters)
630}
631pub fn adam(
635 f: &dyn Fn(&[f64]) -> f64,
636 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
637 x0: &[f64],
638 cfg: &AdamConfig,
639) -> (Vec<f64>, f64, usize) {
640 let n = x0.len();
641 let mut x = x0.to_vec();
642 let mut m = vec![0.0; n];
643 let mut v = vec![0.0; n];
644 let mut iters = 0;
645 for t in 1..=cfg.max_iter {
646 let g = grad_f(&x);
647 let gnorm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
648 if gnorm < cfg.tol {
649 iters = t - 1;
650 break;
651 }
652 for i in 0..n {
653 m[i] = cfg.beta1 * m[i] + (1.0 - cfg.beta1) * g[i];
654 v[i] = cfg.beta2 * v[i] + (1.0 - cfg.beta2) * g[i] * g[i];
655 let m_hat = m[i] / (1.0 - cfg.beta1.powi(t as i32));
656 let v_hat = v[i] / (1.0 - cfg.beta2.powi(t as i32));
657 x[i] -= cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
658 }
659 iters = t;
660 }
661 (x.clone(), f(&x), iters)
662}
663#[allow(clippy::too_many_arguments)]
671pub fn augmented_lagrangian_method(
672 _f: &dyn Fn(&[f64]) -> f64,
673 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
674 g: &dyn Fn(&[f64]) -> Vec<f64>,
675 jac_g: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
676 x0: &[f64],
677 rho: f64,
678 max_outer: usize,
679 max_inner: usize,
680 tol: f64,
681) -> (Vec<f64>, Vec<f64>, usize) {
682 let n = x0.len();
683 let m_c = g(x0).len();
684 let mut x = x0.to_vec();
685 let mut lam = vec![0.0; m_c];
686 let mut total_iters = 0;
687 for outer in 0..max_outer {
688 let aug_grad = |xk: &[f64]| -> Vec<f64> {
689 let gval = g(xk);
690 let jg = jac_g(xk);
691 let mut grad = grad_f(xk);
692 for c in 0..m_c {
693 let scale = lam[c] + rho * gval[c];
694 for i in 0..n {
695 grad[i] += scale * jg[c][i];
696 }
697 }
698 grad
699 };
700 let lr = 0.01 / (1.0 + outer as f64);
701 for _ in 0..max_inner {
702 let gr = aug_grad(&x);
703 let gnorm: f64 = gr.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
704 if gnorm < tol * 0.1 {
705 break;
706 }
707 for i in 0..n {
708 x[i] -= lr * gr[i];
709 }
710 total_iters += 1;
711 }
712 let gval = g(&x);
713 for c in 0..m_c {
714 lam[c] += rho * gval[c];
715 }
716 let feas: f64 = gval.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
717 if feas < tol {
718 break;
719 }
720 let _ = outer;
721 }
722 (x, lam, total_iters)
723}
724#[allow(clippy::too_many_arguments)]
731pub fn interior_point(
732 _f: &dyn Fn(&[f64]) -> f64,
733 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
734 g: &dyn Fn(&[f64]) -> Vec<f64>,
735 jac_g: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
736 x0: &[f64],
737 mut t: f64,
738 mu: f64,
739 max_outer: usize,
740 inner_tol: f64,
741) -> (Vec<f64>, usize) {
742 let n = x0.len();
743 let mut x = x0.to_vec();
744 let mut total_iters = 0;
745 for _ in 0..max_outer {
746 let barrier_grad = |xk: &[f64]| -> Option<Vec<f64>> {
747 let gval = g(xk);
748 let jg = jac_g(xk);
749 for &gv in &gval {
750 if gv >= 0.0 {
751 return None;
752 }
753 }
754 let mut grad = grad_f(xk);
755 for (c, &gv) in gval.iter().enumerate() {
756 let scale = -t / gv;
757 for i in 0..n {
758 grad[i] += scale * jg[c][i];
759 }
760 }
761 Some(grad)
762 };
763 let lr = 0.1 * t;
764 for _ in 0..100 {
765 match barrier_grad(&x) {
766 None => break,
767 Some(gr) => {
768 let gnorm: f64 = gr.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
769 if gnorm < inner_tol {
770 break;
771 }
772 let mut step = lr;
773 for _ in 0..20 {
774 let xnew: Vec<f64> =
775 x.iter().zip(&gr).map(|(xi, gi)| xi - step * gi).collect();
776 let gnew = g(&xnew);
777 if gnew.iter().all(|&gv| gv < 0.0) {
778 x = xnew;
779 break;
780 }
781 step *= 0.5;
782 }
783 total_iters += 1;
784 }
785 }
786 }
787 t /= mu;
788 if t < inner_tol {
789 break;
790 }
791 }
792 (x, total_iters)
793}
794pub fn sqp_step(
801 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
802 hess_f: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
803 g: &dyn Fn(&[f64]) -> Vec<f64>,
804 jac_g: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
805 x: &[f64],
806 lam: &[f64],
807 lr: f64,
808) -> (Vec<f64>, Vec<f64>) {
809 let n = x.len();
810 let gval = g(x);
811 let jg = jac_g(x);
812 let gf = grad_f(x);
813 let hf = hess_f(x);
814 let mut lag_grad = gf.clone();
815 for (c, lc) in lam.iter().enumerate() {
816 for i in 0..n {
817 lag_grad[i] += lc * jg[c][i];
818 }
819 }
820 let mut dx = lag_grad.iter().map(|gi| -gi).collect::<Vec<f64>>();
821 for i in 0..n {
822 let hii = hf[i][i].abs().max(1e-8);
823 dx[i] /= hii;
824 }
825 let xnew: Vec<f64> = x.iter().zip(&dx).map(|(xi, dxi)| xi + dxi).collect();
826 let lam_new: Vec<f64> = lam
827 .iter()
828 .zip(&gval)
829 .map(|(li, gi)| (li + lr * gi).max(0.0))
830 .collect();
831 (xnew, lam_new)
832}
833#[allow(clippy::too_many_arguments)]
837pub fn sqp(
838 f: &dyn Fn(&[f64]) -> f64,
839 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
840 hess_f: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
841 g: &dyn Fn(&[f64]) -> Vec<f64>,
842 jac_g: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
843 x0: &[f64],
844 max_iter: usize,
845 tol: f64,
846) -> (Vec<f64>, Vec<f64>, usize) {
847 let m_c = g(x0).len();
848 let mut x = x0.to_vec();
849 let mut lam = vec![0.0; m_c];
850 for iter in 0..max_iter {
851 let (xnew, lam_new) = sqp_step(grad_f, hess_f, g, jac_g, &x, &lam, 0.1);
852 let gval = g(&xnew);
853 let feas: f64 = gval.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
854 let gf = grad_f(&xnew);
855 let opt: f64 = gf.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
856 x = xnew;
857 lam = lam_new;
858 if feas < tol && opt < tol {
859 return (x, lam, iter + 1);
860 }
861 let _ = f;
862 }
863 (x, lam, max_iter)
864}
865pub fn penalty_method(
873 f: &dyn Fn(&[f64]) -> f64,
874 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
875 g: &dyn Fn(&[f64]) -> Vec<f64>,
876 jac_g: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
877 x0: &[f64],
878 rho0: f64,
879 rho_factor: f64,
880 max_outer: usize,
881 tol: f64,
882) -> (Vec<f64>, usize) {
883 let n = x0.len();
884 let mut x = x0.to_vec();
885 let mut rho = rho0;
886 let mut total_iters = 0;
887 for _ in 0..max_outer {
888 let pen_grad = |xk: &[f64]| -> Vec<f64> {
889 let gval = g(xk);
890 let jg = jac_g(xk);
891 let mut grad = grad_f(xk);
892 for (c, &gv) in gval.iter().enumerate() {
893 for i in 0..n {
894 grad[i] += rho * gv * jg[c][i];
895 }
896 }
897 grad
898 };
899 let lr = 1.0 / (rho + 1.0);
900 for _ in 0..200 {
901 let gr = pen_grad(&x);
902 let gnorm: f64 = gr.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
903 if gnorm < tol * 0.1 {
904 break;
905 }
906 for i in 0..n {
907 x[i] -= lr * gr[i];
908 }
909 total_iters += 1;
910 }
911 let gval = g(&x);
912 let feas: f64 = gval.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
913 if feas < tol {
914 break;
915 }
916 rho *= rho_factor;
917 let _ = f;
918 }
919 (x, total_iters)
920}
921pub fn online_gradient_descent(
930 losses: &[(Box<dyn Fn(&[f64]) -> f64>, Box<dyn Fn(&[f64]) -> Vec<f64>>)],
931 x0: &[f64],
932 eta: f64,
933 project: &dyn Fn(Vec<f64>) -> Vec<f64>,
934) -> (Vec<Vec<f64>>, Vec<f64>) {
935 let n = x0.len();
936 let t_max = losses.len();
937 let mut x = x0.to_vec();
938 let mut trajectory: Vec<Vec<f64>> = Vec::with_capacity(t_max + 1);
939 trajectory.push(x.clone());
940 let mut cumulative_loss = vec![0.0; t_max];
941 for (t, (ft, grad_ft)) in losses.iter().enumerate() {
942 cumulative_loss[t] = ft(&x);
943 let g = grad_ft(&x);
944 let xnew: Vec<f64> = (0..n).map(|i| x[i] - eta * g[i]).collect();
945 x = project(xnew);
946 trajectory.push(x.clone());
947 }
948 (trajectory, cumulative_loss)
949}
950pub fn compute_regret(
954 losses: &[Box<dyn Fn(&[f64]) -> f64>],
955 trajectory: &[Vec<f64>],
956 comparator: &[f64],
957) -> f64 {
958 losses
959 .iter()
960 .zip(trajectory.iter())
961 .map(|(ft, xt)| ft(xt) - ft(comparator))
962 .sum()
963}
964#[allow(clippy::too_many_arguments)]
973pub fn robbins_monro(
974 h_oracle: &dyn Fn(&[f64], usize) -> Vec<f64>,
975 x0: &[f64],
976 a: f64,
977 big_a: f64,
978 alpha: f64,
979 max_iter: usize,
980) -> (Vec<Vec<f64>>, Vec<f64>) {
981 let n = x0.len();
982 let mut x = x0.to_vec();
983 let mut iterates = Vec::with_capacity(max_iter);
984 iterates.push(x.clone());
985 for t in 0..max_iter {
986 let h = h_oracle(&x, t);
987 let step = a / (t as f64 + big_a).powf(alpha);
988 for i in 0..n {
989 x[i] -= step * h[i];
990 }
991 iterates.push(x.clone());
992 }
993 (iterates, x)
994}
995pub fn check_kkt(
999 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
1000 g: &dyn Fn(&[f64]) -> Vec<f64>,
1001 jac_g: &dyn Fn(&[f64]) -> Vec<Vec<f64>>,
1002 x: &[f64],
1003 lam: &[f64],
1004) -> (f64, f64, f64) {
1005 let n = x.len();
1006 let gval = g(x);
1007 let jg = jac_g(x);
1008 let gf = grad_f(x);
1009 let mut stat = gf.clone();
1010 for (c, &lc) in lam.iter().enumerate() {
1011 for i in 0..n {
1012 stat[i] += lc * jg[c][i];
1013 }
1014 }
1015 let stat_err: f64 = stat.iter().map(|s| s * s).sum::<f64>().sqrt();
1016 let prim_feas: f64 = gval
1017 .iter()
1018 .map(|&gv| gv.max(0.0).powi(2))
1019 .sum::<f64>()
1020 .sqrt();
1021 let comp: f64 = lam
1022 .iter()
1023 .zip(&gval)
1024 .map(|(&lc, &gv)| (lc * gv).powi(2))
1025 .sum::<f64>()
1026 .sqrt();
1027 (stat_err, prim_feas, comp)
1028}
1029pub fn nesterov_gradient(
1038 f: &dyn Fn(&[f64]) -> f64,
1039 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
1040 x0: &[f64],
1041 alpha: f64,
1042 max_iter: usize,
1043 tol: f64,
1044) -> (Vec<f64>, f64, usize) {
1045 let _n = x0.len();
1046 let mut x = x0.to_vec();
1047 let mut y = x0.to_vec();
1048 let mut t = 1.0_f64;
1049 let mut iters = 0;
1050 for k in 0..max_iter {
1051 let grad = grad_f(&x);
1052 let gnorm: f64 = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
1053 if gnorm < tol {
1054 iters = k;
1055 break;
1056 }
1057 let y_new: Vec<f64> = x
1058 .iter()
1059 .zip(&grad)
1060 .map(|(xi, gi)| xi - alpha * gi)
1061 .collect();
1062 let t_new = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
1063 let momentum = (t - 1.0) / t_new;
1064 x = y_new
1065 .iter()
1066 .zip(&y)
1067 .map(|(yn, yo)| yn + momentum * (yn - yo))
1068 .collect();
1069 y = y_new;
1070 t = t_new;
1071 iters = k + 1;
1072 }
1073 (x.clone(), f(&x), iters)
1074}
1075#[cfg(test)]
1076mod tests {
1077 use super::*;
1078 #[test]
1079 fn test_sgd_minimises_quadratic() {
1080 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
1081 let grad_f = |x: &[f64]| vec![2.0 * x[0], 2.0 * x[1]];
1082 let cfg = SGDConfig::new(0.1, 1000, 1e-6);
1083 let (x, fval, _iters) = sgd(&f, &grad_f, &[3.0, -2.0], &cfg);
1084 assert!(fval < 1e-6, "SGD fval={fval}");
1085 assert!(x[0].abs() < 1e-3, "x[0]={}", x[0]);
1086 assert!(x[1].abs() < 1e-3, "x[1]={}", x[1]);
1087 }
1088 #[test]
1089 fn test_adam_minimises_quadratic() {
1090 let f = |x: &[f64]| x[0] * x[0] + 4.0 * x[1] * x[1];
1091 let grad_f = |x: &[f64]| vec![2.0 * x[0], 8.0 * x[1]];
1092 let cfg = AdamConfig::default_params(0.1, 2000);
1093 let (x, fval, _) = adam(&f, &grad_f, &[5.0, 3.0], &cfg);
1094 assert!(fval < 1e-4, "Adam fval={fval}");
1095 assert!(x[0].abs() < 0.02, "x[0]={}", x[0]);
1096 assert!(x[1].abs() < 0.02, "x[1]={}", x[1]);
1097 }
1098 #[test]
1099 fn test_finite_diff_gradient() {
1100 let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[1] * x[1];
1101 let x = vec![3.0, -1.0];
1102 let g = finite_diff_gradient(&f, &x, 1e-5);
1103 assert!((g[0] - 6.0).abs() < 1e-6, "g[0]={}", g[0]);
1104 assert!((g[1] + 4.0).abs() < 1e-6, "g[1]={}", g[1]);
1105 }
1106 #[test]
1107 fn test_finite_diff_hessian() {
1108 let f = |x: &[f64]| x[0] * x[0] + 3.0 * x[1] * x[1];
1109 let x = vec![1.0, 1.0];
1110 let h = finite_diff_hessian(&f, &x, 1e-4);
1111 assert!((h[0][0] - 2.0).abs() < 1e-4, "H[0,0]={}", h[0][0]);
1112 assert!((h[1][1] - 6.0).abs() < 1e-4, "H[1,1]={}", h[1][1]);
1113 assert!(h[0][1].abs() < 1e-4, "H[0,1]={}", h[0][1]);
1114 }
1115 #[test]
1116 fn test_check_kkt_unconstrained_minimum() {
1117 let grad_f = |x: &[f64]| vec![2.0 * x[0]];
1118 let g = |_x: &[f64]| -> Vec<f64> { vec![] };
1119 let jac_g = |_x: &[f64]| -> Vec<Vec<f64>> { vec![] };
1120 let (stat_err, prim_feas, comp) = check_kkt(&grad_f, &g, &jac_g, &[0.0], &[]);
1121 assert!(stat_err < 1e-12, "stat_err={stat_err}");
1122 assert!(prim_feas < 1e-12, "prim_feas={prim_feas}");
1123 assert!(comp < 1e-12, "comp={comp}");
1124 }
1125 #[test]
1126 fn test_penalty_method_equality() {
1127 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
1128 let grad_f = |x: &[f64]| vec![2.0 * x[0], 2.0 * x[1]];
1129 let g = |x: &[f64]| vec![x[0] + x[1] - 1.0];
1130 let jac_g = |_x: &[f64]| vec![vec![1.0, 1.0]];
1131 let (x, _iters) = penalty_method(&f, &grad_f, &g, &jac_g, &[0.5, 0.5], 1.0, 3.0, 15, 1e-5);
1132 assert!((x[0] - 0.5).abs() < 0.05, "x[0]={}", x[0]);
1133 assert!((x[1] - 0.5).abs() < 0.05, "x[1]={}", x[1]);
1134 }
1135 #[test]
1136 fn test_ogd_cumulative_loss() {
1137 let t_max = 20;
1138 let losses: Vec<(Box<dyn Fn(&[f64]) -> f64>, Box<dyn Fn(&[f64]) -> Vec<f64>>)> = (0..t_max)
1139 .map(|_| {
1140 let f: Box<dyn Fn(&[f64]) -> f64> = Box::new(|x: &[f64]| (x[0] - 1.0).powi(2));
1141 let g: Box<dyn Fn(&[f64]) -> Vec<f64>> =
1142 Box::new(|x: &[f64]| vec![2.0 * (x[0] - 1.0)]);
1143 (f, g)
1144 })
1145 .collect();
1146 let project = |x: Vec<f64>| x;
1147 let (_traj, cum_loss) = online_gradient_descent(&losses, &[0.0], 0.1, &project);
1148 let avg_loss: f64 = cum_loss.iter().sum::<f64>() / t_max as f64;
1149 assert!(avg_loss < 1.0, "avg_loss={avg_loss}");
1150 }
1151 #[test]
1152 fn test_build_optimization_theory_env() {
1153 let env = build_optimization_theory_env();
1154 assert!(env.get(&Name::str("FirstOrderOptimal")).is_some());
1155 assert!(env.get(&Name::str("KKTPoint")).is_some());
1156 assert!(env.get(&Name::str("WeakDuality")).is_some());
1157 assert!(env.get(&Name::str("AugmentedLagrangian")).is_some());
1158 assert!(env.get(&Name::str("RegretBound")).is_some());
1159 assert!(env.get(&Name::str("StochasticConvergence")).is_some());
1160 assert!(env.get(&Name::str("GradDescentConvergence")).is_some());
1161 assert!(env.get(&Name::str("NesterovAcceleration")).is_some());
1162 assert!(env.get(&Name::str("AdamConvergence")).is_some());
1163 assert!(env.get(&Name::str("FrankWolfeConvergence")).is_some());
1164 assert!(env.get(&Name::str("BregmanDivergence")).is_some());
1165 assert!(env.get(&Name::str("MirrorDescentConvergence")).is_some());
1166 assert!(env.get(&Name::str("UCBRegretBound")).is_some());
1167 assert!(env.get(&Name::str("ADMMConvergence")).is_some());
1168 assert!(env.get(&Name::str("DouglasRachfordConvergence")).is_some());
1169 assert!(env.get(&Name::str("ChambollePockConvergence")).is_some());
1170 assert!(env.get(&Name::str("DykstraConvergence")).is_some());
1171 assert!(env
1172 .get(&Name::str("CoordinateDescentConvergence"))
1173 .is_some());
1174 assert!(env.get(&Name::str("TrustRegionConvergence")).is_some());
1175 assert!(env.get(&Name::str("LBFGSConvergence")).is_some());
1176 assert!(env.get(&Name::str("SDPStrongDuality")).is_some());
1177 }
1178 #[test]
1179 fn test_gradient_descent_armijo() {
1180 let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[1] * x[1];
1181 let grad_f = |x: &[f64]| vec![2.0 * x[0], 4.0 * x[1]];
1182 let cfg = GradientDescentConfig::new(500, 1e-6);
1183 let mut opt = GradientDescentOptimizer::new(vec![4.0, -3.0], cfg);
1184 let (x, fval, _iters) = opt.run(&f, &grad_f);
1185 assert!(fval < 1e-6, "GD fval={fval}");
1186 assert!(x[0].abs() < 1e-3, "x[0]={}", x[0]);
1187 assert!(x[1].abs() < 1e-3, "x[1]={}", x[1]);
1188 }
1189 #[test]
1190 fn test_adam_optimizer_struct() {
1191 let f = |x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] + 1.0).powi(2);
1192 let grad_f = |x: &[f64]| vec![2.0 * (x[0] - 2.0), 2.0 * (x[1] + 1.0)];
1193 let mut opt = AdamOptimizer::new(vec![0.0, 0.0], 0.05, 0.9, 0.999, 1e-8);
1194 let (x, fval, _steps) = opt.run(&f, &grad_f, 3000, 1e-6);
1195 assert!(fval < 0.01, "Adam struct fval={fval}");
1196 assert!((x[0] - 2.0).abs() < 0.05, "x[0]={}", x[0]);
1197 assert!((x[1] + 1.0).abs() < 0.05, "x[1]={}", x[1]);
1198 }
1199 #[test]
1200 fn test_frank_wolfe_simplex() {
1201 let f = |x: &[f64]| (x[0] - 0.3).powi(2) + (x[1] - 0.7).powi(2);
1202 let grad_f = |x: &[f64]| vec![2.0 * (x[0] - 0.3), 2.0 * (x[1] - 0.7)];
1203 let lmo = |g: &[f64]| {
1204 let imin = if g[0] < g[1] { 0 } else { 1 };
1205 let mut s = vec![0.0; g.len()];
1206 s[imin] = 1.0;
1207 s
1208 };
1209 let mut fw = FrankWolfeOptimizer::new(vec![0.5, 0.5]);
1210 let (_x, fval, _iters) = fw.run(&f, &grad_f, &lmo, 200, 1e-6);
1211 assert!(fval < 0.01, "FW fval={fval}");
1212 }
1213 #[test]
1214 fn test_admm_consensus() {
1215 let rho = 1.0_f64;
1216 let x_update = move |z: &[f64], u: &[f64]| vec![rho * (z[0] - u[0]) / (2.0 + rho)];
1217 let z_update = move |x: &[f64], u: &[f64]| vec![rho * (x[0] + u[0]) / (2.0 + rho)];
1218 let constraint = |x: &[f64], z: &[f64]| vec![x[0] - z[0]];
1219 let mut admm = ADMMSolver::new(rho, vec![1.0], vec![1.0], vec![0.0]);
1220 let (x, z, prim, _iters) = admm.run(&x_update, &z_update, &constraint, 200, 1e-8, 1e-4);
1221 assert!(prim < 1e-6, "ADMM primal residual={prim}");
1222 assert!(x[0].abs() < 0.01, "x={}", x[0]);
1223 assert!(z[0].abs() < 0.01, "z={}", z[0]);
1224 }
1225 #[test]
1226 fn test_nesterov_gradient_quadratic() {
1227 let f = |x: &[f64]| x[0] * x[0] + 4.0 * x[1] * x[1];
1228 let grad_f = |x: &[f64]| vec![2.0 * x[0], 8.0 * x[1]];
1229 let (x, fval, _iters) = nesterov_gradient(&f, &grad_f, &[3.0, 2.0], 0.1, 500, 1e-7);
1230 assert!(fval < 1e-6, "Nesterov fval={fval}");
1231 assert!(x[0].abs() < 1e-3, "x[0]={}", x[0]);
1232 }
1233 #[test]
1234 fn test_regret_tracker() {
1235 let mut tracker = RegretTracker::new();
1236 for _ in 0..10 {
1237 let ft = |x: &[f64]| (x[0] - 1.0).powi(2);
1238 tracker.record(&ft, &[0.0], &[1.0]);
1239 }
1240 assert_eq!(tracker.rounds, 10);
1241 assert!((tracker.total_regret() - 10.0).abs() < 1e-10);
1242 assert_eq!(tracker.average_regret(), 1.0);
1243 assert!(!tracker.is_no_regret(0.5));
1244 }
1245 #[test]
1246 fn test_lbfgs_quadratic() {
1247 let f = |x: &[f64]| x[0].powi(2) + 4.0 * x[1].powi(2) + x[2].powi(2);
1248 let grad_f = |x: &[f64]| vec![2.0 * x[0], 8.0 * x[1], 2.0 * x[2]];
1249 let mut lbfgs = LBFGSState::new(vec![2.0, 1.0, -3.0], 5);
1250 let (x, fval, _iters) = lbfgs.run(&f, &grad_f, 200, 1e-8);
1251 assert!(fval < 1e-8, "L-BFGS fval={fval}");
1252 assert!(x[0].abs() < 1e-4, "x[0]={}", x[0]);
1253 assert!(x[1].abs() < 1e-4, "x[1]={}", x[1]);
1254 assert!(x[2].abs() < 1e-4, "x[2]={}", x[2]);
1255 }
1256}
1257#[cfg(test)]
1258mod tests_optimization_extended {
1259 use super::*;
1260 #[test]
1261 fn test_robust_worst_case_cost() {
1262 let prob = RobustOptimizationProblem::new(2, 2, 0.5, vec![1.0, 2.0]);
1263 let x = vec![1.0, 1.0];
1264 let wc = prob.worst_case_cost(&x);
1265 assert!((wc - 4.0).abs() < 1e-10);
1266 }
1267 #[test]
1268 fn test_ellipsoidal_worst_case() {
1269 let prob = RobustOptimizationProblem::new(2, 2, 1.0, vec![0.0, 0.0]);
1270 let x = vec![3.0, 4.0];
1271 let wc = prob.ellipsoidal_worst_case(&x);
1272 assert!((wc - 5.0).abs() < 1e-10);
1273 }
1274 #[test]
1275 fn test_bip_feasibility() {
1276 let bip = BinaryIntegerProgram::new(vec![1.0, 2.0], vec![vec![1.0, 1.0]], vec![1.0]);
1277 assert!(bip.is_feasible(&[true, false]));
1278 assert!(!bip.is_feasible(&[true, true]));
1279 }
1280 #[test]
1281 fn test_bip_greedy_solution() {
1282 let bip = BinaryIntegerProgram::new(vec![-3.0, -1.0], vec![vec![1.0, 1.0]], vec![1.0]);
1283 let sol = bip.greedy_solution();
1284 assert!(sol[0]);
1285 }
1286 #[test]
1287 fn test_two_stage_expected_cost() {
1288 let prog = TwoStageStochasticProgram::new(vec![1.0], vec![0.3, 0.7], vec![10.0, 5.0]);
1289 let eq = prog.expected_second_stage(&[1.0]);
1290 assert!((eq - 6.5).abs() < 1e-10);
1291 }
1292 #[test]
1293 fn test_value_of_perfect_information() {
1294 let prog = TwoStageStochasticProgram::new(vec![1.0], vec![0.5, 0.5], vec![10.0, 2.0]);
1295 let vpi = prog.value_of_perfect_information();
1296 assert!(vpi >= 0.0);
1297 }
1298}