1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 BundleMethodSolver, CuttingPlaneSolver, FISTASolver, GeometricProgramSolver, GradientDescent,
9 L1NormFunction, MirrorDescentSolver, OnlineLearner, ProjectedGradient, ProximalGradientSolver,
10 QuadraticFunction, RipVerifier, SDPRelaxation, SinkhornSolver, ADMM,
11};
12
13pub fn app(f: Expr, a: Expr) -> Expr {
14 Expr::App(Box::new(f), Box::new(a))
15}
16pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
17 app(app(f, a), b)
18}
19pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
20 app(app2(f, a, b), c)
21}
22pub fn cst(s: &str) -> Expr {
23 Expr::Const(Name::str(s), vec![])
24}
25pub fn prop() -> Expr {
26 Expr::Sort(Level::zero())
27}
28pub fn type0() -> Expr {
29 Expr::Sort(Level::succ(Level::zero()))
30}
31pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
32 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
33}
34pub fn arrow(a: Expr, b: Expr) -> Expr {
35 pi(BinderInfo::Default, "_", a, b)
36}
37pub fn nat_ty() -> Expr {
38 cst("Nat")
39}
40pub fn real_ty() -> Expr {
41 cst("Real")
42}
43pub fn list_ty(elem: Expr) -> Expr {
44 app(cst("List"), elem)
45}
46pub fn fn_ty(dom: Expr, cod: Expr) -> Expr {
47 arrow(dom, cod)
48}
49pub fn convex_set_ty() -> Expr {
52 arrow(fn_ty(list_ty(real_ty()), prop()), prop())
53}
54pub fn convex_function_ty() -> Expr {
57 arrow(fn_ty(list_ty(real_ty()), real_ty()), prop())
58}
59pub fn kkt_conditions_ty() -> Expr {
62 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
63 let list_rn_to_r = list_ty(rn_to_r.clone());
64 arrow(
65 rn_to_r,
66 arrow(list_rn_to_r, arrow(list_ty(real_ty()), prop())),
67 )
68}
69pub fn lagrangian_ty() -> Expr {
72 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
73 let list_rn_to_r = list_ty(rn_to_r.clone());
74 arrow(
75 rn_to_r,
76 arrow(
77 list_rn_to_r,
78 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
79 ),
80 )
81}
82pub fn strong_duality_ty() -> Expr {
85 prop()
86}
87pub fn projection_theorem_ty() -> Expr {
90 prop()
91}
92pub fn supporting_hyperplane_ty() -> Expr {
95 prop()
96}
97pub fn jensen_inequality_ty() -> Expr {
100 arrow(fn_ty(list_ty(real_ty()), real_ty()), prop())
101}
102pub fn slater_condition_ty() -> Expr {
105 prop()
106}
107pub fn fenchel_conjugate_ty() -> Expr {
110 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
111 arrow(rn_to_r, fn_ty(list_ty(real_ty()), real_ty()))
112}
113pub fn fenchel_rockafellar_duality_ty() -> Expr {
116 prop()
117}
118pub fn conjugate_subgradient_ty() -> Expr {
121 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
122 let lr = list_ty(real_ty());
123 arrow(rn_to_r, arrow(lr.clone(), arrow(lr, prop())))
124}
125pub fn lagrangian_dual_function_ty() -> Expr {
128 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
129 let list_rn_to_r = list_ty(rn_to_r.clone());
130 arrow(
131 rn_to_r,
132 arrow(list_rn_to_r, fn_ty(list_ty(real_ty()), real_ty())),
133 )
134}
135pub fn licq_ty() -> Expr {
138 prop()
139}
140pub fn mfcq_ty() -> Expr {
143 prop()
144}
145pub fn complementary_slackness_ty() -> Expr {
148 let lr = list_ty(real_ty());
149 arrow(lr.clone(), arrow(lr, prop()))
150}
151pub fn kkt_sufficiency_ty() -> Expr {
154 prop()
155}
156pub fn barrier_function_ty() -> Expr {
159 let rn_to_prop = fn_ty(list_ty(real_ty()), prop());
160 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
161 arrow(rn_to_prop, arrow(rn_to_r, prop()))
162}
163pub fn path_following_method_ty() -> Expr {
166 prop()
167}
168pub fn predictor_corrector_method_ty() -> Expr {
171 prop()
172}
173pub fn central_path_convergence_ty() -> Expr {
176 arrow(real_ty(), prop())
177}
178pub fn positive_semidefinite_ty() -> Expr {
181 let mat = list_ty(list_ty(real_ty()));
182 arrow(mat, prop())
183}
184pub fn sdp_feasibility_ty() -> Expr {
187 let mat = list_ty(list_ty(real_ty()));
188 let mats = list_ty(mat);
189 let lr = list_ty(real_ty());
190 arrow(mats, arrow(lr, prop()))
191}
192pub fn sdp_duality_ty() -> Expr {
195 prop()
196}
197pub fn sdp_optimality_ty() -> Expr {
200 let mat = list_ty(list_ty(real_ty()));
201 let mats = list_ty(mat);
202 let lr = list_ty(real_ty());
203 arrow(mats, arrow(lr.clone(), arrow(lr, arrow(real_ty(), prop()))))
204}
205pub fn second_order_cone_ty() -> Expr {
208 let lr = list_ty(real_ty());
209 arrow(lr, arrow(real_ty(), prop()))
210}
211pub fn socp_feasibility_ty() -> Expr {
214 prop()
215}
216pub fn monomial_ty() -> Expr {
219 let lr = list_ty(real_ty());
220 arrow(lr.clone(), arrow(lr, arrow(real_ty(), real_ty())))
221}
222pub fn posynomial_ty() -> Expr {
225 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
226 let lr = list_ty(real_ty());
227 arrow(list_ty(rn_to_r), arrow(lr, real_ty()))
228}
229pub fn geometric_program_duality_ty() -> Expr {
232 prop()
233}
234pub fn smooth_gradient_convergence_ty() -> Expr {
237 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
238}
239pub fn strongly_convex_convergence_ty() -> Expr {
242 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
243}
244pub fn proximal_operator_ty() -> Expr {
247 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
248 let lr = list_ty(real_ty());
249 arrow(rn_to_r, arrow(real_ty(), fn_ty(lr.clone(), lr)))
250}
251pub fn ista_convergence_ty() -> Expr {
254 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
255}
256pub fn fista_convergence_ty() -> Expr {
259 arrow(real_ty(), arrow(nat_ty(), prop()))
260}
261pub fn proximal_gradient_descent_ty() -> Expr {
264 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
265 arrow(rn_to_r.clone(), arrow(rn_to_r, prop()))
266}
267pub fn douglas_rachford_splitting_ty() -> Expr {
270 prop()
271}
272pub fn chambolle_pock_ty() -> Expr {
275 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
276}
277pub fn augmented_lagrangian_ty() -> Expr {
280 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
281 let list_rn_to_r = list_ty(rn_to_r.clone());
282 arrow(
283 rn_to_r,
284 arrow(
285 list_rn_to_r,
286 arrow(
287 real_ty(),
288 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
289 ),
290 ),
291 )
292}
293pub fn admm_convergence_ty() -> Expr {
296 arrow(real_ty(), prop())
297}
298pub fn supporting_hyperplane_cut_ty() -> Expr {
301 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
302 let lr = list_ty(real_ty());
303 arrow(
304 rn_to_r,
305 arrow(lr.clone(), arrow(lr, arrow(real_ty(), prop()))),
306 )
307}
308pub fn kelley_method_ty() -> Expr {
311 prop()
312}
313pub fn bundle_method_convergence_ty() -> Expr {
316 arrow(real_ty(), prop())
317}
318pub fn ellipsoid_method_complexity_ty() -> Expr {
321 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
322}
323pub fn center_of_gravity_method_ty() -> Expr {
326 arrow(nat_ty(), prop())
327}
328pub fn subgradient_inequality_ty() -> Expr {
331 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
332 let lr = list_ty(real_ty());
333 arrow(rn_to_r, arrow(lr.clone(), arrow(lr, prop())))
334}
335pub fn subgradient_method_convergence_ty() -> Expr {
338 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
339}
340pub fn polyak_stepsize_ty() -> Expr {
343 arrow(real_ty(), arrow(real_ty(), prop()))
344}
345pub fn sgd_convergence_ty() -> Expr {
348 arrow(real_ty(), arrow(nat_ty(), prop()))
349}
350pub fn svrg_convergence_ty() -> Expr {
353 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
354}
355pub fn sarah_convergence_ty() -> Expr {
358 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
359}
360pub fn spider_convergence_ty() -> Expr {
363 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
364}
365pub fn dcp_atom_convex_ty() -> Expr {
368 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
369 arrow(rn_to_r, prop())
370}
371pub fn dcp_composition_rule_ty() -> Expr {
374 prop()
375}
376pub fn dcp_verification_ty() -> Expr {
379 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
380 arrow(rn_to_r, prop())
381}
382pub fn self_concordant_barrier_ty() -> Expr {
385 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
386 arrow(rn_to_r, arrow(real_ty(), prop()))
387}
388pub fn self_concordant_complexity_ty() -> Expr {
391 arrow(real_ty(), arrow(nat_ty(), prop()))
392}
393pub fn logarithmic_barrier_ty() -> Expr {
396 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
397 arrow(nat_ty(), arrow(rn_to_r, prop()))
398}
399pub fn newton_decrement_ty() -> Expr {
402 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
403 let lr = list_ty(real_ty());
404 arrow(rn_to_r, arrow(lr, real_ty()))
405}
406pub fn sdp_slater_condition_ty() -> Expr {
409 prop()
410}
411pub fn sdp_complementarity_ty() -> Expr {
414 let mat = list_ty(list_ty(real_ty()));
415 arrow(mat.clone(), arrow(mat, prop()))
416}
417pub fn sdp_duality_gap_ty() -> Expr {
420 let mat = list_ty(list_ty(real_ty()));
421 let lr = list_ty(real_ty());
422 arrow(mat, arrow(lr, real_ty()))
423}
424pub fn lorentz_cone_ty() -> Expr {
427 arrow(nat_ty(), prop())
428}
429pub fn socp_duality_ty() -> Expr {
432 prop()
433}
434pub fn rotated_lorentz_cone_ty() -> Expr {
437 let lr = list_ty(real_ty());
438 arrow(lr, arrow(real_ty(), arrow(real_ty(), prop())))
439}
440pub fn admm_linear_convergence_ty() -> Expr {
443 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
444}
445pub fn admm_primal_residual_ty() -> Expr {
448 let lr = list_ty(real_ty());
449 arrow(lr.clone(), arrow(lr, arrow(real_ty(), prop())))
450}
451pub fn admm_dual_residual_ty() -> Expr {
454 let lr = list_ty(real_ty());
455 arrow(lr, arrow(real_ty(), prop()))
456}
457pub fn proximal_point_algorithm_ty() -> Expr {
460 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
461 arrow(rn_to_r, arrow(real_ty(), arrow(nat_ty(), prop())))
462}
463pub fn resolvent_operator_ty() -> Expr {
466 let rn_rn_prop = fn_ty(list_ty(real_ty()), fn_ty(list_ty(real_ty()), prop()));
467 let lr = list_ty(real_ty());
468 arrow(rn_rn_prop, arrow(real_ty(), fn_ty(lr.clone(), lr)))
469}
470pub fn firmly_nonexpansive_ty() -> Expr {
473 let lr = list_ty(real_ty());
474 let rn_to_rn = fn_ty(lr.clone(), lr);
475 arrow(rn_to_rn, prop())
476}
477pub fn bregman_divergence_ty() -> Expr {
480 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
481 let lr = list_ty(real_ty());
482 arrow(rn_to_r, arrow(lr.clone(), arrow(lr, real_ty())))
483}
484pub fn mirror_descent_convergence_ty() -> Expr {
487 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
488}
489pub fn negative_entropy_function_ty() -> Expr {
492 let lr = list_ty(real_ty());
493 fn_ty(lr, real_ty())
494}
495pub fn exponential_weights_algorithm_ty() -> Expr {
498 arrow(real_ty(), arrow(nat_ty(), prop()))
499}
500pub fn saga_convergence_ty() -> Expr {
503 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
504}
505pub fn adam_convergence_ty() -> Expr {
508 arrow(
509 real_ty(),
510 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop()))),
511 )
512}
513pub fn momentum_sgd_ty() -> Expr {
516 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
517}
518pub fn maximal_monotone_operator_ty() -> Expr {
521 let lr = list_ty(real_ty());
522 let set_valued = fn_ty(lr.clone(), list_ty(lr));
523 arrow(set_valued, prop())
524}
525pub fn monotone_inclusion_problem_ty() -> Expr {
528 let lr = list_ty(real_ty());
529 let set_valued = fn_ty(lr.clone(), list_ty(lr.clone()));
530 arrow(set_valued, arrow(lr, prop()))
531}
532pub fn splitting_convergence_ty() -> Expr {
535 prop()
536}
537pub fn kantorovich_problem_ty() -> Expr {
540 let lr = list_ty(real_ty());
541 let cost = fn_ty(lr.clone(), fn_ty(lr.clone(), real_ty()));
542 arrow(cost, arrow(lr.clone(), arrow(lr, real_ty())))
543}
544pub fn kantorovich_duality_ty() -> Expr {
547 prop()
548}
549pub fn wasserstein_distance_ty() -> Expr {
552 let lr = list_ty(real_ty());
553 arrow(real_ty(), arrow(lr.clone(), arrow(lr, real_ty())))
554}
555pub fn sinkhorn_algorithm_ty() -> Expr {
558 arrow(real_ty(), arrow(nat_ty(), prop()))
559}
560pub fn restricted_isometry_property_ty() -> Expr {
563 let mat = list_ty(list_ty(real_ty()));
564 arrow(nat_ty(), arrow(real_ty(), arrow(mat, prop())))
565}
566pub fn basis_pursuit_recovery_ty() -> Expr {
569 let mat = list_ty(list_ty(real_ty()));
570 let lr = list_ty(real_ty());
571 arrow(mat, arrow(lr, arrow(nat_ty(), prop())))
572}
573pub fn lasso_sparsity_ty() -> Expr {
576 arrow(real_ty(), arrow(nat_ty(), prop()))
577}
578pub fn nuclear_norm_ty() -> Expr {
581 let mat = list_ty(list_ty(real_ty()));
582 fn_ty(mat, real_ty())
583}
584pub fn matrix_completion_theorem_ty() -> Expr {
588 arrow(
589 nat_ty(),
590 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop()))),
591 )
592}
593pub fn robust_pca_ty() -> Expr {
596 let mat = list_ty(list_ty(real_ty()));
597 arrow(mat.clone(), arrow(mat.clone(), arrow(mat, prop())))
598}
599pub fn chance_constraint_ty() -> Expr {
602 let rn_prop = fn_ty(list_ty(real_ty()), prop());
603 arrow(rn_prop, arrow(real_ty(), prop()))
604}
605pub fn distributionally_robust_objective_ty() -> Expr {
608 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
609 arrow(rn_to_r, arrow(real_ty(), prop()))
610}
611pub fn wasserstein_ambiguity_set_ty() -> Expr {
614 let lr = list_ty(real_ty());
615 arrow(lr, arrow(real_ty(), prop()))
616}
617pub fn cvar_constraint_ty() -> Expr {
620 let rn_to_r = fn_ty(list_ty(real_ty()), real_ty());
621 arrow(rn_to_r, arrow(real_ty(), arrow(real_ty(), prop())))
622}
623pub fn online_convex_optimization_ty() -> Expr {
626 arrow(nat_ty(), arrow(real_ty(), prop()))
627}
628pub fn ftrl_regret_bound_ty() -> Expr {
631 arrow(real_ty(), arrow(nat_ty(), prop()))
632}
633pub fn adaptive_regret_bound_ty() -> Expr {
636 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), prop())))
637}
638pub fn online_gradient_descent_regret_ty() -> Expr {
641 arrow(real_ty(), arrow(nat_ty(), prop()))
642}
643pub fn build_convex_optimization_env() -> Environment {
645 let mut env = Environment::new();
646 let axioms: &[(&str, Expr)] = &[
647 ("ConvexSet", convex_set_ty()),
648 ("ConvexFunction", convex_function_ty()),
649 ("KktConditions", kkt_conditions_ty()),
650 ("Lagrangian", lagrangian_ty()),
651 ("StrongDuality", strong_duality_ty()),
652 ("projection_theorem", projection_theorem_ty()),
653 ("supporting_hyperplane", supporting_hyperplane_ty()),
654 ("jensen_inequality", jensen_inequality_ty()),
655 ("slater_condition", slater_condition_ty()),
656 ("ConvexProgram", prop()),
657 ("DualProgram", prop()),
658 ("OptimalityGap", arrow(real_ty(), prop())),
659 ("FenchelConjugate", fenchel_conjugate_ty()),
660 (
661 "FenchelRockafellarDuality",
662 fenchel_rockafellar_duality_ty(),
663 ),
664 ("ConjugateSubgradient", conjugate_subgradient_ty()),
665 ("LagrangianDualFunction", lagrangian_dual_function_ty()),
666 ("LinearIndependenceCQ", licq_ty()),
667 ("MangasarianFromovitzCQ", mfcq_ty()),
668 ("ComplementarySlackness", complementary_slackness_ty()),
669 ("KktSufficiency", kkt_sufficiency_ty()),
670 ("BarrierFunction", barrier_function_ty()),
671 ("PathFollowingMethod", path_following_method_ty()),
672 ("PredictorCorrectorMethod", predictor_corrector_method_ty()),
673 ("CentralPathConvergence", central_path_convergence_ty()),
674 ("PositiveSemidefinite", positive_semidefinite_ty()),
675 ("SdpFeasibility", sdp_feasibility_ty()),
676 ("SdpDuality", sdp_duality_ty()),
677 ("SdpOptimality", sdp_optimality_ty()),
678 ("SecondOrderCone", second_order_cone_ty()),
679 ("SocpFeasibility", socp_feasibility_ty()),
680 ("Monomial", monomial_ty()),
681 ("Posynomial", posynomial_ty()),
682 ("GeometricProgramDuality", geometric_program_duality_ty()),
683 (
684 "SmoothGradientConvergence",
685 smooth_gradient_convergence_ty(),
686 ),
687 (
688 "StronglyConvexConvergence",
689 strongly_convex_convergence_ty(),
690 ),
691 ("ProximalOperator", proximal_operator_ty()),
692 ("IstaConvergence", ista_convergence_ty()),
693 ("FistaConvergence", fista_convergence_ty()),
694 ("ProximalGradientDescent", proximal_gradient_descent_ty()),
695 ("DouglasRachfordSplitting", douglas_rachford_splitting_ty()),
696 ("ChambollePockAlgorithm", chambolle_pock_ty()),
697 ("AugmentedLagrangian", augmented_lagrangian_ty()),
698 ("AdmmConvergence", admm_convergence_ty()),
699 ("SupportingHyperplaneCut", supporting_hyperplane_cut_ty()),
700 ("KelleyMethod", kelley_method_ty()),
701 ("BundleMethodConvergence", bundle_method_convergence_ty()),
702 (
703 "EllipsoidMethodComplexity",
704 ellipsoid_method_complexity_ty(),
705 ),
706 ("CenterOfGravityMethod", center_of_gravity_method_ty()),
707 ("SubgradientInequality", subgradient_inequality_ty()),
708 (
709 "SubgradientMethodConvergence",
710 subgradient_method_convergence_ty(),
711 ),
712 ("PolyakStepsize", polyak_stepsize_ty()),
713 ("SgdConvergence", sgd_convergence_ty()),
714 ("SvrgConvergence", svrg_convergence_ty()),
715 ("SarahConvergence", sarah_convergence_ty()),
716 ("SpiderConvergence", spider_convergence_ty()),
717 ("DcpAtomConvex", dcp_atom_convex_ty()),
718 ("DcpCompositionRule", dcp_composition_rule_ty()),
719 ("DcpVerification", dcp_verification_ty()),
720 ("SelfConcordantBarrier", self_concordant_barrier_ty()),
721 ("SelfConcordantComplexity", self_concordant_complexity_ty()),
722 ("LogarithmicBarrier", logarithmic_barrier_ty()),
723 ("NewtonDecrement", newton_decrement_ty()),
724 ("SdpSlaterCondition", sdp_slater_condition_ty()),
725 ("SdpComplementarity", sdp_complementarity_ty()),
726 ("SdpDualityGap", sdp_duality_gap_ty()),
727 ("LorentzCone", lorentz_cone_ty()),
728 ("SocpDuality", socp_duality_ty()),
729 ("RotatedLorentzCone", rotated_lorentz_cone_ty()),
730 ("AdmmLinearConvergence", admm_linear_convergence_ty()),
731 ("AdmmPrimalResidual", admm_primal_residual_ty()),
732 ("AdmmDualResidual", admm_dual_residual_ty()),
733 ("ProximalPointAlgorithm", proximal_point_algorithm_ty()),
734 ("ResolventOperator", resolvent_operator_ty()),
735 ("FirmlyNonexpansive", firmly_nonexpansive_ty()),
736 ("BregmanDivergence", bregman_divergence_ty()),
737 ("MirrorDescentConvergence", mirror_descent_convergence_ty()),
738 ("NegativeEntropyFunction", negative_entropy_function_ty()),
739 (
740 "ExponentialWeightsAlgorithm",
741 exponential_weights_algorithm_ty(),
742 ),
743 ("SagaConvergence", saga_convergence_ty()),
744 ("AdamConvergence", adam_convergence_ty()),
745 ("MomentumSgd", momentum_sgd_ty()),
746 ("MaximalMonotoneOperator", maximal_monotone_operator_ty()),
747 ("MonotoneInclusionProblem", monotone_inclusion_problem_ty()),
748 ("SplittingConvergence", splitting_convergence_ty()),
749 ("KantorovichProblem", kantorovich_problem_ty()),
750 ("KantorovichDuality", kantorovich_duality_ty()),
751 ("WassersteinDistance", wasserstein_distance_ty()),
752 ("SinkhornAlgorithm", sinkhorn_algorithm_ty()),
753 (
754 "RestrictedIsometryProperty",
755 restricted_isometry_property_ty(),
756 ),
757 ("BasisPursuitRecovery", basis_pursuit_recovery_ty()),
758 ("LassoSparsity", lasso_sparsity_ty()),
759 ("NuclearNorm", nuclear_norm_ty()),
760 ("MatrixCompletionTheorem", matrix_completion_theorem_ty()),
761 ("RobustPca", robust_pca_ty()),
762 ("ChanceConstraint", chance_constraint_ty()),
763 (
764 "DistributionallyRobustObjective",
765 distributionally_robust_objective_ty(),
766 ),
767 ("WassersteinAmbiguitySet", wasserstein_ambiguity_set_ty()),
768 ("CvarConstraint", cvar_constraint_ty()),
769 ("OnlineConvexOptimization", online_convex_optimization_ty()),
770 ("FtrlRegretBound", ftrl_regret_bound_ty()),
771 ("AdaptiveRegretBound", adaptive_regret_bound_ty()),
772 (
773 "OnlineGradientDescentRegret",
774 online_gradient_descent_regret_ty(),
775 ),
776 ];
777 for (name, ty) in axioms {
778 env.add(Declaration::Axiom {
779 name: Name::str(*name),
780 univ_params: vec![],
781 ty: ty.clone(),
782 })
783 .ok();
784 }
785 env
786}
787pub trait ConvexFunction {
789 fn eval(&self, x: &[f64]) -> f64;
791 fn gradient(&self, x: &[f64]) -> Vec<f64>;
793 fn is_strongly_convex(&self) -> bool;
795}
796pub trait ProxableFunction: ConvexFunction {
798 fn prox(&self, v: &[f64], t: f64) -> Vec<f64>;
800}
801#[cfg(test)]
802mod tests {
803 use super::*;
804 #[test]
805 fn test_quadratic_eval_origin() {
806 let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
807 let f = QuadraticFunction::new(q, vec![0.0, 0.0], 0.0);
808 assert!((f.eval(&[0.0, 0.0])).abs() < 1e-12);
809 }
810 #[test]
811 fn test_quadratic_eval_nonzero() {
812 let q = vec![vec![2.0]];
813 let f = QuadraticFunction::new(q, vec![0.0], 0.0);
814 assert!((f.eval(&[2.0]) - 4.0).abs() < 1e-9);
815 }
816 #[test]
817 fn test_quadratic_gradient() {
818 let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
819 let f = QuadraticFunction::new(q, vec![0.0, 0.0], 0.0);
820 let grad = f.gradient(&[3.0, -1.0]);
821 assert!((grad[0] - 3.0).abs() < 1e-9);
822 assert!((grad[1] + 1.0).abs() < 1e-9);
823 }
824 #[test]
825 fn test_strongly_convex() {
826 let q = vec![vec![2.0, 0.0], vec![0.0, 3.0]];
827 let f = QuadraticFunction::new(q, vec![0.0, 0.0], 0.0);
828 assert!(f.is_strongly_convex());
829 }
830 #[test]
831 fn test_gradient_descent_minimizes_quadratic() {
832 let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
833 let f = QuadraticFunction::new(q, vec![0.0, 0.0], 0.0);
834 let gd = GradientDescent::new(0.1, 500, 1e-6);
835 let (x, fval, _iters) = gd.minimize(&f, &[5.0, -3.0]);
836 assert!(fval < 1e-6, "fval={fval}");
837 assert!(x[0].abs() < 1e-3);
838 assert!(x[1].abs() < 1e-3);
839 }
840 #[test]
841 fn test_projected_gradient_box_constraint() {
842 let q = vec![vec![1.0]];
843 let f = QuadraticFunction::new(q, vec![0.0], 0.0);
844 let pg = ProjectedGradient::new(0.1, 300, 1e-6, vec![1.0], vec![5.0]);
845 let (x, fval) = pg.minimize(&f, &[4.0]);
846 assert!((x[0] - 1.0).abs() < 1e-3, "x={}", x[0]);
847 assert!((fval - 0.5).abs() < 1e-3, "fval={fval}");
848 }
849 #[test]
850 fn test_admm_solve_lasso_stub() {
851 let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
852 let b = vec![1.0, 2.0];
853 let admm = ADMM::new(1.0);
854 let x = admm.solve_lasso(&a, &b, 0.1);
855 assert_eq!(x.len(), 2);
856 assert_eq!(x, vec![0.0, 0.0]);
857 }
858 #[test]
859 fn test_build_convex_optimization_env() {
860 let env = build_convex_optimization_env();
861 assert!(env.get(&Name::str("ConvexSet")).is_some());
862 assert!(env.get(&Name::str("ConvexFunction")).is_some());
863 assert!(env.get(&Name::str("KktConditions")).is_some());
864 assert!(env.get(&Name::str("projection_theorem")).is_some());
865 assert!(env.get(&Name::str("jensen_inequality")).is_some());
866 assert!(env.get(&Name::str("FenchelConjugate")).is_some());
867 assert!(env.get(&Name::str("FenchelRockafellarDuality")).is_some());
868 assert!(env.get(&Name::str("PositiveSemidefinite")).is_some());
869 assert!(env.get(&Name::str("SdpDuality")).is_some());
870 assert!(env.get(&Name::str("FistaConvergence")).is_some());
871 assert!(env.get(&Name::str("ProximalOperator")).is_some());
872 assert!(env.get(&Name::str("DouglasRachfordSplitting")).is_some());
873 assert!(env.get(&Name::str("ChambollePockAlgorithm")).is_some());
874 assert!(env.get(&Name::str("EllipsoidMethodComplexity")).is_some());
875 assert!(env.get(&Name::str("SvrgConvergence")).is_some());
876 assert!(env.get(&Name::str("SpiderConvergence")).is_some());
877 assert!(env.get(&Name::str("DcpVerification")).is_some());
878 }
879 #[test]
880 fn test_l1_norm_prox_soft_threshold() {
881 let g = L1NormFunction::new(1.0);
882 let result = g.prox(&[3.0, -0.5], 1.0);
883 assert!((result[0] - 2.0).abs() < 1e-12, "result[0]={}", result[0]);
884 assert!(result[1].abs() < 1e-12, "result[1]={}", result[1]);
885 }
886 #[test]
887 fn test_fista_minimizes_smooth_quadratic() {
888 let smooth = QuadraticFunction::new(vec![vec![1.0]], vec![0.0], 0.0);
889 let reg = L1NormFunction::new(0.0);
890 let solver = FISTASolver::new(1.0, 200, 1e-6);
891 let (x, _iters) = solver.minimize(&smooth, ®, &[5.0]);
892 assert!(x[0].abs() < 1e-3, "x[0]={}", x[0]);
893 }
894 #[test]
895 fn test_sdp_relaxation_psd_check() {
896 let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
897 assert!(SDPRelaxation::is_psd(&id));
898 let neg = vec![vec![-1.0, 0.0], vec![0.0, 1.0]];
899 assert!(!SDPRelaxation::is_psd(&neg));
900 }
901 #[test]
902 fn test_sdp_relaxation_solve_returns_bound() {
903 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
904 let c = vec![0.0, 0.0];
905 let a = vec![vec![1.0, 0.0]];
906 let b = vec![1.0];
907 let sdp = SDPRelaxation::new(q, c, a, b);
908 let bound = sdp.solve();
909 assert!((bound - 0.0).abs() < 1e-12);
910 }
911 #[test]
912 fn test_geometric_program_log_sum_exp() {
913 let monomials = vec![(0.0_f64, vec![1.0_f64])];
914 let lse = GeometricProgramSolver::log_sum_exp_posynomial(&monomials, &[1.0]);
915 assert!((lse - 1.0).abs() < 1e-9, "lse={lse}");
916 }
917 #[test]
918 fn test_cutting_plane_minimizes_quadratic() {
919 let f = QuadraticFunction::new(vec![vec![2.0]], vec![0.0], 0.0);
920 let solver = CuttingPlaneSolver::new(100, 1e-4, 2.0);
921 let (x, fval, _iters) = solver.minimize(&f, &[3.0]);
922 assert!(fval < 1.0, "fval={fval}");
923 let _ = x;
924 }
925 #[test]
926 fn test_bundle_method_minimizes_quadratic() {
927 let f = QuadraticFunction::new(vec![vec![2.0]], vec![0.0], 0.0);
928 let solver = BundleMethodSolver::new(1.0, 0.5, 20, 200, 1e-5);
929 let (x, fval, _iters) = solver.minimize(&f, &[4.0]);
930 assert!(fval < 1.0, "fval={fval}");
931 let _ = x;
932 }
933 #[test]
934 fn test_new_axioms_in_env() {
935 let env = build_convex_optimization_env();
936 assert!(env.get(&Name::str("SelfConcordantBarrier")).is_some());
937 assert!(env.get(&Name::str("SelfConcordantComplexity")).is_some());
938 assert!(env.get(&Name::str("LogarithmicBarrier")).is_some());
939 assert!(env.get(&Name::str("NewtonDecrement")).is_some());
940 assert!(env.get(&Name::str("SdpSlaterCondition")).is_some());
941 assert!(env.get(&Name::str("SdpComplementarity")).is_some());
942 assert!(env.get(&Name::str("SdpDualityGap")).is_some());
943 assert!(env.get(&Name::str("LorentzCone")).is_some());
944 assert!(env.get(&Name::str("SocpDuality")).is_some());
945 assert!(env.get(&Name::str("RotatedLorentzCone")).is_some());
946 assert!(env.get(&Name::str("AdmmLinearConvergence")).is_some());
947 assert!(env.get(&Name::str("AdmmPrimalResidual")).is_some());
948 assert!(env.get(&Name::str("AdmmDualResidual")).is_some());
949 assert!(env.get(&Name::str("ProximalPointAlgorithm")).is_some());
950 assert!(env.get(&Name::str("ResolventOperator")).is_some());
951 assert!(env.get(&Name::str("FirmlyNonexpansive")).is_some());
952 assert!(env.get(&Name::str("BregmanDivergence")).is_some());
953 assert!(env.get(&Name::str("MirrorDescentConvergence")).is_some());
954 assert!(env.get(&Name::str("NegativeEntropyFunction")).is_some());
955 assert!(env.get(&Name::str("ExponentialWeightsAlgorithm")).is_some());
956 assert!(env.get(&Name::str("SagaConvergence")).is_some());
957 assert!(env.get(&Name::str("AdamConvergence")).is_some());
958 assert!(env.get(&Name::str("MomentumSgd")).is_some());
959 assert!(env.get(&Name::str("MaximalMonotoneOperator")).is_some());
960 assert!(env.get(&Name::str("MonotoneInclusionProblem")).is_some());
961 assert!(env.get(&Name::str("SplittingConvergence")).is_some());
962 assert!(env.get(&Name::str("KantorovichProblem")).is_some());
963 assert!(env.get(&Name::str("KantorovichDuality")).is_some());
964 assert!(env.get(&Name::str("WassersteinDistance")).is_some());
965 assert!(env.get(&Name::str("SinkhornAlgorithm")).is_some());
966 assert!(env.get(&Name::str("RestrictedIsometryProperty")).is_some());
967 assert!(env.get(&Name::str("BasisPursuitRecovery")).is_some());
968 assert!(env.get(&Name::str("LassoSparsity")).is_some());
969 assert!(env.get(&Name::str("NuclearNorm")).is_some());
970 assert!(env.get(&Name::str("MatrixCompletionTheorem")).is_some());
971 assert!(env.get(&Name::str("RobustPca")).is_some());
972 assert!(env.get(&Name::str("ChanceConstraint")).is_some());
973 assert!(env
974 .get(&Name::str("DistributionallyRobustObjective"))
975 .is_some());
976 assert!(env.get(&Name::str("WassersteinAmbiguitySet")).is_some());
977 assert!(env.get(&Name::str("CvarConstraint")).is_some());
978 assert!(env.get(&Name::str("OnlineConvexOptimization")).is_some());
979 assert!(env.get(&Name::str("FtrlRegretBound")).is_some());
980 assert!(env.get(&Name::str("AdaptiveRegretBound")).is_some());
981 assert!(env.get(&Name::str("OnlineGradientDescentRegret")).is_some());
982 }
983 #[test]
984 fn test_mirror_descent_project_simplex() {
985 let v = vec![0.5, 0.5];
986 let p = MirrorDescentSolver::project_simplex(&v);
987 assert!((p.iter().sum::<f64>() - 1.0).abs() < 1e-9);
988 assert!(p.iter().all(|&x| x >= 0.0));
989 }
990 #[test]
991 fn test_mirror_descent_simplex_sum_to_one() {
992 let v = vec![3.0, -1.0, 2.0];
993 let p = MirrorDescentSolver::project_simplex(&v);
994 assert!((p.iter().sum::<f64>() - 1.0).abs() < 1e-9);
995 assert!(p.iter().all(|&x| x >= 0.0));
996 }
997 #[test]
998 fn test_mirror_descent_bregman_kl_zero() {
999 let x = vec![0.25, 0.25, 0.5];
1000 let kl = MirrorDescentSolver::bregman_kl(&x, &x);
1001 assert!(kl.abs() < 1e-10, "kl={kl}");
1002 }
1003 #[test]
1004 fn test_mirror_descent_minimizes_quadratic() {
1005 let f = QuadraticFunction::new(vec![vec![1.0]], vec![0.0], 0.0);
1006 let solver = MirrorDescentSolver::new(0.05, 500, 1e-6, false);
1007 let (x, fval, _iters) = solver.minimize(&f, &[3.0]);
1008 assert!(fval < 0.1, "fval={fval}");
1009 let _ = x;
1010 }
1011 #[test]
1012 fn test_proximal_gradient_ista_minimizes() {
1013 let smooth = QuadraticFunction::new(vec![vec![2.0]], vec![0.0], 0.0);
1014 let reg = L1NormFunction::new(0.0);
1015 let solver = ProximalGradientSolver::new(2.0, 300, 1e-7, false);
1016 let (x, iters) = solver.minimize(&smooth, ®, &[5.0]);
1017 assert!(x[0].abs() < 0.01, "x[0]={}", x[0]);
1018 let _ = iters;
1019 }
1020 #[test]
1021 fn test_proximal_gradient_fista_accelerated() {
1022 let smooth = QuadraticFunction::new(vec![vec![2.0]], vec![0.0], 0.0);
1023 let reg = L1NormFunction::new(0.0);
1024 let solver_fista = ProximalGradientSolver::new(2.0, 200, 1e-7, true);
1025 let solver_ista = ProximalGradientSolver::new(2.0, 200, 1e-7, false);
1026 let (x_fista, iters_fista) = solver_fista.minimize(&smooth, ®, &[5.0]);
1027 let (x_ista, iters_ista) = solver_ista.minimize(&smooth, ®, &[5.0]);
1028 assert!(
1029 iters_fista <= iters_ista + 10,
1030 "fista={iters_fista}, ista={iters_ista}"
1031 );
1032 assert!(x_fista[0].abs() < 0.01, "x_fista={}", x_fista[0]);
1033 let _ = x_ista;
1034 }
1035 #[test]
1036 fn test_proximal_gradient_estimate_lipschitz() {
1037 let f = QuadraticFunction::new(vec![vec![2.0]], vec![0.0], 0.0);
1038 let l_est = ProximalGradientSolver::estimate_lipschitz(&f, &[1.0], 1);
1039 assert!(l_est > 0.5, "L_est={l_est}");
1040 }
1041 #[test]
1042 fn test_sinkhorn_uniform_transport() {
1043 let mu = vec![0.5, 0.5];
1044 let nu = vec![0.5, 0.5];
1045 let cost = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
1046 let solver = SinkhornSolver::new(0.1, 200, 1e-8);
1047 let (gamma, w) = solver.solve(&mu, &nu, &cost);
1048 let total: f64 = gamma.iter().flat_map(|r| r.iter()).sum();
1049 assert!((total - 1.0).abs() < 1e-4, "total={total}");
1050 assert!(w >= 0.0, "w={w}");
1051 }
1052 #[test]
1053 fn test_sinkhorn_same_distribution() {
1054 let mu = vec![0.5, 0.5];
1055 let nu = vec![0.5, 0.5];
1056 let cost = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
1057 let solver = SinkhornSolver::new(0.01, 500, 1e-10);
1058 let (_gamma, w) = solver.solve(&mu, &nu, &cost);
1059 assert!(w < 0.6, "w={w}");
1060 }
1061 #[test]
1062 fn test_sinkhorn_wasserstein2_1d_zero() {
1063 let pts = vec![0.0, 1.0];
1064 let wts = vec![0.5, 0.5];
1065 let w2 = SinkhornSolver::wasserstein2_1d(&pts, &wts, &pts, &wts);
1066 assert!(w2 < 0.2, "w2={w2}");
1067 }
1068 #[test]
1069 fn test_rip_identity_satisfies_rip() {
1070 let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1071 let verifier = RipVerifier::new(1, 10);
1072 let (dl, du) = verifier.estimate_rip_constant(&id);
1073 assert!(du < 0.01, "delta_upper={du}");
1074 let _ = dl;
1075 }
1076 #[test]
1077 fn test_rip_soft_threshold() {
1078 let x = vec![3.0, -0.5, 0.2];
1079 let result = RipVerifier::soft_threshold(&x, 1.0);
1080 assert!((result[0] - 2.0).abs() < 1e-12);
1081 assert!(result[1].abs() < 1e-12);
1082 assert!(result[2].abs() < 1e-12);
1083 }
1084 #[test]
1085 fn test_rip_satisfies_rip_check() {
1086 let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1087 let verifier = RipVerifier::new(1, 5);
1088 assert!(verifier.satisfies_rip(&id, 0.5));
1089 }
1090 #[test]
1091 fn test_online_learner_initial_decision() {
1092 let learner = OnlineLearner::new(0.1, 3);
1093 let x = learner.current_decision();
1094 assert_eq!(x.len(), 3);
1095 assert!(x.iter().all(|&xi| xi.abs() < 1e-12));
1096 }
1097 #[test]
1098 fn test_online_learner_update_accumulates() {
1099 let mut learner = OnlineLearner::new(0.1, 2);
1100 learner.update(&[1.0, 0.0]);
1101 let x = learner.current_decision();
1102 assert!((x[0] + 0.1).abs() < 1e-9, "x[0]={}", x[0]);
1103 assert!(x[1].abs() < 1e-9, "x[1]={}", x[1]);
1104 }
1105 #[test]
1106 fn test_online_learner_regret_bound_positive() {
1107 let mut learner = OnlineLearner::new(0.01, 2);
1108 for _ in 0..10 {
1109 learner.update(&[0.5, -0.5]);
1110 }
1111 let bound = learner.regret_bound(1.0, 1.0);
1112 assert!(bound >= 0.0, "bound={bound}");
1113 }
1114 #[test]
1115 fn test_online_learner_reset() {
1116 let mut learner = OnlineLearner::new(0.1, 2);
1117 learner.update(&[1.0, 2.0]);
1118 learner.reset();
1119 assert_eq!(learner.round, 0);
1120 let x = learner.current_decision();
1121 assert!(x.iter().all(|&xi| xi.abs() < 1e-12));
1122 }
1123}