langen_macro/
lib.rs

1use std::rc::Rc;
2
3use parser::{Lr1Automaton, MetaSymbol, Rule, Symbol, Terminal};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use regex_automata::dfa::dense;
7use syn::{
8    parse::Parse, spanned::Spanned, token::Comma, Data, DeriveInput, Expr, ExprClosure, Fields,
9    Ident, LitStr,
10};
11
12mod parser;
13
14/// Derive macro for Tokens
15///
16/// Every token must have a `#[token(regex [, process_func])]` attribute.
17/// `process_func` is an optional closure that specifies how data should be extracted from the token characters.
18/// `regex` indicates how the token should be matches.
19/// The tokens defined first will have higher priority than the later ones.
20///
21/// You can also use any number of `#[ignore(regex)]` attributes to indicate that some character sequences should be ignored
22///
23/// # Panics
24/// Panics if any attributes are invalid
25#[allow(clippy::too_many_lines)]
26#[proc_macro_derive(Tokens, attributes(ignored, token))]
27pub fn tokens_derive(input: TokenStream) -> TokenStream {
28    let input: DeriveInput = syn::parse(input).expect("Input has to be from derive");
29    if let Data::Enum(data) = input.data {
30        let name = input.ident;
31
32        let mut token_indices = vec![];
33        let mut token_code = vec![];
34        let mut token_patterns = vec![];
35
36        for re in input.attrs.iter().filter_map(|attr| {
37            if attr.path().is_ident("ignored") {
38                let t: LitStr = attr.parse_args().unwrap_or_else(|_| {
39                    panic!("ignored argument for \"{name}\" must be string literal")
40                });
41                Some(t)
42            } else {
43                None
44            }
45        }) {
46            token_indices.push(token_code.len());
47            token_code.push(quote! {continue;});
48            token_patterns.push(re.value());
49        }
50
51        for variant in data.variants {
52            for input in variant.attrs.iter().filter_map(|attr| {
53                if attr.path().is_ident("token") {
54                    let t: TokenInput = attr.parse_args().unwrap_or_else(|e| {
55                        panic!(
56                            "Invalid arguments for token argument for \"{}\": {e}",
57                            variant.ident
58                        )
59                    });
60                    Some(t)
61                } else {
62                    None
63                }
64            }) {
65                let ident = &variant.ident;
66
67                token_indices.push(token_code.len());
68
69                token_code.push(match input.fun {
70                    Some(closure) => {
71                        if let Fields::Unnamed(_fields) = &variant.fields {
72                            quote! {
73                                let r = (#closure)(&input[start..end]);
74                                match r {
75                                    Ok(v) => Self::#ident(v),
76                                    Err(e) => {
77                                        return Err(langen::LexerError::ProcessError(Box::new(e), span))
78                                    }
79                                }
80                            }
81                        } else {
82                            panic!(
83                                "Variant \"{}\" must be one-length unnamed when having process function",
84                                variant.ident
85                            )
86                        }
87                    }
88                    None => {
89                        if let Fields::Unit = variant.fields {
90                            quote! {Self::#ident}
91                        } else {
92                            panic!(
93                                "Non-unit token variant \"{}\" must have process function",
94                                variant.ident
95                            )
96                        }
97                    }
98                });
99                token_patterns.push(input.re.value());
100            }
101        }
102
103        let dfa = dense::DFA::new_many(&token_patterns).expect("Couldn't build regex automaton");
104        let (bytes, pad) = dfa.to_bytes_little_endian();
105        let le_dfa_bytes = &bytes[pad..];
106        let (bytes, pad) = dfa.to_bytes_big_endian();
107        let be_dfa_bytes = &bytes[pad..];
108
109        quote! {
110            impl langen::Tokens for #name {
111                fn scan(input: &str) -> Result<Vec<(Self, langen::Span)>, langen::LexerError> {
112                    const DFA: &langen::regex_automata::util::wire::AlignAs<[u8], u32> = &langen::regex_automata::util::wire::AlignAs {
113                        _align: [],
114                        #[cfg(target_endian = "big")]
115                        bytes: [#(#be_dfa_bytes),*],
116                        #[cfg(target_endian = "little")]
117                        bytes: [#(#le_dfa_bytes),*],
118                    };
119
120                    // This is generated above, so we can always safely expect
121                    let (dfa, _) = langen::regex_automata::dfa::dense::DFA::from_bytes(&DFA.bytes).expect("Couldn't deserialize dfa");
122
123                    let mut re_input = langen::regex_automata::Input::new(input).anchored(langen::regex_automata::Anchored::Yes);
124                    let mut tokens = vec![];
125                    let mut current = 0;
126
127                    while current != input.len() {
128                        re_input.set_start(current);
129                        // Input should always be fine
130                        use langen::regex_automata::dfa::Automaton;
131                        if let Some(m) = dfa.try_search_fwd(&re_input).expect("Regex Error") {
132                            // println!("{} {:?}", current, m);
133                            let start = current;
134                            current = m.offset();
135                            let end = current;
136
137                            let span = langen::Span::new(start, end);
138                            let token = match m.pattern().as_usize() {
139                                #(#token_indices => {#token_code})*
140                                _ => {unreachable!("Every pattern has to come from a regex")}
141                            };
142
143                            tokens.push((token, span));
144                        } else {
145                            return Err(langen::LexerError::NoToken(current));
146                        }
147                    }
148
149                    Ok(tokens)
150                }
151            }
152        }.into()
153    } else {
154        panic!("Langen can only be used on enum");
155    }
156}
157
158struct TokenInput {
159    re: LitStr,
160    fun: Option<ExprClosure>,
161}
162
163impl Parse for TokenInput {
164    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
165        Ok(Self {
166            re: input.parse()?,
167            fun: match input.parse::<Comma>() {
168                Ok(_) => {
169                    let expr: Expr = input.parse()?;
170                    if let Expr::Closure(closure) = expr {
171                        Some(closure)
172                    } else {
173                        return Err(syn::Error::new(
174                            expr.span(),
175                            "Second argument to token must be closure",
176                        ));
177                    }
178                }
179                Err(_) => None,
180            },
181        })
182    }
183}
184
185/// Derive macro for Grammar
186///
187/// Every Symbol must have a `#[rule(process_func, symbol *)]` attribute.
188/// `process_func` is a closure that specifies how data should be extracted from the symbols consumed in this rule.
189/// `symbol`s are any amount of variants of this enum making up this rule. See <https://en.wikipedia.org/wiki/Formal_grammar> for more info.
190/// The first variant with a `#[rule]` attribute always is the starting symbol of the grammar
191///
192/// # Panics
193/// Panics if any attributes are invalid
194#[allow(clippy::too_many_lines)]
195#[proc_macro_derive(Grammar, attributes(rule))]
196pub fn grammar_derive(input: TokenStream) -> TokenStream {
197    let input: DeriveInput = syn::parse(input).expect("Input has to be from derive");
198    if let Data::Enum(data) = input.data {
199        let name = input.ident;
200
201        let mut rules = vec![];
202        let mut non_terminals = vec![];
203        let mut terminals = vec![];
204        let mut out_variant = None;
205        let mut out_type = None;
206
207        for variant in data.variants {
208            let mut has_rule = false;
209            for input in variant.attrs.iter().filter_map(|attr| {
210                if attr.path().is_ident("rule") {
211                    has_rule = true;
212                    let t: RuleInput = attr.parse_args().unwrap_or_else(|e| {
213                        panic!(
214                            "Invalid arguments for rule argument for \"{}\": {e}",
215                            variant.ident
216                        )
217                    });
218                    Some(t)
219                } else {
220                    None
221                }
222            }) {
223                rules.push((variant.ident.clone(), input.parts, input.fun));
224            }
225
226            if has_rule {
227                let Fields::Unnamed(fields) = variant.fields else {
228                    panic!("Every variant of grammar must have one unnamed field");
229                };
230                non_terminals.push(variant.ident.clone());
231                if out_type.is_none() {
232                    out_variant = Some(variant.ident.clone());
233                    out_type = Some(fields.unnamed[0].ty.clone());
234                }
235            } else {
236                terminals.push((
237                    variant.ident.clone(),
238                    matches!(variant.fields, Fields::Unnamed(_)),
239                ));
240            }
241        }
242
243        let parser_rules = rules
244            .iter()
245            .map(|rule| {
246                Rc::new(Rule {
247                    parts: rule
248                        .1
249                        .iter()
250                        .map(|ident| {
251                            if let Some(i) = non_terminals.iter().position(|e| e == ident) {
252                                Symbol::NonTerminal(MetaSymbol::Normal(i))
253                            } else if let Some(i) = terminals.iter().position(|(e, _)| e == ident) {
254                                Symbol::Terminal(Terminal::Normal(i))
255                            } else {
256                                panic!("Symbol \"{ident}\" unknown (in rule for \"{}\")", rule.0);
257                            }
258                        })
259                        .collect(),
260                    result: MetaSymbol::Normal(
261                        non_terminals
262                            .iter()
263                            .position(|e| *e == rule.0)
264                            .expect("This comes from the variants, so should always exist"),
265                    ),
266                })
267            })
268            .collect();
269
270        let mut automaton =
271            Lr1Automaton::create(parser_rules, terminals.len(), non_terminals.len());
272        automaton.build_automaton();
273        automaton = automaton.make_lalr1();
274        let (action, jump) = automaton.generate_tables();
275
276        let mut action_code = vec![];
277
278        for (i, action_row) in action.iter().enumerate() {
279            for (symbol, action) in action_row {
280                let code_symbol = match symbol {
281                    Terminal::Normal(symbol_i) => {
282                        let (ident, has_value) = &terminals[*symbol_i];
283                        if *has_value {
284                            quote! {Some((Self::#ident(v), span))}
285                        } else {
286                            quote! {Some((Self::#ident, span))}
287                        }
288                    }
289                    Terminal::Eof => quote! {None},
290                };
291                let action = match action {
292                    parser::Action::Shift(n) => match symbol {
293                        Terminal::Normal(symbol_i) => {
294                            let (ident, has_value) = &terminals[*symbol_i];
295                            if *has_value {
296                                quote! {
297                                    stack.push(#n);
298                                    symbol_stack.push((Self::#ident(v), span));
299                                }
300                            } else {
301                                quote! {
302                                    stack.push(#n);
303                                    symbol_stack.push((Self::#ident, span));
304                                }
305                            }
306                        }
307                        Terminal::Eof => {
308                            panic!("Can't shift in EOF");
309                        }
310                    },
311                    parser::Action::Reduce(m) => {
312                        let put_back = match symbol {
313                            Terminal::Normal(symbol_i) => {
314                                let (ident, has_value) = &terminals[*symbol_i];
315                                if *has_value {
316                                    quote! {input.push((Self::#ident(v), span));}
317                                } else {
318                                    quote! {input.push((Self::#ident, span));}
319                                }
320                            }
321                            Terminal::Eof => quote! {},
322                        };
323
324                        let mut pop_code = vec![];
325                        let mut fields = vec![];
326                        let mut spans = vec![];
327                        for (i, ident) in rules[*m].1.iter().enumerate() {
328                            let var_ident = format_ident!("v{i}");
329                            let span_ident = format_ident!("s{i}");
330
331                            if non_terminals.contains(ident) {
332                                pop_code.push(quote! {
333                                            stack.pop();
334                                            let Some((Self::#ident(#var_ident), #span_ident)) = symbol_stack.pop() else {unreachable!("Stack corrupted! (1)")};
335                                        });
336                                fields.push(var_ident);
337                            } else if terminals
338                                .iter()
339                                .find_map(|(var_ident, has_value)| {
340                                    if var_ident == ident {
341                                        Some(*has_value)
342                                    } else {
343                                        None
344                                    }
345                                })
346                                .expect("Has to be in terminals")
347                            {
348                                pop_code.push(quote! {
349                                                stack.pop();
350                                                let Some((Self::#ident(#var_ident), #span_ident)) = symbol_stack.pop() else {unreachable!("Stack corrupted! (2)")};
351                                            });
352                                fields.push(var_ident);
353                            } else {
354                                pop_code.push(quote! {
355                                                stack.pop();
356                                                let Some((Self::#ident, #span_ident)) = symbol_stack.pop() else {unreachable!("Stack corrupted! (3)")};
357                                            });
358                            }
359
360                            spans.push(span_ident);
361                        }
362                        pop_code = pop_code.into_iter().rev().collect();
363
364                        let closure = &rules[*m].2;
365
366                        let func_code = if spans.is_empty() {
367                            quote! {
368                                // TODO: maybe try to find actual values for this
369                                let span = langen::Span { start: 0, end: 0 };
370                                let r = (#closure)(span.clone());
371                                let value = match r {
372                                    Ok(v) => v,
373                                    Err(e) => {
374                                        return Err(langen::ParserError::ProcessError(num_tokens-input.len(), Box::new(e), span));
375                                    }
376                                };
377                            }
378                        } else {
379                            let first = spans.first().expect("Can't be empty");
380                            let last = spans.last().expect("Can't be empty");
381                            quote! {
382                                let span = langen::Span { start: #first.start, end: #last.end };
383                                let r = (#closure)(span.clone(), #( #fields ),*);
384                                let value = match r {
385                                    Ok(v) => v,
386                                    Err(e) => {
387                                        return Err(langen::ParserError::ProcessError(num_tokens-input.len(), Box::new(e), span));
388                                    }
389                                };
390                            }
391                        };
392
393                        let result = &rules[*m].0;
394                        let mut jump_code = vec![];
395                        let meta_i = non_terminals
396                            .iter()
397                            .position(|elem| elem == result)
398                            .expect("Must contain ident");
399
400                        for (state, new_state) in &jump[meta_i] {
401                            jump_code.push(quote! {
402                                Some(#state) => {stack.push(#new_state)}
403                            });
404                        }
405
406                        quote! {
407                            #put_back
408                            #( # pop_code )*
409                            #func_code
410                            symbol_stack.push((Self::#result(value), span));
411                            match stack.last() {
412                                #( #jump_code )*
413                                _ => unreachable!("Stack corrupted! (4)"),
414                            }
415                        }
416                    }
417                    parser::Action::Accept => {
418                        quote! {
419                            let (Self::#out_variant(v), _) = symbol_stack.pop().expect("Stack corrupted! (5)") else {
420                                unreachable!("Stack corrupted! (6)")
421                            };
422                            return Ok(v);
423                        }
424                    }
425                };
426                action_code.push(quote! {(Some(#i), #code_symbol) => {#action}});
427            }
428        }
429
430        quote! {
431            impl langen::Grammar for #name {
432                type OUT = #out_type;
433
434                fn parse(tokens: Vec<(Self, langen::Span)>) -> Result<Self::OUT, langen::ParserError<Self>> {
435                    let num_tokens = tokens.len();
436                    let mut input = tokens.into_iter().rev().collect::<Vec<_>>();
437                    let mut symbol_stack: Vec<(Self, langen::Span)> = vec![];
438                    let mut stack: Vec<usize> = vec![0];
439
440                    loop {
441                        // println!("{:?}\n\n{:?}\n\n{:?}\n\n\n", input, symbol_stack, stack);
442                        match (stack.last(), input.pop()) {
443                            #( #action_code )*
444                            (None, _) => {return Err(langen::ParserError::UnexpectedEnd)} // This is something else, might even be unreachable, but i don't care
445                            (_, None) => {return Err(langen::ParserError::UnexpectedEnd)}
446                            (_, Some((token, span))) => {return Err(langen::ParserError::InvalidToken(num_tokens-input.len(), token, span))}
447                        }
448                    }
449                }
450            }
451        }
452        .into()
453    } else {
454        panic!("Langen can only be used on enum");
455    }
456}
457
458struct RuleInput {
459    fun: ExprClosure,
460    parts: Vec<Ident>,
461}
462
463impl Parse for RuleInput {
464    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
465        let expr: Expr = input.parse()?;
466        let Expr::Closure(fun) = expr else {
467            return Err(syn::Error::new(
468                expr.span(),
469                "First argument to rule must be closure",
470            ));
471        };
472        let mut parts = vec![];
473        loop {
474            if input.parse::<Comma>().is_err() {
475                break;
476            }
477            parts.push(input.parse()?);
478        }
479        Ok(Self { fun, parts })
480    }
481}