Skip to main content

evolve_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Ident, Token, bracketed,
5    parse::{Parse, ParseStream},
6    punctuated::Punctuated,
7};
8
9struct Production(Vec<Ident>);
10
11impl Parse for Production {
12    fn parse(input: ParseStream) -> syn::Result<Self> {
13        let content;
14        bracketed!(content in input);
15        let symbols = Punctuated::<Ident, Token![,]>::parse_terminated(&content)?;
16        Ok(Production(symbols.into_iter().collect()))
17    }
18}
19
20struct Rule {
21    name: Ident,
22    productions: Vec<Production>,
23}
24
25impl Parse for Rule {
26    fn parse(input: ParseStream) -> syn::Result<Self> {
27        let name: Ident = input.parse()?;
28        input.parse::<Token![=>]>()?;
29        let mut productions = vec![input.parse::<Production>()?];
30        while input.peek(Token![|]) {
31            input.parse::<Token![|]>()?;
32            productions.push(input.parse()?);
33        }
34        input.parse::<Token![;]>()?;
35        Ok(Rule { name, productions })
36    }
37}
38
39struct GrammarInput {
40    grammar_name: Ident,
41    symbol_name: Ident,
42    start: Ident,
43    rules: Vec<Rule>,
44}
45
46impl Parse for GrammarInput {
47    fn parse(input: ParseStream) -> syn::Result<Self> {
48        let kw: Ident = input.parse()?;
49        if kw != "grammar" {
50            return Err(syn::Error::new(kw.span(), "expected `grammar`"));
51        }
52        let grammar_name: Ident = input.parse()?;
53        input.parse::<Token![;]>()?;
54
55        let kw: Ident = input.parse()?;
56        if kw != "symbol" {
57            return Err(syn::Error::new(kw.span(), "expected `symbol`"));
58        }
59        let symbol_name: Ident = input.parse()?;
60        input.parse::<Token![;]>()?;
61
62        let kw: Ident = input.parse()?;
63        if kw != "start" {
64            return Err(syn::Error::new(kw.span(), "expected `start`"));
65        }
66        let start: Ident = input.parse()?;
67        input.parse::<Token![;]>()?;
68
69        let mut rules = Vec::new();
70        while !input.is_empty() {
71            rules.push(input.parse()?);
72        }
73        Ok(GrammarInput {
74            grammar_name,
75            symbol_name,
76            start,
77            rules,
78        })
79    }
80}
81
82/// Generates a zero-cost grammar with compile-time dispatch.
83///
84/// Produces a grammar struct, a symbol enum, and a `GrammarDef` implementation
85/// with all dispatch resolved via match arms (no allocations, no HashMap lookups).
86///
87/// # Syntax
88///
89/// ```ignore
90/// grammar! {
91///     grammar MyGrammar;
92///     symbol MySymbol;
93///     start Expr;
94///
95///     Expr => [Expr, Expr, BinOp] | [Val];
96///     BinOp => [Add] | [Sub] | [Mul];
97///     Val => [X] | [One];
98/// }
99/// ```
100///
101/// - `grammar <name>` — name of the generated struct implementing `GrammarDef`
102/// - `symbol <name>` — name of the generated enum with all grammar symbols
103/// - `start <rule>` — the start rule
104/// - Rules: `<name> => [symbols...] | [symbols...];`
105///
106/// Symbols appearing on the left side of `=>` are non-terminals.
107/// All other symbols are terminals.
108#[proc_macro]
109pub fn grammar(input: TokenStream) -> TokenStream {
110    let input = syn::parse_macro_input!(input as GrammarInput);
111
112    let grammar_name = &input.grammar_name;
113    let symbol_name = &input.symbol_name;
114    let rule_names: Vec<&Ident> = input.rules.iter().map(|r| &r.name).collect();
115
116    // Collect all symbols (rule names + terminals)
117    let mut all_symbols: Vec<Ident> = rule_names.iter().map(|i| (*i).clone()).collect();
118    for rule in &input.rules {
119        for prod in &rule.productions {
120            for sym in &prod.0 {
121                if !all_symbols.iter().any(|s| s == sym) {
122                    all_symbols.push(sym.clone());
123                }
124            }
125        }
126    }
127
128    let start_sym = &input.start;
129
130    // Generate num_productions match arms
131    let num_prod_arms: Vec<_> = input
132        .rules
133        .iter()
134        .map(|r| {
135            let name = &r.name;
136            let count = r.productions.len();
137            quote! { #symbol_name::#name => #count, }
138        })
139        .collect();
140
141    // Generate production match arms
142    let mut prod_arms = Vec::new();
143    for rule in &input.rules {
144        let name = &rule.name;
145        for (i, prod) in rule.productions.iter().enumerate() {
146            let syms = &prod.0;
147            prod_arms.push(quote! {
148                (#symbol_name::#name, #i) => &[#(#symbol_name::#syms),*],
149            });
150        }
151    }
152
153    // Generate is_terminal match: rule names are non-terminals
154    let non_terminal_names: Vec<_> = rule_names.iter().collect();
155
156    let expanded = quote! {
157        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
158        enum #symbol_name {
159            #(#all_symbols),*
160        }
161
162        struct #grammar_name;
163
164        impl evolve::grammar::grammar_def::GrammarDef for #grammar_name {
165            type Symbol = #symbol_name;
166            type Terminal = #symbol_name;
167
168            fn start(&self) -> #symbol_name {
169                #symbol_name::#start_sym
170            }
171
172            fn num_productions(&self, symbol: #symbol_name) -> usize {
173                match symbol {
174                    #(#num_prod_arms)*
175                    _ => 0,
176                }
177            }
178
179            fn production(&self, symbol: #symbol_name, index: usize) -> &[#symbol_name] {
180                match (symbol, index) {
181                    #(#prod_arms)*
182                    _ => &[],
183                }
184            }
185
186            fn is_terminal(&self, symbol: #symbol_name) -> bool {
187                !matches!(symbol, #(#symbol_name::#non_terminal_names)|*)
188            }
189
190            fn terminal_value(&self, symbol: #symbol_name) -> #symbol_name {
191                symbol
192            }
193        }
194    };
195
196    TokenStream::from(expanded)
197}