delegate/
attributes.rs

1use std::collections::VecDeque;
2
3use syn::parse::ParseStream;
4use syn::{Attribute, Error, Meta, TypePath};
5
6struct CallMethodAttribute {
7    name: syn::Ident,
8}
9
10impl syn::parse::Parse for CallMethodAttribute {
11    fn parse(input: ParseStream) -> Result<Self, Error> {
12        Ok(CallMethodAttribute {
13            name: input.parse()?,
14        })
15    }
16}
17
18struct GenerateAwaitAttribute {
19    literal: syn::LitBool,
20}
21
22impl syn::parse::Parse for GenerateAwaitAttribute {
23    fn parse(input: ParseStream) -> Result<Self, Error> {
24        Ok(GenerateAwaitAttribute {
25            literal: input.parse()?,
26        })
27    }
28}
29
30struct IntoAttribute {
31    type_path: Option<TypePath>,
32}
33
34impl syn::parse::Parse for IntoAttribute {
35    fn parse(input: ParseStream) -> Result<Self, Error> {
36        let type_path: TypePath = input.parse().map_err(|error| {
37            Error::new(
38                input.span(),
39                format!("{error}\nExpected type name, e.g. #[into(u32)]"),
40            )
41        })?;
42
43        Ok(IntoAttribute {
44            type_path: Some(type_path),
45        })
46    }
47}
48
49pub struct TraitTarget {
50    type_path: TypePath,
51}
52
53impl syn::parse::Parse for TraitTarget {
54    fn parse(input: ParseStream) -> Result<Self, Error> {
55        let type_path: TypePath = input.parse().map_err(|error| {
56            Error::new(
57                input.span(),
58                format!("{error}\nExpected trait path, e.g. #[through(foo::MyTrait)]"),
59            )
60        })?;
61
62        Ok(TraitTarget { type_path })
63    }
64}
65
66#[derive(Clone)]
67pub enum ReturnExpression {
68    Into(Option<TypePath>),
69    TryInto,
70    Unwrap,
71}
72
73enum ParsedAttribute {
74    ReturnExpression(ReturnExpression),
75    Await(bool),
76    TargetMethod(syn::Ident),
77    ThroughTrait(TraitTarget),
78}
79
80fn parse_attributes(
81    attrs: &[Attribute],
82) -> (
83    impl Iterator<Item = ParsedAttribute> + '_,
84    impl Iterator<Item = &Attribute>,
85) {
86    let (parsed, other): (Vec<_>, Vec<_>) = attrs
87        .iter()
88        .map(|attribute| {
89            let parsed = if let syn::AttrStyle::Outer = attribute.style {
90                let name = attribute
91                    .path()
92                    .get_ident()
93                    .map(|i| i.to_string())
94                    .unwrap_or_default();
95                match name.as_str() {
96                    "call" => {
97                        let target = attribute
98                            .parse_args::<CallMethodAttribute>()
99                            .expect("Cannot parse `call` attribute");
100                        Some(ParsedAttribute::TargetMethod(target.name))
101                    }
102                    "into" => {
103                        let into = match &attribute.meta {
104                            Meta::NameValue(_) => {
105                                panic!("Cannot parse `into` attribute: expected parentheses")
106                            }
107                            Meta::Path(_) => IntoAttribute { type_path: None },
108                            Meta::List(meta) => meta
109                                .parse_args::<IntoAttribute>()
110                                .expect("Cannot parse `into` attribute"),
111                        };
112                        Some(ParsedAttribute::ReturnExpression(ReturnExpression::Into(
113                            into.type_path,
114                        )))
115                    }
116                    "try_into" => {
117                        if let Meta::List(meta) = &attribute.meta {
118                            meta.parse_nested_meta(|meta| {
119                                if meta.path.is_ident("unwrap") {
120                                    panic!(
121                                        "Replace #[try_into(unwrap)] with\n#[try_into]\n#[unwrap]",
122                                    );
123                                }
124                                Ok(())
125                            })
126                            .expect("Invalid `try_into` arguments");
127                        }
128                        Some(ParsedAttribute::ReturnExpression(ReturnExpression::TryInto))
129                    }
130                    "unwrap" => Some(ParsedAttribute::ReturnExpression(ReturnExpression::Unwrap)),
131                    "await" => {
132                        let generate = attribute
133                            .parse_args::<GenerateAwaitAttribute>()
134                            .expect("Cannot parse `await` attribute");
135                        Some(ParsedAttribute::Await(generate.literal.value))
136                    }
137                    "through" => Some(ParsedAttribute::ThroughTrait(
138                        attribute
139                            .parse_args::<TraitTarget>()
140                            .expect("Cannot parse `through` attribute"),
141                    )),
142                    _ => None,
143                }
144            } else {
145                None
146            };
147
148            (parsed, attribute)
149        })
150        .partition(|(parsed, _)| parsed.is_some());
151    (
152        parsed.into_iter().map(|(parsed, _)| parsed.unwrap()),
153        other.into_iter().map(|(_, attr)| attr),
154    )
155}
156
157pub struct MethodAttributes<'a> {
158    pub attributes: Vec<&'a Attribute>,
159    pub target_method: Option<syn::Ident>,
160    pub expressions: VecDeque<ReturnExpression>,
161    pub generate_await: Option<bool>,
162    pub target_trait: Option<TypePath>,
163}
164
165/// Iterates through the attributes of a method and filters special attributes.
166/// - call => sets the name of the target method to call
167/// - into => generates a `into()` call after the delegated expression
168/// - try_into => generates a `try_into()` call after the delegated expression
169/// - await => generates an `.await` expression after the delegated expression
170/// - unwrap => generates a `unwrap()` call after the delegated expression
171/// - through => generates a UFCS call (`Target::method(&<expr>, ...)`) around the delegated expression
172pub fn parse_method_attributes<'a>(
173    attrs: &'a [Attribute],
174    method: &syn::TraitItemFn,
175) -> MethodAttributes<'a> {
176    let mut target_method: Option<syn::Ident> = None;
177    let mut expressions: Vec<ReturnExpression> = vec![];
178    let mut generate_await: Option<bool> = None;
179    let mut target_trait: Option<TraitTarget> = None;
180
181    let (parsed, other) = parse_attributes(attrs);
182    for attr in parsed {
183        match attr {
184            ParsedAttribute::ReturnExpression(expr) => expressions.push(expr),
185            ParsedAttribute::Await(value) => {
186                if generate_await.is_some() {
187                    panic!(
188                        "Multiple `await` attributes specified for {}",
189                        method.sig.ident
190                    )
191                }
192                generate_await = Some(value);
193            }
194            ParsedAttribute::TargetMethod(target) => {
195                if target_method.is_some() {
196                    panic!(
197                        "Multiple call attributes specified for {}",
198                        method.sig.ident
199                    )
200                }
201                target_method = Some(target);
202            }
203            ParsedAttribute::ThroughTrait(target) => {
204                if target_trait.is_some() {
205                    panic!(
206                        "Multiple through attributes specified for {}",
207                        method.sig.ident
208                    )
209                }
210                target_trait = Some(target);
211            }
212        }
213    }
214
215    MethodAttributes {
216        attributes: other.into_iter().collect(),
217        target_method,
218        generate_await,
219        expressions: expressions.into(),
220        target_trait: target_trait.map(|t| t.type_path),
221    }
222}
223
224pub struct SegmentAttributes {
225    pub expressions: Vec<ReturnExpression>,
226    pub generate_await: Option<bool>,
227    pub target_trait: Option<TypePath>,
228    pub other_attrs: Vec<Attribute>,
229}
230
231pub fn parse_segment_attributes(attrs: &[Attribute]) -> SegmentAttributes {
232    let mut expressions: Vec<ReturnExpression> = vec![];
233    let mut generate_await: Option<bool> = None;
234    let mut target_trait: Option<TraitTarget> = None;
235
236    let (parsed, other) = parse_attributes(attrs);
237
238    for attribute in parsed {
239        match attribute {
240            ParsedAttribute::ReturnExpression(expr) => expressions.push(expr),
241            ParsedAttribute::Await(value) => {
242                if generate_await.is_some() {
243                    panic!("Multiple `await` attributes specified for segment");
244                }
245                generate_await = Some(value);
246            }
247            ParsedAttribute::ThroughTrait(target) => {
248                if target_trait.is_some() {
249                    panic!("Multiple `through` attributes specified for segment");
250                }
251                target_trait = Some(target);
252            }
253            ParsedAttribute::TargetMethod(_) => {
254                panic!("Call attribute cannot be specified on a `to <expr>` segment.");
255            }
256        }
257    }
258    SegmentAttributes {
259        expressions,
260        generate_await,
261        target_trait: target_trait.map(|t| t.type_path),
262        other_attrs: other.cloned().collect::<Vec<_>>(),
263    }
264}
265
266/// Applies default values from the segment and adds them to the method attributes.
267pub fn combine_attributes<'a>(
268    mut method_attrs: MethodAttributes<'a>,
269    segment_attrs: &'a SegmentAttributes,
270) -> MethodAttributes<'a> {
271    let SegmentAttributes {
272        expressions,
273        generate_await,
274        target_trait,
275        other_attrs,
276    } = segment_attrs;
277
278    if method_attrs.generate_await.is_none() {
279        method_attrs.generate_await = *generate_await;
280    }
281
282    if method_attrs.target_trait.is_none() {
283        method_attrs.target_trait.clone_from(target_trait);
284    }
285
286    for expr in expressions {
287        match expr {
288            ReturnExpression::Into(path) => {
289                if !method_attrs
290                    .expressions
291                    .iter()
292                    .any(|expr| matches!(expr, ReturnExpression::Into(_)))
293                {
294                    method_attrs
295                        .expressions
296                        .push_front(ReturnExpression::Into(path.clone()));
297                }
298            }
299            _ => method_attrs.expressions.push_front(expr.clone()),
300        }
301    }
302
303    for other_attr in other_attrs {
304        if !method_attrs
305            .attributes
306            .iter()
307            .any(|attr| attr.path().get_ident() == other_attr.path().get_ident())
308        {
309            method_attrs.attributes.push(other_attr);
310        }
311    }
312
313    method_attrs
314}