Skip to main content

oxilean_std/machine_learning/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8    Activation, AdamOptimizer, DecisionStump, ElasticWeightConsolidation, GradientDescent, KMeans,
9    KnnClassifier, Layer, LinearRegression, MomentumSGD, NeuralNetwork, PACBayesBound,
10    PolynomialRegression, RandomizedSmoothingClassifier, ShapleyExplainer, UncertaintySampler,
11    UncertaintyStrategy,
12};
13
14pub fn app(f: Expr, a: Expr) -> Expr {
15    Expr::App(Box::new(f), Box::new(a))
16}
17pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
18    app(app(f, a), b)
19}
20pub fn cst(s: &str) -> Expr {
21    Expr::Const(Name::str(s), vec![])
22}
23pub fn prop() -> Expr {
24    Expr::Sort(Level::zero())
25}
26pub fn type0() -> Expr {
27    Expr::Sort(Level::succ(Level::zero()))
28}
29pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
30    Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
31}
32pub fn arrow(a: Expr, b: Expr) -> Expr {
33    pi(BinderInfo::Default, "_", a, b)
34}
35pub fn nat_ty() -> Expr {
36    cst("Nat")
37}
38pub fn real_ty() -> Expr {
39    cst("Real")
40}
41#[allow(dead_code)]
42pub fn bool_ty() -> Expr {
43    cst("Bool")
44}
45pub fn list_ty(elem: Expr) -> Expr {
46    app(cst("List"), elem)
47}
48pub fn learner_ty() -> Expr {
49    type0()
50}
51pub fn vc_dimension_ty() -> Expr {
52    arrow(type0(), nat_ty())
53}
54pub fn pac_learnable_ty() -> Expr {
55    arrow(type0(), prop())
56}
57pub fn neural_network_ty() -> Expr {
58    type0()
59}
60pub fn gradient_ty() -> Expr {
61    arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
62}
63pub fn kernel_method_ty() -> Expr {
64    type0()
65}
66pub fn loss_function_ty() -> Expr {
67    arrow(real_ty(), arrow(real_ty(), real_ty()))
68}
69pub fn regularizer_ty() -> Expr {
70    arrow(list_ty(real_ty()), real_ty())
71}
72pub fn cross_validation_ty() -> Expr {
73    arrow(nat_ty(), arrow(type0(), real_ty()))
74}
75pub fn fundamental_thm_pac_ty() -> Expr {
76    pi(
77        BinderInfo::Default,
78        "H",
79        type0(),
80        app2(
81            cst("Iff"),
82            app(cst("PACLearnable"), cst("H")),
83            app2(
84                cst("Nat.lt"),
85                app(cst("VCDimension"), cst("H")),
86                cst("Nat.infinity"),
87            ),
88        ),
89    )
90}
91pub fn universal_approximation_ty() -> Expr {
92    prop()
93}
94pub fn vc_bound_ty() -> Expr {
95    pi(
96        BinderInfo::Default,
97        "eps",
98        real_ty(),
99        pi(
100            BinderInfo::Default,
101            "delta",
102            real_ty(),
103            arrow(
104                app2(cst("Real.lt"), cst("Real.zero"), cst("eps")),
105                arrow(
106                    app2(cst("Real.lt"), cst("Real.zero"), cst("delta")),
107                    app(
108                        cst("Exists"),
109                        pi(
110                            BinderInfo::Default,
111                            "m",
112                            nat_ty(),
113                            app2(
114                                app(cst("GeneralizationBound"), cst("m")),
115                                cst("eps"),
116                                cst("delta"),
117                            ),
118                        ),
119                    ),
120                ),
121            ),
122        ),
123    )
124}
125pub fn no_free_lunch_ty() -> Expr {
126    prop()
127}
128pub fn bias_variance_tradeoff_ty() -> Expr {
129    prop()
130}
131pub fn regularization_convergence_ty() -> Expr {
132    prop()
133}
134pub fn build_machine_learning_env(env: &mut Environment) -> Result<(), Box<dyn std::error::Error>> {
135    let axioms: &[(&str, Expr)] = &[
136        ("Learner", learner_ty()),
137        ("VCDimension", vc_dimension_ty()),
138        ("PACLearnable", pac_learnable_ty()),
139        ("NeuralNetwork", neural_network_ty()),
140        ("Gradient", gradient_ty()),
141        ("KernelMethod", kernel_method_ty()),
142        ("LossFunction", loss_function_ty()),
143        ("Regularizer", regularizer_ty()),
144        ("CrossValidation", cross_validation_ty()),
145        (
146            "GeneralizationBound",
147            arrow(nat_ty(), arrow(real_ty(), arrow(real_ty(), prop()))),
148        ),
149        ("Real.zero", real_ty()),
150        ("Nat.infinity", nat_ty()),
151        ("fundamental_thm_pac", fundamental_thm_pac_ty()),
152        ("universal_approximation", universal_approximation_ty()),
153        ("vc_bound", vc_bound_ty()),
154        ("no_free_lunch", no_free_lunch_ty()),
155        ("bias_variance_tradeoff", bias_variance_tradeoff_ty()),
156        (
157            "regularization_convergence",
158            regularization_convergence_ty(),
159        ),
160    ];
161    for (name, ty) in axioms {
162        env.add(Declaration::Axiom {
163            name: Name::str(*name),
164            univ_params: vec![],
165            ty: ty.clone(),
166        })
167        .ok();
168    }
169    Ok(())
170}
171pub fn mse_loss(y_pred: &[f64], y_true: &[f64]) -> f64 {
172    let n = y_pred.len().min(y_true.len());
173    if n == 0 {
174        return 0.0;
175    }
176    let sum: f64 = y_pred
177        .iter()
178        .zip(y_true.iter())
179        .map(|(p, t)| (p - t).powi(2))
180        .sum();
181    sum / n as f64
182}
183pub fn mae_loss(y_pred: &[f64], y_true: &[f64]) -> f64 {
184    let n = y_pred.len().min(y_true.len());
185    if n == 0 {
186        return 0.0;
187    }
188    let sum: f64 = y_pred
189        .iter()
190        .zip(y_true.iter())
191        .map(|(p, t)| (p - t).abs())
192        .sum();
193    sum / n as f64
194}
195pub fn binary_cross_entropy(y_pred: &[f64], y_true: &[f64]) -> f64 {
196    let n = y_pred.len().min(y_true.len());
197    if n == 0 {
198        return 0.0;
199    }
200    let eps = 1e-15;
201    let sum: f64 = y_pred
202        .iter()
203        .zip(y_true.iter())
204        .map(|(&p, &t)| {
205            let p_clamped = p.clamp(eps, 1.0 - eps);
206            -(t * p_clamped.ln() + (1.0 - t) * (1.0 - p_clamped).ln())
207        })
208        .sum();
209    sum / n as f64
210}
211pub fn hinge_loss(y_pred: &[f64], y_true: &[f64]) -> f64 {
212    let n = y_pred.len().min(y_true.len());
213    if n == 0 {
214        return 0.0;
215    }
216    let sum: f64 = y_pred
217        .iter()
218        .zip(y_true.iter())
219        .map(|(&p, &t)| (1.0 - t * p).max(0.0))
220        .sum();
221    sum / n as f64
222}
223pub fn huber_loss(y_pred: &[f64], y_true: &[f64], delta: f64) -> f64 {
224    let n = y_pred.len().min(y_true.len());
225    if n == 0 {
226        return 0.0;
227    }
228    let sum: f64 = y_pred
229        .iter()
230        .zip(y_true.iter())
231        .map(|(&p, &t)| {
232            let err = (p - t).abs();
233            if err <= delta {
234                0.5 * err * err
235            } else {
236                delta * err - 0.5 * delta * delta
237            }
238        })
239        .sum();
240    sum / n as f64
241}
242pub fn min_max_normalize(data: &[Vec<f64>]) -> Vec<Vec<f64>> {
243    if data.is_empty() {
244        return vec![];
245    }
246    let dim = data[0].len();
247    let mut mins = vec![f64::INFINITY; dim];
248    let mut maxs = vec![f64::NEG_INFINITY; dim];
249    for row in data {
250        for (j, &v) in row.iter().enumerate() {
251            if j < dim {
252                mins[j] = mins[j].min(v);
253                maxs[j] = maxs[j].max(v);
254            }
255        }
256    }
257    data.iter()
258        .map(|row| {
259            row.iter()
260                .enumerate()
261                .map(|(j, &v)| {
262                    let range = maxs[j] - mins[j];
263                    if range.abs() < 1e-15 {
264                        0.0
265                    } else {
266                        (v - mins[j]) / range
267                    }
268                })
269                .collect()
270        })
271        .collect()
272}
273pub fn z_score_normalize(data: &[Vec<f64>]) -> Vec<Vec<f64>> {
274    if data.is_empty() {
275        return vec![];
276    }
277    let dim = data[0].len();
278    let n = data.len() as f64;
279    let mut means = vec![0.0f64; dim];
280    for row in data {
281        for (j, &v) in row.iter().enumerate() {
282            if j < dim {
283                means[j] += v;
284            }
285        }
286    }
287    for m in &mut means {
288        *m /= n;
289    }
290    let mut stds = vec![0.0f64; dim];
291    for row in data {
292        for (j, &v) in row.iter().enumerate() {
293            if j < dim {
294                stds[j] += (v - means[j]).powi(2);
295            }
296        }
297    }
298    for s in &mut stds {
299        *s = (*s / n).sqrt();
300    }
301    data.iter()
302        .map(|row| {
303            row.iter()
304                .enumerate()
305                .map(|(j, &v)| {
306                    if stds[j].abs() < 1e-15 {
307                        0.0
308                    } else {
309                        (v - means[j]) / stds[j]
310                    }
311                })
312                .collect()
313        })
314        .collect()
315}
316pub fn backprop_mse(
317    net: &NeuralNetwork,
318    input: &[f64],
319    target: &[f64],
320) -> (Vec<Vec<Vec<f64>>>, Vec<Vec<f64>>) {
321    let (activations, z_cache, _) = net.forward_cached(input);
322    let n_layers = net.layers.len();
323    let mut weight_grads: Vec<Vec<Vec<f64>>> = Vec::with_capacity(n_layers);
324    let mut bias_grads: Vec<Vec<f64>> = Vec::with_capacity(n_layers);
325    for layer in &net.layers {
326        weight_grads.push(vec![vec![0.0; layer.n_inputs()]; layer.n_outputs()]);
327        bias_grads.push(vec![0.0; layer.n_outputs()]);
328    }
329    let output = &activations[n_layers];
330    let mut delta: Vec<f64> = output
331        .iter()
332        .zip(target.iter())
333        .enumerate()
334        .map(|(j, (&a, &t))| {
335            let dl_da = 2.0 * (a - t) / target.len() as f64;
336            let da_dz = net.layers[n_layers - 1]
337                .activation
338                .derivative(z_cache[n_layers - 1][j]);
339            dl_da * da_dz
340        })
341        .collect();
342    for l in (0..n_layers).rev() {
343        let input_to_layer = &activations[l];
344        for j in 0..delta.len() {
345            bias_grads[l][j] = delta[j];
346            for k in 0..input_to_layer.len() {
347                weight_grads[l][j][k] = delta[j] * input_to_layer[k];
348            }
349        }
350        if l > 0 {
351            let layer = &net.layers[l];
352            let prev_z = &z_cache[l - 1];
353            let mut new_delta = vec![0.0; prev_z.len()];
354            for k in 0..prev_z.len() {
355                let mut sum = 0.0;
356                for j in 0..delta.len() {
357                    if k < layer.weights[j].len() {
358                        sum += delta[j] * layer.weights[j][k];
359                    }
360                }
361                new_delta[k] = sum * net.layers[l - 1].activation.derivative(prev_z[k]);
362            }
363            delta = new_delta;
364        }
365    }
366    (weight_grads, bias_grads)
367}
368pub fn train_network(
369    net: &mut NeuralNetwork,
370    inputs: &[Vec<f64>],
371    targets: &[Vec<f64>],
372    learning_rate: f64,
373    epochs: u32,
374) -> Vec<f64> {
375    let mut loss_history = Vec::new();
376    let n = inputs.len().min(targets.len());
377    if n == 0 {
378        return loss_history;
379    }
380    for _epoch in 0..epochs {
381        let mut total_loss = 0.0;
382        for i in 0..n {
383            let output = net.forward(&inputs[i]);
384            total_loss += mse_loss(&output, &targets[i]);
385            let (w_grads, b_grads) = backprop_mse(net, &inputs[i], &targets[i]);
386            for (l, layer) in net.layers.iter_mut().enumerate() {
387                for j in 0..layer.weights.len() {
388                    for k in 0..layer.weights[j].len() {
389                        layer.weights[j][k] -= learning_rate * w_grads[l][j][k];
390                    }
391                    layer.biases[j] -= learning_rate * b_grads[l][j];
392                }
393            }
394        }
395        loss_history.push(total_loss / n as f64);
396    }
397    loss_history
398}
399pub fn accuracy(predicted: &[usize], actual: &[usize]) -> f64 {
400    let n = predicted.len().min(actual.len());
401    if n == 0 {
402        return 0.0;
403    }
404    let correct = predicted
405        .iter()
406        .zip(actual.iter())
407        .filter(|(p, a)| p == a)
408        .count();
409    correct as f64 / n as f64
410}
411pub fn precision(predicted: &[usize], actual: &[usize], class: usize) -> f64 {
412    let tp = predicted
413        .iter()
414        .zip(actual.iter())
415        .filter(|(&p, &a)| p == class && a == class)
416        .count();
417    let pp = predicted.iter().filter(|&&p| p == class).count();
418    if pp == 0 {
419        0.0
420    } else {
421        tp as f64 / pp as f64
422    }
423}
424pub fn recall(predicted: &[usize], actual: &[usize], class: usize) -> f64 {
425    let tp = predicted
426        .iter()
427        .zip(actual.iter())
428        .filter(|(&p, &a)| p == class && a == class)
429        .count();
430    let ap = actual.iter().filter(|&&a| a == class).count();
431    if ap == 0 {
432        0.0
433    } else {
434        tp as f64 / ap as f64
435    }
436}
437pub fn f1_score(predicted: &[usize], actual: &[usize], class: usize) -> f64 {
438    let p = precision(predicted, actual, class);
439    let r = recall(predicted, actual, class);
440    if p + r < 1e-15 {
441        0.0
442    } else {
443        2.0 * p * r / (p + r)
444    }
445}
446pub fn l2_penalty(weights: &[f64], lambda: f64) -> f64 {
447    lambda * weights.iter().map(|w| w * w).sum::<f64>()
448}
449pub fn l1_penalty(weights: &[f64], lambda: f64) -> f64 {
450    lambda * weights.iter().map(|w| w.abs()).sum::<f64>()
451}
452pub fn elastic_net_penalty(weights: &[f64], lambda: f64, alpha: f64) -> f64 {
453    alpha * l1_penalty(weights, lambda) + (1.0 - alpha) * l2_penalty(weights, lambda)
454}
455pub fn train_test_split<T: Clone>(data: &[T], frac: f64) -> (Vec<T>, Vec<T>) {
456    let split = ((data.len() as f64) * frac).round() as usize;
457    let split = split.min(data.len());
458    (data[..split].to_vec(), data[split..].to_vec())
459}
460pub fn k_fold_indices(n: usize, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
461    if k == 0 || n == 0 {
462        return vec![];
463    }
464    let fold_size = n / k;
465    (0..k)
466        .map(|i| {
467            let test_start = i * fold_size;
468            let test_end = if i == k - 1 { n } else { (i + 1) * fold_size };
469            let test_indices: Vec<usize> = (test_start..test_end).collect();
470            let train_indices: Vec<usize> =
471                (0..n).filter(|idx| !test_indices.contains(idx)).collect();
472            (train_indices, test_indices)
473        })
474        .collect()
475}
476pub fn neural_tangent_kernel_ty() -> Expr {
477    arrow(
478        arrow(real_ty(), real_ty()),
479        arrow(arrow(real_ty(), real_ty()), real_ty()),
480    )
481}
482pub fn lazy_training_regime_ty() -> Expr {
483    arrow(nat_ty(), prop())
484}
485pub fn mean_field_limit_ty() -> Expr {
486    arrow(nat_ty(), arrow(real_ty(), prop()))
487}
488pub fn infinite_width_limit_ty() -> Expr {
489    arrow(nat_ty(), prop())
490}
491pub fn loss_landscape_ty() -> Expr {
492    arrow(list_ty(real_ty()), real_ty())
493}
494pub fn autoencoder_ty() -> Expr {
495    arrow(type0(), arrow(nat_ty(), type0()))
496}
497pub fn contrastive_loss_ty() -> Expr {
498    arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), real_ty())))
499}
500pub fn self_supervised_objective_ty() -> Expr {
501    arrow(type0(), prop())
502}
503pub fn representation_collapse_ty() -> Expr {
504    arrow(type0(), prop())
505}
506pub fn disentanglement_ty() -> Expr {
507    arrow(nat_ty(), arrow(type0(), prop()))
508}
509pub fn message_passing_ty() -> Expr {
510    arrow(nat_ty(), arrow(type0(), type0()))
511}
512pub fn wl_expressiveness_ty() -> Expr {
513    arrow(nat_ty(), prop())
514}
515pub fn graph_isomorphism_power_ty() -> Expr {
516    arrow(nat_ty(), prop())
517}
518pub fn over_smoothing_ty() -> Expr {
519    arrow(nat_ty(), arrow(real_ty(), prop()))
520}
521pub fn attention_mechanism_ty() -> Expr {
522    arrow(
523        list_ty(real_ty()),
524        arrow(list_ty(real_ty()), list_ty(real_ty())),
525    )
526}
527pub fn transformer_universality_ty() -> Expr {
528    arrow(nat_ty(), prop())
529}
530pub fn in_context_learning_ty() -> Expr {
531    arrow(nat_ty(), arrow(type0(), prop()))
532}
533pub fn positional_encoding_ty() -> Expr {
534    arrow(nat_ty(), arrow(nat_ty(), list_ty(real_ty())))
535}
536pub fn pac_mdp_ty() -> Expr {
537    arrow(nat_ty(), arrow(real_ty(), arrow(real_ty(), prop())))
538}
539pub fn regret_bound_ty() -> Expr {
540    arrow(nat_ty(), arrow(real_ty(), prop()))
541}
542pub fn sample_complexity_rl_ty() -> Expr {
543    arrow(real_ty(), arrow(real_ty(), nat_ty()))
544}
545pub fn exploration_exploitation_ty() -> Expr {
546    arrow(nat_ty(), prop())
547}
548pub fn vae_elbo_ty() -> Expr {
549    arrow(type0(), arrow(type0(), real_ty()))
550}
551pub fn gan_equilibrium_ty() -> Expr {
552    arrow(type0(), arrow(type0(), prop()))
553}
554pub fn normalizing_flow_ty() -> Expr {
555    arrow(nat_ty(), arrow(type0(), type0()))
556}
557pub fn diffusion_process_ty() -> Expr {
558    arrow(nat_ty(), arrow(real_ty(), arrow(type0(), type0())))
559}
560pub fn score_matching_ty() -> Expr {
561    arrow(type0(), arrow(real_ty(), list_ty(real_ty())))
562}
563pub fn maml_convergence_ty() -> Expr {
564    arrow(nat_ty(), arrow(real_ty(), prop()))
565}
566pub fn few_shot_bound_ty() -> Expr {
567    arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop())))
568}
569pub fn task_distribution_ty() -> Expr {
570    arrow(type0(), real_ty())
571}
572pub fn catastrophic_forgetting_ty() -> Expr {
573    arrow(nat_ty(), prop())
574}
575pub fn ewc_regularizer_ty() -> Expr {
576    arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
577}
578pub fn memory_replay_bound_ty() -> Expr {
579    arrow(nat_ty(), arrow(real_ty(), prop()))
580}
581pub fn negative_transfer_ty() -> Expr {
582    arrow(type0(), arrow(type0(), prop()))
583}
584pub fn task_relatedness_ty() -> Expr {
585    arrow(type0(), arrow(type0(), real_ty()))
586}
587pub fn transfer_excess_risk_ty() -> Expr {
588    arrow(type0(), arrow(type0(), arrow(nat_ty(), real_ty())))
589}
590pub fn demographic_parity_ty() -> Expr {
591    arrow(type0(), arrow(type0(), prop()))
592}
593pub fn equalized_odds_ty() -> Expr {
594    arrow(type0(), arrow(type0(), prop()))
595}
596pub fn individual_fairness_ty() -> Expr {
597    arrow(arrow(type0(), real_ty()), prop())
598}
599pub fn fairness_accuracy_tradeoff_ty() -> Expr {
600    prop()
601}
602pub fn shapley_value_ty() -> Expr {
603    arrow(
604        arrow(list_ty(real_ty()), real_ty()),
605        arrow(nat_ty(), real_ty()),
606    )
607}
608pub fn shap_attribution_ty() -> Expr {
609    arrow(type0(), arrow(list_ty(real_ty()), list_ty(real_ty())))
610}
611pub fn counterfactual_explanation_ty() -> Expr {
612    arrow(list_ty(real_ty()), arrow(type0(), list_ty(real_ty())))
613}
614pub fn lp_adversarial_attack_ty() -> Expr {
615    arrow(real_ty(), arrow(list_ty(real_ty()), list_ty(real_ty())))
616}
617pub fn certified_defense_ty() -> Expr {
618    arrow(real_ty(), arrow(type0(), prop()))
619}
620pub fn randomized_smoothing_ty() -> Expr {
621    arrow(real_ty(), arrow(type0(), arrow(real_ty(), prop())))
622}
623pub fn nas_search_space_ty() -> Expr {
624    arrow(nat_ty(), type0())
625}
626pub fn one_shot_nas_ty() -> Expr {
627    arrow(type0(), arrow(nat_ty(), prop()))
628}
629pub fn cell_based_nas_ty() -> Expr {
630    arrow(nat_ty(), arrow(nat_ty(), type0()))
631}
632pub fn query_complexity_ty() -> Expr {
633    arrow(real_ty(), arrow(real_ty(), nat_ty()))
634}
635pub fn uncertainty_sampling_ty() -> Expr {
636    arrow(type0(), arrow(list_ty(real_ty()), real_ty()))
637}
638pub fn optimal_stopping_al_ty() -> Expr {
639    arrow(nat_ty(), arrow(real_ty(), prop()))
640}
641pub fn pac_bayes_mcallester_ty() -> Expr {
642    arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), real_ty())))
643}
644pub fn pac_bayes_catoni_ty() -> Expr {
645    arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), real_ty())))
646}
647pub fn data_dependent_prior_ty() -> Expr {
648    arrow(list_ty(type0()), arrow(type0(), real_ty()))
649}
650pub fn kl_divergence_bound_ty() -> Expr {
651    arrow(type0(), arrow(type0(), arrow(nat_ty(), real_ty())))
652}
653pub fn register_advanced_ml_axioms(env: &mut Environment) {
654    let axioms: &[(&str, Expr)] = &[
655        ("NeuralTangentKernel", neural_tangent_kernel_ty()),
656        ("LazyTrainingRegime", lazy_training_regime_ty()),
657        ("MeanFieldLimit", mean_field_limit_ty()),
658        ("InfiniteWidthLimit", infinite_width_limit_ty()),
659        ("LossLandscape", loss_landscape_ty()),
660        ("Autoencoder", autoencoder_ty()),
661        ("ContrastiveLoss", contrastive_loss_ty()),
662        ("SelfSupervisedObjective", self_supervised_objective_ty()),
663        ("RepresentationCollapse", representation_collapse_ty()),
664        ("Disentanglement", disentanglement_ty()),
665        ("MessagePassing", message_passing_ty()),
666        ("WLExpressiveness", wl_expressiveness_ty()),
667        ("GraphIsomorphismPower", graph_isomorphism_power_ty()),
668        ("OverSmoothing", over_smoothing_ty()),
669        ("AttentionMechanism", attention_mechanism_ty()),
670        ("TransformerUniversality", transformer_universality_ty()),
671        ("InContextLearning", in_context_learning_ty()),
672        ("PositionalEncoding", positional_encoding_ty()),
673        ("PACMDP", pac_mdp_ty()),
674        ("RegretBound", regret_bound_ty()),
675        ("SampleComplexityRL", sample_complexity_rl_ty()),
676        ("ExplorationExploitation", exploration_exploitation_ty()),
677        ("VAELBO", vae_elbo_ty()),
678        ("GANEquilibrium", gan_equilibrium_ty()),
679        ("NormalizingFlow", normalizing_flow_ty()),
680        ("DiffusionProcess", diffusion_process_ty()),
681        ("ScoreMatching", score_matching_ty()),
682        ("MAMLConvergence", maml_convergence_ty()),
683        ("FewShotBound", few_shot_bound_ty()),
684        ("TaskDistribution", task_distribution_ty()),
685        ("CatastrophicForgetting", catastrophic_forgetting_ty()),
686        ("EWCRegularizer", ewc_regularizer_ty()),
687        ("MemoryReplayBound", memory_replay_bound_ty()),
688        ("NegativeTransfer", negative_transfer_ty()),
689        ("TaskRelatedness", task_relatedness_ty()),
690        ("TransferExcessRisk", transfer_excess_risk_ty()),
691        ("DemographicParity", demographic_parity_ty()),
692        ("EqualizedOdds", equalized_odds_ty()),
693        ("IndividualFairness", individual_fairness_ty()),
694        ("FairnessAccuracyTradeoff", fairness_accuracy_tradeoff_ty()),
695        ("ShapleyValue", shapley_value_ty()),
696        ("SHAPAttribution", shap_attribution_ty()),
697        ("CounterfactualExplanation", counterfactual_explanation_ty()),
698        ("LpAdversarialAttack", lp_adversarial_attack_ty()),
699        ("CertifiedDefense", certified_defense_ty()),
700        ("RandomizedSmoothing", randomized_smoothing_ty()),
701        ("NASSearchSpace", nas_search_space_ty()),
702        ("OneShotNAS", one_shot_nas_ty()),
703        ("CellBasedNAS", cell_based_nas_ty()),
704        ("QueryComplexity", query_complexity_ty()),
705        ("UncertaintySampling", uncertainty_sampling_ty()),
706        ("OptimalStoppingAL", optimal_stopping_al_ty()),
707        ("PACBayesMcAllester", pac_bayes_mcallester_ty()),
708        ("PACBayesCatoni", pac_bayes_catoni_ty()),
709        ("DataDependentPrior", data_dependent_prior_ty()),
710        ("KLDivergenceBound", kl_divergence_bound_ty()),
711    ];
712    for (name, ty) in axioms {
713        env.add(Declaration::Axiom {
714            name: Name::str(*name),
715            univ_params: vec![],
716            ty: ty.clone(),
717        })
718        .ok();
719    }
720}
721#[cfg(test)]
722mod tests {
723    use super::*;
724    #[test]
725    fn test_activation_relu() {
726        let a = Activation::ReLU;
727        assert_eq!(a.apply(-1.0), 0.0);
728        assert_eq!(a.apply(2.0), 2.0);
729        assert_eq!(a.apply(0.0), 0.0);
730    }
731    #[test]
732    fn test_activation_sigmoid() {
733        let a = Activation::Sigmoid;
734        assert!((a.apply(0.0) - 0.5).abs() < 1e-10);
735        assert!(a.apply(100.0) > 0.999);
736        assert!(a.apply(-100.0) < 0.001);
737    }
738    #[test]
739    fn test_activation_leaky_relu() {
740        let a = Activation::LeakyReLU;
741        assert_eq!(a.apply(5.0), 5.0);
742        assert!((a.apply(-10.0) - (-0.1)).abs() < 1e-10);
743    }
744    #[test]
745    fn test_activation_elu() {
746        let a = Activation::ELU;
747        assert_eq!(a.apply(3.0), 3.0);
748        assert!(a.apply(-1.0) < 0.0);
749    }
750    #[test]
751    fn test_softmax() {
752        let vals = vec![1.0, 2.0, 3.0];
753        let sm = Activation::apply_softmax(&vals);
754        let sum: f64 = sm.iter().sum();
755        assert!((sum - 1.0).abs() < 1e-10);
756        assert!(sm[2] > sm[1] && sm[1] > sm[0]);
757    }
758    #[test]
759    fn test_layer_forward() {
760        let layer = Layer::from_weights(
761            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
762            vec![0.0, 0.0],
763            Activation::Linear,
764        );
765        let out = layer.forward(&[3.0, 5.0]);
766        assert!((out[0] - 3.0).abs() < 1e-10);
767        assert!((out[1] - 5.0).abs() < 1e-10);
768    }
769    #[test]
770    fn test_layer_with_cache() {
771        let layer = Layer::from_weights(
772            vec![vec![2.0, 0.0], vec![0.0, 3.0]],
773            vec![1.0, -1.0],
774            Activation::ReLU,
775        );
776        let (z, a) = layer.forward_with_cache(&[1.0, 2.0]);
777        assert!((z[0] - 3.0).abs() < 1e-10);
778        assert!((z[1] - 5.0).abs() < 1e-10);
779        assert!((a[0] - 3.0).abs() < 1e-10);
780        assert!((a[1] - 5.0).abs() < 1e-10);
781    }
782    #[test]
783    fn test_neural_network_forward() {
784        let layer = Layer::from_weights(
785            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
786            vec![0.0, 0.0],
787            Activation::Linear,
788        );
789        let net = NeuralNetwork::new(vec![layer]);
790        let out = net.forward(&[1.0, 2.0]);
791        assert!((out[0] - 1.0).abs() < 1e-10);
792        assert!((out[1] - 2.0).abs() < 1e-10);
793        assert_eq!(net.predict_class(&[1.0, 2.0]), 1);
794        assert_eq!(net.depth(), 1);
795    }
796    #[test]
797    fn test_backprop_reduces_loss() {
798        let layer = Layer::from_weights(vec![vec![0.5]], vec![0.0], Activation::Linear);
799        let mut net = NeuralNetwork::new(vec![layer]);
800        let inputs = vec![vec![1.0], vec![2.0], vec![3.0]];
801        let targets = vec![vec![1.0], vec![2.0], vec![3.0]];
802        let history = train_network(&mut net, &inputs, &targets, 0.01, 100);
803        assert!(
804            history.last().expect("last should succeed")
805                < history.first().expect("first should succeed")
806        );
807    }
808    #[test]
809    fn test_loss_functions() {
810        let pred = vec![1.0, 2.0, 3.0];
811        let true_vals = vec![1.0, 2.0, 3.0];
812        assert!(mse_loss(&pred, &true_vals) < 1e-10);
813        assert!(mae_loss(&pred, &true_vals) < 1e-10);
814    }
815    #[test]
816    fn test_binary_cross_entropy() {
817        let pred = vec![0.9, 0.1];
818        let true_vals = vec![1.0, 0.0];
819        let bce = binary_cross_entropy(&pred, &true_vals);
820        let bad_pred = vec![0.1, 0.9];
821        let bad_bce = binary_cross_entropy(&bad_pred, &true_vals);
822        assert!(bad_bce > bce);
823    }
824    #[test]
825    fn test_hinge_loss() {
826        let pred = vec![2.0, -2.0];
827        let true_labels = vec![1.0, -1.0];
828        assert!(hinge_loss(&pred, &true_labels) < 1e-10);
829    }
830    #[test]
831    fn test_huber_loss() {
832        let pred = vec![1.0, 2.0];
833        let true_vals = vec![1.0, 2.0];
834        assert!(huber_loss(&pred, &true_vals, 1.0) < 1e-10);
835    }
836    #[test]
837    fn test_knn_classifier_simple() {
838        let mut knn = KnnClassifier::new(1);
839        knn.fit(vec![
840            (vec![0.0, 0.0], 0),
841            (vec![1.0, 0.0], 1),
842            (vec![0.0, 1.0], 0),
843        ]);
844        assert_eq!(knn.predict(&[0.1, 0.1]), 0);
845        assert_eq!(knn.predict(&[0.9, 0.1]), 1);
846    }
847    #[test]
848    fn test_knn_predict_proba() {
849        let mut knn = KnnClassifier::new(3);
850        knn.fit(vec![(vec![0.0], 0), (vec![0.1], 0), (vec![0.2], 1)]);
851        let proba = knn.predict_proba(&[0.05]);
852        let p0 = proba.get(&0).copied().unwrap_or(0.0);
853        assert!(p0 >= 0.5);
854    }
855    #[test]
856    fn test_decision_stump() {
857        let stump = DecisionStump::new(0, 5.0, 1);
858        assert_eq!(stump.predict(&[6.0]), 1);
859        assert_eq!(stump.predict(&[4.0]), 0);
860    }
861    #[test]
862    fn test_decision_stump_find_best() {
863        let data = vec![
864            (vec![1.0], 0),
865            (vec![2.0], 0),
866            (vec![3.0], 1),
867            (vec![4.0], 1),
868        ];
869        let weights = vec![0.25, 0.25, 0.25, 0.25];
870        let best = DecisionStump::find_best(&data, &weights);
871        let correct: usize = data.iter().filter(|(x, y)| best.predict(x) == *y).count();
872        assert!(correct >= 3);
873    }
874    #[test]
875    fn test_kmeans_fit_2_clusters() {
876        let data = vec![
877            vec![0.0, 0.0],
878            vec![0.1, 0.0],
879            vec![10.0, 10.0],
880            vec![10.1, 10.0],
881        ];
882        let mut km = KMeans::new(2, 100);
883        let assignments = km.fit(&data, 42);
884        assert_eq!(assignments[0], assignments[1]);
885        assert_eq!(assignments[2], assignments[3]);
886        assert_ne!(assignments[0], assignments[2]);
887    }
888    #[test]
889    fn test_kmeans_inertia() {
890        let data = vec![vec![0.0], vec![10.0]];
891        let mut km = KMeans::new(2, 100);
892        km.fit(&data, 0);
893        assert!(km.inertia(&data) < 1.0);
894    }
895    #[test]
896    fn test_linear_regression_fit_with_bias() {
897        let x_data = vec![vec![1.0], vec![2.0], vec![3.0]];
898        let y_data = vec![3.0, 5.0, 7.0];
899        let model = LinearRegression::fit_least_squares(&x_data, &y_data);
900        assert!((model.weights[0] - 2.0).abs() < 1e-8);
901        assert!((model.bias - 1.0).abs() < 1e-8);
902        assert!(model.r_squared(&x_data, &y_data) > 0.999);
903    }
904    #[test]
905    fn test_linear_regression_mse() {
906        let model = LinearRegression {
907            weights: vec![1.0],
908            bias: 0.0,
909        };
910        assert!(model.mse(&[vec![1.0], vec![2.0]], &[1.0, 2.0]) < 1e-10);
911    }
912    #[test]
913    fn test_polynomial_regression() {
914        let x_data: Vec<f64> = (0..10).map(|i| i as f64).collect();
915        let y_data: Vec<f64> = x_data.iter().map(|&x| x * x).collect();
916        let model = PolynomialRegression::fit(&x_data, &y_data, 2, 0.0001, 5000);
917        let pred = model.predict(5.0);
918        assert!((pred - 25.0).abs() < 5.0, "Got {}", pred);
919    }
920    #[test]
921    fn test_gradient_descent_quadratic() {
922        let gd = GradientDescent::new(0.1, 10_000);
923        let x_min = gd.minimize_quadratic(1.0, 0.0, 5.0);
924        assert!(x_min.abs() < 1e-4, "got {x_min}");
925    }
926    #[test]
927    fn test_gradient_descent_numerical() {
928        let gd = GradientDescent::new(0.01, 10_000);
929        let x_min = gd.minimize_numerical(&|x: f64| (x - 3.0).powi(2), 0.0);
930        assert!((x_min - 3.0).abs() < 0.1, "got {x_min}");
931    }
932    #[test]
933    fn test_momentum_sgd() {
934        let opt = MomentumSGD::new(0.01, 0.9, 10_000);
935        let x_min = opt.minimize_quadratic(1.0, 0.0, 10.0);
936        assert!(x_min.abs() < 0.1, "got {x_min}");
937    }
938    #[test]
939    fn test_adam_optimizer() {
940        let opt = AdamOptimizer::new(0.1, 10_000);
941        let x_min = opt.minimize_quadratic(1.0, 0.0, 10.0);
942        assert!(x_min.abs() < 0.1, "got {x_min}");
943    }
944    #[test]
945    fn test_min_max_normalize() {
946        let data = vec![vec![0.0, 10.0], vec![5.0, 20.0], vec![10.0, 30.0]];
947        let normed = min_max_normalize(&data);
948        assert!((normed[0][0]).abs() < 1e-10);
949        assert!((normed[2][0] - 1.0).abs() < 1e-10);
950        assert!((normed[1][0] - 0.5).abs() < 1e-10);
951    }
952    #[test]
953    fn test_z_score_normalize() {
954        let data = vec![vec![1.0], vec![2.0], vec![3.0]];
955        let normed = z_score_normalize(&data);
956        let mean: f64 = normed.iter().map(|r| r[0]).sum::<f64>() / 3.0;
957        assert!(mean.abs() < 1e-10);
958    }
959    #[test]
960    fn test_accuracy_metric() {
961        assert!((accuracy(&[0, 1, 1, 0], &[0, 1, 0, 0]) - 0.75).abs() < 1e-10);
962    }
963    #[test]
964    fn test_precision_recall_f1() {
965        let pred = vec![1, 1, 0, 0, 1];
966        let actual = vec![1, 0, 0, 1, 1];
967        let p = precision(&pred, &actual, 1);
968        let r = recall(&pred, &actual, 1);
969        let _f1 = f1_score(&pred, &actual, 1);
970        assert!((p - 2.0 / 3.0).abs() < 1e-10);
971        assert!((r - 2.0 / 3.0).abs() < 1e-10);
972    }
973    #[test]
974    fn test_regularization() {
975        let w = vec![1.0, 2.0, 3.0];
976        assert!((l2_penalty(&w, 0.1) - 1.4).abs() < 1e-10);
977        assert!((l1_penalty(&w, 0.1) - 0.6).abs() < 1e-10);
978    }
979    #[test]
980    fn test_elastic_net() {
981        let w = vec![1.0, 2.0];
982        assert!((elastic_net_penalty(&w, 1.0, 0.0) - l2_penalty(&w, 1.0)).abs() < 1e-10);
983        assert!((elastic_net_penalty(&w, 1.0, 1.0) - l1_penalty(&w, 1.0)).abs() < 1e-10);
984    }
985    #[test]
986    fn test_train_test_split() {
987        let data: Vec<i32> = (0..10).collect();
988        let (train, test) = train_test_split(&data, 0.8);
989        assert_eq!(train.len(), 8);
990        assert_eq!(test.len(), 2);
991    }
992    #[test]
993    fn test_k_fold_indices() {
994        let folds = k_fold_indices(10, 5);
995        assert_eq!(folds.len(), 5);
996        for (train, test) in &folds {
997            assert_eq!(test.len(), 2);
998            assert_eq!(train.len(), 8);
999        }
1000    }
1001    #[test]
1002    fn test_build_machine_learning_env() {
1003        let mut env = Environment::new();
1004        build_machine_learning_env(&mut env).expect("build_machine_learning_env should succeed");
1005        assert!(!env.is_empty());
1006    }
1007    #[test]
1008    fn test_ewc_penalty_zero_at_theta_star() {
1009        let mut ewc = ElasticWeightConsolidation::new(1.0);
1010        let params = vec![1.0, 2.0, 3.0];
1011        let grads = vec![vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]];
1012        ewc.consolidate(&params, &grads);
1013        let pen = ewc.penalty(&params);
1014        assert!(pen.abs() < 1e-10, "penalty at theta* = {pen}");
1015    }
1016    #[test]
1017    fn test_ewc_penalty_nonzero_elsewhere() {
1018        let mut ewc = ElasticWeightConsolidation::new(1.0);
1019        let params = vec![1.0, 2.0];
1020        let grads = vec![vec![1.0, 1.0]];
1021        ewc.consolidate(&params, &grads);
1022        let shifted = vec![2.0, 3.0];
1023        let pen = ewc.penalty(&shifted);
1024        assert!(pen > 0.0, "penalty should be positive when shifted");
1025    }
1026    #[test]
1027    fn test_ewc_penalty_gradient_direction() {
1028        let mut ewc = ElasticWeightConsolidation::new(1.0);
1029        let theta_star = vec![0.0, 0.0];
1030        let grads = vec![vec![1.0, 1.0]];
1031        ewc.consolidate(&theta_star, &grads);
1032        let params = vec![1.0, -1.0];
1033        let grad = ewc.penalty_gradient(&params);
1034        assert!(grad[0] > 0.0);
1035        assert!(grad[1] < 0.0);
1036    }
1037    #[test]
1038    fn test_shapley_explainer_constant_model() {
1039        let bg = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
1040        let explainer = ShapleyExplainer::new(2, 200, bg);
1041        let x = vec![3.0, 4.0];
1042        let phi = explainer.explain(&x, &|_inp: &[f64]| 5.0_f64);
1043        for &v in &phi {
1044            assert!(
1045                v.abs() < 1e-10,
1046                "constant model => all Shapley = 0, got {v}"
1047            );
1048        }
1049    }
1050    #[test]
1051    fn test_shapley_explainer_linear_model() {
1052        let bg = vec![vec![0.0, 0.0]];
1053        let explainer = ShapleyExplainer::new(2, 500, bg);
1054        let x = vec![3.0, 5.0];
1055        let phi = explainer.explain(&x, &|inp: &[f64]| inp.iter().sum::<f64>());
1056        assert!((phi[0] - 3.0).abs() < 0.5, "phi[0]={}", phi[0]);
1057        assert!((phi[1] - 5.0).abs() < 0.5, "phi[1]={}", phi[1]);
1058    }
1059    #[test]
1060    fn test_randomized_smoothing_predict() {
1061        let smoother = RandomizedSmoothingClassifier::new(0.1, 200, 0.95);
1062        let base = |x: &[f64]| if x[0] > 0.5 { 1usize } else { 0usize };
1063        let cls = smoother.smooth_predict(&[1.0, 0.0], &base);
1064        assert_eq!(cls, 1);
1065    }
1066    #[test]
1067    fn test_randomized_smoothing_certify() {
1068        let smoother = RandomizedSmoothingClassifier::new(0.1, 500, 0.95);
1069        let base = |x: &[f64]| if x[0] > 0.5 { 1usize } else { 0usize };
1070        let (cls, radius) = smoother.certify(&[1.0, 0.0], &base);
1071        assert_eq!(cls, 1);
1072        assert!(radius >= 0.0);
1073    }
1074    #[test]
1075    fn test_pac_bayes_mcallester() {
1076        let bound_calc = PACBayesBound::new(0.05, 1000);
1077        let bound = bound_calc.mcallester(0.1, 1.0);
1078        assert!(bound > 0.1);
1079        assert!(bound.is_finite());
1080    }
1081    #[test]
1082    fn test_pac_bayes_catoni() {
1083        let bound_calc = PACBayesBound::new(0.05, 1000);
1084        let bound = bound_calc.catoni(0.1, 1.0, 1.0);
1085        assert!(bound.is_finite());
1086        assert!(bound > 0.0);
1087    }
1088    #[test]
1089    fn test_pac_bayes_kl_bernoulli() {
1090        let kl = PACBayesBound::kl_bernoulli(0.3, 0.3);
1091        assert!(kl.abs() < 1e-10);
1092        assert!(PACBayesBound::kl_bernoulli(0.1, 0.9) > 0.0);
1093    }
1094    #[test]
1095    fn test_pac_bayes_kl_gaussians() {
1096        let kl = PACBayesBound::kl_gaussians(1.0, 1.0, 1.0);
1097        assert!(kl.abs() < 1e-10);
1098        assert!(PACBayesBound::kl_gaussians(0.0, 1.0, 1.0) > 0.0);
1099    }
1100    #[test]
1101    fn test_uncertainty_sampler_least_confident() {
1102        let sampler = UncertaintySampler::new(UncertaintyStrategy::LeastConfident);
1103        let certain = vec![0.99, 0.01];
1104        let uncertain = vec![0.5, 0.5];
1105        assert!(sampler.score(&uncertain) > sampler.score(&certain));
1106    }
1107    #[test]
1108    fn test_uncertainty_sampler_margin() {
1109        let sampler = UncertaintySampler::new(UncertaintyStrategy::MarginSampling);
1110        let small_margin = vec![0.51, 0.49];
1111        let large_margin = vec![0.99, 0.01];
1112        assert!(sampler.score(&small_margin) > sampler.score(&large_margin));
1113    }
1114    #[test]
1115    fn test_uncertainty_sampler_entropy() {
1116        let sampler = UncertaintySampler::new(UncertaintyStrategy::Entropy);
1117        let uniform = vec![0.25, 0.25, 0.25, 0.25];
1118        let peaked = vec![0.97, 0.01, 0.01, 0.01];
1119        assert!(sampler.score(&uniform) > sampler.score(&peaked));
1120    }
1121    #[test]
1122    fn test_uncertainty_sampler_select_query() {
1123        let sampler = UncertaintySampler::new(UncertaintyStrategy::Entropy);
1124        let candidates = vec![vec![0.99, 0.01], vec![0.5, 0.5], vec![0.7, 0.3]];
1125        let idx = sampler.select_query(&candidates);
1126        assert_eq!(idx, 1);
1127    }
1128    #[test]
1129    fn test_register_advanced_ml_axioms() {
1130        let mut env = Environment::new();
1131        register_advanced_ml_axioms(&mut env);
1132        assert!(!env.is_empty());
1133    }
1134}