1use std::collections::BTreeSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{braced, parenthesized, parse_macro_input, Expr, Ident, Token, Type};
6use syn::parse::{Error, Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8
9#[proc_macro]
10pub fn fsm(input: TokenStream) -> TokenStream {
11 let Fsm {
12 name,
13 error,
14 context,
15 states,
16 events,
17 transitions,
18 } = parse_macro_input!(input as Fsm);
19
20 let mut state_impls = quote! {};
21 for state in &states {
22 state_impls = quote! {
23 #state_impls
24
25 impl tamata::State<#name> for #state {}
26 };
27 }
28
29 let mut event_impls = quote! {};
30 for event in &events {
31 event_impls = quote! {
32 #event_impls
33
34 impl tamata::Event<#name> for #event {}
35 };
36 }
37
38 let state_enum_name = quote::format_ident!("{}State", name);
39 let mut state_enum_variants = quote! {};
40 for state in &states {
41 state_enum_variants = quote!{
42 #state_enum_variants
43 #state(#state),
44 }
45 }
46 let state_enum = quote! {
47 #[derive(Debug)]
48 pub enum #state_enum_name {
49 #state_enum_variants
50 }
51 };
52
53 let mut state_enum_from_impls = quote! {};
54 for state in &states {
55 state_enum_from_impls = quote! {
56 #state_enum_from_impls
57
58 impl From<#state> for #state_enum_name {
59 fn from(state: #state) -> #state_enum_name {
60 #state_enum_name :: #state(state)
61 }
62 }
63 }
64 }
65
66 let event_enum_name = quote::format_ident!("{}Event", name);
67 let mut event_enum_variants = quote! {};
68 for event in &events {
69 event_enum_variants = quote!{
70 #event_enum_variants
71 #event(#event),
72 }
73 }
74 let event_enum = quote! {
75 #[derive(Debug)]
76 pub enum #event_enum_name {
77 #event_enum_variants
78 }
79 };
80
81 let mut event_enum_from_impls = quote! {};
82 for event in &events {
83 event_enum_from_impls = quote! {
84 #event_enum_from_impls
85
86 impl From<#event> for #event_enum_name {
87 fn from(event: #event) -> #event_enum_name {
88 #event_enum_name :: #event(event)
89 }
90 }
91 }
92 }
93
94 let mut enum_transitions = quote! {};
95 for transition in &transitions {
96 let state = &transition.state;
97 let event = &transition.event;
98 let next = &transition.next;
99 let action = &transition.action;
100
101 if let Some(action) = action {
102 enum_transitions = quote! {
103 #enum_transitions
104
105 (#state_enum_name::#state(s), #event_enum_name::#event(e)) => {
106 impl tamata::Transition<#name, #event> for #state {
107 type Next = #next;
108
109 fn send(
110 self,
111 event: #event,
112 ctx: #context,
113 ) -> Result<#next, #error> {
114 (#action)(self, event, ctx)
115 }
116 }
117
118 let next = tamata::Transition::<#name, #event>::send(s, e, ctx)?;
119 let next = #state_enum_name::#next(next);
120 tamata::Sent::Valid(next)
121 },
122 }
123 } else {
124 enum_transitions = quote! {
125 #enum_transitions
126
127 (#state_enum_name::#state(s), #event_enum_name::#event(e)) => {
128 let next = tamata::Transition::<#name, #event>::send(s, e, ctx)?;
129 let next = #state_enum_name::from(next);
130 tamata::Sent::Valid(next)
131 },
132 }
133 };
134 }
135
136 let impl_state_enum = quote! {
137 impl #state_enum_name {
138 pub fn send(
139 self,
140 event: impl Into<#event_enum_name>,
141 ctx: #context
142 ) -> Result<tamata::Sent<#name>, #error> {
143 let next = match (self, event.into()) {
144 #enum_transitions
145 (state, event) => {
146 tamata::Sent::Invalid(state, event)
147 }
148 };
149
150 Ok(next)
151 }
152 }
153 };
154
155 let impl_fsm = quote! {
156 impl tamata::Fsm for #name {
157 type Error = #error;
158 type Context = #context;
159
160 type State = #state_enum_name;
161 type Event = #event_enum_name;
162 }
163 };
164
165 let expanded = quote! {
166 #impl_fsm
167
168 #state_impls
169
170 #event_impls
171
172 #state_enum
173
174 #state_enum_from_impls
175
176 #event_enum
177
178 #event_enum_from_impls
179
180 #impl_state_enum
181 };
182
183 TokenStream::from(expanded)
184}
185
186struct Fsm {
187 name: Ident,
188 error: Type,
189 context: Type,
190 states: Vec<Ident>,
191 events: Vec<Ident>,
192 transitions: Vec<Transition>,
193}
194
195impl Parse for Fsm {
196 fn parse(input: ParseStream) -> Result<Self> {
197 let name: Ident = input.parse()?;
198
199 input.parse::<Token![,]>()?;
200
201 let error = input.parse::<Ident>()?;
202 if error != "Error" {
203 return Err(Error::new(error.span(), "expected `Error`"));
204 }
205 input.parse::<Token![=]>()?;
206 let error: Type = input.parse()?;
207
208 input.parse::<Token![,]>()?;
209
210 let context = input.parse::<Ident>()?;
211 if context != "Context" {
212 return Err(Error::new(context.span(), "expected `Context`"));
213 }
214 input.parse::<Token![=]>()?;
215 let context: Type = input.parse()?;
216
217 let _ = input.parse::<Token![,]>();
219
220 let transitions;
221 braced!(transitions in input);
222 let transitions: Punctuated<Transition, Token![,]> =
223 transitions.parse_terminated(Transition::parse)?;
224
225 let transitions: Vec<_> = transitions.into_iter().collect();
226
227 let _ = input.parse::<Token![,]>();
229
230 let mut states = BTreeSet::default();
231 let mut events = BTreeSet::default();
232
233 for transition in &transitions {
234 states.insert(transition.state.clone());
235 states.insert(transition.next.clone());
236 events.insert(transition.event.clone());
237 }
238
239 let states: Vec<_> = states.into_iter().collect();
240 let events: Vec<_> = events.into_iter().collect();
241
242 Ok(Fsm {
243 name,
244 error,
245 context,
246 states,
247 events,
248 transitions,
249 })
250 }
251}
252
253struct Transition {
254 state: Ident,
255 event: Ident,
256 next: Ident,
257 action: Option<Expr>,
258}
259
260impl Parse for Transition {
261 fn parse(input: ParseStream) -> Result<Self> {
262 let state: Ident = input.parse()?;
263
264 let events;
265 parenthesized!(events in input);
266 let events: Punctuated<Ident, Token![,]> =
267 events.parse_terminated(Ident::parse)?;
268
269 let event: Ident = events.into_iter().next().unwrap();
270
271 input.parse::<Token![->]>()?;
272
273 let next: Ident = input.parse()?;
274
275 let action = if input.peek(Token![=]) {
276 input.parse::<Token![=]>()?;
277 let action: Expr = input.parse()?;
278 Some(action)
279 } else {
280 None
281 };
282
283 Ok(Transition {
284 state,
285 event,
286 next,
287 action,
288 })
289 }
290}