1use std::collections::VecDeque;
2
3use syn::parse::ParseStream;
4use syn::{Attribute, Error, Meta, TypePath};
5
6struct CallMethodAttribute {
7 name: syn::Ident,
8}
9
10impl syn::parse::Parse for CallMethodAttribute {
11 fn parse(input: ParseStream) -> Result<Self, Error> {
12 Ok(CallMethodAttribute {
13 name: input.parse()?,
14 })
15 }
16}
17
18struct GenerateAwaitAttribute {
19 literal: syn::LitBool,
20}
21
22impl syn::parse::Parse for GenerateAwaitAttribute {
23 fn parse(input: ParseStream) -> Result<Self, Error> {
24 Ok(GenerateAwaitAttribute {
25 literal: input.parse()?,
26 })
27 }
28}
29
30struct IntoAttribute {
31 type_path: Option<TypePath>,
32}
33
34impl syn::parse::Parse for IntoAttribute {
35 fn parse(input: ParseStream) -> Result<Self, Error> {
36 let type_path: TypePath = input.parse().map_err(|error| {
37 Error::new(
38 input.span(),
39 format!("{error}\nExpected type name, e.g. #[into(u32)]"),
40 )
41 })?;
42
43 Ok(IntoAttribute {
44 type_path: Some(type_path),
45 })
46 }
47}
48
49pub struct TraitTarget {
50 type_path: TypePath,
51}
52
53impl syn::parse::Parse for TraitTarget {
54 fn parse(input: ParseStream) -> Result<Self, Error> {
55 let type_path: TypePath = input.parse().map_err(|error| {
56 Error::new(
57 input.span(),
58 format!("{error}\nExpected trait path, e.g. #[through(foo::MyTrait)]"),
59 )
60 })?;
61
62 Ok(TraitTarget { type_path })
63 }
64}
65
66#[derive(Clone)]
67pub enum ReturnExpression {
68 Into(Option<TypePath>),
69 TryInto,
70 Unwrap,
71}
72
73enum ParsedAttribute {
74 ReturnExpression(ReturnExpression),
75 Await(bool),
76 TargetMethod(syn::Ident),
77 ThroughTrait(TraitTarget),
78}
79
80fn parse_attributes(
81 attrs: &[Attribute],
82) -> (
83 impl Iterator<Item = ParsedAttribute> + '_,
84 impl Iterator<Item = &Attribute>,
85) {
86 let (parsed, other): (Vec<_>, Vec<_>) = attrs
87 .iter()
88 .map(|attribute| {
89 let parsed = if let syn::AttrStyle::Outer = attribute.style {
90 let name = attribute
91 .path()
92 .get_ident()
93 .map(|i| i.to_string())
94 .unwrap_or_default();
95 match name.as_str() {
96 "call" => {
97 let target = attribute
98 .parse_args::<CallMethodAttribute>()
99 .expect("Cannot parse `call` attribute");
100 Some(ParsedAttribute::TargetMethod(target.name))
101 }
102 "into" => {
103 let into = match &attribute.meta {
104 Meta::NameValue(_) => {
105 panic!("Cannot parse `into` attribute: expected parentheses")
106 }
107 Meta::Path(_) => IntoAttribute { type_path: None },
108 Meta::List(meta) => meta
109 .parse_args::<IntoAttribute>()
110 .expect("Cannot parse `into` attribute"),
111 };
112 Some(ParsedAttribute::ReturnExpression(ReturnExpression::Into(
113 into.type_path,
114 )))
115 }
116 "try_into" => {
117 if let Meta::List(meta) = &attribute.meta {
118 meta.parse_nested_meta(|meta| {
119 if meta.path.is_ident("unwrap") {
120 panic!(
121 "Replace #[try_into(unwrap)] with\n#[try_into]\n#[unwrap]",
122 );
123 }
124 Ok(())
125 })
126 .expect("Invalid `try_into` arguments");
127 }
128 Some(ParsedAttribute::ReturnExpression(ReturnExpression::TryInto))
129 }
130 "unwrap" => Some(ParsedAttribute::ReturnExpression(ReturnExpression::Unwrap)),
131 "await" => {
132 let generate = attribute
133 .parse_args::<GenerateAwaitAttribute>()
134 .expect("Cannot parse `await` attribute");
135 Some(ParsedAttribute::Await(generate.literal.value))
136 }
137 "through" => Some(ParsedAttribute::ThroughTrait(
138 attribute
139 .parse_args::<TraitTarget>()
140 .expect("Cannot parse `through` attribute"),
141 )),
142 _ => None,
143 }
144 } else {
145 None
146 };
147
148 (parsed, attribute)
149 })
150 .partition(|(parsed, _)| parsed.is_some());
151 (
152 parsed.into_iter().map(|(parsed, _)| parsed.unwrap()),
153 other.into_iter().map(|(_, attr)| attr),
154 )
155}
156
157pub struct MethodAttributes<'a> {
158 pub attributes: Vec<&'a Attribute>,
159 pub target_method: Option<syn::Ident>,
160 pub expressions: VecDeque<ReturnExpression>,
161 pub generate_await: Option<bool>,
162 pub target_trait: Option<TypePath>,
163}
164
165pub fn parse_method_attributes<'a>(
173 attrs: &'a [Attribute],
174 method: &syn::TraitItemFn,
175) -> MethodAttributes<'a> {
176 let mut target_method: Option<syn::Ident> = None;
177 let mut expressions: Vec<ReturnExpression> = vec![];
178 let mut generate_await: Option<bool> = None;
179 let mut target_trait: Option<TraitTarget> = None;
180
181 let (parsed, other) = parse_attributes(attrs);
182 for attr in parsed {
183 match attr {
184 ParsedAttribute::ReturnExpression(expr) => expressions.push(expr),
185 ParsedAttribute::Await(value) => {
186 if generate_await.is_some() {
187 panic!(
188 "Multiple `await` attributes specified for {}",
189 method.sig.ident
190 )
191 }
192 generate_await = Some(value);
193 }
194 ParsedAttribute::TargetMethod(target) => {
195 if target_method.is_some() {
196 panic!(
197 "Multiple call attributes specified for {}",
198 method.sig.ident
199 )
200 }
201 target_method = Some(target);
202 }
203 ParsedAttribute::ThroughTrait(target) => {
204 if target_trait.is_some() {
205 panic!(
206 "Multiple through attributes specified for {}",
207 method.sig.ident
208 )
209 }
210 target_trait = Some(target);
211 }
212 }
213 }
214
215 MethodAttributes {
216 attributes: other.into_iter().collect(),
217 target_method,
218 generate_await,
219 expressions: expressions.into(),
220 target_trait: target_trait.map(|t| t.type_path),
221 }
222}
223
224pub struct SegmentAttributes {
225 pub expressions: Vec<ReturnExpression>,
226 pub generate_await: Option<bool>,
227 pub target_trait: Option<TypePath>,
228 pub other_attrs: Vec<Attribute>,
229}
230
231pub fn parse_segment_attributes(attrs: &[Attribute]) -> SegmentAttributes {
232 let mut expressions: Vec<ReturnExpression> = vec![];
233 let mut generate_await: Option<bool> = None;
234 let mut target_trait: Option<TraitTarget> = None;
235
236 let (parsed, other) = parse_attributes(attrs);
237
238 for attribute in parsed {
239 match attribute {
240 ParsedAttribute::ReturnExpression(expr) => expressions.push(expr),
241 ParsedAttribute::Await(value) => {
242 if generate_await.is_some() {
243 panic!("Multiple `await` attributes specified for segment");
244 }
245 generate_await = Some(value);
246 }
247 ParsedAttribute::ThroughTrait(target) => {
248 if target_trait.is_some() {
249 panic!("Multiple `through` attributes specified for segment");
250 }
251 target_trait = Some(target);
252 }
253 ParsedAttribute::TargetMethod(_) => {
254 panic!("Call attribute cannot be specified on a `to <expr>` segment.");
255 }
256 }
257 }
258 SegmentAttributes {
259 expressions,
260 generate_await,
261 target_trait: target_trait.map(|t| t.type_path),
262 other_attrs: other.cloned().collect::<Vec<_>>(),
263 }
264}
265
266pub fn combine_attributes<'a>(
268 mut method_attrs: MethodAttributes<'a>,
269 segment_attrs: &'a SegmentAttributes,
270) -> MethodAttributes<'a> {
271 let SegmentAttributes {
272 expressions,
273 generate_await,
274 target_trait,
275 other_attrs,
276 } = segment_attrs;
277
278 if method_attrs.generate_await.is_none() {
279 method_attrs.generate_await = *generate_await;
280 }
281
282 if method_attrs.target_trait.is_none() {
283 method_attrs.target_trait.clone_from(target_trait);
284 }
285
286 for expr in expressions {
287 match expr {
288 ReturnExpression::Into(path) => {
289 if !method_attrs
290 .expressions
291 .iter()
292 .any(|expr| matches!(expr, ReturnExpression::Into(_)))
293 {
294 method_attrs
295 .expressions
296 .push_front(ReturnExpression::Into(path.clone()));
297 }
298 }
299 _ => method_attrs.expressions.push_front(expr.clone()),
300 }
301 }
302
303 for other_attr in other_attrs {
304 if !method_attrs
305 .attributes
306 .iter()
307 .any(|attr| attr.path().get_ident() == other_attr.path().get_ident())
308 {
309 method_attrs.attributes.push(other_attr);
310 }
311 }
312
313 method_attrs
314}