Skip to main content

oxilean_std/stochastic_control/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::functions::*;
6
7/// Nash equilibrium for a two-player zero-sum stochastic differential game.
8///
9/// Stores the equilibrium strategies (u*, v*).
10#[derive(Debug, Clone)]
11pub struct NashEquilibrium {
12    /// Equilibrium strategy for player 1 (minimiser).
13    pub player1_strategy: Vec<usize>,
14    /// Equilibrium strategy for player 2 (maximiser).
15    pub player2_strategy: Vec<usize>,
16    /// Equilibrium value function.
17    pub value: Vec<f64>,
18}
19impl NashEquilibrium {
20    /// Construct a Nash equilibrium.
21    pub fn new(
22        player1_strategy: Vec<usize>,
23        player2_strategy: Vec<usize>,
24        value: Vec<f64>,
25    ) -> Self {
26        Self {
27            player1_strategy,
28            player2_strategy,
29            value,
30        }
31    }
32    /// Check whether neither player benefits from unilateral deviation (numeric check).
33    /// Returns true if the strategies are consistent with the value function.
34    pub fn verify_nash_property(&self) -> bool {
35        !self.player1_strategy.is_empty()
36            && !self.player2_strategy.is_empty()
37            && self.player1_strategy.len() == self.value.len()
38            && self.player2_strategy.len() == self.value.len()
39    }
40}
41/// Risk-sensitive cost using Conditional Value-at-Risk (CVaR).
42///
43/// CVaR_α(X) = (1/α) ∫_{α}^{1} VaR_u(X) du
44///           ≈ E[X | X ≥ VaR_α(X)] (discrete approximation via sorting).
45#[derive(Debug, Clone)]
46pub struct RiskSensitiveCost {
47    /// Risk level α ∈ (0, 1]: CVaR_α.
48    pub alpha: f64,
49    /// Entropic risk parameter θ (for entropic risk measure ρ_θ).
50    pub theta: f64,
51}
52impl RiskSensitiveCost {
53    /// Construct a risk-sensitive cost calculator.
54    pub fn new(alpha: f64, theta: f64) -> Self {
55        Self { alpha, theta }
56    }
57    /// Compute VaR_α(X) from a sorted list of losses.
58    ///
59    /// Sorts `samples` and returns the (1-α)-quantile.
60    pub fn var(&self, samples: &[f64]) -> f64 {
61        if samples.is_empty() {
62            return 0.0;
63        }
64        let mut sorted = samples.to_vec();
65        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
66        let idx = ((1.0 - self.alpha) * sorted.len() as f64).floor() as usize;
67        sorted[idx.min(sorted.len() - 1)]
68    }
69    /// Compute CVaR_α(X): average of losses at or above VaR_α.
70    pub fn cvar(&self, samples: &[f64]) -> f64 {
71        if samples.is_empty() {
72            return 0.0;
73        }
74        let var_val = self.var(samples);
75        let tail: Vec<f64> = samples.iter().cloned().filter(|&x| x >= var_val).collect();
76        if tail.is_empty() {
77            return var_val;
78        }
79        tail.iter().sum::<f64>() / tail.len() as f64
80    }
81    /// Entropic risk measure: ρ_θ(X) = (1/θ) log E[exp(θ X)].
82    pub fn entropic_risk(&self, samples: &[f64]) -> f64 {
83        if samples.is_empty() {
84            return 0.0;
85        }
86        let mean_exp: f64 =
87            samples.iter().map(|&x| (self.theta * x).exp()).sum::<f64>() / samples.len() as f64;
88        mean_exp.ln() / self.theta
89    }
90}
91#[allow(dead_code)]
92#[derive(Debug, Clone)]
93pub struct RiskSensitiveControl {
94    pub risk_parameter: f64,
95    pub time_horizon: f64,
96    pub control_space: String,
97    pub state_space: String,
98}
99#[allow(dead_code)]
100impl RiskSensitiveControl {
101    pub fn risk_averse(theta: f64, horizon: f64) -> Self {
102        assert!(theta > 0.0, "risk-averse requires θ > 0");
103        RiskSensitiveControl {
104            risk_parameter: theta,
105            time_horizon: horizon,
106            control_space: "U".to_string(),
107            state_space: "R^n".to_string(),
108        }
109    }
110    pub fn exponential_criterion(&self) -> String {
111        format!(
112            "Risk-sensitive criterion: J(u) = (1/{:.3}) log E[exp({:.3} ∫ r dt)]",
113            self.risk_parameter, self.risk_parameter
114        )
115    }
116    pub fn risk_sensitive_hjb(&self) -> String {
117        format!(
118            "RS-HJB: 0 = ∂V/∂t + min_u[f·∇V + (1/2)tr(σσ^T ∇²V) + r + ({:.3}/2)|σ^T∇V|²]",
119            self.risk_parameter
120        )
121    }
122    pub fn certainty_equivalent(&self, expected_cost: f64, variance: f64) -> f64 {
123        expected_cost + self.risk_parameter / 2.0 * variance
124    }
125    pub fn is_robust_control_connection(&self) -> bool {
126        self.risk_parameter > 0.0
127    }
128}
129/// Mean field game solver via fixed-point iteration.
130///
131/// Iterates:
132///   1. Given population distribution μ, solve individual optimal control → policy π.
133///   2. Given policy π, simulate population dynamics → new distribution μ'.
134///   3. Repeat until ||μ' - μ|| < ε.
135#[derive(Debug, Clone)]
136pub struct MeanFieldGameSolver {
137    /// Number of states.
138    pub num_states: usize,
139    /// Number of actions.
140    pub num_actions: usize,
141    /// Transition kernel (state-dependent on population): T[s][a][s'].
142    pub transitions: Vec<Vec<Vec<f64>>>,
143    /// Reward (depends on state, action, mean population weight).
144    /// reward[s][a] — simplified to not depend on μ for tractability.
145    pub rewards: Vec<Vec<f64>>,
146    /// Discount factor.
147    pub discount: f64,
148    /// Convergence tolerance.
149    pub tol: f64,
150    /// Maximum iterations.
151    pub max_iter: usize,
152}
153impl MeanFieldGameSolver {
154    /// Construct a mean field game solver.
155    #[allow(clippy::too_many_arguments)]
156    pub fn new(
157        num_states: usize,
158        num_actions: usize,
159        transitions: Vec<Vec<Vec<f64>>>,
160        rewards: Vec<Vec<f64>>,
161        discount: f64,
162        tol: f64,
163        max_iter: usize,
164    ) -> Self {
165        Self {
166            num_states,
167            num_actions,
168            transitions,
169            rewards,
170            discount,
171            tol,
172            max_iter,
173        }
174    }
175    /// Solve the individual MDP given current population distribution (fixed μ).
176    /// Returns the greedy policy as a state→action table.
177    fn solve_individual(&self) -> Vec<usize> {
178        let mdp = MDP::new(
179            self.num_states,
180            self.num_actions,
181            self.transitions.clone(),
182            self.rewards.clone(),
183            self.discount,
184        );
185        let v = mdp.value_iteration(self.tol * 0.01, self.max_iter);
186        mdp.policy_improvement(&v)
187    }
188    /// Compute the stationary distribution of a Markov chain induced by policy π.
189    fn stationary_distribution(&self, policy: &[usize]) -> Vec<f64> {
190        let mut mu = vec![1.0_f64 / self.num_states as f64; self.num_states];
191        for _ in 0..self.max_iter {
192            let mut new_mu = vec![0.0_f64; self.num_states];
193            for s in 0..self.num_states {
194                let a = policy[s];
195                for sp in 0..self.num_states {
196                    new_mu[sp] += mu[s] * self.transitions[s][a][sp];
197                }
198            }
199            let delta: f64 = mu
200                .iter()
201                .zip(new_mu.iter())
202                .map(|(a, b)| (a - b).abs())
203                .fold(0.0_f64, f64::max);
204            mu = new_mu;
205            if delta < self.tol {
206                break;
207            }
208        }
209        mu
210    }
211    /// Run the mean field game fixed-point iteration.
212    /// Returns (equilibrium_policy, equilibrium_distribution).
213    pub fn solve(&self) -> (Vec<usize>, Vec<f64>) {
214        let policy = self.solve_individual();
215        let mu = self.stationary_distribution(&policy);
216        (policy, mu)
217    }
218}
219#[allow(dead_code)]
220#[derive(Debug, Clone)]
221pub enum NumericalSDEScheme {
222    EulerMaruyama,
223    Milstein,
224    RungeKuttaSDE,
225    StochasticTaylor,
226}
227#[allow(dead_code)]
228#[derive(Debug, Clone)]
229pub struct SDGame {
230    pub player1_strategy: String,
231    pub player2_strategy: String,
232    pub value_function: Option<f64>,
233    pub is_zero_sum: bool,
234    pub horizon: f64,
235}
236#[allow(dead_code)]
237impl SDGame {
238    pub fn zero_sum(horizon: f64) -> Self {
239        SDGame {
240            player1_strategy: "min".to_string(),
241            player2_strategy: "max".to_string(),
242            value_function: None,
243            is_zero_sum: true,
244            horizon,
245        }
246    }
247    pub fn cooperative(horizon: f64) -> Self {
248        SDGame {
249            player1_strategy: "cooperative".to_string(),
250            player2_strategy: "cooperative".to_string(),
251            value_function: None,
252            is_zero_sum: false,
253            horizon,
254        }
255    }
256    pub fn isaacs_equation(&self) -> String {
257        if self.is_zero_sum {
258            format!(
259                "Isaacs equation for zero-sum game on [0,{}]: -∂V/∂t = min_u max_v H(x,u,v,∇V)",
260                self.horizon
261            )
262        } else {
263            format!("Nash system for cooperative game on [0,{}]", self.horizon)
264        }
265    }
266    pub fn saddle_point_exists(&self) -> bool {
267        self.is_zero_sum
268    }
269    pub fn set_value(&mut self, val: f64) {
270        self.value_function = Some(val);
271    }
272}
273/// Zero-sum stochastic differential game.
274///
275/// Two players minimise/maximise J = E[∫ L(x,u,v) dt + g(x_T)].
276/// Here we represent the saddle value and equilibrium strategies as numeric tables.
277#[derive(Debug, Clone)]
278pub struct ZeroSumSDG {
279    /// Value function table: `value[s]` = saddle value from state s.
280    pub value: Vec<f64>,
281    /// Minimising player strategy: `min_strategy[s]` = optimal action index.
282    pub min_strategy: Vec<usize>,
283    /// Maximising player strategy: `max_strategy[s]` = optimal action index.
284    pub max_strategy: Vec<usize>,
285}
286impl ZeroSumSDG {
287    /// Construct a zero-sum SDG from pre-computed value and strategies.
288    pub fn new(value: Vec<f64>, min_strategy: Vec<usize>, max_strategy: Vec<usize>) -> Self {
289        Self {
290            value,
291            min_strategy,
292            max_strategy,
293        }
294    }
295    /// Return the saddle value from state `s`.
296    pub fn saddle_value(&self, s: usize) -> f64 {
297        self.value[s]
298    }
299}
300/// Exponential mean-square stability: E[||x_t||²] ≤ C e^{-λt}.
301#[derive(Debug, Clone)]
302pub struct ExponentialMSStability {
303    /// Upper bound constant C.
304    pub c: f64,
305    /// Decay rate λ > 0.
306    pub lambda: f64,
307}
308impl ExponentialMSStability {
309    /// Construct an exponential MS stability bound.
310    pub fn new(c: f64, lambda: f64) -> Self {
311        Self { c, lambda }
312    }
313    /// Evaluate the bound C e^{-λt}.
314    pub fn bound(&self, t: f64) -> f64 {
315        self.c * (-self.lambda * t).exp()
316    }
317    /// Check whether a given E[||x_t||²] value satisfies the bound at time t.
318    pub fn check(&self, ms_value: f64, t: f64) -> bool {
319        ms_value <= self.bound(t)
320    }
321}
322/// Standalone discounted MDP value iteration solver.
323///
324/// More flexible than `MDP::value_iteration` — accepts sparse transition
325/// representations via closures (approximated here via table).
326#[derive(Debug, Clone)]
327pub struct ValueIteration {
328    /// Number of states.
329    pub num_states: usize,
330    /// Number of actions.
331    pub num_actions: usize,
332    /// Transition probabilities T[s][a][s'].
333    pub transitions: Vec<Vec<Vec<f64>>>,
334    /// Reward R[s][a].
335    pub rewards: Vec<Vec<f64>>,
336    /// Discount factor γ.
337    pub discount: f64,
338}
339impl ValueIteration {
340    /// Construct a value iteration solver.
341    pub fn new(
342        num_states: usize,
343        num_actions: usize,
344        transitions: Vec<Vec<Vec<f64>>>,
345        rewards: Vec<Vec<f64>>,
346        discount: f64,
347    ) -> Self {
348        Self {
349            num_states,
350            num_actions,
351            transitions,
352            rewards,
353            discount,
354        }
355    }
356    /// Run value iteration and return (V*, π*).
357    pub fn run(&self, tol: f64, max_iter: usize) -> (Vec<f64>, Vec<usize>) {
358        let mdp = MDP::new(
359            self.num_states,
360            self.num_actions,
361            self.transitions.clone(),
362            self.rewards.clone(),
363            self.discount,
364        );
365        let v = mdp.value_iteration(tol, max_iter);
366        let pi = mdp.policy_improvement(&v);
367        (v, pi)
368    }
369    /// Compute the Q-function from the optimal value function.
370    pub fn q_from_v(&self, v: &[f64]) -> Vec<Vec<f64>> {
371        let mut q = vec![vec![0.0_f64; self.num_actions]; self.num_states];
372        for s in 0..self.num_states {
373            for a in 0..self.num_actions {
374                let mut qa = self.rewards[s][a];
375                for sp in 0..self.num_states {
376                    qa += self.discount * self.transitions[s][a][sp] * v[sp];
377                }
378                q[s][a] = qa;
379            }
380        }
381        q
382    }
383    /// Span semi-norm: measure of Bellman residual (for convergence diagnostics).
384    pub fn span(&self, v: &[f64]) -> f64 {
385        let tv_v = MDP::new(
386            self.num_states,
387            self.num_actions,
388            self.transitions.clone(),
389            self.rewards.clone(),
390            self.discount,
391        )
392        .bellman_operator(v);
393        tv_v.iter()
394            .zip(v.iter())
395            .map(|(tv, vv)| (tv - vv).abs())
396            .fold(0.0_f64, f64::max)
397    }
398}
399/// State-action value function Q^π(s,a).
400#[derive(Debug, Clone)]
401pub struct ActionValueFunction {
402    /// `q[s][a]` = Q(s, a).
403    pub q: Vec<Vec<f64>>,
404}
405impl ActionValueFunction {
406    /// Construct from a table.
407    pub fn new(q: Vec<Vec<f64>>) -> Self {
408        Self { q }
409    }
410    /// Return Q(s, a).
411    pub fn get(&self, s: usize, a: usize) -> f64 {
412        self.q[s][a]
413    }
414    /// Return max_a Q(s, a).
415    pub fn max_action_value(&self, s: usize) -> f64 {
416        self.q[s].iter().cloned().fold(f64::NEG_INFINITY, f64::max)
417    }
418    /// Return argmax_a Q(s, a).
419    pub fn greedy_action(&self, s: usize) -> usize {
420        self.q[s]
421            .iter()
422            .enumerate()
423            .max_by(|x, y| x.1.partial_cmp(y.1).unwrap_or(std::cmp::Ordering::Equal))
424            .map(|(i, _)| i)
425            .unwrap_or(0)
426    }
427}
428/// A finite Markov Decision Process (S, A, P, R, γ).
429///
430/// - `num_states`: |S|
431/// - `num_actions`: |A|
432/// - `transitions[s][a][s']`: P(s' | s, a)
433/// - `rewards[s][a]`: R(s, a)
434/// - `discount`: γ ∈ [0,1)
435#[derive(Debug, Clone)]
436pub struct MDP {
437    /// Number of states.
438    pub num_states: usize,
439    /// Number of actions.
440    pub num_actions: usize,
441    /// Transition probabilities: `transitions[s][a]` is a probability vector over next states.
442    pub transitions: Vec<Vec<Vec<f64>>>,
443    /// Expected reward: `rewards[s][a]`.
444    pub rewards: Vec<Vec<f64>>,
445    /// Discount factor γ ∈ [0,1).
446    pub discount: f64,
447}
448impl MDP {
449    /// Construct a new MDP.
450    pub fn new(
451        num_states: usize,
452        num_actions: usize,
453        transitions: Vec<Vec<Vec<f64>>>,
454        rewards: Vec<Vec<f64>>,
455        discount: f64,
456    ) -> Self {
457        Self {
458            num_states,
459            num_actions,
460            transitions,
461            rewards,
462            discount,
463        }
464    }
465    /// Apply the Bellman operator: (TV)(s) = max_a [R(s,a) + γ Σ P(s'|s,a) V(s')].
466    pub fn bellman_operator(&self, v: &[f64]) -> Vec<f64> {
467        let mut tv = vec![0.0_f64; self.num_states];
468        for s in 0..self.num_states {
469            let mut best = f64::NEG_INFINITY;
470            for a in 0..self.num_actions {
471                let mut q = self.rewards[s][a];
472                for sp in 0..self.num_states {
473                    q += self.discount * self.transitions[s][a][sp] * v[sp];
474                }
475                if q > best {
476                    best = q;
477                }
478            }
479            tv[s] = best;
480        }
481        tv
482    }
483    /// Value iteration: iterate the Bellman operator until convergence.
484    pub fn value_iteration(&self, tol: f64, max_iter: usize) -> Vec<f64> {
485        let mut v = vec![0.0_f64; self.num_states];
486        for _ in 0..max_iter {
487            let tv = self.bellman_operator(&v);
488            let delta: f64 = v
489                .iter()
490                .zip(tv.iter())
491                .map(|(a, b)| (a - b).abs())
492                .fold(0.0_f64, f64::max);
493            v = tv;
494            if delta < tol {
495                break;
496            }
497        }
498        v
499    }
500    /// Extract the greedy policy from a value function.
501    pub fn policy_improvement(&self, v: &[f64]) -> Vec<usize> {
502        let mut policy = vec![0_usize; self.num_states];
503        for s in 0..self.num_states {
504            let mut best_a = 0;
505            let mut best_q = f64::NEG_INFINITY;
506            for a in 0..self.num_actions {
507                let mut q = self.rewards[s][a];
508                for sp in 0..self.num_states {
509                    q += self.discount * self.transitions[s][a][sp] * v[sp];
510                }
511                if q > best_q {
512                    best_q = q;
513                    best_a = a;
514                }
515            }
516            policy[s] = best_a;
517        }
518        policy
519    }
520    /// Policy evaluation: compute V^π for a deterministic policy using iterative updates.
521    pub fn policy_evaluation(&self, policy: &[usize], tol: f64, max_iter: usize) -> Vec<f64> {
522        let mut v = vec![0.0_f64; self.num_states];
523        for _ in 0..max_iter {
524            let mut new_v = vec![0.0_f64; self.num_states];
525            for s in 0..self.num_states {
526                let a = policy[s];
527                let mut val = self.rewards[s][a];
528                for sp in 0..self.num_states {
529                    val += self.discount * self.transitions[s][a][sp] * v[sp];
530                }
531                new_v[s] = val;
532            }
533            let delta: f64 = v
534                .iter()
535                .zip(new_v.iter())
536                .map(|(a, b)| (a - b).abs())
537                .fold(0.0_f64, f64::max);
538            v = new_v;
539            if delta < tol {
540                break;
541            }
542        }
543        v
544    }
545}
546/// Q-learning agent.
547///
548/// Off-policy TD control: Q(s,a) ← Q(s,a) + α(r + γ max_a' Q(s',a') − Q(s,a)).
549#[derive(Debug, Clone)]
550pub struct QLearning {
551    /// Q-value table: `q[s][a]`.
552    pub q: Vec<Vec<f64>>,
553    /// Learning rate α ∈ (0,1].
554    pub alpha: f64,
555    /// Discount factor γ ∈ [0,1).
556    pub gamma: f64,
557}
558impl QLearning {
559    /// Construct a Q-learning agent with zero-initialised Q-table.
560    pub fn new(num_states: usize, num_actions: usize, alpha: f64, gamma: f64) -> Self {
561        Self {
562            q: vec![vec![0.0_f64; num_actions]; num_states],
563            alpha,
564            gamma,
565        }
566    }
567    /// Perform a single Q-learning update.
568    pub fn update(&mut self, s: usize, a: usize, r: f64, s_next: usize) {
569        let max_q_next = self.q[s_next]
570            .iter()
571            .cloned()
572            .fold(f64::NEG_INFINITY, f64::max);
573        let td_error = r + self.gamma * max_q_next - self.q[s][a];
574        self.q[s][a] += self.alpha * td_error;
575    }
576    /// Return the greedy action in state `s`.
577    pub fn greedy_action(&self, s: usize) -> usize {
578        self.q[s]
579            .iter()
580            .enumerate()
581            .max_by(|x, y| x.1.partial_cmp(y.1).unwrap_or(std::cmp::Ordering::Equal))
582            .map(|(i, _)| i)
583            .unwrap_or(0)
584    }
585    /// Expected return (sum of Q values) from state `s` under current Q-table greedy policy.
586    pub fn expected_return(&self, s: usize) -> f64 {
587        self.q[s].iter().cloned().fold(f64::NEG_INFINITY, f64::max)
588    }
589}
590/// Actor-Critic agent: combined policy gradient (actor) + TD value function (critic).
591#[derive(Debug, Clone)]
592pub struct ActorCritic {
593    /// Actor: policy gradient component.
594    pub actor: PolicyGradient,
595    /// Critic: value function V(s) estimate.
596    pub critic: Vec<f64>,
597    /// Critic learning rate.
598    pub critic_alpha: f64,
599}
600impl ActorCritic {
601    /// Construct an actor-critic agent.
602    pub fn new(
603        num_states: usize,
604        num_actions: usize,
605        actor_alpha: f64,
606        critic_alpha: f64,
607        gamma: f64,
608    ) -> Self {
609        Self {
610            actor: PolicyGradient::new(num_states, num_actions, actor_alpha, gamma),
611            critic: vec![0.0_f64; num_states],
612            critic_alpha,
613        }
614    }
615    /// Perform a single actor-critic update given (s, a, r, s').
616    pub fn update(&mut self, s: usize, a: usize, r: f64, s_next: usize) {
617        let td_error = r + self.actor.gamma * self.critic[s_next] - self.critic[s];
618        self.critic[s] += self.critic_alpha * td_error;
619        let pi = self.actor.softmax(s);
620        let num_actions = self.actor.theta[s].len();
621        for b in 0..num_actions {
622            let indicator = if b == a { 1.0 } else { 0.0 };
623            let grad_log = indicator - pi[b];
624            self.actor.theta[s][b] += self.actor.alpha * grad_log * td_error;
625        }
626    }
627    /// Return the estimated state value V(s).
628    pub fn expected_return(&self, s: usize) -> f64 {
629        self.critic[s]
630    }
631}
632/// Pursuit-evasion game: pursuer minimises time-to-capture, evader maximises it.
633#[derive(Debug, Clone)]
634pub struct PursuitEvasionGame {
635    /// Position of pursuer (2D).
636    pub pursuer: [f64; 2],
637    /// Position of evader (2D).
638    pub evader: [f64; 2],
639    /// Pursuer speed.
640    pub pursuer_speed: f64,
641    /// Evader speed.
642    pub evader_speed: f64,
643}
644impl PursuitEvasionGame {
645    /// Construct a pursuit-evasion game.
646    pub fn new(pursuer: [f64; 2], evader: [f64; 2], pursuer_speed: f64, evader_speed: f64) -> Self {
647        Self {
648            pursuer,
649            evader,
650            pursuer_speed,
651            evader_speed,
652        }
653    }
654    /// Euclidean distance between pursuer and evader.
655    pub fn distance(&self) -> f64 {
656        let dx = self.pursuer[0] - self.evader[0];
657        let dy = self.pursuer[1] - self.evader[1];
658        (dx * dx + dy * dy).sqrt()
659    }
660    /// Returns true if the pursuer can eventually catch the evader (pursuer speed > evader speed).
661    pub fn pursuer_wins(&self) -> bool {
662        self.pursuer_speed > self.evader_speed
663    }
664    /// Isotropic capture time estimate (simple Apollonius circle formula for equal speeds).
665    pub fn capture_time_estimate(&self) -> f64 {
666        let d = self.distance();
667        let relative_speed = self.pursuer_speed - self.evader_speed;
668        if relative_speed <= 0.0 {
669            f64::INFINITY
670        } else {
671            d / relative_speed
672        }
673    }
674}
675/// Mean-square stability checker: verifies E[||x_t||²] → 0.
676#[derive(Debug, Clone)]
677pub struct MeanSquareStability {
678    /// Sequence of E[||x_t||²] samples.
679    pub ms_samples: Vec<f64>,
680}
681impl MeanSquareStability {
682    /// Construct from a sequence of mean-square values.
683    pub fn new(ms_samples: Vec<f64>) -> Self {
684        Self { ms_samples }
685    }
686    /// Check if the sequence is non-increasing (necessary but not sufficient for stability).
687    pub fn is_non_increasing(&self) -> bool {
688        self.ms_samples.windows(2).all(|w| w[0] >= w[1])
689    }
690    /// Check if the last sample is below a tolerance (approximate convergence to 0).
691    pub fn has_converged(&self, tol: f64) -> bool {
692        self.ms_samples.last().is_some_and(|&v| v < tol)
693    }
694}
695/// POMDP belief state update (Bayes filter).
696///
697/// Given belief b ∈ Δ(S), action a, observation o,
698/// computes b'(s') ∝ Z(o|s', a) · Σ_s T(s'|s, a) · b(s).
699#[derive(Debug, Clone)]
700pub struct BeliefMDP {
701    /// Number of states.
702    pub num_states: usize,
703    /// Number of actions.
704    pub num_actions: usize,
705    /// Number of observations.
706    pub num_obs: usize,
707    /// Transition probabilities T[s][a][s'].
708    pub transitions: Vec<Vec<Vec<f64>>>,
709    /// Observation probabilities Z[s'][a][o].
710    pub observations: Vec<Vec<Vec<f64>>>,
711}
712impl BeliefMDP {
713    /// Construct a BeliefMDP.
714    pub fn new(
715        num_states: usize,
716        num_actions: usize,
717        num_obs: usize,
718        transitions: Vec<Vec<Vec<f64>>>,
719        observations: Vec<Vec<Vec<f64>>>,
720    ) -> Self {
721        Self {
722            num_states,
723            num_actions,
724            num_obs,
725            transitions,
726            observations,
727        }
728    }
729    /// Perform a Bayesian belief update.
730    ///
731    /// b'(s') ∝ Z(o | s', a) · Σ_s T(s' | s, a) · b(s)
732    pub fn belief_update(&self, belief: &[f64], action: usize, obs: usize) -> Vec<f64> {
733        let mut new_belief = vec![0.0_f64; self.num_states];
734        for sp in 0..self.num_states {
735            let predict: f64 = (0..self.num_states)
736                .map(|s| self.transitions[s][action][sp] * belief[s])
737                .sum();
738            new_belief[sp] = self.observations[sp][action][obs] * predict;
739        }
740        let total: f64 = new_belief.iter().sum();
741        if total > 0.0 {
742            for b in &mut new_belief {
743                *b /= total;
744            }
745        }
746        new_belief
747    }
748    /// QMDP approximation: V(b) ≈ max_a Σ_s b(s) · Q*(s, a).
749    pub fn qmdp_value(&self, belief: &[f64], q_star: &[Vec<f64>]) -> f64 {
750        (0..self.num_actions)
751            .map(|a| {
752                belief
753                    .iter()
754                    .enumerate()
755                    .map(|(s, &bs)| bs * q_star[s][a])
756                    .sum::<f64>()
757            })
758            .fold(f64::NEG_INFINITY, f64::max)
759    }
760}
761/// Algebraic Riccati equation solver for LQR/LQG.
762///
763/// Solves A^T P + P A - P B R^{-1} B^T P + Q = 0 iteratively (Euler integration).
764#[derive(Debug, Clone)]
765pub struct RiccatiEquation {
766    /// System matrix A (n×n).
767    pub a: Vec<Vec<f64>>,
768    /// Input matrix B (n×m).
769    pub b: Vec<Vec<f64>>,
770    /// State cost matrix Q (n×n), positive semidefinite.
771    pub q_cost: Vec<Vec<f64>>,
772    /// Input cost matrix R (m×m), positive definite.
773    pub r_cost: Vec<Vec<f64>>,
774}
775impl RiccatiEquation {
776    /// Construct a Riccati solver.
777    pub fn new(
778        a: Vec<Vec<f64>>,
779        b: Vec<Vec<f64>>,
780        q_cost: Vec<Vec<f64>>,
781        r_cost: Vec<Vec<f64>>,
782    ) -> Self {
783        Self {
784            a,
785            b,
786            q_cost,
787            r_cost,
788        }
789    }
790    fn mat_mul(m1: &[Vec<f64>], m2: &[Vec<f64>]) -> Vec<Vec<f64>> {
791        let r1 = m1.len();
792        let c2 = m2[0].len();
793        let inner = m2.len();
794        let mut out = vec![vec![0.0_f64; c2]; r1];
795        for i in 0..r1 {
796            for j in 0..c2 {
797                for k in 0..inner {
798                    out[i][j] += m1[i][k] * m2[k][j];
799                }
800            }
801        }
802        out
803    }
804    fn mat_transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
805        if m.is_empty() {
806            return vec![];
807        }
808        let rows = m.len();
809        let cols = m[0].len();
810        let mut out = vec![vec![0.0_f64; rows]; cols];
811        for i in 0..rows {
812            for j in 0..cols {
813                out[j][i] = m[i][j];
814            }
815        }
816        out
817    }
818    fn mat_add(m1: &[Vec<f64>], m2: &[Vec<f64>]) -> Vec<Vec<f64>> {
819        m1.iter()
820            .zip(m2.iter())
821            .map(|(r1, r2)| r1.iter().zip(r2.iter()).map(|(a, b)| a + b).collect())
822            .collect()
823    }
824    fn mat_sub(m1: &[Vec<f64>], m2: &[Vec<f64>]) -> Vec<Vec<f64>> {
825        m1.iter()
826            .zip(m2.iter())
827            .map(|(r1, r2)| r1.iter().zip(r2.iter()).map(|(a, b)| a - b).collect())
828            .collect()
829    }
830    fn mat_scale(m: &[Vec<f64>], s: f64) -> Vec<Vec<f64>> {
831        m.iter()
832            .map(|r| r.iter().map(|x| x * s).collect())
833            .collect()
834    }
835    /// Invert a 1×1 or 2×2 matrix (sufficient for test cases).
836    fn mat_inv_small(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
837        let n = m.len();
838        if n == 1 {
839            return vec![vec![1.0 / m[0][0]]];
840        }
841        if n == 2 {
842            let det = m[0][0] * m[1][1] - m[0][1] * m[1][0];
843            return vec![
844                vec![m[1][1] / det, -m[0][1] / det],
845                vec![-m[1][0] / det, m[0][0] / det],
846            ];
847        }
848        let mut eye = vec![vec![0.0_f64; n]; n];
849        for i in 0..n {
850            eye[i][i] = 1.0;
851        }
852        eye
853    }
854    /// Compute the Riccati derivative dP/dt = A^T P + P A - P B R^{-1} B^T P + Q.
855    fn riccati_deriv(&self, p: &[Vec<f64>]) -> Vec<Vec<f64>> {
856        let at = Self::mat_transpose(&self.a);
857        let bt = Self::mat_transpose(&self.b);
858        let r_inv = Self::mat_inv_small(&self.r_cost);
859        let at_p = Self::mat_mul(&at, p);
860        let p_a = Self::mat_mul(p, &self.a);
861        let p_b = Self::mat_mul(p, &self.b);
862        let p_b_rinv = Self::mat_mul(&p_b, &r_inv);
863        let p_b_rinv_bt = Self::mat_mul(&p_b_rinv, &bt);
864        let p_b_rinv_bt_p = Self::mat_mul(&p_b_rinv_bt, p);
865        let sum = Self::mat_add(&at_p, &p_a);
866        let sum2 = Self::mat_add(&sum, &self.q_cost);
867        Self::mat_sub(&sum2, &p_b_rinv_bt_p)
868    }
869    /// Solve the algebraic Riccati equation via backward Euler integration.
870    pub fn solve_riccati(&self, dt: f64, max_iter: usize) -> Vec<Vec<f64>> {
871        let n = self.a.len();
872        let mut p = vec![vec![0.0_f64; n]; n];
873        for _ in 0..max_iter {
874            let dp = self.riccati_deriv(&p);
875            let update = Self::mat_scale(&dp, dt);
876            p = Self::mat_add(&p, &update);
877        }
878        p
879    }
880    /// Compute the optimal LQR gain matrix K* = R^{-1} B^T P.
881    pub fn optimal_gain_matrix(&self, p: &[Vec<f64>]) -> Vec<Vec<f64>> {
882        let bt = Self::mat_transpose(&self.b);
883        let r_inv = Self::mat_inv_small(&self.r_cost);
884        let bt_p = Self::mat_mul(&bt, p);
885        Self::mat_mul(&r_inv, &bt_p)
886    }
887    /// Solve the infinite-horizon LQR: return (P, K) where K stabilises the system.
888    pub fn infinite_horizon_lqr(&self) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
889        let p = self.solve_riccati(0.001, 10_000);
890        let k = self.optimal_gain_matrix(&p);
891        (p, k)
892    }
893}
894#[allow(dead_code)]
895#[derive(Debug, Clone)]
896pub struct MeanFieldGame {
897    pub num_players: usize,
898    pub coupling_strength: f64,
899    pub mean_field_type: MFGType,
900    pub convergence_rate: f64,
901}
902#[allow(dead_code)]
903impl MeanFieldGame {
904    pub fn new(players: usize, coupling: f64) -> Self {
905        MeanFieldGame {
906            num_players: players,
907            coupling_strength: coupling,
908            mean_field_type: MFGType::LasryLions,
909            convergence_rate: 1.0 / (players as f64).sqrt(),
910        }
911    }
912    pub fn mfg_system_description(&self) -> String {
913        format!(
914            "MFG ({} players, coupling={:.3}): HJB + FP system, rate O(1/√N)",
915            self.num_players, self.coupling_strength
916        )
917    }
918    pub fn price_of_anarchy(&self) -> f64 {
919        1.0 + self.coupling_strength * 0.5
920    }
921    pub fn master_equation(&self) -> String {
922        "∂_t U + H(x, m, ∇_x U) - ν Δ_x U = ∫ (∂_m U)(y) δ_y F(y,m) m(dy)".to_string()
923    }
924}
925#[allow(dead_code)]
926#[derive(Debug, Clone)]
927pub struct PathwiseSDE {
928    pub drift: String,
929    pub diffusion: String,
930    pub initial_condition: f64,
931    pub time_steps: usize,
932    pub step_size: f64,
933    pub scheme: NumericalSDEScheme,
934}
935#[allow(dead_code)]
936impl PathwiseSDE {
937    pub fn euler_maruyama(drift: &str, diffusion: &str, x0: f64, steps: usize, dt: f64) -> Self {
938        PathwiseSDE {
939            drift: drift.to_string(),
940            diffusion: diffusion.to_string(),
941            initial_condition: x0,
942            time_steps: steps,
943            step_size: dt,
944            scheme: NumericalSDEScheme::EulerMaruyama,
945        }
946    }
947    pub fn milstein(drift: &str, diffusion: &str, x0: f64, steps: usize, dt: f64) -> Self {
948        PathwiseSDE {
949            drift: drift.to_string(),
950            diffusion: diffusion.to_string(),
951            initial_condition: x0,
952            time_steps: steps,
953            step_size: dt,
954            scheme: NumericalSDEScheme::Milstein,
955        }
956    }
957    pub fn strong_order(&self) -> f64 {
958        match &self.scheme {
959            NumericalSDEScheme::EulerMaruyama => 0.5,
960            NumericalSDEScheme::Milstein => 1.0,
961            NumericalSDEScheme::RungeKuttaSDE => 1.5,
962            NumericalSDEScheme::StochasticTaylor => 2.0,
963        }
964    }
965    pub fn weak_order(&self) -> f64 {
966        match &self.scheme {
967            NumericalSDEScheme::EulerMaruyama => 1.0,
968            NumericalSDEScheme::Milstein => 1.0,
969            NumericalSDEScheme::RungeKuttaSDE => 2.0,
970            NumericalSDEScheme::StochasticTaylor => 2.0,
971        }
972    }
973    pub fn simulate_one_path(&self) -> Vec<f64> {
974        let mut path = vec![self.initial_condition];
975        let mut x = self.initial_condition;
976        for _ in 0..self.time_steps {
977            let dw = 0.0;
978            x += self.step_size * 0.5 + 0.3 * dw;
979            path.push(x);
980        }
981        path
982    }
983}
984/// Stochastic Lyapunov stability checker.
985///
986/// Checks the Foster-Lyapunov condition: LV(x) ≤ -α V(x) + β.
987#[derive(Debug, Clone)]
988pub struct StochasticLyapunov {
989    /// Decay rate α > 0.
990    pub alpha: f64,
991    /// Additive drift β ≥ 0.
992    pub beta: f64,
993}
994impl StochasticLyapunov {
995    /// Construct a stochastic Lyapunov condition checker.
996    pub fn new(alpha: f64, beta: f64) -> Self {
997        Self { alpha, beta }
998    }
999    /// Check if LV(x) ≤ -α V(x) + β holds for a given LV(x) and V(x).
1000    pub fn check(&self, lv: f64, v: f64) -> bool {
1001        lv <= -self.alpha * v + self.beta
1002    }
1003    /// Upper bound on E[V(x_t)] given E[V(x_0)] = v0 (from Foster-Lyapunov).
1004    pub fn ev_upper_bound(&self, v0: f64, t: f64) -> f64 {
1005        let decay = (-self.alpha * t).exp();
1006        v0 * decay + (self.beta / self.alpha) * (1.0 - decay)
1007    }
1008}
1009#[allow(dead_code)]
1010#[derive(Debug, Clone)]
1011pub enum MFGType {
1012    LasryLions,
1013    MFGControl,
1014    ExtendedMFG,
1015}
1016/// Policy gradient agent (softmax parameterisation).
1017///
1018/// Policy: π_θ(a|s) = exp(θ[s][a]) / Σ exp(θ[s][a'])
1019/// Update: θ[s][a] += α · ∇_θ log π_θ(a|s) · G where G is the return.
1020#[derive(Debug, Clone)]
1021pub struct PolicyGradient {
1022    /// Policy parameter table θ[s][a].
1023    pub theta: Vec<Vec<f64>>,
1024    /// Learning rate α.
1025    pub alpha: f64,
1026    /// Discount factor γ.
1027    pub gamma: f64,
1028}
1029impl PolicyGradient {
1030    /// Construct a policy gradient agent.
1031    pub fn new(num_states: usize, num_actions: usize, alpha: f64, gamma: f64) -> Self {
1032        Self {
1033            theta: vec![vec![0.0_f64; num_actions]; num_states],
1034            alpha,
1035            gamma,
1036        }
1037    }
1038    /// Compute the softmax policy π_θ(·|s).
1039    pub fn softmax(&self, s: usize) -> Vec<f64> {
1040        let row = &self.theta[s];
1041        let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1042        let exps: Vec<f64> = row.iter().map(|&x| (x - max_val).exp()).collect();
1043        let sum: f64 = exps.iter().sum();
1044        exps.iter().map(|&e| e / sum).collect()
1045    }
1046    /// Update θ using a single (s, a, G) sample from a trajectory.
1047    pub fn update(&mut self, s: usize, a: usize, g: f64) {
1048        let pi = self.softmax(s);
1049        let num_actions = self.theta[s].len();
1050        for b in 0..num_actions {
1051            let indicator = if b == a { 1.0 } else { 0.0 };
1052            let grad_log = indicator - pi[b];
1053            self.theta[s][b] += self.alpha * grad_log * g;
1054        }
1055    }
1056    /// Expected return from state `s` as E_π[Q(s,a)].
1057    pub fn expected_return(&self, s: usize, q: &ActionValueFunction) -> f64 {
1058        let pi = self.softmax(s);
1059        pi.iter().enumerate().map(|(a, &p)| p * q.get(s, a)).sum()
1060    }
1061    /// Convergence rate estimate: max |∇J| across states (gradient norm).
1062    pub fn convergence_rate(&self, q: &ActionValueFunction) -> f64 {
1063        let mut max_grad = 0.0_f64;
1064        for s in 0..self.theta.len() {
1065            let pi = self.softmax(s);
1066            for a in 0..self.theta[s].len() {
1067                let grad = pi
1068                    .iter()
1069                    .enumerate()
1070                    .map(|(b, &pb)| {
1071                        let indicator = if b == a { 1.0 } else { 0.0 };
1072                        (indicator - pb) * q.get(s, a)
1073                    })
1074                    .sum::<f64>()
1075                    .abs();
1076                if grad > max_grad {
1077                    max_grad = grad;
1078                }
1079            }
1080        }
1081        max_grad
1082    }
1083}
1084/// Q-learning solver with epsilon-greedy exploration and decaying step size.
1085#[derive(Debug, Clone)]
1086pub struct QLearningSolver {
1087    /// Q-value table Q[s][a].
1088    pub q: Vec<Vec<f64>>,
1089    /// Learning rate α.
1090    pub alpha: f64,
1091    /// Discount factor γ.
1092    pub gamma: f64,
1093    /// Exploration rate ε.
1094    pub epsilon: f64,
1095    /// Step counter per (s,a) pair (for step-size decay).
1096    pub visit_count: Vec<Vec<u64>>,
1097}
1098impl QLearningSolver {
1099    /// Construct a Q-learning solver.
1100    pub fn new(
1101        num_states: usize,
1102        num_actions: usize,
1103        alpha: f64,
1104        gamma: f64,
1105        epsilon: f64,
1106    ) -> Self {
1107        Self {
1108            q: vec![vec![0.0_f64; num_actions]; num_states],
1109            alpha,
1110            gamma,
1111            epsilon,
1112            visit_count: vec![vec![0_u64; num_actions]; num_states],
1113        }
1114    }
1115    /// Perform a Q-learning update with harmonic step size 1/(1 + n(s,a)).
1116    pub fn update(&mut self, s: usize, a: usize, r: f64, s_next: usize) {
1117        self.visit_count[s][a] += 1;
1118        let n = self.visit_count[s][a] as f64;
1119        let step = self.alpha / (1.0 + n).sqrt();
1120        let max_q_next = self.q[s_next]
1121            .iter()
1122            .cloned()
1123            .fold(f64::NEG_INFINITY, f64::max);
1124        let td_error = r + self.gamma * max_q_next - self.q[s][a];
1125        self.q[s][a] += step * td_error;
1126    }
1127    /// Select an action using epsilon-greedy policy (deterministic tie-breaking).
1128    /// `rng_val` ∈ [0,1) is a uniform random value supplied by the caller.
1129    pub fn select_action(&self, s: usize, rng_val: f64) -> usize {
1130        if rng_val < self.epsilon {
1131            let n = self.q[s].len();
1132            ((rng_val / self.epsilon) * n as f64) as usize % n
1133        } else {
1134            self.q[s]
1135                .iter()
1136                .enumerate()
1137                .max_by(|x, y| x.1.partial_cmp(y.1).unwrap_or(std::cmp::Ordering::Equal))
1138                .map(|(i, _)| i)
1139                .unwrap_or(0)
1140        }
1141    }
1142    /// Check convergence: max |Q(s,a) - Q_prev(s,a)| < tol.
1143    pub fn has_converged(&self, prev_q: &[Vec<f64>], tol: f64) -> bool {
1144        self.q.iter().zip(prev_q.iter()).all(|(row, prev_row)| {
1145            row.iter()
1146                .zip(prev_row.iter())
1147                .all(|(q, pq)| (q - pq).abs() < tol)
1148        })
1149    }
1150    /// Return the current greedy policy.
1151    pub fn greedy_policy(&self) -> Vec<usize> {
1152        (0..self.q.len())
1153            .map(|s| {
1154                self.q[s]
1155                    .iter()
1156                    .enumerate()
1157                    .max_by(|x, y| x.1.partial_cmp(y.1).unwrap_or(std::cmp::Ordering::Equal))
1158                    .map(|(i, _)| i)
1159                    .unwrap_or(0)
1160            })
1161            .collect()
1162    }
1163}
1164#[allow(dead_code)]
1165#[derive(Debug, Clone)]
1166pub struct HInfinityControl {
1167    pub disturbance_attenuation: f64,
1168    pub state_dim: usize,
1169    pub control_dim: usize,
1170    pub disturbance_dim: usize,
1171    pub riccati_solution: Option<f64>,
1172}
1173#[allow(dead_code)]
1174impl HInfinityControl {
1175    pub fn new(gamma: f64, n: usize, m: usize, k: usize) -> Self {
1176        HInfinityControl {
1177            disturbance_attenuation: gamma,
1178            state_dim: n,
1179            control_dim: m,
1180            disturbance_dim: k,
1181            riccati_solution: None,
1182        }
1183    }
1184    pub fn minimax_criterion(&self) -> String {
1185        format!(
1186            "H∞: min_u max_w ||z||² - {:.3}² ||w||² (disturbance attenuation γ={:.3})",
1187            self.disturbance_attenuation, self.disturbance_attenuation
1188        )
1189    }
1190    pub fn game_riccati_equation(&self) -> String {
1191        format!(
1192            "Game ARE: PA + A^TP - P(B B^T - (1/{:.3}²) B_w B_w^T)P + C^TC = 0",
1193            self.disturbance_attenuation
1194        )
1195    }
1196    pub fn is_feasible(&self) -> bool {
1197        self.riccati_solution.map_or(false, |p| p > 0.0)
1198    }
1199}
1200#[allow(dead_code)]
1201#[derive(Debug, Clone)]
1202pub struct ErgodicControl {
1203    pub discount_rate: f64,
1204    pub state_space_dim: usize,
1205    pub long_run_cost: Option<f64>,
1206    pub eigenvalue_lambda: Option<f64>,
1207}
1208#[allow(dead_code)]
1209impl ErgodicControl {
1210    pub fn new(dim: usize) -> Self {
1211        ErgodicControl {
1212            discount_rate: 0.0,
1213            state_space_dim: dim,
1214            long_run_cost: None,
1215            eigenvalue_lambda: None,
1216        }
1217    }
1218    pub fn ergodic_hjb(&self) -> String {
1219        "λ + H(x, ∇V, ∇²V) = 0 (ergodic HJB: λ is long-run average cost)".to_string()
1220    }
1221    pub fn turnpike_property(&self) -> String {
1222        "Turnpike: finite-horizon optimal trajectories spend most time near ergodic optimal"
1223            .to_string()
1224    }
1225    pub fn set_eigenvalue(&mut self, lambda: f64) {
1226        self.eigenvalue_lambda = Some(lambda);
1227        self.long_run_cost = Some(lambda);
1228    }
1229    pub fn relative_value_function_description(&self) -> String {
1230        format!(
1231            "Ergodic control dim={}: solve (λ*, V) pair in ergodic HJB",
1232            self.state_space_dim
1233        )
1234    }
1235}
1236/// Almost-sure stability: checks ||x_t|| → 0 along sample paths.
1237#[derive(Debug, Clone)]
1238pub struct AlmostSureStability {
1239    /// Sample paths: each inner Vec is a trajectory of ||x_t||.
1240    pub paths: Vec<Vec<f64>>,
1241}
1242impl AlmostSureStability {
1243    /// Construct from sample paths.
1244    pub fn new(paths: Vec<Vec<f64>>) -> Self {
1245        Self { paths }
1246    }
1247    /// Fraction of paths that have converged below `tol` at the final time step.
1248    pub fn convergence_fraction(&self, tol: f64) -> f64 {
1249        if self.paths.is_empty() {
1250            return 0.0;
1251        }
1252        let converged = self
1253            .paths
1254            .iter()
1255            .filter(|p| p.last().is_some_and(|&v| v < tol))
1256            .count();
1257        converged as f64 / self.paths.len() as f64
1258    }
1259}
1260/// Deterministic policy π: S → A.
1261#[derive(Debug, Clone)]
1262pub struct Policy {
1263    /// `table[s]` = action chosen in state s.
1264    pub table: Vec<usize>,
1265}
1266impl Policy {
1267    /// Construct a policy from an action table.
1268    pub fn new(table: Vec<usize>) -> Self {
1269        Self { table }
1270    }
1271    /// Return the action for state `s`.
1272    pub fn action(&self, s: usize) -> usize {
1273        self.table[s]
1274    }
1275}
1276/// SARSA agent (on-policy TD(0)).
1277///
1278/// Q(s,a) ← Q(s,a) + α(r + γ Q(s',a') − Q(s,a)).
1279#[derive(Debug, Clone)]
1280pub struct SARSA {
1281    /// Q-value table.
1282    pub q: Vec<Vec<f64>>,
1283    /// Learning rate α.
1284    pub alpha: f64,
1285    /// Discount factor γ.
1286    pub gamma: f64,
1287}
1288impl SARSA {
1289    /// Construct a SARSA agent.
1290    pub fn new(num_states: usize, num_actions: usize, alpha: f64, gamma: f64) -> Self {
1291        Self {
1292            q: vec![vec![0.0_f64; num_actions]; num_states],
1293            alpha,
1294            gamma,
1295        }
1296    }
1297    /// Perform a single SARSA update given (s, a, r, s', a').
1298    pub fn update(&mut self, s: usize, a: usize, r: f64, s_next: usize, a_next: usize) {
1299        let td_error = r + self.gamma * self.q[s_next][a_next] - self.q[s][a];
1300        self.q[s][a] += self.alpha * td_error;
1301    }
1302    /// Return the greedy action in state `s`.
1303    pub fn greedy_action(&self, s: usize) -> usize {
1304        self.q[s]
1305            .iter()
1306            .enumerate()
1307            .max_by(|x, y| x.1.partial_cmp(y.1).unwrap_or(std::cmp::Ordering::Equal))
1308            .map(|(i, _)| i)
1309            .unwrap_or(0)
1310    }
1311    /// Expected return from state `s` under the current greedy policy.
1312    pub fn expected_return(&self, s: usize) -> f64 {
1313        self.q[s].iter().cloned().fold(f64::NEG_INFINITY, f64::max)
1314    }
1315}