1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 Activation, AdamOptimizer, DecisionStump, ElasticWeightConsolidation, GradientDescent, KMeans,
9 KnnClassifier, Layer, LinearRegression, MomentumSGD, NeuralNetwork, PACBayesBound,
10 PolynomialRegression, RandomizedSmoothingClassifier, ShapleyExplainer, UncertaintySampler,
11 UncertaintyStrategy,
12};
13
14pub fn app(f: Expr, a: Expr) -> Expr {
15 Expr::App(Box::new(f), Box::new(a))
16}
17pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
18 app(app(f, a), b)
19}
20pub fn cst(s: &str) -> Expr {
21 Expr::Const(Name::str(s), vec![])
22}
23pub fn prop() -> Expr {
24 Expr::Sort(Level::zero())
25}
26pub fn type0() -> Expr {
27 Expr::Sort(Level::succ(Level::zero()))
28}
29pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
30 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
31}
32pub fn arrow(a: Expr, b: Expr) -> Expr {
33 pi(BinderInfo::Default, "_", a, b)
34}
35pub fn nat_ty() -> Expr {
36 cst("Nat")
37}
38pub fn real_ty() -> Expr {
39 cst("Real")
40}
41#[allow(dead_code)]
42pub fn bool_ty() -> Expr {
43 cst("Bool")
44}
45pub fn list_ty(elem: Expr) -> Expr {
46 app(cst("List"), elem)
47}
48pub fn learner_ty() -> Expr {
49 type0()
50}
51pub fn vc_dimension_ty() -> Expr {
52 arrow(type0(), nat_ty())
53}
54pub fn pac_learnable_ty() -> Expr {
55 arrow(type0(), prop())
56}
57pub fn neural_network_ty() -> Expr {
58 type0()
59}
60pub fn gradient_ty() -> Expr {
61 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
62}
63pub fn kernel_method_ty() -> Expr {
64 type0()
65}
66pub fn loss_function_ty() -> Expr {
67 arrow(real_ty(), arrow(real_ty(), real_ty()))
68}
69pub fn regularizer_ty() -> Expr {
70 arrow(list_ty(real_ty()), real_ty())
71}
72pub fn cross_validation_ty() -> Expr {
73 arrow(nat_ty(), arrow(type0(), real_ty()))
74}
75pub fn fundamental_thm_pac_ty() -> Expr {
76 pi(
77 BinderInfo::Default,
78 "H",
79 type0(),
80 app2(
81 cst("Iff"),
82 app(cst("PACLearnable"), cst("H")),
83 app2(
84 cst("Nat.lt"),
85 app(cst("VCDimension"), cst("H")),
86 cst("Nat.infinity"),
87 ),
88 ),
89 )
90}
91pub fn universal_approximation_ty() -> Expr {
92 prop()
93}
94pub fn vc_bound_ty() -> Expr {
95 pi(
96 BinderInfo::Default,
97 "eps",
98 real_ty(),
99 pi(
100 BinderInfo::Default,
101 "delta",
102 real_ty(),
103 arrow(
104 app2(cst("Real.lt"), cst("Real.zero"), cst("eps")),
105 arrow(
106 app2(cst("Real.lt"), cst("Real.zero"), cst("delta")),
107 app(
108 cst("Exists"),
109 pi(
110 BinderInfo::Default,
111 "m",
112 nat_ty(),
113 app2(
114 app(cst("GeneralizationBound"), cst("m")),
115 cst("eps"),
116 cst("delta"),
117 ),
118 ),
119 ),
120 ),
121 ),
122 ),
123 )
124}
125pub fn no_free_lunch_ty() -> Expr {
126 prop()
127}
128pub fn bias_variance_tradeoff_ty() -> Expr {
129 prop()
130}
131pub fn regularization_convergence_ty() -> Expr {
132 prop()
133}
134pub fn build_machine_learning_env(env: &mut Environment) -> Result<(), Box<dyn std::error::Error>> {
135 let axioms: &[(&str, Expr)] = &[
136 ("Learner", learner_ty()),
137 ("VCDimension", vc_dimension_ty()),
138 ("PACLearnable", pac_learnable_ty()),
139 ("NeuralNetwork", neural_network_ty()),
140 ("Gradient", gradient_ty()),
141 ("KernelMethod", kernel_method_ty()),
142 ("LossFunction", loss_function_ty()),
143 ("Regularizer", regularizer_ty()),
144 ("CrossValidation", cross_validation_ty()),
145 (
146 "GeneralizationBound",
147 arrow(nat_ty(), arrow(real_ty(), arrow(real_ty(), prop()))),
148 ),
149 ("Real.zero", real_ty()),
150 ("Nat.infinity", nat_ty()),
151 ("fundamental_thm_pac", fundamental_thm_pac_ty()),
152 ("universal_approximation", universal_approximation_ty()),
153 ("vc_bound", vc_bound_ty()),
154 ("no_free_lunch", no_free_lunch_ty()),
155 ("bias_variance_tradeoff", bias_variance_tradeoff_ty()),
156 (
157 "regularization_convergence",
158 regularization_convergence_ty(),
159 ),
160 ];
161 for (name, ty) in axioms {
162 env.add(Declaration::Axiom {
163 name: Name::str(*name),
164 univ_params: vec![],
165 ty: ty.clone(),
166 })
167 .ok();
168 }
169 Ok(())
170}
171pub fn mse_loss(y_pred: &[f64], y_true: &[f64]) -> f64 {
172 let n = y_pred.len().min(y_true.len());
173 if n == 0 {
174 return 0.0;
175 }
176 let sum: f64 = y_pred
177 .iter()
178 .zip(y_true.iter())
179 .map(|(p, t)| (p - t).powi(2))
180 .sum();
181 sum / n as f64
182}
183pub fn mae_loss(y_pred: &[f64], y_true: &[f64]) -> f64 {
184 let n = y_pred.len().min(y_true.len());
185 if n == 0 {
186 return 0.0;
187 }
188 let sum: f64 = y_pred
189 .iter()
190 .zip(y_true.iter())
191 .map(|(p, t)| (p - t).abs())
192 .sum();
193 sum / n as f64
194}
195pub fn binary_cross_entropy(y_pred: &[f64], y_true: &[f64]) -> f64 {
196 let n = y_pred.len().min(y_true.len());
197 if n == 0 {
198 return 0.0;
199 }
200 let eps = 1e-15;
201 let sum: f64 = y_pred
202 .iter()
203 .zip(y_true.iter())
204 .map(|(&p, &t)| {
205 let p_clamped = p.clamp(eps, 1.0 - eps);
206 -(t * p_clamped.ln() + (1.0 - t) * (1.0 - p_clamped).ln())
207 })
208 .sum();
209 sum / n as f64
210}
211pub fn hinge_loss(y_pred: &[f64], y_true: &[f64]) -> f64 {
212 let n = y_pred.len().min(y_true.len());
213 if n == 0 {
214 return 0.0;
215 }
216 let sum: f64 = y_pred
217 .iter()
218 .zip(y_true.iter())
219 .map(|(&p, &t)| (1.0 - t * p).max(0.0))
220 .sum();
221 sum / n as f64
222}
223pub fn huber_loss(y_pred: &[f64], y_true: &[f64], delta: f64) -> f64 {
224 let n = y_pred.len().min(y_true.len());
225 if n == 0 {
226 return 0.0;
227 }
228 let sum: f64 = y_pred
229 .iter()
230 .zip(y_true.iter())
231 .map(|(&p, &t)| {
232 let err = (p - t).abs();
233 if err <= delta {
234 0.5 * err * err
235 } else {
236 delta * err - 0.5 * delta * delta
237 }
238 })
239 .sum();
240 sum / n as f64
241}
242pub fn min_max_normalize(data: &[Vec<f64>]) -> Vec<Vec<f64>> {
243 if data.is_empty() {
244 return vec![];
245 }
246 let dim = data[0].len();
247 let mut mins = vec![f64::INFINITY; dim];
248 let mut maxs = vec![f64::NEG_INFINITY; dim];
249 for row in data {
250 for (j, &v) in row.iter().enumerate() {
251 if j < dim {
252 mins[j] = mins[j].min(v);
253 maxs[j] = maxs[j].max(v);
254 }
255 }
256 }
257 data.iter()
258 .map(|row| {
259 row.iter()
260 .enumerate()
261 .map(|(j, &v)| {
262 let range = maxs[j] - mins[j];
263 if range.abs() < 1e-15 {
264 0.0
265 } else {
266 (v - mins[j]) / range
267 }
268 })
269 .collect()
270 })
271 .collect()
272}
273pub fn z_score_normalize(data: &[Vec<f64>]) -> Vec<Vec<f64>> {
274 if data.is_empty() {
275 return vec![];
276 }
277 let dim = data[0].len();
278 let n = data.len() as f64;
279 let mut means = vec![0.0f64; dim];
280 for row in data {
281 for (j, &v) in row.iter().enumerate() {
282 if j < dim {
283 means[j] += v;
284 }
285 }
286 }
287 for m in &mut means {
288 *m /= n;
289 }
290 let mut stds = vec![0.0f64; dim];
291 for row in data {
292 for (j, &v) in row.iter().enumerate() {
293 if j < dim {
294 stds[j] += (v - means[j]).powi(2);
295 }
296 }
297 }
298 for s in &mut stds {
299 *s = (*s / n).sqrt();
300 }
301 data.iter()
302 .map(|row| {
303 row.iter()
304 .enumerate()
305 .map(|(j, &v)| {
306 if stds[j].abs() < 1e-15 {
307 0.0
308 } else {
309 (v - means[j]) / stds[j]
310 }
311 })
312 .collect()
313 })
314 .collect()
315}
316pub fn backprop_mse(
317 net: &NeuralNetwork,
318 input: &[f64],
319 target: &[f64],
320) -> (Vec<Vec<Vec<f64>>>, Vec<Vec<f64>>) {
321 let (activations, z_cache, _) = net.forward_cached(input);
322 let n_layers = net.layers.len();
323 let mut weight_grads: Vec<Vec<Vec<f64>>> = Vec::with_capacity(n_layers);
324 let mut bias_grads: Vec<Vec<f64>> = Vec::with_capacity(n_layers);
325 for layer in &net.layers {
326 weight_grads.push(vec![vec![0.0; layer.n_inputs()]; layer.n_outputs()]);
327 bias_grads.push(vec![0.0; layer.n_outputs()]);
328 }
329 let output = &activations[n_layers];
330 let mut delta: Vec<f64> = output
331 .iter()
332 .zip(target.iter())
333 .enumerate()
334 .map(|(j, (&a, &t))| {
335 let dl_da = 2.0 * (a - t) / target.len() as f64;
336 let da_dz = net.layers[n_layers - 1]
337 .activation
338 .derivative(z_cache[n_layers - 1][j]);
339 dl_da * da_dz
340 })
341 .collect();
342 for l in (0..n_layers).rev() {
343 let input_to_layer = &activations[l];
344 for j in 0..delta.len() {
345 bias_grads[l][j] = delta[j];
346 for k in 0..input_to_layer.len() {
347 weight_grads[l][j][k] = delta[j] * input_to_layer[k];
348 }
349 }
350 if l > 0 {
351 let layer = &net.layers[l];
352 let prev_z = &z_cache[l - 1];
353 let mut new_delta = vec![0.0; prev_z.len()];
354 for k in 0..prev_z.len() {
355 let mut sum = 0.0;
356 for j in 0..delta.len() {
357 if k < layer.weights[j].len() {
358 sum += delta[j] * layer.weights[j][k];
359 }
360 }
361 new_delta[k] = sum * net.layers[l - 1].activation.derivative(prev_z[k]);
362 }
363 delta = new_delta;
364 }
365 }
366 (weight_grads, bias_grads)
367}
368pub fn train_network(
369 net: &mut NeuralNetwork,
370 inputs: &[Vec<f64>],
371 targets: &[Vec<f64>],
372 learning_rate: f64,
373 epochs: u32,
374) -> Vec<f64> {
375 let mut loss_history = Vec::new();
376 let n = inputs.len().min(targets.len());
377 if n == 0 {
378 return loss_history;
379 }
380 for _epoch in 0..epochs {
381 let mut total_loss = 0.0;
382 for i in 0..n {
383 let output = net.forward(&inputs[i]);
384 total_loss += mse_loss(&output, &targets[i]);
385 let (w_grads, b_grads) = backprop_mse(net, &inputs[i], &targets[i]);
386 for (l, layer) in net.layers.iter_mut().enumerate() {
387 for j in 0..layer.weights.len() {
388 for k in 0..layer.weights[j].len() {
389 layer.weights[j][k] -= learning_rate * w_grads[l][j][k];
390 }
391 layer.biases[j] -= learning_rate * b_grads[l][j];
392 }
393 }
394 }
395 loss_history.push(total_loss / n as f64);
396 }
397 loss_history
398}
399pub fn accuracy(predicted: &[usize], actual: &[usize]) -> f64 {
400 let n = predicted.len().min(actual.len());
401 if n == 0 {
402 return 0.0;
403 }
404 let correct = predicted
405 .iter()
406 .zip(actual.iter())
407 .filter(|(p, a)| p == a)
408 .count();
409 correct as f64 / n as f64
410}
411pub fn precision(predicted: &[usize], actual: &[usize], class: usize) -> f64 {
412 let tp = predicted
413 .iter()
414 .zip(actual.iter())
415 .filter(|(&p, &a)| p == class && a == class)
416 .count();
417 let pp = predicted.iter().filter(|&&p| p == class).count();
418 if pp == 0 {
419 0.0
420 } else {
421 tp as f64 / pp as f64
422 }
423}
424pub fn recall(predicted: &[usize], actual: &[usize], class: usize) -> f64 {
425 let tp = predicted
426 .iter()
427 .zip(actual.iter())
428 .filter(|(&p, &a)| p == class && a == class)
429 .count();
430 let ap = actual.iter().filter(|&&a| a == class).count();
431 if ap == 0 {
432 0.0
433 } else {
434 tp as f64 / ap as f64
435 }
436}
437pub fn f1_score(predicted: &[usize], actual: &[usize], class: usize) -> f64 {
438 let p = precision(predicted, actual, class);
439 let r = recall(predicted, actual, class);
440 if p + r < 1e-15 {
441 0.0
442 } else {
443 2.0 * p * r / (p + r)
444 }
445}
446pub fn l2_penalty(weights: &[f64], lambda: f64) -> f64 {
447 lambda * weights.iter().map(|w| w * w).sum::<f64>()
448}
449pub fn l1_penalty(weights: &[f64], lambda: f64) -> f64 {
450 lambda * weights.iter().map(|w| w.abs()).sum::<f64>()
451}
452pub fn elastic_net_penalty(weights: &[f64], lambda: f64, alpha: f64) -> f64 {
453 alpha * l1_penalty(weights, lambda) + (1.0 - alpha) * l2_penalty(weights, lambda)
454}
455pub fn train_test_split<T: Clone>(data: &[T], frac: f64) -> (Vec<T>, Vec<T>) {
456 let split = ((data.len() as f64) * frac).round() as usize;
457 let split = split.min(data.len());
458 (data[..split].to_vec(), data[split..].to_vec())
459}
460pub fn k_fold_indices(n: usize, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
461 if k == 0 || n == 0 {
462 return vec![];
463 }
464 let fold_size = n / k;
465 (0..k)
466 .map(|i| {
467 let test_start = i * fold_size;
468 let test_end = if i == k - 1 { n } else { (i + 1) * fold_size };
469 let test_indices: Vec<usize> = (test_start..test_end).collect();
470 let train_indices: Vec<usize> =
471 (0..n).filter(|idx| !test_indices.contains(idx)).collect();
472 (train_indices, test_indices)
473 })
474 .collect()
475}
476pub fn neural_tangent_kernel_ty() -> Expr {
477 arrow(
478 arrow(real_ty(), real_ty()),
479 arrow(arrow(real_ty(), real_ty()), real_ty()),
480 )
481}
482pub fn lazy_training_regime_ty() -> Expr {
483 arrow(nat_ty(), prop())
484}
485pub fn mean_field_limit_ty() -> Expr {
486 arrow(nat_ty(), arrow(real_ty(), prop()))
487}
488pub fn infinite_width_limit_ty() -> Expr {
489 arrow(nat_ty(), prop())
490}
491pub fn loss_landscape_ty() -> Expr {
492 arrow(list_ty(real_ty()), real_ty())
493}
494pub fn autoencoder_ty() -> Expr {
495 arrow(type0(), arrow(nat_ty(), type0()))
496}
497pub fn contrastive_loss_ty() -> Expr {
498 arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), real_ty())))
499}
500pub fn self_supervised_objective_ty() -> Expr {
501 arrow(type0(), prop())
502}
503pub fn representation_collapse_ty() -> Expr {
504 arrow(type0(), prop())
505}
506pub fn disentanglement_ty() -> Expr {
507 arrow(nat_ty(), arrow(type0(), prop()))
508}
509pub fn message_passing_ty() -> Expr {
510 arrow(nat_ty(), arrow(type0(), type0()))
511}
512pub fn wl_expressiveness_ty() -> Expr {
513 arrow(nat_ty(), prop())
514}
515pub fn graph_isomorphism_power_ty() -> Expr {
516 arrow(nat_ty(), prop())
517}
518pub fn over_smoothing_ty() -> Expr {
519 arrow(nat_ty(), arrow(real_ty(), prop()))
520}
521pub fn attention_mechanism_ty() -> Expr {
522 arrow(
523 list_ty(real_ty()),
524 arrow(list_ty(real_ty()), list_ty(real_ty())),
525 )
526}
527pub fn transformer_universality_ty() -> Expr {
528 arrow(nat_ty(), prop())
529}
530pub fn in_context_learning_ty() -> Expr {
531 arrow(nat_ty(), arrow(type0(), prop()))
532}
533pub fn positional_encoding_ty() -> Expr {
534 arrow(nat_ty(), arrow(nat_ty(), list_ty(real_ty())))
535}
536pub fn pac_mdp_ty() -> Expr {
537 arrow(nat_ty(), arrow(real_ty(), arrow(real_ty(), prop())))
538}
539pub fn regret_bound_ty() -> Expr {
540 arrow(nat_ty(), arrow(real_ty(), prop()))
541}
542pub fn sample_complexity_rl_ty() -> Expr {
543 arrow(real_ty(), arrow(real_ty(), nat_ty()))
544}
545pub fn exploration_exploitation_ty() -> Expr {
546 arrow(nat_ty(), prop())
547}
548pub fn vae_elbo_ty() -> Expr {
549 arrow(type0(), arrow(type0(), real_ty()))
550}
551pub fn gan_equilibrium_ty() -> Expr {
552 arrow(type0(), arrow(type0(), prop()))
553}
554pub fn normalizing_flow_ty() -> Expr {
555 arrow(nat_ty(), arrow(type0(), type0()))
556}
557pub fn diffusion_process_ty() -> Expr {
558 arrow(nat_ty(), arrow(real_ty(), arrow(type0(), type0())))
559}
560pub fn score_matching_ty() -> Expr {
561 arrow(type0(), arrow(real_ty(), list_ty(real_ty())))
562}
563pub fn maml_convergence_ty() -> Expr {
564 arrow(nat_ty(), arrow(real_ty(), prop()))
565}
566pub fn few_shot_bound_ty() -> Expr {
567 arrow(nat_ty(), arrow(nat_ty(), arrow(real_ty(), prop())))
568}
569pub fn task_distribution_ty() -> Expr {
570 arrow(type0(), real_ty())
571}
572pub fn catastrophic_forgetting_ty() -> Expr {
573 arrow(nat_ty(), prop())
574}
575pub fn ewc_regularizer_ty() -> Expr {
576 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty()))
577}
578pub fn memory_replay_bound_ty() -> Expr {
579 arrow(nat_ty(), arrow(real_ty(), prop()))
580}
581pub fn negative_transfer_ty() -> Expr {
582 arrow(type0(), arrow(type0(), prop()))
583}
584pub fn task_relatedness_ty() -> Expr {
585 arrow(type0(), arrow(type0(), real_ty()))
586}
587pub fn transfer_excess_risk_ty() -> Expr {
588 arrow(type0(), arrow(type0(), arrow(nat_ty(), real_ty())))
589}
590pub fn demographic_parity_ty() -> Expr {
591 arrow(type0(), arrow(type0(), prop()))
592}
593pub fn equalized_odds_ty() -> Expr {
594 arrow(type0(), arrow(type0(), prop()))
595}
596pub fn individual_fairness_ty() -> Expr {
597 arrow(arrow(type0(), real_ty()), prop())
598}
599pub fn fairness_accuracy_tradeoff_ty() -> Expr {
600 prop()
601}
602pub fn shapley_value_ty() -> Expr {
603 arrow(
604 arrow(list_ty(real_ty()), real_ty()),
605 arrow(nat_ty(), real_ty()),
606 )
607}
608pub fn shap_attribution_ty() -> Expr {
609 arrow(type0(), arrow(list_ty(real_ty()), list_ty(real_ty())))
610}
611pub fn counterfactual_explanation_ty() -> Expr {
612 arrow(list_ty(real_ty()), arrow(type0(), list_ty(real_ty())))
613}
614pub fn lp_adversarial_attack_ty() -> Expr {
615 arrow(real_ty(), arrow(list_ty(real_ty()), list_ty(real_ty())))
616}
617pub fn certified_defense_ty() -> Expr {
618 arrow(real_ty(), arrow(type0(), prop()))
619}
620pub fn randomized_smoothing_ty() -> Expr {
621 arrow(real_ty(), arrow(type0(), arrow(real_ty(), prop())))
622}
623pub fn nas_search_space_ty() -> Expr {
624 arrow(nat_ty(), type0())
625}
626pub fn one_shot_nas_ty() -> Expr {
627 arrow(type0(), arrow(nat_ty(), prop()))
628}
629pub fn cell_based_nas_ty() -> Expr {
630 arrow(nat_ty(), arrow(nat_ty(), type0()))
631}
632pub fn query_complexity_ty() -> Expr {
633 arrow(real_ty(), arrow(real_ty(), nat_ty()))
634}
635pub fn uncertainty_sampling_ty() -> Expr {
636 arrow(type0(), arrow(list_ty(real_ty()), real_ty()))
637}
638pub fn optimal_stopping_al_ty() -> Expr {
639 arrow(nat_ty(), arrow(real_ty(), prop()))
640}
641pub fn pac_bayes_mcallester_ty() -> Expr {
642 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), real_ty())))
643}
644pub fn pac_bayes_catoni_ty() -> Expr {
645 arrow(real_ty(), arrow(real_ty(), arrow(nat_ty(), real_ty())))
646}
647pub fn data_dependent_prior_ty() -> Expr {
648 arrow(list_ty(type0()), arrow(type0(), real_ty()))
649}
650pub fn kl_divergence_bound_ty() -> Expr {
651 arrow(type0(), arrow(type0(), arrow(nat_ty(), real_ty())))
652}
653pub fn register_advanced_ml_axioms(env: &mut Environment) {
654 let axioms: &[(&str, Expr)] = &[
655 ("NeuralTangentKernel", neural_tangent_kernel_ty()),
656 ("LazyTrainingRegime", lazy_training_regime_ty()),
657 ("MeanFieldLimit", mean_field_limit_ty()),
658 ("InfiniteWidthLimit", infinite_width_limit_ty()),
659 ("LossLandscape", loss_landscape_ty()),
660 ("Autoencoder", autoencoder_ty()),
661 ("ContrastiveLoss", contrastive_loss_ty()),
662 ("SelfSupervisedObjective", self_supervised_objective_ty()),
663 ("RepresentationCollapse", representation_collapse_ty()),
664 ("Disentanglement", disentanglement_ty()),
665 ("MessagePassing", message_passing_ty()),
666 ("WLExpressiveness", wl_expressiveness_ty()),
667 ("GraphIsomorphismPower", graph_isomorphism_power_ty()),
668 ("OverSmoothing", over_smoothing_ty()),
669 ("AttentionMechanism", attention_mechanism_ty()),
670 ("TransformerUniversality", transformer_universality_ty()),
671 ("InContextLearning", in_context_learning_ty()),
672 ("PositionalEncoding", positional_encoding_ty()),
673 ("PACMDP", pac_mdp_ty()),
674 ("RegretBound", regret_bound_ty()),
675 ("SampleComplexityRL", sample_complexity_rl_ty()),
676 ("ExplorationExploitation", exploration_exploitation_ty()),
677 ("VAELBO", vae_elbo_ty()),
678 ("GANEquilibrium", gan_equilibrium_ty()),
679 ("NormalizingFlow", normalizing_flow_ty()),
680 ("DiffusionProcess", diffusion_process_ty()),
681 ("ScoreMatching", score_matching_ty()),
682 ("MAMLConvergence", maml_convergence_ty()),
683 ("FewShotBound", few_shot_bound_ty()),
684 ("TaskDistribution", task_distribution_ty()),
685 ("CatastrophicForgetting", catastrophic_forgetting_ty()),
686 ("EWCRegularizer", ewc_regularizer_ty()),
687 ("MemoryReplayBound", memory_replay_bound_ty()),
688 ("NegativeTransfer", negative_transfer_ty()),
689 ("TaskRelatedness", task_relatedness_ty()),
690 ("TransferExcessRisk", transfer_excess_risk_ty()),
691 ("DemographicParity", demographic_parity_ty()),
692 ("EqualizedOdds", equalized_odds_ty()),
693 ("IndividualFairness", individual_fairness_ty()),
694 ("FairnessAccuracyTradeoff", fairness_accuracy_tradeoff_ty()),
695 ("ShapleyValue", shapley_value_ty()),
696 ("SHAPAttribution", shap_attribution_ty()),
697 ("CounterfactualExplanation", counterfactual_explanation_ty()),
698 ("LpAdversarialAttack", lp_adversarial_attack_ty()),
699 ("CertifiedDefense", certified_defense_ty()),
700 ("RandomizedSmoothing", randomized_smoothing_ty()),
701 ("NASSearchSpace", nas_search_space_ty()),
702 ("OneShotNAS", one_shot_nas_ty()),
703 ("CellBasedNAS", cell_based_nas_ty()),
704 ("QueryComplexity", query_complexity_ty()),
705 ("UncertaintySampling", uncertainty_sampling_ty()),
706 ("OptimalStoppingAL", optimal_stopping_al_ty()),
707 ("PACBayesMcAllester", pac_bayes_mcallester_ty()),
708 ("PACBayesCatoni", pac_bayes_catoni_ty()),
709 ("DataDependentPrior", data_dependent_prior_ty()),
710 ("KLDivergenceBound", kl_divergence_bound_ty()),
711 ];
712 for (name, ty) in axioms {
713 env.add(Declaration::Axiom {
714 name: Name::str(*name),
715 univ_params: vec![],
716 ty: ty.clone(),
717 })
718 .ok();
719 }
720}
721#[cfg(test)]
722mod tests {
723 use super::*;
724 #[test]
725 fn test_activation_relu() {
726 let a = Activation::ReLU;
727 assert_eq!(a.apply(-1.0), 0.0);
728 assert_eq!(a.apply(2.0), 2.0);
729 assert_eq!(a.apply(0.0), 0.0);
730 }
731 #[test]
732 fn test_activation_sigmoid() {
733 let a = Activation::Sigmoid;
734 assert!((a.apply(0.0) - 0.5).abs() < 1e-10);
735 assert!(a.apply(100.0) > 0.999);
736 assert!(a.apply(-100.0) < 0.001);
737 }
738 #[test]
739 fn test_activation_leaky_relu() {
740 let a = Activation::LeakyReLU;
741 assert_eq!(a.apply(5.0), 5.0);
742 assert!((a.apply(-10.0) - (-0.1)).abs() < 1e-10);
743 }
744 #[test]
745 fn test_activation_elu() {
746 let a = Activation::ELU;
747 assert_eq!(a.apply(3.0), 3.0);
748 assert!(a.apply(-1.0) < 0.0);
749 }
750 #[test]
751 fn test_softmax() {
752 let vals = vec![1.0, 2.0, 3.0];
753 let sm = Activation::apply_softmax(&vals);
754 let sum: f64 = sm.iter().sum();
755 assert!((sum - 1.0).abs() < 1e-10);
756 assert!(sm[2] > sm[1] && sm[1] > sm[0]);
757 }
758 #[test]
759 fn test_layer_forward() {
760 let layer = Layer::from_weights(
761 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
762 vec![0.0, 0.0],
763 Activation::Linear,
764 );
765 let out = layer.forward(&[3.0, 5.0]);
766 assert!((out[0] - 3.0).abs() < 1e-10);
767 assert!((out[1] - 5.0).abs() < 1e-10);
768 }
769 #[test]
770 fn test_layer_with_cache() {
771 let layer = Layer::from_weights(
772 vec![vec![2.0, 0.0], vec![0.0, 3.0]],
773 vec![1.0, -1.0],
774 Activation::ReLU,
775 );
776 let (z, a) = layer.forward_with_cache(&[1.0, 2.0]);
777 assert!((z[0] - 3.0).abs() < 1e-10);
778 assert!((z[1] - 5.0).abs() < 1e-10);
779 assert!((a[0] - 3.0).abs() < 1e-10);
780 assert!((a[1] - 5.0).abs() < 1e-10);
781 }
782 #[test]
783 fn test_neural_network_forward() {
784 let layer = Layer::from_weights(
785 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
786 vec![0.0, 0.0],
787 Activation::Linear,
788 );
789 let net = NeuralNetwork::new(vec![layer]);
790 let out = net.forward(&[1.0, 2.0]);
791 assert!((out[0] - 1.0).abs() < 1e-10);
792 assert!((out[1] - 2.0).abs() < 1e-10);
793 assert_eq!(net.predict_class(&[1.0, 2.0]), 1);
794 assert_eq!(net.depth(), 1);
795 }
796 #[test]
797 fn test_backprop_reduces_loss() {
798 let layer = Layer::from_weights(vec![vec![0.5]], vec![0.0], Activation::Linear);
799 let mut net = NeuralNetwork::new(vec![layer]);
800 let inputs = vec![vec![1.0], vec![2.0], vec![3.0]];
801 let targets = vec![vec![1.0], vec![2.0], vec![3.0]];
802 let history = train_network(&mut net, &inputs, &targets, 0.01, 100);
803 assert!(
804 history.last().expect("last should succeed")
805 < history.first().expect("first should succeed")
806 );
807 }
808 #[test]
809 fn test_loss_functions() {
810 let pred = vec![1.0, 2.0, 3.0];
811 let true_vals = vec![1.0, 2.0, 3.0];
812 assert!(mse_loss(&pred, &true_vals) < 1e-10);
813 assert!(mae_loss(&pred, &true_vals) < 1e-10);
814 }
815 #[test]
816 fn test_binary_cross_entropy() {
817 let pred = vec![0.9, 0.1];
818 let true_vals = vec![1.0, 0.0];
819 let bce = binary_cross_entropy(&pred, &true_vals);
820 let bad_pred = vec![0.1, 0.9];
821 let bad_bce = binary_cross_entropy(&bad_pred, &true_vals);
822 assert!(bad_bce > bce);
823 }
824 #[test]
825 fn test_hinge_loss() {
826 let pred = vec![2.0, -2.0];
827 let true_labels = vec![1.0, -1.0];
828 assert!(hinge_loss(&pred, &true_labels) < 1e-10);
829 }
830 #[test]
831 fn test_huber_loss() {
832 let pred = vec![1.0, 2.0];
833 let true_vals = vec![1.0, 2.0];
834 assert!(huber_loss(&pred, &true_vals, 1.0) < 1e-10);
835 }
836 #[test]
837 fn test_knn_classifier_simple() {
838 let mut knn = KnnClassifier::new(1);
839 knn.fit(vec![
840 (vec![0.0, 0.0], 0),
841 (vec![1.0, 0.0], 1),
842 (vec![0.0, 1.0], 0),
843 ]);
844 assert_eq!(knn.predict(&[0.1, 0.1]), 0);
845 assert_eq!(knn.predict(&[0.9, 0.1]), 1);
846 }
847 #[test]
848 fn test_knn_predict_proba() {
849 let mut knn = KnnClassifier::new(3);
850 knn.fit(vec![(vec![0.0], 0), (vec![0.1], 0), (vec![0.2], 1)]);
851 let proba = knn.predict_proba(&[0.05]);
852 let p0 = proba.get(&0).copied().unwrap_or(0.0);
853 assert!(p0 >= 0.5);
854 }
855 #[test]
856 fn test_decision_stump() {
857 let stump = DecisionStump::new(0, 5.0, 1);
858 assert_eq!(stump.predict(&[6.0]), 1);
859 assert_eq!(stump.predict(&[4.0]), 0);
860 }
861 #[test]
862 fn test_decision_stump_find_best() {
863 let data = vec![
864 (vec![1.0], 0),
865 (vec![2.0], 0),
866 (vec![3.0], 1),
867 (vec![4.0], 1),
868 ];
869 let weights = vec![0.25, 0.25, 0.25, 0.25];
870 let best = DecisionStump::find_best(&data, &weights);
871 let correct: usize = data.iter().filter(|(x, y)| best.predict(x) == *y).count();
872 assert!(correct >= 3);
873 }
874 #[test]
875 fn test_kmeans_fit_2_clusters() {
876 let data = vec![
877 vec![0.0, 0.0],
878 vec![0.1, 0.0],
879 vec![10.0, 10.0],
880 vec![10.1, 10.0],
881 ];
882 let mut km = KMeans::new(2, 100);
883 let assignments = km.fit(&data, 42);
884 assert_eq!(assignments[0], assignments[1]);
885 assert_eq!(assignments[2], assignments[3]);
886 assert_ne!(assignments[0], assignments[2]);
887 }
888 #[test]
889 fn test_kmeans_inertia() {
890 let data = vec![vec![0.0], vec![10.0]];
891 let mut km = KMeans::new(2, 100);
892 km.fit(&data, 0);
893 assert!(km.inertia(&data) < 1.0);
894 }
895 #[test]
896 fn test_linear_regression_fit_with_bias() {
897 let x_data = vec![vec![1.0], vec![2.0], vec![3.0]];
898 let y_data = vec![3.0, 5.0, 7.0];
899 let model = LinearRegression::fit_least_squares(&x_data, &y_data);
900 assert!((model.weights[0] - 2.0).abs() < 1e-8);
901 assert!((model.bias - 1.0).abs() < 1e-8);
902 assert!(model.r_squared(&x_data, &y_data) > 0.999);
903 }
904 #[test]
905 fn test_linear_regression_mse() {
906 let model = LinearRegression {
907 weights: vec![1.0],
908 bias: 0.0,
909 };
910 assert!(model.mse(&[vec![1.0], vec![2.0]], &[1.0, 2.0]) < 1e-10);
911 }
912 #[test]
913 fn test_polynomial_regression() {
914 let x_data: Vec<f64> = (0..10).map(|i| i as f64).collect();
915 let y_data: Vec<f64> = x_data.iter().map(|&x| x * x).collect();
916 let model = PolynomialRegression::fit(&x_data, &y_data, 2, 0.0001, 5000);
917 let pred = model.predict(5.0);
918 assert!((pred - 25.0).abs() < 5.0, "Got {}", pred);
919 }
920 #[test]
921 fn test_gradient_descent_quadratic() {
922 let gd = GradientDescent::new(0.1, 10_000);
923 let x_min = gd.minimize_quadratic(1.0, 0.0, 5.0);
924 assert!(x_min.abs() < 1e-4, "got {x_min}");
925 }
926 #[test]
927 fn test_gradient_descent_numerical() {
928 let gd = GradientDescent::new(0.01, 10_000);
929 let x_min = gd.minimize_numerical(&|x: f64| (x - 3.0).powi(2), 0.0);
930 assert!((x_min - 3.0).abs() < 0.1, "got {x_min}");
931 }
932 #[test]
933 fn test_momentum_sgd() {
934 let opt = MomentumSGD::new(0.01, 0.9, 10_000);
935 let x_min = opt.minimize_quadratic(1.0, 0.0, 10.0);
936 assert!(x_min.abs() < 0.1, "got {x_min}");
937 }
938 #[test]
939 fn test_adam_optimizer() {
940 let opt = AdamOptimizer::new(0.1, 10_000);
941 let x_min = opt.minimize_quadratic(1.0, 0.0, 10.0);
942 assert!(x_min.abs() < 0.1, "got {x_min}");
943 }
944 #[test]
945 fn test_min_max_normalize() {
946 let data = vec![vec![0.0, 10.0], vec![5.0, 20.0], vec![10.0, 30.0]];
947 let normed = min_max_normalize(&data);
948 assert!((normed[0][0]).abs() < 1e-10);
949 assert!((normed[2][0] - 1.0).abs() < 1e-10);
950 assert!((normed[1][0] - 0.5).abs() < 1e-10);
951 }
952 #[test]
953 fn test_z_score_normalize() {
954 let data = vec![vec![1.0], vec![2.0], vec![3.0]];
955 let normed = z_score_normalize(&data);
956 let mean: f64 = normed.iter().map(|r| r[0]).sum::<f64>() / 3.0;
957 assert!(mean.abs() < 1e-10);
958 }
959 #[test]
960 fn test_accuracy_metric() {
961 assert!((accuracy(&[0, 1, 1, 0], &[0, 1, 0, 0]) - 0.75).abs() < 1e-10);
962 }
963 #[test]
964 fn test_precision_recall_f1() {
965 let pred = vec![1, 1, 0, 0, 1];
966 let actual = vec![1, 0, 0, 1, 1];
967 let p = precision(&pred, &actual, 1);
968 let r = recall(&pred, &actual, 1);
969 let _f1 = f1_score(&pred, &actual, 1);
970 assert!((p - 2.0 / 3.0).abs() < 1e-10);
971 assert!((r - 2.0 / 3.0).abs() < 1e-10);
972 }
973 #[test]
974 fn test_regularization() {
975 let w = vec![1.0, 2.0, 3.0];
976 assert!((l2_penalty(&w, 0.1) - 1.4).abs() < 1e-10);
977 assert!((l1_penalty(&w, 0.1) - 0.6).abs() < 1e-10);
978 }
979 #[test]
980 fn test_elastic_net() {
981 let w = vec![1.0, 2.0];
982 assert!((elastic_net_penalty(&w, 1.0, 0.0) - l2_penalty(&w, 1.0)).abs() < 1e-10);
983 assert!((elastic_net_penalty(&w, 1.0, 1.0) - l1_penalty(&w, 1.0)).abs() < 1e-10);
984 }
985 #[test]
986 fn test_train_test_split() {
987 let data: Vec<i32> = (0..10).collect();
988 let (train, test) = train_test_split(&data, 0.8);
989 assert_eq!(train.len(), 8);
990 assert_eq!(test.len(), 2);
991 }
992 #[test]
993 fn test_k_fold_indices() {
994 let folds = k_fold_indices(10, 5);
995 assert_eq!(folds.len(), 5);
996 for (train, test) in &folds {
997 assert_eq!(test.len(), 2);
998 assert_eq!(train.len(), 8);
999 }
1000 }
1001 #[test]
1002 fn test_build_machine_learning_env() {
1003 let mut env = Environment::new();
1004 build_machine_learning_env(&mut env).expect("build_machine_learning_env should succeed");
1005 assert!(!env.is_empty());
1006 }
1007 #[test]
1008 fn test_ewc_penalty_zero_at_theta_star() {
1009 let mut ewc = ElasticWeightConsolidation::new(1.0);
1010 let params = vec![1.0, 2.0, 3.0];
1011 let grads = vec![vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]];
1012 ewc.consolidate(¶ms, &grads);
1013 let pen = ewc.penalty(¶ms);
1014 assert!(pen.abs() < 1e-10, "penalty at theta* = {pen}");
1015 }
1016 #[test]
1017 fn test_ewc_penalty_nonzero_elsewhere() {
1018 let mut ewc = ElasticWeightConsolidation::new(1.0);
1019 let params = vec![1.0, 2.0];
1020 let grads = vec![vec![1.0, 1.0]];
1021 ewc.consolidate(¶ms, &grads);
1022 let shifted = vec![2.0, 3.0];
1023 let pen = ewc.penalty(&shifted);
1024 assert!(pen > 0.0, "penalty should be positive when shifted");
1025 }
1026 #[test]
1027 fn test_ewc_penalty_gradient_direction() {
1028 let mut ewc = ElasticWeightConsolidation::new(1.0);
1029 let theta_star = vec![0.0, 0.0];
1030 let grads = vec![vec![1.0, 1.0]];
1031 ewc.consolidate(&theta_star, &grads);
1032 let params = vec![1.0, -1.0];
1033 let grad = ewc.penalty_gradient(¶ms);
1034 assert!(grad[0] > 0.0);
1035 assert!(grad[1] < 0.0);
1036 }
1037 #[test]
1038 fn test_shapley_explainer_constant_model() {
1039 let bg = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
1040 let explainer = ShapleyExplainer::new(2, 200, bg);
1041 let x = vec![3.0, 4.0];
1042 let phi = explainer.explain(&x, &|_inp: &[f64]| 5.0_f64);
1043 for &v in &phi {
1044 assert!(
1045 v.abs() < 1e-10,
1046 "constant model => all Shapley = 0, got {v}"
1047 );
1048 }
1049 }
1050 #[test]
1051 fn test_shapley_explainer_linear_model() {
1052 let bg = vec![vec![0.0, 0.0]];
1053 let explainer = ShapleyExplainer::new(2, 500, bg);
1054 let x = vec![3.0, 5.0];
1055 let phi = explainer.explain(&x, &|inp: &[f64]| inp.iter().sum::<f64>());
1056 assert!((phi[0] - 3.0).abs() < 0.5, "phi[0]={}", phi[0]);
1057 assert!((phi[1] - 5.0).abs() < 0.5, "phi[1]={}", phi[1]);
1058 }
1059 #[test]
1060 fn test_randomized_smoothing_predict() {
1061 let smoother = RandomizedSmoothingClassifier::new(0.1, 200, 0.95);
1062 let base = |x: &[f64]| if x[0] > 0.5 { 1usize } else { 0usize };
1063 let cls = smoother.smooth_predict(&[1.0, 0.0], &base);
1064 assert_eq!(cls, 1);
1065 }
1066 #[test]
1067 fn test_randomized_smoothing_certify() {
1068 let smoother = RandomizedSmoothingClassifier::new(0.1, 500, 0.95);
1069 let base = |x: &[f64]| if x[0] > 0.5 { 1usize } else { 0usize };
1070 let (cls, radius) = smoother.certify(&[1.0, 0.0], &base);
1071 assert_eq!(cls, 1);
1072 assert!(radius >= 0.0);
1073 }
1074 #[test]
1075 fn test_pac_bayes_mcallester() {
1076 let bound_calc = PACBayesBound::new(0.05, 1000);
1077 let bound = bound_calc.mcallester(0.1, 1.0);
1078 assert!(bound > 0.1);
1079 assert!(bound.is_finite());
1080 }
1081 #[test]
1082 fn test_pac_bayes_catoni() {
1083 let bound_calc = PACBayesBound::new(0.05, 1000);
1084 let bound = bound_calc.catoni(0.1, 1.0, 1.0);
1085 assert!(bound.is_finite());
1086 assert!(bound > 0.0);
1087 }
1088 #[test]
1089 fn test_pac_bayes_kl_bernoulli() {
1090 let kl = PACBayesBound::kl_bernoulli(0.3, 0.3);
1091 assert!(kl.abs() < 1e-10);
1092 assert!(PACBayesBound::kl_bernoulli(0.1, 0.9) > 0.0);
1093 }
1094 #[test]
1095 fn test_pac_bayes_kl_gaussians() {
1096 let kl = PACBayesBound::kl_gaussians(1.0, 1.0, 1.0);
1097 assert!(kl.abs() < 1e-10);
1098 assert!(PACBayesBound::kl_gaussians(0.0, 1.0, 1.0) > 0.0);
1099 }
1100 #[test]
1101 fn test_uncertainty_sampler_least_confident() {
1102 let sampler = UncertaintySampler::new(UncertaintyStrategy::LeastConfident);
1103 let certain = vec![0.99, 0.01];
1104 let uncertain = vec![0.5, 0.5];
1105 assert!(sampler.score(&uncertain) > sampler.score(&certain));
1106 }
1107 #[test]
1108 fn test_uncertainty_sampler_margin() {
1109 let sampler = UncertaintySampler::new(UncertaintyStrategy::MarginSampling);
1110 let small_margin = vec![0.51, 0.49];
1111 let large_margin = vec![0.99, 0.01];
1112 assert!(sampler.score(&small_margin) > sampler.score(&large_margin));
1113 }
1114 #[test]
1115 fn test_uncertainty_sampler_entropy() {
1116 let sampler = UncertaintySampler::new(UncertaintyStrategy::Entropy);
1117 let uniform = vec![0.25, 0.25, 0.25, 0.25];
1118 let peaked = vec![0.97, 0.01, 0.01, 0.01];
1119 assert!(sampler.score(&uniform) > sampler.score(&peaked));
1120 }
1121 #[test]
1122 fn test_uncertainty_sampler_select_query() {
1123 let sampler = UncertaintySampler::new(UncertaintyStrategy::Entropy);
1124 let candidates = vec![vec![0.99, 0.01], vec![0.5, 0.5], vec![0.7, 0.3]];
1125 let idx = sampler.select_query(&candidates);
1126 assert_eq!(idx, 1);
1127 }
1128 #[test]
1129 fn test_register_advanced_ml_axioms() {
1130 let mut env = Environment::new();
1131 register_advanced_ml_axioms(&mut env);
1132 assert!(!env.is_empty());
1133 }
1134}