Skip to main content

oxilean_std/probabilistic_programming/
functions.rs

1//! Auto-generated module
2//!
3//! šŸ¤– Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{Distribution, ParticleFilter};
8
9pub fn app(f: Expr, a: Expr) -> Expr {
10    Expr::App(Box::new(f), Box::new(a))
11}
12pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
13    app(app(f, a), b)
14}
15pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
16    app(app2(f, a, b), c)
17}
18pub fn cst(s: &str) -> Expr {
19    Expr::Const(Name::str(s), vec![])
20}
21pub fn prop() -> Expr {
22    Expr::Sort(Level::zero())
23}
24pub fn type0() -> Expr {
25    Expr::Sort(Level::succ(Level::zero()))
26}
27pub fn type1() -> Expr {
28    Expr::Sort(Level::succ(Level::succ(Level::zero())))
29}
30pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
31    Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
32}
33pub fn arrow(a: Expr, b: Expr) -> Expr {
34    pi(BinderInfo::Default, "_", a, b)
35}
36pub fn bvar(n: u32) -> Expr {
37    Expr::BVar(n)
38}
39pub fn nat_ty() -> Expr {
40    cst("Nat")
41}
42pub fn real_ty() -> Expr {
43    cst("Real")
44}
45pub fn bool_ty() -> Expr {
46    cst("Bool")
47}
48pub fn list_ty(elem: Expr) -> Expr {
49    app(cst("List"), elem)
50}
51/// `Measure : Type → Type` — a σ-finite measure on a measurable space.
52pub fn measure_ty() -> Expr {
53    arrow(type0(), type0())
54}
55/// `SigmaAlgebra : Type → Type` — a σ-algebra of subsets.
56pub fn sigma_algebra_ty() -> Expr {
57    arrow(type0(), type0())
58}
59/// `MeasurableSpace : Type` — a type equipped with a σ-algebra.
60pub fn measurable_space_ty() -> Expr {
61    type0()
62}
63/// `ProbabilityMonad : (Type → Type)` — the distribution / Giry monad.
64pub fn probability_monad_ty() -> Expr {
65    arrow(type0(), type0())
66}
67/// `Kernel : Type → Type → Type` — a Markov kernel k(x, A).
68pub fn kernel_ty() -> Expr {
69    arrow(type0(), arrow(type0(), type0()))
70}
71/// `Sampler : Type → Type` — a procedure that draws samples from a distribution.
72pub fn sampler_ty() -> Expr {
73    arrow(type0(), type0())
74}
75/// `Density : Type → Type` — a density/pmf function on a type.
76pub fn density_ty() -> Expr {
77    arrow(type0(), arrow(type0(), real_ty()))
78}
79/// `PPLProgram : Type → Type` — a probabilistic program returning values of type A.
80pub fn ppl_program_ty() -> Expr {
81    arrow(type0(), type0())
82}
83/// `ELBO : (Type → Type) → Type → Real` — evidence lower bound for variational inference.
84pub fn elbo_ty() -> Expr {
85    arrow(arrow(type0(), type0()), arrow(type0(), real_ty()))
86}
87/// `ImportanceWeight : Type → Real` — self-normalised importance weight.
88pub fn importance_weight_ty() -> Expr {
89    arrow(type0(), real_ty())
90}
91/// `ParticleFilter : Type → Type` — sequential Monte Carlo state estimator.
92pub fn particle_filter_ty() -> Expr {
93    arrow(type0(), type0())
94}
95/// `GradientEstimator : Type → Type` — a Monte Carlo gradient estimator.
96pub fn gradient_estimator_ty() -> Expr {
97    arrow(type0(), type0())
98}
99/// **Measure-Theoretic Bayes**: the posterior is the Radon-Nikodym derivative
100/// of the joint w.r.t. the marginal likelihood.
101///
102/// `bayes_measure_theory : āˆ€ (prior : Measure X) (likelihood : Kernel X Y),
103///   Posterior prior likelihood = RadonNikodym (Joint prior likelihood) (Marginal likelihood)`
104pub fn bayes_measure_theory_ty() -> Expr {
105    pi(
106        BinderInfo::Default,
107        "X",
108        type0(),
109        pi(
110            BinderInfo::Default,
111            "Y",
112            type0(),
113            pi(
114                BinderInfo::Default,
115                "prior",
116                app(cst("Measure"), bvar(1)),
117                pi(
118                    BinderInfo::Default,
119                    "likelihood",
120                    app2(cst("Kernel"), bvar(2), bvar(1)),
121                    app2(
122                        cst("Eq"),
123                        app2(cst("Posterior"), bvar(1), bvar(0)),
124                        app2(
125                            cst("RadonNikodym"),
126                            app2(cst("Joint"), bvar(1), bvar(0)),
127                            app(cst("Marginal"), bvar(0)),
128                        ),
129                    ),
130                ),
131            ),
132        ),
133    )
134}
135/// **Giry Monad Laws**: the distribution monad satisfies monad laws.
136///
137/// `giry_monad_laws : MonadLaws ProbabilityMonad`
138pub fn giry_monad_laws_ty() -> Expr {
139    app(cst("MonadLaws"), cst("ProbabilityMonad"))
140}
141/// **Importance Sampling Consistency**: the IS estimator is consistent as Nā†’āˆž.
142///
143/// `is_consistency : āˆ€ (f : X → Real) (q p : Measure X),
144///   AbsContinuous p q → ConsistentEstimator (ISSampler f q p)`
145pub fn is_consistency_ty() -> Expr {
146    pi(
147        BinderInfo::Default,
148        "X",
149        type0(),
150        pi(
151            BinderInfo::Default,
152            "f",
153            arrow(bvar(0), real_ty()),
154            pi(
155                BinderInfo::Default,
156                "q",
157                app(cst("Measure"), bvar(1)),
158                pi(
159                    BinderInfo::Default,
160                    "p",
161                    app(cst("Measure"), bvar(2)),
162                    arrow(
163                        app2(cst("AbsContinuous"), bvar(0), bvar(1)),
164                        app(
165                            cst("ConsistentEstimator"),
166                            app3(cst("ISSampler"), bvar(3), bvar(1), bvar(0)),
167                        ),
168                    ),
169                ),
170            ),
171        ),
172    )
173}
174/// **ELBO Lower Bound**: ELBO(q) ≤ log p(x) for all variational families q.
175///
176/// `elbo_lower_bound : āˆ€ (q : Measure Z) (p : Joint Z X) (x : X),
177///   ELBO q p x ≤ LogMarginalLikelihood p x`
178pub fn elbo_lower_bound_ty() -> Expr {
179    pi(
180        BinderInfo::Default,
181        "Z",
182        type0(),
183        pi(
184            BinderInfo::Default,
185            "X",
186            type0(),
187            pi(
188                BinderInfo::Default,
189                "q",
190                app(cst("Measure"), bvar(1)),
191                pi(
192                    BinderInfo::Default,
193                    "p",
194                    app2(cst("JointMeasure"), bvar(2), bvar(1)),
195                    pi(
196                        BinderInfo::Default,
197                        "x",
198                        bvar(2),
199                        app2(
200                            cst("Real.le"),
201                            app3(cst("ELBO"), bvar(2), bvar(1), bvar(0)),
202                            app2(cst("LogMarginalLikelihood"), bvar(1), bvar(0)),
203                        ),
204                    ),
205                ),
206            ),
207        ),
208    )
209}
210/// **Reparameterisation Gradient**: the reparameterised gradient estimator is
211/// an unbiased estimator of āˆ‡_φ E_{z~q_φ}[f(z)].
212///
213/// `reparam_unbiased : āˆ€ (φ : Params) (f : Z → Real),
214///   Unbiased (ReparamGradient f) (GradExpectation f φ)`
215pub fn reparam_unbiased_ty() -> Expr {
216    pi(
217        BinderInfo::Default,
218        "Z",
219        type0(),
220        pi(
221            BinderInfo::Default,
222            "Params",
223            type0(),
224            pi(
225                BinderInfo::Default,
226                "phi",
227                bvar(0),
228                pi(
229                    BinderInfo::Default,
230                    "f",
231                    arrow(bvar(2), real_ty()),
232                    app2(
233                        cst("Unbiased"),
234                        app(cst("ReparamGradient"), bvar(0)),
235                        app2(cst("GradExpectation"), bvar(0), bvar(1)),
236                    ),
237                ),
238            ),
239        ),
240    )
241}
242/// **HMC Correctness**: Hamiltonian Monte Carlo leaves the target distribution invariant.
243///
244/// `hmc_invariant : āˆ€ (target : Measure X), InvariantUnder (HMCKernel target) target`
245pub fn hmc_invariant_ty() -> Expr {
246    pi(
247        BinderInfo::Default,
248        "X",
249        type0(),
250        pi(
251            BinderInfo::Default,
252            "target",
253            app(cst("Measure"), bvar(0)),
254            app2(
255                cst("InvariantUnder"),
256                app(cst("HMCKernel"), bvar(0)),
257                bvar(0),
258            ),
259        ),
260    )
261}
262/// **SMC Consistency**: sequential Monte Carlo converges to the true filtering distribution.
263pub fn smc_consistency_ty() -> Expr {
264    pi(
265        BinderInfo::Default,
266        "X",
267        type0(),
268        pi(
269            BinderInfo::Default,
270            "ssm",
271            app(cst("StateSpaceModel"), bvar(0)),
272            app(cst("ConsistentFilteringDist"), bvar(0)),
273        ),
274    )
275}
276/// **Stochastic Variational Inference (SVI) Convergence**: SVI converges to a local ELBO maximum.
277///
278/// `svi_convergence : āˆ€ (q_family : VariationalFamily) (lr_schedule : LRSchedule),
279///   ConvergesToLocalOptimum (SVIOptimizer q_family lr_schedule)`
280pub fn svi_convergence_ty() -> Expr {
281    pi(
282        BinderInfo::Default,
283        "VF",
284        type0(),
285        pi(
286            BinderInfo::Default,
287            "lr",
288            type0(),
289            app(
290                cst("ConvergesToLocalOptimum"),
291                app2(cst("SVIOptimizer"), bvar(1), bvar(0)),
292            ),
293        ),
294    )
295}
296/// **Normalizing Flow Change of Variables**: the pushforward density satisfies the
297/// change-of-variables formula.
298///
299/// `normalizing_flow_cov : āˆ€ (f : X → Y) (p : Measure X),
300///   Bijective f → Density (Pushforward f p) y = Density p (f_inv y) * AbsDetJac f_inv y`
301pub fn normalizing_flow_cov_ty() -> Expr {
302    pi(
303        BinderInfo::Default,
304        "X",
305        type0(),
306        pi(
307            BinderInfo::Default,
308            "Y",
309            type0(),
310            pi(
311                BinderInfo::Default,
312                "f",
313                arrow(bvar(1), bvar(0)),
314                pi(
315                    BinderInfo::Default,
316                    "p",
317                    app(cst("Measure"), bvar(2)),
318                    arrow(
319                        app(cst("Bijective"), bvar(1)),
320                        app(cst("FlowDensityEq"), bvar(2)),
321                    ),
322                ),
323            ),
324        ),
325    )
326}
327/// **Score Function Estimator Unbiasedness**: the REINFORCE estimator is an unbiased
328/// gradient estimator under mild regularity conditions.
329///
330/// `score_fn_unbiased : āˆ€ (q_phi : ParametricMeasure) (f : Z → Real),
331///   RegularFamily q_phi → Unbiased (ScoreFnGrad q_phi f) (GradELBO q_phi f)`
332pub fn score_fn_unbiased_ty() -> Expr {
333    pi(
334        BinderInfo::Default,
335        "Z",
336        type0(),
337        pi(
338            BinderInfo::Default,
339            "q_phi",
340            app(cst("ParametricMeasure"), bvar(0)),
341            pi(
342                BinderInfo::Default,
343                "f",
344                arrow(bvar(1), real_ty()),
345                arrow(
346                    app(cst("RegularFamily"), bvar(1)),
347                    app2(
348                        cst("Unbiased"),
349                        app2(cst("ScoreFnGrad"), bvar(1), bvar(0)),
350                        app2(cst("GradELBO"), bvar(1), bvar(0)),
351                    ),
352                ),
353            ),
354        ),
355    )
356}
357/// **Pathwise Gradient Unbiasedness**: the reparameterised (pathwise) gradient is
358/// an unbiased estimator when the reparameterisation is differentiable.
359///
360/// `pathwise_gradient_unbiased : āˆ€ (g : Eps → Phi → Z) (f : Z → Real),
361///   DiffReparameterisation g → Unbiased (PathwiseGrad g f) (GradELBO (Reparam g) f)`
362pub fn pathwise_gradient_unbiased_ty() -> Expr {
363    pi(
364        BinderInfo::Default,
365        "Eps",
366        type0(),
367        pi(
368            BinderInfo::Default,
369            "Phi",
370            type0(),
371            pi(
372                BinderInfo::Default,
373                "Z",
374                type0(),
375                pi(
376                    BinderInfo::Default,
377                    "g",
378                    arrow(bvar(2), arrow(bvar(1), bvar(0))),
379                    pi(
380                        BinderInfo::Default,
381                        "f",
382                        arrow(bvar(1), real_ty()),
383                        arrow(
384                            app(cst("DiffReparameterisation"), bvar(1)),
385                            app2(
386                                cst("Unbiased"),
387                                app2(cst("PathwiseGrad"), bvar(1), bvar(0)),
388                                app2(cst("GradELBO"), app(cst("Reparam"), bvar(1)), bvar(0)),
389                            ),
390                        ),
391                    ),
392                ),
393            ),
394        ),
395    )
396}
397/// **Measure Transport Existence**: for any two probability measures with the same
398/// total mass there exists a measurable transport map.
399///
400/// `measure_transport_exists : āˆ€ (mu nu : ProbMeasure X),
401///   ∃ (T : X → X), Pushforward T mu = nu`
402pub fn measure_transport_exists_ty() -> Expr {
403    pi(
404        BinderInfo::Default,
405        "X",
406        type0(),
407        pi(
408            BinderInfo::Default,
409            "mu",
410            app(cst("ProbMeasure"), bvar(0)),
411            pi(
412                BinderInfo::Default,
413                "nu",
414                app(cst("ProbMeasure"), bvar(1)),
415                app(
416                    cst("Exists"),
417                    pi(
418                        BinderInfo::Default,
419                        "T",
420                        arrow(bvar(2), bvar(2)),
421                        app2(
422                            cst("Eq"),
423                            app2(cst("Pushforward"), bvar(0), bvar(2)),
424                            bvar(1),
425                        ),
426                    ),
427                ),
428            ),
429        ),
430    )
431}
432/// **Optimal Transport Kantorovich Duality**: the Wasserstein-1 distance equals
433/// the supremum of Lipschitz-1 functions.
434///
435/// `ot_kantorovich : āˆ€ (mu nu : ProbMeasure X),
436///   W1 mu nu = sup_{f : 1-Lip} (E_mu[f] - E_nu[f])`
437pub fn ot_kantorovich_ty() -> Expr {
438    pi(
439        BinderInfo::Default,
440        "X",
441        type0(),
442        pi(
443            BinderInfo::Default,
444            "mu",
445            app(cst("ProbMeasure"), bvar(0)),
446            pi(
447                BinderInfo::Default,
448                "nu",
449                app(cst("ProbMeasure"), bvar(1)),
450                app2(
451                    cst("Eq"),
452                    app2(cst("W1Dist"), bvar(1), bvar(0)),
453                    app2(cst("KantorovichDual"), bvar(1), bvar(0)),
454                ),
455            ),
456        ),
457    )
458}
459/// **Stein Identity**: for any smooth function h and score function s_p = āˆ‡ log p,
460/// E_p[āˆ‡ h(x) + h(x) s_p(x)] = 0.
461///
462/// `stein_identity : āˆ€ (p : SmoothMeasure X) (h : X → Real),
463///   SmoothTestFn h → E_p[SteinOp p h] = 0`
464pub fn stein_identity_ty() -> Expr {
465    pi(
466        BinderInfo::Default,
467        "X",
468        type0(),
469        pi(
470            BinderInfo::Default,
471            "p",
472            app(cst("SmoothMeasure"), bvar(0)),
473            pi(
474                BinderInfo::Default,
475                "h",
476                arrow(bvar(1), real_ty()),
477                arrow(
478                    app(cst("SmoothTestFn"), bvar(0)),
479                    app2(
480                        cst("Eq"),
481                        app2(
482                            cst("Expectation"),
483                            bvar(1),
484                            app2(cst("SteinOp"), bvar(1), bvar(0)),
485                        ),
486                        cst("Real.zero"),
487                    ),
488                ),
489            ),
490        ),
491    )
492}
493/// **Stein Variational Gradient Descent Convergence**: SVGD converges to the target
494/// distribution in the Stein discrepancy sense.
495///
496/// `svgd_convergence : āˆ€ (target : SmoothMeasure X) (n : Nat),
497///   SteinDiscrepancy (SVGDUpdate target n) target ≤ SVGDBound n`
498pub fn svgd_convergence_ty() -> Expr {
499    pi(
500        BinderInfo::Default,
501        "X",
502        type0(),
503        pi(
504            BinderInfo::Default,
505            "target",
506            app(cst("SmoothMeasure"), bvar(0)),
507            pi(
508                BinderInfo::Default,
509                "n",
510                nat_ty(),
511                app2(
512                    cst("Real.le"),
513                    app2(
514                        cst("SteinDiscrepancy"),
515                        app2(cst("SVGDUpdate"), bvar(1), bvar(0)),
516                        bvar(1),
517                    ),
518                    app(cst("SVGDBound"), bvar(0)),
519                ),
520            ),
521        ),
522    )
523}
524/// **Population Monte Carlo Consistency**: PMC estimators are consistent as
525/// population size and iterations grow.
526///
527/// `pmc_consistency : āˆ€ (N : Nat) (T : Nat) (target : Measure X),
528///   ConsistentEstimator (PMCEstimator target N T)`
529pub fn pmc_consistency_ty() -> Expr {
530    pi(
531        BinderInfo::Default,
532        "X",
533        type0(),
534        pi(
535            BinderInfo::Default,
536            "N",
537            nat_ty(),
538            pi(
539                BinderInfo::Default,
540                "T",
541                nat_ty(),
542                pi(
543                    BinderInfo::Default,
544                    "target",
545                    app(cst("Measure"), bvar(2)),
546                    app(
547                        cst("ConsistentEstimator"),
548                        app3(cst("PMCEstimator"), bvar(0), bvar(2), bvar(1)),
549                    ),
550                ),
551            ),
552        ),
553    )
554}
555/// **Evolutionary MCMC Detailed Balance**: evolutionary MCMC satisfies detailed
556/// balance with respect to a product-form invariant distribution.
557///
558/// `evol_mcmc_detailed_balance : āˆ€ (target : Measure X) (temp : Tempering),
559///   DetailedBalance (EvolMCMCKernel target temp) (TemperedTarget target temp)`
560pub fn evol_mcmc_detailed_balance_ty() -> Expr {
561    pi(
562        BinderInfo::Default,
563        "X",
564        type0(),
565        pi(
566            BinderInfo::Default,
567            "target",
568            app(cst("Measure"), bvar(0)),
569            pi(
570                BinderInfo::Default,
571                "temp",
572                cst("Tempering"),
573                app2(
574                    cst("DetailedBalance"),
575                    app2(cst("EvolMCMCKernel"), bvar(1), bvar(0)),
576                    app2(cst("TemperedTarget"), bvar(1), bvar(0)),
577                ),
578            ),
579        ),
580    )
581}
582/// **Parallel Tempering Exchange Correctness**: the swap move in parallel tempering
583/// preserves the joint invariant distribution.
584///
585/// `parallel_tempering_exchange : āˆ€ (temps : List Real) (joint : Measure (ProductSpace temps)),
586///   InvariantUnder (SwapKernel temps) joint`
587pub fn parallel_tempering_exchange_ty() -> Expr {
588    pi(
589        BinderInfo::Default,
590        "temps",
591        list_ty(real_ty()),
592        pi(
593            BinderInfo::Default,
594            "joint",
595            app(cst("Measure"), app(cst("ProductSpace"), bvar(0))),
596            app2(
597                cst("InvariantUnder"),
598                app(cst("SwapKernel"), bvar(1)),
599                bvar(0),
600            ),
601        ),
602    )
603}
604/// **Simulated Annealing Convergence**: simulated annealing converges to a global
605/// optimum under a logarithmic cooling schedule.
606///
607/// `simulated_annealing_convergence : āˆ€ (f : X → Real) (T : CoolingSchedule),
608///   LogarithmicSchedule T → ConvergesToGlobalOpt (SAChain f T)`
609pub fn simulated_annealing_convergence_ty() -> Expr {
610    pi(
611        BinderInfo::Default,
612        "X",
613        type0(),
614        pi(
615            BinderInfo::Default,
616            "f",
617            arrow(bvar(0), real_ty()),
618            pi(
619                BinderInfo::Default,
620                "T",
621                cst("CoolingSchedule"),
622                arrow(
623                    app(cst("LogarithmicSchedule"), bvar(0)),
624                    app(
625                        cst("ConvergesToGlobalOpt"),
626                        app2(cst("SAChain"), bvar(2), bvar(0)),
627                    ),
628                ),
629            ),
630        ),
631    )
632}
633/// **VAE ELBO Decomposition**: the VAE objective decomposes as reconstruction term
634/// minus KL divergence.
635///
636/// `vae_elbo_decomp : āˆ€ (encoder decoder : NeuralNet) (x : X),
637///   VAELBO encoder decoder x = Reconstruction encoder decoder x - KLDivQ encoder x`
638pub fn vae_elbo_decomp_ty() -> Expr {
639    pi(
640        BinderInfo::Default,
641        "X",
642        type0(),
643        pi(
644            BinderInfo::Default,
645            "encoder",
646            cst("NeuralNet"),
647            pi(
648                BinderInfo::Default,
649                "decoder",
650                cst("NeuralNet"),
651                pi(
652                    BinderInfo::Default,
653                    "x",
654                    bvar(2),
655                    app2(
656                        cst("Eq"),
657                        app3(cst("VAELBO"), bvar(2), bvar(1), bvar(0)),
658                        app2(
659                            cst("Real.sub"),
660                            app3(cst("Reconstruction"), bvar(2), bvar(1), bvar(0)),
661                            app2(cst("KLDivQ"), bvar(2), bvar(0)),
662                        ),
663                    ),
664                ),
665            ),
666        ),
667    )
668}
669/// **Diffusion Model Score Matching**: the reverse diffusion score function
670/// minimises the denoising score matching objective.
671///
672/// `diffusion_score_matching : āˆ€ (p_data : Measure X) (sigma : Real),
673///   OptimScore (DSMObjective p_data sigma) = ScoreFunction (GaussianSmooth p_data sigma)`
674pub fn diffusion_score_matching_ty() -> Expr {
675    pi(
676        BinderInfo::Default,
677        "X",
678        type0(),
679        pi(
680            BinderInfo::Default,
681            "p_data",
682            app(cst("Measure"), bvar(0)),
683            pi(
684                BinderInfo::Default,
685                "sigma",
686                real_ty(),
687                app2(
688                    cst("Eq"),
689                    app(
690                        cst("OptimScore"),
691                        app2(cst("DSMObjective"), bvar(1), bvar(0)),
692                    ),
693                    app(
694                        cst("ScoreFunction"),
695                        app2(cst("GaussianSmooth"), bvar(1), bvar(0)),
696                    ),
697                ),
698            ),
699        ),
700    )
701}
702/// **Flow Matching ODE Correctness**: the conditional flow matching ODE generates
703/// the correct marginal distribution at time t=1.
704///
705/// `flow_matching_ode : āˆ€ (p_0 p_1 : Measure X) (vt : VectorField),
706///   CondFlowMatchingField vt p_0 p_1 → PushforwardODE vt p_0 1 = p_1`
707pub fn flow_matching_ode_ty() -> Expr {
708    pi(
709        BinderInfo::Default,
710        "X",
711        type0(),
712        pi(
713            BinderInfo::Default,
714            "p_0",
715            app(cst("Measure"), bvar(0)),
716            pi(
717                BinderInfo::Default,
718                "p_1",
719                app(cst("Measure"), bvar(1)),
720                pi(
721                    BinderInfo::Default,
722                    "vt",
723                    cst("VectorField"),
724                    arrow(
725                        app3(cst("CondFlowMatchingField"), bvar(0), bvar(2), bvar(1)),
726                        app2(
727                            cst("Eq"),
728                            app3(cst("PushforwardODE"), bvar(0), bvar(2), cst("Real.one")),
729                            bvar(1),
730                        ),
731                    ),
732                ),
733            ),
734        ),
735    )
736}
737/// **Gaussian Process Posterior**: the posterior of a GP given observations is also a GP.
738///
739/// `gp_posterior_is_gp : āˆ€ (prior : GaussianProcess X) (obs : Observations),
740///   IsGaussianProcess (GPPosterior prior obs)`
741pub fn gp_posterior_is_gp_ty() -> Expr {
742    pi(
743        BinderInfo::Default,
744        "X",
745        type0(),
746        pi(
747            BinderInfo::Default,
748            "prior",
749            app(cst("GaussianProcess"), bvar(0)),
750            pi(
751                BinderInfo::Default,
752                "obs",
753                cst("Observations"),
754                app(
755                    cst("IsGaussianProcess"),
756                    app2(cst("GPPosterior"), bvar(1), bvar(0)),
757                ),
758            ),
759        ),
760    )
761}
762/// **GP Marginal Likelihood**: the marginal likelihood of a GP is Gaussian.
763///
764/// `gp_marginal_gaussian : āˆ€ (gp : GaussianProcess X) (X_train : List X),
765///   MarginalLikelihood gp X_train = MultivariateGaussian (GPMean gp X_train) (GPKernelMatrix gp X_train)`
766pub fn gp_marginal_gaussian_ty() -> Expr {
767    pi(
768        BinderInfo::Default,
769        "X",
770        type0(),
771        pi(
772            BinderInfo::Default,
773            "gp",
774            app(cst("GaussianProcess"), bvar(0)),
775            pi(
776                BinderInfo::Default,
777                "X_train",
778                list_ty(bvar(1)),
779                app2(
780                    cst("Eq"),
781                    app2(cst("MarginalLikelihood"), bvar(1), bvar(0)),
782                    app2(
783                        cst("MultivariateGaussian"),
784                        app2(cst("GPMean"), bvar(1), bvar(0)),
785                        app2(cst("GPKernelMatrix"), bvar(1), bvar(0)),
786                    ),
787                ),
788            ),
789        ),
790    )
791}
792/// **Probabilistic Numerics Integration**: Bayesian quadrature produces a posterior
793/// over integrals.
794///
795/// `pn_integration : āˆ€ (f : X → Real) (prior_gp : GaussianProcess X),
796///   BQPosterior prior_gp f = GaussianMeasureOver Real`
797pub fn pn_integration_ty() -> Expr {
798    pi(
799        BinderInfo::Default,
800        "X",
801        type0(),
802        pi(
803            BinderInfo::Default,
804            "f",
805            arrow(bvar(0), real_ty()),
806            pi(
807                BinderInfo::Default,
808                "prior_gp",
809                app(cst("GaussianProcess"), bvar(1)),
810                app2(
811                    cst("Eq"),
812                    app2(cst("BQPosterior"), bvar(0), bvar(1)),
813                    app(cst("GaussianMeasureOver"), real_ty()),
814                ),
815            ),
816        ),
817    )
818}
819/// **Stein Discrepancy Zero Iff Same Distribution**: the kernel Stein discrepancy
820/// between two measures is zero if and only if they are equal.
821///
822/// `stein_disc_zero_iff : āˆ€ (p q : Measure X) (k : SteinKernel X),
823///   KSD p q k = 0 ↔ p = q`
824pub fn stein_disc_zero_iff_ty() -> Expr {
825    pi(
826        BinderInfo::Default,
827        "X",
828        type0(),
829        pi(
830            BinderInfo::Default,
831            "p",
832            app(cst("Measure"), bvar(0)),
833            pi(
834                BinderInfo::Default,
835                "q",
836                app(cst("Measure"), bvar(1)),
837                pi(
838                    BinderInfo::Default,
839                    "k",
840                    app(cst("SteinKernel"), bvar(2)),
841                    app2(
842                        cst("Iff"),
843                        app2(
844                            cst("Eq"),
845                            app3(cst("KSD"), bvar(2), bvar(1), bvar(0)),
846                            cst("Real.zero"),
847                        ),
848                        app2(cst("Eq"), bvar(2), bvar(1)),
849                    ),
850                ),
851            ),
852        ),
853    )
854}
855/// **SMC Feynman-Kac**: SMC computes the Feynman-Kac normalising constant exactly
856/// in expectation.
857///
858/// `smc_feynman_kac : āˆ€ (fk : FeynmanKacModel X) (N : Nat),
859///   E[SMCNormConst fk N] = FeynmanKacNormConst fk`
860pub fn smc_feynman_kac_ty() -> Expr {
861    pi(
862        BinderInfo::Default,
863        "X",
864        type0(),
865        pi(
866            BinderInfo::Default,
867            "fk",
868            app(cst("FeynmanKacModel"), bvar(0)),
869            pi(
870                BinderInfo::Default,
871                "N",
872                nat_ty(),
873                app2(
874                    cst("Eq"),
875                    app(
876                        cst("Expectation"),
877                        app2(cst("SMCNormConst"), bvar(1), bvar(0)),
878                    ),
879                    app(cst("FeynmanKacNormConst"), bvar(1)),
880                ),
881            ),
882        ),
883    )
884}
885/// **Particle Marginal Metropolis-Hastings Correctness**: PMMH targeting the exact
886/// posterior is asymptotically exact.
887///
888/// `pmmh_correctness : āˆ€ (model : LatentModel X Y) (obs : Y),
889///   TargetsExactPosterior (PMMHKernel model obs)`
890pub fn pmmh_correctness_ty() -> Expr {
891    pi(
892        BinderInfo::Default,
893        "X",
894        type0(),
895        pi(
896            BinderInfo::Default,
897            "Y",
898            type0(),
899            pi(
900                BinderInfo::Default,
901                "model",
902                app2(cst("LatentModel"), bvar(1), bvar(0)),
903                pi(
904                    BinderInfo::Default,
905                    "obs",
906                    bvar(1),
907                    app(
908                        cst("TargetsExactPosterior"),
909                        app2(cst("PMMHKernel"), bvar(1), bvar(0)),
910                    ),
911                ),
912            ),
913        ),
914    )
915}
916/// **Annealed Importance Sampling Unbiasedness**: AIS produces unbiased estimates
917/// of the normalising constant.
918///
919/// `ais_unbiased : āˆ€ (p_0 p_T : Measure X) (beta_sched : AnnealingSchedule),
920///   Unbiased (AISEstimator p_0 p_T beta_sched) (NormConst p_T)`
921pub fn ais_unbiased_ty() -> Expr {
922    pi(
923        BinderInfo::Default,
924        "X",
925        type0(),
926        pi(
927            BinderInfo::Default,
928            "p_0",
929            app(cst("Measure"), bvar(0)),
930            pi(
931                BinderInfo::Default,
932                "p_T",
933                app(cst("Measure"), bvar(1)),
934                pi(
935                    BinderInfo::Default,
936                    "beta_sched",
937                    cst("AnnealingSchedule"),
938                    app2(
939                        cst("Unbiased"),
940                        app3(cst("AISEstimator"), bvar(2), bvar(1), bvar(0)),
941                        app(cst("NormConst"), bvar(1)),
942                    ),
943                ),
944            ),
945        ),
946    )
947}
948/// **Denoising Score Matching Objective**: DSM objective equals implicit score matching
949/// objective under Gaussian noise.
950///
951/// `dsm_equals_sm : āˆ€ (p_data : Measure X) (sigma : Real),
952///   DSMObjective p_data sigma = ImplicitSMObjective (GaussianConvolution p_data sigma)`
953pub fn dsm_equals_sm_ty() -> Expr {
954    pi(
955        BinderInfo::Default,
956        "X",
957        type0(),
958        pi(
959            BinderInfo::Default,
960            "p_data",
961            app(cst("Measure"), bvar(0)),
962            pi(
963                BinderInfo::Default,
964                "sigma",
965                real_ty(),
966                app2(
967                    cst("Eq"),
968                    app2(cst("DSMObjective"), bvar(1), bvar(0)),
969                    app(
970                        cst("ImplicitSMObjective"),
971                        app2(cst("GaussianConvolution"), bvar(1), bvar(0)),
972                    ),
973                ),
974            ),
975        ),
976    )
977}
978/// **Langevin Dynamics Convergence**: the unadjusted Langevin algorithm (ULA) converges
979/// to the target in 2-Wasserstein under strong convexity.
980///
981/// `langevin_convergence : āˆ€ (target : LogConcaveMeasure X) (eps : Real),
982///   W2 (ULADist target eps n) target ≤ LangevinBound target eps n`
983pub fn langevin_convergence_ty() -> Expr {
984    pi(
985        BinderInfo::Default,
986        "X",
987        type0(),
988        pi(
989            BinderInfo::Default,
990            "target",
991            app(cst("LogConcaveMeasure"), bvar(0)),
992            pi(
993                BinderInfo::Default,
994                "eps",
995                real_ty(),
996                pi(
997                    BinderInfo::Default,
998                    "n",
999                    nat_ty(),
1000                    app2(
1001                        cst("Real.le"),
1002                        app3(
1003                            cst("W2"),
1004                            app3(cst("ULADist"), bvar(2), bvar(1), bvar(0)),
1005                            bvar(2),
1006                            bvar(1),
1007                        ),
1008                        app3(cst("LangevinBound"), bvar(2), bvar(1), bvar(0)),
1009                    ),
1010                ),
1011            ),
1012        ),
1013    )
1014}
1015/// **Metropolis-Hastings Detailed Balance**: MH kernel satisfies detailed balance
1016/// w.r.t. the target distribution.
1017///
1018/// `mh_detailed_balance : āˆ€ (target : Measure X) (proposal : Kernel X X),
1019///   DetailedBalance (MHKernel target proposal) target`
1020pub fn mh_detailed_balance_ty() -> Expr {
1021    pi(
1022        BinderInfo::Default,
1023        "X",
1024        type0(),
1025        pi(
1026            BinderInfo::Default,
1027            "target",
1028            app(cst("Measure"), bvar(0)),
1029            pi(
1030                BinderInfo::Default,
1031                "proposal",
1032                app2(cst("Kernel"), bvar(1), bvar(1)),
1033                app2(
1034                    cst("DetailedBalance"),
1035                    app2(cst("MHKernel"), bvar(1), bvar(0)),
1036                    bvar(1),
1037                ),
1038            ),
1039        ),
1040    )
1041}
1042/// **Gibbs Sampling Invariance**: the Gibbs sampler leaves the joint distribution invariant.
1043///
1044/// `gibbs_invariant : āˆ€ (joint : Measure (X Ɨ Y)),
1045///   InvariantUnder (GibbsKernel joint) joint`
1046pub fn gibbs_invariant_ty() -> Expr {
1047    pi(
1048        BinderInfo::Default,
1049        "X",
1050        type0(),
1051        pi(
1052            BinderInfo::Default,
1053            "Y",
1054            type0(),
1055            pi(
1056                BinderInfo::Default,
1057                "joint",
1058                app(cst("Measure"), app2(cst("Prod"), bvar(1), bvar(0))),
1059                app2(
1060                    cst("InvariantUnder"),
1061                    app(cst("GibbsKernel"), bvar(0)),
1062                    bvar(0),
1063                ),
1064            ),
1065        ),
1066    )
1067}
1068/// **Variational Autoencoder Posterior Collapse**: with a sufficiently expressive decoder
1069/// there exists a risk of posterior collapse.
1070///
1071/// `vae_posterior_collapse_risk : āˆ€ (decoder : ExpressiveDecoder) (beta : Real),
1072///   beta < 1 → AvoidsCollapse (BetaVAE decoder beta)`
1073pub fn vae_posterior_collapse_risk_ty() -> Expr {
1074    pi(
1075        BinderInfo::Default,
1076        "decoder",
1077        cst("ExpressiveDecoder"),
1078        pi(
1079            BinderInfo::Default,
1080            "beta",
1081            real_ty(),
1082            arrow(
1083                app2(cst("Real.lt"), bvar(0), cst("Real.one")),
1084                app(
1085                    cst("AvoidsCollapse"),
1086                    app2(cst("BetaVAE"), bvar(1), bvar(0)),
1087                ),
1088            ),
1089        ),
1090    )
1091}
1092/// **Gradient of Log Normalizer**: the gradient of the log normalizer of an exponential
1093/// family equals the mean of the sufficient statistics.
1094///
1095/// `grad_log_normalizer : āˆ€ (eta : NaturalParams) (T : SuffStat),
1096///   Gradient (LogNormalizer T) eta = MeanSuffStat T (ExpFamilyDist T eta)`
1097pub fn grad_log_normalizer_ty() -> Expr {
1098    pi(
1099        BinderInfo::Default,
1100        "T",
1101        type0(),
1102        pi(
1103            BinderInfo::Default,
1104            "eta",
1105            app(cst("NaturalParams"), bvar(0)),
1106            app2(
1107                cst("Eq"),
1108                app(cst("Gradient"), app(cst("LogNormalizer"), bvar(0))),
1109                app2(
1110                    cst("MeanSuffStat"),
1111                    bvar(0),
1112                    app2(cst("ExpFamilyDist"), bvar(0), bvar(1)),
1113                ),
1114            ),
1115        ),
1116    )
1117}
1118/// **Sequential Monte Carlo Genealogy**: the ancestral lineage in SMC traces back
1119/// through the resampling steps.
1120///
1121/// `smc_genealogy : āˆ€ (pf : ParticleSystem X T),
1122///   AncestralLineage pf = CoalescentProcess (ResamplingTimes pf)`
1123pub fn smc_genealogy_ty() -> Expr {
1124    pi(
1125        BinderInfo::Default,
1126        "X",
1127        type0(),
1128        pi(
1129            BinderInfo::Default,
1130            "T",
1131            nat_ty(),
1132            pi(
1133                BinderInfo::Default,
1134                "pf",
1135                app2(cst("ParticleSystem"), bvar(1), bvar(0)),
1136                app2(
1137                    cst("Eq"),
1138                    app(cst("AncestralLineage"), bvar(0)),
1139                    app(
1140                        cst("CoalescentProcess"),
1141                        app(cst("ResamplingTimes"), bvar(0)),
1142                    ),
1143                ),
1144            ),
1145        ),
1146    )
1147}
1148/// **Kernel Density Estimation Consistency**: the KDE converges to the true density
1149/// in L2 as n → āˆž with optimal bandwidth.
1150///
1151/// `kde_consistency : āˆ€ (p : SmoothMeasure X) (h : BandwidthSeq),
1152///   OptimalBandwidth h → L2Convergence (KDEn p h)`
1153pub fn kde_consistency_ty() -> Expr {
1154    pi(
1155        BinderInfo::Default,
1156        "X",
1157        type0(),
1158        pi(
1159            BinderInfo::Default,
1160            "p",
1161            app(cst("SmoothMeasure"), bvar(0)),
1162            pi(
1163                BinderInfo::Default,
1164                "h",
1165                cst("BandwidthSeq"),
1166                arrow(
1167                    app(cst("OptimalBandwidth"), bvar(0)),
1168                    app(cst("L2Convergence"), app2(cst("KDEn"), bvar(1), bvar(0))),
1169                ),
1170            ),
1171        ),
1172    )
1173}
1174/// **Variational Inference Mean-Field Factorization**: the mean-field approximation
1175/// optimises each factor holding others fixed via coordinate ascent.
1176///
1177/// `mean_field_cavi : āˆ€ (joint : Measure Z) (q_factors : List (Measure Z)),
1178///   CAVIStep joint q_factors = UpdatedFactors joint q_factors`
1179pub fn mean_field_cavi_ty() -> Expr {
1180    pi(
1181        BinderInfo::Default,
1182        "Z",
1183        type0(),
1184        pi(
1185            BinderInfo::Default,
1186            "joint",
1187            app(cst("Measure"), bvar(0)),
1188            pi(
1189                BinderInfo::Default,
1190                "q_factors",
1191                list_ty(app(cst("Measure"), bvar(1))),
1192                app2(
1193                    cst("Eq"),
1194                    app2(cst("CAVIStep"), bvar(1), bvar(0)),
1195                    app2(cst("UpdatedFactors"), bvar(1), bvar(0)),
1196                ),
1197            ),
1198        ),
1199    )
1200}
1201/// **Probabilistic Backpropagation Gaussian Propagation**: PBP propagates a
1202/// Gaussian approximation through each layer of a neural network.
1203///
1204/// `pbp_gaussian_propagation : āˆ€ (net : BayesianNeuralNet) (x : Input),
1205///   GaussianApproxActivations (PBP net x) = PBPActivations net x`
1206pub fn pbp_gaussian_propagation_ty() -> Expr {
1207    pi(
1208        BinderInfo::Default,
1209        "net",
1210        cst("BayesianNeuralNet"),
1211        pi(
1212            BinderInfo::Default,
1213            "x",
1214            cst("Input"),
1215            app2(
1216                cst("Eq"),
1217                app(
1218                    cst("GaussianApproxActivations"),
1219                    app2(cst("PBP"), bvar(1), bvar(0)),
1220                ),
1221                app2(cst("PBPActivations"), bvar(1), bvar(0)),
1222            ),
1223        ),
1224    )
1225}
1226/// **Expectation Propagation Fixed Point**: EP converges when the cavity distribution
1227/// and tilted distribution agree.
1228///
1229/// `ep_fixed_point : āˆ€ (model : FactorGraph) (approx : GaussianApprox),
1230///   EPFixedPoint approx model ↔ CavityAgreement approx model`
1231pub fn ep_fixed_point_ty() -> Expr {
1232    pi(
1233        BinderInfo::Default,
1234        "model",
1235        cst("FactorGraph"),
1236        pi(
1237            BinderInfo::Default,
1238            "approx",
1239            cst("GaussianApprox"),
1240            app2(
1241                cst("Iff"),
1242                app2(cst("EPFixedPoint"), bvar(0), bvar(1)),
1243                app2(cst("CavityAgreement"), bvar(0), bvar(1)),
1244            ),
1245        ),
1246    )
1247}
1248/// **Nested Monte Carlo Estimator Bias**: nested MC estimators are biased but
1249/// consistent as the inner sample size grows.
1250///
1251/// `nested_mc_bias : āˆ€ (outer inner : Nat) (f : X → Real),
1252///   Bias (NestedMCEstimator f outer inner) ≤ NestedMCBiasRate inner`
1253pub fn nested_mc_bias_ty() -> Expr {
1254    pi(
1255        BinderInfo::Default,
1256        "X",
1257        type0(),
1258        pi(
1259            BinderInfo::Default,
1260            "outer",
1261            nat_ty(),
1262            pi(
1263                BinderInfo::Default,
1264                "inner",
1265                nat_ty(),
1266                pi(
1267                    BinderInfo::Default,
1268                    "f",
1269                    arrow(bvar(2), real_ty()),
1270                    app2(
1271                        cst("Real.le"),
1272                        app(
1273                            cst("Bias"),
1274                            app3(cst("NestedMCEstimator"), bvar(0), bvar(2), bvar(1)),
1275                        ),
1276                        app(cst("NestedMCBiasRate"), bvar(1)),
1277                    ),
1278                ),
1279            ),
1280        ),
1281    )
1282}
1283/// **Approximate Bayesian Computation Consistency**: ABC-SMC converges to the
1284/// correct posterior as the tolerance ε → 0.
1285///
1286/// `abc_smc_consistency : āˆ€ (prior : Measure Theta) (sim : Theta → Measure Y) (eps : Real),
1287///   eps > 0 → ApproxPosterior (ABCSMC prior sim eps) eps (TruePosterior prior sim)`
1288pub fn abc_smc_consistency_ty() -> Expr {
1289    pi(
1290        BinderInfo::Default,
1291        "Theta",
1292        type0(),
1293        pi(
1294            BinderInfo::Default,
1295            "Y",
1296            type0(),
1297            pi(
1298                BinderInfo::Default,
1299                "prior",
1300                app(cst("Measure"), bvar(1)),
1301                pi(
1302                    BinderInfo::Default,
1303                    "sim",
1304                    arrow(bvar(2), app(cst("Measure"), bvar(1))),
1305                    pi(
1306                        BinderInfo::Default,
1307                        "eps",
1308                        real_ty(),
1309                        arrow(
1310                            app2(cst("Real.gt"), bvar(0), cst("Real.zero")),
1311                            app3(
1312                                cst("ApproxPosterior"),
1313                                app3(cst("ABCSMC"), bvar(3), bvar(1), bvar(0)),
1314                                bvar(0),
1315                                app2(cst("TruePosterior"), bvar(3), bvar(1)),
1316                            ),
1317                        ),
1318                    ),
1319                ),
1320            ),
1321        ),
1322    )
1323}
1324/// Populate an `Environment` with all probabilistic-programming axiom declarations.
1325pub fn build_probabilistic_programming_env(
1326    env: &mut Environment,
1327) -> Result<(), Box<dyn std::error::Error>> {
1328    let axioms: &[(&str, Expr)] = &[
1329        ("Measure", measure_ty()),
1330        ("SigmaAlgebra", sigma_algebra_ty()),
1331        ("MeasurableSpace", measurable_space_ty()),
1332        ("ProbabilityMonad", probability_monad_ty()),
1333        ("Kernel", kernel_ty()),
1334        ("Sampler", sampler_ty()),
1335        ("PPLProgram", ppl_program_ty()),
1336        ("GradientEstimator", gradient_estimator_ty()),
1337        ("ParticleFilter", particle_filter_ty()),
1338        (
1339            "Joint",
1340            arrow(
1341                app(cst("Measure"), type0()),
1342                arrow(app2(cst("Kernel"), type0(), type0()), type0()),
1343            ),
1344        ),
1345        ("JointMeasure", arrow(type0(), arrow(type0(), type0()))),
1346        (
1347            "Marginal",
1348            arrow(app2(cst("Kernel"), type0(), type0()), type0()),
1349        ),
1350        (
1351            "Posterior",
1352            arrow(
1353                app(cst("Measure"), type0()),
1354                arrow(
1355                    app2(cst("Kernel"), type0(), type0()),
1356                    app(cst("Measure"), type0()),
1357                ),
1358            ),
1359        ),
1360        ("RadonNikodym", arrow(type0(), arrow(type0(), type0()))),
1361        (
1362            "AbsContinuous",
1363            arrow(
1364                app(cst("Measure"), type0()),
1365                arrow(app(cst("Measure"), type0()), prop()),
1366            ),
1367        ),
1368        ("ConsistentEstimator", arrow(type0(), prop())),
1369        (
1370            "ISSampler",
1371            arrow(
1372                arrow(type0(), real_ty()),
1373                arrow(
1374                    app(cst("Measure"), type0()),
1375                    arrow(app(cst("Measure"), type0()), type0()),
1376                ),
1377            ),
1378        ),
1379        (
1380            "ELBO",
1381            arrow(
1382                app(cst("Measure"), type0()),
1383                arrow(type0(), arrow(type0(), real_ty())),
1384            ),
1385        ),
1386        (
1387            "LogMarginalLikelihood",
1388            arrow(type0(), arrow(type0(), real_ty())),
1389        ),
1390        ("MonadLaws", arrow(arrow(type0(), type0()), prop())),
1391        ("Unbiased", arrow(type0(), arrow(type0(), prop()))),
1392        ("ReparamGradient", arrow(arrow(type0(), real_ty()), type0())),
1393        (
1394            "GradExpectation",
1395            arrow(arrow(type0(), real_ty()), arrow(type0(), type0())),
1396        ),
1397        (
1398            "HMCKernel",
1399            arrow(
1400                app(cst("Measure"), type0()),
1401                app2(cst("Kernel"), type0(), type0()),
1402            ),
1403        ),
1404        (
1405            "InvariantUnder",
1406            arrow(
1407                app2(cst("Kernel"), type0(), type0()),
1408                arrow(app(cst("Measure"), type0()), prop()),
1409            ),
1410        ),
1411        ("StateSpaceModel", arrow(type0(), type0())),
1412        ("ConsistentFilteringDist", arrow(type0(), prop())),
1413        ("bayes_measure_theory", bayes_measure_theory_ty()),
1414        ("giry_monad_laws", giry_monad_laws_ty()),
1415        ("is_consistency", is_consistency_ty()),
1416        ("elbo_lower_bound", elbo_lower_bound_ty()),
1417        ("reparam_unbiased", reparam_unbiased_ty()),
1418        ("hmc_invariant", hmc_invariant_ty()),
1419        ("smc_consistency", smc_consistency_ty()),
1420        ("VariationalFamily", arrow(type0(), type0())),
1421        ("LRSchedule", type0()),
1422        ("SVIOptimizer", arrow(type0(), arrow(type0(), type0()))),
1423        ("ConvergesToLocalOptimum", arrow(type0(), prop())),
1424        ("Bijective", arrow(arrow(type0(), type0()), prop())),
1425        ("FlowDensityEq", arrow(type0(), prop())),
1426        ("ParametricMeasure", arrow(type0(), type0())),
1427        ("RegularFamily", arrow(type0(), prop())),
1428        (
1429            "ScoreFnGrad",
1430            arrow(type0(), arrow(arrow(type0(), real_ty()), type0())),
1431        ),
1432        (
1433            "GradELBO",
1434            arrow(type0(), arrow(arrow(type0(), real_ty()), type0())),
1435        ),
1436        (
1437            "DiffReparameterisation",
1438            arrow(arrow(type0(), arrow(type0(), type0())), prop()),
1439        ),
1440        (
1441            "PathwiseGrad",
1442            arrow(
1443                arrow(type0(), arrow(type0(), type0())),
1444                arrow(arrow(type0(), real_ty()), type0()),
1445            ),
1446        ),
1447        (
1448            "Reparam",
1449            arrow(arrow(type0(), arrow(type0(), type0())), type0()),
1450        ),
1451        ("ProbMeasure", arrow(type0(), type0())),
1452        (
1453            "Pushforward",
1454            arrow(
1455                arrow(type0(), type0()),
1456                arrow(app(cst("Measure"), type0()), app(cst("Measure"), type0())),
1457            ),
1458        ),
1459        (
1460            "W1Dist",
1461            arrow(
1462                app(cst("Measure"), type0()),
1463                arrow(app(cst("Measure"), type0()), real_ty()),
1464            ),
1465        ),
1466        (
1467            "KantorovichDual",
1468            arrow(
1469                app(cst("Measure"), type0()),
1470                arrow(app(cst("Measure"), type0()), real_ty()),
1471            ),
1472        ),
1473        ("SmoothMeasure", arrow(type0(), type0())),
1474        ("SmoothTestFn", arrow(arrow(type0(), real_ty()), prop())),
1475        (
1476            "Expectation",
1477            arrow(app(cst("Measure"), type0()), arrow(type0(), real_ty())),
1478        ),
1479        (
1480            "SteinOp",
1481            arrow(
1482                app(cst("SmoothMeasure"), type0()),
1483                arrow(arrow(type0(), real_ty()), type0()),
1484            ),
1485        ),
1486        ("Real.zero", real_ty()),
1487        ("Real.one", real_ty()),
1488        (
1489            "SteinDiscrepancy",
1490            arrow(
1491                app(cst("Measure"), type0()),
1492                arrow(app(cst("Measure"), type0()), real_ty()),
1493            ),
1494        ),
1495        (
1496            "SVGDUpdate",
1497            arrow(
1498                app(cst("SmoothMeasure"), type0()),
1499                arrow(nat_ty(), app(cst("Measure"), type0())),
1500            ),
1501        ),
1502        ("SVGDBound", arrow(nat_ty(), real_ty())),
1503        ("SteinKernel", arrow(type0(), type0())),
1504        (
1505            "KSD",
1506            arrow(
1507                app(cst("Measure"), type0()),
1508                arrow(
1509                    app(cst("Measure"), type0()),
1510                    arrow(app(cst("SteinKernel"), type0()), real_ty()),
1511                ),
1512            ),
1513        ),
1514        ("Iff", arrow(prop(), arrow(prop(), prop()))),
1515        (
1516            "PMCEstimator",
1517            arrow(
1518                app(cst("Measure"), type0()),
1519                arrow(nat_ty(), arrow(nat_ty(), type0())),
1520            ),
1521        ),
1522        ("Tempering", type0()),
1523        (
1524            "EvolMCMCKernel",
1525            arrow(
1526                app(cst("Measure"), type0()),
1527                arrow(cst("Tempering"), app2(cst("Kernel"), type0(), type0())),
1528            ),
1529        ),
1530        (
1531            "TemperedTarget",
1532            arrow(
1533                app(cst("Measure"), type0()),
1534                arrow(cst("Tempering"), app(cst("Measure"), type0())),
1535            ),
1536        ),
1537        (
1538            "DetailedBalance",
1539            arrow(
1540                app2(cst("Kernel"), type0(), type0()),
1541                arrow(app(cst("Measure"), type0()), prop()),
1542            ),
1543        ),
1544        ("ProductSpace", arrow(list_ty(real_ty()), type0())),
1545        (
1546            "SwapKernel",
1547            arrow(list_ty(real_ty()), app2(cst("Kernel"), type0(), type0())),
1548        ),
1549        ("CoolingSchedule", type0()),
1550        ("LogarithmicSchedule", arrow(cst("CoolingSchedule"), prop())),
1551        (
1552            "SAChain",
1553            arrow(
1554                arrow(type0(), real_ty()),
1555                arrow(cst("CoolingSchedule"), type0()),
1556            ),
1557        ),
1558        ("ConvergesToGlobalOpt", arrow(type0(), prop())),
1559        ("NeuralNet", type0()),
1560        (
1561            "VAELBO",
1562            arrow(
1563                cst("NeuralNet"),
1564                arrow(cst("NeuralNet"), arrow(type0(), real_ty())),
1565            ),
1566        ),
1567        (
1568            "Reconstruction",
1569            arrow(
1570                cst("NeuralNet"),
1571                arrow(cst("NeuralNet"), arrow(type0(), real_ty())),
1572            ),
1573        ),
1574        ("KLDivQ", arrow(cst("NeuralNet"), arrow(type0(), real_ty()))),
1575        ("Real.sub", arrow(real_ty(), arrow(real_ty(), real_ty()))),
1576        (
1577            "DSMObjective",
1578            arrow(app(cst("Measure"), type0()), arrow(real_ty(), type0())),
1579        ),
1580        (
1581            "GaussianSmooth",
1582            arrow(
1583                app(cst("Measure"), type0()),
1584                arrow(real_ty(), app(cst("Measure"), type0())),
1585            ),
1586        ),
1587        ("OptimScore", arrow(type0(), type0())),
1588        (
1589            "ScoreFunction",
1590            arrow(app(cst("Measure"), type0()), type0()),
1591        ),
1592        ("VectorField", type0()),
1593        (
1594            "CondFlowMatchingField",
1595            arrow(
1596                cst("VectorField"),
1597                arrow(
1598                    app(cst("Measure"), type0()),
1599                    arrow(app(cst("Measure"), type0()), prop()),
1600                ),
1601            ),
1602        ),
1603        (
1604            "PushforwardODE",
1605            arrow(
1606                cst("VectorField"),
1607                arrow(
1608                    app(cst("Measure"), type0()),
1609                    arrow(real_ty(), app(cst("Measure"), type0())),
1610                ),
1611            ),
1612        ),
1613        ("GaussianProcess", arrow(type0(), type0())),
1614        ("Observations", type0()),
1615        ("IsGaussianProcess", arrow(type0(), prop())),
1616        (
1617            "GPPosterior",
1618            arrow(
1619                app(cst("GaussianProcess"), type0()),
1620                arrow(cst("Observations"), type0()),
1621            ),
1622        ),
1623        (
1624            "MarginalLikelihood",
1625            arrow(
1626                app(cst("GaussianProcess"), type0()),
1627                arrow(list_ty(type0()), app(cst("Measure"), type0())),
1628            ),
1629        ),
1630        (
1631            "MultivariateGaussian",
1632            arrow(type0(), arrow(type0(), app(cst("Measure"), type0()))),
1633        ),
1634        (
1635            "GPMean",
1636            arrow(
1637                app(cst("GaussianProcess"), type0()),
1638                arrow(list_ty(type0()), type0()),
1639            ),
1640        ),
1641        (
1642            "GPKernelMatrix",
1643            arrow(
1644                app(cst("GaussianProcess"), type0()),
1645                arrow(list_ty(type0()), type0()),
1646            ),
1647        ),
1648        (
1649            "BQPosterior",
1650            arrow(
1651                app(cst("GaussianProcess"), type0()),
1652                arrow(arrow(type0(), real_ty()), app(cst("Measure"), real_ty())),
1653            ),
1654        ),
1655        (
1656            "GaussianMeasureOver",
1657            arrow(type0(), app(cst("Measure"), real_ty())),
1658        ),
1659        ("FeynmanKacModel", arrow(type0(), type0())),
1660        (
1661            "SMCNormConst",
1662            arrow(
1663                app(cst("FeynmanKacModel"), type0()),
1664                arrow(nat_ty(), real_ty()),
1665            ),
1666        ),
1667        (
1668            "FeynmanKacNormConst",
1669            arrow(app(cst("FeynmanKacModel"), type0()), real_ty()),
1670        ),
1671        ("LatentModel", arrow(type0(), arrow(type0(), type0()))),
1672        (
1673            "PMMHKernel",
1674            arrow(
1675                app2(cst("LatentModel"), type0(), type0()),
1676                arrow(type0(), app2(cst("Kernel"), type0(), type0())),
1677            ),
1678        ),
1679        (
1680            "TargetsExactPosterior",
1681            arrow(app2(cst("Kernel"), type0(), type0()), prop()),
1682        ),
1683        ("AnnealingSchedule", type0()),
1684        (
1685            "AISEstimator",
1686            arrow(
1687                app(cst("Measure"), type0()),
1688                arrow(
1689                    app(cst("Measure"), type0()),
1690                    arrow(cst("AnnealingSchedule"), type0()),
1691                ),
1692            ),
1693        ),
1694        ("NormConst", arrow(app(cst("Measure"), type0()), real_ty())),
1695        (
1696            "ImplicitSMObjective",
1697            arrow(app(cst("Measure"), type0()), type0()),
1698        ),
1699        (
1700            "GaussianConvolution",
1701            arrow(
1702                app(cst("Measure"), type0()),
1703                arrow(real_ty(), app(cst("Measure"), type0())),
1704            ),
1705        ),
1706        ("LogConcaveMeasure", arrow(type0(), type0())),
1707        (
1708            "ULADist",
1709            arrow(
1710                app(cst("LogConcaveMeasure"), type0()),
1711                arrow(real_ty(), arrow(nat_ty(), app(cst("Measure"), type0()))),
1712            ),
1713        ),
1714        (
1715            "W2",
1716            arrow(
1717                app(cst("Measure"), type0()),
1718                arrow(app(cst("Measure"), type0()), real_ty()),
1719            ),
1720        ),
1721        (
1722            "LangevinBound",
1723            arrow(
1724                app(cst("LogConcaveMeasure"), type0()),
1725                arrow(real_ty(), arrow(nat_ty(), real_ty())),
1726            ),
1727        ),
1728        (
1729            "MHKernel",
1730            arrow(
1731                app(cst("Measure"), type0()),
1732                arrow(
1733                    app2(cst("Kernel"), type0(), type0()),
1734                    app2(cst("Kernel"), type0(), type0()),
1735                ),
1736            ),
1737        ),
1738        ("Prod", arrow(type0(), arrow(type0(), type0()))),
1739        (
1740            "GibbsKernel",
1741            arrow(
1742                app(cst("Measure"), app2(cst("Prod"), type0(), type0())),
1743                app2(
1744                    cst("Kernel"),
1745                    app2(cst("Prod"), type0(), type0()),
1746                    app2(cst("Prod"), type0(), type0()),
1747                ),
1748            ),
1749        ),
1750        ("ExpressiveDecoder", type0()),
1751        (
1752            "BetaVAE",
1753            arrow(cst("ExpressiveDecoder"), arrow(real_ty(), type0())),
1754        ),
1755        ("AvoidsCollapse", arrow(type0(), prop())),
1756        ("Real.lt", arrow(real_ty(), arrow(real_ty(), prop()))),
1757        ("NaturalParams", arrow(type0(), type0())),
1758        ("LogNormalizer", arrow(type0(), arrow(type0(), real_ty()))),
1759        ("Gradient", arrow(arrow(type0(), real_ty()), type0())),
1760        (
1761            "MeanSuffStat",
1762            arrow(type0(), arrow(app(cst("Measure"), type0()), type0())),
1763        ),
1764        (
1765            "ExpFamilyDist",
1766            arrow(type0(), arrow(type0(), app(cst("Measure"), type0()))),
1767        ),
1768        ("ParticleSystem", arrow(type0(), arrow(nat_ty(), type0()))),
1769        ("AncestralLineage", arrow(type0(), type0())),
1770        ("CoalescentProcess", arrow(type0(), type0())),
1771        ("ResamplingTimes", arrow(type0(), type0())),
1772        ("BandwidthSeq", type0()),
1773        (
1774            "KDEn",
1775            arrow(
1776                app(cst("SmoothMeasure"), type0()),
1777                arrow(cst("BandwidthSeq"), type0()),
1778            ),
1779        ),
1780        ("OptimalBandwidth", arrow(cst("BandwidthSeq"), prop())),
1781        ("L2Convergence", arrow(type0(), prop())),
1782        (
1783            "CAVIStep",
1784            arrow(
1785                app(cst("Measure"), type0()),
1786                arrow(
1787                    list_ty(app(cst("Measure"), type0())),
1788                    list_ty(app(cst("Measure"), type0())),
1789                ),
1790            ),
1791        ),
1792        (
1793            "UpdatedFactors",
1794            arrow(
1795                app(cst("Measure"), type0()),
1796                arrow(
1797                    list_ty(app(cst("Measure"), type0())),
1798                    list_ty(app(cst("Measure"), type0())),
1799                ),
1800            ),
1801        ),
1802        ("BayesianNeuralNet", type0()),
1803        ("Input", type0()),
1804        (
1805            "PBP",
1806            arrow(cst("BayesianNeuralNet"), arrow(cst("Input"), type0())),
1807        ),
1808        ("GaussianApproxActivations", arrow(type0(), type0())),
1809        (
1810            "PBPActivations",
1811            arrow(cst("BayesianNeuralNet"), arrow(cst("Input"), type0())),
1812        ),
1813        ("FactorGraph", type0()),
1814        ("GaussianApprox", type0()),
1815        (
1816            "EPFixedPoint",
1817            arrow(cst("GaussianApprox"), arrow(cst("FactorGraph"), prop())),
1818        ),
1819        (
1820            "CavityAgreement",
1821            arrow(cst("GaussianApprox"), arrow(cst("FactorGraph"), prop())),
1822        ),
1823        (
1824            "NestedMCEstimator",
1825            arrow(
1826                arrow(type0(), real_ty()),
1827                arrow(nat_ty(), arrow(nat_ty(), type0())),
1828            ),
1829        ),
1830        ("Bias", arrow(type0(), real_ty())),
1831        ("NestedMCBiasRate", arrow(nat_ty(), real_ty())),
1832        (
1833            "ABCSMC",
1834            arrow(
1835                app(cst("Measure"), type0()),
1836                arrow(
1837                    arrow(type0(), app(cst("Measure"), type0())),
1838                    arrow(real_ty(), app(cst("Measure"), type0())),
1839                ),
1840            ),
1841        ),
1842        (
1843            "TruePosterior",
1844            arrow(
1845                app(cst("Measure"), type0()),
1846                arrow(
1847                    arrow(type0(), app(cst("Measure"), type0())),
1848                    app(cst("Measure"), type0()),
1849                ),
1850            ),
1851        ),
1852        (
1853            "ApproxPosterior",
1854            arrow(
1855                app(cst("Measure"), type0()),
1856                arrow(real_ty(), arrow(app(cst("Measure"), type0()), prop())),
1857            ),
1858        ),
1859        ("Real.gt", arrow(real_ty(), arrow(real_ty(), prop()))),
1860        ("svi_convergence", svi_convergence_ty()),
1861        ("normalizing_flow_cov", normalizing_flow_cov_ty()),
1862        ("score_fn_unbiased", score_fn_unbiased_ty()),
1863        (
1864            "pathwise_gradient_unbiased",
1865            pathwise_gradient_unbiased_ty(),
1866        ),
1867        ("measure_transport_exists", measure_transport_exists_ty()),
1868        ("ot_kantorovich", ot_kantorovich_ty()),
1869        ("stein_identity", stein_identity_ty()),
1870        ("svgd_convergence", svgd_convergence_ty()),
1871        ("pmc_consistency", pmc_consistency_ty()),
1872        (
1873            "evol_mcmc_detailed_balance",
1874            evol_mcmc_detailed_balance_ty(),
1875        ),
1876        (
1877            "parallel_tempering_exchange",
1878            parallel_tempering_exchange_ty(),
1879        ),
1880        (
1881            "simulated_annealing_convergence",
1882            simulated_annealing_convergence_ty(),
1883        ),
1884        ("vae_elbo_decomp", vae_elbo_decomp_ty()),
1885        ("diffusion_score_matching", diffusion_score_matching_ty()),
1886        ("flow_matching_ode", flow_matching_ode_ty()),
1887        ("gp_posterior_is_gp", gp_posterior_is_gp_ty()),
1888        ("gp_marginal_gaussian", gp_marginal_gaussian_ty()),
1889        ("pn_integration", pn_integration_ty()),
1890        ("stein_disc_zero_iff", stein_disc_zero_iff_ty()),
1891        ("smc_feynman_kac", smc_feynman_kac_ty()),
1892        ("pmmh_correctness", pmmh_correctness_ty()),
1893        ("ais_unbiased", ais_unbiased_ty()),
1894        ("dsm_equals_sm", dsm_equals_sm_ty()),
1895        ("langevin_convergence", langevin_convergence_ty()),
1896        ("mh_detailed_balance", mh_detailed_balance_ty()),
1897        ("gibbs_invariant", gibbs_invariant_ty()),
1898        (
1899            "vae_posterior_collapse_risk",
1900            vae_posterior_collapse_risk_ty(),
1901        ),
1902        ("grad_log_normalizer", grad_log_normalizer_ty()),
1903        ("smc_genealogy", smc_genealogy_ty()),
1904        ("kde_consistency", kde_consistency_ty()),
1905        ("mean_field_cavi", mean_field_cavi_ty()),
1906        ("pbp_gaussian_propagation", pbp_gaussian_propagation_ty()),
1907        ("ep_fixed_point", ep_fixed_point_ty()),
1908        ("nested_mc_bias", nested_mc_bias_ty()),
1909        ("abc_smc_consistency", abc_smc_consistency_ty()),
1910    ];
1911    for (name, ty) in axioms {
1912        env.add(Declaration::Axiom {
1913            name: Name::str(*name),
1914            univ_params: vec![],
1915            ty: ty.clone(),
1916        })
1917        .ok();
1918    }
1919    Ok(())
1920}