1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 AdaBoost, BiasVarianceTradeoff, CausalBackdoor, CrossValidation, DoubleRademacher,
9 EarlyStoppingReg, ExponentialWeightsAlgorithm, FeatureMap, FisherInformation,
10 GaussianComplexity, GaussianProcess, GradientBoosting, GrowthFunction, KernelFunction,
11 KernelMatrix, KernelSVM, KernelSVMTrainer, LassoReg, OnlineGradientDescent,
12 PACBayesGeneralization, PACLearner, Perceptron, RademacherComplexity, RegretBound,
13 SVMClassifier, SampleComplexity, TikhonovReg, UCBBandit, UniformConvergence, VCDimension, ELBO,
14};
15
16pub fn app(f: Expr, a: Expr) -> Expr {
17 Expr::App(Box::new(f), Box::new(a))
18}
19pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
20 app(app(f, a), b)
21}
22pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
23 app(app2(f, a, b), c)
24}
25pub fn cst(s: &str) -> Expr {
26 Expr::Const(Name::str(s), vec![])
27}
28pub fn prop() -> Expr {
29 Expr::Sort(Level::zero())
30}
31pub fn type0() -> Expr {
32 Expr::Sort(Level::succ(Level::zero()))
33}
34pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
35 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
36}
37pub fn arrow(a: Expr, b: Expr) -> Expr {
38 pi(BinderInfo::Default, "_", a, b)
39}
40pub fn bvar(n: u32) -> Expr {
41 Expr::BVar(n)
42}
43pub fn nat_ty() -> Expr {
44 cst("Nat")
45}
46pub fn real_ty() -> Expr {
47 cst("Real")
48}
49pub fn bool_ty() -> Expr {
50 cst("Bool")
51}
52pub fn list_ty(elem: Expr) -> Expr {
53 app(cst("List"), elem)
54}
55pub fn pac_learner_ty() -> Expr {
58 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), type0())))
59}
60pub fn sample_complexity_ty() -> Expr {
63 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), nat_ty())))
64}
65pub fn vc_dimension_ty() -> Expr {
68 arrow(type0(), nat_ty())
69}
70pub fn growth_function_ty() -> Expr {
73 arrow(type0(), arrow(nat_ty(), nat_ty()))
74}
75pub fn pac_learnability_ty() -> Expr {
78 arrow(type0(), prop())
79}
80pub fn fundamental_theorem_pac_ty() -> Expr {
83 pi(BinderInfo::Default, "H", type0(), prop())
84}
85pub fn sauer_shelah_ty() -> Expr {
88 pi(
89 BinderInfo::Default,
90 "H",
91 type0(),
92 pi(BinderInfo::Default, "m", nat_ty(), prop()),
93 )
94}
95pub fn sample_complexity_bound_ty() -> Expr {
98 pi(
99 BinderInfo::Default,
100 "eps",
101 real_ty(),
102 pi(
103 BinderInfo::Default,
104 "delta",
105 real_ty(),
106 pi(BinderInfo::Default, "d", nat_ty(), prop()),
107 ),
108 )
109}
110pub fn rademacher_complexity_ty() -> Expr {
113 arrow(type0(), arrow(nat_ty(), real_ty()))
114}
115pub fn uniform_convergence_ty() -> Expr {
118 arrow(type0(), arrow(real_ty(), arrow(real_ty(), prop())))
119}
120pub fn double_rademacher_ty() -> Expr {
123 arrow(type0(), arrow(nat_ty(), arrow(real_ty(), prop())))
124}
125pub fn gaussian_complexity_ty() -> Expr {
128 arrow(type0(), arrow(nat_ty(), real_ty()))
129}
130pub fn rademacher_bound_ty() -> Expr {
133 pi(
134 BinderInfo::Default,
135 "H",
136 type0(),
137 pi(
138 BinderInfo::Default,
139 "n",
140 nat_ty(),
141 pi(BinderInfo::Default, "delta", real_ty(), prop()),
142 ),
143 )
144}
145pub fn symmetrization_lemma_ty() -> Expr {
148 pi(
149 BinderInfo::Default,
150 "H",
151 type0(),
152 pi(BinderInfo::Default, "n", nat_ty(), prop()),
153 )
154}
155pub fn kernel_function_ty() -> Expr {
158 arrow(type0(), type0())
159}
160pub fn rkhs_ty() -> Expr {
163 arrow(arrow(type0(), type0()), type0())
164}
165pub fn feature_map_ty() -> Expr {
168 arrow(type0(), arrow(type0(), type0()))
169}
170pub fn kernel_matrix_ty() -> Expr {
173 arrow(arrow(type0(), type0()), arrow(nat_ty(), type0()))
174}
175pub fn kernel_svm_ty() -> Expr {
178 arrow(arrow(type0(), type0()), arrow(real_ty(), type0()))
179}
180pub fn mercer_theorem_ty() -> Expr {
183 pi(BinderInfo::Default, "k", arrow(type0(), type0()), prop())
184}
185pub fn representer_theorem_ty() -> Expr {
188 pi(
189 BinderInfo::Default,
190 "k",
191 arrow(type0(), type0()),
192 pi(BinderInfo::Default, "n", nat_ty(), prop()),
193 )
194}
195pub fn kernel_pca_ty() -> Expr {
198 arrow(
199 arrow(type0(), type0()),
200 arrow(nat_ty(), arrow(nat_ty(), type0())),
201 )
202}
203pub fn regularized_objective_ty() -> Expr {
206 arrow(real_ty(), type0())
207}
208pub fn tikhonov_reg_ty() -> Expr {
211 arrow(real_ty(), arrow(arrow(type0(), type0()), type0()))
212}
213pub fn lasso_reg_ty() -> Expr {
216 arrow(real_ty(), type0())
217}
218pub fn early_stopping_reg_ty() -> Expr {
221 arrow(nat_ty(), type0())
222}
223pub fn bias_variance_tradeoff_ty() -> Expr {
226 arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), prop())))
227}
228pub fn ridge_regression_solution_ty() -> Expr {
231 pi(
232 BinderInfo::Default,
233 "n",
234 nat_ty(),
235 pi(
236 BinderInfo::Default,
237 "d",
238 nat_ty(),
239 pi(BinderInfo::Default, "lam", real_ty(), prop()),
240 ),
241 )
242}
243pub fn bias_variance_decomposition_ty() -> Expr {
246 prop()
247}
248pub fn online_algorithm_ty() -> Expr {
251 arrow(nat_ty(), type0())
252}
253pub fn perceptron_ty() -> Expr {
256 arrow(nat_ty(), type0())
257}
258pub fn adaboost_ty() -> Expr {
261 arrow(nat_ty(), arrow(type0(), type0()))
262}
263pub fn online_gradient_descent_ty() -> Expr {
266 arrow(real_ty(), arrow(nat_ty(), type0()))
267}
268pub fn regret_bound_ty() -> Expr {
271 arrow(nat_ty(), arrow(real_ty(), prop()))
272}
273pub fn perceptron_mistake_bound_ty() -> Expr {
276 pi(
277 BinderInfo::Default,
278 "R",
279 real_ty(),
280 pi(BinderInfo::Default, "gamma", real_ty(), prop()),
281 )
282}
283pub fn ogd_regret_bound_ty() -> Expr {
286 pi(
287 BinderInfo::Default,
288 "T",
289 nat_ty(),
290 pi(BinderInfo::Default, "eta", real_ty(), prop()),
291 )
292}
293pub fn adaboost_training_error_ty() -> Expr {
296 pi(BinderInfo::Default, "T", nat_ty(), prop())
297}
298pub fn ml_mutual_information_ty() -> Expr {
301 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
302}
303pub fn ml_kl_divergence_ty() -> Expr {
306 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
307}
308pub fn fisher_information_ty() -> Expr {
311 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
312}
313pub fn elbo_ty() -> Expr {
316 arrow(
317 arrow(real_ty(), real_ty()),
318 arrow(arrow(real_ty(), real_ty()), real_ty()),
319 )
320}
321pub fn data_processing_inequality_ty() -> Expr {
324 prop()
325}
326pub fn chain_rule_mi_ty() -> Expr {
329 prop()
330}
331pub fn cramer_rao_bound_ty() -> Expr {
334 pi(
335 BinderInfo::Default,
336 "p",
337 arrow(real_ty(), real_ty()),
338 pi(BinderInfo::Default, "theta", real_ty(), prop()),
339 )
340}
341pub fn pac_bayes_bound_ty() -> Expr {
344 pi(
345 BinderInfo::Default,
346 "n",
347 nat_ty(),
348 pi(BinderInfo::Default, "delta", real_ty(), prop()),
349 )
350}
351pub fn build_env(env: &mut Environment) -> Result<(), String> {
353 let axioms: &[(&str, Expr)] = &[
354 ("PACLearner", pac_learner_ty()),
355 ("SampleComplexity", sample_complexity_ty()),
356 ("VCDimension", vc_dimension_ty()),
357 ("GrowthFunction", growth_function_ty()),
358 ("PACLearnability", pac_learnability_ty()),
359 ("fundamental_theorem_pac", fundamental_theorem_pac_ty()),
360 ("sauer_shelah_lemma", sauer_shelah_ty()),
361 ("sample_complexity_bound", sample_complexity_bound_ty()),
362 ("RademacherComplexity", rademacher_complexity_ty()),
363 ("UniformConvergence", uniform_convergence_ty()),
364 ("DoubleRademacher", double_rademacher_ty()),
365 ("GaussianComplexity", gaussian_complexity_ty()),
366 ("rademacher_bound", rademacher_bound_ty()),
367 ("symmetrization_lemma", symmetrization_lemma_ty()),
368 ("KernelFunction", kernel_function_ty()),
369 ("RKHS", rkhs_ty()),
370 ("FeatureMap", feature_map_ty()),
371 ("KernelMatrix", kernel_matrix_ty()),
372 ("KernelSVM", kernel_svm_ty()),
373 ("mercer_theorem", mercer_theorem_ty()),
374 ("representer_theorem", representer_theorem_ty()),
375 ("KernelPCA", kernel_pca_ty()),
376 ("RegularizedObjective", regularized_objective_ty()),
377 ("TikhonovReg", tikhonov_reg_ty()),
378 ("LassoReg", lasso_reg_ty()),
379 ("EarlyStoppingReg", early_stopping_reg_ty()),
380 ("BiasVarianceTradeoff", bias_variance_tradeoff_ty()),
381 ("ridge_regression_solution", ridge_regression_solution_ty()),
382 (
383 "bias_variance_decomposition",
384 bias_variance_decomposition_ty(),
385 ),
386 ("OnlineAlgorithm", online_algorithm_ty()),
387 ("Perceptron", perceptron_ty()),
388 ("AdaBoost", adaboost_ty()),
389 ("OnlineGradientDescent", online_gradient_descent_ty()),
390 ("RegretBound", regret_bound_ty()),
391 ("perceptron_mistake_bound", perceptron_mistake_bound_ty()),
392 ("ogd_regret_bound", ogd_regret_bound_ty()),
393 ("adaboost_training_error", adaboost_training_error_ty()),
394 ("MLMutualInformation", ml_mutual_information_ty()),
395 ("MLKLDivergence", ml_kl_divergence_ty()),
396 ("FisherInformation", fisher_information_ty()),
397 ("ELBO", elbo_ty()),
398 (
399 "data_processing_inequality",
400 data_processing_inequality_ty(),
401 ),
402 ("chain_rule_mutual_information", chain_rule_mi_ty()),
403 ("cramer_rao_bound", cramer_rao_bound_ty()),
404 ("pac_bayes_bound", pac_bayes_bound_ty()),
405 ];
406 for (name, ty) in axioms {
407 env.add(Declaration::Axiom {
408 name: Name::str(*name),
409 univ_params: vec![],
410 ty: ty.clone(),
411 })
412 .ok();
413 }
414 Ok(())
415}
416pub(super) fn dot(a: &[f64], b: &[f64]) -> f64 {
418 a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
419}
420pub fn ewa_algorithm_ty() -> Expr {
423 arrow(nat_ty(), arrow(real_ty(), type0()))
424}
425pub fn multiplicative_weights_update_ty() -> Expr {
428 arrow(nat_ty(), arrow(real_ty(), type0()))
429}
430pub fn ewa_regret_bound_ty() -> Expr {
433 pi(
434 BinderInfo::Default,
435 "n",
436 nat_ty(),
437 pi(
438 BinderInfo::Default,
439 "T",
440 nat_ty(),
441 pi(BinderInfo::Default, "eta", real_ty(), prop()),
442 ),
443 )
444}
445pub fn bandit_algorithm_ty() -> Expr {
448 arrow(nat_ty(), arrow(nat_ty(), type0()))
449}
450pub fn ucb_algorithm_ty() -> Expr {
453 arrow(nat_ty(), arrow(real_ty(), type0()))
454}
455pub fn ucb_regret_bound_ty() -> Expr {
458 pi(
459 BinderInfo::Default,
460 "n",
461 nat_ty(),
462 pi(BinderInfo::Default, "T", nat_ty(), prop()),
463 )
464}
465pub fn thompson_sampling_ty() -> Expr {
468 arrow(nat_ty(), type0())
469}
470pub fn bayesian_regret_bound_ty() -> Expr {
473 pi(
474 BinderInfo::Default,
475 "n",
476 nat_ty(),
477 pi(BinderInfo::Default, "T", nat_ty(), prop()),
478 )
479}
480pub fn data_dependent_bound_ty() -> Expr {
483 arrow(nat_ty(), arrow(real_ty(), prop()))
484}
485pub fn localized_rademacher_ty() -> Expr {
488 arrow(type0(), arrow(nat_ty(), arrow(real_ty(), real_ty())))
489}
490pub fn localized_bound_ty() -> Expr {
493 pi(
494 BinderInfo::Default,
495 "H",
496 type0(),
497 pi(
498 BinderInfo::Default,
499 "n",
500 nat_ty(),
501 pi(BinderInfo::Default, "delta", real_ty(), prop()),
502 ),
503 )
504}
505pub fn pac_bayes_mcallester_ty() -> Expr {
508 pi(
509 BinderInfo::Default,
510 "n",
511 nat_ty(),
512 pi(BinderInfo::Default, "delta", real_ty(), prop()),
513 )
514}
515pub fn catoni_pac_bayes_ty() -> Expr {
518 pi(
519 BinderInfo::Default,
520 "n",
521 nat_ty(),
522 pi(BinderInfo::Default, "delta", real_ty(), prop()),
523 )
524}
525pub fn rkhs_norm_ty() -> Expr {
528 arrow(arrow(type0(), type0()), real_ty())
529}
530pub fn kernel_pca_projection_ty() -> Expr {
533 arrow(
534 arrow(type0(), type0()),
535 arrow(nat_ty(), arrow(nat_ty(), type0())),
536 )
537}
538pub fn svm_generalization_bound_ty() -> Expr {
541 pi(
542 BinderInfo::Default,
543 "n",
544 nat_ty(),
545 pi(
546 BinderInfo::Default,
547 "R",
548 real_ty(),
549 pi(BinderInfo::Default, "gamma", real_ty(), prop()),
550 ),
551 )
552}
553pub fn neural_network_ty() -> Expr {
556 arrow(nat_ty(), arrow(nat_ty(), type0()))
557}
558pub fn depth_separation_ty() -> Expr {
561 arrow(nat_ty(), prop())
562}
563pub fn barron_space_ty() -> Expr {
566 arrow(real_ty(), type0())
567}
568pub fn barron_approximation_ty() -> Expr {
571 pi(
572 BinderInfo::Default,
573 "B",
574 real_ty(),
575 pi(BinderInfo::Default, "m", nat_ty(), prop()),
576 )
577}
578pub fn nn_expressivity_ty() -> Expr {
581 arrow(nat_ty(), arrow(nat_ty(), nat_ty()))
582}
583pub fn nn_generalization_bound_ty() -> Expr {
586 pi(
587 BinderInfo::Default,
588 "depth",
589 nat_ty(),
590 pi(
591 BinderInfo::Default,
592 "width",
593 nat_ty(),
594 pi(
595 BinderInfo::Default,
596 "n",
597 nat_ty(),
598 pi(BinderInfo::Default, "delta", real_ty(), prop()),
599 ),
600 ),
601 )
602}
603pub fn double_descent_ty() -> Expr {
606 arrow(nat_ty(), real_ty())
607}
608pub fn benign_overfitting_ty() -> Expr {
611 pi(
612 BinderInfo::Default,
613 "n",
614 nat_ty(),
615 pi(BinderInfo::Default, "d", nat_ty(), prop()),
616 )
617}
618pub fn implicit_regularization_ty() -> Expr {
621 pi(
622 BinderInfo::Default,
623 "T",
624 nat_ty(),
625 pi(BinderInfo::Default, "eta", real_ty(), prop()),
626 )
627}
628pub fn min_norm_interpolation_ty() -> Expr {
631 arrow(nat_ty(), arrow(nat_ty(), type0()))
632}
633pub fn uniform_stability_ty() -> Expr {
636 arrow(real_ty(), prop())
637}
638pub fn on_average_stability_ty() -> Expr {
641 arrow(real_ty(), prop())
642}
643pub fn stability_generalization_bound_ty() -> Expr {
646 pi(
647 BinderInfo::Default,
648 "beta",
649 real_ty(),
650 pi(
651 BinderInfo::Default,
652 "n",
653 nat_ty(),
654 pi(BinderInfo::Default, "delta", real_ty(), prop()),
655 ),
656 )
657}
658pub fn data_deletion_ty() -> Expr {
661 arrow(nat_ty(), arrow(real_ty(), type0()))
662}
663pub fn dp_sgd_algorithm_ty() -> Expr {
666 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), type0())))
667}
668pub fn private_pac_learning_ty() -> Expr {
671 arrow(real_ty(), arrow(real_ty(), type0()))
672}
673pub fn private_query_mechanism_ty() -> Expr {
676 arrow(real_ty(), arrow(real_ty(), type0()))
677}
678pub fn dp_generalization_bound_ty() -> Expr {
681 pi(
682 BinderInfo::Default,
683 "n",
684 nat_ty(),
685 pi(
686 BinderInfo::Default,
687 "eps_priv",
688 real_ty(),
689 pi(BinderInfo::Default, "delta_priv", real_ty(), prop()),
690 ),
691 )
692}
693pub fn dp_sample_complexity_ty() -> Expr {
696 arrow(
697 real_ty(),
698 arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), nat_ty()))),
699 )
700}
701pub fn calibration_error_ty() -> Expr {
704 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
705}
706pub fn proper_scoring_rule_ty() -> Expr {
709 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
710}
711pub fn reliability_diagram_ty() -> Expr {
714 arrow(nat_ty(), type0())
715}
716pub fn sharpness_measure_ty() -> Expr {
719 arrow(list_ty(real_ty()), real_ty())
720}
721pub fn do_calculus_ty() -> Expr {
724 arrow(real_ty(), arrow(real_ty(), real_ty()))
725}
726pub fn interventional_dist_ty() -> Expr {
729 arrow(type0(), arrow(type0(), type0()))
730}
731pub fn backdoor_criterion_ty() -> Expr {
734 arrow(type0(), arrow(type0(), arrow(type0(), prop())))
735}
736pub fn backdoor_adjustment_ty() -> Expr {
739 pi(
740 BinderInfo::Default,
741 "X",
742 type0(),
743 pi(
744 BinderInfo::Default,
745 "Y",
746 type0(),
747 pi(BinderInfo::Default, "Z", type0(), prop()),
748 ),
749 )
750}
751pub fn confounding_bias_ty() -> Expr {
754 arrow(real_ty(), prop())
755}
756pub fn domain_adaptation_ty() -> Expr {
759 arrow(type0(), arrow(type0(), arrow(nat_ty(), type0())))
760}
761pub fn covariate_shift_ty() -> Expr {
764 arrow(type0(), prop())
765}
766pub fn importance_weighting_ty() -> Expr {
769 arrow(
770 arrow(real_ty(), real_ty()),
771 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty())),
772 )
773}
774pub fn domain_adaptation_bound_ty() -> Expr {
777 pi(
778 BinderInfo::Default,
779 "n",
780 nat_ty(),
781 pi(BinderInfo::Default, "delta", real_ty(), prop()),
782 )
783}
784pub fn federated_learning_ty() -> Expr {
787 arrow(nat_ty(), arrow(nat_ty(), type0()))
788}
789pub fn heterogeneity_measure_ty() -> Expr {
792 arrow(nat_ty(), arrow(real_ty(), prop()))
793}
794pub fn fedavg_convergence_ty() -> Expr {
797 pi(
798 BinderInfo::Default,
799 "T",
800 nat_ty(),
801 pi(BinderInfo::Default, "m", nat_ty(), prop()),
802 )
803}
804pub fn byzantine_fault_tolerance_ty() -> Expr {
807 arrow(nat_ty(), arrow(nat_ty(), type0()))
808}
809pub fn communication_complexity_ty() -> Expr {
812 arrow(real_ty(), arrow(nat_ty(), nat_ty()))
813}
814pub fn build_env_extended(env: &mut Environment) -> Result<(), String> {
816 build_env(env)?;
817 let axioms: &[(&str, Expr)] = &[
818 ("EWAAlgorithm", ewa_algorithm_ty()),
819 (
820 "MultiplicativeWeightsUpdate",
821 multiplicative_weights_update_ty(),
822 ),
823 ("ewa_regret_bound", ewa_regret_bound_ty()),
824 ("BanditAlgorithm", bandit_algorithm_ty()),
825 ("UCBAlgorithm", ucb_algorithm_ty()),
826 ("ucb_regret_bound", ucb_regret_bound_ty()),
827 ("ThompsonSampling", thompson_sampling_ty()),
828 ("bayesian_regret_bound", bayesian_regret_bound_ty()),
829 ("DataDependentBound", data_dependent_bound_ty()),
830 ("LocalizedRademacher", localized_rademacher_ty()),
831 ("localized_bound", localized_bound_ty()),
832 ("pac_bayes_mcallester", pac_bayes_mcallester_ty()),
833 ("catoni_pac_bayes", catoni_pac_bayes_ty()),
834 ("RKHSNorm", rkhs_norm_ty()),
835 ("KernelPCAProjection", kernel_pca_projection_ty()),
836 ("svm_generalization_bound", svm_generalization_bound_ty()),
837 ("NeuralNetwork", neural_network_ty()),
838 ("depth_separation", depth_separation_ty()),
839 ("BarronSpace", barron_space_ty()),
840 ("barron_approximation", barron_approximation_ty()),
841 ("NNExpressivity", nn_expressivity_ty()),
842 ("nn_generalization_bound", nn_generalization_bound_ty()),
843 ("DoubleDescent", double_descent_ty()),
844 ("benign_overfitting", benign_overfitting_ty()),
845 ("implicit_regularization", implicit_regularization_ty()),
846 ("MinNormInterpolation", min_norm_interpolation_ty()),
847 ("UniformStability", uniform_stability_ty()),
848 ("OnAverageStability", on_average_stability_ty()),
849 (
850 "stability_generalization_bound",
851 stability_generalization_bound_ty(),
852 ),
853 ("DataDeletion", data_deletion_ty()),
854 ("DPSGDAlgorithm", dp_sgd_algorithm_ty()),
855 ("PrivatePACLearning", private_pac_learning_ty()),
856 ("PrivateQueryMechanism", private_query_mechanism_ty()),
857 ("dp_generalization_bound", dp_generalization_bound_ty()),
858 ("DPSampleComplexity", dp_sample_complexity_ty()),
859 ("CalibrationError", calibration_error_ty()),
860 ("ProperScoringRule", proper_scoring_rule_ty()),
861 ("ReliabilityDiagram", reliability_diagram_ty()),
862 ("SharpnessMeasure", sharpness_measure_ty()),
863 ("DoCalculus", do_calculus_ty()),
864 ("InterventionalDist", interventional_dist_ty()),
865 ("BackdoorCriterion", backdoor_criterion_ty()),
866 ("backdoor_adjustment", backdoor_adjustment_ty()),
867 ("ConfoundingBias", confounding_bias_ty()),
868 ("DomainAdaptation", domain_adaptation_ty()),
869 ("CovariateShift", covariate_shift_ty()),
870 ("ImportanceWeighting", importance_weighting_ty()),
871 ("domain_adaptation_bound", domain_adaptation_bound_ty()),
872 ("FederatedLearning", federated_learning_ty()),
873 ("HeterogeneityMeasure", heterogeneity_measure_ty()),
874 ("fedavg_convergence", fedavg_convergence_ty()),
875 ("ByzantineFaultTolerance", byzantine_fault_tolerance_ty()),
876 ("CommunicationComplexity", communication_complexity_ty()),
877 ];
878 for (name, ty) in axioms {
879 env.add(Declaration::Axiom {
880 name: Name::str(*name),
881 univ_params: vec![],
882 ty: ty.clone(),
883 })
884 .ok();
885 }
886 Ok(())
887}
888#[cfg(test)]
889mod extended_tests {
890 use super::*;
891 #[test]
892 fn test_ewa_regret_bound() {
893 let n = 4;
894 let t = 100;
895 let eta = ExponentialWeightsAlgorithm::optimal_eta(n, t);
896 let mut ewa = ExponentialWeightsAlgorithm::new(n, eta);
897 for _ in 0..t {
898 ewa.update(&[0.1, 0.2, 0.3, 0.4]);
899 }
900 let bound = ewa.regret_bound();
901 assert!(bound > 0.0, "EWA regret bound must be positive");
902 assert!(bound.is_finite(), "EWA bound must be finite");
903 }
904 #[test]
905 fn test_ewa_distribution_sums_to_one() {
906 let mut ewa = ExponentialWeightsAlgorithm::new(3, 0.1);
907 ewa.update(&[0.5, 0.1, 0.8]);
908 let dist = ewa.distribution();
909 let sum: f64 = dist.iter().sum();
910 assert!((sum - 1.0).abs() < 1e-9, "EWA distribution must sum to 1");
911 }
912 #[test]
913 fn test_ucb_bandit_selects_all_arms_initially() {
914 let mut bandit = UCBBandit::new(3);
915 let arm0 = bandit.select();
916 bandit.update(arm0, 0.5);
917 let arm1 = bandit.select();
918 bandit.update(arm1, 0.8);
919 let arm2 = bandit.select();
920 bandit.update(arm2, 0.3);
921 assert_eq!(arm0, 0);
922 assert_eq!(arm1, 1);
923 assert_eq!(arm2, 2);
924 }
925 #[test]
926 fn test_ucb_regret_bound_positive() {
927 let mut bandit = UCBBandit::new(2);
928 for i in 0..10 {
929 let arm = bandit.select();
930 bandit.update(arm, if i % 2 == 0 { 1.0 } else { 0.0 });
931 }
932 let bound = bandit.regret_bound_upper();
933 assert!(bound > 0.0 && bound.is_finite());
934 }
935 #[test]
936 fn test_pac_bayes_mcallester_bound() {
937 let pb = PACBayesGeneralization::new(0.5, 1000, 0.05);
938 let bound = pb.mcallester_bound(0.1);
939 assert!(bound > 0.1, "PAC-Bayes bound must exceed empirical loss");
940 assert!(
941 bound < 1.0,
942 "PAC-Bayes bound must be less than 1 for reasonable params"
943 );
944 }
945 #[test]
946 fn test_pac_bayes_catoni_bound() {
947 let pb = PACBayesGeneralization::new(0.3, 500, 0.05);
948 let lam = pb.optimal_lambda(0.1);
949 assert!(lam > 0.0 && lam < 1.0, "optimal lambda must be in (0,1)");
950 let bound = pb.catoni_bound(0.1, lam);
951 assert!(bound > 0.0 && bound.is_finite());
952 }
953 #[test]
954 fn test_kernel_svm_trainer_smo_step() {
955 let labels = vec![1.0, -1.0, 1.0, -1.0];
956 let mut svm = KernelSVMTrainer::new(4, labels, 1.0);
957 let k = vec![
958 vec![1.0, 0.0, 0.0, 0.0],
959 vec![0.0, 1.0, 0.0, 0.0],
960 vec![0.0, 0.0, 1.0, 0.0],
961 vec![0.0, 0.0, 0.0, 1.0],
962 ];
963 let _updated = svm.smo_step(0, 1, &k);
964 for &a in &svm.alphas {
965 assert!(a >= 0.0 && a <= svm.c + 1e-9);
966 }
967 }
968 #[test]
969 fn test_kernel_svm_generalization_bound() {
970 let bound = KernelSVMTrainer::generalization_bound(1.0, 0.5, 100);
971 assert!(
972 (bound - 0.04).abs() < 1e-9,
973 "RΒ²/(Ξ³Β²n) = 1/(0.25*100) = 0.04"
974 );
975 }
976 #[test]
977 fn test_causal_backdoor_adjust() {
978 let bd = CausalBackdoor::new(vec![0.8, 0.4], vec![0.6, 0.4]);
979 assert!(bd.is_valid(), "stratum probs must sum to 1");
980 let adjusted = bd.adjust();
981 assert!(
982 (adjusted - 0.64).abs() < 1e-9,
983 "backdoor adjustment must be 0.64"
984 );
985 }
986 #[test]
987 fn test_causal_backdoor_confounding_bias() {
988 let bd = CausalBackdoor::new(vec![0.8, 0.4], vec![0.6, 0.4]);
989 let bias = bd.confounding_bias(0.75);
990 assert!(
991 (bias - 0.11).abs() < 1e-9,
992 "confounding bias = |0.75 - 0.64| = 0.11"
993 );
994 }
995 #[test]
996 fn test_build_env_extended() {
997 let mut env = Environment::new();
998 let result = build_env_extended(&mut env);
999 assert!(result.is_ok(), "build_env_extended must succeed");
1000 assert!(env.get(&Name::str("EWAAlgorithm")).is_some());
1001 assert!(env.get(&Name::str("UCBAlgorithm")).is_some());
1002 assert!(env.get(&Name::str("ThompsonSampling")).is_some());
1003 assert!(env.get(&Name::str("NeuralNetwork")).is_some());
1004 assert!(env.get(&Name::str("BarronSpace")).is_some());
1005 assert!(env.get(&Name::str("DoubleDescent")).is_some());
1006 assert!(env.get(&Name::str("ByzantineFaultTolerance")).is_some());
1007 assert!(env.get(&Name::str("BackdoorCriterion")).is_some());
1008 assert!(env.get(&Name::str("DPSGDAlgorithm")).is_some());
1009 assert!(env.get(&Name::str("CalibrationError")).is_some());
1010 }
1011}
1012pub(super) fn gauss_solve(a: &[Vec<f64>], b: &[f64], d: usize, _n: usize) -> Vec<f64> {
1015 if d == 0 {
1016 return vec![];
1017 }
1018 let mut mat: Vec<Vec<f64>> = (0..d)
1019 .map(|i| {
1020 let mut row = a[i].clone();
1021 row.push(b[i]);
1022 row
1023 })
1024 .collect();
1025 for col in 0..d {
1026 let pivot = (col..d).max_by(|&i, &j| {
1027 mat[i][col]
1028 .abs()
1029 .partial_cmp(&mat[j][col].abs())
1030 .unwrap_or(std::cmp::Ordering::Equal)
1031 });
1032 if let Some(pivot_row) = pivot {
1033 mat.swap(col, pivot_row);
1034 }
1035 let diag = mat[col][col];
1036 if diag.abs() < 1e-12 {
1037 continue;
1038 }
1039 for row in (col + 1)..d {
1040 let factor = mat[row][col] / diag;
1041 for k in col..=d {
1042 let val = mat[col][k] * factor;
1043 mat[row][k] -= val;
1044 }
1045 }
1046 }
1047 let mut x = vec![0.0f64; d];
1048 for i in (0..d).rev() {
1049 let mut sum = mat[i][d];
1050 for j in (i + 1)..d {
1051 sum -= mat[i][j] * x[j];
1052 }
1053 x[i] = if mat[i][i].abs() < 1e-12 {
1054 0.0
1055 } else {
1056 sum / mat[i][i]
1057 };
1058 }
1059 x
1060}
1061#[allow(dead_code)]
1062pub fn dot_ext(a: &[f64], b: &[f64]) -> f64 {
1063 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1064}
1065#[cfg(test)]
1066mod tests_sl_extra {
1067 use super::*;
1068 #[test]
1069 fn test_gaussian_process() {
1070 let gp = GaussianProcess::default_rbf();
1071 let k = gp.rbf_kernel(0.0, 0.0);
1072 assert!((k - 1.0).abs() < 1e-9, "k(x,x) = Ο^2 = 1.0");
1073 let k_far = gp.rbf_kernel(0.0, 100.0);
1074 assert!(k_far < 1e-9, "k(0, 100) should be ~0");
1075 }
1076 #[test]
1077 fn test_svm_kernel() {
1078 let svm = SVMClassifier::rbf(1.0, 1.0);
1079 let x = vec![1.0, 0.0];
1080 let xp = vec![1.0, 0.0];
1081 let k = svm.kernel_value(&x, &xp);
1082 assert!((k - 1.0).abs() < 1e-9, "RBF(x,x)=1 for Ξ³=1");
1083 }
1084 #[test]
1085 fn test_gradient_boosting() {
1086 let gb = GradientBoosting::xgboost_style(100);
1087 assert!(gb.is_regularized());
1088 assert_eq!(gb.n_leaves_upper_bound(), 64);
1089 }
1090 #[test]
1091 fn test_cross_validation() {
1092 let cv = CrossValidation::k_fold_5(100);
1093 assert_eq!(cv.fold_size(), 20);
1094 assert_eq!(cv.train_size(), 80);
1095 assert_eq!(cv.n_train_test_splits(), 5);
1096 let loocv = CrossValidation::loocv(10);
1097 assert_eq!(loocv.n_folds, 10);
1098 }
1099}