completion_macro/
lib.rs

1//! Macro to generate completion-based async functions and blocks. This crate shouldn't be used
2//! directly, instead use `completion`.
3
4use proc_macro::TokenStream as TokenStream1;
5use proc_macro2::{Group, Literal, Punct, Span, TokenStream, TokenTree};
6use quote::{quote, quote_spanned, ToTokens, TokenStreamExt};
7use syn::parse::{self, Parse, ParseStream, Parser};
8use syn::punctuated::Punctuated;
9use syn::visit_mut::{self, VisitMut};
10use syn::{
11    token, AttrStyle, Attribute, Block, Expr, ExprAsync, Path, Signature, Stmt, Token, Visibility,
12};
13
14mod block;
15mod function;
16mod stream;
17
18#[proc_macro_attribute]
19pub fn completion(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
20    let (CompletionAttr { crate_path, boxed }, input) = match (syn::parse(attr), syn::parse(input))
21    {
22        (Ok(attr), Ok(input)) => (attr, input),
23        (Ok(_), Err(e)) | (Err(e), Ok(_)) => return e.into_compile_error().into(),
24        (Err(mut e1), Err(e2)) => {
25            e1.combine(e2);
26            return e1.into_compile_error().into();
27        }
28    };
29    match input {
30        CompletionInput::AsyncFn(f) => function::transform(f, boxed, &crate_path),
31        CompletionInput::AsyncBlock(async_block, semi) => {
32            let tokens = block::transform(async_block, &crate_path);
33            quote!(#tokens #semi)
34        }
35    }
36    .into()
37}
38
39struct CompletionAttr {
40    crate_path: CratePath,
41    boxed: Option<Boxed>,
42}
43impl Parse for CompletionAttr {
44    fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
45        let mut crate_path = None;
46        let mut boxed = None;
47
48        while !input.is_empty() {
49            if input.peek(Token![crate]) {
50                if crate_path.is_some() {
51                    return Err(input.error("duplicate crate option"));
52                }
53                input.parse::<Token![crate]>()?;
54                input.parse::<Token![=]>()?;
55                crate_path = Some(
56                    input
57                        .parse::<Path>()?
58                        .into_token_stream()
59                        .into_iter()
60                        .map(|mut token| {
61                            token.set_span(Span::call_site());
62                            token
63                        })
64                        .collect(),
65                );
66            } else if input.peek(Token![box]) {
67                if boxed.is_some() {
68                    return Err(input.error("duplicate boxed option"));
69                }
70                let span = input.parse::<Token![box]>()?.span;
71                let send = input.peek(token::Paren);
72                if send {
73                    let content;
74                    syn::parenthesized!(content in input);
75                    content.parse::<Token![?]>()?;
76                    syn::custom_keyword!(Send);
77                    content.parse::<Send>()?;
78                }
79                boxed = Some(Boxed { span, send });
80            } else {
81                return Err(input.error("expected `crate` or `box`"));
82            }
83
84            if input.is_empty() {
85                break;
86            }
87
88            input.parse::<Token![,]>()?;
89        }
90
91        Ok(Self {
92            crate_path: CratePath::new(crate_path.unwrap_or_else(|| quote!(::completion))),
93            boxed,
94        })
95    }
96}
97
98struct Boxed {
99    span: Span,
100    send: bool,
101}
102
103/// Input to the `#[completion]` attribute macro.
104enum CompletionInput {
105    AsyncFn(AnyFn),
106    AsyncBlock(ExprAsync, Option<Token![;]>),
107}
108impl Parse for CompletionInput {
109    fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
110        let mut attrs = input.call(Attribute::parse_outer)?;
111
112        Ok(
113            if input.peek(Token![async]) && (input.peek2(Token![move]) || input.peek2(token::Brace))
114            {
115                let mut block: ExprAsync = input.parse()?;
116                block.attrs.append(&mut attrs);
117                CompletionInput::AsyncBlock(block, input.parse()?)
118            } else {
119                let mut f: AnyFn = input.parse()?;
120                f.attrs.append(&mut attrs);
121                CompletionInput::AsyncFn(f)
122            },
123        )
124    }
125}
126
127/// Any kind of function.
128struct AnyFn {
129    attrs: Vec<Attribute>,
130    vis: Visibility,
131    sig: Signature,
132    block: Option<Block>,
133    semi_token: Option<Token![;]>,
134}
135impl Parse for AnyFn {
136    fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
137        let mut attrs = input.call(Attribute::parse_outer)?;
138        let vis: Visibility = input.parse()?;
139        let sig: Signature = input.parse()?;
140
141        let (block, semi_token) = if input.peek(Token![;]) {
142            (None, Some(input.parse::<Token![;]>()?))
143        } else {
144            let content;
145            let brace_token = syn::braced!(content in input);
146            attrs.append(&mut content.call(Attribute::parse_inner)?);
147            let stmts = content.call(Block::parse_within)?;
148            (Some(Block { brace_token, stmts }), None)
149        };
150
151        Ok(Self {
152            attrs,
153            vis,
154            sig,
155            block,
156            semi_token,
157        })
158    }
159}
160impl ToTokens for AnyFn {
161    fn to_tokens(&self, tokens: &mut TokenStream) {
162        tokens.append_all(
163            self.attrs
164                .iter()
165                .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
166        );
167        self.vis.to_tokens(tokens);
168        self.sig.to_tokens(tokens);
169        if let Some(block) = &self.block {
170            block.brace_token.surround(tokens, |tokens| {
171                tokens.append_all(
172                    self.attrs
173                        .iter()
174                        .filter(|attr| matches!(attr.style, AttrStyle::Inner(_))),
175                );
176                tokens.append_all(&block.stmts);
177            });
178        }
179        if let Some(semi_token) = &self.semi_token {
180            semi_token.to_tokens(tokens);
181        }
182    }
183}
184
185#[proc_macro]
186#[doc(hidden)]
187pub fn completion_async_inner(input: TokenStream1) -> TokenStream1 {
188    completion_async_inner2(input.into(), false).into()
189}
190#[proc_macro]
191#[doc(hidden)]
192pub fn completion_async_move_inner(input: TokenStream1) -> TokenStream1 {
193    completion_async_inner2(input.into(), true).into()
194}
195
196fn completion_async_inner2(input: TokenStream, capture_move: bool) -> TokenStream {
197    let (crate_path, stmts) = match parse_bang_input.parse2(input) {
198        Ok(input) => input,
199        Err(e) => return e.into_compile_error(),
200    };
201    block::transform(call_site_async(capture_move, stmts), &crate_path)
202}
203
204#[proc_macro]
205#[doc(hidden)]
206pub fn completion_stream_inner(input: TokenStream1) -> TokenStream1 {
207    let (crate_path, stmts) = match parse_bang_input.parse(input) {
208        Ok(r) => r,
209        Err(e) => return e.into_compile_error().into(),
210    };
211    stream::transform(call_site_async(true, stmts), &crate_path).into()
212}
213
214fn parse_bang_input(input: ParseStream<'_>) -> parse::Result<(CratePath, Vec<Stmt>)> {
215    let crate_path = CratePath::new(input.parse::<Group>().unwrap().stream());
216    let item = Block::parse_within(input)?;
217    Ok((crate_path, item))
218}
219
220/// Create an async block at the call site.
221fn call_site_async(capture_move: bool, stmts: Vec<Stmt>) -> ExprAsync {
222    ExprAsync {
223        attrs: Vec::new(),
224        async_token: Token![async](Span::call_site()),
225        capture: if capture_move {
226            Some(Token![move](Span::call_site()))
227        } else {
228            None
229        },
230        block: Block {
231            brace_token: token::Brace {
232                span: Span::call_site(),
233            },
234            stmts,
235        },
236    }
237}
238
239struct CratePath {
240    inner: TokenStream,
241}
242impl CratePath {
243    fn new(inner: TokenStream) -> Self {
244        Self { inner }
245    }
246    fn with_span(&self, span: Span) -> impl ToTokens + '_ {
247        struct CratePathWithSpan<'a>(&'a TokenStream, Span);
248
249        impl ToTokens for CratePathWithSpan<'_> {
250            fn to_tokens(&self, tokens: &mut TokenStream) {
251                tokens.extend(self.0.clone().into_iter().map(|mut token| {
252                    token.set_span(token.span().located_at(self.1));
253                    token
254                }));
255            }
256        }
257
258        CratePathWithSpan(&self.inner, span)
259    }
260}
261
262/// Transform the top level of a list of statements.
263fn transform_top_level(stmts: &mut [Stmt], crate_path: &CratePath, f: impl FnMut(&mut Expr)) {
264    struct Visitor<'a, F> {
265        crate_path: &'a CratePath,
266        f: F,
267    }
268
269    impl<F: FnMut(&mut Expr)> VisitMut for Visitor<'_, F> {
270        fn visit_expr_mut(&mut self, expr: &mut Expr) {
271            match expr {
272                Expr::Async(_) | Expr::Closure(_) => {
273                    // Don't do anything, we don't want to touch inner async blocks or closures.
274                }
275                Expr::Macro(expr_macro) => {
276                    // Normally we don't transform the bodies of macros as they could do anything
277                    // with the tokens they're given. However we special-case standard library
278                    // macros to allow.
279
280                    const SPECIAL_MACROS: &[&str] = &[
281                        "assert",
282                        "assert_eq",
283                        "assert_ne",
284                        "dbg",
285                        "debug_assert",
286                        "debug_assert_eq",
287                        "debug_assert_ne",
288                        "eprint",
289                        "eprintln",
290                        "format",
291                        "format_args",
292                        "matches",
293                        "panic",
294                        "print",
295                        "println",
296                        "todo",
297                        "unimplemented",
298                        "unreachable",
299                        "vec",
300                        "write",
301                        "writeln",
302                    ];
303
304                    let mut is_trusted =
305                        token_stream_starts_with(expr_macro.mac.path.to_token_stream(), {
306                            let crate_path = self.crate_path.with_span(Span::call_site());
307                            quote!(#crate_path::__special_macros::)
308                        });
309
310                    if !is_trusted
311                        && SPECIAL_MACROS
312                            .iter()
313                            .any(|name| expr_macro.mac.path.is_ident(name))
314                    {
315                        let macro_ident = expr_macro.mac.path.get_ident().unwrap();
316                        let crate_path = self.crate_path.with_span(macro_ident.span());
317                        let path = quote_spanned!(macro_ident.span()=> #crate_path::__special_macros::#macro_ident);
318                        expr_macro.mac.path = syn::parse2(path).unwrap();
319                        is_trusted = true;
320                    }
321
322                    if is_trusted {
323                        let last_segment = expr_macro.mac.path.segments.last().unwrap();
324
325                        match &*last_segment.ident.to_string() {
326                            "matches" => {
327                                let res =
328                                    expr_macro.mac.parse_body_with(|tokens: ParseStream<'_>| {
329                                        let expr = tokens.parse::<Expr>()?;
330                                        let rest = tokens.parse::<TokenStream>()?;
331                                        Ok((expr, rest))
332                                    });
333                                if let Ok((mut scrutinee, rest)) = res {
334                                    self.visit_expr_mut(&mut scrutinee);
335                                    expr_macro.mac.tokens = scrutinee.into_token_stream();
336                                    expr_macro.mac.tokens.extend(rest.into_token_stream());
337                                }
338                            }
339                            _ => {
340                                let res = expr_macro
341                                    .mac
342                                    .parse_body_with(<Punctuated<_, Token![,]>>::parse_terminated);
343                                if let Ok(mut exprs) = res {
344                                    for expr in &mut exprs {
345                                        self.visit_expr_mut(expr);
346                                    }
347                                    expr_macro.mac.tokens = exprs.into_token_stream();
348                                }
349                            }
350                        }
351                    }
352                }
353                _ => {
354                    visit_mut::visit_expr_mut(self, expr);
355                }
356            }
357            (self.f)(expr);
358        }
359        fn visit_item_mut(&mut self, _: &mut syn::Item) {
360            // Don't do anything, we don't want to touch inner items.
361        }
362    }
363
364    let mut visitor = Visitor { crate_path, f };
365    for stmt in stmts {
366        visitor.visit_stmt_mut(stmt);
367    }
368}
369
370fn token_stream_starts_with(tokens: TokenStream, prefix: TokenStream) -> bool {
371    let mut tokens = tokens.into_iter();
372
373    for prefix_token in prefix {
374        let token = match tokens.next() {
375            Some(token) => token,
376            None => return false,
377        };
378        if !token_tree_eq(&prefix_token, &token) {
379            return false;
380        }
381    }
382
383    true
384}
385
386fn token_stream_eq(lhs: TokenStream, rhs: TokenStream) -> bool {
387    lhs.into_iter()
388        .zip(rhs)
389        .all(|(lhs, rhs)| token_tree_eq(&lhs, &rhs))
390}
391fn token_tree_eq(lhs: &TokenTree, rhs: &TokenTree) -> bool {
392    match (lhs, rhs) {
393        (TokenTree::Group(lhs), TokenTree::Group(rhs)) => group_eq(lhs, rhs),
394        (TokenTree::Ident(lhs), TokenTree::Ident(rhs)) => lhs == rhs,
395        (TokenTree::Punct(lhs), TokenTree::Punct(rhs)) => punct_eq(lhs, rhs),
396        (TokenTree::Literal(lhs), TokenTree::Literal(rhs)) => literal_eq(lhs, rhs),
397        (_, _) => false,
398    }
399}
400fn group_eq(lhs: &Group, rhs: &Group) -> bool {
401    lhs.delimiter() == rhs.delimiter() && token_stream_eq(lhs.stream(), rhs.stream())
402}
403fn punct_eq(lhs: &Punct, rhs: &Punct) -> bool {
404    lhs.as_char() == rhs.as_char() && lhs.spacing() == rhs.spacing()
405}
406fn literal_eq(lhs: &Literal, rhs: &Literal) -> bool {
407    lhs.to_string() == rhs.to_string()
408}
409
410struct OuterAttrs<'a>(&'a [Attribute]);
411impl ToTokens for OuterAttrs<'_> {
412    fn to_tokens(&self, tokens: &mut TokenStream) {
413        tokens.append_all(
414            self.0
415                .iter()
416                .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
417        )
418    }
419}