machine/
lib.rs

1//! # Machine
2//!
3//! ## Features
4//!
5//! This crate defines three procedural macros to help you write enum based state machines,
6//! without writing the associated boilerplate.
7//!
8//! * define a state machine as an enum, each variant can contain members
9//! * an Error state for invalid transitions is added automatically
10//! * transitions can have multiple end states if needed (conditions depending on message content, etc)
11//! * accessors can be generated for state members
12//! * wrapper methods and accessors are generated on the parent enum
13//! * the generated code is also written in the `target/machine` directory for further inspection
14//! * a dot file is written in the `target/machine` directory for graph generation
15//!
16//! ## Usage
17//!
18//! machine is available on [crates.io](https://crates.io/crates/machine) and can be included in your Cargo enabled project like this:
19//!
20//! ```toml
21//! [dependencies]
22//! machine = "^0.2"
23//! ```
24//!
25//! Then include it in your code like this:
26//!
27//! ```rust,ignore
28//! #[macro_use]
29//! extern crate machine;
30//! ```
31//!
32//! ## Example: the traffic light
33//!
34//! We'll define a state machine representing a traffic light, specifying a maximum
35//! number of cars passing while in the green state.
36//!
37//! The following machine definition:
38//!
39//! ```rust,ignore
40//! machine!(
41//!   enum Traffic {
42//!     Green { count: u8 },
43//!     Orange,
44//!     Red
45//!   }
46//! );
47//! ```
48//!
49//! will produce the following code:
50//!
51//! ```rust,ignore
52//! #[derive(Clone, Debug, PartialEq)]
53//! pub enum Traffic {
54//!     Error,
55//!     Green(Green),
56//!     Orange(Orange),
57//!     Red(Red),
58//! }
59//!
60//! #[derive(Clone, Debug, PartialEq)]
61//! pub struct Green {
62//!     count: u8,
63//! }
64//!
65//! #[derive(Clone, Debug, PartialEq)]
66//! pub struct Orange {}
67//!
68//! #[derive(Clone, Debug, PartialEq)]
69//! pub struct Red {}
70//!
71//! impl Traffic {
72//!   pub fn green(count: u8) -> Traffic {
73//!     Traffic::Green(Green { count })
74//!   }
75//!   pub fn orange() -> Traffic {
76//!     Traffic::Orange(Orange {})
77//!   }
78//!   pub fn red() -> Traffic {
79//!     Traffic::Red(Red {})
80//!   }
81//!   pub fn error() -> Traffic {
82//!     Traffic::Error
83//!   }
84//! }
85//! ```
86//!
87//! ### Transitions
88//!
89//! From there, we can define the `Advance` message to go to the next color, and the associated
90//! transitions:
91//!
92//! ```rust,ignore
93//! #[derive(Clone,Debug,PartialEq)]
94//! pub struct Advance;
95//!
96//! transitions!(Traffic,
97//!   [
98//!     (Green, Advance) => Orange,
99//!     (Orange, Advance) => Red,
100//!     (Red, Advance) => Green
101//!   ]
102//! );
103//! ```
104//!
105//! This will generate an enum holding the messages for that state machine,
106//! and a `on_advance` method on the parent enum.
107//!
108//! ```rust,ignore
109//! #[derive(Clone, Debug, PartialEq)]
110//! pub enum TrafficMessages {
111//!     Advance(Advance),
112//! }
113//!
114//! impl Traffic {
115//!   pub fn on_advance(self, input: Advance) -> Traffic {
116//!     match self {
117//!       Traffic::Green(state) => Traffic::Orange(state.on_advance(input)),
118//!       Traffic::Orange(state) => Traffic::Red(state.on_advance(input)),
119//!       Traffic::Red(state) => Traffic::Green(state.on_advance(input)),
120//!       _ => Traffic::Error,
121//!     }
122//!   }
123//! }
124//! ```
125//!
126//! The compiler will then complain that the `on_advance` is missing on the
127//! `Green`, `Orange` and `Red` structures:
128//!
129//! ```text,ignore
130//! error[E0599]: no method named on_advance found for type Green in the current scope
131//!   --> tests/t.rs:18:1
132//!    |
133//! 4  | / machine!(
134//! 5  | |   enum Traffic {
135//! 6  | |     Green { count: u8 },
136//! 7  | |     Orange,
137//! 8  | |     Red,
138//! 9  | |   }
139//! 10 | | );
140//!    | |__- method `on_advance` not found for this
141//! ...
142//! 18 | / transitions!(Traffic,
143//! 19 | |   [
144//! 20 | |     (Green, Advance) => Orange,
145//! 21 | |     (Orange, Advance) => Red,
146//! 22 | |     (Red, Advance) => Green
147//! 23 | |   ]
148//! 24 | | );
149//!    | |__^
150//!
151//! [...]
152//! ```
153//!
154//! The `transitions` macro takes care of the boilerplate, writing the wrapper
155//! methods, and making sure that a state machine receiving the wrong message
156//! will get into the error state. But we still need to define manually the
157//! transition functions for each of our states, since most of the work will
158//! be done there:
159//!
160//! ```rust,ignore
161//! impl Green {
162//!   pub fn on_advance(self, _: Advance) -> Orange {
163//!     Orange {}
164//!   }
165//! }
166//!
167//! impl Orange {
168//!   pub fn on_advance(self, _: Advance) -> Red {
169//!     Red {}
170//!   }
171//! }
172//!
173//! impl Red {
174//!   pub fn on_advance(self, _: Advance) -> Green {
175//!     Green {
176//!       count: 0
177//!     }
178//!   }
179//! }
180//! ```
181//!
182//! Now we want to add a message to count passing cars when in the green state,
183//! and switch to the orange state if at least 10 cars have passed.
184//! So the `PassCar` message is only accepted by the green state, and the
185//! transition has two possible end states, green and orange.
186//! While we might want a clean state machine where each state and message
187//! combination only has one end state, we could have conditions depending
188//! on message values, or state members that would not require creating
189//! new states or messages instead:
190//!
191//! ```rust,ignore
192//! #[derive(Clone,Debug,PartialEq)]
193//! pub struct PassCar { count: u8 }
194//!
195//! transitions!(Traffic,
196//!   [
197//!     (Green, Advance) => Orange,
198//!     (Orange, Advance) => Red,
199//!     (Red, Advance) => Green,
200//!     (Green, PassCar) => [Green, Orange]
201//!   ]
202//! );
203//!
204//! impl Green {
205//!   pub fn on_pass_car(self, input: PassCar) -> Traffic {
206//!     let count = self.count + input.count;
207//!     if count >= 10 {
208//!       println!("reached max cars count: {}", count);
209//!       Traffic::orange()
210//!     } else {
211//!       Traffic::green(count)
212//!     }
213//!   }
214//! }
215//! ```
216//!
217//! The `on_pass_car` method can have multiple end states, so it must
218//! return a `Traffic`.
219//!
220//! The generated code will now contain a `on_pass_car` for the
221//! `Traffic` enum. Note that if a state other than `Green`
222//! receives the `PassCar` message, the state machine will go
223//! into the `Error` state and stay there indefinitely.
224//!
225//! ```rust,ignore
226//! #[derive(Clone, Debug, PartialEq)]
227//! pub enum TrafficMessages {
228//!   Advance(Advance),
229//!   PassCar(PassCar),
230//! }
231//!
232//! impl Traffic {
233//!   pub fn on_advance(self, input: Advance) -> Traffic {
234//!     match self {
235//!       Traffic::Green(state) => Traffic::Orange(state.on_advance(input)),
236//!       Traffic::Orange(state) => Traffic::Red(state.on_advance(input)),
237//!       Traffic::Red(state) => Traffic::Green(state.on_advance(input)),
238//!       _ => Traffic::Error,
239//!     }
240//!   }
241//!
242//!   pub fn on_pass_car(self, input: PassCar) -> Traffic {
243//!     match self {
244//!       Traffic::Green(state) => state.on_pass_car(input),
245//!       _ => Traffic::Error,
246//!     }
247//!   }
248//! }
249//! ```
250//!
251//! The complete generated code can be found in `target/machine/traffic.rs`.
252//!
253//! The machine crate will also generate the `target/machine/traffic.dot` file
254//! for graphviz usage:
255//!
256//! ```dot
257//! digraph Traffic {
258//! Green -> Orange [ label = "Advance" ];
259//! Orange -> Red [ label = "Advance" ];
260//! Red -> Green [ label = "Advance" ];
261//! Green -> Green [ label = "PassCar" ];
262//! Green -> Orange [ label = "PassCar" ];
263//! }
264//! ```
265//!
266//! `dot -Tpng target/machine/traffic.dot > traffic.png` will generate the following image:
267//!
268//! ![traffic light transitions graph](https://raw.githubusercontent.com/rust-bakery/machine/master/assets/traffic.png)
269//!
270//! We can then use the messages to trigger transitions:
271//!
272//! ```rust,ignore
273//! // starting in green state, no cars have passed
274//! let mut t = Traffic::Green(Green { count: 0 });
275//!
276//! t = t.on_pass_car(PassCar { count: 1});
277//! t = t.on_pass_car(PassCar { count: 2});
278//! // still in green state, 3 cars have passed
279//! assert_eq!(t, Traffic::green(3));
280//!
281//! // each advance call will move to the next color
282//! t = t.on_advance(Advance);
283//! assert_eq!(t, Traffic::orange());
284//!
285//! t = t.on_advance(Advance);
286//! assert_eq!(t, Traffic::red());
287//!
288//! t = t.on_advance(Advance);
289//! assert_eq!(t, Traffic::green(0));
290//! t = t.on_pass_car(PassCar { count: 5 });
291//! assert_eq!(t, Traffic::green(5));
292//!
293//! // when more than 10 cars have passed, go to the orange state
294//! t = t.on_pass_car(PassCar { count: 7 });
295//! assert_eq!(t, Traffic::orange());
296//! t = t.on_advance(Advance);
297//! assert_eq!(t, Traffic::red());
298//!
299//! // if we try to use the PassCar message on state other than Green,
300//! // we go into the error state
301//! t = t.on_pass_car(PassCar { count: 7 });
302//! assert_eq!(t, Traffic::error());
303//!
304//! // once in the error state, we stay in the error state
305//! t = t.on_advance(Advance);
306//! assert_eq!(t, Traffic::error());
307//! ```
308//!
309//! ### Methods
310//!
311//! The `methods!` procedural macro can generate wrapper methods for state member
312//! accessors, or require method implementations on states:
313//!
314//! ```rust,ignore
315//! methods!(Traffic,
316//!   [
317//!     Green => get count: u8,
318//!     Green => set count: u8,
319//!     Green, Orange, Red => fn can_pass(&self) -> bool
320//!   ]
321//! );
322//! ```
323//!
324//! This will generate:
325//! - a `count()` getter for the `Green` state (`get`) and the wrapping enum
326//! - a `count_mut()` setter for the `Green` state (`set`) and the wrapping enum
327//! - a `can_pass()` method for the wrapping enum, requiring its implementations for all states
328//!
329//! Methods can have arguments, and those will be passed to the corresponding method
330//! on states, as expected.
331//!
332//! ```rust,ignore
333//! impl Orange {}
334//! impl Red {}
335//! impl Green {
336//!   pub fn count(&self) -> &u8 {
337//!     &self.count
338//!   }
339//!
340//!   pub fn count_mut(&mut self) -> &mut u8 {
341//!     &mut self.count
342//!   }
343//! }
344//!
345//! impl Traffic {
346//!   pub fn count(&self) -> Option<&u8> {
347//!     match self {
348//!       Traffic::Green(ref v) => Some(v.count()),
349//!       _ => None,
350//!     }
351//!   }
352//!
353//!   pub fn count_mut(&mut self) -> Option<&mut u8> {
354//!     match self {
355//!       Traffic::Green(ref mut v) => Some(v.count_mut()),
356//!       _ => None,
357//!     }
358//!   }
359//!
360//!   pub fn can_pass(&self) -> Option<bool> {
361//!     match self {
362//!       Traffic::Green(ref v) => Some(v.can_pass()),
363//!       Traffic::Orange(ref v) => Some(v.can_pass()),
364//!       Traffic::Red(ref v) => Some(v.can_pass()),
365//!       _ => None,
366//!     }
367//!   }
368//! }
369//! ```
370//!
371//! We can now add the remaining methods and get a working state machine:
372//!
373//! ```rust,ignore
374//! impl Green {
375//!   pub fn can_pass(&self) -> bool {
376//!     true
377//!   }
378//! }
379//!
380//! impl Orange {
381//!   pub fn can_pass(&self) -> bool {
382//!     false
383//!   }
384//! }
385//!
386//! impl Red {
387//!   pub fn can_pass(&self) -> bool {
388//!     false
389//!   }
390//! }
391//! ```
392
393extern crate case;
394extern crate proc_macro;
395/*
396#[macro_use] mod dynamic_machine;
397
398#[macro_export]
399macro_rules! machine(
400  ( $($token:tt)* ) => ( static_machine!( $($token)* ); );
401);
402*/
403
404#[macro_use]
405extern crate log;
406#[macro_use]
407extern crate syn;
408#[macro_use]
409extern crate quote;
410
411use std::collections::{HashMap, HashSet};
412use std::fs::{File, OpenOptions, create_dir};
413use std::io::{Seek, Write};
414
415use case::CaseExt;
416use syn::export::Span;
417use syn::punctuated::Pair;
418use syn::parse::{Parse, ParseStream, Result};
419use syn::{
420    Abi, Attribute, Expr, FnArg, FnDecl, Generics, Ident, ItemEnum, MethodSig, ReturnType, Type,
421    WhereClause, PathArguments, GenericArgument,
422};
423use quote::ToTokens;
424
425struct Machine {
426    attributes: Vec<Attribute>,
427    data: ItemEnum,
428}
429
430impl Parse for Machine {
431    fn parse(input: ParseStream) -> Result<Self> {
432        let attributes: Vec<Attribute> = input.call(Attribute::parse_outer)?;
433        let data: syn::ItemEnum = input.parse()?;
434
435        Ok(Machine { attributes, data })
436    }
437}
438
439#[proc_macro]
440pub fn machine(input: proc_macro::TokenStream) -> syn::export::TokenStream {
441    let ast = parse_macro_input!(input as Machine);
442
443    // Build the impl
444    let (name, gen) = impl_machine(&ast);
445
446    trace!("generated: {}", gen);
447
448    let file_name = format!("target/machine/{}.rs", name.to_string().to_lowercase());
449    let _ = create_dir("target/machine");
450    File::create(&file_name)
451        .and_then(|mut file| {
452            file.seek(std::io::SeekFrom::End(0))?;
453            file.write_all(gen.to_string().as_bytes())?;
454            file.flush()
455        })
456        .expect("error writing machine definition");
457
458    gen
459}
460
461fn impl_machine(m: &Machine) -> (&Ident, syn::export::TokenStream) {
462    let Machine { attributes, data } = m;
463    let ast = data;
464    //println!("attributes: {:?}", attributes);
465    //println!("ast: {:#?}", ast);
466
467    let machine_name = &ast.ident;
468    let variants_names = &ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
469    let structs_names = variants_names.clone();
470
471    // define the state enum
472    let toks = quote! {
473      #(#attributes)*
474      pub enum #machine_name {
475        Error,
476        #(#variants_names(#structs_names)),*
477      }
478    };
479
480    let mut stream = proc_macro::TokenStream::from(toks);
481
482    // define structs for each state
483    for ref variant in ast.variants.iter() {
484        let name = &variant.ident;
485
486        let fields = &variant
487            .fields
488            .iter()
489            .map(|f| {
490                let vis = &f.vis;
491                let ident = &f.ident;
492                let ty = &f.ty;
493
494                quote! {
495                  #vis #ident: #ty
496                }
497            })
498            .collect::<Vec<_>>();
499
500        let toks = quote! {
501          #(#attributes)*
502          pub struct #name {
503            #(#fields),*
504          }
505        };
506
507        stream.extend(proc_macro::TokenStream::from(toks));
508    }
509
510    let methods = &ast
511        .variants
512        .iter()
513        .map(|variant| {
514            let fn_name = Ident::new(&variant.ident.to_string().to_snake(), Span::call_site());
515            let struct_name = &variant.ident;
516
517            let args = &variant
518                .fields
519                .iter()
520                .map(|f| {
521                    let ident = &f.ident;
522                    let ty = &f.ty;
523
524                    quote! {
525                      #ident: #ty
526                    }
527                })
528                .collect::<Vec<_>>();
529
530            let arg_names = &variant.fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
531
532            quote! {
533              pub fn #fn_name(#(#args),*) -> #machine_name {
534                #machine_name::#struct_name(#struct_name {
535                  #(#arg_names),*
536                })
537              }
538            }
539        })
540        .collect::<Vec<_>>();
541
542    let toks = quote! {
543      impl #machine_name {
544        #(#methods)*
545
546        pub fn error() -> #machine_name {
547          #machine_name::Error
548        }
549      }
550    };
551
552    stream.extend(proc_macro::TokenStream::from(toks));
553
554    (machine_name, stream)
555}
556
557#[derive(Debug)]
558struct Transitions {
559    pub machine_name: Ident,
560    pub transitions: Vec<Transition>,
561}
562
563#[derive(Debug)]
564struct Transition {
565    pub start: Ident,
566    pub message: Type,
567    pub end: Vec<Ident>,
568}
569
570impl Parse for Transitions {
571    fn parse(input: ParseStream) -> Result<Self> {
572        let machine_name: Ident = input.parse()?;
573        let _: Token![,] = input.parse()?;
574
575        let content;
576        bracketed!(content in input);
577
578        trace!("content: {:?}", content);
579        let mut transitions = Vec::new();
580
581        let t: Transition = content.parse()?;
582        transitions.push(t);
583
584        loop {
585            let lookahead = content.lookahead1();
586            if lookahead.peek(Token![,]) {
587                let _: Token![,] = content.parse()?;
588                let t: Transition = content.parse()?;
589                transitions.push(t);
590            } else {
591                break;
592            }
593        }
594
595        Ok(Transitions {
596            machine_name,
597            transitions,
598        })
599    }
600}
601
602impl Parse for Transition {
603    fn parse(input: ParseStream) -> Result<Self> {
604        let left;
605        parenthesized!(left in input);
606
607        let start: Ident = left.parse()?;
608        let _: Token![,] = left.parse()?;
609        let message: Type = left.parse()?;
610
611        let _: Token![=>] = input.parse()?;
612
613        let end = match input.parse::<Ident>() {
614            Ok(i) => vec![i],
615            Err(_) => {
616                let content;
617                bracketed!(content in input);
618
619                //println!("content: {:?}", content);
620                let mut states = Vec::new();
621
622                let t: Ident = content.parse()?;
623                states.push(t);
624
625                loop {
626                    let lookahead = content.lookahead1();
627                    if lookahead.peek(Token![,]) {
628                        let _: Token![,] = content.parse()?;
629                        let t: Ident = content.parse()?;
630                        states.push(t);
631                    } else {
632                        break;
633                    }
634                }
635
636                states
637            }
638        };
639
640        Ok(Transition {
641            start,
642            message,
643            end,
644        })
645    }
646}
647
648impl Transitions {
649    pub fn render(&self) {
650        let file_name = format!(
651            "target/machine/{}.dot",
652            self.machine_name.to_string().to_lowercase()
653        );
654        let _ = create_dir("target/machine");
655        let mut file = File::create(&file_name).expect("error opening dot file");
656
657        file.write_all(format!("digraph {} {{\n", self.machine_name.to_string()).as_bytes())
658            .expect("error writing to dot file");
659
660        let mut edges = Vec::new();
661        for transition in self.transitions.iter() {
662            for state in transition.end.iter() {
663                edges.push((&transition.start, &transition.message, state));
664            }
665        }
666
667        for edge in edges.iter() {
668            file.write_all(
669                &format!("{} -> {} [ label = \"{}\" ];\n", edge.0, edge.2, edge.1.into_token_stream()).as_bytes(),
670            )
671            .expect("error writing to dot file");
672        }
673
674        file.write_all(&b"}"[..])
675            .expect("error writing to dot file");
676        file.flush().expect("error flushhing dot file");
677    }
678}
679
680#[proc_macro]
681pub fn transitions(input: proc_macro::TokenStream) -> syn::export::TokenStream {
682    //println!("\ninput: {:?}", input);
683    let mut stream = proc_macro::TokenStream::new();
684
685    let transitions = parse_macro_input!(input as Transitions);
686    trace!("\nparsed transitions: {:#?}", transitions);
687
688    transitions.render();
689
690    let machine_name = transitions.machine_name;
691
692    let mut messages = HashMap::new();
693    for t in transitions.transitions.iter() {
694        let entry = messages.entry(&t.message).or_insert(Vec::new());
695        entry.push((&t.start, &t.end));
696    }
697
698    //let mut message_types = transitions.transitions.iter().map(|t| &t.message).collect::<Vec<_>>();
699
700    let mut type_arguments = HashSet::new();
701    for t in transitions.transitions.iter() {
702      let mut args = type_args(&t.message);
703      type_arguments.extend(args.drain());
704    }
705
706    let type_arguments = reorder_type_arguments(type_arguments);
707
708    // create an enum from the messages
709    let message_enum_ident = Ident::new(
710        &format!("{}Messages", &machine_name.to_string()),
711        Span::call_site(),
712    );
713    let structs_names = messages.keys().collect::<Vec<_>>();
714    let variants_names = structs_names.iter().map(|t| type_last_ident(*t)).collect::<Vec<_>>();
715
716
717    let type_arg_toks = if type_arguments.is_empty() {
718      quote!{}
719    } else {
720      quote!{
721        < #(#type_arguments),* >
722      }
723    };
724
725    // define the state enum
726    let toks = quote! {
727      #[derive(Clone,Debug,PartialEq)]
728      pub enum #message_enum_ident #type_arg_toks {
729        #(#variants_names(#structs_names)),*
730      }
731    };
732
733    stream.extend(proc_macro::TokenStream::from(toks));
734    let functions = messages
735      .iter()
736      .map(|(msg, moves)| {
737        let fn_ident = Ident::new(
738          //&format!("on_{}", &msg.to_string().to_snake()),
739          &format!("on_{}", type_to_snake(msg)),
740          Span::call_site(),
741          );
742        let mv = moves.iter().map(|(start, end)| {
743          if end.len() == 1 {
744            let end_state = &end[0];
745            quote!{
746              #machine_name::#start(state) => #machine_name::#end_state(state.#fn_ident(input)),
747            }
748          } else {
749            quote!{
750              #machine_name::#start(state) => state.#fn_ident(input),
751            }
752          }
753        }).collect::<Vec<_>>();
754
755        let type_arguments = reorder_type_arguments(type_args(msg));
756        let type_arg_toks = if type_arguments.is_empty() {
757          quote!{}
758        } else {
759          quote!{
760            < #(#type_arguments),* >
761          }
762        };
763
764        quote! {
765          pub fn #fn_ident #type_arg_toks(self, input: #msg) -> #machine_name {
766            match self {
767              #(#mv)*
768              _ => #machine_name::Error,
769            }
770          }
771        }
772      })
773    .collect::<Vec<_>>();
774
775    let matches = messages
776      .keys()
777      .map(|msg| {
778        let fn_ident = Ident::new(
779          //&format!("on_{}", &msg.to_string().to_snake()),
780          &format!("on_{}", type_to_snake(msg)),
781          Span::call_site(),
782          );
783
784          let id = type_last_ident(msg);
785
786          quote!{
787            #message_enum_ident::#id(message) => self.#fn_ident(message),
788          }
789
790      })
791    .collect::<Vec<_>>();
792
793    /*let type_arg_toks = if type_arguments.is_empty() {
794      quote!{}
795    } else {
796      quote!{
797        < #(#type_arguments),* >
798      }
799    };*/
800
801    let execute = quote! {
802      pub fn execute #type_arg_toks(self, input: #message_enum_ident #type_arg_toks) -> #machine_name {
803        match input {
804          #(#matches)*
805          _ => #machine_name::Error,
806        }
807      }
808    };
809
810    let toks = quote! {
811      impl #machine_name {
812        #(#functions)*
813
814        #execute
815      }
816    };
817
818    stream.extend(proc_macro::TokenStream::from(toks));
819
820    //println!("generated: {:?}", gen);
821    trace!("generated transitions: {}", stream);
822    let _ = create_dir("target/machine");
823    let file_name = format!("target/machine/{}.rs", machine_name.to_string().to_lowercase());
824    OpenOptions::new()
825        .create(true)
826        .write(true)
827        .open(&file_name)
828        .and_then(|mut file| {
829            file.seek(std::io::SeekFrom::End(0))?;
830            file.write_all(stream.to_string().as_bytes())?;
831            file.flush()
832        })
833        .expect("error writing transitions");
834
835    stream
836}
837
838#[proc_macro]
839pub fn methods(input: proc_macro::TokenStream) -> syn::export::TokenStream {
840    //println!("\ninput: {:?}", input);
841    let mut stream = proc_macro::TokenStream::new();
842
843    let methods = parse_macro_input!(input as Methods);
844    trace!("\nparsed methods: {:#?}", methods);
845
846    let mut h = HashMap::new();
847    for method in methods.methods.iter() {
848        for state in method.states.iter() {
849            let entry = h.entry(state).or_insert(Vec::new());
850            entry.push(&method.method_type);
851        }
852    }
853
854    for (state, methods) in h.iter() {
855        let method_toks = methods
856            .iter()
857            .map(|method| {
858                match method {
859                    MethodType::Get(ident, ty) => {
860                        quote! {
861                          pub fn #ident(&self) -> &#ty {
862                            &self.#ident
863                          }
864                        }
865                    }
866                    MethodType::Set(ident, ty) => {
867                        let mut_ident =
868                            Ident::new(&format!("{}_mut", &ident.to_string()), Span::call_site());
869                        quote! {
870                          pub fn #mut_ident(&mut self) -> &mut #ty {
871                            &mut self.#ident
872                          }
873                        }
874                    }
875                    MethodType::Fn(_) => {
876                        // we let the user implement these methods on the types
877                        quote! {}
878                    }
879                }
880            })
881            .collect::<Vec<_>>();
882
883        let toks = quote! {
884          impl #state {
885            #(#method_toks)*
886          }
887        };
888
889        stream.extend(proc_macro::TokenStream::from(toks));
890    }
891
892    let machine_name = methods.machine_name;
893    let wrapper_methods = methods
894        .methods
895        .iter()
896        .map(|method| match &method.method_type {
897            MethodType::Get(ident, ty) => {
898                let variants = method
899                    .states
900                    .iter()
901                    .map(|state| {
902                        quote! {
903                          #machine_name::#state(ref v) => Some(v.#ident()),
904                        }
905                    })
906                    .collect::<Vec<_>>();
907                quote! {
908                  pub fn #ident(&self) -> Option<&#ty> {
909                    match self {
910                      #(#variants)*
911                      _ => None,
912                    }
913                  }
914                }
915            }
916            MethodType::Set(ident, ty) => {
917                let mut_ident =
918                    Ident::new(&format!("{}_mut", &ident.to_string()), Span::call_site());
919
920                let variants = method
921                    .states
922                    .iter()
923                    .map(|state| {
924                        quote! {
925                          #machine_name::#state(ref mut v) => Some(v.#mut_ident()),
926                        }
927                    })
928                    .collect::<Vec<_>>();
929                quote! {
930                  pub fn #mut_ident(&mut self) -> Option<&mut #ty> {
931                    match self {
932                      #(#variants)*
933                      _ => None,
934                    }
935                  }
936                }
937            }
938            MethodType::Fn(m) => {
939                let ident = &m.ident;
940                let args = m
941                    .decl
942                    .inputs
943                    .iter()
944                    .filter(|arg| match arg {
945                        FnArg::Captured(_) => true,
946                        _ => false,
947                    })
948                    .map(|arg| {
949                        if let FnArg::Captured(a) = arg {
950                            &a.pat
951                        } else {
952                            panic!();
953                        }
954                    })
955                    .collect::<Vec<_>>();
956
957                let variants = method
958                    .states
959                    .iter()
960                    .map(|state| {
961                        let a = args.clone();
962                        if method.default.is_default() {
963                            quote! {
964                              #machine_name::#state(ref v) => v.#ident( #(#a),* ),
965                            }
966                        } else {
967                            quote! {
968                              #machine_name::#state(ref v) => Some(v.#ident( #(#a),* )),
969                            }
970                        }
971                    })
972                    .collect::<Vec<_>>();
973
974                let inputs = &m.decl.inputs;
975                let output = match &m.decl.output {
976                    ReturnType::Default => quote! {},
977                    ReturnType::Type(arrow, ty) => {
978                        if method.default.is_default() {
979                            quote! {
980                              #arrow #ty
981                            }
982                        } else {
983                            quote! {
984                              #arrow Option<#ty>
985                            }
986                        }
987                    }
988                };
989
990                match method.default {
991                    DefaultValue::None => {
992                        quote! {
993                          pub fn #ident(#inputs) #output {
994                            match self {
995                              #(#variants)*
996                              _ => None,
997                            }
998                          }
999                        }
1000                    }
1001                    DefaultValue::Default => {
1002                        quote! {
1003                          pub fn #ident(#inputs) #output {
1004                            match self {
1005                              #(#variants)*
1006                              _ => std::default::Default::default(),
1007                            }
1008                          }
1009                        }
1010                    }
1011                    DefaultValue::Val(ref expr) => {
1012                        quote! {
1013                          pub fn #ident(#inputs) #output {
1014                            match self {
1015                              #(#variants)*
1016                              _ => #expr,
1017                            }
1018                          }
1019                        }
1020                    }
1021                }
1022            }
1023        })
1024        .collect::<Vec<_>>();
1025
1026    let toks = quote! {
1027      impl #machine_name {
1028        #(#wrapper_methods)*
1029      }
1030    };
1031
1032    stream.extend(proc_macro::TokenStream::from(toks));
1033
1034    let file_name = format!("target/machine/{}.rs", machine_name.to_string().to_lowercase());
1035    let _ = create_dir("target/machine");
1036    OpenOptions::new()
1037        .create(true)
1038        .write(true)
1039        .open(&file_name)
1040        .and_then(|mut file| {
1041            file.seek(std::io::SeekFrom::End(0))?;
1042            file.write_all(stream.to_string().as_bytes())?;
1043            file.flush()
1044        })
1045        .expect("error writing methods");
1046
1047    stream
1048}
1049
1050#[derive(Debug)]
1051struct Methods {
1052    pub machine_name: Ident,
1053    pub methods: Vec<Method>,
1054}
1055
1056#[derive(Debug)]
1057struct Method {
1058    pub states: Vec<Ident>,
1059    pub method_type: MethodType,
1060    pub default: DefaultValue,
1061}
1062
1063#[derive(Debug)]
1064enum MethodType {
1065    Get(Ident, Type),
1066    Set(Ident, Type),
1067    Fn(MethodSig),
1068}
1069
1070#[derive(Debug)]
1071enum DefaultValue {
1072    None,
1073    Default,
1074    Val(Expr),
1075}
1076
1077impl DefaultValue {
1078    pub fn is_default(&self) -> bool {
1079        match self {
1080            DefaultValue::None => false,
1081            _ => true,
1082        }
1083    }
1084}
1085
1086impl Parse for Methods {
1087    fn parse(input: ParseStream) -> Result<Self> {
1088        let machine_name: Ident = input.parse()?;
1089        let _: Token![,] = input.parse()?;
1090
1091        let content;
1092        bracketed!(content in input);
1093
1094        let mut methods = Vec::new();
1095
1096        let t: Method = content.parse()?;
1097        methods.push(t);
1098
1099        loop {
1100            let lookahead = content.lookahead1();
1101            if lookahead.peek(Token![,]) {
1102                let _: Token![,] = content.parse()?;
1103                let t: Method = content.parse()?;
1104                methods.push(t);
1105            } else {
1106                break;
1107            }
1108        }
1109
1110        Ok(Methods {
1111            machine_name,
1112            methods,
1113        })
1114    }
1115}
1116struct ParenVal {
1117    expr: Expr,
1118}
1119
1120impl Parse for ParenVal {
1121    fn parse(input: ParseStream) -> Result<Self> {
1122        let stream;
1123        parenthesized!(stream in input);
1124        let expr: Expr = stream.parse()?;
1125        Ok(ParenVal { expr })
1126    }
1127}
1128
1129impl Parse for Method {
1130    fn parse(input: ParseStream) -> Result<Self> {
1131        let mut states = Vec::new();
1132
1133        let state: Ident = input.parse()?;
1134        states.push(state);
1135
1136        loop {
1137            let lookahead = input.lookahead1();
1138            if lookahead.peek(Token![,]) {
1139                let _: Token![,] = input.parse()?;
1140                let state: Ident = input.parse()?;
1141                states.push(state);
1142            } else {
1143                break;
1144            }
1145        }
1146
1147        let _: Token![=>] = input.parse()?;
1148        let default_token: Option<Token![default]> = input.parse()?;
1149        let default = if default_token.is_some() {
1150            match input.parse::<ParenVal>() {
1151                Ok(content) => DefaultValue::Val(content.expr),
1152                Err(_) => DefaultValue::Default,
1153            }
1154        } else {
1155            DefaultValue::None
1156        };
1157
1158        let method_type = match parse_method_sig(input) {
1159            Ok(f) => MethodType::Fn(f),
1160            Err(_) => {
1161                let i: Ident = input.parse()?;
1162                let name: Ident = input.parse()?;
1163                let _: Token![:] = input.parse()?;
1164                let ty: Type = input.parse()?;
1165
1166                if i.to_string() == "get" {
1167                    MethodType::Get(name, ty)
1168                } else if i.to_string() == "set" {
1169                    MethodType::Set(name, ty)
1170                } else {
1171                    return Err(syn::Error::new(i.span(), "expected `get` or `set`"));
1172                }
1173            }
1174        };
1175
1176        Ok(Method {
1177            states,
1178            method_type,
1179            default,
1180        })
1181    }
1182}
1183
1184fn parse_method_sig(input: ParseStream) -> Result<MethodSig> {
1185    //let vis: Visibility = input.parse()?;
1186    let constness: Option<Token![const]> = input.parse()?;
1187    let unsafety: Option<Token![unsafe]> = input.parse()?;
1188    let asyncness: Option<Token![async]> = input.parse()?;
1189    let abi: Option<Abi> = input.parse()?;
1190    let fn_token: Token![fn] = input.parse()?;
1191    let ident: Ident = input.parse()?;
1192    let generics: Generics = input.parse()?;
1193
1194    let content;
1195    let paren_token = parenthesized!(content in input);
1196    let inputs = content.parse_terminated(FnArg::parse)?;
1197
1198    let output: ReturnType = input.parse()?;
1199    let where_clause: Option<WhereClause> = input.parse()?;
1200
1201    Ok(MethodSig {
1202        constness,
1203        unsafety,
1204        asyncness,
1205        abi,
1206        ident,
1207        decl: FnDecl {
1208            fn_token: fn_token,
1209            paren_token: paren_token,
1210            inputs: inputs,
1211            output: output,
1212            variadic: None,
1213            generics: Generics {
1214                where_clause: where_clause,
1215                ..generics
1216            },
1217        },
1218    })
1219}
1220
1221fn type_to_snake(t: &Type) -> String {
1222  match t {
1223    Type::Path(ref p) => {
1224      match p.path.segments.last() {
1225        Some(Pair::End(segment)) => {
1226          segment.ident.to_string().to_snake()
1227        },
1228        _ => panic!("expected a path segment"),
1229      }
1230    },
1231    t => panic!("expected a Type::Path, got {:?}", t),
1232  }
1233}
1234
1235fn type_last_ident(t: &Type) -> &Ident {
1236  match t {
1237    Type::Path(ref p) => {
1238      match p.path.segments.last() {
1239        Some(Pair::End(segment)) => {
1240          &segment.ident
1241        },
1242        _ => panic!("expected a path segment"),
1243      }
1244    },
1245    t => panic!("expected a Type::Path, got {:?}", t),
1246  }
1247}
1248
1249fn type_args(t: &Type) -> HashSet<GenericArgument> {
1250  match t {
1251    Type::Path(ref p) => {
1252      match p.path.segments.last() {
1253        Some(Pair::End(segment)) => {
1254          match &segment.arguments {
1255            PathArguments::AngleBracketed(a) => {
1256              a.args.iter().cloned().collect()
1257            },
1258            PathArguments::None => HashSet::new(),
1259            a => panic!("expected angle bracketed arguments, got {:?}", a),
1260          }
1261        },
1262        _ => panic!("expected a path segment"),
1263      }
1264    },
1265    t => panic!("expected a Type::Path, got {:?}", t),
1266  }
1267}
1268
1269// lifetimes must appear before other type arguments
1270fn reorder_type_arguments(mut t: HashSet<GenericArgument>) -> Vec<GenericArgument> {
1271  let mut lifetimes = Vec::new();
1272  let mut others = Vec::new();
1273
1274  for arg in t.drain() {
1275    if let GenericArgument::Lifetime(_) = arg {
1276      lifetimes.push(arg);
1277    } else {
1278      others.push(arg);
1279    }
1280  }
1281
1282  lifetimes.extend(others.drain(..));
1283  lifetimes
1284}