#[derive(Debug)]
pub struct MarkovDecisionProcess {
pub num_states: usize,
pub num_actions: usize,
pub gamma: f64,
pub transitions: Vec<Vec<Vec<(usize, f64, f64)>>>,
}
impl MarkovDecisionProcess {
pub fn new(
num_states: usize,
num_actions: usize,
gamma: f64,
transitions: Vec<Vec<Vec<(usize, f64, f64)>>>,
) -> Self {
assert!(
(0.0..=1.0).contains(&gamma),
"Discount factor gamma must be between 0 and 1"
);
assert_eq!(transitions.len(), num_states);
for sa in &transitions {
assert_eq!(sa.len(), num_actions);
}
for (s, _sa) in transitions.iter().enumerate().take(num_states) {
for a in 0..num_actions {
let prob_sum: f64 = transitions[s][a].iter().map(|(_, p, _)| p).sum();
let diff = (prob_sum - 1.0).abs();
assert!(
diff < 1e-8,
"Probabilities in state {}, action {} must sum to 1.0, but got {}",
s,
a,
prob_sum
);
}
}
Self {
num_states,
num_actions,
gamma,
transitions,
}
}
}
pub fn value_iteration(
mdp: &MarkovDecisionProcess,
max_iterations: usize,
tolerance: f64,
) -> (Vec<f64>, Vec<usize>) {
let n = mdp.num_states;
let mut v = vec![0.0; n]; let gamma = mdp.gamma;
for _iter in 0..max_iterations {
let mut delta = 0.0_f64;
let mut v_new = vec![0.0; n];
for s in 0..n {
let mut best_val = f64::NEG_INFINITY;
for a in 0..mdp.num_actions {
let q_sa = compute_q_value(s, a, &v, mdp, gamma);
if q_sa > best_val {
best_val = q_sa;
}
}
v_new[s] = best_val;
delta = delta.max((v_new[s] - v[s]).abs());
}
v = v_new;
if delta < tolerance {
break;
}
}
let mut policy = vec![0_usize; n];
for (s, policy_s) in policy.iter_mut().enumerate().take(n) {
let mut best_a = 0;
let mut best_val = f64::NEG_INFINITY;
for a in 0..mdp.num_actions {
let q_sa = compute_q_value(s, a, &v, mdp, gamma);
if q_sa > best_val {
best_val = q_sa;
best_a = a;
}
}
*policy_s = best_a;
}
(v, policy)
}
fn compute_q_value(
s: usize,
a: usize,
values: &[f64],
mdp: &MarkovDecisionProcess,
gamma: f64,
) -> f64 {
let mut q = 0.0;
for &(s_next, prob, reward) in &mdp.transitions[s][a] {
q += prob * (reward + gamma * values[s_next]);
}
q
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_mdp() {
let mdp = MarkovDecisionProcess::new(
2,
2,
0.9,
vec![
vec![vec![(0, 1.0, 1.0)], vec![(1, 1.0, 0.0)]],
vec![vec![(0, 1.0, 0.0)], vec![(1, 1.0, 2.0)]],
],
);
let (values, policy) = value_iteration(&mdp, 100, 1e-6);
assert!(values[1] > 10.0); assert_eq!(policy[1], 1);
assert!(policy[0] < 2);
}
#[test]
#[should_panic]
fn test_invalid_probabilities() {
MarkovDecisionProcess::new(
1,
1,
0.99,
vec![vec![vec![(0, 0.5, 10.0)]]], );
}
}