1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 ActionValueFunction, ActorCritic, AlmostSureStability, BeliefMDP, ErgodicControl,
9 ExponentialMSStability, HInfinityControl, MeanFieldGame, MeanFieldGameSolver,
10 MeanSquareStability, NashEquilibrium, PathwiseSDE, Policy, PolicyGradient, PursuitEvasionGame,
11 QLearning, QLearningSolver, RiccatiEquation, RiskSensitiveControl, RiskSensitiveCost, SDGame,
12 StochasticLyapunov, ValueIteration, ZeroSumSDG, MDP, SARSA,
13};
14
15pub fn app(f: Expr, a: Expr) -> Expr {
16 Expr::App(Box::new(f), Box::new(a))
17}
18pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
19 app(app(f, a), b)
20}
21pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
22 app(app2(f, a, b), c)
23}
24pub fn cst(s: &str) -> Expr {
25 Expr::Const(Name::str(s), vec![])
26}
27pub fn prop() -> Expr {
28 Expr::Sort(Level::zero())
29}
30pub fn type0() -> Expr {
31 Expr::Sort(Level::succ(Level::zero()))
32}
33pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
34 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
35}
36pub fn arrow(a: Expr, b: Expr) -> Expr {
37 pi(BinderInfo::Default, "_", a, b)
38}
39pub fn bvar(n: u32) -> Expr {
40 Expr::BVar(n)
41}
42pub fn nat_ty() -> Expr {
43 cst("Nat")
44}
45pub fn real_ty() -> Expr {
46 cst("Real")
47}
48pub fn list_ty(elem: Expr) -> Expr {
49 app(cst("List"), elem)
50}
51pub fn bool_ty() -> Expr {
52 cst("Bool")
53}
54pub fn fn_ty(dom: Expr, cod: Expr) -> Expr {
55 arrow(dom, cod)
56}
57pub fn mdp_ty() -> Expr {
60 let rel = fn_ty(
61 nat_ty(),
62 fn_ty(
63 nat_ty(),
64 fn_ty(real_ty(), fn_ty(nat_ty(), fn_ty(real_ty(), prop()))),
65 ),
66 );
67 arrow(rel, prop())
68}
69pub fn policy_ty() -> Expr {
72 fn_ty(nat_ty(), nat_ty())
73}
74pub fn stochastic_policy_ty() -> Expr {
77 fn_ty(nat_ty(), fn_ty(nat_ty(), real_ty()))
78}
79pub fn value_function_ty() -> Expr {
82 fn_ty(fn_ty(nat_ty(), nat_ty()), fn_ty(nat_ty(), real_ty()))
83}
84pub fn action_value_function_ty() -> Expr {
87 fn_ty(
88 fn_ty(nat_ty(), nat_ty()),
89 fn_ty(nat_ty(), fn_ty(nat_ty(), real_ty())),
90 )
91}
92pub fn bellman_operator_ty() -> Expr {
95 prop()
96}
97pub fn policy_evaluation_ty() -> Expr {
100 prop()
101}
102pub fn policy_improvement_ty() -> Expr {
105 prop()
106}
107pub fn value_iteration_ty() -> Expr {
110 prop()
111}
112pub fn hjb_ty() -> Expr {
115 prop()
116}
117pub fn stochastic_optimal_control_ty() -> Expr {
120 prop()
121}
122pub fn lqr_ty() -> Expr {
125 let mat = list_ty(list_ty(real_ty()));
126 arrow(
127 mat.clone(),
128 arrow(mat.clone(), arrow(mat.clone(), arrow(mat, prop()))),
129 )
130}
131pub fn riccati_equation_ty() -> Expr {
134 prop()
135}
136pub fn solve_riccati_ty() -> Expr {
139 prop()
140}
141pub fn optimal_gain_matrix_ty() -> Expr {
144 prop()
145}
146pub fn infinite_horizon_lqr_ty() -> Expr {
149 prop()
150}
151pub fn q_learning_ty() -> Expr {
154 arrow(
155 real_ty(),
156 arrow(
157 real_ty(),
158 arrow(
159 nat_ty(),
160 arrow(nat_ty(), arrow(real_ty(), arrow(nat_ty(), prop()))),
161 ),
162 ),
163 )
164}
165pub fn sarsa_ty() -> Expr {
168 arrow(
169 real_ty(),
170 arrow(
171 real_ty(),
172 arrow(
173 nat_ty(),
174 arrow(
175 nat_ty(),
176 arrow(real_ty(), arrow(nat_ty(), arrow(nat_ty(), prop()))),
177 ),
178 ),
179 ),
180 )
181}
182pub fn policy_gradient_ty() -> Expr {
185 prop()
186}
187pub fn actor_critic_ty() -> Expr {
190 prop()
191}
192pub fn convergence_rate_rl_ty() -> Expr {
195 prop()
196}
197pub fn zero_sum_sdg_ty() -> Expr {
200 prop()
201}
202pub fn nash_equilibrium_ty() -> Expr {
205 prop()
206}
207pub fn isaac_equation_ty() -> Expr {
210 prop()
211}
212pub fn pursuit_evasion_game_ty() -> Expr {
215 prop()
216}
217pub fn stochastic_lyapunov_ty() -> Expr {
220 prop()
221}
222pub fn mean_square_stability_ty() -> Expr {
225 prop()
226}
227pub fn almost_sure_stability_ty() -> Expr {
230 prop()
231}
232pub fn exponential_ms_stability_ty() -> Expr {
235 prop()
236}
237pub fn pomdp_ty() -> Expr {
241 prop()
242}
243pub fn belief_state_ty() -> Expr {
246 arrow(fn_ty(nat_ty(), real_ty()), prop())
247}
248pub fn belief_update_ty() -> Expr {
251 arrow(
252 fn_ty(nat_ty(), real_ty()),
253 arrow(
254 nat_ty(),
255 arrow(nat_ty(), arrow(fn_ty(nat_ty(), real_ty()), prop())),
256 ),
257 )
258}
259pub fn belief_mdp_ty() -> Expr {
262 prop()
263}
264pub fn qmdp_approximation_ty() -> Expr {
267 prop()
268}
269pub fn entropic_risk_ty() -> Expr {
272 arrow(real_ty(), arrow(fn_ty(nat_ty(), real_ty()), real_ty()))
273}
274pub fn cvar_optimization_ty() -> Expr {
277 arrow(real_ty(), prop())
278}
279pub fn coherent_risk_measure_ty() -> Expr {
283 arrow(fn_ty(fn_ty(nat_ty(), real_ty()), real_ty()), prop())
284}
285pub fn risk_sensitive_bellman_ty() -> Expr {
288 arrow(real_ty(), prop())
289}
290pub fn minimax_control_ty() -> Expr {
293 prop()
294}
295pub fn h_infinity_stochastic_ty() -> Expr {
298 prop()
299}
300pub fn ambiguity_set_ty() -> Expr {
303 arrow(fn_ty(fn_ty(nat_ty(), real_ty()), prop()), prop())
304}
305pub fn distributionally_robust_mdp_ty() -> Expr {
308 prop()
309}
310pub fn mean_field_game_ty() -> Expr {
313 prop()
314}
315pub fn mckean_vlasov_sde_ty() -> Expr {
318 prop()
319}
320pub fn mfg_nash_equilibrium_ty() -> Expr {
323 arrow(fn_ty(nat_ty(), real_ty()), prop())
324}
325pub fn mfg_consistency_ty() -> Expr {
328 prop()
329}
330pub fn two_player_zero_sum_game_ty() -> Expr {
333 let reward_kernel = fn_ty(nat_ty(), fn_ty(nat_ty(), fn_ty(nat_ty(), real_ty())));
334 arrow(reward_kernel, prop())
335}
336pub fn shapley_operator_ty() -> Expr {
339 prop()
340}
341pub fn average_cost_mdp_ty() -> Expr {
344 prop()
345}
346pub fn poisson_equation_ty() -> Expr {
349 prop()
350}
351pub fn bias_function_ty() -> Expr {
354 arrow(fn_ty(nat_ty(), real_ty()), prop())
355}
356pub fn optimal_stopping_ty() -> Expr {
359 prop()
360}
361pub fn dynkin_formula_ty() -> Expr {
364 prop()
365}
366pub fn quasi_variational_inequality_ty() -> Expr {
369 prop()
370}
371pub fn impulse_control_policy_ty() -> Expr {
374 arrow(fn_ty(nat_ty(), real_ty()), arrow(nat_ty(), prop()))
375}
376pub fn certainty_equivalence_ty() -> Expr {
379 prop()
380}
381pub fn self_tuning_regulator_ty() -> Expr {
384 prop()
385}
386pub fn q_learning_convergence_ty() -> Expr {
389 arrow(real_ty(), arrow(real_ty(), prop()))
390}
391pub fn fitted_value_iteration_ty() -> Expr {
394 prop()
395}
396pub fn joint_policy_convergence_ty() -> Expr {
399 arrow(nat_ty(), prop())
400}
401pub fn correlated_equilibrium_ty() -> Expr {
405 let joint_dist = fn_ty(fn_ty(nat_ty(), fn_ty(nat_ty(), real_ty())), prop());
406 arrow(joint_dist, prop())
407}
408pub fn regret_minimisation_ty() -> Expr {
411 arrow(real_ty(), prop())
412}
413pub fn build_env(env: &mut Environment) {
415 let axioms: &[(&str, Expr)] = &[
416 ("MDP", mdp_ty()),
417 ("Policy", policy_ty()),
418 ("StochasticPolicy", stochastic_policy_ty()),
419 ("ValueFunction", value_function_ty()),
420 ("ActionValueFunction", action_value_function_ty()),
421 ("BellmanOperator", bellman_operator_ty()),
422 ("PolicyEvaluation", policy_evaluation_ty()),
423 ("PolicyImprovement", policy_improvement_ty()),
424 ("ValueIteration", value_iteration_ty()),
425 ("HamiltonJacobiBellman", hjb_ty()),
426 ("StochasticOptimalControl", stochastic_optimal_control_ty()),
427 ("LinearQuadraticRegulator", lqr_ty()),
428 ("RiccatiEquation", riccati_equation_ty()),
429 ("SolveRiccati", solve_riccati_ty()),
430 ("OptimalGainMatrix", optimal_gain_matrix_ty()),
431 ("InfiniteHorizonLqr", infinite_horizon_lqr_ty()),
432 ("QLearning", q_learning_ty()),
433 ("SARSA", sarsa_ty()),
434 ("PolicyGradient", policy_gradient_ty()),
435 ("ActorCritic", actor_critic_ty()),
436 ("ConvergenceRateRL", convergence_rate_rl_ty()),
437 ("ZeroSumSDG", zero_sum_sdg_ty()),
438 ("NashEquilibrium", nash_equilibrium_ty()),
439 ("IsaacEquation", isaac_equation_ty()),
440 ("PursuitEvasionGame", pursuit_evasion_game_ty()),
441 ("StochasticLyapunov", stochastic_lyapunov_ty()),
442 ("MeanSquareStability", mean_square_stability_ty()),
443 ("AlmostSureStability", almost_sure_stability_ty()),
444 ("ExponentialMSStability", exponential_ms_stability_ty()),
445 ("POMDP", pomdp_ty()),
446 ("BeliefState", belief_state_ty()),
447 ("BeliefUpdate", belief_update_ty()),
448 ("BeliefMDP", belief_mdp_ty()),
449 ("QMDPApproximation", qmdp_approximation_ty()),
450 ("EntropicRisk", entropic_risk_ty()),
451 ("CVaROptimization", cvar_optimization_ty()),
452 ("CoherentRiskMeasure", coherent_risk_measure_ty()),
453 ("RiskSensitiveBellman", risk_sensitive_bellman_ty()),
454 ("MinimaxControl", minimax_control_ty()),
455 ("HInfinityStochastic", h_infinity_stochastic_ty()),
456 ("AmbiguitySet", ambiguity_set_ty()),
457 (
458 "DistributionallyRobustMDP",
459 distributionally_robust_mdp_ty(),
460 ),
461 ("MeanFieldGame", mean_field_game_ty()),
462 ("McKeanVlasovSDE", mckean_vlasov_sde_ty()),
463 ("MFGNashEquilibrium", mfg_nash_equilibrium_ty()),
464 ("MFGConsistencyCondition", mfg_consistency_ty()),
465 ("TwoPlayerZeroSumGame", two_player_zero_sum_game_ty()),
466 ("ShapleyOperator", shapley_operator_ty()),
467 ("AverageCostMDP", average_cost_mdp_ty()),
468 ("PoissonEquation", poisson_equation_ty()),
469 ("BiasFunction", bias_function_ty()),
470 ("OptimalStopping", optimal_stopping_ty()),
471 ("DynkinFormula", dynkin_formula_ty()),
472 (
473 "QuasiVariationalInequality",
474 quasi_variational_inequality_ty(),
475 ),
476 ("ImpulseControlPolicy", impulse_control_policy_ty()),
477 ("CertaintyEquivalence", certainty_equivalence_ty()),
478 ("SelfTuningRegulator", self_tuning_regulator_ty()),
479 ("QlearningConvergence", q_learning_convergence_ty()),
480 ("FittedValueIteration", fitted_value_iteration_ty()),
481 ("JointPolicyConvergence", joint_policy_convergence_ty()),
482 ("CorrelatedEquilibrium", correlated_equilibrium_ty()),
483 ("RegretMinimisation", regret_minimisation_ty()),
484 ];
485 for (name, ty) in axioms {
486 env.add(Declaration::Axiom {
487 name: Name::str(*name),
488 univ_params: vec![],
489 ty: ty.clone(),
490 })
491 .ok();
492 }
493}
494#[cfg(test)]
495mod tests {
496 use super::*;
497 #[test]
498 fn test_build_env() {
499 let mut env = Environment::new();
500 build_env(&mut env);
501 assert!(env.get(&Name::str("MDP")).is_some());
502 assert!(env.get(&Name::str("ValueFunction")).is_some());
503 assert!(env.get(&Name::str("QLearning")).is_some());
504 assert!(env.get(&Name::str("NashEquilibrium")).is_some());
505 assert!(env.get(&Name::str("StochasticLyapunov")).is_some());
506 }
507 #[test]
508 fn test_mdp_value_iteration() {
509 let transitions = vec![
510 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
511 vec![vec![0.0, 1.0], vec![0.0, 1.0]],
512 ];
513 let rewards = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
514 let mdp = MDP::new(2, 2, transitions, rewards, 0.9);
515 let v = mdp.value_iteration(1e-8, 1000);
516 assert!((v[1] - 10.0).abs() < 0.01, "V*(1) ≈ 10, got {}", v[1]);
517 assert!((v[0] - 9.0).abs() < 0.01, "V*(0) ≈ 9, got {}", v[0]);
518 }
519 #[test]
520 fn test_policy_evaluation() {
521 let transitions = vec![
522 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
523 vec![vec![0.0, 1.0], vec![0.0, 1.0]],
524 ];
525 let rewards = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
526 let mdp = MDP::new(2, 2, transitions, rewards, 0.9);
527 let policy = vec![1, 0];
528 let v = mdp.policy_evaluation(&policy, 1e-8, 1000);
529 assert!(v[1] > v[0], "Good state should have higher value");
530 }
531 #[test]
532 fn test_q_learning_update() {
533 let mut agent = QLearning::new(3, 2, 0.5, 0.9);
534 agent.update(0, 1, 1.0, 1);
535 assert!((agent.q[0][1] - 0.5).abs() < 1e-12, "Q(0,1) should be 0.5");
536 }
537 #[test]
538 fn test_sarsa_update() {
539 let mut agent = SARSA::new(3, 2, 0.5, 0.9);
540 agent.update(0, 1, 1.0, 1, 0);
541 assert!(
542 (agent.q[0][1] - 0.5).abs() < 1e-12,
543 "SARSA Q(0,1) should be 0.5"
544 );
545 }
546 #[test]
547 fn test_policy_gradient_softmax() {
548 let agent = PolicyGradient::new(2, 3, 0.01, 0.99);
549 let pi = agent.softmax(0);
550 let total: f64 = pi.iter().sum();
551 assert!((total - 1.0).abs() < 1e-12, "softmax should sum to 1");
552 assert!((pi[0] - 1.0 / 3.0).abs() < 1e-12);
553 }
554 #[test]
555 fn test_actor_critic_update() {
556 let mut ac = ActorCritic::new(2, 2, 0.1, 0.1, 0.9);
557 let v0 = ac.expected_return(0);
558 ac.update(0, 1, 1.0, 1);
559 assert!(
560 ac.expected_return(0) > v0,
561 "critic value should increase after positive reward"
562 );
563 }
564 #[test]
565 fn test_pursuit_evasion() {
566 let game = PursuitEvasionGame::new([0.0, 0.0], [3.0, 4.0], 2.0, 1.0);
567 assert!((game.distance() - 5.0).abs() < 1e-12);
568 assert!(game.pursuer_wins());
569 assert!((game.capture_time_estimate() - 5.0).abs() < 1e-12);
570 }
571 #[test]
572 fn test_stochastic_lyapunov_bound() {
573 let lv = StochasticLyapunov::new(0.5, 0.1);
574 assert!(lv.check(-0.9, 2.0));
575 let v0 = 10.0;
576 assert!((lv.ev_upper_bound(v0, 0.0) - v0).abs() < 1e-12);
577 }
578 #[test]
579 fn test_exponential_ms_stability() {
580 let stab = ExponentialMSStability::new(5.0, 0.5);
581 assert!((stab.bound(0.0) - 5.0).abs() < 1e-12);
582 assert!(stab.bound(10.0) < 0.1);
583 assert!(stab.check(4.9, 0.0));
584 assert!(!stab.check(5.1, 0.0));
585 }
586 #[test]
587 fn test_riccati_solve() {
588 let riccati = RiccatiEquation::new(
589 vec![vec![-1.0]],
590 vec![vec![1.0]],
591 vec![vec![1.0]],
592 vec![vec![1.0]],
593 );
594 let (p, k) = riccati.infinite_horizon_lqr();
595 assert!(p[0][0] > 0.0, "P should be positive");
596 assert!(k[0][0] > 0.0, "K should be positive");
597 }
598 #[test]
599 fn test_build_env_new_axioms() {
600 let mut env = Environment::new();
601 build_env(&mut env);
602 assert!(env.get(&Name::str("POMDP")).is_some());
603 assert!(env.get(&Name::str("BeliefState")).is_some());
604 assert!(env.get(&Name::str("BeliefUpdate")).is_some());
605 assert!(env.get(&Name::str("BeliefMDP")).is_some());
606 assert!(env.get(&Name::str("QMDPApproximation")).is_some());
607 assert!(env.get(&Name::str("EntropicRisk")).is_some());
608 assert!(env.get(&Name::str("CVaROptimization")).is_some());
609 assert!(env.get(&Name::str("CoherentRiskMeasure")).is_some());
610 assert!(env.get(&Name::str("RiskSensitiveBellman")).is_some());
611 assert!(env.get(&Name::str("MinimaxControl")).is_some());
612 assert!(env.get(&Name::str("HInfinityStochastic")).is_some());
613 assert!(env.get(&Name::str("AmbiguitySet")).is_some());
614 assert!(env.get(&Name::str("DistributionallyRobustMDP")).is_some());
615 assert!(env.get(&Name::str("MeanFieldGame")).is_some());
616 assert!(env.get(&Name::str("McKeanVlasovSDE")).is_some());
617 assert!(env.get(&Name::str("MFGNashEquilibrium")).is_some());
618 assert!(env.get(&Name::str("MFGConsistencyCondition")).is_some());
619 assert!(env.get(&Name::str("TwoPlayerZeroSumGame")).is_some());
620 assert!(env.get(&Name::str("ShapleyOperator")).is_some());
621 assert!(env.get(&Name::str("AverageCostMDP")).is_some());
622 assert!(env.get(&Name::str("PoissonEquation")).is_some());
623 assert!(env.get(&Name::str("BiasFunction")).is_some());
624 assert!(env.get(&Name::str("OptimalStopping")).is_some());
625 assert!(env.get(&Name::str("DynkinFormula")).is_some());
626 assert!(env.get(&Name::str("QuasiVariationalInequality")).is_some());
627 assert!(env.get(&Name::str("ImpulseControlPolicy")).is_some());
628 assert!(env.get(&Name::str("CertaintyEquivalence")).is_some());
629 assert!(env.get(&Name::str("SelfTuningRegulator")).is_some());
630 assert!(env.get(&Name::str("QlearningConvergence")).is_some());
631 assert!(env.get(&Name::str("FittedValueIteration")).is_some());
632 assert!(env.get(&Name::str("JointPolicyConvergence")).is_some());
633 assert!(env.get(&Name::str("CorrelatedEquilibrium")).is_some());
634 assert!(env.get(&Name::str("RegretMinimisation")).is_some());
635 }
636 #[test]
637 fn test_belief_update_normalises() {
638 let transitions = vec![vec![vec![1.0, 0.0]], vec![vec![0.0, 1.0]]];
639 let observations = vec![vec![vec![0.9, 0.1]], vec![vec![0.2, 0.8]]];
640 let pomdp = BeliefMDP::new(2, 1, 2, transitions, observations);
641 let belief = vec![0.5, 0.5];
642 let new_belief = pomdp.belief_update(&belief, 0, 0);
643 let total: f64 = new_belief.iter().sum();
644 assert!((total - 1.0).abs() < 1e-10, "belief should sum to 1");
645 assert!((new_belief[0] - 0.45 / 0.55).abs() < 1e-10, "b'(0) ≈ 0.818");
646 }
647 #[test]
648 fn test_qmdp_value() {
649 let transitions = vec![vec![vec![1.0, 0.0]], vec![vec![0.0, 1.0]]];
650 let observations = vec![vec![vec![1.0, 0.0]], vec![vec![0.0, 1.0]]];
651 let pomdp = BeliefMDP::new(2, 1, 2, transitions, observations);
652 let q_star = vec![vec![1.0], vec![5.0]];
653 let belief = vec![0.5, 0.5];
654 let v = pomdp.qmdp_value(&belief, &q_star);
655 assert!((v - 3.0).abs() < 1e-12, "QMDP value should be 3.0");
656 }
657 #[test]
658 fn test_cvar_basic() {
659 let rsc = RiskSensitiveCost::new(0.5, 1.0);
660 let samples: Vec<f64> = (1..=10).map(|x| x as f64).collect();
661 let cvar = rsc.cvar(&samples);
662 assert!(cvar >= 6.0, "CVaR_0.5 should be ≥ 6.0, got {cvar}");
663 }
664 #[test]
665 fn test_entropic_risk() {
666 let rsc = RiskSensitiveCost::new(0.5, 2.0);
667 let samples = vec![1.0_f64; 100];
668 let er = rsc.entropic_risk(&samples);
669 assert!((er - 1.0).abs() < 1e-10, "ρ_θ(1) = 1, got {er}");
670 }
671 #[test]
672 fn test_mean_field_game_solver() {
673 let transitions = vec![
674 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
675 vec![vec![0.0, 1.0], vec![0.0, 1.0]],
676 ];
677 let rewards = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
678 let solver = MeanFieldGameSolver::new(2, 2, transitions, rewards, 0.9, 1e-6, 1000);
679 let (policy, mu) = solver.solve();
680 assert_eq!(policy.len(), 2);
681 let total: f64 = mu.iter().sum();
682 assert!((total - 1.0).abs() < 1e-6, "MFG distribution sums to 1");
683 }
684 #[test]
685 fn test_value_iteration_solver() {
686 let transitions = vec![
687 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
688 vec![vec![0.0, 1.0], vec![0.0, 1.0]],
689 ];
690 let rewards = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
691 let vi = ValueIteration::new(2, 2, transitions, rewards, 0.9);
692 let (v, pi) = vi.run(1e-8, 1000);
693 assert!((v[1] - 10.0).abs() < 0.01, "V*(1) ≈ 10");
694 assert!((v[0] - 9.0).abs() < 0.01, "V*(0) ≈ 9");
695 assert_eq!(pi.len(), 2);
696 }
697 #[test]
698 fn test_q_from_v() {
699 let transitions = vec![
700 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
701 vec![vec![0.0, 1.0], vec![0.0, 1.0]],
702 ];
703 let rewards = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
704 let vi = ValueIteration::new(2, 2, transitions.clone(), rewards.clone(), 0.9);
705 let (v, _) = vi.run(1e-8, 1000);
706 let q = vi.q_from_v(&v);
707 assert!((q[1][0] - 10.0).abs() < 0.1, "Q*(1,0) ≈ 10");
708 }
709 #[test]
710 fn test_span_convergence() {
711 let transitions = vec![
712 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
713 vec![vec![0.0, 1.0], vec![0.0, 1.0]],
714 ];
715 let rewards = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
716 let vi = ValueIteration::new(2, 2, transitions, rewards, 0.9);
717 let (v, _) = vi.run(1e-8, 1000);
718 let span = vi.span(&v);
719 assert!(
720 span < 1e-6,
721 "span should be near 0 at convergence, got {span}"
722 );
723 }
724 #[test]
725 fn test_q_learning_solver_update() {
726 let mut solver = QLearningSolver::new(3, 2, 1.0, 0.9, 0.1);
727 let prev_q = solver.q.clone();
728 solver.update(0, 1, 1.0, 1);
729 assert!(
730 (solver.q[0][1] - prev_q[0][1]).abs() > 1e-12,
731 "Q(0,1) should have been updated"
732 );
733 }
734 #[test]
735 fn test_q_learning_solver_greedy_policy() {
736 let mut solver = QLearningSolver::new(2, 3, 0.5, 0.9, 0.0);
737 solver.q[0][2] = 5.0;
738 solver.q[1][1] = 3.0;
739 let policy = solver.greedy_policy();
740 assert_eq!(policy[0], 2, "state 0 should choose action 2");
741 assert_eq!(policy[1], 1, "state 1 should choose action 1");
742 }
743 #[test]
744 fn test_q_learning_solver_convergence_check() {
745 let solver = QLearningSolver::new(2, 2, 0.5, 0.9, 0.1);
746 let same_q = solver.q.clone();
747 assert!(
748 solver.has_converged(&same_q, 1e-10),
749 "identical Q tables should be converged"
750 );
751 let mut diff_q = same_q.clone();
752 diff_q[0][0] += 1.0;
753 assert!(
754 !solver.has_converged(&diff_q, 1e-10),
755 "different Q tables should not be converged"
756 );
757 }
758}
759#[cfg(test)]
760mod tests_stoch_control_ext {
761 use super::*;
762 #[test]
763 fn test_sd_game_zero_sum() {
764 let mut game = SDGame::zero_sum(1.0);
765 assert!(game.is_zero_sum);
766 assert!(game.saddle_point_exists());
767 let isaacs = game.isaacs_equation();
768 assert!(isaacs.contains("Isaacs"));
769 game.set_value(2.5);
770 assert_eq!(game.value_function, Some(2.5));
771 }
772 #[test]
773 fn test_mean_field_game() {
774 let mfg = MeanFieldGame::new(1000, 0.5);
775 assert!((mfg.convergence_rate - 1.0 / (1000_f64).sqrt()).abs() < 1e-10);
776 let desc = mfg.mfg_system_description();
777 assert!(desc.contains("1000"));
778 let poa = mfg.price_of_anarchy();
779 assert!(poa > 1.0);
780 let master = mfg.master_equation();
781 assert!(master.contains("FP") || master.contains("∂_t"));
782 }
783 #[test]
784 fn test_risk_sensitive_control() {
785 let rsc = RiskSensitiveControl::risk_averse(0.5, 1.0);
786 assert!(rsc.risk_parameter > 0.0);
787 let ce = rsc.certainty_equivalent(1.0, 0.5);
788 assert!((ce - 1.125).abs() < 1e-10);
789 assert!(rsc.is_robust_control_connection());
790 let crit = rsc.exponential_criterion();
791 assert!(crit.contains("Risk-sensitive"));
792 }
793 #[test]
794 fn test_hinf_control() {
795 let hinf = HInfinityControl::new(2.0, 3, 2, 2);
796 assert!(!hinf.is_feasible());
797 let minimax = hinf.minimax_criterion();
798 assert!(minimax.contains("H∞"));
799 let riccati = hinf.game_riccati_equation();
800 assert!(riccati.contains("ARE"));
801 }
802 #[test]
803 fn test_ergodic_control() {
804 let mut ec = ErgodicControl::new(3);
805 ec.set_eigenvalue(1.5);
806 assert_eq!(ec.long_run_cost, Some(1.5));
807 let hjb = ec.ergodic_hjb();
808 assert!(hjb.contains("ergodic"));
809 let tp = ec.turnpike_property();
810 assert!(tp.contains("Turnpike"));
811 }
812 #[test]
813 fn test_pathwise_sde_euler() {
814 let sde = PathwiseSDE::euler_maruyama("ax", "bx", 1.0, 10, 0.01);
815 assert!((sde.strong_order() - 0.5).abs() < 1e-10);
816 assert!((sde.weak_order() - 1.0).abs() < 1e-10);
817 let path = sde.simulate_one_path();
818 assert_eq!(path.len(), 11);
819 }
820 #[test]
821 fn test_pathwise_sde_milstein() {
822 let sde = PathwiseSDE::milstein("ax", "bx", 1.0, 5, 0.01);
823 assert!((sde.strong_order() - 1.0).abs() < 1e-10);
824 }
825}