madepro/models/
policy.rs

1use std::collections::HashMap;
2
3use crate::errors::NotFound;
4
5use super::{Action, Sampler, State};
6
7/// # Policy
8///
9/// Represents a mapping from states to actions.
10#[derive(Debug, PartialEq, Eq)]
11pub struct Policy<S, A>(HashMap<S, A>)
12where
13    S: State,
14    A: Action;
15
16impl<S, A> Policy<S, A>
17where
18    S: State,
19    A: Action,
20{
21    /// Creates a new policy with each state mapped to a random action.
22    pub fn new(states: &Sampler<S>, actions: &Sampler<A>) -> Self {
23        let mut map = HashMap::new();
24        for state in states {
25            map.insert(state.clone(), actions.get_random().clone());
26        }
27        Self(map)
28    }
29
30    /// Returns the action associated with the given state.
31    pub fn get(&self, state: &S) -> &A {
32        self.0
33            .get(state)
34            .unwrap_or_else(|| panic!("{}", NotFound::StateInPolicy))
35    }
36
37    /// Inserts the given action for the given state.
38    pub fn insert(&mut self, state: &S, action: &A) {
39        self.0.insert(state.clone(), action.clone());
40    }
41}