call_fsm/
lib.rs

1extern crate alloc;
2extern crate core;
3
4use core::fmt::{Display, Formatter};
5
6#[macro_export]
7macro_rules! declare_data_type {
8    ($dt:ty) => {
9        type DataType = $dt;
10    }
11}
12
13#[macro_export]
14macro_rules! declare_state_machine {
15    ($name:ident, $data: ident, $num_states:expr) => {
16        let mut $name: StateMachine<DataType> = StateMachine::new($data, $num_states);
17    }
18}
19
20#[macro_export]
21macro_rules! new_state {
22    ($sm:ident, $name:ident, $init:expr, $exec:expr) => {
23        let $name: State<DataType> = State::new(
24            stringify!($name),
25            $init,
26            $exec);
27        let $name = $sm.add_state($name).expect("Failed to add state");
28    }
29}
30
31#[macro_export]
32macro_rules! new_transition {
33    ($sm:ident, $src:ident, $dst: ident, $check:expr, $done:expr) => {
34        let _t: Transition<DataType> = Transition::new(
35            concat!(stringify!($src), "__", stringify!($dst)),
36            $src,
37            $dst,
38            $check,
39            $done);
40        $sm.add_transition(_t, $src, $dst).expect("Failed to add transition");
41    }
42}
43
44pub type FsmResult = Result<(), FsmError>;
45
46#[derive(Copy, Clone, Debug, PartialEq)]
47pub enum FsmError {
48    StateIndexOutOfBounds,
49    TransitionIndexOutOfBounds,
50    MaxNumberOfStatesExceeded,
51    AddTransitionSrcDstStatesEqual,
52    StateIsEmpty,
53    TransitionIsEmpty,
54}
55
56impl Display for FsmError {
57    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58        write!(f, "{:?}", self)
59    }
60}
61
62pub type StateCallback<T> = dyn Fn(&State<T>, &mut T) -> Result<(), FsmError>;
63pub type TransCheckCallback<T> = dyn Fn(&Transition<T>, &T) -> bool;
64pub type TransDoneCallback<T> = dyn Fn(&Transition<T>, &mut T) -> Result<(), FsmError>;
65pub type ErrorCallback<T> = dyn Fn(FsmError, &mut T) -> Option<Destination>;
66
67pub enum Destination {
68    Index(usize),
69    Name(String),
70}
71
72pub struct StateMachine<T: 'static + Clone> {
73    data: T,
74
75    states: Vec<Option<State<T>>>,
76    num_states: usize,
77
78    transitions: Vec<Vec<Option<Transition<T>>>>,
79    active_state: Option<usize>,
80    active_state_initialized: bool,
81
82    error: Option<(&'static ErrorCallback<T>, &'static ErrorCallback<T>)>,
83}
84
85impl<T: Clone> StateMachine<T> {
86    pub fn new(data: T, max_states: usize) -> StateMachine<T> {
87        StateMachine {
88            data,
89            states: vec![None; max_states],
90            num_states: 0,
91            transitions: vec![vec![None; max_states]; max_states],
92            active_state: None,
93            active_state_initialized: false,
94            error: None
95        }
96    }
97
98    pub fn state(&self, index: usize) -> Result<&State<T>, FsmError> {
99        if index >= self.num_states {
100            Err(FsmError::StateIndexOutOfBounds)
101        } else if let Some(ref state) = self.states[index] {
102            Ok(state)
103        } else {
104            Err(FsmError::StateIsEmpty)
105        }
106    }
107
108    pub fn state_by_name(&self, name: String) -> Option<usize> {
109        for (i, s) in self.states.iter().enumerate() {
110            if let Some(state) = s {
111                if name == state.name {
112                    return Some(i);
113                }
114            }
115        }
116        None
117    }
118
119    pub fn mut_state(&mut self, index: usize) -> Result<&mut State<T>, FsmError> {
120        if index >= self.num_states {
121            Err(FsmError::StateIndexOutOfBounds)
122        } else if let Some(ref mut state) = self.states[index] {
123            Ok(state)
124        } else {
125            Err(FsmError::StateIsEmpty)
126        }
127    }
128
129    pub fn transition(&self, src: usize, dst: usize) -> Result<&Transition<T>, FsmError> {
130        if src >= self.num_states || dst >= self.num_states {
131            Err(FsmError::TransitionIndexOutOfBounds)
132        } else if let Some(ref transition) = self.transitions[src][dst] {
133            Ok(transition)
134        } else {
135            Err(FsmError::TransitionIsEmpty)
136        }
137    }
138
139    pub fn active_transitions(&self, src: usize) -> Result<&[Option<Transition<T>>], FsmError> {
140        if src >= self.num_states {
141            Err(FsmError::TransitionIndexOutOfBounds)
142        } else {
143            Ok(&self.transitions[src][..])
144        }
145    }
146
147    pub fn add_state(&mut self, s: State<T>) -> Result<usize, FsmError> {
148        if self.num_states >= self.states.capacity() {
149            Err(FsmError::MaxNumberOfStatesExceeded)
150        } else {
151            self.states[self.num_states] = Some(s);
152            let index = self.num_states;
153            self.num_states += 1;
154            Ok(index)
155        }
156    }
157
158    pub fn add_transition(&mut self, t: Transition<T>, src: usize, dst: usize) -> Result<(), FsmError>{
159        if src >= self.num_states || dst >= self.num_states {
160            Err(FsmError::TransitionIndexOutOfBounds)
161        } else if src == dst {
162            Err(FsmError::AddTransitionSrcDstStatesEqual)
163        } else {
164            self.transitions[src][dst] = Some(t);
165            Ok(())
166        }
167    }
168
169    pub fn set_active_state(&mut self, s: usize) -> Result<(), FsmError> {
170        match self.state(s) {
171            Ok(_) => {
172                self.active_state = Some(s);
173                Ok(())
174            },
175            Err(e) => Err(e),
176        }
177
178    }
179
180    pub fn set_error_callbacks(&mut self, init: &'static ErrorCallback<T>, exec: &'static ErrorCallback<T>) {
181        self.error = Some((init, exec))
182    }
183
184    pub fn run(&mut self) {
185        if let Some(active_state_index) = self.active_state {
186            let active_state = self.state(active_state_index).expect("Failed to acquire active state").to_owned();
187
188            // Initialize state if needed
189            if !&self.active_state_initialized {
190                if let Err(e) = active_state.do_init(&mut self.data) {
191                    self.do_error_callback(e);
192                    return;
193                }
194            }
195
196            self.active_state_initialized = true;
197
198            if let Err(e) = active_state.do_exec(&mut self.data) {
199                self.do_error_callback(e);
200                return;
201            }
202
203            let mut next_state_index = active_state_index;
204            let next_state_trans = self.active_transitions(active_state_index).expect("Failed to acquire active transitions");
205            let mut check = false;
206
207            // Check transitions
208            for t in next_state_trans {
209                if let Some(transition) = t {
210                    let transition = transition.to_owned();
211                    check = transition.do_check(&self.data);
212                    if check {
213                        next_state_index = transition.dst;
214                        match transition.do_done(&mut self.data) {
215                            Err(e) => {
216                                self.do_error_callback(e);
217                                return;
218                            },
219                            Ok(_) => break
220                        }
221                    }
222                }
223            }
224
225            if !check {
226                // No transition check returned true, stay in the same active state
227                return;
228            }
229
230            // Some transition check returned true, move to dst state
231            self.active_state = Some(next_state_index);
232            self.active_state_initialized = false;
233        }
234
235        // for s in &mut self.states {
236        //     if let Some(state) = s {
237        //         state.do_init().unwrap();
238        //         state.do_exec().unwrap();
239        //     }
240        // }
241        // for trans_src in &self.transitions {
242        //     for trans_dst in trans_src {
243        //         if let Some(trans) = trans_dst {
244        //             trans.do_check();
245        //             trans.do_done().unwrap();
246        //         }
247        //     }
248        // }
249    }
250
251    fn do_error_callback(&mut self, error: FsmError) {
252        println!("Error state: {}", error);
253        if let Some((callback_init, callback_exec)) = self.error {
254            callback_init(error, &mut self.data);
255            if let Some(next_state) = callback_exec(error, &mut self.data) {
256                match next_state {
257                    Destination::Index(next_state_index) => {
258                        if next_state_index < self.num_states {
259                            self.active_state = Some(next_state_index);
260                            self.active_state_initialized = false;
261                        }
262                    },
263                    Destination::Name(next_state_name) => {
264                        if let Some(next_state_index) = self.state_by_name(next_state_name) {
265                            self.active_state = Some(next_state_index);
266                            self.active_state_initialized = false;
267                        }
268                    }
269                }
270            }
271        }
272    }
273}
274
275#[derive(Clone)]
276pub struct State<T: 'static> {
277    pub name: String,
278    pub init: &'static StateCallback<T>,
279    pub exec: &'static StateCallback<T>,
280}
281
282impl<T> State<T> {
283    pub fn new<'b>(name: impl Into<alloc::borrow::Cow<'b, str>>,
284                   init: &'static StateCallback<T>,
285                   exec: &'static StateCallback<T>
286    ) -> State<T> {
287        State { name: name.into().into_owned(), init, exec }
288    }
289
290    pub fn do_init(&self, data: &mut T) -> Result<(), FsmError> {
291        (self.init)(self, data)
292    }
293
294    pub fn do_exec(&self, data: &mut T) -> Result<(), FsmError> {
295        (self.exec)(self, data)
296    }
297}
298
299#[derive(Clone)]
300pub struct Transition<T: 'static + Clone> {
301    pub name: String,
302    pub src: usize,
303    pub dst: usize,
304    pub check: &'static TransCheckCallback<T>,
305    pub done: &'static TransDoneCallback<T>,
306}
307
308impl<T: Clone> Transition<T> {
309    pub fn new<'b>(name: impl Into<alloc::borrow::Cow<'b, str>>,
310                   src: usize,
311                   dst: usize,
312                   check: &'static TransCheckCallback<T>,
313                   done: &'static TransDoneCallback<T>
314    ) -> Transition<T> {
315        Transition {
316            name: name.into().into_owned(),
317            src, dst,
318            check, done }
319    }
320
321    pub fn do_check(&self, data: &T) -> bool {
322        (self.check)(self, data)
323    }
324
325    pub fn do_done(&self, data: &mut T) -> Result<(), FsmError> {
326        (self.done)(self, data)
327    }
328}