compiler_tools_derive/
lib.rs

1use std::collections::{BTreeMap, HashSet};
2
3use indexmap::IndexMap;
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, TokenStream as TokenStream2, TokenTree};
6use quote::{format_ident, quote, quote_spanned, ToTokens, TokenStreamExt};
7use regex::Regex;
8use syn::{parse_macro_input, spanned::Spanned, DeriveInput, Expr, ExprLit, ExprPath, Fields, FieldsUnnamed, Lifetime, Lit, Meta, Type};
9
10use crate::{gen::class_match::gen_class_match, lit_table::LitTable, simple_regex::SimpleRegex};
11
12mod gen;
13mod lit_table;
14mod simple_regex;
15
16#[proc_macro_attribute]
17pub fn token_parse(_metadata: TokenStream, input: TokenStream) -> TokenStream {
18    let ast = parse_macro_input!(input as DeriveInput);
19    impl_token_parse(&ast).into()
20}
21
22fn flatten<S: ToTokens, T: IntoIterator<Item = S>>(iter: T) -> TokenStream2 {
23    let mut out = quote! {};
24    out.append_all(iter);
25    out
26}
27
28struct TokenParseData {
29    has_target: bool,
30    target_needs_parse: bool,
31    is_illegal: bool,
32    literals: Vec<String>,
33    simple_regexes: Vec<String>,
34    regexes: Vec<String>,
35    parse_fn: Option<String>,
36    ident: Ident,
37}
38
39fn parse_attributes(input: TokenStream2) -> Option<IndexMap<String, Option<String>>> {
40    let mut tokens = input.into_iter();
41
42    let mut attributes = IndexMap::<String, Option<String>>::new();
43
44    loop {
45        let name = match tokens.next() {
46            None => break,
47            Some(TokenTree::Ident(ident)) => ident,
48            _ => return None,
49        };
50        match tokens.next() {
51            None => {
52                attributes.insert(name.to_string(), None);
53                break;
54            }
55            Some(TokenTree::Punct(p)) if p.as_char() == ',' => {
56                attributes.insert(name.to_string(), None);
57            }
58            Some(TokenTree::Punct(p)) if p.as_char() == '=' => {
59                let value = if let TokenTree::Literal(literal) = tokens.next()? {
60                    let lit = Lit::new(literal);
61                    Some(match lit {
62                        Lit::Str(s) => s.value(),
63                        Lit::Char(c) => c.value().to_string(),
64                        _ => return None,
65                    })
66                } else {
67                    return None;
68                };
69                attributes.insert(name.to_string(), value);
70            }
71            _ => return None,
72        }
73    }
74    Some(attributes)
75}
76
77fn construct_variant(item: &TokenParseData, enum_ident: &Ident) -> TokenStream2 {
78    let variant = &item.ident;
79    if item.has_target {
80        if item.target_needs_parse {
81            //TODO: emit better error for parsefail
82            quote! {
83                #enum_ident::#variant(passed.parse().ok()?)
84            }
85        } else {
86            quote! {
87                #enum_ident::#variant(passed)
88            }
89        }
90    } else {
91        quote! {
92            #enum_ident::#variant
93        }
94    }
95}
96
97struct SimpleRegexData {
98    pub token_index: usize,
99    pub regex: SimpleRegex,
100}
101
102struct RegexData {
103    pub token_index: usize,
104    pub regex: Regex,
105}
106
107fn impl_token_parse(input: &DeriveInput) -> proc_macro2::TokenStream {
108    if input.generics.params.len() > 1 || !matches!(input.generics.params.first(), None | Some(syn::GenericParam::Lifetime(_))) {
109        return quote_spanned! {
110            input.generics.span() =>
111            compile_error!("TokenParse can only have a single lifetime type parameter");
112        };
113    }
114    let has_lifetime_param = input.generics.params.len() == 1;
115    let original_lifetime_param = if let Some(syn::GenericParam::Lifetime(lifetime)) = input.generics.params.first() {
116        Some(lifetime.lifetime.ident.clone())
117    } else {
118        None
119    };
120
121    let items = match &input.data {
122        syn::Data::Enum(items) => items,
123        _ => {
124            return quote_spanned! {
125                input.span() =>
126                compile_error!("TokenParse can only be derived on enums");
127            }
128        }
129    };
130
131    let mut tokens_to_parse = vec![];
132    let mut has_illegal = false;
133    for variant in &items.variants {
134        let mut parse_data = TokenParseData {
135            has_target: false,
136            target_needs_parse: false,
137            is_illegal: false,
138            literals: vec![],
139            simple_regexes: vec![],
140            regexes: vec![],
141            parse_fn: None,
142            ident: variant.ident.clone(),
143        };
144
145        for attribute in &variant.attrs {
146            if !attribute.path().is_ident("token") {
147                continue;
148            }
149
150            let Meta::List(meta) = &attribute.meta else {
151                continue;
152            };
153
154            let attributes = match parse_attributes(meta.tokens.clone()) {
155                Some(x) => x,
156                None => {
157                    return quote_spanned! {
158                        attribute.span() =>
159                        compile_error!("invalid attribute syntax");
160                    }
161                }
162            };
163            for (name, value) in attributes {
164                if name != "illegal" && value.is_none() {
165                    return quote_spanned! {
166                        attribute.span() =>
167                        compile_error!("missing attribute value");
168                    };
169                }
170
171                match &*name {
172                    "literal" => {
173                        parse_data.literals.push(value.unwrap());
174                    }
175                    "regex" => {
176                        parse_data.simple_regexes.push(value.unwrap());
177                    }
178                    "regex_full" => {
179                        parse_data.regexes.push(value.unwrap());
180                    }
181                    "parse_fn" => {
182                        if parse_data.parse_fn.is_some() {
183                            return quote_spanned! {
184                                attribute.span() =>
185                                compile_error!("redefined 'parse_fn' attribute");
186                            };
187                        }
188                        parse_data.parse_fn = Some(value.unwrap());
189                    }
190                    "illegal" => {
191                        if value.is_some() {
192                            return quote_spanned! {
193                                attribute.span() =>
194                                compile_error!("unexpected attribute value");
195                            };
196                        }
197                        if parse_data.is_illegal || has_illegal {
198                            return quote_spanned! {
199                                attribute.span() =>
200                                compile_error!("redefined 'illegal' attribute");
201                            };
202                        }
203                        parse_data.is_illegal = true;
204                        has_illegal = true;
205                    }
206                    _ => {
207                        return quote_spanned! {
208                            attribute.span() =>
209                            compile_error!("unknown attribute");
210                        }
211                    }
212                }
213            }
214        }
215        if let Some((_, discriminant)) = &variant.discriminant {
216            if let Expr::Lit(ExprLit {
217                lit: Lit::Str(lit_str),
218                ..
219            }) = discriminant
220            {
221                parse_data.literals.push(lit_str.value());
222            } else {
223                return quote_spanned! {
224                    input.span() =>
225                    compile_error!("TokenParse enums cannot have non-string discriminants");
226                };
227            }
228        }
229        if parse_data.parse_fn.is_some() && (!parse_data.literals.is_empty() || !parse_data.simple_regexes.is_empty() || !parse_data.regexes.is_empty()) {
230            return quote_spanned! {
231                input.span() =>
232                compile_error!("cannot have a 'parse_fn' attribute and a 'literal', 'regex', or 'regex_full' attribute");
233            };
234        }
235        let has_anything =
236            parse_data.parse_fn.is_some() || !parse_data.literals.is_empty() || !parse_data.simple_regexes.is_empty() || !parse_data.regexes.is_empty();
237        if parse_data.is_illegal && has_anything {
238            return quote_spanned! {
239                input.span() =>
240                compile_error!("cannot have an 'illegal' attribute and a 'literal', 'regex', 'regex_full', or 'parse_fn' attribute");
241            };
242        } else if !parse_data.is_illegal && !has_anything {
243            return quote_spanned! {
244                input.span() =>
245                compile_error!("must have an enum discriminant or 'illegal', 'literal', 'regex', 'regex_full', or 'parse_fn' attribute");
246            };
247        }
248
249        match &variant.fields {
250            Fields::Named(_) => {
251                return quote_spanned! {
252                    variant.fields.span() =>
253                    compile_error!("cannot have enum struct in TokenParse variant");
254                };
255            }
256            Fields::Unnamed(FieldsUnnamed {
257                unnamed,
258                ..
259            }) => {
260                if unnamed.len() != 1 {
261                    return quote_spanned! {
262                        unnamed.span() =>
263                        compile_error!("must have single target type in TokenParse variant");
264                    };
265                }
266                let field = unnamed.first().unwrap();
267                match &field.ty {
268                    Type::Reference(ty) => {
269                        if ty.mutability.is_some() {
270                            return quote_spanned! {
271                                unnamed.span() =>
272                                compile_error!("cannot have `&mut` in TokenParse variant");
273                            };
274                        }
275                        if !matches!(&ty.lifetime, Some(Lifetime { ident, ..}) if Some(ident) == original_lifetime_param.as_ref()) {
276                            return quote_spanned! {
277                                unnamed.span() =>
278                                compile_error!("unexpected lifetime in TokenParse variant (use the same one as defined in enum declaration)");
279                            };
280                        }
281                        if let Type::Path(path) = &*ty.elem {
282                            if path.qself.is_some() || path.path.segments.len() != 1 || path.path.segments.first().unwrap().ident != "str" {
283                                return quote_spanned! {
284                                    unnamed.span() =>
285                                    compile_error!("invalid type in reference for TokenParse (only &str allowed)");
286                                };
287                            }
288                        } else {
289                            return quote_spanned! {
290                                unnamed.span() =>
291                                compile_error!("invalid type in reference for TokenParse (only &str allowed)");
292                            };
293                        }
294                        parse_data.has_target = true;
295                    }
296                    _ => {
297                        parse_data.has_target = true;
298                        parse_data.target_needs_parse = true;
299                    }
300                }
301            }
302            // no target
303            Fields::Unit => {
304                if parse_data.is_illegal {
305                    return quote_spanned! {
306                        variant.span() =>
307                        compile_error!("'illegal' attributed tokens must have a single field (usually 'char' or '&str')");
308                    };
309                }
310            }
311        }
312
313        tokens_to_parse.push(parse_data)
314    }
315
316    let mut simple_regexes = BTreeMap::new();
317    for (token_index, item) in tokens_to_parse.iter().enumerate() {
318        for simple_regex in &item.simple_regexes {
319            let parsed = match SimpleRegex::parse(simple_regex) {
320                Some(x) => x,
321                None => {
322                    return quote_spanned! {
323                        item.ident.span() =>
324                        compile_error!("invalid simple regex");
325                    }
326                }
327            };
328            simple_regexes.insert(
329                (item.ident.clone(), simple_regex.clone()),
330                SimpleRegexData {
331                    token_index,
332                    regex: parsed,
333                },
334            );
335        }
336    }
337
338    let mut regexes = BTreeMap::new();
339    for (token_index, item) in tokens_to_parse.iter().enumerate() {
340        for regex in &item.regexes {
341            let modified_regex = format!("^{}", regex);
342            let parsed = match Regex::new(&*modified_regex) {
343                Ok(x) => x,
344                Err(_) => {
345                    return quote_spanned! {
346                        item.ident.span() =>
347                        compile_error!("invalid simple regex");
348                    }
349                }
350            };
351            regexes.insert(
352                (item.ident.clone(), regex.clone()),
353                RegexData {
354                    token_index,
355                    regex: parsed,
356                },
357            );
358        }
359    }
360
361    // (regex ident, raw regex) => (literal ident, literal)
362    let mut simple_regex_ident_conflicts: BTreeMap<(Ident, String), Vec<(Ident, String)>> = BTreeMap::new();
363    let mut regex_ident_conflicts: BTreeMap<(Ident, String), Vec<(Ident, String)>> = BTreeMap::new();
364    let mut known_literals = HashSet::new();
365
366    let mut parse_fns: BTreeMap<usize, Vec<TokenStream2>> = BTreeMap::new();
367
368    let mut lit_table = LitTable::default();
369    for (token_index, item) in tokens_to_parse.iter().enumerate() {
370        for literal in &item.literals {
371            if !known_literals.insert(literal.clone()) {
372                return quote_spanned! {
373                    item.ident.span() =>
374                    compile_error!("conflicting literals");
375                };
376            }
377            let mut any_matched = false;
378            for ((ident, raw_regex), regex) in &simple_regexes {
379                if regex.token_index > token_index && regex.regex.matches(&**literal) {
380                    simple_regex_ident_conflicts
381                        .entry((ident.clone(), raw_regex.clone()))
382                        .or_default()
383                        .push((item.ident.clone(), literal.clone()));
384                    any_matched = true;
385                }
386            }
387            if any_matched {
388                continue;
389            }
390            for ((ident, raw_regex), regex) in &regexes {
391                if regex.token_index > token_index && regex.regex.is_match(&**literal) {
392                    regex_ident_conflicts
393                        .entry((ident.clone(), raw_regex.clone()))
394                        .or_default()
395                        .push((item.ident.clone(), literal.clone()));
396                    any_matched = true;
397                }
398            }
399            if any_matched {
400                continue;
401            }
402            lit_table.push(item, &**literal, &mut literal.chars());
403        }
404    }
405
406    let lit_table_name = format_ident!("parse_lits");
407    let lit_table = lit_table.emit(&lit_table_name, &input.ident);
408
409    if let Err(e) = gen::simple_regex::gen_simple_regex(&tokens_to_parse[..], &simple_regexes, &simple_regex_ident_conflicts, &input.ident, &mut parse_fns) {
410        return e;
411    }
412    if let Err(e) = gen::full_regex::gen_full_regex(&tokens_to_parse[..], &regex_ident_conflicts, &input.ident, &mut parse_fns) {
413        return e;
414    }
415
416    let lifetime_param = if has_lifetime_param {
417        quote! { <'a> }
418    } else {
419        quote! {}
420    };
421    let ident_raw = input.ident.to_string();
422    let tokenizer_ident = if ident_raw.contains("Token") {
423        format_ident!("{}", ident_raw.replace("Token", "Tokenizer"))
424    } else {
425        format_ident!("{}Tokenizer", ident_raw)
426    };
427    let token_ident = &input.ident;
428    let vis = &input.vis;
429
430    let display_fields = gen::display::gen_display(&tokens_to_parse[..], &input.ident);
431
432    let illegal_emission = if let Some(illegal) = tokens_to_parse.iter().find(|x| x.is_illegal) {
433        let constructor = construct_variant(illegal, &input.ident);
434        quote! {
435            if let Some(value) = self.inner.chars().next() {
436                let span = ::compiler_tools::Span {
437                    line_start: self.line,
438                    col_start: self.col,
439                    line_stop: if value == '\n' {
440                        self.line
441                    } else {
442                        self.line += 1;
443                        self.line
444                    },
445                    col_stop: if value != '\n' {
446                        self.col += value.len_utf8() as u64;
447                        self.col
448                    } else {
449                        self.col = 0;
450                        self.col
451                    },
452                };
453                let passed = &self.inner[..value.len_utf8()];
454                self.inner = &self.inner[value.len_utf8()..];
455                return Some(::compiler_tools::Spanned {
456                    token: #constructor,
457                    span,
458                })
459            } else {
460                None
461            }
462        }
463    } else {
464        quote! {
465            None
466        }
467    };
468
469    let reinput = {
470        let attrs = flatten(&input.attrs);
471        let vis = &input.vis;
472        let ident = &input.ident;
473        let generics = &input.generics;
474        let mut variants = vec![];
475        for variant in &items.variants {
476            let attrs = flatten(
477                variant
478                    .attrs
479                    .iter()
480                    .filter(|a| a.path().segments.len() != 1 || a.path().segments.first().unwrap().ident != "token"),
481            );
482            let ident = &variant.ident;
483            let fields = &variant.fields;
484            // discriminant ignored
485            variants.push(quote! {
486                #attrs
487                #ident #fields,
488            });
489        }
490        let variants = flatten(variants);
491        quote! {
492            #attrs
493            #vis enum #ident #generics {
494                #variants
495            }
496        }
497    };
498
499    let class_matches = gen_class_match(&tokens_to_parse[..], &input.ident);
500
501    for (token_index, token) in tokens_to_parse.iter().enumerate() {
502        if let Some(parse_fn) = &token.parse_fn {
503            let path_expr: ExprPath = match syn::parse_str(&parse_fn) {
504                Ok(x) => x,
505                Err(_e) => {
506                    parse_fns
507                        .entry(token_index)
508                        .or_default()
509                        .push(quote! { compile_error!("can't parse path for parse_fn"); });
510                    continue;
511                }
512            };
513            let constructed = construct_variant(token, &input.ident);
514            parse_fns.entry(token_index).or_default().push(quote! {
515                {
516                    if let Some((passed, remaining)) = #path_expr(self.inner) {
517                        let span = ::compiler_tools::Span {
518                            line_start: self.line,
519                            col_start: self.col,
520                            line_stop: {
521                                self.line += passed.chars().filter(|x| *x == '\n').count() as u64;
522                                self.line
523                            },
524                            //todo: handle utf8 better with newline seeking here
525                            col_stop: if let Some(newline_offset) = passed.as_bytes().iter().rev().position(|x| *x == b'\n') {
526                                let newline_offset = passed.len() - newline_offset;
527                                self.col = (newline_offset as u64).saturating_sub(1);
528                                self.col
529                            } else {
530                                self.col += passed.len() as u64;
531                                self.col
532                            },
533                        };
534                        self.inner = remaining;
535                        match passed {
536                            passed => return Some(::compiler_tools::Spanned {
537                                token: #constructed,
538                                span,
539                            }),
540                        }
541                    }
542                }
543            });
544        }
545    }
546
547    let lit_table_parse = quote! {
548        match #lit_table_name(self.inner) {
549            Some((token, remaining, newlines)) => {
550                let span = ::compiler_tools::Span {
551                    line_start: self.line,
552                    col_start: self.col,
553                    line_stop: if newlines == 0 {
554                        self.line
555                    } else {
556                        self.line += newlines;
557                        self.line
558                    },
559                    col_stop: if newlines == 0 {
560                        self.col += (self.inner.len() - remaining.len()) as u64;
561                        self.col
562                    } else {
563                        //todo: handle utf8 better with newline seeking here
564                        let newline_offset = self.inner[..self.inner.len() - remaining.len()].as_bytes().iter().rev().position(|x| *x == b'\n').expect("malformed newline state");
565                        let newline_offset = (self.inner.len() - remaining.len()) - newline_offset;
566                        self.col = (newline_offset as u64).saturating_sub(1);
567                        self.col
568                    },
569                };
570                self.inner = remaining;
571                return Some(::compiler_tools::Spanned {
572                    token,
573                    span,
574                });
575            },
576            None => (),
577        }
578    };
579
580    if let Some(lit_table_index) = tokens_to_parse.iter().position(|x| !x.literals.is_empty()) {
581        parse_fns.entry(lit_table_index).or_default().push(lit_table_parse);
582    }
583
584    let parse_fns = flatten(parse_fns.into_values().flatten());
585
586    quote! {
587        #reinput
588
589        impl #lifetime_param ::core::fmt::Display for #token_ident #lifetime_param {
590            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
591                match self {
592                    #display_fields
593                }
594            }
595        }
596
597        impl #lifetime_param ::compiler_tools::TokenExt for #token_ident #lifetime_param {
598            fn matches_class(&self, other: &Self) -> bool {
599                match (self, other) {
600                    #class_matches
601                }
602            }
603        }
604
605        #vis struct #tokenizer_ident<'a> {
606            line: u64,
607            col: u64,
608            inner: &'a str,
609        }
610
611        impl<'a> #tokenizer_ident<'a> {
612            pub fn new(input: &'a str) -> Self {
613                Self {
614                    line: 0,
615                    col: 0,
616                    inner: input,
617                }
618            }
619        }
620
621        impl<'a> ::compiler_tools::TokenParse<'a> for #tokenizer_ident<'a> {
622            type Token = #token_ident #lifetime_param;
623
624            #[allow(non_snake_case, unreachable_pattern, unreachable_code)]
625            fn next(&mut self) -> Option<::compiler_tools::Spanned<Self::Token>> {
626                #lit_table
627                #parse_fns
628                #illegal_emission
629            }
630        }
631    }
632}