tamata_macros/
lib.rs

1use std::collections::BTreeSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{braced, parenthesized, parse_macro_input, Expr, Ident, Token, Type};
6use syn::parse::{Error, Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8
9#[proc_macro]
10pub fn fsm(input: TokenStream) -> TokenStream {
11    let Fsm {
12        name,
13        error,
14        context,
15        states,
16        events,
17        transitions,
18    } = parse_macro_input!(input as Fsm);
19
20    let mut state_impls = quote! {};
21    for state in &states {
22        state_impls = quote! {
23            #state_impls
24
25            impl tamata::State<#name> for #state {}
26        };
27    }
28
29    let mut event_impls = quote! {};
30    for event in &events {
31        event_impls = quote! {
32            #event_impls
33
34            impl tamata::Event<#name> for #event {}
35        };
36    }
37
38    let state_enum_name = quote::format_ident!("{}State", name);
39    let mut state_enum_variants = quote! {};
40    for state in &states {
41        state_enum_variants = quote!{
42            #state_enum_variants
43            #state(#state),
44        }
45    }
46    let state_enum = quote! {
47        #[derive(Debug)]
48        pub enum #state_enum_name {
49            #state_enum_variants
50        }
51    };
52
53    let mut state_enum_from_impls = quote! {};
54    for state in &states {
55        state_enum_from_impls = quote! {
56            #state_enum_from_impls
57
58            impl From<#state> for #state_enum_name {
59                fn from(state: #state) -> #state_enum_name {
60                    #state_enum_name :: #state(state)
61                }
62            }
63        }
64    }
65
66    let event_enum_name = quote::format_ident!("{}Event", name);
67    let mut event_enum_variants = quote! {};
68    for event in &events {
69        event_enum_variants = quote!{
70            #event_enum_variants
71            #event(#event),
72        }
73    }
74    let event_enum = quote! {
75        #[derive(Debug)]
76        pub enum #event_enum_name {
77            #event_enum_variants
78        }
79    };
80
81    let mut event_enum_from_impls = quote! {};
82    for event in &events {
83        event_enum_from_impls = quote! {
84            #event_enum_from_impls
85
86            impl From<#event> for #event_enum_name {
87                fn from(event: #event) -> #event_enum_name {
88                    #event_enum_name :: #event(event)
89                }
90            }
91        }
92    }
93
94    let mut enum_transitions = quote! {};
95    for transition in &transitions {
96        let state = &transition.state;
97        let event = &transition.event;
98        let next = &transition.next;
99        let action = &transition.action;
100
101        if let Some(action) = action {
102            enum_transitions = quote! {
103                #enum_transitions
104
105                (#state_enum_name::#state(s), #event_enum_name::#event(e)) => {
106                    impl tamata::Transition<#name, #event> for #state {
107                        type Next = #next;
108
109                        fn send(
110                            self,
111                            event: #event,
112                            ctx: #context,
113                        ) -> Result<#next, #error> {
114                            (#action)(self, event, ctx)
115                        }
116                    }
117
118                    let next = tamata::Transition::<#name, #event>::send(s, e, ctx)?;
119                    let next = #state_enum_name::#next(next);
120                    tamata::Sent::Valid(next)
121                },
122            }
123        } else {
124            enum_transitions = quote! {
125                #enum_transitions
126
127                (#state_enum_name::#state(s), #event_enum_name::#event(e)) => {
128                    let next = tamata::Transition::<#name, #event>::send(s, e, ctx)?;
129                    let next = #state_enum_name::from(next);
130                    tamata::Sent::Valid(next)
131                },
132            }
133        };
134    }
135
136    let impl_state_enum = quote! {
137        impl #state_enum_name {
138            pub fn send(
139                self,
140                event: impl Into<#event_enum_name>,
141                ctx: #context
142            ) -> Result<tamata::Sent<#name>, #error> {
143                let next = match (self, event.into()) {
144                    #enum_transitions
145                    (state, event) => {
146                        tamata::Sent::Invalid(state, event)
147                    }
148                };
149
150                Ok(next)
151            }
152        }
153    };
154
155    let impl_fsm = quote! {
156        impl tamata::Fsm for #name {
157            type Error = #error;
158            type Context = #context;
159
160            type State = #state_enum_name;
161            type Event = #event_enum_name;
162        }
163    };
164
165    let expanded = quote! {
166        #impl_fsm
167
168        #state_impls
169
170        #event_impls
171
172        #state_enum
173
174        #state_enum_from_impls
175
176        #event_enum
177
178        #event_enum_from_impls
179
180        #impl_state_enum
181    };
182
183    TokenStream::from(expanded)
184}
185
186struct Fsm {
187    name: Ident,
188    error: Type,
189    context: Type,
190    states: Vec<Ident>,
191    events: Vec<Ident>,
192    transitions: Vec<Transition>,
193}
194
195impl Parse for Fsm {
196    fn parse(input: ParseStream) -> Result<Self> {
197        let name: Ident = input.parse()?;
198
199        input.parse::<Token![,]>()?;
200
201        let error = input.parse::<Ident>()?;
202        if error != "Error" {
203            return Err(Error::new(error.span(), "expected `Error`"));
204        }
205        input.parse::<Token![=]>()?;
206        let error: Type = input.parse()?;
207
208        input.parse::<Token![,]>()?;
209
210        let context = input.parse::<Ident>()?;
211        if context != "Context" {
212            return Err(Error::new(context.span(), "expected `Context`"));
213        }
214        input.parse::<Token![=]>()?;
215        let context: Type = input.parse()?;
216
217        // Optional trailing comma.
218        let _ = input.parse::<Token![,]>();
219
220        let transitions;
221        braced!(transitions in input);
222        let transitions: Punctuated<Transition, Token![,]> =
223            transitions.parse_terminated(Transition::parse)?;
224
225        let transitions: Vec<_> = transitions.into_iter().collect();
226
227        // Optional trailing comma.
228        let _ = input.parse::<Token![,]>();
229
230        let mut states = BTreeSet::default();
231        let mut events = BTreeSet::default();
232
233        for transition in &transitions {
234            states.insert(transition.state.clone());
235            states.insert(transition.next.clone());
236            events.insert(transition.event.clone());
237        }
238
239        let states: Vec<_> = states.into_iter().collect();
240        let events: Vec<_> = events.into_iter().collect();
241
242        Ok(Fsm {
243            name,
244            error,
245            context,
246            states,
247            events,
248            transitions,
249        })
250    }
251}
252
253struct Transition {
254    state: Ident,
255    event: Ident,
256    next: Ident,
257    action: Option<Expr>,
258}
259
260impl Parse for Transition {
261    fn parse(input: ParseStream) -> Result<Self> {
262        let state: Ident = input.parse()?;
263
264        let events;
265        parenthesized!(events in input);
266        let events: Punctuated<Ident, Token![,]> =
267            events.parse_terminated(Ident::parse)?;
268
269        let event: Ident = events.into_iter().next().unwrap();
270
271        input.parse::<Token![->]>()?;
272
273        let next: Ident = input.parse()?;
274
275        let action = if input.peek(Token![=]) {
276            input.parse::<Token![=]>()?;
277            let action: Expr = input.parse()?;
278            Some(action)
279        } else {
280            None
281        };
282
283        Ok(Transition {
284            state,
285            event,
286            next,
287            action,
288        })
289    }
290}