async_proto_derive/
lib.rs

1#![deny(missing_docs, rust_2018_idioms, unused, unused_crate_dependencies, unused_import_braces, unused_lifetimes, unused_qualifications, warnings)]
2#![forbid(unsafe_code)]
3
4//! Procedural macros for the [`async-proto`](https://docs.rs/async-proto) crate.
5
6use {
7    std::convert::TryFrom as _,
8    itertools::Itertools as _,
9    proc_macro::TokenStream,
10    proc_macro2::Span,
11    quote::{
12        quote,
13        quote_spanned,
14    },
15    syn::{
16        *,
17        parse::{
18            Parse,
19            ParseStream,
20        },
21        punctuated::Punctuated,
22        spanned::Spanned as _,
23        token::{
24            Brace,
25            Paren,
26        },
27    },
28};
29
30fn read_fields(internal: bool, sync: bool, fields: &Fields) -> proc_macro2::TokenStream {
31    let async_proto_crate = if internal { quote!(crate) } else { quote!(::async_proto) };
32    let read = if sync { quote!(::read_sync(stream)) } else { quote!(::read(stream).await) };
33    match fields {
34        Fields::Unit => quote!(),
35        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
36            let read_fields = unnamed.iter()
37                .enumerate()
38                .map(|(idx, Field { attrs, ty, .. })| {
39                    let mut max_len = None;
40                    for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
41                        match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
42                            Ok(attrs) => for attr in attrs {
43                                match attr {
44                                    FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
45                                        return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
46                                    },
47                                }
48                            },
49                            Err(e) => return e.to_compile_error().into(),
50                        }
51                    }
52                    let read = if let Some(max_len) = max_len {
53                        let read = if sync { quote!(::read_length_prefixed_sync(stream, #max_len)) } else { quote!(::read_length_prefixed(stream, #max_len).await) };
54                        quote_spanned! {ty.span()=>
55                            <#ty as #async_proto_crate::LengthPrefixed>#read
56                        }
57                    } else {
58                        quote_spanned! {ty.span()=>
59                            <#ty as #async_proto_crate::Protocol>#read
60                        }
61                    };
62                    quote_spanned! {ty.span()=>
63                        #read.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
64                            context: #async_proto_crate::ErrorContext::UnnamedField {
65                                idx: #idx,
66                                source: Box::new(context),
67                            },
68                            kind,
69                        })?
70                    }
71                })
72                .collect_vec();
73            quote!((#(#read_fields,)*))
74        }
75        Fields::Named(FieldsNamed { named, .. }) => {
76            let read_fields = named.iter()
77                .map(|Field { attrs, ident, ty, .. }| {
78                    let mut max_len = None;
79                    for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
80                        match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
81                            Ok(attrs) => for attr in attrs {
82                                match attr {
83                                    FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
84                                        return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
85                                    },
86                                }
87                            },
88                            Err(e) => return e.to_compile_error().into(),
89                        }
90                    }
91                    let name = ident.as_ref().expect("FieldsNamed with unnamed field").to_string();
92                    let read = if let Some(max_len) = max_len {
93                        let read = if sync { quote!(::read_length_prefixed_sync(stream, #max_len)) } else { quote!(::read_length_prefixed(stream, #max_len).await) };
94                        quote_spanned! {ty.span()=>
95                            <#ty as #async_proto_crate::LengthPrefixed>#read
96                        }
97                    } else {
98                        quote_spanned! {ty.span()=>
99                            <#ty as #async_proto_crate::Protocol>#read
100                        }
101                    };
102                    quote_spanned! {ty.span()=>
103                        #ident: #read.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
104                            context: #async_proto_crate::ErrorContext::NamedField {
105                                name: #name,
106                                source: Box::new(context),
107                            },
108                            kind,
109                        })?
110                    }
111                })
112                .collect_vec();
113            quote!({ #(#read_fields,)* })
114        }
115    }
116}
117
118fn fields_pat(fields: &Fields) -> proc_macro2::TokenStream {
119    match fields {
120        Fields::Unit => quote!(),
121        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
122            let field_idents = unnamed.iter()
123                .enumerate()
124                .map(|(idx, _)| Ident::new(&format!("__field{}", idx), Span::call_site()))
125                .collect_vec();
126            quote!((#(#field_idents,)*))
127        }
128        Fields::Named(FieldsNamed { named, .. }) => {
129            let field_idents = named.iter()
130                .map(|Field { ident, .. }| ident)
131                .collect_vec();
132            quote!({ #(#field_idents,)* })
133        }
134    }
135}
136
137fn write_fields(internal: bool, sync: bool, fields: &Fields) -> proc_macro2::TokenStream {
138    let async_proto_crate = if internal { quote!(crate) } else { quote!(::async_proto) };
139    match fields {
140        Fields::Unit => quote!(),
141        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
142            let write_fields = unnamed.iter()
143                .enumerate()
144                .map(|(idx, Field { attrs, ty, .. })| {
145                    let mut max_len = None;
146                    for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
147                        match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
148                            Ok(attrs) => for attr in attrs {
149                                match attr {
150                                    FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
151                                        return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
152                                    },
153                                }
154                            },
155                            Err(e) => return e.to_compile_error().into(),
156                        }
157                    }
158                    let ident = Ident::new(&format!("__field{}", idx), Span::call_site());
159                    let write = if let Some(max_len) = max_len {
160                        let write = if sync { quote!(::write_length_prefixed_sync(#ident, sink, #max_len)) } else { quote!(::write_length_prefixed(#ident, sink, #max_len).await) };
161                        quote_spanned! {ty.span()=>
162                            <#ty as #async_proto_crate::LengthPrefixed>#write
163                        }
164                    } else {
165                        let write = if sync { quote!(::write_sync(#ident, sink)) } else { quote!(::write(#ident, sink).await) };
166                        quote_spanned! {ty.span()=>
167                            <#ty as #async_proto_crate::Protocol>#write
168                        }
169                    };
170                    quote!(#write.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
171                        context: #async_proto_crate::ErrorContext::UnnamedField {
172                            idx: #idx,
173                            source: Box::new(context),
174                        },
175                        kind,
176                    })?;)
177                });
178            quote!(#(#write_fields)*)
179        }
180        Fields::Named(FieldsNamed { named, .. }) => {
181            let write_fields = named.iter()
182                .map(|Field { attrs, ident, ty, .. }| {
183                    let mut max_len = None;
184                    for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
185                        match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
186                            Ok(attrs) => for attr in attrs {
187                                match attr {
188                                    FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
189                                        return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
190                                    },
191                                }
192                            },
193                            Err(e) => return e.to_compile_error().into(),
194                        }
195                    }
196                    let write = if let Some(max_len) = max_len {
197                        let write = if sync { quote!(::write_length_prefixed_sync(#ident, sink, #max_len)) } else { quote!(::write_length_prefixed(#ident, sink, #max_len).await) };
198                        quote_spanned! {ty.span()=>
199                            <#ty as #async_proto_crate::LengthPrefixed>#write
200                        }
201                    } else {
202                        let write = if sync { quote!(::write_sync(#ident, sink)) } else { quote!(::write(#ident, sink).await) };
203                        quote_spanned! {ty.span()=>
204                            <#ty as #async_proto_crate::Protocol>#write
205                        }
206                    };
207                    let name = ident.as_ref().expect("FieldsNamed with unnamed field").to_string();
208                    quote!(#write.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
209                        context: #async_proto_crate::ErrorContext::NamedField {
210                            name: #name,
211                            source: Box::new(context),
212                        },
213                        kind,
214                    })?;)
215                });
216            quote!(#(#write_fields)*)
217        }
218    }
219}
220
221enum AsyncProtoAttr {
222    AsString,
223    Attr(Punctuated<Meta, Token![,]>),
224    Clone,
225    Internal,
226    MapErr(Expr),
227    Via(Type),
228    Where(Punctuated<WherePredicate, Token![,]>),
229}
230
231impl Parse for AsyncProtoAttr {
232    fn parse(input: ParseStream<'_>) -> Result<Self> {
233        Ok(if input.peek(Token![where]) {
234            let _ = input.parse::<Token![where]>()?;
235            let content;
236            parenthesized!(content in input);
237            Self::Where(Punctuated::parse_terminated(&content)?)
238        } else {
239            let ident = input.parse::<Ident>()?;
240            match &*ident.to_string() {
241                "as_string" => Self::AsString,
242                "attr" => {
243                    let content;
244                    parenthesized!(content in input);
245                    Self::Attr(Punctuated::parse_terminated(&content)?)
246                }
247                "clone" => Self::Clone,
248                "internal" => Self::Internal,
249                "map_err" => {
250                    let _ = input.parse::<Token![=]>()?;
251                    Self::MapErr(input.parse()?)
252                }
253                "via" => {
254                    let _ = input.parse::<Token![=]>()?;
255                    Self::Via(input.parse()?)
256                }
257                _ => return Err(Error::new(ident.span(), "unknown async_proto type attribute")),
258            }
259        })
260    }
261}
262
263enum FieldAttr {
264    MaxLen(u64),
265}
266
267impl Parse for FieldAttr {
268    fn parse(input: ParseStream<'_>) -> Result<Self> {
269        let ident = input.parse::<Ident>()?;
270        Ok(match &*ident.to_string() {
271            "max_len" => {
272                let _ = input.parse::<Token![=]>()?;
273                Self::MaxLen(input.parse::<LitInt>()?.base10_parse()?)
274            }
275            _ => return Err(Error::new(ident.span(), "unknown async_proto field attribute")),
276        })
277    }
278}
279
280fn impl_protocol_inner(mut internal: bool, attrs: Vec<Attribute>, qual_ty: Path, generics: Generics, data: Option<Data>) -> proc_macro2::TokenStream {
281    let for_type = quote!(#qual_ty).to_string();
282    let mut as_string = false;
283    let mut via = None;
284    let mut clone = false;
285    let mut map_err = None;
286    let mut where_predicates = None;
287    let mut impl_attrs = Vec::default();
288    for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
289        match attr.parse_args_with(Punctuated::<AsyncProtoAttr, Token![,]>::parse_terminated) {
290            Ok(attrs) => for attr in attrs {
291                match attr {
292                    AsyncProtoAttr::AsString => {
293                        if via.is_some() { return quote!(compile_error!("#[async_proto(as_str)] and #[async_proto(via = ...)] are incompatible");).into() }
294                        as_string = true;
295                    }
296                    AsyncProtoAttr::Attr(attr) => impl_attrs.extend(attr),
297                    AsyncProtoAttr::Clone => clone = true,
298                    AsyncProtoAttr::Internal => internal = true,
299                    AsyncProtoAttr::MapErr(expr) => if map_err.replace(expr).is_some() {
300                        return quote!(compile_error!("#[async_proto(map_err = ...)] specified multiple times");).into()
301                    },
302                    AsyncProtoAttr::Via(ty) => if via.replace(ty).is_some() {
303                        return quote!(compile_error!("#[async_proto(via = ...)] specified multiple times");).into()
304                    },
305                    AsyncProtoAttr::Where(predicates) => if where_predicates.replace(predicates).is_some() {
306                        return quote!(compile_error!("#[async_proto(where(...))] specified multiple times");).into()
307                    },
308                }
309            },
310            Err(e) => return e.to_compile_error().into(),
311        }
312    }
313    let async_proto_crate = if internal { quote!(crate) } else { quote!(::async_proto) };
314    let mut impl_generics = generics.clone();
315    if let Some(predicates) = where_predicates {
316        impl_generics.make_where_clause().predicates.extend(predicates);
317    } else {
318        for param in impl_generics.type_params_mut() {
319            param.colon_token.get_or_insert_with(<Token![:]>::default);
320            param.bounds.push(parse_quote!(#async_proto_crate::Protocol));
321            param.bounds.push(parse_quote!(::core::marker::Send));
322            param.bounds.push(parse_quote!(::core::marker::Sync));
323            param.bounds.push(parse_quote!('static));
324        }
325    };
326    let (impl_read, impl_write, impl_read_sync, impl_write_sync) = if as_string {
327        if internal && data.is_some() { return quote!(compile_error!("redundant type layout specification with #[async_proto(as_string)]");).into() }
328        let map_err = map_err.unwrap_or(parse_quote!(::core::convert::Into::<#async_proto_crate::ReadErrorKind>::into));
329        (
330            quote!(<Self as ::std::str::FromStr>::from_str(&<::std::string::String as #async_proto_crate::Protocol>::read(stream).await.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
331                context: #async_proto_crate::ErrorContext::AsString {
332                    source: Box::new(context),
333                },
334                kind,
335            })?).map_err(|e| #async_proto_crate::ReadError {
336                context: #async_proto_crate::ErrorContext::FromStr,
337                kind: (#map_err)(e),
338            })),
339            quote!(<::std::string::String as #async_proto_crate::Protocol>::write(&<Self as ::std::string::ToString>::to_string(self), sink).await.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
340                context: #async_proto_crate::ErrorContext::AsString {
341                    source: Box::new(context),
342                },
343                kind,
344            })),
345            quote!(<Self as ::std::str::FromStr>::from_str(&<::std::string::String as #async_proto_crate::Protocol>::read_sync(stream).map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
346                context: #async_proto_crate::ErrorContext::AsString {
347                    source: Box::new(context),
348                },
349                kind,
350            })?).map_err(|e| #async_proto_crate::ReadError {
351                context: #async_proto_crate::ErrorContext::FromStr,
352                kind: (#map_err)(e),
353            })),
354            quote!(<::std::string::String as #async_proto_crate::Protocol>::write_sync(&<Self as ::std::string::ToString>::to_string(self), sink).map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
355                context: #async_proto_crate::ErrorContext::AsString {
356                    source: Box::new(context),
357                },
358                kind,
359            })),
360        )
361    } else if let Some(proxy_ty) = via {
362        if internal && data.is_some() { return quote!(compile_error!("redundant type layout specification with #[async_proto(via = ...)]");).into() }
363        let (write_proxy, write_sync_proxy) = if clone {
364            (
365                quote!(<Self as ::core::convert::TryInto<#proxy_ty>>::try_into(<Self as ::core::clone::Clone>::clone(self)).map_err(|e| #async_proto_crate::WriteError {
366                    context: #async_proto_crate::ErrorContext::TryInto,
367                    kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
368                })?),
369                quote!(<Self as ::core::convert::TryInto<#proxy_ty>>::try_into(<Self as ::core::clone::Clone>::clone(self)).map_err(|e| #async_proto_crate::WriteError {
370                    context: #async_proto_crate::ErrorContext::TryInto,
371                    kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
372                })?),
373            )
374        } else {
375            (
376                quote!(<&'a Self as ::core::convert::TryInto<#proxy_ty>>::try_into(self).map_err(|e| #async_proto_crate::WriteError {
377                    context: #async_proto_crate::ErrorContext::TryInto,
378                    kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
379                })?),
380                quote!(<&Self as ::core::convert::TryInto<#proxy_ty>>::try_into(self).map_err(|e| #async_proto_crate::WriteError {
381                    context: #async_proto_crate::ErrorContext::TryInto,
382                    kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
383                })?),
384            )
385        };
386        let map_err = map_err.unwrap_or(parse_quote!(::core::convert::Into::<#async_proto_crate::ReadErrorKind>::into));
387        (
388            quote!(<#proxy_ty as ::core::convert::TryInto<Self>>::try_into(<#proxy_ty as #async_proto_crate::Protocol>::read(stream).await.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
389                context: #async_proto_crate::ErrorContext::Via {
390                    source: Box::new(context),
391                },
392                kind,
393            })?).map_err(|e| #async_proto_crate::ReadError {
394                context: #async_proto_crate::ErrorContext::TryInto,
395                kind: (#map_err)(e),
396            })),
397            quote!(<#proxy_ty as #async_proto_crate::Protocol>::write(&#write_proxy, sink).await.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
398                context: #async_proto_crate::ErrorContext::Via {
399                    source: Box::new(context),
400                },
401                kind,
402            })),
403            quote!(<Self as ::core::convert::TryFrom<#proxy_ty>>::try_from(<#proxy_ty as #async_proto_crate::Protocol>::read_sync(stream).map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
404                context: #async_proto_crate::ErrorContext::Via {
405                    source: Box::new(context),
406                },
407                kind,
408            })?).map_err(|e| #async_proto_crate::ReadError {
409                context: #async_proto_crate::ErrorContext::TryInto,
410                kind: (#map_err)(e),
411            })),
412            quote!(<#proxy_ty as #async_proto_crate::Protocol>::write_sync(&#write_sync_proxy, sink).map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
413                context: #async_proto_crate::ErrorContext::Via {
414                    source: Box::new(context),
415                },
416                kind,
417            })),
418        )
419    } else {
420        if map_err.is_some() { return quote!(compile_error!("#[async_proto(map_err = ...)] does nothing without #[async_proto(as_string)] or #[async_proto(via = ...)]");).into() }
421        match data {
422            Some(Data::Struct(DataStruct { fields, .. })) => {
423                let fields_pat = fields_pat(&fields);
424                let read_fields_async = read_fields(internal, false, &fields);
425                let write_fields_async = write_fields(internal, false, &fields);
426                let read_fields_sync = read_fields(internal, true, &fields);
427                let write_fields_sync = write_fields(internal, true, &fields);
428                (
429                    quote!(::core::result::Result::Ok(Self #read_fields_async)),
430                    quote! {
431                        let Self #fields_pat = self;
432                        #write_fields_async
433                        ::core::result::Result::Ok(())
434                    },
435                    quote!(::core::result::Result::Ok(Self #read_fields_sync)),
436                    quote! {
437                        let Self #fields_pat = self;
438                        #write_fields_sync
439                        ::core::result::Result::Ok(())
440                    },
441                )
442            }
443            Some(Data::Enum(DataEnum { variants, .. })) => {
444                if variants.is_empty() {
445                    (
446                        quote!(::core::result::Result::Err(#async_proto_crate::ReadError {
447                            context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
448                            kind: #async_proto_crate::ReadErrorKind::ReadNever,
449                        })),
450                        quote!(match *self {}),
451                        quote!(::core::result::Result::Err(#async_proto_crate::ReadError {
452                            context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
453                            kind: #async_proto_crate::ReadErrorKind::ReadNever,
454                        })),
455                        quote!(match *self {}),
456                    )
457                } else {
458                    let (discrim_ty, unknown_variant_variant, get_discrim) = match variants.len() {
459                        0 => unreachable!(), // empty enum handled above
460                        1..=256 => (quote!(u8), quote!(UnknownVariant8), (&|idx| {
461                            let idx = u8::try_from(idx).expect("variant index unexpectedly high");
462                            quote!(#idx)
463                        }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
464                        257..=65_536 => (quote!(u16), quote!(UnknownVariant16), (&|idx| {
465                            let idx = u16::try_from(idx).expect("variant index unexpectedly high");
466                            quote!(#idx)
467                        }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
468                        #[cfg(target_pointer_width = "32")]
469                        _ => (quote!(u32), quote!(UnknownVariant32), (&|idx| {
470                            let idx = u32::try_from(idx).expect("variant index unexpectedly high");
471                            quote!(#idx)
472                        }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
473                        #[cfg(target_pointer_width = "64")]
474                        65_537..=4_294_967_296 => (quote!(u32), quote!(UnknownVariant32), (&|idx| {
475                            let idx = u32::try_from(idx).expect("variant index unexpectedly high");
476                            quote!(#idx)
477                        }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
478                        #[cfg(target_pointer_width = "64")]
479                        _ => (quote!(u64), quote!(UnknownVariant64), (&|idx| {
480                            let idx = u64::try_from(idx).expect("variant index unexpectedly high");
481                            quote!(#idx)
482                        }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
483                    };
484                    let read_arms = variants.iter()
485                        .enumerate()
486                        .map(|(idx, Variant { ident: var, fields, .. })| {
487                            let idx = get_discrim(idx);
488                            let read_fields = read_fields(internal, false, fields);
489                            quote!(#idx => ::core::result::Result::Ok(Self::#var #read_fields))
490                        })
491                        .collect_vec();
492                    let write_arms = variants.iter()
493                        .enumerate()
494                        .map(|(idx, Variant { ident: var, fields, .. })| {
495                            let idx = get_discrim(idx);
496                            let fields_pat = fields_pat(&fields);
497                            let write_fields = write_fields(internal, false, fields);
498                            quote! {
499                                Self::#var #fields_pat => {
500                                    #idx.write(sink).await.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
501                                        context: #async_proto_crate::ErrorContext::EnumDiscrim {
502                                            source: Box::new(context),
503                                        },
504                                        kind,
505                                    })?;
506                                    #write_fields
507                                }
508                            }
509                        })
510                        .collect_vec();
511                    let read_sync_arms = variants.iter()
512                        .enumerate()
513                        .map(|(idx, Variant { ident: var, fields, .. })| {
514                            let idx = get_discrim(idx);
515                            let read_fields = read_fields(internal, true, fields);
516                            quote!(#idx => ::core::result::Result::Ok(Self::#var #read_fields))
517                        })
518                        .collect_vec();
519                    let write_sync_arms = variants.iter()
520                        .enumerate()
521                        .map(|(idx, Variant { ident: var, fields, .. })| {
522                            let idx = get_discrim(idx);
523                            let fields_pat = fields_pat(&fields);
524                            let write_fields = write_fields(internal, true, fields);
525                            quote! {
526                                Self::#var #fields_pat => {
527                                    #idx.write_sync(sink).map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
528                                        context: #async_proto_crate::ErrorContext::EnumDiscrim {
529                                            source: Box::new(context),
530                                        },
531                                        kind,
532                                    })?;
533                                    #write_fields
534                                }
535                            }
536                        })
537                        .collect_vec();
538                    (
539                        quote! {
540                            match <#discrim_ty as #async_proto_crate::Protocol>::read(stream).await.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
541                                context: #async_proto_crate::ErrorContext::EnumDiscrim {
542                                    source: Box::new(context),
543                                },
544                                kind,
545                            })? {
546                                #(#read_arms,)*
547                                n => ::core::result::Result::Err(#async_proto_crate::ReadError {
548                                    context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
549                                    kind: #async_proto_crate::ReadErrorKind::#unknown_variant_variant(n),
550                                }),
551                            }
552                        },
553                        quote! {
554                            match self {
555                                #(#write_arms,)*
556                            }
557                            ::core::result::Result::Ok(())
558                        },
559                        quote! {
560                            match <#discrim_ty as #async_proto_crate::Protocol>::read_sync(stream).map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
561                                context: #async_proto_crate::ErrorContext::EnumDiscrim {
562                                    source: Box::new(context),
563                                },
564                                kind,
565                            })? {
566                                #(#read_sync_arms,)*
567                                n => ::core::result::Result::Err(#async_proto_crate::ReadError {
568                                    context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
569                                    kind: #async_proto_crate::ReadErrorKind::#unknown_variant_variant(n),
570                                }),
571                            }
572                        },
573                        quote! {
574                            match self {
575                                #(#write_sync_arms,)*
576                            }
577                            ::core::result::Result::Ok(())
578                        },
579                    )
580                }
581            }
582            Some(Data::Union(_)) => return quote!(compile_error!("unions not supported in derive(Protocol)");).into(),
583            None => return quote!(compile_error!("missing type layout specification or #[async_proto(via = ...)]");).into(),
584        }
585    };
586    let (impl_generics, ty_generics, where_clause) = impl_generics.split_for_impl();
587    quote! {
588        #(#[#impl_attrs])*
589        impl #impl_generics #async_proto_crate::Protocol for #qual_ty #ty_generics #where_clause {
590            fn read<'a, R: #async_proto_crate::tokio::io::AsyncRead + ::core::marker::Unpin + ::core::marker::Send + 'a>(stream: &'a mut R) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<Self, #async_proto_crate::ReadError>> + ::core::marker::Send + 'a>> {
591                ::std::boxed::Box::pin(async move { #impl_read })
592            }
593
594            fn write<'a, W: #async_proto_crate::tokio::io::AsyncWrite + ::core::marker::Unpin + ::core::marker::Send + 'a>(&'a self, sink: &'a mut W) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<(), #async_proto_crate::WriteError>> + ::core::marker::Send + 'a>> {
595                ::std::boxed::Box::pin(async move { #impl_write })
596            }
597
598            fn read_sync(mut stream: &mut impl ::std::io::Read) -> ::core::result::Result<Self, #async_proto_crate::ReadError> { #impl_read_sync }
599            fn write_sync(&self, mut sink: &mut impl ::std::io::Write) -> ::core::result::Result<(), #async_proto_crate::WriteError> { #impl_write_sync }
600        }
601    }
602}
603
604/// Implements the `Protocol` trait for this type.
605///
606/// By default, the network representation is very simple:
607///
608/// * Attempting to read an `enum` with no variants errors immediately, without waiting for data to appear on the stream.
609/// * For non-empty `enum`s, the representation starts with the discriminant (a number representing the variant), starting with `0` for the first variant declared and so on.
610///     * For `enum`s with up to 256 variants, the discriminant is represented as a [`u8`]. For `enums` with 257 to 65536 variants, as a [`u16`], and so on.
611/// * Then follow the `Protocol` representations of any fields of the `struct` or variant, in the order declared.
612///
613/// This representation can waste bandwidth for some types, e.g. `struct`s with multiple [`bool`] fields. For those, you may want to implement `Protocol` manually.
614///
615/// # Attributes
616///
617/// This macro's behavior can be modified using attributes. Multiple attributes can be specified as `#[async_proto(attr1, attr2, ...)]` or `#[async_proto(attr1)] #[async_proto(attr2)] ...`. The following attributes are available:
618///
619/// * `#[async_proto(as_string)]`: Implements `Protocol` for this type by converting from and to a string using the `FromStr` and `ToString` traits. The `FromStr` error type must implement `Into<ReadErrorKind>`.
620///     * `#[async_proto(map_err = ...)]`: Removes the requirement for the `FromStr` error type to implement `Into<ReadErrorKind>` and instead uses the given expression (which should be an `FnOnce(<T as FromStr>::Err) -> ReadErrorKind`) to convert the error.
621/// * `#[async_proto(attr(...))]`: Adds the given attribute(s) to the `Protocol` implementation. For example, the implementation can be documented using `#[async_proto(attr(doc = "..."))]`. May be specified multiple times.
622/// * `#[async_proto(via = Proxy)]`: Implements `Protocol` for this type (let's call it `T`) in terms of another type (`Proxy` in this case) instead of using the variant- and field-based representation described above. `&'a T` must implement `TryInto<Proxy>` for all `'a`, with an `Error` type that implements `Into<WriteErrorKind>`, and `Proxy` must implement `Protocol` and `TryInto<T>`, with an `Error` type that implements `Into<ReadErrorKind>`.
623///     * `#[async_proto(clone)]`: Replaces the requirement for `&'a T` to implement `TryInto<Proxy>` with requirements for `T` to implement `Clone` and `TryInto<Proxy>`.
624///     * `#[async_proto(map_err = ...)]`: Removes the requirement for `<Proxy as TryInto<T>>::Error` to implement `Into<ReadErrorKind>` and instead uses the given expression (which should be an `FnOnce(<Proxy as TryInto<T>>::Error) -> ReadErrorKind`) to convert the error.
625/// * `#[async_proto(where(...))]`: Overrides the bounds for the generated `Protocol` implementation. The default is to require `Protocol + Send + Sync + 'static` for each type parameter of this type.
626///
627/// # Field attributes
628///
629/// Additionally, the following attributes can be set on struct or enum fields, rather than the entire type for which `Protocol` is being derived:
630///
631/// * `#[async_proto(max_len = ...)]`: Can be used on a field implementing the `LengthPrefixed` trait to limit the allowable length. Note that this alters the network representation of the length prefix (with a `max_len` of up to 255, the length is represented as a [`u8`]; with a `max_len` of 256 to 65535, as a [`u16`]; and so on), so adding/removing/changing this attribute may break protocol compatibility.
632///
633/// # Compile errors
634///
635/// * This macro can't be used with `union`s.
636#[proc_macro_derive(Protocol, attributes(async_proto))]
637pub fn derive_protocol(input: TokenStream) -> TokenStream {
638    let DeriveInput { attrs, ident, generics, data, .. } = parse_macro_input!(input);
639    impl_protocol_inner(false, attrs, parse_quote!(#ident), generics, Some(data)).into()
640}
641
642struct ImplProtocolFor(Vec<(Vec<Attribute>, Path, Generics, Option<Data>)>);
643
644impl Parse for ImplProtocolFor {
645    fn parse(input: ParseStream<'_>) -> Result<Self> {
646        let mut decls = Vec::default();
647        while !input.is_empty() {
648            let attrs = Attribute::parse_outer(input)?;
649            let lookahead = input.lookahead1();
650            decls.push(if lookahead.peek(Token![enum]) {
651                let enum_token = input.parse()?;
652                let path = Path::parse_mod_style(input)?;
653                let generics = input.parse()?;
654                let content;
655                let brace_token = braced!(content in input);
656                let variants = Punctuated::parse_terminated(&content)?;
657                (attrs, path, generics, Some(Data::Enum(DataEnum { enum_token, brace_token, variants })))
658            } else if lookahead.peek(Token![struct]) {
659                let struct_token = input.parse()?;
660                let path = Path::parse_mod_style(input)?;
661                let generics = input.parse()?;
662                let lookahead = input.lookahead1();
663                let fields = if lookahead.peek(Token![;]) {
664                    Fields::Unit
665                } else if lookahead.peek(Paren) {
666                    let content;
667                    let paren_token = parenthesized!(content in input);
668                    let unnamed = Punctuated::parse_terminated_with(&content, Field::parse_unnamed)?;
669                    Fields::Unnamed(FieldsUnnamed { paren_token, unnamed })
670                } else if lookahead.peek(Brace) {
671                    let content;
672                    let brace_token = braced!(content in input);
673                    let named = Punctuated::parse_terminated_with(&content, Field::parse_named)?;
674                    Fields::Named(FieldsNamed { brace_token, named })
675                } else {
676                    return Err(lookahead.error())
677                };
678                let semi_token = input.peek(Token![;]).then(|| input.parse()).transpose()?;
679                (attrs, path, generics, Some(Data::Struct(DataStruct { struct_token, fields, semi_token })))
680            } else if lookahead.peek(Token![type]) {
681                let _ = input.parse::<Token![type]>()?;
682                let path = Path::parse_mod_style(input)?;
683                let mut generics = input.parse::<Generics>()?;
684                generics.where_clause = input.parse()?;
685                let _ = input.parse::<Token![;]>()?;
686                (attrs, path, generics, None)
687            } else {
688                return Err(lookahead.error())
689            });
690        }
691        Ok(ImplProtocolFor(decls))
692    }
693}
694
695#[doc(hidden)]
696#[proc_macro]
697pub fn impl_protocol_for(input: TokenStream) -> TokenStream {
698    let impls = parse_macro_input!(input as ImplProtocolFor)
699        .0.into_iter()
700        .map(|(attrs, path, generics, data)| impl_protocol_inner(true, attrs, path, generics, data));
701    TokenStream::from(quote!(#(#impls)*))
702}
703
704struct Bitflags {
705    name: Ident,
706    repr: Ident,
707}
708
709impl Parse for Bitflags {
710    fn parse(input: ParseStream<'_>) -> Result<Self> {
711        let name = input.parse()?;
712        input.parse::<Token![:]>()?;
713        let repr = input.parse()?;
714        Ok(Self { name, repr })
715    }
716}
717
718/// Implements `Protocol` for a type defined using the [`bitflags::bitflags`](https://docs.rs/bitflags/latest/bitflags/macro.bitflags.html) macro.
719///
720/// The type will be read via [`from_bits_truncate`](https://docs.rs/bitflags/latest/bitflags/example_generated/struct.Flags.html#method.from_bits_truncate), dropping any bits that do not correspond to flags.
721///
722/// # Usage
723///
724/// ```rust
725/// bitflags::bitflags! {
726///     struct Flags: u32 {
727///         const A = 0b00000001;
728///         const B = 0b00000010;
729///         const C = 0b00000100;
730///         const ABC = Self::A.bits | Self::B.bits | Self::C.bits;
731///     }
732/// }
733///
734/// async_proto::bitflags!(Flags: u32);
735/// ```
736#[proc_macro]
737pub fn bitflags(input: TokenStream) -> TokenStream {
738    let Bitflags { name, repr } = parse_macro_input!(input);
739    TokenStream::from(quote! {
740        impl ::async_proto::Protocol for #name {
741            fn read<'a, R: ::async_proto::tokio::io::AsyncRead + ::core::marker::Unpin + ::core::marker::Send + 'a>(stream: &'a mut R) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<Self, ::async_proto::ReadError>> + ::core::marker::Send + 'a>> {
742                ::std::boxed::Box::pin(async move {
743                    Ok(Self::from_bits_truncate(<#repr as ::async_proto::Protocol>::read(stream).await.map_err(|::async_proto::ReadError { context, kind }| ::async_proto::ReadError {
744                        context: ::async_proto::ErrorContext::Bitflags {
745                            source: Box::new(context),
746                        },
747                        kind,
748                    })?))
749                })
750            }
751
752            fn write<'a, W: ::async_proto::tokio::io::AsyncWrite + ::core::marker::Unpin + ::core::marker::Send + 'a>(&'a self, sink: &'a mut W) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<(), ::async_proto::WriteError>> + ::core::marker::Send + 'a>> {
753                ::std::boxed::Box::pin(async move {
754                    <#repr as ::async_proto::Protocol>::write(&self.bits(), sink).await.map_err(|::async_proto::WriteError { context, kind }| ::async_proto::WriteError {
755                        context: ::async_proto::ErrorContext::Bitflags {
756                            source: Box::new(context),
757                        },
758                        kind,
759                    })
760                })
761            }
762
763            fn read_sync(stream: &mut impl ::std::io::Read) -> ::core::result::Result<Self, ::async_proto::ReadError> {
764                Ok(Self::from_bits_truncate(<#repr as ::async_proto::Protocol>::read_sync(stream).map_err(|::async_proto::ReadError { context, kind }| ::async_proto::ReadError {
765                    context: ::async_proto::ErrorContext::Bitflags {
766                        source: Box::new(context),
767                    },
768                    kind,
769                })?))
770            }
771
772            fn write_sync(&self, sink: &mut impl ::std::io::Write) -> ::core::result::Result<(), ::async_proto::WriteError> {
773                <#repr as ::async_proto::Protocol>::write_sync(&self.bits(), sink).map_err(|::async_proto::WriteError { context, kind }| ::async_proto::WriteError {
774                    context: ::async_proto::ErrorContext::Bitflags {
775                        source: Box::new(context),
776                    },
777                    kind,
778                })
779            }
780        }
781    })
782}