Skip to main content

fluent_fsm/machine/
active.rs

1// MIT License
2//
3// Copyright (c) 2024 Wes Kelly
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23use crate::active::ActiveMachineEvent::*;
24use crate::passive::PassiveStateMachine;
25use std::hash::Hash;
26use std::sync::{Arc, RwLock, mpsc};
27use std::thread;
28use std::thread::JoinHandle;
29
30enum ActiveMachineEvent<T: Eq + Hash + Copy> {
31    Start,
32    Stop,
33    ExternalEvent(T),
34}
35
36pub struct ActiveStateMachine<TState, TModel = (), TEvent = ()>
37where
38    TState: Eq + Hash + Copy,
39    TEvent: Eq + Hash + Copy,
40{
41    internal_state: Arc<RwLock<PassiveStateMachine<TState, TModel, TEvent>>>,
42    machine_loop: JoinHandle<()>,
43    tx: mpsc::Sender<ActiveMachineEvent<TEvent>>,
44}
45
46impl<TState, TModel, TEvent> ActiveStateMachine<TState, TModel, TEvent>
47where
48    TEvent: Eq + Hash + Copy + Sync + Send + 'static,
49    TState: Eq + Hash + Copy + Sync + Send + 'static,
50    TModel: Sync + Send + 'static,
51{
52    pub(crate) fn create(
53        active_action: impl Fn(&TState, &TModel) -> Option<TState> + 'static + Send + Sync,
54        machine: PassiveStateMachine<TState, TModel, TEvent>,
55    ) -> Self {
56        let (tx, rx) = mpsc::channel();
57        let machine = Arc::new(RwLock::new(machine));
58        let internal_state = Arc::clone(&machine);
59
60        let machine_loop = thread::spawn(move || {
61            loop {
62                match rx.try_recv() {
63                    Ok(Start) => {
64                        let mut machine = machine.write().unwrap();
65                        machine.start();
66                    }
67                    Ok(ExternalEvent(event)) => {
68                        let mut machine = machine.write().unwrap();
69                        machine.fire(event);
70                    }
71                    Ok(Stop) => {
72                        return;
73                    }
74                    Err(mpsc::TryRecvError::Empty) => {
75                        let mut machine = machine.write().unwrap();
76                        if let Some(state) = active_action(machine.current_state(), machine.model())
77                        {
78                            machine.goto(state);
79                        }
80                    }
81                    Err(mpsc::TryRecvError::Disconnected) => {
82                        return;
83                    }
84                }
85
86                thread::yield_now();
87            }
88        });
89
90        Self {
91            internal_state,
92            machine_loop,
93            tx,
94        }
95    }
96
97    pub fn fire(&self, event: TEvent) {
98        self.tx.send(ExternalEvent(event)).unwrap();
99    }
100
101    pub fn start(&self) {
102        self.tx.send(Start).unwrap();
103    }
104
105    pub fn stop(self) {
106        self.tx.send(Stop).unwrap();
107        self.machine_loop.join().unwrap();
108    }
109
110    pub fn write_model(&mut self, update: impl Fn(&mut TModel) + Send + Sync + 'static) {
111        let mut model = self.internal_state.write().unwrap();
112        update(model.model_mut())
113    }
114
115    pub fn read_state<R>(&self, read: impl Fn(&TModel) -> R) -> R {
116        let state = self.internal_state.read().unwrap();
117        read(state.model())
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::super::builder::StateMachineBuilder;
124    use super::*;
125    use std::time::{Duration, SystemTime};
126
127    struct Model<TState> {
128        in_state: TState,
129        num_transitions: u32,
130        prev_state: Option<TState>,
131        last_transition: SystemTime,
132    }
133
134    impl Model<u32> {
135        pub fn new() -> Self {
136            Self {
137                in_state: 0,
138                num_transitions: 0,
139                prev_state: None,
140                last_transition: SystemTime::now(),
141            }
142        }
143
144        pub fn time_since_last_transition(&self) -> Duration {
145            SystemTime::now()
146                .duration_since(self.last_transition)
147                .expect("time went backwards, please inform the nearest physicist")
148        }
149    }
150
151    #[test]
152    fn test_active_state_machine() {
153        const STATE_1: u32 = 111;
154        const STATE_2: u32 = 222;
155        const MAX_TRANSITIONS: u32 = 5;
156
157        let builder = StateMachineBuilder::<u32, Model<u32>>::create(STATE_1, Model::<u32>::new())
158            .on_enter_mut(|model| {
159                model.in_state = STATE_1;
160                model.num_transitions += 1;
161                model.last_transition = SystemTime::now();
162            })
163            .on_leave_mut(|model| {
164                model.prev_state = Some(STATE_1);
165            })
166            .in_state(STATE_2)
167            .on_enter_mut(|model| {
168                model.in_state = STATE_2;
169                model.num_transitions += 1;
170                model.last_transition = SystemTime::now();
171            })
172            .on_leave_mut(|model| {
173                model.prev_state = Some(STATE_2);
174            });
175
176        let machine = builder.build_active(tick);
177        machine.start();
178
179        thread::sleep(Duration::from_millis(50));
180
181        assert_eq!(
182            machine.read_state(|model| model.num_transitions),
183            MAX_TRANSITIONS
184        );
185
186        machine.stop();
187
188        fn tick(state: &u32, model: &Model<u32>) -> Option<u32> {
189            if model.num_transitions >= MAX_TRANSITIONS {
190                return None;
191            }
192
193            match state {
194                &STATE_1 => {
195                    if let Some(prev) = model.prev_state {
196                        assert_eq!(prev, STATE_2)
197                    }
198
199                    if model.time_since_last_transition() > Duration::from_millis(5) {
200                        Some(STATE_2)
201                    } else {
202                        None
203                    }
204                }
205                &STATE_2 => {
206                    assert_eq!(model.prev_state, Some(STATE_1));
207
208                    if model.time_since_last_transition() > Duration::from_millis(5) {
209                        Some(STATE_1)
210                    } else {
211                        None
212                    }
213                }
214                v => panic!("unexpected state: {v}"),
215            }
216        }
217    }
218}