roto_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::Parse, parse_macro_input, spanned::Spanned, Error, Token};
4
5#[proc_macro_derive(Context)]
6pub fn roto_context(item: TokenStream) -> TokenStream {
7    let item = parse_macro_input!(item as syn::DeriveInput);
8
9    let struct_name = &item.ident;
10
11    let syn::Data::Struct(s) = &item.data else {
12        panic!("Only structs can be used as context");
13    };
14
15    let syn::Fields::Named(fields) = &s.fields else {
16        panic!("Fields must be named");
17    };
18
19    let fields: Vec<_> = fields
20        .named
21        .iter()
22        .map(|f| {
23            if !matches!(f.vis, syn::Visibility::Public(_)) {
24                panic!("All fields must be marked pub")
25            }
26
27            let field_name = f.ident.as_ref().unwrap();
28            let field_ty = &f.ty;
29            let offset = quote!(std::mem::offset_of!(Self, #field_name));
30            let type_name = quote!(std::any::type_name::<#field_ty>());
31            let type_id = quote!(std::any::TypeId::of::<#field_ty>());
32            let docstring = gather_docstring(&f.attrs);
33
34            quote!(
35                roto::__internal::ContextField {
36                    name: stringify!(#field_name),
37                    offset: #offset,
38                    type_name: #type_name,
39                    type_id: #type_id,
40                    docstring: #docstring,
41                }
42            )
43        })
44        .collect();
45
46    let expanded = quote!(
47        unsafe impl Context for #struct_name {
48            fn fields() -> Vec<roto::__internal::ContextField> {
49                vec![
50                    #(#fields),*
51                ]
52            }
53        }
54    );
55
56    TokenStream::from(expanded)
57}
58
59struct ItemList {
60    items: Vec<ItemWithDocs>,
61}
62
63struct ItemWithDocs {
64    doc: proc_macro2::TokenStream,
65    item: Item,
66}
67
68enum Item {
69    Type(syn::ItemType),
70    Let(syn::ExprLet),
71    Fn(syn::ItemFn),
72    Mod(syn::Ident, ItemList),
73    Impl(proc_macro2::Span, syn::Type, ItemList),
74    Const(syn::ItemConst),
75    Include(proc_macro2::TokenStream),
76    Use(syn::ItemUse),
77}
78
79mod kw {
80    syn::custom_keyword!(include);
81}
82
83impl Parse for ItemList {
84    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
85        let mut items = Vec::new();
86
87        while !input.is_empty() {
88            // We need to parse the docstring and other attributes first, so
89            // that we can then branch on the next keyword. Afterwards, we can
90            // then assign the attributes to the item we parsed.
91            let attributes = syn::Attribute::parse_outer(input)?;
92            let doc = gather_docstring(&attributes);
93
94            let look = input.lookahead1();
95            let item = match () {
96                // Case 1: A normal function
97                _ if look.peek(Token![fn]) => {
98                    let mut item: syn::ItemFn = input.parse()?;
99                    item.attrs = attributes;
100                    Item::Fn(item)
101                }
102                // Case 2: A module
103                _ if look.peek(Token![mod]) => {
104                    input.parse::<Token![mod]>()?;
105                    let ident: syn::Ident = input.parse()?;
106                    let content;
107                    syn::braced!(content in input);
108                    let item_list: ItemList = content.parse()?;
109                    Item::Mod(ident, item_list)
110                }
111                // Case 3: A let binding with a closure
112                _ if look.peek(Token![let]) => {
113                    let mut item: syn::ExprLet = input.parse()?;
114                    input.parse::<Token![;]>()?;
115                    item.attrs = attributes;
116                    Item::Let(item)
117                }
118                // Case 4: An impl block
119                _ if look.peek(Token![impl]) => {
120                    let tok = input.parse::<Token![impl]>()?;
121                    let ty: syn::Type = input.parse()?;
122
123                    let content;
124                    syn::braced!(content in input);
125                    let item_list: ItemList = content.parse()?;
126                    Item::Impl(tok.span, ty, item_list)
127                }
128                // Case 5: A constant
129                _ if look.peek(Token![const]) => {
130                    let mut item: syn::ItemConst = input.parse()?;
131                    item.attrs = attributes;
132                    Item::Const(item)
133                }
134                // Case 6: A clone type
135                _ if look.peek(Token![type]) => {
136                    let mut item: syn::ItemType = input.parse()?;
137                    item.attrs = attributes;
138                    Item::Type(item)
139                }
140                _ if look.peek(Token![use]) => {
141                    let item = input.parse()?;
142                    Item::Use(item)
143                }
144                _ if look.peek(kw::include) => {
145                    let m: syn::Macro = input.parse()?;
146                    input.parse::<Token![;]>()?;
147                    Item::Include(m.tokens)
148                }
149                _ => return Err(look.error()),
150            };
151
152            items.push(ItemWithDocs { doc, item });
153        }
154
155        Ok(ItemList { items })
156    }
157}
158
159#[proc_macro]
160pub fn library(input: TokenStream) -> TokenStream {
161    let parsed_items: ItemList = syn::parse_macro_input!(input);
162
163    let expanded = to_tokens(parsed_items, None)
164        .unwrap_or_else(Error::into_compile_error);
165
166    TokenStream::from(expanded)
167}
168
169fn location(s: proc_macro2::Span) -> proc_macro2::TokenStream {
170    let start = s.end();
171    let line = start.line as u32;
172    let column = start.column as u32;
173    quote! {
174        roto::Location {
175            file: file!(),
176            line: #line,
177            column: #column,
178        }
179    }
180}
181
182fn to_tokens(
183    item_list: ItemList,
184    ty: Option<&syn::Type>,
185) -> syn::Result<proc_macro2::TokenStream> {
186    let mut items = Vec::new();
187    for ItemWithDocs { doc, item } in item_list.items {
188        let new = match item {
189            Item::Type(item) => {
190                let ident = &item.ident;
191                let location = location(ident.span());
192                let ident_str = ident.to_string();
193                let ty = &item.ty;
194                let movability = get_movability(item.span(), &item.attrs)?;
195                quote! {
196                    roto::Type::#movability::<#ty>(
197                        #ident_str,
198                        #doc,
199                        #location,
200                    ).unwrap()
201                }
202            }
203            Item::Let(item) => {
204                let pat = item.pat;
205
206                let syn::Pat::Ident(ident) = &*pat else {
207                    todo!("good error message");
208                };
209                let location = location(ident.ident.span());
210                let ident_str = ident.ident.to_string();
211
212                let expr = item.expr;
213                let syn::Expr::Closure(closure) = &*expr else {
214                    todo!("good error message");
215                };
216
217                let params: Vec<_> = closure
218                    .inputs
219                    .iter()
220                    .map(|p| param_name(p).unwrap())
221                    .collect();
222
223                quote! {
224                    roto::Function::new(
225                        #ident_str,
226                        #doc,
227                        { let x: Vec<&'static str> = vec![#(#params),*]; x },
228                        #expr,
229                        #location,
230                    ).unwrap()
231                }
232            }
233            Item::Fn(item) => {
234                let sig = &item.sig;
235                let ident = &sig.ident;
236                let location = location(ident.span());
237                let ident_str = ident.to_string();
238                let params: Vec<_> = item
239                    .sig
240                    .inputs
241                    .iter()
242                    .map(|arg| match arg {
243                        syn::FnArg::Receiver(_) => "self".into(),
244                        syn::FnArg::Typed(pat) => {
245                            param_name(&pat.pat).unwrap()
246                        }
247                    })
248                    .collect();
249
250                let expr = if let Some(ty) = ty {
251                    // This is a trick to allow method syntax:
252                    //  - We define a private extension trait in a const block.
253                    //  - Then we implement that trait, which won't conflict
254                    //    with anything.
255                    //  - Then we export that method as a free function.
256                    //
257                    // We do need to map each pattern to `_` because patterns
258                    // are not allowed to appear in trait definitions.
259                    let new_inputs = sig
260                        .inputs
261                        .iter()
262                        .map(|arg| match arg {
263                            syn::FnArg::Receiver(rec) => {
264                                syn::FnArg::Receiver(syn::Receiver {
265                                    mutability: None,
266                                    ..rec.clone()
267                                })
268                            }
269                            syn::FnArg::Typed(pat_type) => {
270                                syn::FnArg::Typed(syn::PatType {
271                                    pat: Box::new(syn::Pat::Wild(
272                                        syn::PatWild {
273                                            attrs: Vec::new(),
274                                            underscore_token:
275                                                syn::token::Underscore {
276                                                    spans: [pat_type.span()],
277                                                },
278                                        },
279                                    )),
280                                    ..pat_type.clone()
281                                })
282                            }
283                        })
284                        .collect();
285
286                    let new_sig = syn::Signature {
287                        inputs: new_inputs,
288                        ident: syn::Ident::new("__ext__", sig.ident.span()),
289                        ..sig.clone()
290                    };
291
292                    let mut new_item = item.clone();
293                    new_item.sig.ident =
294                        syn::Ident::new("__ext__", sig.ident.span());
295
296                    quote!(const {
297                        trait Ext {
298                            #new_sig;
299                        }
300
301                        impl Ext for #ty {
302                            #new_item
303                        }
304
305                        <#ty as Ext>::__ext__
306                    })
307                } else {
308                    quote!({ #item #ident })
309                };
310
311                quote! {
312                    roto::Function::new(
313                        #ident_str,
314                        #doc,
315                        { let x: Vec<&'static str> = vec![#(#params),*]; x },
316                        #expr,
317                        #location,
318                    ).unwrap()
319                }
320            }
321            Item::Mod(ident, items) => {
322                let ident_str = ident.to_string();
323                let location = location(ident.span());
324                let items = to_tokens(items, None)?;
325                quote! {{
326                    let mut module = roto::Module::new(
327                        #ident_str,
328                        #doc,
329                        #location,
330                    ).unwrap();
331                    module.add(#items);
332                    module
333                }}
334            }
335            Item::Impl(span, ty, items) => {
336                let items = to_tokens(items, Some(&ty))?;
337                let location = location(span);
338                quote! {{
339                    let mut impl_block = roto::Impl::new::<#ty>(#location);
340                    impl_block.add(#items);
341                    impl_block
342                }}
343            }
344            Item::Const(item) => {
345                let ident = item.ident;
346                let location = location(ident.span());
347                let ident_str = ident.to_string();
348                let ty = item.ty;
349                let expr = item.expr;
350                quote! {
351                    roto::Constant::new::<#ty>(
352                        #ident_str,
353                        #doc,
354                        #expr,
355                        #location,
356                    ).unwrap()
357                }
358            }
359            Item::Include(item) => {
360                quote! { #item }
361            }
362            Item::Use(item) => {
363                let imports = flatten_use_tree(&item.tree);
364                quote! {
365                    roto::Use::new(
366                        vec![#(vec![#(#imports.to_string()),*]),*],
367                        roto::location!(),
368                    )
369                }
370            }
371        };
372
373        items.push(quote! { #new });
374    }
375
376    Ok(quote! {{
377        let mut lib = roto::Library::new();
378        #(roto::Registerable::add_to_lib(#items, &mut lib);)*
379        lib
380    }})
381}
382
383fn get_movability(
384    span: proc_macro2::Span,
385    attrs: &[syn::Attribute],
386) -> syn::Result<syn::Ident> {
387    let mut clone = 0;
388    let mut copy = 0;
389    let mut value = 0;
390    let mut ident_span = None;
391
392    for attr in attrs {
393        if let syn::Meta::Path(p) = &attr.meta {
394            if p.is_ident("clone") {
395                clone += 1;
396                ident_span = Some(p.span());
397            } else if p.is_ident("copy") {
398                copy += 1;
399                ident_span = Some(p.span());
400            } else if p.is_ident("value") {
401                value += 1;
402                ident_span = Some(p.span());
403            }
404        }
405    }
406
407    let s =
408        match (clone, copy, value) {
409            (1, 0, 0) => "clone",
410            (0, 1, 0) => "copy",
411            (0, 0, 1) => "value",
412            _ => return Err(syn::Error::new(
413                span,
414                "specify exactly 1 of `#[clone]`, `#[copy]` or `#[value]`",
415            )),
416        };
417
418    Ok(syn::Ident::new(s, ident_span.unwrap()))
419}
420
421fn flatten_use_tree(tree: &syn::UseTree) -> Vec<Vec<String>> {
422    match tree {
423        syn::UseTree::Path(p) => {
424            let recursive = flatten_use_tree(&p.tree);
425            recursive
426                .into_iter()
427                .map(|v| {
428                    let mut new_v = vec![p.ident.to_string()];
429                    new_v.extend(v);
430                    new_v
431                })
432                .collect()
433        }
434        syn::UseTree::Name(name) => {
435            vec![vec![name.ident.to_string()]]
436        }
437        syn::UseTree::Group(g) => {
438            g.items.iter().flat_map(flatten_use_tree).collect()
439        }
440        syn::UseTree::Rename(_) => panic!(),
441        syn::UseTree::Glob(_) => panic!(),
442    }
443}
444
445fn param_name(pat: &syn::Pat) -> Option<String> {
446    match pat {
447        syn::Pat::Ident(ident) => Some(ident.ident.to_string()),
448        syn::Pat::Paren(paren) => param_name(&paren.pat),
449        syn::Pat::Reference(reference) => param_name(&reference.pat),
450        syn::Pat::TupleStruct(pat) => {
451            let elems: Vec<_> = pat.elems.iter().collect();
452            let [elem] = &*elems else { return None };
453            param_name(elem)
454        }
455        syn::Pat::Type(p) => param_name(&p.pat),
456        syn::Pat::Wild(_) => Some("_".to_string()),
457
458        // Rust will ensure that any name bound in any or the or pattern cases
459        // will also appear in the other cases and error out otherwise.
460        // Therefore, we can just look at the first case to extract the name.
461        syn::Pat::Or(p) => param_name(p.cases.first()?),
462
463        // ---
464        syn::Pat::Verbatim(_) => None,
465        syn::Pat::Tuple(_) => None,
466        syn::Pat::Struct(_) => None,
467        syn::Pat::Slice(_) => None,
468        syn::Pat::Rest(_) => None,
469        syn::Pat::Range(_) => None,
470        syn::Pat::Path(_) => None,
471        syn::Pat::Const(_) => None,
472        syn::Pat::Lit(_) => None,
473        syn::Pat::Macro(_) => None,
474        _ => None,
475    }
476}
477
478struct Intermediate {
479    function: proc_macro2::TokenStream,
480    ident: syn::Ident,
481    docstring: proc_macro2::TokenStream,
482    parameter_names: proc_macro2::TokenStream,
483}
484
485struct FunctionArgs {
486    runtime_ident: syn::Ident,
487    name: Option<syn::Ident>,
488}
489
490impl syn::parse::Parse for FunctionArgs {
491    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
492        let runtime_ident = input.parse()?;
493
494        let mut name = None;
495        if input.peek(Token![,]) {
496            input.parse::<Token![,]>()?;
497            if input.peek(syn::Ident) {
498                name = input.parse()?;
499            }
500        }
501
502        Ok(Self {
503            runtime_ident,
504            name,
505        })
506    }
507}
508
509#[proc_macro_attribute]
510pub fn roto_function(attr: TokenStream, item: TokenStream) -> TokenStream {
511    let item = parse_macro_input!(item as syn::ItemFn);
512    let Intermediate {
513        function,
514        ident,
515        docstring,
516        parameter_names,
517    } = generate_function(item);
518
519    let FunctionArgs {
520        runtime_ident,
521        name,
522    } = syn::parse(attr).unwrap();
523
524    let name = name.unwrap_or(ident.clone());
525
526    let expanded = quote! {
527        #function
528
529        #runtime_ident.register_fn(
530            stringify!(#name),
531            #docstring,
532            #parameter_names,
533            #ident,
534        ).unwrap();
535    };
536
537    TokenStream::from(expanded)
538}
539
540struct MethodArgs {
541    runtime_ident: syn::Ident,
542    ty: syn::Type,
543    name: Option<syn::Ident>,
544}
545
546impl syn::parse::Parse for MethodArgs {
547    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
548        let runtime_ident = input.parse()?;
549        input.parse::<Token![,]>()?;
550        let ty = input.parse()?;
551
552        let mut name = None;
553        if input.peek(Token![,]) {
554            input.parse::<Token![,]>()?;
555            if input.peek(syn::Ident) {
556                name = input.parse()?;
557            }
558        }
559        Ok(Self {
560            runtime_ident,
561            ty,
562            name,
563        })
564    }
565}
566
567#[proc_macro_attribute]
568pub fn roto_method(attr: TokenStream, item: TokenStream) -> TokenStream {
569    let item = parse_macro_input!(item as syn::ItemFn);
570    let Intermediate {
571        function,
572        ident,
573        docstring,
574        parameter_names,
575    } = generate_function(item);
576
577    let MethodArgs {
578        runtime_ident,
579        ty,
580        name,
581    } = parse_macro_input!(attr as MethodArgs);
582
583    let name = name.unwrap_or(ident.clone());
584
585    let expanded = quote! {
586        #function
587
588        #runtime_ident.register_method::<#ty, _, _>(
589            stringify!(#name),
590            #docstring,
591            #parameter_names,
592            #ident
593        ).unwrap();
594    };
595
596    TokenStream::from(expanded)
597}
598
599#[proc_macro_attribute]
600pub fn roto_static_method(
601    attr: TokenStream,
602    item: TokenStream,
603) -> TokenStream {
604    let item = parse_macro_input!(item as syn::ItemFn);
605    let Intermediate {
606        function,
607        ident,
608        docstring,
609        parameter_names,
610    } = generate_function(item);
611
612    let MethodArgs {
613        runtime_ident,
614        ty,
615        name,
616    } = parse_macro_input!(attr as MethodArgs);
617
618    let name = name.unwrap_or(ident.clone());
619
620    let expanded = quote! {
621        #function
622
623        #runtime_ident.register_static_method::<#ty, _, _>(
624            stringify!(#name),
625            #docstring.to_string(),
626            #parameter_names,
627            #ident
628        ).unwrap();
629    };
630
631    TokenStream::from(expanded)
632}
633
634fn gather_docstring(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
635    let mut docstring = Vec::new();
636
637    for attr in attrs {
638        if attr.path().is_ident("doc") {
639            let value = match &attr.meta {
640                syn::Meta::NameValue(name_value) => &name_value.value,
641                _ => panic!("doc attribute must be a name and a value"),
642            };
643            docstring.push(value);
644        }
645    }
646
647    quote! {{
648        let x: Vec<String> = vec![#({
649            let s: String = #docstring.to_string();
650            s.strip_prefix(" ").unwrap_or(&s).to_string()
651        }),*];
652        x.join("\n")
653    }}
654}
655
656fn generate_function(item: syn::ItemFn) -> Intermediate {
657    let syn::ItemFn {
658        attrs,
659        vis: _,
660        sig,
661        block: _,
662    } = item.clone();
663
664    let docstring = gather_docstring(&attrs);
665
666    assert!(sig.unsafety.is_none());
667    assert!(sig.variadic.is_none());
668
669    let ident = sig.ident;
670    let args: Vec<_> = sig
671        .inputs
672        .iter()
673        .map(|i| {
674            let syn::FnArg::Typed(syn::PatType { pat, .. }) = i else {
675                panic!()
676            };
677            pat
678        })
679        .collect();
680
681    let parameter_names = quote!( [#(stringify!(#args)),*] );
682
683    Intermediate {
684        function: quote!(#item),
685        ident,
686        docstring,
687        parameter_names,
688    }
689}