1use proc_macro::TokenStream as TokenStream1;
5use proc_macro2::{Group, Literal, Punct, Span, TokenStream, TokenTree};
6use quote::{quote, quote_spanned, ToTokens, TokenStreamExt};
7use syn::parse::{self, Parse, ParseStream, Parser};
8use syn::punctuated::Punctuated;
9use syn::visit_mut::{self, VisitMut};
10use syn::{
11 token, AttrStyle, Attribute, Block, Expr, ExprAsync, Path, Signature, Stmt, Token, Visibility,
12};
13
14mod block;
15mod function;
16mod stream;
17
18#[proc_macro_attribute]
19pub fn completion(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
20 let (CompletionAttr { crate_path, boxed }, input) = match (syn::parse(attr), syn::parse(input))
21 {
22 (Ok(attr), Ok(input)) => (attr, input),
23 (Ok(_), Err(e)) | (Err(e), Ok(_)) => return e.into_compile_error().into(),
24 (Err(mut e1), Err(e2)) => {
25 e1.combine(e2);
26 return e1.into_compile_error().into();
27 }
28 };
29 match input {
30 CompletionInput::AsyncFn(f) => function::transform(f, boxed, &crate_path),
31 CompletionInput::AsyncBlock(async_block, semi) => {
32 let tokens = block::transform(async_block, &crate_path);
33 quote!(#tokens #semi)
34 }
35 }
36 .into()
37}
38
39struct CompletionAttr {
40 crate_path: CratePath,
41 boxed: Option<Boxed>,
42}
43impl Parse for CompletionAttr {
44 fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
45 let mut crate_path = None;
46 let mut boxed = None;
47
48 while !input.is_empty() {
49 if input.peek(Token![crate]) {
50 if crate_path.is_some() {
51 return Err(input.error("duplicate crate option"));
52 }
53 input.parse::<Token![crate]>()?;
54 input.parse::<Token![=]>()?;
55 crate_path = Some(
56 input
57 .parse::<Path>()?
58 .into_token_stream()
59 .into_iter()
60 .map(|mut token| {
61 token.set_span(Span::call_site());
62 token
63 })
64 .collect(),
65 );
66 } else if input.peek(Token![box]) {
67 if boxed.is_some() {
68 return Err(input.error("duplicate boxed option"));
69 }
70 let span = input.parse::<Token![box]>()?.span;
71 let send = input.peek(token::Paren);
72 if send {
73 let content;
74 syn::parenthesized!(content in input);
75 content.parse::<Token![?]>()?;
76 syn::custom_keyword!(Send);
77 content.parse::<Send>()?;
78 }
79 boxed = Some(Boxed { span, send });
80 } else {
81 return Err(input.error("expected `crate` or `box`"));
82 }
83
84 if input.is_empty() {
85 break;
86 }
87
88 input.parse::<Token![,]>()?;
89 }
90
91 Ok(Self {
92 crate_path: CratePath::new(crate_path.unwrap_or_else(|| quote!(::completion))),
93 boxed,
94 })
95 }
96}
97
98struct Boxed {
99 span: Span,
100 send: bool,
101}
102
103enum CompletionInput {
105 AsyncFn(AnyFn),
106 AsyncBlock(ExprAsync, Option<Token![;]>),
107}
108impl Parse for CompletionInput {
109 fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
110 let mut attrs = input.call(Attribute::parse_outer)?;
111
112 Ok(
113 if input.peek(Token![async]) && (input.peek2(Token![move]) || input.peek2(token::Brace))
114 {
115 let mut block: ExprAsync = input.parse()?;
116 block.attrs.append(&mut attrs);
117 CompletionInput::AsyncBlock(block, input.parse()?)
118 } else {
119 let mut f: AnyFn = input.parse()?;
120 f.attrs.append(&mut attrs);
121 CompletionInput::AsyncFn(f)
122 },
123 )
124 }
125}
126
127struct AnyFn {
129 attrs: Vec<Attribute>,
130 vis: Visibility,
131 sig: Signature,
132 block: Option<Block>,
133 semi_token: Option<Token![;]>,
134}
135impl Parse for AnyFn {
136 fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
137 let mut attrs = input.call(Attribute::parse_outer)?;
138 let vis: Visibility = input.parse()?;
139 let sig: Signature = input.parse()?;
140
141 let (block, semi_token) = if input.peek(Token![;]) {
142 (None, Some(input.parse::<Token![;]>()?))
143 } else {
144 let content;
145 let brace_token = syn::braced!(content in input);
146 attrs.append(&mut content.call(Attribute::parse_inner)?);
147 let stmts = content.call(Block::parse_within)?;
148 (Some(Block { brace_token, stmts }), None)
149 };
150
151 Ok(Self {
152 attrs,
153 vis,
154 sig,
155 block,
156 semi_token,
157 })
158 }
159}
160impl ToTokens for AnyFn {
161 fn to_tokens(&self, tokens: &mut TokenStream) {
162 tokens.append_all(
163 self.attrs
164 .iter()
165 .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
166 );
167 self.vis.to_tokens(tokens);
168 self.sig.to_tokens(tokens);
169 if let Some(block) = &self.block {
170 block.brace_token.surround(tokens, |tokens| {
171 tokens.append_all(
172 self.attrs
173 .iter()
174 .filter(|attr| matches!(attr.style, AttrStyle::Inner(_))),
175 );
176 tokens.append_all(&block.stmts);
177 });
178 }
179 if let Some(semi_token) = &self.semi_token {
180 semi_token.to_tokens(tokens);
181 }
182 }
183}
184
185#[proc_macro]
186#[doc(hidden)]
187pub fn completion_async_inner(input: TokenStream1) -> TokenStream1 {
188 completion_async_inner2(input.into(), false).into()
189}
190#[proc_macro]
191#[doc(hidden)]
192pub fn completion_async_move_inner(input: TokenStream1) -> TokenStream1 {
193 completion_async_inner2(input.into(), true).into()
194}
195
196fn completion_async_inner2(input: TokenStream, capture_move: bool) -> TokenStream {
197 let (crate_path, stmts) = match parse_bang_input.parse2(input) {
198 Ok(input) => input,
199 Err(e) => return e.into_compile_error(),
200 };
201 block::transform(call_site_async(capture_move, stmts), &crate_path)
202}
203
204#[proc_macro]
205#[doc(hidden)]
206pub fn completion_stream_inner(input: TokenStream1) -> TokenStream1 {
207 let (crate_path, stmts) = match parse_bang_input.parse(input) {
208 Ok(r) => r,
209 Err(e) => return e.into_compile_error().into(),
210 };
211 stream::transform(call_site_async(true, stmts), &crate_path).into()
212}
213
214fn parse_bang_input(input: ParseStream<'_>) -> parse::Result<(CratePath, Vec<Stmt>)> {
215 let crate_path = CratePath::new(input.parse::<Group>().unwrap().stream());
216 let item = Block::parse_within(input)?;
217 Ok((crate_path, item))
218}
219
220fn call_site_async(capture_move: bool, stmts: Vec<Stmt>) -> ExprAsync {
222 ExprAsync {
223 attrs: Vec::new(),
224 async_token: Token),
225 capture: if capture_move {
226 Some(Token))
227 } else {
228 None
229 },
230 block: Block {
231 brace_token: token::Brace {
232 span: Span::call_site(),
233 },
234 stmts,
235 },
236 }
237}
238
239struct CratePath {
240 inner: TokenStream,
241}
242impl CratePath {
243 fn new(inner: TokenStream) -> Self {
244 Self { inner }
245 }
246 fn with_span(&self, span: Span) -> impl ToTokens + '_ {
247 struct CratePathWithSpan<'a>(&'a TokenStream, Span);
248
249 impl ToTokens for CratePathWithSpan<'_> {
250 fn to_tokens(&self, tokens: &mut TokenStream) {
251 tokens.extend(self.0.clone().into_iter().map(|mut token| {
252 token.set_span(token.span().located_at(self.1));
253 token
254 }));
255 }
256 }
257
258 CratePathWithSpan(&self.inner, span)
259 }
260}
261
262fn transform_top_level(stmts: &mut [Stmt], crate_path: &CratePath, f: impl FnMut(&mut Expr)) {
264 struct Visitor<'a, F> {
265 crate_path: &'a CratePath,
266 f: F,
267 }
268
269 impl<F: FnMut(&mut Expr)> VisitMut for Visitor<'_, F> {
270 fn visit_expr_mut(&mut self, expr: &mut Expr) {
271 match expr {
272 Expr::Async(_) | Expr::Closure(_) => {
273 }
275 Expr::Macro(expr_macro) => {
276 const SPECIAL_MACROS: &[&str] = &[
281 "assert",
282 "assert_eq",
283 "assert_ne",
284 "dbg",
285 "debug_assert",
286 "debug_assert_eq",
287 "debug_assert_ne",
288 "eprint",
289 "eprintln",
290 "format",
291 "format_args",
292 "matches",
293 "panic",
294 "print",
295 "println",
296 "todo",
297 "unimplemented",
298 "unreachable",
299 "vec",
300 "write",
301 "writeln",
302 ];
303
304 let mut is_trusted =
305 token_stream_starts_with(expr_macro.mac.path.to_token_stream(), {
306 let crate_path = self.crate_path.with_span(Span::call_site());
307 quote!(#crate_path::__special_macros::)
308 });
309
310 if !is_trusted
311 && SPECIAL_MACROS
312 .iter()
313 .any(|name| expr_macro.mac.path.is_ident(name))
314 {
315 let macro_ident = expr_macro.mac.path.get_ident().unwrap();
316 let crate_path = self.crate_path.with_span(macro_ident.span());
317 let path = quote_spanned!(macro_ident.span()=> #crate_path::__special_macros::#macro_ident);
318 expr_macro.mac.path = syn::parse2(path).unwrap();
319 is_trusted = true;
320 }
321
322 if is_trusted {
323 let last_segment = expr_macro.mac.path.segments.last().unwrap();
324
325 match &*last_segment.ident.to_string() {
326 "matches" => {
327 let res =
328 expr_macro.mac.parse_body_with(|tokens: ParseStream<'_>| {
329 let expr = tokens.parse::<Expr>()?;
330 let rest = tokens.parse::<TokenStream>()?;
331 Ok((expr, rest))
332 });
333 if let Ok((mut scrutinee, rest)) = res {
334 self.visit_expr_mut(&mut scrutinee);
335 expr_macro.mac.tokens = scrutinee.into_token_stream();
336 expr_macro.mac.tokens.extend(rest.into_token_stream());
337 }
338 }
339 _ => {
340 let res = expr_macro
341 .mac
342 .parse_body_with(<Punctuated<_, Token![,]>>::parse_terminated);
343 if let Ok(mut exprs) = res {
344 for expr in &mut exprs {
345 self.visit_expr_mut(expr);
346 }
347 expr_macro.mac.tokens = exprs.into_token_stream();
348 }
349 }
350 }
351 }
352 }
353 _ => {
354 visit_mut::visit_expr_mut(self, expr);
355 }
356 }
357 (self.f)(expr);
358 }
359 fn visit_item_mut(&mut self, _: &mut syn::Item) {
360 }
362 }
363
364 let mut visitor = Visitor { crate_path, f };
365 for stmt in stmts {
366 visitor.visit_stmt_mut(stmt);
367 }
368}
369
370fn token_stream_starts_with(tokens: TokenStream, prefix: TokenStream) -> bool {
371 let mut tokens = tokens.into_iter();
372
373 for prefix_token in prefix {
374 let token = match tokens.next() {
375 Some(token) => token,
376 None => return false,
377 };
378 if !token_tree_eq(&prefix_token, &token) {
379 return false;
380 }
381 }
382
383 true
384}
385
386fn token_stream_eq(lhs: TokenStream, rhs: TokenStream) -> bool {
387 lhs.into_iter()
388 .zip(rhs)
389 .all(|(lhs, rhs)| token_tree_eq(&lhs, &rhs))
390}
391fn token_tree_eq(lhs: &TokenTree, rhs: &TokenTree) -> bool {
392 match (lhs, rhs) {
393 (TokenTree::Group(lhs), TokenTree::Group(rhs)) => group_eq(lhs, rhs),
394 (TokenTree::Ident(lhs), TokenTree::Ident(rhs)) => lhs == rhs,
395 (TokenTree::Punct(lhs), TokenTree::Punct(rhs)) => punct_eq(lhs, rhs),
396 (TokenTree::Literal(lhs), TokenTree::Literal(rhs)) => literal_eq(lhs, rhs),
397 (_, _) => false,
398 }
399}
400fn group_eq(lhs: &Group, rhs: &Group) -> bool {
401 lhs.delimiter() == rhs.delimiter() && token_stream_eq(lhs.stream(), rhs.stream())
402}
403fn punct_eq(lhs: &Punct, rhs: &Punct) -> bool {
404 lhs.as_char() == rhs.as_char() && lhs.spacing() == rhs.spacing()
405}
406fn literal_eq(lhs: &Literal, rhs: &Literal) -> bool {
407 lhs.to_string() == rhs.to_string()
408}
409
410struct OuterAttrs<'a>(&'a [Attribute]);
411impl ToTokens for OuterAttrs<'_> {
412 fn to_tokens(&self, tokens: &mut TokenStream) {
413 tokens.append_all(
414 self.0
415 .iter()
416 .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
417 )
418 }
419}