momba_explore/simulate/
mod.rs1use rand::seq::IteratorRandom;
4use rand::Rng;
5
6use super::*;
7
8pub trait Oracle<T: time::TimeType> {
13    fn choose<'e, 't>(
15        &self,
16        state: &State<T::Valuations>,
17        transitions: &'t Vec<Transition<'e, T>>,
18    ) -> &'t Transition<'e, T>;
19}
20
21pub struct UniformOracle {}
23
24impl UniformOracle {
25    pub fn new() -> Self {
27        UniformOracle {}
28    }
29}
30
31impl<T: time::TimeType> Oracle<T> for UniformOracle {
32    fn choose<'e, 't>(
33        &self,
34        _state: &State<T::Valuations>,
35        transitions: &'t Vec<Transition<'e, T>>,
36    ) -> &'t Transition<'e, T> {
37        let mut rng = rand::thread_rng();
38        transitions.iter().choose(&mut rng).unwrap()
39    }
40}
41
42pub struct InjectionOracle {}
49
50pub struct Simulator<O: Oracle<T>, T: time::TimeType> {
52    pub(crate) oracle: O,
53
54    _phontom_time_type: std::marker::PhantomData<T>,
55}
56
57impl<O: Oracle<T>, T: time::TimeType> Simulator<O, T> {
58    pub fn new(oracle: O) -> Self {
60        Simulator {
61            oracle,
62            _phontom_time_type: std::marker::PhantomData,
63        }
64    }
65
66    pub fn oracle(&self) -> &O {
68        &self.oracle
69    }
70
71    pub fn simulate(&self, explorer: &Explorer<T>, steps: usize) {
73        let mut rng = rand::thread_rng();
74        let mut state = explorer
75            .initial_states()
76            .into_iter()
77            .choose(&mut rng)
78            .unwrap();
79
80        for _ in 0..steps {
81            let transition = explorer
82                .transitions(&state)
83                .into_iter()
84                .choose(&mut rng)
85                .unwrap();
86
87            match transition.result_action() {
88                Action::Silent => println!("τ"),
89                Action::Labeled(labeled) => println!(
90                    "{} {:?}",
91                    labeled.label(&explorer.network).unwrap(),
92                    labeled.arguments()
93                ),
94            }
95
96            let destinations = explorer.destinations(&state, &transition);
97
98            let threshold: f64 = rng.gen();
99            let mut accumulated = 0.0;
100
101            for destination in destinations {
102                accumulated += destination.probability();
103                if accumulated >= threshold {
104                    state = explorer.successor(&state, &transition, &destination);
105                    break;
106                }
107            }
108        }
109    }
110}