delegate/
lib.rs

1//! This crate removes some boilerplate for structs that simply delegate
2//! some of their methods to one or more of their fields.
3//!
4//! It gives you the `delegate!` macro, which delegates method calls to selected expressions (usually inner fields).
5//!
6//! ## Features:
7//! - Delegate to a method with a different name
8//! ```rust
9//! use delegate::delegate;
10//!
11//! struct Stack { inner: Vec<u32> }
12//! impl Stack {
13//!     delegate! {
14//!         to self.inner {
15//!             #[call(push)]
16//!             pub fn add(&mut self, value: u32);
17//!         }
18//!     }
19//! }
20//! ```
21//! - Use an arbitrary inner field expression
22//! ```rust
23//! use delegate::delegate;
24//!
25//! use std::rc::Rc;
26//! use std::cell::RefCell;
27//! use std::ops::Deref;
28//!
29//! struct Wrapper { inner: Rc<RefCell<Vec<u32>>> }
30//! impl Wrapper {
31//!     delegate! {
32//!         to self.inner.deref().borrow_mut() {
33//!             pub fn push(&mut self, val: u32);
34//!         }
35//!     }
36//! }
37//! ```
38//!
39//! - Delegate to enum variants
40//!
41//! ```rust
42//! use delegate::delegate;
43//!
44//! enum Enum {
45//!     A(A),
46//!     B(B),
47//!     C { v: C },
48//! }
49//!
50//! struct A {
51//!     val: usize,
52//! }
53//!
54//! impl A {
55//!     fn dbg_inner(&self) -> usize {
56//!         dbg!(self.val);
57//!         1
58//!     }
59//! }
60//! struct B {
61//!     val_a: String,
62//! }
63//!
64//! impl B {
65//!     fn dbg_inner(&self) -> usize {
66//!         dbg!(self.val_a.clone());
67//!         2
68//!     }
69//! }
70//!
71//! struct C {
72//!     val_c: f64,
73//! }
74//!
75//! impl C {
76//!     fn dbg_inner(&self) -> usize {
77//!         dbg!(self.val_c);
78//!         3
79//!     }
80//! }
81//!
82//! impl Enum {
83//!     delegate! {
84//!         // transformed to
85//!         //
86//!         // ```rust
87//!         // match self {
88//!         //     Enum::A(a) => a.dbg_inner(),
89//!         //     Enum::B(b) => { println!("i am b"); b }.dbg_inner(),
90//!         //     Enum::C { v: c } => { c }.dbg_inner(),
91//!         // }
92//!         // ```
93//!         to match self {
94//!             Enum::A(a) => a,
95//!             Enum::B(b) => { println!("i am b"); b },
96//!             Enum::C { v: c } => { c },
97//!         } {
98//!             fn dbg_inner(&self) -> usize;
99//!         }
100//!     }
101//! }
102//! ```
103//!
104//! - Use modifiers that alter the generated method body
105//! ```rust
106//! use delegate::delegate;
107//! struct Inner;
108//! impl Inner {
109//!     pub fn method(&self, num: u32) -> u32 { num }
110//!     pub fn method_res(&self, num: u32) -> Result<u32, ()> { Ok(num) }
111//! }
112//! struct Wrapper { inner: Inner }
113//! impl Wrapper {
114//!     delegate! {
115//!         to self.inner {
116//!             // calls method, converts result to u64 using `From`
117//!             #[into]
118//!             pub fn method(&self, num: u32) -> u64;
119//!
120//!             // calls method, returns ()
121//!             #[call(method)]
122//!             pub fn method_noreturn(&self, num: u32);
123//!
124//!             // calls method, converts result to i6 using `TryFrom`
125//!             #[try_into]
126//!             #[call(method)]
127//!             pub fn method2(&self, num: u32) -> Result<u16, std::num::TryFromIntError>;
128//!
129//!             // calls method_res, unwraps the result
130//!             #[unwrap]
131//!             pub fn method_res(&self, num: u32) -> u32;
132//!
133//!             // calls method_res, unwraps the result, then calls into
134//!             #[unwrap]
135//!             #[into]
136//!             #[call(method_res)]
137//!             pub fn method_res_into(&self, num: u32) -> u64;
138//!
139//!             // specify explicit type for into
140//!             #[into(u64)]
141//!             #[call(method)]
142//!             pub fn method_into_explicit(&self, num: u32) -> u64;
143//!         }
144//!     }
145//! }
146//! ```
147//! - Call `await` on async functions
148//! ```rust
149//! use delegate::delegate;
150//!
151//! struct Inner;
152//! impl Inner {
153//!     pub async fn method(&self, num: u32) -> u32 { num }
154//! }
155//! struct Wrapper { inner: Inner }
156//! impl Wrapper {
157//!     delegate! {
158//!         to self.inner {
159//!             // calls method(num).await, returns impl Future<Output = u32>
160//!             pub async fn method(&self, num: u32) -> u32;
161//!
162//!             // calls method(num).await.into(), returns impl Future<Output = u64>
163//!             #[into]
164//!             #[call(method)]
165//!             pub async fn method_into(&self, num: u32) -> u64;
166//!         }
167//!     }
168//! }
169//! ```
170//! You can use the `#[await(true/false)]` attribute on delegated methods to specify if `.await` should
171//! be generated after the delegated expression. It will be generated by default if the delegated
172//! method is `async`.
173//! - Delegate to multiple fields
174//! ```rust
175//! use delegate::delegate;
176//!
177//! struct MultiStack {
178//!     left: Vec<u32>,
179//!     right: Vec<u32>,
180//! }
181//! impl MultiStack {
182//!     delegate! {
183//!         to self.left {
184//!             // Push an item to the top of the left stack
185//!             #[call(push)]
186//!             pub fn push_left(&mut self, value: u32);
187//!         }
188//!         to self.right {
189//!             // Push an item to the top of the right stack
190//!             #[call(push)]
191//!             pub fn push_right(&mut self, value: u32);
192//!         }
193//!     }
194//! }
195//! ```
196//! - Inserts `#[inline(always)]` automatically (unless you specify `#[inline]` manually on the method)
197//! - You can use an attribute on a whole segment to automatically apply it to all methods in that
198//!   segment:
199//! ```rust
200//! use delegate::delegate;
201//!
202//! struct Inner;
203//!
204//! impl Inner {
205//!   fn foo(&self) -> Result<u32, ()> { Ok(0) }
206//!   fn bar(&self) -> Result<u32, ()> { Ok(1) }
207//! }
208//!
209//! struct Wrapper { inner: Inner }
210//!
211//! impl Wrapper {
212//!   delegate! {
213//!     #[unwrap]
214//!     to self.inner {
215//!       fn foo(&self) -> u32; // calls self.inner.foo().unwrap()
216//!       fn bar(&self) -> u32; // calls self.inner.bar().unwrap()
217//!     }
218//!   }
219//! }
220//! ```
221//! - Specify expressions in the signature that will be used as delegated arguments
222//! ```rust
223//! use delegate::delegate;
224//! struct Inner;
225//! impl Inner {
226//!     pub fn polynomial(&self, a: i32, x: i32, b: i32, y: i32, c: i32) -> i32 {
227//!         a + x * x + b * y + c
228//!     }
229//! }
230//! struct Wrapper { inner: Inner, a: i32, b: i32, c: i32 }
231//! impl Wrapper {
232//!     delegate! {
233//!         to self.inner {
234//!             // Calls `polynomial` on `inner` with `self.a`, `self.b` and
235//!             // `self.c` passed as arguments `a`, `b`, and `c`, effectively
236//!             // calling `polynomial(self.a, x, self.b, y, self.c)`.
237//!             pub fn polynomial(&self, [ self.a ], x: i32, [ self.b ], y: i32, [ self.c ]) -> i32 ;
238//!             // Calls `polynomial` on `inner` with `0`s passed for arguments
239//!             // `a` and `x`, and `self.b` and `self.c` for `b` and `c`,
240//!             // effectively calling `polynomial(0, 0, self.b, y, self.c)`.
241//!             #[call(polynomial)]
242//!             pub fn linear(&self, [ 0 ], [ 0 ], [ self.b ], y: i32, [ self.c ]) -> i32 ;
243//!         }
244//!     }
245//! }
246//! ```
247//! - Modify how will an input parameter be passed to the delegated method with parameter attribute modifiers.
248//!   Currently, the following modifiers are supported:
249//!     - `#[into]`: Calls `.into()` on the parameter passed to the delegated method.
250//!     - `#[as_ref]`: Calls `.as_ref()` on the parameter passed to the delegated method.
251//!     - `#[newtype]`: Calls `.0` on the parameter passed to the delegated method.
252//! ```rust
253//! use delegate::delegate;
254//!
255//! struct InnerType {}
256//! impl InnerType {
257//!     fn foo(&self, other: Self) {}
258//! }
259//!
260//! impl From<Wrapper> for InnerType {
261//!     fn from(wrapper: Wrapper) -> Self {
262//!         wrapper.0
263//!     }
264//! }
265//!
266//! struct Wrapper(InnerType);
267//! impl Wrapper {
268//!     delegate! {
269//!         to self.0 {
270//!             // Calls `self.0.foo(other.into());`
271//!             pub fn foo(&self, #[into] other: Self);
272//!         }
273//!     }
274//! }
275//! ```
276//! - Specify a trait through which will the delegated method be called
277//!   (using [UFCS](https://doc.rust-lang.org/reference/expressions/call-expr.html#disambiguating-function-calls).
278//! ```rust
279//! use delegate::delegate;
280//!
281//! struct InnerType {}
282//! impl InnerType {
283//!     
284//! }
285//!
286//! trait MyTrait {
287//!   fn foo(&self);
288//! }
289//! impl MyTrait for InnerType {
290//!   fn foo(&self) {}
291//! }
292//!
293//! struct Wrapper(InnerType);
294//! impl Wrapper {
295//!     delegate! {
296//!         to &self.0 {
297//!             // Calls `MyTrait::foo(&self.0)`
298//!             #[through(MyTrait)]
299//!             pub fn foo(&self);
300//!         }
301//!     }
302//! }
303//! ```
304//!
305//! - Add additional arguments to method
306//!
307//!  ```rust
308//!  use delegate::delegate;
309//!  use std::cell::OnceCell;
310//!  struct Inner(u32);
311//!  impl Inner {
312//!      pub fn new(m: u32) -> Self {
313//!          // some "very complex" constructing work
314//!          Self(m)
315//!      }
316//!      pub fn method(&self, n: u32) -> u32 {
317//!          self.0 + n
318//!      }
319//!  }
320//!  
321//!  struct Wrapper {
322//!      inner: OnceCell<Inner>,
323//!  }
324//!  
325//!  impl Wrapper {
326//!      pub fn new() -> Self {
327//!          Self {
328//!              inner: OnceCell::new(),
329//!          }
330//!      }
331//!      fn content(&self, val: u32) -> &Inner {
332//!          self.inner.get_or_init(|| Inner(val))
333//!      }
334//!      delegate! {
335//!          to |k: u32| self.content(k) {
336//!              // `wrapper.method(k, num)` will call `self.content(k).method(num)`
337//!              pub fn method(&self, num: u32) -> u32;
338//!          }
339//!      }
340//!  }
341//!  ```
342//! - Delegate associated functions
343//!   ```rust
344//!   use delegate::delegate;
345//!
346//!   struct A {}
347//!   impl A {
348//!       fn foo(a: u32) -> u32 {
349//!           a + 1
350//!       }
351//!   }
352//!
353//!   struct B;
354//!
355//!   impl B {
356//!       delegate! {
357//!           to A {
358//!               fn foo(a: u32) -> u32;
359//!           }
360//!       }
361//!   }
362//!
363//!   assert_eq!(B::foo(1), 2);
364//!   ```
365
366extern crate proc_macro;
367use std::mem;
368
369use proc_macro::TokenStream;
370
371use proc_macro2::Ident;
372use quote::{quote, ToTokens};
373use syn::parse::{Parse, ParseStream};
374use syn::spanned::Spanned;
375use syn::visit_mut::VisitMut;
376use syn::{parse_quote, Error, Expr, ExprField, ExprMethodCall, FnArg, GenericParam, Meta};
377
378use crate::attributes::{
379    combine_attributes, parse_method_attributes, parse_segment_attributes, ReturnExpression,
380    SegmentAttributes,
381};
382
383mod attributes;
384
385mod kw {
386    syn::custom_keyword!(to);
387    syn::custom_keyword!(target);
388}
389
390#[derive(Clone)]
391enum ArgumentModifier {
392    Into,
393    AsRef,
394    Newtype,
395}
396
397#[derive(Clone)]
398enum DelegatedInput {
399    Input {
400        parameter: syn::FnArg,
401        modifier: Option<ArgumentModifier>,
402    },
403    Argument(syn::Expr),
404}
405
406fn get_argument_modifier(attribute: syn::Attribute) -> Result<ArgumentModifier, Error> {
407    if let Meta::Path(mut path) = attribute.meta {
408        if path.segments.len() == 1 {
409            let segment = path.segments.pop().unwrap();
410            if segment.value().arguments.is_empty() {
411                let ident = segment.value().ident.to_string();
412                let ident = ident.as_str();
413
414                match ident {
415                    "into" => return Ok(ArgumentModifier::Into),
416                    "as_ref" => return Ok(ArgumentModifier::AsRef),
417                    "newtype" => return Ok(ArgumentModifier::Newtype),
418                    _ => (),
419                }
420            }
421        }
422    };
423
424    panic!("The attribute argument has to be `into` or `as_ref`, like this: `#[into] a: u32`.")
425}
426
427impl syn::parse::Parse for DelegatedInput {
428    fn parse(input: ParseStream) -> Result<Self, Error> {
429        let lookahead = input.lookahead1();
430        if lookahead.peek(syn::token::Bracket) {
431            let content;
432            let _bracket_token = syn::bracketed!(content in input);
433            let expression: syn::Expr = content.parse()?;
434            Ok(Self::Argument(expression))
435        } else {
436            let (input, modifier) = if lookahead.peek(syn::token::Pound) {
437                let mut attributes = input.call(tolerant_outer_attributes)?;
438                if attributes.len() > 1 {
439                    panic!("You can specify at most a single attribute for each parameter in a delegated method");
440                }
441                let modifier = get_argument_modifier(attributes.pop().unwrap())
442                    .expect("Could not parse argument modifier attribute");
443
444                let input: syn::FnArg = input.parse()?;
445                (input, Some(modifier))
446            } else {
447                (input.parse()?, None)
448            };
449
450            Ok(Self::Input {
451                parameter: input,
452                modifier,
453            })
454        }
455    }
456}
457
458struct DelegatedMethod {
459    method: syn::TraitItemFn,
460    attributes: Vec<syn::Attribute>,
461    visibility: syn::Visibility,
462    arguments: syn::punctuated::Punctuated<syn::Expr, syn::Token![,]>,
463}
464
465// Given an input parameter from a function signature, create a function
466// argument used to call the delegate function: omit receiver, extract an
467// identifier from a typed input parameter (and wrap it in an `Expr`).
468fn parse_input_into_argument_expression(
469    function_name: &Ident,
470    input: &syn::FnArg,
471) -> Option<syn::Expr> {
472    match input {
473        // Parse inputs of the form `x: T` to retrieve their identifiers.
474        syn::FnArg::Typed(typed) => {
475            match &*typed.pat {
476                // This should not happen, I think. If it does,
477                // it will be ignored as if it were the
478                // receiver.
479                syn::Pat::Ident(ident) if ident.ident == "self" => None,
480                // Expression in the form `x: T`. Extract the
481                // identifier, wrap it in Expr for type compatibility with bracketed expressions,
482                // and append it
483                // to the argument list.
484                syn::Pat::Ident(ident) => {
485                    let path_segment = syn::PathSegment {
486                        ident: ident.ident.clone(),
487                        arguments: syn::PathArguments::None,
488                    };
489                    let mut segments = syn::punctuated::Punctuated::new();
490                    segments.push(path_segment);
491                    let path = syn::Path {
492                        leading_colon: None,
493                        segments,
494                    };
495                    let ident_as_expr = syn::Expr::from(syn::ExprPath {
496                        attrs: Vec::new(),
497                        qself: None,
498                        path,
499                    });
500                    Some(ident_as_expr)
501                }
502                // Other more complex argument expressions are not covered.
503                _ => panic!(
504                    "You have to use simple identifiers for delegated method parameters ({})",
505                    function_name // The signature is not constructed yet. We make due.
506                ),
507            }
508        }
509        // Skip any `self`/`&self`/`&mut self` argument, since
510        // it does not appear in the argument list and it's
511        // already added to the parameter list.
512        syn::FnArg::Receiver(_receiver) => None,
513    }
514}
515
516impl syn::parse::Parse for DelegatedMethod {
517    fn parse(input: ParseStream) -> Result<Self, Error> {
518        let attributes = input.call(tolerant_outer_attributes)?;
519        let visibility = input.call(syn::Visibility::parse)?;
520
521        // Unchanged from Parse from TraitItemMethod
522        let constness: Option<syn::Token![const]> = input.parse()?;
523        let asyncness: Option<syn::Token![async]> = input.parse()?;
524        let unsafety: Option<syn::Token![unsafe]> = input.parse()?;
525        let abi: Option<syn::Abi> = input.parse()?;
526        let fn_token: syn::Token![fn] = input.parse()?;
527        let ident: Ident = input.parse()?;
528        let generics: syn::Generics = input.parse()?;
529
530        let content;
531        let paren_token = syn::parenthesized!(content in input);
532
533        // Parse inputs (method parameters) and arguments. The parameters
534        // constitute the parameter list of the signature of the delegating
535        // method so it must include all inputs, except bracketed expressions.
536        // The argument list constitutes the list of arguments used to call the
537        // delegated function. It must include all inputs, excluding the
538        // receiver (self-type) input. The arguments must all be parsed to
539        // retrieve the expressions inside of the brackets as well as variable
540        // identifiers of ordinary inputs. The arguments must preserve the order
541        // of the inputs.
542        let delegated_inputs = content.parse_terminated(DelegatedInput::parse, syn::Token![,])?;
543        let mut inputs: syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]> =
544            syn::punctuated::Punctuated::new();
545        let mut arguments: syn::punctuated::Punctuated<syn::Expr, syn::Token![,]> =
546            syn::punctuated::Punctuated::new();
547
548        // First, combine the cases for pairs with cases for end, to remove
549        // redundancy below.
550        delegated_inputs
551            .into_pairs()
552            .map(|punctuated_pair| match punctuated_pair {
553                syn::punctuated::Pair::Punctuated(item, comma) => (item, Some(comma)),
554                syn::punctuated::Pair::End(item) => (item, None),
555            })
556            .for_each(|pair| match pair {
557                // This input is a bracketed argument (eg. `[ self.x ]`). It
558                // is omitted in the signature of the delegator, but the
559                // expression inside the brackets is used in the body of the
560                // delegator, as an arugnment to the delegated function (eg.
561                // `self.x`). The argument needs to be generated in the
562                // appropriate position with respect other arguments and non-
563                // argument inputs. As long as inputs are added to the
564                // `arguments` vector in order of occurance, this is trivial.
565                (DelegatedInput::Argument(argument), maybe_comma) => {
566                    arguments.push_value(argument);
567                    if let Some(comma) = maybe_comma {
568                        arguments.push_punct(comma)
569                    }
570                }
571                // The input is a standard function parameter with a name and
572                // a type (eg. `x: T`). This input needs to be reflected in
573                // the delegator signature as is (eg. `x: T`). The identifier
574                // also needs to be included in the argument list in part
575                // (eg. `x`). The argument list needs to preserve the order of
576                // the inputs with relation to arguments (see above), so the
577                // parsing is best done here (previously it was done at
578                // generation).
579                (
580                    DelegatedInput::Input {
581                        parameter,
582                        modifier,
583                    },
584                    maybe_comma,
585                ) => {
586                    inputs.push_value(parameter.clone());
587                    if let Some(comma) = maybe_comma {
588                        inputs.push_punct(comma);
589                    }
590                    let maybe_argument = parse_input_into_argument_expression(&ident, &parameter);
591                    if let Some(mut argument) = maybe_argument {
592                        let span = argument.span();
593
594                        if let Some(modifier) = modifier {
595                            let method_call = |name: &str| {
596                                syn::Expr::from(ExprMethodCall {
597                                    attrs: vec![],
598                                    receiver: Box::new(argument.clone()),
599                                    dot_token: Default::default(),
600                                    method: Ident::new(name, span),
601                                    turbofish: None,
602                                    paren_token,
603                                    args: Default::default(),
604                                })
605                            };
606
607                            let field_call = || {
608                                syn::Expr::from(ExprField {
609                                    attrs: vec![],
610                                    base: Box::new(argument.clone()),
611                                    dot_token: Default::default(),
612                                    member: syn::Member::Unnamed(0.into()),
613                                })
614                            };
615
616                            match modifier {
617                                ArgumentModifier::Into => {
618                                    argument = method_call("into");
619                                }
620                                ArgumentModifier::AsRef => {
621                                    argument = method_call("as_ref");
622                                }
623                                ArgumentModifier::Newtype => argument = field_call(),
624                            }
625                        }
626
627                        arguments.push(argument);
628                        if let Some(comma) = maybe_comma {
629                            arguments.push_punct(comma);
630                        }
631                    }
632                }
633            });
634
635        // Unchanged from Parse from TraitItemMethod
636        let output: syn::ReturnType = input.parse()?;
637        let where_clause: Option<syn::WhereClause> = input.parse()?;
638
639        // This needs to be generated manually, because inputs need to be
640        // separated into actual inputs that go in the signature (the
641        // parameters) and the additional expressions in square brackets which
642        // go into the arguments vector (artguments of the call on the method
643        // on the inner object).
644        let signature = syn::Signature {
645            constness,
646            asyncness,
647            unsafety,
648            abi,
649            fn_token,
650            ident,
651            paren_token,
652            inputs,
653            output,
654            variadic: None,
655            generics: syn::Generics {
656                where_clause,
657                ..generics
658            },
659        };
660
661        // Check if the input contains a semicolon or a brace. If it contains
662        // a semicolon, we parse it (to retain token location information) and
663        // continue. However, if it contains a brace, this indicates that
664        // there is a default definition of the method. This is not supported,
665        // so in that case we error out.
666        let lookahead = input.lookahead1();
667        let semi_token: Option<syn::Token![;]> = if lookahead.peek(syn::Token![;]) {
668            Some(input.parse()?)
669        } else {
670            panic!(
671                "Do not include implementation of delegated functions ({})",
672                signature.ident
673            );
674        };
675
676        // This needs to be populated from scratch because of the signature above.
677        let method = syn::TraitItemFn {
678            // All attributes are attached to `DelegatedMethod`, since they
679            // presumably pertain to the process of delegation, not the
680            // signature of the delegator.
681            attrs: Vec::new(),
682            sig: signature,
683            default: None,
684            semi_token,
685        };
686
687        Ok(DelegatedMethod {
688            method,
689            attributes,
690            visibility,
691            arguments,
692        })
693    }
694}
695
696struct DelegatedSegment {
697    delegator: syn::Expr,
698    methods: Vec<DelegatedMethod>,
699    segment_attrs: SegmentAttributes,
700}
701
702impl syn::parse::Parse for DelegatedSegment {
703    fn parse(input: ParseStream) -> Result<Self, Error> {
704        let attributes = input.call(tolerant_outer_attributes)?;
705        let segment_attrs = parse_segment_attributes(&attributes);
706
707        if let Ok(keyword) = input.parse::<kw::target>() {
708            return Err(Error::new(keyword.span(), "You are using the old `target` expression, which is deprecated. Please replace `target` with `to`."));
709        } else {
710            input.parse::<kw::to>()?;
711        }
712
713        syn::Expr::parse_without_eager_brace(input).and_then(|delegator| {
714            let content;
715            syn::braced!(content in input);
716
717            let mut methods = vec![];
718            while !content.is_empty() {
719                methods.push(
720                    content
721                        .parse::<DelegatedMethod>()
722                        .expect("Cannot parse delegated method"),
723                );
724            }
725
726            Ok(DelegatedSegment {
727                delegator,
728                methods,
729                segment_attrs,
730            })
731        })
732    }
733}
734
735struct DelegationBlock {
736    segments: Vec<DelegatedSegment>,
737}
738
739impl syn::parse::Parse for DelegationBlock {
740    fn parse(input: ParseStream) -> Result<Self, Error> {
741        let mut segments = vec![];
742        while !input.is_empty() {
743            segments.push(input.parse()?);
744        }
745
746        Ok(DelegationBlock { segments })
747    }
748}
749
750/// Returns true if there are any `inline` attributes in the input.
751fn has_inline_attribute(attrs: &[&syn::Attribute]) -> bool {
752    attrs.iter().any(|attr| {
753        if let syn::AttrStyle::Outer = attr.style {
754            attr.path().is_ident("inline")
755        } else {
756            false
757        }
758    })
759}
760
761struct MatchVisitor<F>(F);
762
763impl<F: Fn(&Expr) -> proc_macro2::TokenStream> VisitMut for MatchVisitor<F> {
764    fn visit_arm_mut(&mut self, arm: &mut syn::Arm) {
765        let transformed = self.0(&arm.body);
766        arm.body = parse_quote!(#transformed);
767    }
768}
769
770#[proc_macro]
771pub fn delegate(tokens: TokenStream) -> TokenStream {
772    let block: DelegationBlock = syn::parse_macro_input!(tokens);
773    let sections = block.segments.iter().map(|delegator| {
774        let delegated_expr = &delegator.delegator;
775        let functions = delegator.methods.iter().map(|method| {
776            let input = &method.method;
777            let mut signature = input.sig.clone();
778            if let Expr::Closure(closure) = delegated_expr {
779                let additional_inputs: Vec<FnArg> = closure
780                    .inputs
781                    .iter()
782                    .map(|input| {
783                        if let syn::Pat::Type(pat_type) = input {
784                            syn::parse_quote!(#pat_type)
785                        } else {
786                            panic!(
787                                "Use a type pattern (`a: u32`) for delegation closure arguments"
788                            );
789                        }
790                    })
791                    .collect();
792                let mut origin_inputs = mem::take(&mut signature.inputs).into_iter();
793                // When delegating methods, `first_input` should be self or similar receivers
794                // Then we need to move it to first
795                // When delegating associated methods, it may be a trivial argument or does not even exist
796                // We just keep the origin order.
797                let first_input = origin_inputs.next();
798                match first_input {
799                    Some(FnArg::Receiver(receiver)) => {
800                        signature.inputs.push(FnArg::Receiver(receiver));
801                        signature.inputs.extend(additional_inputs);
802                    }
803                    Some(first_input) => {
804                        signature.inputs.extend(additional_inputs);
805                        signature.inputs.push(first_input);
806                    }
807                    _ => {
808                        signature.inputs.extend(additional_inputs);
809                    }
810                }
811                signature.inputs.extend(origin_inputs);
812            }
813            let attributes = parse_method_attributes(&method.attributes, input);
814            let attributes = combine_attributes(attributes, &delegator.segment_attrs);
815            if input.default.is_some() {
816                panic!(
817                    "Do not include implementation of delegated functions ({})",
818                    signature.ident
819                );
820            }
821
822            // Generate an argument vector from Punctuated list.
823            let args: Vec<Expr> = method.arguments.clone().into_iter().collect();
824            let name = match &attributes.target_method {
825                Some(n) => n,
826                None => &input.sig.ident,
827            };
828            let inline = if has_inline_attribute(&attributes.attributes) {
829                quote!()
830            } else {
831                quote! { #[inline] }
832            };
833            let visibility = &method.visibility;
834
835            let is_method = method.method.sig.receiver().is_some();
836
837            // Use the body of a closure (like `|k: u32| <body>`) as the delegation expression
838            let delegated_body = if let Expr::Closure(closure) = delegated_expr {
839                &closure.body
840            } else {
841                delegated_expr
842            };
843
844            let span = input.span();
845            let generate_await = attributes
846                .generate_await
847                .unwrap_or_else(|| method.method.sig.asyncness.is_some());
848
849            // fn method<'a, A, B> -> method::<'a, A, B>
850            let generic_params = &method.method.sig.generics.params;
851            let generics = if generic_params.is_empty() {
852                quote::quote! {}
853            } else {
854                let span = generic_params.span();
855                let mut params: syn::punctuated::Punctuated<
856                    proc_macro2::TokenStream,
857                    syn::Token![,],
858                > = syn::punctuated::Punctuated::new();
859                for param in generic_params.iter() {
860                    let token = match param {
861                        GenericParam::Lifetime(l) => {
862                            let token = &l.lifetime;
863                            let span = l.span();
864                            quote::quote_spanned! {span=> #token }
865                        }
866                        GenericParam::Type(t) => {
867                            let token = &t.ident;
868                            let span = t.span();
869                            quote::quote_spanned! {span=> #token }
870                        }
871                        GenericParam::Const(c) => {
872                            let token = &c.ident;
873                            let span = c.span();
874                            quote::quote_spanned! {span=> #token }
875                        }
876                    };
877                    params.push(token);
878                }
879                quote::quote_spanned! {span=> ::<#params> }
880            };
881
882            let modify_expr = |expr: &Expr| {
883                let body = if let Some(target_trait) = &attributes.target_trait {
884                    quote::quote! { #target_trait::#name#generics(#expr, #(#args),*) }
885                } else if is_method {
886                    quote::quote! { #expr.#name#generics(#(#args),*) }
887                } else {
888                    quote::quote! { #expr::#name#generics(#(#args),*) }
889                };
890
891                let mut body = if generate_await {
892                    quote::quote! { #body.await }
893                } else {
894                    body
895                };
896
897                for expression in &attributes.expressions {
898                    match expression {
899                        ReturnExpression::Into(type_name) => {
900                            body = match type_name {
901                                Some(name) => {
902                                    quote::quote! { ::core::convert::Into::<#name>::into(#body) }
903                                }
904                                None => quote::quote! { ::core::convert::Into::into(#body) },
905                            };
906                        }
907                        ReturnExpression::TryInto => {
908                            body = quote::quote! { ::core::convert::TryInto::try_into(#body) };
909                        }
910                        ReturnExpression::Unwrap => {
911                            body = quote::quote! { #body.unwrap() };
912                        }
913                    }
914                }
915                body
916            };
917            let mut body = if let Expr::Match(expr_match) = delegated_body {
918                let mut expr_match = expr_match.clone();
919                MatchVisitor(modify_expr).visit_expr_match_mut(&mut expr_match);
920                expr_match.into_token_stream()
921            } else {
922                modify_expr(delegated_body)
923            };
924
925            if let syn::ReturnType::Default = &signature.output {
926                body = quote::quote! { #body; };
927            };
928
929            let attrs = &attributes.attributes;
930            quote::quote_spanned! {span=>
931                #(#attrs)*
932                #inline
933                #visibility #signature {
934                    #body
935                }
936            }
937        });
938
939        quote! { #(#functions)* }
940    });
941
942    let result = quote! {
943        #(#sections)*
944    };
945    result.into()
946}
947
948// we cannot use `Attributes::parse_outer` directly, because it does not allow keywords to appear
949// in meta path positions, i.e., it does not accept `#[await(true)]`.
950// related issue: https://github.com/dtolnay/syn/issues/1458
951fn tolerant_outer_attributes(input: ParseStream) -> syn::Result<Vec<syn::Attribute>> {
952    use proc_macro2::{Delimiter, TokenTree};
953    use syn::{
954        bracketed,
955        ext::IdentExt,
956        parse::discouraged::Speculative,
957        token::{Brace, Bracket, Paren},
958        AttrStyle, Attribute, ExprLit, Lit, MacroDelimiter, MetaList, MetaNameValue, Path, Result,
959        Token,
960    };
961
962    fn tolerant_attr(input: ParseStream) -> Result<Attribute> {
963        let content;
964        Ok(Attribute {
965            pound_token: input.parse()?,
966            style: AttrStyle::Outer,
967            bracket_token: bracketed!(content in input),
968            meta: content.call(tolerant_meta)?,
969        })
970    }
971
972    // adapted from `impl Parse for Meta`
973    fn tolerant_meta(input: ParseStream) -> Result<Meta> {
974        // Try to parse as Meta
975        if let Ok(meta) = input.call(Meta::parse) {
976            Ok(meta)
977        } else {
978            // If it's not possible, try to parse it as any identifier, to support #[await]
979            let path = Path::from(input.call(Ident::parse_any)?);
980            if input.peek(Paren) || input.peek(Bracket) || input.peek(Brace) {
981                // adapted from the private `syn::attr::parse_meta_after_path`
982                input.step(|cursor| {
983                    if let Some((TokenTree::Group(g), rest)) = cursor.token_tree() {
984                        let span = g.delim_span();
985                        let delimiter = match g.delimiter() {
986                            Delimiter::Parenthesis => MacroDelimiter::Paren(Paren(span)),
987                            Delimiter::Brace => MacroDelimiter::Brace(Brace(span)),
988                            Delimiter::Bracket => MacroDelimiter::Bracket(Bracket(span)),
989                            Delimiter::None => {
990                                return Err(cursor.error("expected delimiter"));
991                            }
992                        };
993                        Ok((
994                            Meta::List(MetaList {
995                                path,
996                                delimiter,
997                                tokens: g.stream(),
998                            }),
999                            rest,
1000                        ))
1001                    } else {
1002                        Err(cursor.error("expected delimiter"))
1003                    }
1004                })
1005            } else if input.peek(Token![=]) {
1006                // adapted from the private `syn::attr::parse_meta_name_value_after_path`
1007                let eq_token = input.parse()?;
1008                let ahead = input.fork();
1009                let value = match ahead.parse::<Option<Lit>>()? {
1010                    // this branch is probably for speeding up the parsing for doc comments etc.
1011                    Some(lit) if ahead.is_empty() => {
1012                        input.advance_to(&ahead);
1013                        Expr::Lit(ExprLit {
1014                            attrs: Vec::new(),
1015                            lit,
1016                        })
1017                    }
1018                    _ => input.parse()?,
1019                };
1020                Ok(Meta::NameValue(MetaNameValue {
1021                    path,
1022                    eq_token,
1023                    value,
1024                }))
1025            } else {
1026                Ok(Meta::Path(path))
1027            }
1028        }
1029    }
1030
1031    let mut attrs = Vec::new();
1032    while input.peek(Token![#]) {
1033        attrs.push(input.call(tolerant_attr)?);
1034    }
1035    Ok(attrs)
1036}