use crate::mdp::model::{Action, State};
use rand::seq::SliceRandom;
use std::collections::HashMap;
#[derive(Debug, PartialEq, Eq)]
pub struct Policy<'a, S: State, A: Action> {
mapping: HashMap<&'a S, &'a A>,
}
impl<'a, S: State, A: Action> Policy<'a, S, A> {
pub fn new(mapping: HashMap<&'a S, &'a A>) -> Self {
Self { mapping }
}
pub fn random(states: &'a [S], actions: &'a [A]) -> Self {
let mut rng = rand::thread_rng();
let mapping = states
.iter()
.map(|state| {
let action = actions.choose(&mut rng).expect("Actions must not be empty");
(state, action)
})
.collect();
Self { mapping }
}
pub fn select_action(&self, state: &S) -> Option<&A> {
self.mapping.get(state).copied()
}
}
#[cfg(test)]
mod tests {
use crate::mdp::{
model::{Action, State},
policy::Policy,
};
#[derive(Debug, Hash, PartialEq, Eq)]
struct S {
id: usize,
}
impl State for S {
fn id(&self) -> usize {
self.id
}
}
#[derive(Debug, PartialEq, Eq)]
struct A {
id: usize,
}
impl Action for A {
fn id(&self) -> usize {
self.id
}
}
#[test]
fn random_policy() {
let states: Vec<S> = (0..5).map(|id| S { id }).collect();
let actions: Vec<A> = (0..4).map(|id| A { id }).collect();
let random_policy = Policy::random(&states, &actions);
assert_eq!(random_policy.mapping.len(), 5);
assert!(states
.iter()
.all(|state| random_policy.select_action(state).is_some()));
assert!(random_policy.select_action(&S { id: 10 }).is_none());
}
}