fn_bnf_macro/
lib.rs

1//! Proc macro for [fn-bnf](https://docs.rs/fn-bnf). Go there. This is not a place of honor.
2
3#![doc(hidden)] // it's the proc macro who gives a shit
4
5use indexmap::IndexMap;
6use proc_macro_error2::{Diagnostic, Level};
7use proc_macro2::Span;
8use proc_macro::TokenStream;
9use quote::{quote, TokenStreamExt};
10use syn::{
11    ext::IdentExt, parse::{Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Attribute, Expr, GenericParam, Generics, Ident, Lifetime, LifetimeParam, Stmt, Type, TypeInfer, Visibility
12};
13use itertools::Itertools;
14
15mod kw {
16    syn::custom_keyword!(grammar);
17    syn::custom_keyword!(from);
18    syn::custom_keyword!(try_from);
19}
20
21fn cr8_name() -> syn::Path {
22    use proc_macro_crate::{crate_name, FoundCrate};
23    let found_crate = crate_name("fn-bnf").expect("fn-bnf should be present in `Cargo.toml`");
24
25    match found_crate {
26        FoundCrate::Itself => syn::parse_quote!( ::fn_bnf ),
27        FoundCrate::Name(name) => syn::Path {
28            leading_colon: Some(syn::token::PathSep::default()),
29            segments: [syn::PathSegment { ident: Ident::new(&name, Span::call_site()), arguments: syn::PathArguments::None }]
30                .into_iter().collect()
31        }
32    }
33}
34
35#[derive(Debug)]
36struct Grammar {
37    attrs: Vec<Attribute>,
38    vis: Visibility,
39    ident: Ident,
40    ty: Type,
41    rules: IndexMap<Ident, RuleBody>,
42}
43
44impl Parse for Grammar {
45    fn parse(input: ParseStream) -> syn::Result<Self> {
46        let attrs = input.call(Attribute::parse_outer)?;
47        let vis: Visibility = input.parse()?;
48        input.parse::<kw::grammar>()?;
49        let ident: Ident = input.parse()?;
50        input.parse::<syn::Token![<]>()?;
51        let ty: Type = input.parse()?;
52        input.parse::<syn::Token![>]>()?;
53        let body;
54        syn::braced!(body in input);
55        let mut rules = IndexMap::new();
56
57        while !body.is_empty() {
58            let attrs = body.call(syn::Attribute::parse_outer)?;
59            let vis: syn::Visibility = body.parse()?;
60            let name = body.call(Ident::parse_any)?;
61            let generics = Generics::parse(&body)?;
62            let mut fields = Fields::Unit;
63            if body.peek(syn::token::Brace) {
64                fields = Fields::Structured(body.parse()?);
65            }
66            body.parse::<syn::Token![->]>()?;
67            let ty = body.parse::<Type>()?;
68            let mut func = MapFunc::Empty;
69            if body.peek(kw::from) {
70                body.parse::<kw::from>()?;
71                let parens;
72                syn::parenthesized!(parens in body);
73                func = MapFunc::From(parens.parse::<syn::Expr>()?);
74            } else if body.peek(kw::try_from) {
75                body.parse::<kw::try_from>()?;
76                let parens;
77                syn::parenthesized!(parens in body);
78                func = MapFunc::TryFrom(parens.parse::<syn::Expr>()?);
79            }
80            let mut where_clause = None;
81            if body.peek(syn::Token![where]) {
82                where_clause = Some(body.parse::<syn::WhereClause>()?);
83            }
84            body.parse::<syn::Token![=]>()?;
85            let definition = body.parse::<RuleGroup>()?;
86            let end = body.parse::<syn::Token![;]>()?;
87            let mut span = vis.span();
88            if let Some(joined) = span.join(end.span) {
89                span = joined;
90            }
91
92            if let Some(
93                RuleBody { span: dupe_span, .. }
94            ) = rules.insert(name,
95                RuleBody { vis, attrs, generics, fields, ty, func, where_clause, span, group: definition }
96            ) {
97                Diagnostic::spanned(span, Level::Error, "cannot have duplicate rules".into())
98                    .span_note(dupe_span, "first definition of rule here".into())
99                    .abort();
100            }
101        }
102
103        Ok(Grammar {
104            attrs,
105            vis,
106            ident,
107            ty,
108            rules,
109        })
110    }
111}
112
113#[derive(Debug, Clone)]
114struct RuleGroup {
115    options: Punctuated<RulePath, syn::Token![:]>,
116}
117
118impl Parse for RuleGroup {
119    fn parse(input: ParseStream) -> syn::Result<Self> {
120        Punctuated::<RulePath, syn::Token![:]>::parse_separated_nonempty(input).map(|options| Self { 
121            options 
122        })
123    }
124}
125
126#[derive(Debug, Clone)]
127enum Fields {
128    Unit,
129    Structured(syn::FieldsNamed)
130}
131
132impl quote::ToTokens for Fields {
133    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
134        match self {
135            Fields::Unit => tokens.append_all(quote!(;)),
136            Fields::Structured(fields) 
137                => fields.to_tokens(tokens),
138        }
139    }
140}
141
142#[derive(Debug, Clone)]
143enum MapFunc {
144    Empty,
145    From(Expr),
146    TryFrom(Expr)
147}
148
149#[derive(Debug, Clone)]
150struct RuleBody {
151    vis: syn::Visibility,
152    attrs: Vec<syn::Attribute>,
153    generics: Generics,
154    fields: Fields,
155    ty: Type,
156    func: MapFunc,
157    where_clause: Option<syn::WhereClause>,
158    span: Span,
159    group: RuleGroup,
160}
161
162#[derive(Debug, Clone)]
163struct ElementTy {
164    silent: bool,
165    inner: Expr
166}
167
168impl Parse for ElementTy {
169    fn parse(input: ParseStream) -> syn::Result<Self> {
170        let silent = input.peek(syn::Token![_]);
171        if silent { input.parse::<syn::Token![_]>()?; }
172        input.parse().map(|inner| Self { silent, inner })
173    }
174}
175
176impl quote::ToTokens for ElementTy {
177    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
178        self.inner.to_tokens(tokens);
179    }
180}
181
182#[derive(Debug, Clone)]
183struct RulePath {
184    elements: Punctuated<ElementTy, syn::Token![,]>,
185}
186
187impl Parse for RulePath {
188    fn parse(input: ParseStream) -> syn::Result<Self> {
189        Ok(Self {
190            elements: Punctuated::<ElementTy, syn::Token![,]>::parse_separated_nonempty(input)?
191        })
192    }
193}
194
195impl RuleBody {
196    #[allow(clippy::too_many_lines)]
197    fn tokenize(self, name: &Ident, rule_ty: &Type) -> Vec<Stmt> {
198        let cr8 = cr8_name();
199
200        let Self { vis, attrs, mut generics, fields, ty, func, where_clause, span, group } = self;
201        let infer = Type::Infer(TypeInfer { underscore_token: Default::default() });
202
203        let tys = match &ty {
204            t @ Type::Tuple(tup) if tup.elems.is_empty() => vec![t],
205            Type::Tuple(tup) => tup.elems.iter().collect(),
206            other => vec![other]
207        }.into_iter();
208
209        let mut impl_generics = generics.clone();
210        let rule_generics = impl_generics.clone();
211        if !impl_generics.params.iter().filter_map(|param| {
212            let GenericParam::Lifetime(p) = param else { return None; };
213            Some(&p.lifetime)
214        }).any(|lt| lt.ident == "input") {
215            impl_generics.params.push(GenericParam::Lifetime(
216                LifetimeParam::new(Lifetime::new("'input", name.span()))
217            ));
218        }
219        for param in &mut generics.params {
220            match param {
221                GenericParam::Lifetime(lt) => {
222                    lt.colon_token = None;
223                    lt.bounds.clear();
224                }
225                GenericParam::Type(ty) => {
226                    ty.colon_token = None;
227                    ty.bounds.clear();
228                }
229                GenericParam::Const(cst) => {
230                    *param = GenericParam::Type(syn::TypeParam {
231                        attrs: cst.attrs.clone(), 
232                        ident: cst.ident.clone(),
233                        colon_token: None,
234                        bounds: Punctuated::default(),
235                        eq_token: None,
236                        default: None
237                    });
238                }
239            }
240        }
241        let mut min_options = None::<usize>;
242        let mut max_options = None::<usize>;
243        for option in &group.options {
244            let count = option.elements.iter().filter(|el| !el.silent).count();
245            min_options = Some(min_options.map_or(count, |m| m.min(count)));
246            max_options = Some(max_options.map_or(count, |m| m.max(count)));
247        }
248        let min_options = min_options.unwrap_or(0);
249        let max_options = max_options.unwrap_or(0);
250        let mut element_defs = Vec::<Stmt>::new();
251
252        element_defs.push(syn::parse_quote!(
253            let _ = (
254                "### DEBUG INFORMATION ###",
255                "min_options: ", #min_options,
256                "max_options:", #max_options
257            );
258        ));
259
260        let variable_names = (0..max_options)
261            .map(|n| syn::Ident::new_raw(&format!("arg_{n}"), span))
262            .collect_vec();
263        let optional_variable_defs = variable_names.iter()
264            .skip(min_options)
265            .map(|id| -> Stmt {syn::parse_quote_spanned! {span=> let mut #id = None;}})
266            .collect_vec();
267    
268        for (i, option) in group.options.iter().enumerate() {
269            let at_end = i + 1 >= group.options.len();
270            let mut next_args = Vec::<Stmt>::new();
271
272            let return_expr: Expr = match &func {
273                MapFunc::Empty =>
274                    syn::parse_quote_spanned!(span=> Ok((#(#variable_names),*))),
275                MapFunc::From(func) =>
276                    syn::parse_quote_spanned!(span=> Ok((#func)(#(#variable_names),*))),
277                MapFunc::TryFrom(func) => 
278                    syn::parse_quote_spanned!(
279                        span=> 
280                        (#func)(#(#variable_names),*)
281                            .map_err(|err| {
282                                (*input, *index) = __before;
283                                #cr8::ParseError::new(
284                                    Some(#cr8::Box::new(err)),
285                                    rule.name(),
286                                    *index
287                                )
288                            }
289                        )
290                    )                
291            };
292            let mut tys = tys.clone();
293            let mut iter = option.elements.iter();
294            let Some(first) = iter.next() else {
295                element_defs.extend::<Vec<Stmt>>(syn::parse_quote_spanned!(span=>
296                    return #return_expr;
297                ));
298                break;
299            };
300
301            let mut name_iter = variable_names.iter();
302            if !first.silent {
303                let _ = name_iter.next();
304            }
305
306            let fail_condition: Stmt = if at_end {
307                syn::parse_quote_spanned!(span=> {
308                    (*input, *index) = __before;
309                    return Err(#cr8::ParseError::new(
310                        Some(#cr8::Box::new(err)),
311                        rule.name().or(self.name()),
312                        *index
313                    ))
314                })
315            } else {
316                syn::parse_quote_spanned!(span=> {
317                    (*input, *index) = __before;
318                    break 'b;
319                })
320            };
321
322            let names = &mut name_iter;
323
324            let maybe_arg0 = (!first.silent).then(|| -> Stmt {
325                if first.silent {
326                    syn::parse_quote_spanned!(first.span()=> ;)
327                } else {
328                    let first_ty = if let MapFunc::Empty = func {
329                        tys.next().unwrap_or_else(||
330                            Diagnostic::spanned(ty.span(), Level::Error, "returned type count does not match options".into())
331                                .abort()
332                        )
333                    } else { &infer };
334                    if min_options == 0 {
335                        syn::parse_quote_spanned!(first.span()=> let mut arg_0: #first_ty = Some(first);)
336                    } else {
337                        syn::parse_quote_spanned!(first.span()=> let mut arg_0: #first_ty = first;)
338                    }
339                }
340            }).into_iter();
341
342            let take_count = min_options.saturating_sub(usize::from(!first.silent));
343            for el in (&mut iter).take(take_count) {
344                let opt_ty = if let MapFunc::Empty = func {
345                    tys.next().unwrap_or_else(||
346                        Diagnostic::spanned(ty.span(), Level::Error, "returned type count does not match options".into())
347                            .abort()
348                    )
349                } else { &infer };
350                if el.silent {
351                    next_args.extend::<Vec<Stmt>>(syn::parse_quote_spanned!(el.span()=>
352                        let rule = { #el };
353                        if let Err(err) = #cr8::Rule::<'input, #rule_ty>::parse_at(&rule, input, index) {
354                            #fail_condition
355                        };
356                    ));
357                    continue;
358                }
359                let Some(name) = names.next() else { break };
360                next_args.extend::<Vec<Stmt>>(syn::parse_quote_spanned!(el.span()=>
361                    let rule = { #el };
362                    let mut #name: #opt_ty = match #cr8::Rule::<'input, #rule_ty>::parse_at(&rule, input, index) {
363                        Ok(val) => val,
364                        Err(err) => #fail_condition
365                    };
366                ));
367            }
368            for el in &mut iter {
369                if el.silent {
370                    next_args.extend::<Vec<Stmt>>(syn::parse_quote_spanned!(el.span()=>
371                        let rule = { #el };
372                        if let Err(err) = #cr8::Rule::<'input, #rule_ty>::parse_at(&rule, input, index) {
373                            #fail_condition
374                        };
375                    ));
376                    continue;
377                }
378                let Some(name) = names.next() else { break };
379                next_args.extend::<Vec<Stmt>>(syn::parse_quote_spanned!(el.span()=>
380                    let rule = { #el };
381                    let mut #name = match #cr8::Rule::<'input, #rule_ty>::parse_at(&rule, input, index) {
382                        Ok(val) => Some(val),
383                        Err(err) => #fail_condition
384                    };
385                ));
386            }
387            
388            element_defs.extend::<Vec<Stmt>>(syn::parse_quote_spanned!(first.span()=> 'b: {
389                let rule = { #first };
390                match #cr8::Rule::<'input, #rule_ty>::parse_at(&rule, input, index) {
391                    Ok(first) => {
392                        #(#maybe_arg0)*
393                        #(#next_args)*
394                        return #return_expr;
395                    }
396                    Err(err) => {
397                        #fail_condition
398                    }
399                }
400            }));
401        }
402
403        syn::parse_quote_spanned!(self.span=>
404            #(#attrs)*
405            #vis struct #name #rule_generics #where_clause #fields
406
407            impl #rule_generics #cr8::NamedRule for #name #generics #where_clause {
408                fn name(&self) -> Option<&'static str> { Some(stringify!(#name)) }
409            }
410
411            impl #impl_generics #cr8::Rule<'input, #rule_ty> for #name #generics #where_clause {
412                type Output = #ty;
413
414                #[allow(unused_variables, unreachable_code, unused_labels, unused_parens, unused_mut)]
415                fn parse_at<'cursor, 'this, 'index>(&'this self, input: &'cursor mut &'input #rule_ty, index: &'index mut usize)
416                    -> Result<Self::Output, #cr8::ParseError>
417                    where 'input : 'this
418                {
419                    #[allow(unused)]
420                    let __before = (*input, *index);
421                    #(#optional_variable_defs)*
422
423                    #(#element_defs)*
424
425                    Err(#cr8::ParseError::new(Some(#cr8::Box::new(#cr8::errors::ExhaustedInput)), self.name(), *index))
426                }
427            }
428        )
429    }
430}
431
432#[proc_macro_derive(NamedRule)]
433pub fn derive_named(input: TokenStream) -> TokenStream {
434    use syn::Data;
435
436    let cr8 = cr8_name();
437
438    let input = parse_macro_input!(input as syn::DeriveInput);
439
440    let name = input.ident;
441    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
442    let expanded = match &input.data {
443        Data::Struct(_) | Data::Union(_) => quote! {
444            #[automatically_derived]
445            impl #impl_generics #cr8::NamedRule for #name #ty_generics #where_clause {
446                #[inline]
447                fn name(&self) -> Option<&'static str> {
448                    Some(stringify!(#name))
449                }
450            }
451        },
452        Data::Enum(data_enum) => {
453            let variants = &data_enum.variants;
454            let variant_pats = variants.iter().map(|variant| -> syn::Pat {
455                let name = &variant.ident;
456                match variant.fields {
457                    syn::Fields::Unit => 
458                        syn::Pat::Path(syn::parse_quote!(Self::#name)),
459                    syn::Fields::Named(_) => 
460                        syn::Pat::Struct(syn::PatStruct {
461                            attrs: vec![],
462                            qself: None,
463                            path: syn::parse_quote!(Self::#name),
464                            brace_token: Default::default(),
465                            fields: Punctuated::new(),
466                            rest: Some(syn::PatRest { attrs: vec![], dot2_token: Default::default(), })
467                        }),
468                    syn::Fields::Unnamed(_) => 
469                        syn::Pat::TupleStruct(syn::PatTupleStruct {
470                            attrs: vec![],
471                            qself: None,
472                            path: syn::parse_quote!(Self::#name),
473                            paren_token: Default::default(),
474                            elems: [syn::Pat::Rest(syn::PatRest { attrs: vec![], dot2_token: Default::default(), })]
475                                    .into_iter().collect()
476                        })
477
478                }
479            });
480            let field_names = variants.iter()
481                .map(|variant| format!("{name}::{}", variant.ident));
482            
483            quote! {
484                #[automatically_derived]
485                impl #impl_generics #cr8::NamedRule for #name #ty_generics #where_clause  {
486                    #[inline]
487                    fn name(&self) -> Option<&'static str> {
488                        Some(match &self {
489                            #(#variant_pats => #field_names),*
490                        })
491                    }
492                }
493            }
494        }
495    };
496
497    TokenStream::from(expanded)
498}
499
500#[proc_macro_error2::proc_macro_error]
501#[proc_macro]
502pub fn define(input: TokenStream) -> TokenStream {
503    let Grammar {
504        attrs, vis, ident, ty, rules, ..
505    } = parse_macro_input!(input as Grammar);
506    let cr8 = cr8_name();
507
508    let rules = rules.iter().map(|(name, body)| body.clone().tokenize(name, &ty));
509
510    quote!(
511        #( #attrs )*
512        #[allow(non_snake_case)]
513        #vis mod #ident {
514            #![allow(unused_imports)]
515            #![allow(clippy::double_parens)]
516            use super::*;
517            use #cr8::NamedRule as _;
518            #(#(#rules)*)*
519        }
520    ).into()
521}