1use std::collections::HashMap;
2
3use crate::errors::NotFound;
4
5use super::{Action, Sampler, State};
6
7#[derive(Debug, PartialEq, Eq)]
11pub struct Policy<S, A>(HashMap<S, A>)
12where
13 S: State,
14 A: Action;
15
16impl<S, A> Policy<S, A>
17where
18 S: State,
19 A: Action,
20{
21 pub fn new(states: &Sampler<S>, actions: &Sampler<A>) -> Self {
23 let mut map = HashMap::new();
24 for state in states {
25 map.insert(state.clone(), actions.get_random().clone());
26 }
27 Self(map)
28 }
29
30 pub fn get(&self, state: &S) -> &A {
32 self.0
33 .get(state)
34 .unwrap_or_else(|| panic!("{}", NotFound::StateInPolicy))
35 }
36
37 pub fn insert(&mut self, state: &S, action: &A) {
39 self.0.insert(state.clone(), action.clone());
40 }
41}