Skip to main content

oxilean_std/statistical_learning/
functions.rs

1//! Auto-generated module
2//!
3//! πŸ€– Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8    AdaBoost, BiasVarianceTradeoff, CausalBackdoor, CrossValidation, DoubleRademacher,
9    EarlyStoppingReg, ExponentialWeightsAlgorithm, FeatureMap, FisherInformation,
10    GaussianComplexity, GaussianProcess, GradientBoosting, GrowthFunction, KernelFunction,
11    KernelMatrix, KernelSVM, KernelSVMTrainer, LassoReg, OnlineGradientDescent,
12    PACBayesGeneralization, PACLearner, Perceptron, RademacherComplexity, RegretBound,
13    SVMClassifier, SampleComplexity, TikhonovReg, UCBBandit, UniformConvergence, VCDimension, ELBO,
14};
15
16pub fn app(f: Expr, a: Expr) -> Expr {
17    Expr::App(Box::new(f), Box::new(a))
18}
19pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
20    app(app(f, a), b)
21}
22pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
23    app(app2(f, a, b), c)
24}
25pub fn cst(s: &str) -> Expr {
26    Expr::Const(Name::str(s), vec![])
27}
28pub fn prop() -> Expr {
29    Expr::Sort(Level::zero())
30}
31pub fn type0() -> Expr {
32    Expr::Sort(Level::succ(Level::zero()))
33}
34pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
35    Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
36}
37pub fn arrow(a: Expr, b: Expr) -> Expr {
38    pi(BinderInfo::Default, "_", a, b)
39}
40pub fn bvar(n: u32) -> Expr {
41    Expr::BVar(n)
42}
43pub fn nat_ty() -> Expr {
44    cst("Nat")
45}
46pub fn real_ty() -> Expr {
47    cst("Real")
48}
49pub fn bool_ty() -> Expr {
50    cst("Bool")
51}
52pub fn list_ty(elem: Expr) -> Expr {
53    app(cst("List"), elem)
54}
55/// `PACLearner`: a learning algorithm that for any Ξ΅, Ξ΄ > 0 returns h with L_D(h) ≀ Ξ΅
56/// Type: Real β†’ Real β†’ Nat β†’ Type
57pub fn pac_learner_ty() -> Expr {
58    arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), type0())))
59}
60/// `SampleComplexity`: m = O((d log(d/Ξ΅) + log(1/Ξ΄)) / Ξ΅)
61/// Type: Real β†’ Real β†’ Nat β†’ Nat
62pub fn sample_complexity_ty() -> Expr {
63    arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), nat_ty())))
64}
65/// `VCDimension`: maximal shattered set size for a hypothesis class H
66/// Type: Type β†’ Nat
67pub fn vc_dimension_ty() -> Expr {
68    arrow(type0(), nat_ty())
69}
70/// `GrowthFunction`: Π_H(m) = max_{S,|S|=m} |{h|_S : h ∈ H}|
71/// Type: Type β†’ Nat β†’ Nat
72pub fn growth_function_ty() -> Expr {
73    arrow(type0(), arrow(nat_ty(), nat_ty()))
74}
75/// `PACLearnability`: a hypothesis class is PAC learnable
76/// Type: Type β†’ Prop
77pub fn pac_learnability_ty() -> Expr {
78    arrow(type0(), prop())
79}
80/// Fundamental theorem of PAC learning: finite VC dimension ↔ PAC learnability
81/// Type: βˆ€ (H : Type), VCDimension H < ∞ β†’ PACLearnability H
82pub fn fundamental_theorem_pac_ty() -> Expr {
83    pi(BinderInfo::Default, "H", type0(), prop())
84}
85/// Sauer-Shelah lemma: Ξ _H(m) ≀ Ξ£_{i=0}^{d} C(m,i)
86/// Type: βˆ€ (H : Type) (m : Nat), GrowthFunction H m ≀ Sauer_bound (VCDimension H) m
87pub fn sauer_shelah_ty() -> Expr {
88    pi(
89        BinderInfo::Default,
90        "H",
91        type0(),
92        pi(BinderInfo::Default, "m", nat_ty(), prop()),
93    )
94}
95/// Sample complexity upper bound for PAC learning
96/// Type: βˆ€ (Ξ΅ Ξ΄ : Real) (d : Nat), m β‰₯ SampleComplexity Ξ΅ Ξ΄ d β†’ Prop
97pub fn sample_complexity_bound_ty() -> Expr {
98    pi(
99        BinderInfo::Default,
100        "eps",
101        real_ty(),
102        pi(
103            BinderInfo::Default,
104            "delta",
105            real_ty(),
106            pi(BinderInfo::Default, "d", nat_ty(), prop()),
107        ),
108    )
109}
110/// `RademacherComplexity`: R_n(H) = E_{Οƒ,S}[sup_{h∈H} (1/n) Ξ£ Οƒ_i h(x_i)]
111/// Type: Type β†’ Nat β†’ Real
112pub fn rademacher_complexity_ty() -> Expr {
113    arrow(type0(), arrow(nat_ty(), real_ty()))
114}
115/// `UniformConvergence`: sup_{h∈H} |L_D(h) - L_S(h)| ≀ Ξ΅ w.p. β‰₯ 1βˆ’Ξ΄
116/// Type: Type β†’ Real β†’ Real β†’ Prop
117pub fn uniform_convergence_ty() -> Expr {
118    arrow(type0(), arrow(real_ty(), arrow(real_ty(), prop())))
119}
120/// `DoubleRademacher`: two-sided Rademacher bound (two-sided uniform convergence)
121/// Type: Type β†’ Nat β†’ Real β†’ Prop
122pub fn double_rademacher_ty() -> Expr {
123    arrow(type0(), arrow(nat_ty(), arrow(real_ty(), prop())))
124}
125/// `GaussianComplexity`: G_n(H) = E_{g,S}[sup_{h∈H} (1/n) Σ g_i h(x_i)]
126/// Type: Type β†’ Nat β†’ Real
127pub fn gaussian_complexity_ty() -> Expr {
128    arrow(type0(), arrow(nat_ty(), real_ty()))
129}
130/// Rademacher generalization bound: L_D(h) ≀ L_S(h) + 2 R_n(H) + O(√(log(1/Ξ΄)/n))
131/// Type: βˆ€ (H : Type) (n : Nat) (Ξ΄ : Real), Prop
132pub fn rademacher_bound_ty() -> Expr {
133    pi(
134        BinderInfo::Default,
135        "H",
136        type0(),
137        pi(
138            BinderInfo::Default,
139            "n",
140            nat_ty(),
141            pi(BinderInfo::Default, "delta", real_ty(), prop()),
142        ),
143    )
144}
145/// Symmetrization lemma: relates population risk to Rademacher complexity
146/// Type: βˆ€ (H : Type) (n : Nat), Prop
147pub fn symmetrization_lemma_ty() -> Expr {
148    pi(
149        BinderInfo::Default,
150        "H",
151        type0(),
152        pi(BinderInfo::Default, "n", nat_ty(), prop()),
153    )
154}
155/// `KernelFunction`: k: X Γ— X β†’ ℝ, a positive-definite symmetric function
156/// Type: Type β†’ Type (representing a kernel on X)
157pub fn kernel_function_ty() -> Expr {
158    arrow(type0(), type0())
159}
160/// `RKHS`: reproducing kernel Hilbert space H_k associated to kernel k
161/// Type: (Type β†’ Type) β†’ Type
162pub fn rkhs_ty() -> Expr {
163    arrow(arrow(type0(), type0()), type0())
164}
165/// `FeatureMap`: Ο†: X β†’ H_k with k(x,x') = βŸ¨Ο†(x),Ο†(x')⟩
166/// Type: Type β†’ Type β†’ Type
167pub fn feature_map_ty() -> Expr {
168    arrow(type0(), arrow(type0(), type0()))
169}
170/// `KernelMatrix`: Gram matrix K_{ij} = k(x_i, x_j) ∈ ℝ^{nΓ—n}
171/// Type: (Type β†’ Type) β†’ Nat β†’ Type
172pub fn kernel_matrix_ty() -> Expr {
173    arrow(arrow(type0(), type0()), arrow(nat_ty(), type0()))
174}
175/// `KernelSVM`: support vector machine with kernel trick
176/// Type: (Type β†’ Type) β†’ Real β†’ Type
177pub fn kernel_svm_ty() -> Expr {
178    arrow(arrow(type0(), type0()), arrow(real_ty(), type0()))
179}
180/// Mercer's theorem: k is p.d. ↔ βˆƒ feature map Ο† with k(x,x') = βŸ¨Ο†(x),Ο†(x')⟩
181/// Type: βˆ€ (k : Type β†’ Type), isPositiveDefinite k ↔ βˆƒ feature map, Prop
182pub fn mercer_theorem_ty() -> Expr {
183    pi(BinderInfo::Default, "k", arrow(type0(), type0()), prop())
184}
185/// Representer theorem: optimal solution in RKHS lies in span of kernel evaluations
186/// Type: βˆ€ (k : Type β†’ Type) (n : Nat), Prop
187pub fn representer_theorem_ty() -> Expr {
188    pi(
189        BinderInfo::Default,
190        "k",
191        arrow(type0(), type0()),
192        pi(BinderInfo::Default, "n", nat_ty(), prop()),
193    )
194}
195/// Kernel PCA: principal components in feature space via eigendecomposition of K
196/// Type: (Type β†’ Type) β†’ Nat β†’ Nat β†’ Type
197pub fn kernel_pca_ty() -> Expr {
198    arrow(
199        arrow(type0(), type0()),
200        arrow(nat_ty(), arrow(nat_ty(), type0())),
201    )
202}
203/// `RegularizedObjective`: L(h) + Ξ» Ξ©(h)
204/// Type: Real β†’ Type (regularization weight β†’ regularized problem)
205pub fn regularized_objective_ty() -> Expr {
206    arrow(real_ty(), type0())
207}
208/// `TikhonovReg`: Tikhonov regularization Ξ»β€–hβ€–Β²_{H_k}
209/// Type: Real β†’ (Type β†’ Type) β†’ Type
210pub fn tikhonov_reg_ty() -> Expr {
211    arrow(real_ty(), arrow(arrow(type0(), type0()), type0()))
212}
213/// `LassoReg`: ℓ₁ regularization Ξ»β€–h‖₁ (sparsity-promoting)
214/// Type: Real β†’ Type
215pub fn lasso_reg_ty() -> Expr {
216    arrow(real_ty(), type0())
217}
218/// `EarlyStoppingReg`: implicit regularization via iteration count T
219/// Type: Nat β†’ Type
220pub fn early_stopping_reg_ty() -> Expr {
221    arrow(nat_ty(), type0())
222}
223/// `BiasVarianceTradeoff`: MSE = BiasΒ² + Variance + Noise
224/// Type: Real β†’ Real β†’ Real β†’ Prop
225pub fn bias_variance_tradeoff_ty() -> Expr {
226    arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), prop())))
227}
228/// Ridge regression solution: (X^T X + Ξ»I)^{-1} X^T y
229/// Type: βˆ€ (n d : Nat) (Ξ» : Real), List (List Real) β†’ List Real β†’ List Real
230pub fn ridge_regression_solution_ty() -> Expr {
231    pi(
232        BinderInfo::Default,
233        "n",
234        nat_ty(),
235        pi(
236            BinderInfo::Default,
237            "d",
238            nat_ty(),
239            pi(BinderInfo::Default, "lam", real_ty(), prop()),
240        ),
241    )
242}
243/// Bias-variance decomposition: E[(Ε· - y)Β²] = BiasΒ²(Ε·) + Var(Ε·) + σ²
244/// Type: Prop
245pub fn bias_variance_decomposition_ty() -> Expr {
246    prop()
247}
248/// `OnlineAlgorithm`: sequential prediction protocol over T rounds
249/// Type: Nat β†’ Type
250pub fn online_algorithm_ty() -> Expr {
251    arrow(nat_ty(), type0())
252}
253/// `Perceptron`: online linear classifier with mistake bound
254/// Type: Nat β†’ Type (dimension β†’ perceptron)
255pub fn perceptron_ty() -> Expr {
256    arrow(nat_ty(), type0())
257}
258/// `AdaBoost`: adaptive boosting with exponential loss, T weak learners
259/// Type: Nat β†’ Type β†’ Type
260pub fn adaboost_ty() -> Expr {
261    arrow(nat_ty(), arrow(type0(), type0()))
262}
263/// `OnlineGradientDescent`: OGD with regret O(√T)
264/// Type: Real β†’ Nat β†’ Type (learning rate Ξ·, rounds T)
265pub fn online_gradient_descent_ty() -> Expr {
266    arrow(real_ty(), arrow(nat_ty(), type0()))
267}
268/// `RegretBound`: R_T = Ξ£ L_t(h_t) - min_h Ξ£ L_t(h) ≀ O(√T)
269/// Type: Nat β†’ Real β†’ Prop (rounds T, bound Ξ΅)
270pub fn regret_bound_ty() -> Expr {
271    arrow(nat_ty(), arrow(real_ty(), prop()))
272}
273/// Perceptron mistake bound: M ≀ (R/Ξ³)Β² where R = maxβ€–xβ€–, Ξ³ = margin
274/// Type: βˆ€ (R Ξ³ : Real), mistakes ≀ (R/Ξ³)Β²
275pub fn perceptron_mistake_bound_ty() -> Expr {
276    pi(
277        BinderInfo::Default,
278        "R",
279        real_ty(),
280        pi(BinderInfo::Default, "gamma", real_ty(), prop()),
281    )
282}
283/// OGD regret bound: R_T ≀ DΒ²/(2Ξ·) + Ξ· Ξ£β€–βˆ‡_tβ€–Β² ≀ DΒ·G·√(2T)
284/// Type: βˆ€ (T : Nat) (eta : Real), Prop
285pub fn ogd_regret_bound_ty() -> Expr {
286    pi(
287        BinderInfo::Default,
288        "T",
289        nat_ty(),
290        pi(BinderInfo::Default, "eta", real_ty(), prop()),
291    )
292}
293/// AdaBoost training error: ≀ exp(-2 Ξ£ Ξ³_tΒ²) after T rounds
294/// Type: βˆ€ (T : Nat), Prop
295pub fn adaboost_training_error_ty() -> Expr {
296    pi(BinderInfo::Default, "T", nat_ty(), prop())
297}
298/// `MLMutualInformation`: I(X;Y) = H(X) - H(X|Y) for learning-theoretic analysis
299/// Type: (List Real) β†’ (List Real) β†’ Real
300pub fn ml_mutual_information_ty() -> Expr {
301    arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
302}
303/// `MLKLDivergence`: D_KL(Pβ€–Q) = Ξ£ P log(P/Q) β€” used in PAC-Bayes bounds
304/// Type: (List Real) β†’ (List Real) β†’ Real
305pub fn ml_kl_divergence_ty() -> Expr {
306    arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
307}
308/// `FisherInformation`: I(ΞΈ) = E[(βˆ‚/βˆ‚ΞΈ log p(x;ΞΈ))Β²]
309/// Type: (Real β†’ Real) β†’ Real β†’ Real
310pub fn fisher_information_ty() -> Expr {
311    arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
312}
313/// `ELBO`: evidence lower bound β„’(q) = E_q[log p(x,z)] - E_q[log q(z)]
314/// Type: (Real β†’ Real) β†’ (Real β†’ Real) β†’ Real
315pub fn elbo_ty() -> Expr {
316    arrow(
317        arrow(real_ty(), real_ty()),
318        arrow(arrow(real_ty(), real_ty()), real_ty()),
319    )
320}
321/// Data processing inequality: I(X;Z) ≀ I(X;Y) for Markov chain Xβ†’Yβ†’Z
322/// Type: Prop
323pub fn data_processing_inequality_ty() -> Expr {
324    prop()
325}
326/// Chain rule of mutual information: I(X;Y,Z) = I(X;Y) + I(X;Z|Y)
327/// Type: Prop
328pub fn chain_rule_mi_ty() -> Expr {
329    prop()
330}
331/// CramΓ©r-Rao bound: Var(ΞΈΜ‚) β‰₯ 1/I(ΞΈ) for any unbiased estimator
332/// Type: βˆ€ (p : Real β†’ Real) (ΞΈ : Real), Prop
333pub fn cramer_rao_bound_ty() -> Expr {
334    pi(
335        BinderInfo::Default,
336        "p",
337        arrow(real_ty(), real_ty()),
338        pi(BinderInfo::Default, "theta", real_ty(), prop()),
339    )
340}
341/// PAC-Bayes bound: L_D(Q) ≀ L_S(Q) + √((D_KL(Qβ€–P) + log(n/Ξ΄))/(2n))
342/// Type: βˆ€ (n : Nat) (Ξ΄ : Real), Prop
343pub fn pac_bayes_bound_ty() -> Expr {
344    pi(
345        BinderInfo::Default,
346        "n",
347        nat_ty(),
348        pi(BinderInfo::Default, "delta", real_ty(), prop()),
349    )
350}
351/// Register all statistical learning theory axioms and theorems in the kernel environment.
352pub fn build_env(env: &mut Environment) -> Result<(), String> {
353    let axioms: &[(&str, Expr)] = &[
354        ("PACLearner", pac_learner_ty()),
355        ("SampleComplexity", sample_complexity_ty()),
356        ("VCDimension", vc_dimension_ty()),
357        ("GrowthFunction", growth_function_ty()),
358        ("PACLearnability", pac_learnability_ty()),
359        ("fundamental_theorem_pac", fundamental_theorem_pac_ty()),
360        ("sauer_shelah_lemma", sauer_shelah_ty()),
361        ("sample_complexity_bound", sample_complexity_bound_ty()),
362        ("RademacherComplexity", rademacher_complexity_ty()),
363        ("UniformConvergence", uniform_convergence_ty()),
364        ("DoubleRademacher", double_rademacher_ty()),
365        ("GaussianComplexity", gaussian_complexity_ty()),
366        ("rademacher_bound", rademacher_bound_ty()),
367        ("symmetrization_lemma", symmetrization_lemma_ty()),
368        ("KernelFunction", kernel_function_ty()),
369        ("RKHS", rkhs_ty()),
370        ("FeatureMap", feature_map_ty()),
371        ("KernelMatrix", kernel_matrix_ty()),
372        ("KernelSVM", kernel_svm_ty()),
373        ("mercer_theorem", mercer_theorem_ty()),
374        ("representer_theorem", representer_theorem_ty()),
375        ("KernelPCA", kernel_pca_ty()),
376        ("RegularizedObjective", regularized_objective_ty()),
377        ("TikhonovReg", tikhonov_reg_ty()),
378        ("LassoReg", lasso_reg_ty()),
379        ("EarlyStoppingReg", early_stopping_reg_ty()),
380        ("BiasVarianceTradeoff", bias_variance_tradeoff_ty()),
381        ("ridge_regression_solution", ridge_regression_solution_ty()),
382        (
383            "bias_variance_decomposition",
384            bias_variance_decomposition_ty(),
385        ),
386        ("OnlineAlgorithm", online_algorithm_ty()),
387        ("Perceptron", perceptron_ty()),
388        ("AdaBoost", adaboost_ty()),
389        ("OnlineGradientDescent", online_gradient_descent_ty()),
390        ("RegretBound", regret_bound_ty()),
391        ("perceptron_mistake_bound", perceptron_mistake_bound_ty()),
392        ("ogd_regret_bound", ogd_regret_bound_ty()),
393        ("adaboost_training_error", adaboost_training_error_ty()),
394        ("MLMutualInformation", ml_mutual_information_ty()),
395        ("MLKLDivergence", ml_kl_divergence_ty()),
396        ("FisherInformation", fisher_information_ty()),
397        ("ELBO", elbo_ty()),
398        (
399            "data_processing_inequality",
400            data_processing_inequality_ty(),
401        ),
402        ("chain_rule_mutual_information", chain_rule_mi_ty()),
403        ("cramer_rao_bound", cramer_rao_bound_ty()),
404        ("pac_bayes_bound", pac_bayes_bound_ty()),
405    ];
406    for (name, ty) in axioms {
407        env.add(Declaration::Axiom {
408            name: Name::str(*name),
409            univ_params: vec![],
410            ty: ty.clone(),
411        })
412        .ok();
413    }
414    Ok(())
415}
416/// Dot product of two equal-length slices.
417pub(super) fn dot(a: &[f64], b: &[f64]) -> f64 {
418    a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
419}
420/// `ExponentialWeightsAlgorithm`: EWA (Hedge) distribution over n experts
421/// Type: Nat β†’ Real β†’ Type (n experts, learning rate Ξ·)
422pub fn ewa_algorithm_ty() -> Expr {
423    arrow(nat_ty(), arrow(real_ty(), type0()))
424}
425/// `MultiplicativeWeightsUpdate`: MW update step at round t
426/// Type: Nat β†’ Real β†’ Type (n, Ξ·)
427pub fn multiplicative_weights_update_ty() -> Expr {
428    arrow(nat_ty(), arrow(real_ty(), type0()))
429}
430/// `EWARegretBound`: R_T(EWA) ≀ ln(n)/Ξ· + Ξ· T/2
431/// Type: βˆ€ (n : Nat) (T : Nat) (Ξ· : Real), Prop
432pub fn ewa_regret_bound_ty() -> Expr {
433    pi(
434        BinderInfo::Default,
435        "n",
436        nat_ty(),
437        pi(
438            BinderInfo::Default,
439            "T",
440            nat_ty(),
441            pi(BinderInfo::Default, "eta", real_ty(), prop()),
442        ),
443    )
444}
445/// `BanditAlgorithm`: protocol with only loss feedback (no gradient)
446/// Type: Nat β†’ Nat β†’ Type (n arms, T rounds)
447pub fn bandit_algorithm_ty() -> Expr {
448    arrow(nat_ty(), arrow(nat_ty(), type0()))
449}
450/// `UCBAlgorithm`: upper confidence bound algorithm
451/// Type: Nat β†’ Real β†’ Type (n arms, exploration param c)
452pub fn ucb_algorithm_ty() -> Expr {
453    arrow(nat_ty(), arrow(real_ty(), type0()))
454}
455/// `UCBRegretBound`: UCB1 regret O(√(n T ln T))
456/// Type: βˆ€ (n T : Nat), Prop
457pub fn ucb_regret_bound_ty() -> Expr {
458    pi(
459        BinderInfo::Default,
460        "n",
461        nat_ty(),
462        pi(BinderInfo::Default, "T", nat_ty(), prop()),
463    )
464}
465/// `ThompsonSampling`: Bayesian bandit via posterior sampling
466/// Type: Nat β†’ Type (n arms)
467pub fn thompson_sampling_ty() -> Expr {
468    arrow(nat_ty(), type0())
469}
470/// `BayesianRegretBound`: Bayesian regret for Thompson sampling O(√(n T))
471/// Type: βˆ€ (n T : Nat), Prop
472pub fn bayesian_regret_bound_ty() -> Expr {
473    pi(
474        BinderInfo::Default,
475        "n",
476        nat_ty(),
477        pi(BinderInfo::Default, "T", nat_ty(), prop()),
478    )
479}
480/// `DataDependentBound`: a bound that depends on the observed dataset S
481/// Type: Nat β†’ Real β†’ Prop (n samples, bound value)
482pub fn data_dependent_bound_ty() -> Expr {
483    arrow(nat_ty(), arrow(real_ty(), prop()))
484}
485/// `LocalizedRademacher`: local Rademacher complexity around minimizer
486/// Type: Type β†’ Nat β†’ Real β†’ Real (H, n, radius β†’ complexity)
487pub fn localized_rademacher_ty() -> Expr {
488    arrow(type0(), arrow(nat_ty(), arrow(real_ty(), real_ty())))
489}
490/// `LocalizedBound`: generalization bound via localized Rademacher
491/// Type: βˆ€ (H : Type) (n : Nat) (Ξ΄ : Real), Prop
492pub fn localized_bound_ty() -> Expr {
493    pi(
494        BinderInfo::Default,
495        "H",
496        type0(),
497        pi(
498            BinderInfo::Default,
499            "n",
500            nat_ty(),
501            pi(BinderInfo::Default, "delta", real_ty(), prop()),
502        ),
503    )
504}
505/// `PACBayesBound`: McAllester's PAC-Bayes bound
506/// Type: βˆ€ (n : Nat) (Ξ΄ : Real), (List Real) β†’ (List Real) β†’ Prop
507pub fn pac_bayes_mcallester_ty() -> Expr {
508    pi(
509        BinderInfo::Default,
510        "n",
511        nat_ty(),
512        pi(BinderInfo::Default, "delta", real_ty(), prop()),
513    )
514}
515/// `CatoniPACBayes`: Catoni's PAC-Bayes bound with tighter constants
516/// Type: βˆ€ (n : Nat) (Ξ΄ : Real), Prop
517pub fn catoni_pac_bayes_ty() -> Expr {
518    pi(
519        BinderInfo::Default,
520        "n",
521        nat_ty(),
522        pi(BinderInfo::Default, "delta", real_ty(), prop()),
523    )
524}
525/// `RKHSNorm`: β€–fβ€–_{H_k}Β² = Ξ£_{i,j} Ξ±_i Ξ±_j k(x_i, x_j)
526/// Type: (Type β†’ Type) β†’ Real (kernel β†’ normΒ²)
527pub fn rkhs_norm_ty() -> Expr {
528    arrow(arrow(type0(), type0()), real_ty())
529}
530/// `KernelPCAProjection`: projection onto top-k kernel PCA components
531/// Type: (Type β†’ Type) β†’ Nat β†’ Nat β†’ Type (kernel, n, k β†’ projector)
532pub fn kernel_pca_projection_ty() -> Expr {
533    arrow(
534        arrow(type0(), type0()),
535        arrow(nat_ty(), arrow(nat_ty(), type0())),
536    )
537}
538/// `SVMGeneralizationBound`: margin-based bound ≀ RΒ²/(Ξ³Β² n)
539/// Type: βˆ€ (n : Nat) (R Ξ³ : Real), Prop
540pub fn svm_generalization_bound_ty() -> Expr {
541    pi(
542        BinderInfo::Default,
543        "n",
544        nat_ty(),
545        pi(
546            BinderInfo::Default,
547            "R",
548            real_ty(),
549            pi(BinderInfo::Default, "gamma", real_ty(), prop()),
550        ),
551    )
552}
553/// `NeuralNetwork`: a feedforward network of given depth and width
554/// Type: Nat β†’ Nat β†’ Type (depth, width β†’ network)
555pub fn neural_network_ty() -> Expr {
556    arrow(nat_ty(), arrow(nat_ty(), type0()))
557}
558/// `DepthSeparation`: a function requiring exponentially wide shallow net
559/// Type: Nat β†’ Prop (depth d β†’ separation result)
560pub fn depth_separation_ty() -> Expr {
561    arrow(nat_ty(), prop())
562}
563/// `BarronSpace`: functions with finite Barron norm (representable by shallow nets)
564/// Type: Real β†’ Type (radius B β†’ function class)
565pub fn barron_space_ty() -> Expr {
566    arrow(real_ty(), type0())
567}
568/// `BarronApproximation`: shallow nets approximate Barron functions at rate 1/√m
569/// Type: βˆ€ (B : Real) (m : Nat), Prop (B = norm bound, m = neurons)
570pub fn barron_approximation_ty() -> Expr {
571    pi(
572        BinderInfo::Default,
573        "B",
574        real_ty(),
575        pi(BinderInfo::Default, "m", nat_ty(), prop()),
576    )
577}
578/// `NNExpressivity`: VC dimension / capacity of a neural net class
579/// Type: Nat β†’ Nat β†’ Nat (depth, width β†’ VC dim)
580pub fn nn_expressivity_ty() -> Expr {
581    arrow(nat_ty(), arrow(nat_ty(), nat_ty()))
582}
583/// `NNGeneralizationBound`: Rademacher-based bound for neural networks
584/// Type: βˆ€ (depth width n : Nat) (Ξ΄ : Real), Prop
585pub fn nn_generalization_bound_ty() -> Expr {
586    pi(
587        BinderInfo::Default,
588        "depth",
589        nat_ty(),
590        pi(
591            BinderInfo::Default,
592            "width",
593            nat_ty(),
594            pi(
595                BinderInfo::Default,
596                "n",
597                nat_ty(),
598                pi(BinderInfo::Default, "delta", real_ty(), prop()),
599            ),
600        ),
601    )
602}
603/// `DoubleDescent`: test error as function of model complexity
604/// Type: Nat β†’ Real (n_params β†’ test error curve)
605pub fn double_descent_ty() -> Expr {
606    arrow(nat_ty(), real_ty())
607}
608/// `BenignOverfitting`: interpolating solution still generalizes well
609/// Type: βˆ€ (n d : Nat), Prop (n samples, d features)
610pub fn benign_overfitting_ty() -> Expr {
611    pi(
612        BinderInfo::Default,
613        "n",
614        nat_ty(),
615        pi(BinderInfo::Default, "d", nat_ty(), prop()),
616    )
617}
618/// `ImplicitRegularization`: GD with zero init converges to min-norm solution
619/// Type: Nat β†’ Real β†’ Prop (steps T, step size Ξ·)
620pub fn implicit_regularization_ty() -> Expr {
621    pi(
622        BinderInfo::Default,
623        "T",
624        nat_ty(),
625        pi(BinderInfo::Default, "eta", real_ty(), prop()),
626    )
627}
628/// `MinNormInterpolation`: min-β€–wβ€– solution that fits training data exactly
629/// Type: Nat β†’ Nat β†’ Type (n samples, d features β†’ solution)
630pub fn min_norm_interpolation_ty() -> Expr {
631    arrow(nat_ty(), arrow(nat_ty(), type0()))
632}
633/// `UniformStability`: |L(A(S), z) - L(A(S'), z)| ≀ Ξ² for any S, S' differing in 1 point
634/// Type: Real β†’ Prop (Ξ² β†’ stability predicate)
635pub fn uniform_stability_ty() -> Expr {
636    arrow(real_ty(), prop())
637}
638/// `OnAverageStability`: E_S[E_z|L(A(S),z) - L(A(S^{(i)}),z)|] ≀ Ξ²
639/// Type: Real β†’ Prop
640pub fn on_average_stability_ty() -> Expr {
641    arrow(real_ty(), prop())
642}
643/// `StabilityGeneralizationBound`: Ξ²-stable β†’ gen error ≀ 2Ξ² + O(1/√n)
644/// Type: βˆ€ (Ξ² : Real) (n : Nat) (Ξ΄ : Real), Prop
645pub fn stability_generalization_bound_ty() -> Expr {
646    pi(
647        BinderInfo::Default,
648        "beta",
649        real_ty(),
650        pi(
651            BinderInfo::Default,
652            "n",
653            nat_ty(),
654            pi(BinderInfo::Default, "delta", real_ty(), prop()),
655        ),
656    )
657}
658/// `DataDeletion`: privacy-adjacent: model update after removing one point
659/// Type: Nat β†’ Real β†’ Type (n, budget β†’ deletion mechanism)
660pub fn data_deletion_ty() -> Expr {
661    arrow(nat_ty(), arrow(real_ty(), type0()))
662}
663/// `DPSGDAlgorithm`: DP-SGD with noise Οƒ, clipping C, epochs T
664/// Type: Real β†’ Real β†’ Nat β†’ Type (Οƒ, C, T β†’ algorithm)
665pub fn dp_sgd_algorithm_ty() -> Expr {
666    arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), type0())))
667}
668/// `PrivatePACLearning`: PAC learning with (Ξ΅, Ξ΄)-differential privacy
669/// Type: Real β†’ Real β†’ Type (Ξ΅_priv, Ξ΄_priv β†’ learner)
670pub fn private_pac_learning_ty() -> Expr {
671    arrow(real_ty(), arrow(real_ty(), type0()))
672}
673/// `PrivateQueryMechanism`: answering statistical queries with DP
674/// Type: Real β†’ Real β†’ Type (Ξ΅, Ξ΄ β†’ mechanism)
675pub fn private_query_mechanism_ty() -> Expr {
676    arrow(real_ty(), arrow(real_ty(), type0()))
677}
678/// `DPGeneralizationBound`: utility bound for DP learning
679/// Type: βˆ€ (n : Nat) (eps_priv delta_priv : Real), Prop
680pub fn dp_generalization_bound_ty() -> Expr {
681    pi(
682        BinderInfo::Default,
683        "n",
684        nat_ty(),
685        pi(
686            BinderInfo::Default,
687            "eps_priv",
688            real_ty(),
689            pi(BinderInfo::Default, "delta_priv", real_ty(), prop()),
690        ),
691    )
692}
693/// `DPSampleComplexity`: extra samples needed for privacy
694/// Type: Real β†’ Real β†’ Real β†’ Real β†’ Nat (Ξ΅_priv, Ξ΄_priv, Ξ΅_learn, Ξ΄_learn β†’ m)
695pub fn dp_sample_complexity_ty() -> Expr {
696    arrow(
697        real_ty(),
698        arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), nat_ty()))),
699    )
700}
701/// `CalibrationError`: ECE = E[|P(Y=1|score=p) - p|]
702/// Type: (List Real) β†’ (List Real) β†’ Real (predictions, labels β†’ ECE)
703pub fn calibration_error_ty() -> Expr {
704    arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
705}
706/// `ProperScoringRule`: S(f, y) is proper if E[S(f,Y)] maximized by true dist
707/// Type: (Real β†’ Real) β†’ Real β†’ Real (forecast function, outcome β†’ score)
708pub fn proper_scoring_rule_ty() -> Expr {
709    arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
710}
711/// `ReliabilityDiagram`: calibration curve P(Y=1 | score ∈ bin) vs. score
712/// Type: Nat β†’ Type (n bins β†’ diagram)
713pub fn reliability_diagram_ty() -> Expr {
714    arrow(nat_ty(), type0())
715}
716/// `SharpnessMeasure`: variance of the forecast distribution
717/// Type: (List Real) β†’ Real (forecasts β†’ sharpness)
718pub fn sharpness_measure_ty() -> Expr {
719    arrow(list_ty(real_ty()), real_ty())
720}
721/// `DoCalculus`: causal interventional distribution P(Y | do(X=x))
722/// Type: Real β†’ Real β†’ Real (x, context β†’ P(Y | do(X=x)))
723pub fn do_calculus_ty() -> Expr {
724    arrow(real_ty(), arrow(real_ty(), real_ty()))
725}
726/// `InterventionalDist`: P(Y | do(X)) in a structural causal model
727/// Type: Type β†’ Type β†’ Type (X space, Y space β†’ dist)
728pub fn interventional_dist_ty() -> Expr {
729    arrow(type0(), arrow(type0(), type0()))
730}
731/// `BackdoorCriterion`: set Z satisfies backdoor criterion for (X, Y)
732/// Type: Type β†’ Type β†’ Type β†’ Prop (X, Y, Z β†’ criterion)
733pub fn backdoor_criterion_ty() -> Expr {
734    arrow(type0(), arrow(type0(), arrow(type0(), prop())))
735}
736/// `BackdoorAdjustment`: P(Y|do(X)) = Ξ£_z P(Y|X,Z=z) P(Z=z)
737/// Type: βˆ€ (X Y Z : Type), BackdoorCriterion X Y Z β†’ Prop
738pub fn backdoor_adjustment_ty() -> Expr {
739    pi(
740        BinderInfo::Default,
741        "X",
742        type0(),
743        pi(
744            BinderInfo::Default,
745            "Y",
746            type0(),
747            pi(BinderInfo::Default, "Z", type0(), prop()),
748        ),
749    )
750}
751/// `ConfoundingBias`: bias from ignoring confounders
752/// Type: Real β†’ Prop (bias magnitude β†’ bounded)
753pub fn confounding_bias_ty() -> Expr {
754    arrow(real_ty(), prop())
755}
756/// `DomainAdaptation`: learning when source and target domains differ
757/// Type: Type β†’ Type β†’ Nat β†’ Type (source, target, n β†’ adapted model)
758pub fn domain_adaptation_ty() -> Expr {
759    arrow(type0(), arrow(type0(), arrow(nat_ty(), type0())))
760}
761/// `CovariateShift`: p_S(x) β‰  p_T(x) but p(y|x) the same
762/// Type: Type β†’ Prop
763pub fn covariate_shift_ty() -> Expr {
764    arrow(type0(), prop())
765}
766/// `ImportanceWeighting`: reweight source samples by p_T(x)/p_S(x)
767/// Type: (Real β†’ Real) β†’ (Real β†’ Real) β†’ Real β†’ Real (p_T, p_S, x β†’ weight)
768pub fn importance_weighting_ty() -> Expr {
769    arrow(
770        arrow(real_ty(), real_ty()),
771        arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty())),
772    )
773}
774/// `DomainAdaptationBound`: generalization bound under covariate shift
775/// Type: βˆ€ (n : Nat) (delta : Real), Prop
776pub fn domain_adaptation_bound_ty() -> Expr {
777    pi(
778        BinderInfo::Default,
779        "n",
780        nat_ty(),
781        pi(BinderInfo::Default, "delta", real_ty(), prop()),
782    )
783}
784/// `FederatedLearning`: distributed optimization across m clients
785/// Type: Nat β†’ Nat β†’ Type (m clients, T rounds β†’ protocol)
786pub fn federated_learning_ty() -> Expr {
787    arrow(nat_ty(), arrow(nat_ty(), type0()))
788}
789/// `HeterogeneityMeasure`: degree of statistical heterogeneity across clients
790/// Type: Nat β†’ Real β†’ Prop (m clients, Ξ“ measure β†’ bound)
791pub fn heterogeneity_measure_ty() -> Expr {
792    arrow(nat_ty(), arrow(real_ty(), prop()))
793}
794/// `FedAvgConvergence`: FedAvg converges at rate O(1/√(T m))
795/// Type: βˆ€ (T m : Nat), Prop
796pub fn fedavg_convergence_ty() -> Expr {
797    pi(
798        BinderInfo::Default,
799        "T",
800        nat_ty(),
801        pi(BinderInfo::Default, "m", nat_ty(), prop()),
802    )
803}
804/// `ByzantineFaultTolerance`: learning with f Byzantine clients out of m
805/// Type: Nat β†’ Nat β†’ Type (m total, f Byzantine β†’ robust protocol)
806pub fn byzantine_fault_tolerance_ty() -> Expr {
807    arrow(nat_ty(), arrow(nat_ty(), type0()))
808}
809/// `CommunicationComplexity`: total bits communicated to reach Ξ΅ accuracy
810/// Type: Real β†’ Nat β†’ Nat (Ξ΅, m β†’ bits)
811pub fn communication_complexity_ty() -> Expr {
812    arrow(real_ty(), arrow(nat_ty(), nat_ty()))
813}
814/// Register the extended set of statistical learning theory axioms (Β§7–§17).
815pub fn build_env_extended(env: &mut Environment) -> Result<(), String> {
816    build_env(env)?;
817    let axioms: &[(&str, Expr)] = &[
818        ("EWAAlgorithm", ewa_algorithm_ty()),
819        (
820            "MultiplicativeWeightsUpdate",
821            multiplicative_weights_update_ty(),
822        ),
823        ("ewa_regret_bound", ewa_regret_bound_ty()),
824        ("BanditAlgorithm", bandit_algorithm_ty()),
825        ("UCBAlgorithm", ucb_algorithm_ty()),
826        ("ucb_regret_bound", ucb_regret_bound_ty()),
827        ("ThompsonSampling", thompson_sampling_ty()),
828        ("bayesian_regret_bound", bayesian_regret_bound_ty()),
829        ("DataDependentBound", data_dependent_bound_ty()),
830        ("LocalizedRademacher", localized_rademacher_ty()),
831        ("localized_bound", localized_bound_ty()),
832        ("pac_bayes_mcallester", pac_bayes_mcallester_ty()),
833        ("catoni_pac_bayes", catoni_pac_bayes_ty()),
834        ("RKHSNorm", rkhs_norm_ty()),
835        ("KernelPCAProjection", kernel_pca_projection_ty()),
836        ("svm_generalization_bound", svm_generalization_bound_ty()),
837        ("NeuralNetwork", neural_network_ty()),
838        ("depth_separation", depth_separation_ty()),
839        ("BarronSpace", barron_space_ty()),
840        ("barron_approximation", barron_approximation_ty()),
841        ("NNExpressivity", nn_expressivity_ty()),
842        ("nn_generalization_bound", nn_generalization_bound_ty()),
843        ("DoubleDescent", double_descent_ty()),
844        ("benign_overfitting", benign_overfitting_ty()),
845        ("implicit_regularization", implicit_regularization_ty()),
846        ("MinNormInterpolation", min_norm_interpolation_ty()),
847        ("UniformStability", uniform_stability_ty()),
848        ("OnAverageStability", on_average_stability_ty()),
849        (
850            "stability_generalization_bound",
851            stability_generalization_bound_ty(),
852        ),
853        ("DataDeletion", data_deletion_ty()),
854        ("DPSGDAlgorithm", dp_sgd_algorithm_ty()),
855        ("PrivatePACLearning", private_pac_learning_ty()),
856        ("PrivateQueryMechanism", private_query_mechanism_ty()),
857        ("dp_generalization_bound", dp_generalization_bound_ty()),
858        ("DPSampleComplexity", dp_sample_complexity_ty()),
859        ("CalibrationError", calibration_error_ty()),
860        ("ProperScoringRule", proper_scoring_rule_ty()),
861        ("ReliabilityDiagram", reliability_diagram_ty()),
862        ("SharpnessMeasure", sharpness_measure_ty()),
863        ("DoCalculus", do_calculus_ty()),
864        ("InterventionalDist", interventional_dist_ty()),
865        ("BackdoorCriterion", backdoor_criterion_ty()),
866        ("backdoor_adjustment", backdoor_adjustment_ty()),
867        ("ConfoundingBias", confounding_bias_ty()),
868        ("DomainAdaptation", domain_adaptation_ty()),
869        ("CovariateShift", covariate_shift_ty()),
870        ("ImportanceWeighting", importance_weighting_ty()),
871        ("domain_adaptation_bound", domain_adaptation_bound_ty()),
872        ("FederatedLearning", federated_learning_ty()),
873        ("HeterogeneityMeasure", heterogeneity_measure_ty()),
874        ("fedavg_convergence", fedavg_convergence_ty()),
875        ("ByzantineFaultTolerance", byzantine_fault_tolerance_ty()),
876        ("CommunicationComplexity", communication_complexity_ty()),
877    ];
878    for (name, ty) in axioms {
879        env.add(Declaration::Axiom {
880            name: Name::str(*name),
881            univ_params: vec![],
882            ty: ty.clone(),
883        })
884        .ok();
885    }
886    Ok(())
887}
888#[cfg(test)]
889mod extended_tests {
890    use super::*;
891    #[test]
892    fn test_ewa_regret_bound() {
893        let n = 4;
894        let t = 100;
895        let eta = ExponentialWeightsAlgorithm::optimal_eta(n, t);
896        let mut ewa = ExponentialWeightsAlgorithm::new(n, eta);
897        for _ in 0..t {
898            ewa.update(&[0.1, 0.2, 0.3, 0.4]);
899        }
900        let bound = ewa.regret_bound();
901        assert!(bound > 0.0, "EWA regret bound must be positive");
902        assert!(bound.is_finite(), "EWA bound must be finite");
903    }
904    #[test]
905    fn test_ewa_distribution_sums_to_one() {
906        let mut ewa = ExponentialWeightsAlgorithm::new(3, 0.1);
907        ewa.update(&[0.5, 0.1, 0.8]);
908        let dist = ewa.distribution();
909        let sum: f64 = dist.iter().sum();
910        assert!((sum - 1.0).abs() < 1e-9, "EWA distribution must sum to 1");
911    }
912    #[test]
913    fn test_ucb_bandit_selects_all_arms_initially() {
914        let mut bandit = UCBBandit::new(3);
915        let arm0 = bandit.select();
916        bandit.update(arm0, 0.5);
917        let arm1 = bandit.select();
918        bandit.update(arm1, 0.8);
919        let arm2 = bandit.select();
920        bandit.update(arm2, 0.3);
921        assert_eq!(arm0, 0);
922        assert_eq!(arm1, 1);
923        assert_eq!(arm2, 2);
924    }
925    #[test]
926    fn test_ucb_regret_bound_positive() {
927        let mut bandit = UCBBandit::new(2);
928        for i in 0..10 {
929            let arm = bandit.select();
930            bandit.update(arm, if i % 2 == 0 { 1.0 } else { 0.0 });
931        }
932        let bound = bandit.regret_bound_upper();
933        assert!(bound > 0.0 && bound.is_finite());
934    }
935    #[test]
936    fn test_pac_bayes_mcallester_bound() {
937        let pb = PACBayesGeneralization::new(0.5, 1000, 0.05);
938        let bound = pb.mcallester_bound(0.1);
939        assert!(bound > 0.1, "PAC-Bayes bound must exceed empirical loss");
940        assert!(
941            bound < 1.0,
942            "PAC-Bayes bound must be less than 1 for reasonable params"
943        );
944    }
945    #[test]
946    fn test_pac_bayes_catoni_bound() {
947        let pb = PACBayesGeneralization::new(0.3, 500, 0.05);
948        let lam = pb.optimal_lambda(0.1);
949        assert!(lam > 0.0 && lam < 1.0, "optimal lambda must be in (0,1)");
950        let bound = pb.catoni_bound(0.1, lam);
951        assert!(bound > 0.0 && bound.is_finite());
952    }
953    #[test]
954    fn test_kernel_svm_trainer_smo_step() {
955        let labels = vec![1.0, -1.0, 1.0, -1.0];
956        let mut svm = KernelSVMTrainer::new(4, labels, 1.0);
957        let k = vec![
958            vec![1.0, 0.0, 0.0, 0.0],
959            vec![0.0, 1.0, 0.0, 0.0],
960            vec![0.0, 0.0, 1.0, 0.0],
961            vec![0.0, 0.0, 0.0, 1.0],
962        ];
963        let _updated = svm.smo_step(0, 1, &k);
964        for &a in &svm.alphas {
965            assert!(a >= 0.0 && a <= svm.c + 1e-9);
966        }
967    }
968    #[test]
969    fn test_kernel_svm_generalization_bound() {
970        let bound = KernelSVMTrainer::generalization_bound(1.0, 0.5, 100);
971        assert!(
972            (bound - 0.04).abs() < 1e-9,
973            "RΒ²/(Ξ³Β²n) = 1/(0.25*100) = 0.04"
974        );
975    }
976    #[test]
977    fn test_causal_backdoor_adjust() {
978        let bd = CausalBackdoor::new(vec![0.8, 0.4], vec![0.6, 0.4]);
979        assert!(bd.is_valid(), "stratum probs must sum to 1");
980        let adjusted = bd.adjust();
981        assert!(
982            (adjusted - 0.64).abs() < 1e-9,
983            "backdoor adjustment must be 0.64"
984        );
985    }
986    #[test]
987    fn test_causal_backdoor_confounding_bias() {
988        let bd = CausalBackdoor::new(vec![0.8, 0.4], vec![0.6, 0.4]);
989        let bias = bd.confounding_bias(0.75);
990        assert!(
991            (bias - 0.11).abs() < 1e-9,
992            "confounding bias = |0.75 - 0.64| = 0.11"
993        );
994    }
995    #[test]
996    fn test_build_env_extended() {
997        let mut env = Environment::new();
998        let result = build_env_extended(&mut env);
999        assert!(result.is_ok(), "build_env_extended must succeed");
1000        assert!(env.get(&Name::str("EWAAlgorithm")).is_some());
1001        assert!(env.get(&Name::str("UCBAlgorithm")).is_some());
1002        assert!(env.get(&Name::str("ThompsonSampling")).is_some());
1003        assert!(env.get(&Name::str("NeuralNetwork")).is_some());
1004        assert!(env.get(&Name::str("BarronSpace")).is_some());
1005        assert!(env.get(&Name::str("DoubleDescent")).is_some());
1006        assert!(env.get(&Name::str("ByzantineFaultTolerance")).is_some());
1007        assert!(env.get(&Name::str("BackdoorCriterion")).is_some());
1008        assert!(env.get(&Name::str("DPSGDAlgorithm")).is_some());
1009        assert!(env.get(&Name::str("CalibrationError")).is_some());
1010    }
1011}
1012/// Solve a dΓ—d linear system Ax = b using Gaussian elimination with partial pivoting.
1013/// `n` is kept for signature clarity; the key dimension is `d`.
1014pub(super) fn gauss_solve(a: &[Vec<f64>], b: &[f64], d: usize, _n: usize) -> Vec<f64> {
1015    if d == 0 {
1016        return vec![];
1017    }
1018    let mut mat: Vec<Vec<f64>> = (0..d)
1019        .map(|i| {
1020            let mut row = a[i].clone();
1021            row.push(b[i]);
1022            row
1023        })
1024        .collect();
1025    for col in 0..d {
1026        let pivot = (col..d).max_by(|&i, &j| {
1027            mat[i][col]
1028                .abs()
1029                .partial_cmp(&mat[j][col].abs())
1030                .unwrap_or(std::cmp::Ordering::Equal)
1031        });
1032        if let Some(pivot_row) = pivot {
1033            mat.swap(col, pivot_row);
1034        }
1035        let diag = mat[col][col];
1036        if diag.abs() < 1e-12 {
1037            continue;
1038        }
1039        for row in (col + 1)..d {
1040            let factor = mat[row][col] / diag;
1041            for k in col..=d {
1042                let val = mat[col][k] * factor;
1043                mat[row][k] -= val;
1044            }
1045        }
1046    }
1047    let mut x = vec![0.0f64; d];
1048    for i in (0..d).rev() {
1049        let mut sum = mat[i][d];
1050        for j in (i + 1)..d {
1051            sum -= mat[i][j] * x[j];
1052        }
1053        x[i] = if mat[i][i].abs() < 1e-12 {
1054            0.0
1055        } else {
1056            sum / mat[i][i]
1057        };
1058    }
1059    x
1060}
1061#[allow(dead_code)]
1062pub fn dot_ext(a: &[f64], b: &[f64]) -> f64 {
1063    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1064}
1065#[cfg(test)]
1066mod tests_sl_extra {
1067    use super::*;
1068    #[test]
1069    fn test_gaussian_process() {
1070        let gp = GaussianProcess::default_rbf();
1071        let k = gp.rbf_kernel(0.0, 0.0);
1072        assert!((k - 1.0).abs() < 1e-9, "k(x,x) = Οƒ^2 = 1.0");
1073        let k_far = gp.rbf_kernel(0.0, 100.0);
1074        assert!(k_far < 1e-9, "k(0, 100) should be ~0");
1075    }
1076    #[test]
1077    fn test_svm_kernel() {
1078        let svm = SVMClassifier::rbf(1.0, 1.0);
1079        let x = vec![1.0, 0.0];
1080        let xp = vec![1.0, 0.0];
1081        let k = svm.kernel_value(&x, &xp);
1082        assert!((k - 1.0).abs() < 1e-9, "RBF(x,x)=1 for Ξ³=1");
1083    }
1084    #[test]
1085    fn test_gradient_boosting() {
1086        let gb = GradientBoosting::xgboost_style(100);
1087        assert!(gb.is_regularized());
1088        assert_eq!(gb.n_leaves_upper_bound(), 64);
1089    }
1090    #[test]
1091    fn test_cross_validation() {
1092        let cv = CrossValidation::k_fold_5(100);
1093        assert_eq!(cv.fold_size(), 20);
1094        assert_eq!(cv.train_size(), 80);
1095        assert_eq!(cv.n_train_test_splits(), 5);
1096        let loocv = CrossValidation::loocv(10);
1097        assert_eq!(loocv.n_folds, 10);
1098    }
1099}