1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{Distribution, ParticleFilter};
8
9pub fn app(f: Expr, a: Expr) -> Expr {
10 Expr::App(Box::new(f), Box::new(a))
11}
12pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
13 app(app(f, a), b)
14}
15pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
16 app(app2(f, a, b), c)
17}
18pub fn cst(s: &str) -> Expr {
19 Expr::Const(Name::str(s), vec![])
20}
21pub fn prop() -> Expr {
22 Expr::Sort(Level::zero())
23}
24pub fn type0() -> Expr {
25 Expr::Sort(Level::succ(Level::zero()))
26}
27pub fn type1() -> Expr {
28 Expr::Sort(Level::succ(Level::succ(Level::zero())))
29}
30pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
31 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
32}
33pub fn arrow(a: Expr, b: Expr) -> Expr {
34 pi(BinderInfo::Default, "_", a, b)
35}
36pub fn bvar(n: u32) -> Expr {
37 Expr::BVar(n)
38}
39pub fn nat_ty() -> Expr {
40 cst("Nat")
41}
42pub fn real_ty() -> Expr {
43 cst("Real")
44}
45pub fn bool_ty() -> Expr {
46 cst("Bool")
47}
48pub fn list_ty(elem: Expr) -> Expr {
49 app(cst("List"), elem)
50}
51pub fn measure_ty() -> Expr {
53 arrow(type0(), type0())
54}
55pub fn sigma_algebra_ty() -> Expr {
57 arrow(type0(), type0())
58}
59pub fn measurable_space_ty() -> Expr {
61 type0()
62}
63pub fn probability_monad_ty() -> Expr {
65 arrow(type0(), type0())
66}
67pub fn kernel_ty() -> Expr {
69 arrow(type0(), arrow(type0(), type0()))
70}
71pub fn sampler_ty() -> Expr {
73 arrow(type0(), type0())
74}
75pub fn density_ty() -> Expr {
77 arrow(type0(), arrow(type0(), real_ty()))
78}
79pub fn ppl_program_ty() -> Expr {
81 arrow(type0(), type0())
82}
83pub fn elbo_ty() -> Expr {
85 arrow(arrow(type0(), type0()), arrow(type0(), real_ty()))
86}
87pub fn importance_weight_ty() -> Expr {
89 arrow(type0(), real_ty())
90}
91pub fn particle_filter_ty() -> Expr {
93 arrow(type0(), type0())
94}
95pub fn gradient_estimator_ty() -> Expr {
97 arrow(type0(), type0())
98}
99pub fn bayes_measure_theory_ty() -> Expr {
105 pi(
106 BinderInfo::Default,
107 "X",
108 type0(),
109 pi(
110 BinderInfo::Default,
111 "Y",
112 type0(),
113 pi(
114 BinderInfo::Default,
115 "prior",
116 app(cst("Measure"), bvar(1)),
117 pi(
118 BinderInfo::Default,
119 "likelihood",
120 app2(cst("Kernel"), bvar(2), bvar(1)),
121 app2(
122 cst("Eq"),
123 app2(cst("Posterior"), bvar(1), bvar(0)),
124 app2(
125 cst("RadonNikodym"),
126 app2(cst("Joint"), bvar(1), bvar(0)),
127 app(cst("Marginal"), bvar(0)),
128 ),
129 ),
130 ),
131 ),
132 ),
133 )
134}
135pub fn giry_monad_laws_ty() -> Expr {
139 app(cst("MonadLaws"), cst("ProbabilityMonad"))
140}
141pub fn is_consistency_ty() -> Expr {
146 pi(
147 BinderInfo::Default,
148 "X",
149 type0(),
150 pi(
151 BinderInfo::Default,
152 "f",
153 arrow(bvar(0), real_ty()),
154 pi(
155 BinderInfo::Default,
156 "q",
157 app(cst("Measure"), bvar(1)),
158 pi(
159 BinderInfo::Default,
160 "p",
161 app(cst("Measure"), bvar(2)),
162 arrow(
163 app2(cst("AbsContinuous"), bvar(0), bvar(1)),
164 app(
165 cst("ConsistentEstimator"),
166 app3(cst("ISSampler"), bvar(3), bvar(1), bvar(0)),
167 ),
168 ),
169 ),
170 ),
171 ),
172 )
173}
174pub fn elbo_lower_bound_ty() -> Expr {
179 pi(
180 BinderInfo::Default,
181 "Z",
182 type0(),
183 pi(
184 BinderInfo::Default,
185 "X",
186 type0(),
187 pi(
188 BinderInfo::Default,
189 "q",
190 app(cst("Measure"), bvar(1)),
191 pi(
192 BinderInfo::Default,
193 "p",
194 app2(cst("JointMeasure"), bvar(2), bvar(1)),
195 pi(
196 BinderInfo::Default,
197 "x",
198 bvar(2),
199 app2(
200 cst("Real.le"),
201 app3(cst("ELBO"), bvar(2), bvar(1), bvar(0)),
202 app2(cst("LogMarginalLikelihood"), bvar(1), bvar(0)),
203 ),
204 ),
205 ),
206 ),
207 ),
208 )
209}
210pub fn reparam_unbiased_ty() -> Expr {
216 pi(
217 BinderInfo::Default,
218 "Z",
219 type0(),
220 pi(
221 BinderInfo::Default,
222 "Params",
223 type0(),
224 pi(
225 BinderInfo::Default,
226 "phi",
227 bvar(0),
228 pi(
229 BinderInfo::Default,
230 "f",
231 arrow(bvar(2), real_ty()),
232 app2(
233 cst("Unbiased"),
234 app(cst("ReparamGradient"), bvar(0)),
235 app2(cst("GradExpectation"), bvar(0), bvar(1)),
236 ),
237 ),
238 ),
239 ),
240 )
241}
242pub fn hmc_invariant_ty() -> Expr {
246 pi(
247 BinderInfo::Default,
248 "X",
249 type0(),
250 pi(
251 BinderInfo::Default,
252 "target",
253 app(cst("Measure"), bvar(0)),
254 app2(
255 cst("InvariantUnder"),
256 app(cst("HMCKernel"), bvar(0)),
257 bvar(0),
258 ),
259 ),
260 )
261}
262pub fn smc_consistency_ty() -> Expr {
264 pi(
265 BinderInfo::Default,
266 "X",
267 type0(),
268 pi(
269 BinderInfo::Default,
270 "ssm",
271 app(cst("StateSpaceModel"), bvar(0)),
272 app(cst("ConsistentFilteringDist"), bvar(0)),
273 ),
274 )
275}
276pub fn svi_convergence_ty() -> Expr {
281 pi(
282 BinderInfo::Default,
283 "VF",
284 type0(),
285 pi(
286 BinderInfo::Default,
287 "lr",
288 type0(),
289 app(
290 cst("ConvergesToLocalOptimum"),
291 app2(cst("SVIOptimizer"), bvar(1), bvar(0)),
292 ),
293 ),
294 )
295}
296pub fn normalizing_flow_cov_ty() -> Expr {
302 pi(
303 BinderInfo::Default,
304 "X",
305 type0(),
306 pi(
307 BinderInfo::Default,
308 "Y",
309 type0(),
310 pi(
311 BinderInfo::Default,
312 "f",
313 arrow(bvar(1), bvar(0)),
314 pi(
315 BinderInfo::Default,
316 "p",
317 app(cst("Measure"), bvar(2)),
318 arrow(
319 app(cst("Bijective"), bvar(1)),
320 app(cst("FlowDensityEq"), bvar(2)),
321 ),
322 ),
323 ),
324 ),
325 )
326}
327pub fn score_fn_unbiased_ty() -> Expr {
333 pi(
334 BinderInfo::Default,
335 "Z",
336 type0(),
337 pi(
338 BinderInfo::Default,
339 "q_phi",
340 app(cst("ParametricMeasure"), bvar(0)),
341 pi(
342 BinderInfo::Default,
343 "f",
344 arrow(bvar(1), real_ty()),
345 arrow(
346 app(cst("RegularFamily"), bvar(1)),
347 app2(
348 cst("Unbiased"),
349 app2(cst("ScoreFnGrad"), bvar(1), bvar(0)),
350 app2(cst("GradELBO"), bvar(1), bvar(0)),
351 ),
352 ),
353 ),
354 ),
355 )
356}
357pub fn pathwise_gradient_unbiased_ty() -> Expr {
363 pi(
364 BinderInfo::Default,
365 "Eps",
366 type0(),
367 pi(
368 BinderInfo::Default,
369 "Phi",
370 type0(),
371 pi(
372 BinderInfo::Default,
373 "Z",
374 type0(),
375 pi(
376 BinderInfo::Default,
377 "g",
378 arrow(bvar(2), arrow(bvar(1), bvar(0))),
379 pi(
380 BinderInfo::Default,
381 "f",
382 arrow(bvar(1), real_ty()),
383 arrow(
384 app(cst("DiffReparameterisation"), bvar(1)),
385 app2(
386 cst("Unbiased"),
387 app2(cst("PathwiseGrad"), bvar(1), bvar(0)),
388 app2(cst("GradELBO"), app(cst("Reparam"), bvar(1)), bvar(0)),
389 ),
390 ),
391 ),
392 ),
393 ),
394 ),
395 )
396}
397pub fn measure_transport_exists_ty() -> Expr {
403 pi(
404 BinderInfo::Default,
405 "X",
406 type0(),
407 pi(
408 BinderInfo::Default,
409 "mu",
410 app(cst("ProbMeasure"), bvar(0)),
411 pi(
412 BinderInfo::Default,
413 "nu",
414 app(cst("ProbMeasure"), bvar(1)),
415 app(
416 cst("Exists"),
417 pi(
418 BinderInfo::Default,
419 "T",
420 arrow(bvar(2), bvar(2)),
421 app2(
422 cst("Eq"),
423 app2(cst("Pushforward"), bvar(0), bvar(2)),
424 bvar(1),
425 ),
426 ),
427 ),
428 ),
429 ),
430 )
431}
432pub fn ot_kantorovich_ty() -> Expr {
438 pi(
439 BinderInfo::Default,
440 "X",
441 type0(),
442 pi(
443 BinderInfo::Default,
444 "mu",
445 app(cst("ProbMeasure"), bvar(0)),
446 pi(
447 BinderInfo::Default,
448 "nu",
449 app(cst("ProbMeasure"), bvar(1)),
450 app2(
451 cst("Eq"),
452 app2(cst("W1Dist"), bvar(1), bvar(0)),
453 app2(cst("KantorovichDual"), bvar(1), bvar(0)),
454 ),
455 ),
456 ),
457 )
458}
459pub fn stein_identity_ty() -> Expr {
465 pi(
466 BinderInfo::Default,
467 "X",
468 type0(),
469 pi(
470 BinderInfo::Default,
471 "p",
472 app(cst("SmoothMeasure"), bvar(0)),
473 pi(
474 BinderInfo::Default,
475 "h",
476 arrow(bvar(1), real_ty()),
477 arrow(
478 app(cst("SmoothTestFn"), bvar(0)),
479 app2(
480 cst("Eq"),
481 app2(
482 cst("Expectation"),
483 bvar(1),
484 app2(cst("SteinOp"), bvar(1), bvar(0)),
485 ),
486 cst("Real.zero"),
487 ),
488 ),
489 ),
490 ),
491 )
492}
493pub fn svgd_convergence_ty() -> Expr {
499 pi(
500 BinderInfo::Default,
501 "X",
502 type0(),
503 pi(
504 BinderInfo::Default,
505 "target",
506 app(cst("SmoothMeasure"), bvar(0)),
507 pi(
508 BinderInfo::Default,
509 "n",
510 nat_ty(),
511 app2(
512 cst("Real.le"),
513 app2(
514 cst("SteinDiscrepancy"),
515 app2(cst("SVGDUpdate"), bvar(1), bvar(0)),
516 bvar(1),
517 ),
518 app(cst("SVGDBound"), bvar(0)),
519 ),
520 ),
521 ),
522 )
523}
524pub fn pmc_consistency_ty() -> Expr {
530 pi(
531 BinderInfo::Default,
532 "X",
533 type0(),
534 pi(
535 BinderInfo::Default,
536 "N",
537 nat_ty(),
538 pi(
539 BinderInfo::Default,
540 "T",
541 nat_ty(),
542 pi(
543 BinderInfo::Default,
544 "target",
545 app(cst("Measure"), bvar(2)),
546 app(
547 cst("ConsistentEstimator"),
548 app3(cst("PMCEstimator"), bvar(0), bvar(2), bvar(1)),
549 ),
550 ),
551 ),
552 ),
553 )
554}
555pub fn evol_mcmc_detailed_balance_ty() -> Expr {
561 pi(
562 BinderInfo::Default,
563 "X",
564 type0(),
565 pi(
566 BinderInfo::Default,
567 "target",
568 app(cst("Measure"), bvar(0)),
569 pi(
570 BinderInfo::Default,
571 "temp",
572 cst("Tempering"),
573 app2(
574 cst("DetailedBalance"),
575 app2(cst("EvolMCMCKernel"), bvar(1), bvar(0)),
576 app2(cst("TemperedTarget"), bvar(1), bvar(0)),
577 ),
578 ),
579 ),
580 )
581}
582pub fn parallel_tempering_exchange_ty() -> Expr {
588 pi(
589 BinderInfo::Default,
590 "temps",
591 list_ty(real_ty()),
592 pi(
593 BinderInfo::Default,
594 "joint",
595 app(cst("Measure"), app(cst("ProductSpace"), bvar(0))),
596 app2(
597 cst("InvariantUnder"),
598 app(cst("SwapKernel"), bvar(1)),
599 bvar(0),
600 ),
601 ),
602 )
603}
604pub fn simulated_annealing_convergence_ty() -> Expr {
610 pi(
611 BinderInfo::Default,
612 "X",
613 type0(),
614 pi(
615 BinderInfo::Default,
616 "f",
617 arrow(bvar(0), real_ty()),
618 pi(
619 BinderInfo::Default,
620 "T",
621 cst("CoolingSchedule"),
622 arrow(
623 app(cst("LogarithmicSchedule"), bvar(0)),
624 app(
625 cst("ConvergesToGlobalOpt"),
626 app2(cst("SAChain"), bvar(2), bvar(0)),
627 ),
628 ),
629 ),
630 ),
631 )
632}
633pub fn vae_elbo_decomp_ty() -> Expr {
639 pi(
640 BinderInfo::Default,
641 "X",
642 type0(),
643 pi(
644 BinderInfo::Default,
645 "encoder",
646 cst("NeuralNet"),
647 pi(
648 BinderInfo::Default,
649 "decoder",
650 cst("NeuralNet"),
651 pi(
652 BinderInfo::Default,
653 "x",
654 bvar(2),
655 app2(
656 cst("Eq"),
657 app3(cst("VAELBO"), bvar(2), bvar(1), bvar(0)),
658 app2(
659 cst("Real.sub"),
660 app3(cst("Reconstruction"), bvar(2), bvar(1), bvar(0)),
661 app2(cst("KLDivQ"), bvar(2), bvar(0)),
662 ),
663 ),
664 ),
665 ),
666 ),
667 )
668}
669pub fn diffusion_score_matching_ty() -> Expr {
675 pi(
676 BinderInfo::Default,
677 "X",
678 type0(),
679 pi(
680 BinderInfo::Default,
681 "p_data",
682 app(cst("Measure"), bvar(0)),
683 pi(
684 BinderInfo::Default,
685 "sigma",
686 real_ty(),
687 app2(
688 cst("Eq"),
689 app(
690 cst("OptimScore"),
691 app2(cst("DSMObjective"), bvar(1), bvar(0)),
692 ),
693 app(
694 cst("ScoreFunction"),
695 app2(cst("GaussianSmooth"), bvar(1), bvar(0)),
696 ),
697 ),
698 ),
699 ),
700 )
701}
702pub fn flow_matching_ode_ty() -> Expr {
708 pi(
709 BinderInfo::Default,
710 "X",
711 type0(),
712 pi(
713 BinderInfo::Default,
714 "p_0",
715 app(cst("Measure"), bvar(0)),
716 pi(
717 BinderInfo::Default,
718 "p_1",
719 app(cst("Measure"), bvar(1)),
720 pi(
721 BinderInfo::Default,
722 "vt",
723 cst("VectorField"),
724 arrow(
725 app3(cst("CondFlowMatchingField"), bvar(0), bvar(2), bvar(1)),
726 app2(
727 cst("Eq"),
728 app3(cst("PushforwardODE"), bvar(0), bvar(2), cst("Real.one")),
729 bvar(1),
730 ),
731 ),
732 ),
733 ),
734 ),
735 )
736}
737pub fn gp_posterior_is_gp_ty() -> Expr {
742 pi(
743 BinderInfo::Default,
744 "X",
745 type0(),
746 pi(
747 BinderInfo::Default,
748 "prior",
749 app(cst("GaussianProcess"), bvar(0)),
750 pi(
751 BinderInfo::Default,
752 "obs",
753 cst("Observations"),
754 app(
755 cst("IsGaussianProcess"),
756 app2(cst("GPPosterior"), bvar(1), bvar(0)),
757 ),
758 ),
759 ),
760 )
761}
762pub fn gp_marginal_gaussian_ty() -> Expr {
767 pi(
768 BinderInfo::Default,
769 "X",
770 type0(),
771 pi(
772 BinderInfo::Default,
773 "gp",
774 app(cst("GaussianProcess"), bvar(0)),
775 pi(
776 BinderInfo::Default,
777 "X_train",
778 list_ty(bvar(1)),
779 app2(
780 cst("Eq"),
781 app2(cst("MarginalLikelihood"), bvar(1), bvar(0)),
782 app2(
783 cst("MultivariateGaussian"),
784 app2(cst("GPMean"), bvar(1), bvar(0)),
785 app2(cst("GPKernelMatrix"), bvar(1), bvar(0)),
786 ),
787 ),
788 ),
789 ),
790 )
791}
792pub fn pn_integration_ty() -> Expr {
798 pi(
799 BinderInfo::Default,
800 "X",
801 type0(),
802 pi(
803 BinderInfo::Default,
804 "f",
805 arrow(bvar(0), real_ty()),
806 pi(
807 BinderInfo::Default,
808 "prior_gp",
809 app(cst("GaussianProcess"), bvar(1)),
810 app2(
811 cst("Eq"),
812 app2(cst("BQPosterior"), bvar(0), bvar(1)),
813 app(cst("GaussianMeasureOver"), real_ty()),
814 ),
815 ),
816 ),
817 )
818}
819pub fn stein_disc_zero_iff_ty() -> Expr {
825 pi(
826 BinderInfo::Default,
827 "X",
828 type0(),
829 pi(
830 BinderInfo::Default,
831 "p",
832 app(cst("Measure"), bvar(0)),
833 pi(
834 BinderInfo::Default,
835 "q",
836 app(cst("Measure"), bvar(1)),
837 pi(
838 BinderInfo::Default,
839 "k",
840 app(cst("SteinKernel"), bvar(2)),
841 app2(
842 cst("Iff"),
843 app2(
844 cst("Eq"),
845 app3(cst("KSD"), bvar(2), bvar(1), bvar(0)),
846 cst("Real.zero"),
847 ),
848 app2(cst("Eq"), bvar(2), bvar(1)),
849 ),
850 ),
851 ),
852 ),
853 )
854}
855pub fn smc_feynman_kac_ty() -> Expr {
861 pi(
862 BinderInfo::Default,
863 "X",
864 type0(),
865 pi(
866 BinderInfo::Default,
867 "fk",
868 app(cst("FeynmanKacModel"), bvar(0)),
869 pi(
870 BinderInfo::Default,
871 "N",
872 nat_ty(),
873 app2(
874 cst("Eq"),
875 app(
876 cst("Expectation"),
877 app2(cst("SMCNormConst"), bvar(1), bvar(0)),
878 ),
879 app(cst("FeynmanKacNormConst"), bvar(1)),
880 ),
881 ),
882 ),
883 )
884}
885pub fn pmmh_correctness_ty() -> Expr {
891 pi(
892 BinderInfo::Default,
893 "X",
894 type0(),
895 pi(
896 BinderInfo::Default,
897 "Y",
898 type0(),
899 pi(
900 BinderInfo::Default,
901 "model",
902 app2(cst("LatentModel"), bvar(1), bvar(0)),
903 pi(
904 BinderInfo::Default,
905 "obs",
906 bvar(1),
907 app(
908 cst("TargetsExactPosterior"),
909 app2(cst("PMMHKernel"), bvar(1), bvar(0)),
910 ),
911 ),
912 ),
913 ),
914 )
915}
916pub fn ais_unbiased_ty() -> Expr {
922 pi(
923 BinderInfo::Default,
924 "X",
925 type0(),
926 pi(
927 BinderInfo::Default,
928 "p_0",
929 app(cst("Measure"), bvar(0)),
930 pi(
931 BinderInfo::Default,
932 "p_T",
933 app(cst("Measure"), bvar(1)),
934 pi(
935 BinderInfo::Default,
936 "beta_sched",
937 cst("AnnealingSchedule"),
938 app2(
939 cst("Unbiased"),
940 app3(cst("AISEstimator"), bvar(2), bvar(1), bvar(0)),
941 app(cst("NormConst"), bvar(1)),
942 ),
943 ),
944 ),
945 ),
946 )
947}
948pub fn dsm_equals_sm_ty() -> Expr {
954 pi(
955 BinderInfo::Default,
956 "X",
957 type0(),
958 pi(
959 BinderInfo::Default,
960 "p_data",
961 app(cst("Measure"), bvar(0)),
962 pi(
963 BinderInfo::Default,
964 "sigma",
965 real_ty(),
966 app2(
967 cst("Eq"),
968 app2(cst("DSMObjective"), bvar(1), bvar(0)),
969 app(
970 cst("ImplicitSMObjective"),
971 app2(cst("GaussianConvolution"), bvar(1), bvar(0)),
972 ),
973 ),
974 ),
975 ),
976 )
977}
978pub fn langevin_convergence_ty() -> Expr {
984 pi(
985 BinderInfo::Default,
986 "X",
987 type0(),
988 pi(
989 BinderInfo::Default,
990 "target",
991 app(cst("LogConcaveMeasure"), bvar(0)),
992 pi(
993 BinderInfo::Default,
994 "eps",
995 real_ty(),
996 pi(
997 BinderInfo::Default,
998 "n",
999 nat_ty(),
1000 app2(
1001 cst("Real.le"),
1002 app3(
1003 cst("W2"),
1004 app3(cst("ULADist"), bvar(2), bvar(1), bvar(0)),
1005 bvar(2),
1006 bvar(1),
1007 ),
1008 app3(cst("LangevinBound"), bvar(2), bvar(1), bvar(0)),
1009 ),
1010 ),
1011 ),
1012 ),
1013 )
1014}
1015pub fn mh_detailed_balance_ty() -> Expr {
1021 pi(
1022 BinderInfo::Default,
1023 "X",
1024 type0(),
1025 pi(
1026 BinderInfo::Default,
1027 "target",
1028 app(cst("Measure"), bvar(0)),
1029 pi(
1030 BinderInfo::Default,
1031 "proposal",
1032 app2(cst("Kernel"), bvar(1), bvar(1)),
1033 app2(
1034 cst("DetailedBalance"),
1035 app2(cst("MHKernel"), bvar(1), bvar(0)),
1036 bvar(1),
1037 ),
1038 ),
1039 ),
1040 )
1041}
1042pub fn gibbs_invariant_ty() -> Expr {
1047 pi(
1048 BinderInfo::Default,
1049 "X",
1050 type0(),
1051 pi(
1052 BinderInfo::Default,
1053 "Y",
1054 type0(),
1055 pi(
1056 BinderInfo::Default,
1057 "joint",
1058 app(cst("Measure"), app2(cst("Prod"), bvar(1), bvar(0))),
1059 app2(
1060 cst("InvariantUnder"),
1061 app(cst("GibbsKernel"), bvar(0)),
1062 bvar(0),
1063 ),
1064 ),
1065 ),
1066 )
1067}
1068pub fn vae_posterior_collapse_risk_ty() -> Expr {
1074 pi(
1075 BinderInfo::Default,
1076 "decoder",
1077 cst("ExpressiveDecoder"),
1078 pi(
1079 BinderInfo::Default,
1080 "beta",
1081 real_ty(),
1082 arrow(
1083 app2(cst("Real.lt"), bvar(0), cst("Real.one")),
1084 app(
1085 cst("AvoidsCollapse"),
1086 app2(cst("BetaVAE"), bvar(1), bvar(0)),
1087 ),
1088 ),
1089 ),
1090 )
1091}
1092pub fn grad_log_normalizer_ty() -> Expr {
1098 pi(
1099 BinderInfo::Default,
1100 "T",
1101 type0(),
1102 pi(
1103 BinderInfo::Default,
1104 "eta",
1105 app(cst("NaturalParams"), bvar(0)),
1106 app2(
1107 cst("Eq"),
1108 app(cst("Gradient"), app(cst("LogNormalizer"), bvar(0))),
1109 app2(
1110 cst("MeanSuffStat"),
1111 bvar(0),
1112 app2(cst("ExpFamilyDist"), bvar(0), bvar(1)),
1113 ),
1114 ),
1115 ),
1116 )
1117}
1118pub fn smc_genealogy_ty() -> Expr {
1124 pi(
1125 BinderInfo::Default,
1126 "X",
1127 type0(),
1128 pi(
1129 BinderInfo::Default,
1130 "T",
1131 nat_ty(),
1132 pi(
1133 BinderInfo::Default,
1134 "pf",
1135 app2(cst("ParticleSystem"), bvar(1), bvar(0)),
1136 app2(
1137 cst("Eq"),
1138 app(cst("AncestralLineage"), bvar(0)),
1139 app(
1140 cst("CoalescentProcess"),
1141 app(cst("ResamplingTimes"), bvar(0)),
1142 ),
1143 ),
1144 ),
1145 ),
1146 )
1147}
1148pub fn kde_consistency_ty() -> Expr {
1154 pi(
1155 BinderInfo::Default,
1156 "X",
1157 type0(),
1158 pi(
1159 BinderInfo::Default,
1160 "p",
1161 app(cst("SmoothMeasure"), bvar(0)),
1162 pi(
1163 BinderInfo::Default,
1164 "h",
1165 cst("BandwidthSeq"),
1166 arrow(
1167 app(cst("OptimalBandwidth"), bvar(0)),
1168 app(cst("L2Convergence"), app2(cst("KDEn"), bvar(1), bvar(0))),
1169 ),
1170 ),
1171 ),
1172 )
1173}
1174pub fn mean_field_cavi_ty() -> Expr {
1180 pi(
1181 BinderInfo::Default,
1182 "Z",
1183 type0(),
1184 pi(
1185 BinderInfo::Default,
1186 "joint",
1187 app(cst("Measure"), bvar(0)),
1188 pi(
1189 BinderInfo::Default,
1190 "q_factors",
1191 list_ty(app(cst("Measure"), bvar(1))),
1192 app2(
1193 cst("Eq"),
1194 app2(cst("CAVIStep"), bvar(1), bvar(0)),
1195 app2(cst("UpdatedFactors"), bvar(1), bvar(0)),
1196 ),
1197 ),
1198 ),
1199 )
1200}
1201pub fn pbp_gaussian_propagation_ty() -> Expr {
1207 pi(
1208 BinderInfo::Default,
1209 "net",
1210 cst("BayesianNeuralNet"),
1211 pi(
1212 BinderInfo::Default,
1213 "x",
1214 cst("Input"),
1215 app2(
1216 cst("Eq"),
1217 app(
1218 cst("GaussianApproxActivations"),
1219 app2(cst("PBP"), bvar(1), bvar(0)),
1220 ),
1221 app2(cst("PBPActivations"), bvar(1), bvar(0)),
1222 ),
1223 ),
1224 )
1225}
1226pub fn ep_fixed_point_ty() -> Expr {
1232 pi(
1233 BinderInfo::Default,
1234 "model",
1235 cst("FactorGraph"),
1236 pi(
1237 BinderInfo::Default,
1238 "approx",
1239 cst("GaussianApprox"),
1240 app2(
1241 cst("Iff"),
1242 app2(cst("EPFixedPoint"), bvar(0), bvar(1)),
1243 app2(cst("CavityAgreement"), bvar(0), bvar(1)),
1244 ),
1245 ),
1246 )
1247}
1248pub fn nested_mc_bias_ty() -> Expr {
1254 pi(
1255 BinderInfo::Default,
1256 "X",
1257 type0(),
1258 pi(
1259 BinderInfo::Default,
1260 "outer",
1261 nat_ty(),
1262 pi(
1263 BinderInfo::Default,
1264 "inner",
1265 nat_ty(),
1266 pi(
1267 BinderInfo::Default,
1268 "f",
1269 arrow(bvar(2), real_ty()),
1270 app2(
1271 cst("Real.le"),
1272 app(
1273 cst("Bias"),
1274 app3(cst("NestedMCEstimator"), bvar(0), bvar(2), bvar(1)),
1275 ),
1276 app(cst("NestedMCBiasRate"), bvar(1)),
1277 ),
1278 ),
1279 ),
1280 ),
1281 )
1282}
1283pub fn abc_smc_consistency_ty() -> Expr {
1289 pi(
1290 BinderInfo::Default,
1291 "Theta",
1292 type0(),
1293 pi(
1294 BinderInfo::Default,
1295 "Y",
1296 type0(),
1297 pi(
1298 BinderInfo::Default,
1299 "prior",
1300 app(cst("Measure"), bvar(1)),
1301 pi(
1302 BinderInfo::Default,
1303 "sim",
1304 arrow(bvar(2), app(cst("Measure"), bvar(1))),
1305 pi(
1306 BinderInfo::Default,
1307 "eps",
1308 real_ty(),
1309 arrow(
1310 app2(cst("Real.gt"), bvar(0), cst("Real.zero")),
1311 app3(
1312 cst("ApproxPosterior"),
1313 app3(cst("ABCSMC"), bvar(3), bvar(1), bvar(0)),
1314 bvar(0),
1315 app2(cst("TruePosterior"), bvar(3), bvar(1)),
1316 ),
1317 ),
1318 ),
1319 ),
1320 ),
1321 ),
1322 )
1323}
1324pub fn build_probabilistic_programming_env(
1326 env: &mut Environment,
1327) -> Result<(), Box<dyn std::error::Error>> {
1328 let axioms: &[(&str, Expr)] = &[
1329 ("Measure", measure_ty()),
1330 ("SigmaAlgebra", sigma_algebra_ty()),
1331 ("MeasurableSpace", measurable_space_ty()),
1332 ("ProbabilityMonad", probability_monad_ty()),
1333 ("Kernel", kernel_ty()),
1334 ("Sampler", sampler_ty()),
1335 ("PPLProgram", ppl_program_ty()),
1336 ("GradientEstimator", gradient_estimator_ty()),
1337 ("ParticleFilter", particle_filter_ty()),
1338 (
1339 "Joint",
1340 arrow(
1341 app(cst("Measure"), type0()),
1342 arrow(app2(cst("Kernel"), type0(), type0()), type0()),
1343 ),
1344 ),
1345 ("JointMeasure", arrow(type0(), arrow(type0(), type0()))),
1346 (
1347 "Marginal",
1348 arrow(app2(cst("Kernel"), type0(), type0()), type0()),
1349 ),
1350 (
1351 "Posterior",
1352 arrow(
1353 app(cst("Measure"), type0()),
1354 arrow(
1355 app2(cst("Kernel"), type0(), type0()),
1356 app(cst("Measure"), type0()),
1357 ),
1358 ),
1359 ),
1360 ("RadonNikodym", arrow(type0(), arrow(type0(), type0()))),
1361 (
1362 "AbsContinuous",
1363 arrow(
1364 app(cst("Measure"), type0()),
1365 arrow(app(cst("Measure"), type0()), prop()),
1366 ),
1367 ),
1368 ("ConsistentEstimator", arrow(type0(), prop())),
1369 (
1370 "ISSampler",
1371 arrow(
1372 arrow(type0(), real_ty()),
1373 arrow(
1374 app(cst("Measure"), type0()),
1375 arrow(app(cst("Measure"), type0()), type0()),
1376 ),
1377 ),
1378 ),
1379 (
1380 "ELBO",
1381 arrow(
1382 app(cst("Measure"), type0()),
1383 arrow(type0(), arrow(type0(), real_ty())),
1384 ),
1385 ),
1386 (
1387 "LogMarginalLikelihood",
1388 arrow(type0(), arrow(type0(), real_ty())),
1389 ),
1390 ("MonadLaws", arrow(arrow(type0(), type0()), prop())),
1391 ("Unbiased", arrow(type0(), arrow(type0(), prop()))),
1392 ("ReparamGradient", arrow(arrow(type0(), real_ty()), type0())),
1393 (
1394 "GradExpectation",
1395 arrow(arrow(type0(), real_ty()), arrow(type0(), type0())),
1396 ),
1397 (
1398 "HMCKernel",
1399 arrow(
1400 app(cst("Measure"), type0()),
1401 app2(cst("Kernel"), type0(), type0()),
1402 ),
1403 ),
1404 (
1405 "InvariantUnder",
1406 arrow(
1407 app2(cst("Kernel"), type0(), type0()),
1408 arrow(app(cst("Measure"), type0()), prop()),
1409 ),
1410 ),
1411 ("StateSpaceModel", arrow(type0(), type0())),
1412 ("ConsistentFilteringDist", arrow(type0(), prop())),
1413 ("bayes_measure_theory", bayes_measure_theory_ty()),
1414 ("giry_monad_laws", giry_monad_laws_ty()),
1415 ("is_consistency", is_consistency_ty()),
1416 ("elbo_lower_bound", elbo_lower_bound_ty()),
1417 ("reparam_unbiased", reparam_unbiased_ty()),
1418 ("hmc_invariant", hmc_invariant_ty()),
1419 ("smc_consistency", smc_consistency_ty()),
1420 ("VariationalFamily", arrow(type0(), type0())),
1421 ("LRSchedule", type0()),
1422 ("SVIOptimizer", arrow(type0(), arrow(type0(), type0()))),
1423 ("ConvergesToLocalOptimum", arrow(type0(), prop())),
1424 ("Bijective", arrow(arrow(type0(), type0()), prop())),
1425 ("FlowDensityEq", arrow(type0(), prop())),
1426 ("ParametricMeasure", arrow(type0(), type0())),
1427 ("RegularFamily", arrow(type0(), prop())),
1428 (
1429 "ScoreFnGrad",
1430 arrow(type0(), arrow(arrow(type0(), real_ty()), type0())),
1431 ),
1432 (
1433 "GradELBO",
1434 arrow(type0(), arrow(arrow(type0(), real_ty()), type0())),
1435 ),
1436 (
1437 "DiffReparameterisation",
1438 arrow(arrow(type0(), arrow(type0(), type0())), prop()),
1439 ),
1440 (
1441 "PathwiseGrad",
1442 arrow(
1443 arrow(type0(), arrow(type0(), type0())),
1444 arrow(arrow(type0(), real_ty()), type0()),
1445 ),
1446 ),
1447 (
1448 "Reparam",
1449 arrow(arrow(type0(), arrow(type0(), type0())), type0()),
1450 ),
1451 ("ProbMeasure", arrow(type0(), type0())),
1452 (
1453 "Pushforward",
1454 arrow(
1455 arrow(type0(), type0()),
1456 arrow(app(cst("Measure"), type0()), app(cst("Measure"), type0())),
1457 ),
1458 ),
1459 (
1460 "W1Dist",
1461 arrow(
1462 app(cst("Measure"), type0()),
1463 arrow(app(cst("Measure"), type0()), real_ty()),
1464 ),
1465 ),
1466 (
1467 "KantorovichDual",
1468 arrow(
1469 app(cst("Measure"), type0()),
1470 arrow(app(cst("Measure"), type0()), real_ty()),
1471 ),
1472 ),
1473 ("SmoothMeasure", arrow(type0(), type0())),
1474 ("SmoothTestFn", arrow(arrow(type0(), real_ty()), prop())),
1475 (
1476 "Expectation",
1477 arrow(app(cst("Measure"), type0()), arrow(type0(), real_ty())),
1478 ),
1479 (
1480 "SteinOp",
1481 arrow(
1482 app(cst("SmoothMeasure"), type0()),
1483 arrow(arrow(type0(), real_ty()), type0()),
1484 ),
1485 ),
1486 ("Real.zero", real_ty()),
1487 ("Real.one", real_ty()),
1488 (
1489 "SteinDiscrepancy",
1490 arrow(
1491 app(cst("Measure"), type0()),
1492 arrow(app(cst("Measure"), type0()), real_ty()),
1493 ),
1494 ),
1495 (
1496 "SVGDUpdate",
1497 arrow(
1498 app(cst("SmoothMeasure"), type0()),
1499 arrow(nat_ty(), app(cst("Measure"), type0())),
1500 ),
1501 ),
1502 ("SVGDBound", arrow(nat_ty(), real_ty())),
1503 ("SteinKernel", arrow(type0(), type0())),
1504 (
1505 "KSD",
1506 arrow(
1507 app(cst("Measure"), type0()),
1508 arrow(
1509 app(cst("Measure"), type0()),
1510 arrow(app(cst("SteinKernel"), type0()), real_ty()),
1511 ),
1512 ),
1513 ),
1514 ("Iff", arrow(prop(), arrow(prop(), prop()))),
1515 (
1516 "PMCEstimator",
1517 arrow(
1518 app(cst("Measure"), type0()),
1519 arrow(nat_ty(), arrow(nat_ty(), type0())),
1520 ),
1521 ),
1522 ("Tempering", type0()),
1523 (
1524 "EvolMCMCKernel",
1525 arrow(
1526 app(cst("Measure"), type0()),
1527 arrow(cst("Tempering"), app2(cst("Kernel"), type0(), type0())),
1528 ),
1529 ),
1530 (
1531 "TemperedTarget",
1532 arrow(
1533 app(cst("Measure"), type0()),
1534 arrow(cst("Tempering"), app(cst("Measure"), type0())),
1535 ),
1536 ),
1537 (
1538 "DetailedBalance",
1539 arrow(
1540 app2(cst("Kernel"), type0(), type0()),
1541 arrow(app(cst("Measure"), type0()), prop()),
1542 ),
1543 ),
1544 ("ProductSpace", arrow(list_ty(real_ty()), type0())),
1545 (
1546 "SwapKernel",
1547 arrow(list_ty(real_ty()), app2(cst("Kernel"), type0(), type0())),
1548 ),
1549 ("CoolingSchedule", type0()),
1550 ("LogarithmicSchedule", arrow(cst("CoolingSchedule"), prop())),
1551 (
1552 "SAChain",
1553 arrow(
1554 arrow(type0(), real_ty()),
1555 arrow(cst("CoolingSchedule"), type0()),
1556 ),
1557 ),
1558 ("ConvergesToGlobalOpt", arrow(type0(), prop())),
1559 ("NeuralNet", type0()),
1560 (
1561 "VAELBO",
1562 arrow(
1563 cst("NeuralNet"),
1564 arrow(cst("NeuralNet"), arrow(type0(), real_ty())),
1565 ),
1566 ),
1567 (
1568 "Reconstruction",
1569 arrow(
1570 cst("NeuralNet"),
1571 arrow(cst("NeuralNet"), arrow(type0(), real_ty())),
1572 ),
1573 ),
1574 ("KLDivQ", arrow(cst("NeuralNet"), arrow(type0(), real_ty()))),
1575 ("Real.sub", arrow(real_ty(), arrow(real_ty(), real_ty()))),
1576 (
1577 "DSMObjective",
1578 arrow(app(cst("Measure"), type0()), arrow(real_ty(), type0())),
1579 ),
1580 (
1581 "GaussianSmooth",
1582 arrow(
1583 app(cst("Measure"), type0()),
1584 arrow(real_ty(), app(cst("Measure"), type0())),
1585 ),
1586 ),
1587 ("OptimScore", arrow(type0(), type0())),
1588 (
1589 "ScoreFunction",
1590 arrow(app(cst("Measure"), type0()), type0()),
1591 ),
1592 ("VectorField", type0()),
1593 (
1594 "CondFlowMatchingField",
1595 arrow(
1596 cst("VectorField"),
1597 arrow(
1598 app(cst("Measure"), type0()),
1599 arrow(app(cst("Measure"), type0()), prop()),
1600 ),
1601 ),
1602 ),
1603 (
1604 "PushforwardODE",
1605 arrow(
1606 cst("VectorField"),
1607 arrow(
1608 app(cst("Measure"), type0()),
1609 arrow(real_ty(), app(cst("Measure"), type0())),
1610 ),
1611 ),
1612 ),
1613 ("GaussianProcess", arrow(type0(), type0())),
1614 ("Observations", type0()),
1615 ("IsGaussianProcess", arrow(type0(), prop())),
1616 (
1617 "GPPosterior",
1618 arrow(
1619 app(cst("GaussianProcess"), type0()),
1620 arrow(cst("Observations"), type0()),
1621 ),
1622 ),
1623 (
1624 "MarginalLikelihood",
1625 arrow(
1626 app(cst("GaussianProcess"), type0()),
1627 arrow(list_ty(type0()), app(cst("Measure"), type0())),
1628 ),
1629 ),
1630 (
1631 "MultivariateGaussian",
1632 arrow(type0(), arrow(type0(), app(cst("Measure"), type0()))),
1633 ),
1634 (
1635 "GPMean",
1636 arrow(
1637 app(cst("GaussianProcess"), type0()),
1638 arrow(list_ty(type0()), type0()),
1639 ),
1640 ),
1641 (
1642 "GPKernelMatrix",
1643 arrow(
1644 app(cst("GaussianProcess"), type0()),
1645 arrow(list_ty(type0()), type0()),
1646 ),
1647 ),
1648 (
1649 "BQPosterior",
1650 arrow(
1651 app(cst("GaussianProcess"), type0()),
1652 arrow(arrow(type0(), real_ty()), app(cst("Measure"), real_ty())),
1653 ),
1654 ),
1655 (
1656 "GaussianMeasureOver",
1657 arrow(type0(), app(cst("Measure"), real_ty())),
1658 ),
1659 ("FeynmanKacModel", arrow(type0(), type0())),
1660 (
1661 "SMCNormConst",
1662 arrow(
1663 app(cst("FeynmanKacModel"), type0()),
1664 arrow(nat_ty(), real_ty()),
1665 ),
1666 ),
1667 (
1668 "FeynmanKacNormConst",
1669 arrow(app(cst("FeynmanKacModel"), type0()), real_ty()),
1670 ),
1671 ("LatentModel", arrow(type0(), arrow(type0(), type0()))),
1672 (
1673 "PMMHKernel",
1674 arrow(
1675 app2(cst("LatentModel"), type0(), type0()),
1676 arrow(type0(), app2(cst("Kernel"), type0(), type0())),
1677 ),
1678 ),
1679 (
1680 "TargetsExactPosterior",
1681 arrow(app2(cst("Kernel"), type0(), type0()), prop()),
1682 ),
1683 ("AnnealingSchedule", type0()),
1684 (
1685 "AISEstimator",
1686 arrow(
1687 app(cst("Measure"), type0()),
1688 arrow(
1689 app(cst("Measure"), type0()),
1690 arrow(cst("AnnealingSchedule"), type0()),
1691 ),
1692 ),
1693 ),
1694 ("NormConst", arrow(app(cst("Measure"), type0()), real_ty())),
1695 (
1696 "ImplicitSMObjective",
1697 arrow(app(cst("Measure"), type0()), type0()),
1698 ),
1699 (
1700 "GaussianConvolution",
1701 arrow(
1702 app(cst("Measure"), type0()),
1703 arrow(real_ty(), app(cst("Measure"), type0())),
1704 ),
1705 ),
1706 ("LogConcaveMeasure", arrow(type0(), type0())),
1707 (
1708 "ULADist",
1709 arrow(
1710 app(cst("LogConcaveMeasure"), type0()),
1711 arrow(real_ty(), arrow(nat_ty(), app(cst("Measure"), type0()))),
1712 ),
1713 ),
1714 (
1715 "W2",
1716 arrow(
1717 app(cst("Measure"), type0()),
1718 arrow(app(cst("Measure"), type0()), real_ty()),
1719 ),
1720 ),
1721 (
1722 "LangevinBound",
1723 arrow(
1724 app(cst("LogConcaveMeasure"), type0()),
1725 arrow(real_ty(), arrow(nat_ty(), real_ty())),
1726 ),
1727 ),
1728 (
1729 "MHKernel",
1730 arrow(
1731 app(cst("Measure"), type0()),
1732 arrow(
1733 app2(cst("Kernel"), type0(), type0()),
1734 app2(cst("Kernel"), type0(), type0()),
1735 ),
1736 ),
1737 ),
1738 ("Prod", arrow(type0(), arrow(type0(), type0()))),
1739 (
1740 "GibbsKernel",
1741 arrow(
1742 app(cst("Measure"), app2(cst("Prod"), type0(), type0())),
1743 app2(
1744 cst("Kernel"),
1745 app2(cst("Prod"), type0(), type0()),
1746 app2(cst("Prod"), type0(), type0()),
1747 ),
1748 ),
1749 ),
1750 ("ExpressiveDecoder", type0()),
1751 (
1752 "BetaVAE",
1753 arrow(cst("ExpressiveDecoder"), arrow(real_ty(), type0())),
1754 ),
1755 ("AvoidsCollapse", arrow(type0(), prop())),
1756 ("Real.lt", arrow(real_ty(), arrow(real_ty(), prop()))),
1757 ("NaturalParams", arrow(type0(), type0())),
1758 ("LogNormalizer", arrow(type0(), arrow(type0(), real_ty()))),
1759 ("Gradient", arrow(arrow(type0(), real_ty()), type0())),
1760 (
1761 "MeanSuffStat",
1762 arrow(type0(), arrow(app(cst("Measure"), type0()), type0())),
1763 ),
1764 (
1765 "ExpFamilyDist",
1766 arrow(type0(), arrow(type0(), app(cst("Measure"), type0()))),
1767 ),
1768 ("ParticleSystem", arrow(type0(), arrow(nat_ty(), type0()))),
1769 ("AncestralLineage", arrow(type0(), type0())),
1770 ("CoalescentProcess", arrow(type0(), type0())),
1771 ("ResamplingTimes", arrow(type0(), type0())),
1772 ("BandwidthSeq", type0()),
1773 (
1774 "KDEn",
1775 arrow(
1776 app(cst("SmoothMeasure"), type0()),
1777 arrow(cst("BandwidthSeq"), type0()),
1778 ),
1779 ),
1780 ("OptimalBandwidth", arrow(cst("BandwidthSeq"), prop())),
1781 ("L2Convergence", arrow(type0(), prop())),
1782 (
1783 "CAVIStep",
1784 arrow(
1785 app(cst("Measure"), type0()),
1786 arrow(
1787 list_ty(app(cst("Measure"), type0())),
1788 list_ty(app(cst("Measure"), type0())),
1789 ),
1790 ),
1791 ),
1792 (
1793 "UpdatedFactors",
1794 arrow(
1795 app(cst("Measure"), type0()),
1796 arrow(
1797 list_ty(app(cst("Measure"), type0())),
1798 list_ty(app(cst("Measure"), type0())),
1799 ),
1800 ),
1801 ),
1802 ("BayesianNeuralNet", type0()),
1803 ("Input", type0()),
1804 (
1805 "PBP",
1806 arrow(cst("BayesianNeuralNet"), arrow(cst("Input"), type0())),
1807 ),
1808 ("GaussianApproxActivations", arrow(type0(), type0())),
1809 (
1810 "PBPActivations",
1811 arrow(cst("BayesianNeuralNet"), arrow(cst("Input"), type0())),
1812 ),
1813 ("FactorGraph", type0()),
1814 ("GaussianApprox", type0()),
1815 (
1816 "EPFixedPoint",
1817 arrow(cst("GaussianApprox"), arrow(cst("FactorGraph"), prop())),
1818 ),
1819 (
1820 "CavityAgreement",
1821 arrow(cst("GaussianApprox"), arrow(cst("FactorGraph"), prop())),
1822 ),
1823 (
1824 "NestedMCEstimator",
1825 arrow(
1826 arrow(type0(), real_ty()),
1827 arrow(nat_ty(), arrow(nat_ty(), type0())),
1828 ),
1829 ),
1830 ("Bias", arrow(type0(), real_ty())),
1831 ("NestedMCBiasRate", arrow(nat_ty(), real_ty())),
1832 (
1833 "ABCSMC",
1834 arrow(
1835 app(cst("Measure"), type0()),
1836 arrow(
1837 arrow(type0(), app(cst("Measure"), type0())),
1838 arrow(real_ty(), app(cst("Measure"), type0())),
1839 ),
1840 ),
1841 ),
1842 (
1843 "TruePosterior",
1844 arrow(
1845 app(cst("Measure"), type0()),
1846 arrow(
1847 arrow(type0(), app(cst("Measure"), type0())),
1848 app(cst("Measure"), type0()),
1849 ),
1850 ),
1851 ),
1852 (
1853 "ApproxPosterior",
1854 arrow(
1855 app(cst("Measure"), type0()),
1856 arrow(real_ty(), arrow(app(cst("Measure"), type0()), prop())),
1857 ),
1858 ),
1859 ("Real.gt", arrow(real_ty(), arrow(real_ty(), prop()))),
1860 ("svi_convergence", svi_convergence_ty()),
1861 ("normalizing_flow_cov", normalizing_flow_cov_ty()),
1862 ("score_fn_unbiased", score_fn_unbiased_ty()),
1863 (
1864 "pathwise_gradient_unbiased",
1865 pathwise_gradient_unbiased_ty(),
1866 ),
1867 ("measure_transport_exists", measure_transport_exists_ty()),
1868 ("ot_kantorovich", ot_kantorovich_ty()),
1869 ("stein_identity", stein_identity_ty()),
1870 ("svgd_convergence", svgd_convergence_ty()),
1871 ("pmc_consistency", pmc_consistency_ty()),
1872 (
1873 "evol_mcmc_detailed_balance",
1874 evol_mcmc_detailed_balance_ty(),
1875 ),
1876 (
1877 "parallel_tempering_exchange",
1878 parallel_tempering_exchange_ty(),
1879 ),
1880 (
1881 "simulated_annealing_convergence",
1882 simulated_annealing_convergence_ty(),
1883 ),
1884 ("vae_elbo_decomp", vae_elbo_decomp_ty()),
1885 ("diffusion_score_matching", diffusion_score_matching_ty()),
1886 ("flow_matching_ode", flow_matching_ode_ty()),
1887 ("gp_posterior_is_gp", gp_posterior_is_gp_ty()),
1888 ("gp_marginal_gaussian", gp_marginal_gaussian_ty()),
1889 ("pn_integration", pn_integration_ty()),
1890 ("stein_disc_zero_iff", stein_disc_zero_iff_ty()),
1891 ("smc_feynman_kac", smc_feynman_kac_ty()),
1892 ("pmmh_correctness", pmmh_correctness_ty()),
1893 ("ais_unbiased", ais_unbiased_ty()),
1894 ("dsm_equals_sm", dsm_equals_sm_ty()),
1895 ("langevin_convergence", langevin_convergence_ty()),
1896 ("mh_detailed_balance", mh_detailed_balance_ty()),
1897 ("gibbs_invariant", gibbs_invariant_ty()),
1898 (
1899 "vae_posterior_collapse_risk",
1900 vae_posterior_collapse_risk_ty(),
1901 ),
1902 ("grad_log_normalizer", grad_log_normalizer_ty()),
1903 ("smc_genealogy", smc_genealogy_ty()),
1904 ("kde_consistency", kde_consistency_ty()),
1905 ("mean_field_cavi", mean_field_cavi_ty()),
1906 ("pbp_gaussian_propagation", pbp_gaussian_propagation_ty()),
1907 ("ep_fixed_point", ep_fixed_point_ty()),
1908 ("nested_mc_bias", nested_mc_bias_ty()),
1909 ("abc_smc_consistency", abc_smc_consistency_ty()),
1910 ];
1911 for (name, ty) in axioms {
1912 env.add(Declaration::Axiom {
1913 name: Name::str(*name),
1914 univ_params: vec![],
1915 ty: ty.clone(),
1916 })
1917 .ok();
1918 }
1919 Ok(())
1920}