1#![deny(missing_docs, rust_2018_idioms, unused, unused_crate_dependencies, unused_import_braces, unused_lifetimes, unused_qualifications, warnings)]
2#![forbid(unsafe_code)]
3
4use {
7 std::convert::TryFrom as _,
8 itertools::Itertools as _,
9 proc_macro::TokenStream,
10 proc_macro2::Span,
11 quote::{
12 quote,
13 quote_spanned,
14 },
15 syn::{
16 *,
17 parse::{
18 Parse,
19 ParseStream,
20 },
21 punctuated::Punctuated,
22 spanned::Spanned as _,
23 token::{
24 Brace,
25 Paren,
26 },
27 },
28};
29
30fn read_fields(internal: bool, sync: bool, fields: &Fields) -> proc_macro2::TokenStream {
31 let async_proto_crate = if internal { quote!(crate) } else { quote!(::async_proto) };
32 let read = if sync { quote!(::read_sync(stream)) } else { quote!(::read(stream).await) };
33 match fields {
34 Fields::Unit => quote!(),
35 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
36 let read_fields = unnamed.iter()
37 .enumerate()
38 .map(|(idx, Field { attrs, ty, .. })| {
39 let mut max_len = None;
40 for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
41 match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
42 Ok(attrs) => for attr in attrs {
43 match attr {
44 FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
45 return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
46 },
47 }
48 },
49 Err(e) => return e.to_compile_error().into(),
50 }
51 }
52 let read = if let Some(max_len) = max_len {
53 let read = if sync { quote!(::read_length_prefixed_sync(stream, #max_len)) } else { quote!(::read_length_prefixed(stream, #max_len).await) };
54 quote_spanned! {ty.span()=>
55 <#ty as #async_proto_crate::LengthPrefixed>#read
56 }
57 } else {
58 quote_spanned! {ty.span()=>
59 <#ty as #async_proto_crate::Protocol>#read
60 }
61 };
62 quote_spanned! {ty.span()=>
63 #read.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
64 context: #async_proto_crate::ErrorContext::UnnamedField {
65 idx: #idx,
66 source: Box::new(context),
67 },
68 kind,
69 })?
70 }
71 })
72 .collect_vec();
73 quote!((#(#read_fields,)*))
74 }
75 Fields::Named(FieldsNamed { named, .. }) => {
76 let read_fields = named.iter()
77 .map(|Field { attrs, ident, ty, .. }| {
78 let mut max_len = None;
79 for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
80 match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
81 Ok(attrs) => for attr in attrs {
82 match attr {
83 FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
84 return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
85 },
86 }
87 },
88 Err(e) => return e.to_compile_error().into(),
89 }
90 }
91 let name = ident.as_ref().expect("FieldsNamed with unnamed field").to_string();
92 let read = if let Some(max_len) = max_len {
93 let read = if sync { quote!(::read_length_prefixed_sync(stream, #max_len)) } else { quote!(::read_length_prefixed(stream, #max_len).await) };
94 quote_spanned! {ty.span()=>
95 <#ty as #async_proto_crate::LengthPrefixed>#read
96 }
97 } else {
98 quote_spanned! {ty.span()=>
99 <#ty as #async_proto_crate::Protocol>#read
100 }
101 };
102 quote_spanned! {ty.span()=>
103 #ident: #read.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
104 context: #async_proto_crate::ErrorContext::NamedField {
105 name: #name,
106 source: Box::new(context),
107 },
108 kind,
109 })?
110 }
111 })
112 .collect_vec();
113 quote!({ #(#read_fields,)* })
114 }
115 }
116}
117
118fn fields_pat(fields: &Fields) -> proc_macro2::TokenStream {
119 match fields {
120 Fields::Unit => quote!(),
121 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
122 let field_idents = unnamed.iter()
123 .enumerate()
124 .map(|(idx, _)| Ident::new(&format!("__field{}", idx), Span::call_site()))
125 .collect_vec();
126 quote!((#(#field_idents,)*))
127 }
128 Fields::Named(FieldsNamed { named, .. }) => {
129 let field_idents = named.iter()
130 .map(|Field { ident, .. }| ident)
131 .collect_vec();
132 quote!({ #(#field_idents,)* })
133 }
134 }
135}
136
137fn write_fields(internal: bool, sync: bool, fields: &Fields) -> proc_macro2::TokenStream {
138 let async_proto_crate = if internal { quote!(crate) } else { quote!(::async_proto) };
139 match fields {
140 Fields::Unit => quote!(),
141 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
142 let write_fields = unnamed.iter()
143 .enumerate()
144 .map(|(idx, Field { attrs, ty, .. })| {
145 let mut max_len = None;
146 for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
147 match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
148 Ok(attrs) => for attr in attrs {
149 match attr {
150 FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
151 return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
152 },
153 }
154 },
155 Err(e) => return e.to_compile_error().into(),
156 }
157 }
158 let ident = Ident::new(&format!("__field{}", idx), Span::call_site());
159 let write = if let Some(max_len) = max_len {
160 let write = if sync { quote!(::write_length_prefixed_sync(#ident, sink, #max_len)) } else { quote!(::write_length_prefixed(#ident, sink, #max_len).await) };
161 quote_spanned! {ty.span()=>
162 <#ty as #async_proto_crate::LengthPrefixed>#write
163 }
164 } else {
165 let write = if sync { quote!(::write_sync(#ident, sink)) } else { quote!(::write(#ident, sink).await) };
166 quote_spanned! {ty.span()=>
167 <#ty as #async_proto_crate::Protocol>#write
168 }
169 };
170 quote!(#write.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
171 context: #async_proto_crate::ErrorContext::UnnamedField {
172 idx: #idx,
173 source: Box::new(context),
174 },
175 kind,
176 })?;)
177 });
178 quote!(#(#write_fields)*)
179 }
180 Fields::Named(FieldsNamed { named, .. }) => {
181 let write_fields = named.iter()
182 .map(|Field { attrs, ident, ty, .. }| {
183 let mut max_len = None;
184 for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
185 match attr.parse_args_with(Punctuated::<FieldAttr, Token![,]>::parse_terminated) {
186 Ok(attrs) => for attr in attrs {
187 match attr {
188 FieldAttr::MaxLen(new_max_len) => if max_len.replace(new_max_len).is_some() {
189 return quote!(compile_error!("#[async_proto(max_len = ...)] specified multiple times");).into()
190 },
191 }
192 },
193 Err(e) => return e.to_compile_error().into(),
194 }
195 }
196 let write = if let Some(max_len) = max_len {
197 let write = if sync { quote!(::write_length_prefixed_sync(#ident, sink, #max_len)) } else { quote!(::write_length_prefixed(#ident, sink, #max_len).await) };
198 quote_spanned! {ty.span()=>
199 <#ty as #async_proto_crate::LengthPrefixed>#write
200 }
201 } else {
202 let write = if sync { quote!(::write_sync(#ident, sink)) } else { quote!(::write(#ident, sink).await) };
203 quote_spanned! {ty.span()=>
204 <#ty as #async_proto_crate::Protocol>#write
205 }
206 };
207 let name = ident.as_ref().expect("FieldsNamed with unnamed field").to_string();
208 quote!(#write.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
209 context: #async_proto_crate::ErrorContext::NamedField {
210 name: #name,
211 source: Box::new(context),
212 },
213 kind,
214 })?;)
215 });
216 quote!(#(#write_fields)*)
217 }
218 }
219}
220
221enum AsyncProtoAttr {
222 AsString,
223 Attr(Punctuated<Meta, Token![,]>),
224 Clone,
225 Internal,
226 MapErr(Expr),
227 Via(Type),
228 Where(Punctuated<WherePredicate, Token![,]>),
229}
230
231impl Parse for AsyncProtoAttr {
232 fn parse(input: ParseStream<'_>) -> Result<Self> {
233 Ok(if input.peek(Token![where]) {
234 let _ = input.parse::<Token![where]>()?;
235 let content;
236 parenthesized!(content in input);
237 Self::Where(Punctuated::parse_terminated(&content)?)
238 } else {
239 let ident = input.parse::<Ident>()?;
240 match &*ident.to_string() {
241 "as_string" => Self::AsString,
242 "attr" => {
243 let content;
244 parenthesized!(content in input);
245 Self::Attr(Punctuated::parse_terminated(&content)?)
246 }
247 "clone" => Self::Clone,
248 "internal" => Self::Internal,
249 "map_err" => {
250 let _ = input.parse::<Token![=]>()?;
251 Self::MapErr(input.parse()?)
252 }
253 "via" => {
254 let _ = input.parse::<Token![=]>()?;
255 Self::Via(input.parse()?)
256 }
257 _ => return Err(Error::new(ident.span(), "unknown async_proto type attribute")),
258 }
259 })
260 }
261}
262
263enum FieldAttr {
264 MaxLen(u64),
265}
266
267impl Parse for FieldAttr {
268 fn parse(input: ParseStream<'_>) -> Result<Self> {
269 let ident = input.parse::<Ident>()?;
270 Ok(match &*ident.to_string() {
271 "max_len" => {
272 let _ = input.parse::<Token![=]>()?;
273 Self::MaxLen(input.parse::<LitInt>()?.base10_parse()?)
274 }
275 _ => return Err(Error::new(ident.span(), "unknown async_proto field attribute")),
276 })
277 }
278}
279
280fn impl_protocol_inner(mut internal: bool, attrs: Vec<Attribute>, qual_ty: Path, generics: Generics, data: Option<Data>) -> proc_macro2::TokenStream {
281 let for_type = quote!(#qual_ty).to_string();
282 let mut as_string = false;
283 let mut via = None;
284 let mut clone = false;
285 let mut map_err = None;
286 let mut where_predicates = None;
287 let mut impl_attrs = Vec::default();
288 for attr in attrs.into_iter().filter(|attr| attr.path().is_ident("async_proto")) {
289 match attr.parse_args_with(Punctuated::<AsyncProtoAttr, Token![,]>::parse_terminated) {
290 Ok(attrs) => for attr in attrs {
291 match attr {
292 AsyncProtoAttr::AsString => {
293 if via.is_some() { return quote!(compile_error!("#[async_proto(as_str)] and #[async_proto(via = ...)] are incompatible");).into() }
294 as_string = true;
295 }
296 AsyncProtoAttr::Attr(attr) => impl_attrs.extend(attr),
297 AsyncProtoAttr::Clone => clone = true,
298 AsyncProtoAttr::Internal => internal = true,
299 AsyncProtoAttr::MapErr(expr) => if map_err.replace(expr).is_some() {
300 return quote!(compile_error!("#[async_proto(map_err = ...)] specified multiple times");).into()
301 },
302 AsyncProtoAttr::Via(ty) => if via.replace(ty).is_some() {
303 return quote!(compile_error!("#[async_proto(via = ...)] specified multiple times");).into()
304 },
305 AsyncProtoAttr::Where(predicates) => if where_predicates.replace(predicates).is_some() {
306 return quote!(compile_error!("#[async_proto(where(...))] specified multiple times");).into()
307 },
308 }
309 },
310 Err(e) => return e.to_compile_error().into(),
311 }
312 }
313 let async_proto_crate = if internal { quote!(crate) } else { quote!(::async_proto) };
314 let mut impl_generics = generics.clone();
315 if let Some(predicates) = where_predicates {
316 impl_generics.make_where_clause().predicates.extend(predicates);
317 } else {
318 for param in impl_generics.type_params_mut() {
319 param.colon_token.get_or_insert_with(<Token![:]>::default);
320 param.bounds.push(parse_quote!(#async_proto_crate::Protocol));
321 param.bounds.push(parse_quote!(::core::marker::Send));
322 param.bounds.push(parse_quote!(::core::marker::Sync));
323 param.bounds.push(parse_quote!('static));
324 }
325 };
326 let (impl_read, impl_write, impl_read_sync, impl_write_sync) = if as_string {
327 if internal && data.is_some() { return quote!(compile_error!("redundant type layout specification with #[async_proto(as_string)]");).into() }
328 let map_err = map_err.unwrap_or(parse_quote!(::core::convert::Into::<#async_proto_crate::ReadErrorKind>::into));
329 (
330 quote!(<Self as ::std::str::FromStr>::from_str(&<::std::string::String as #async_proto_crate::Protocol>::read(stream).await.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
331 context: #async_proto_crate::ErrorContext::AsString {
332 source: Box::new(context),
333 },
334 kind,
335 })?).map_err(|e| #async_proto_crate::ReadError {
336 context: #async_proto_crate::ErrorContext::FromStr,
337 kind: (#map_err)(e),
338 })),
339 quote!(<::std::string::String as #async_proto_crate::Protocol>::write(&<Self as ::std::string::ToString>::to_string(self), sink).await.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
340 context: #async_proto_crate::ErrorContext::AsString {
341 source: Box::new(context),
342 },
343 kind,
344 })),
345 quote!(<Self as ::std::str::FromStr>::from_str(&<::std::string::String as #async_proto_crate::Protocol>::read_sync(stream).map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
346 context: #async_proto_crate::ErrorContext::AsString {
347 source: Box::new(context),
348 },
349 kind,
350 })?).map_err(|e| #async_proto_crate::ReadError {
351 context: #async_proto_crate::ErrorContext::FromStr,
352 kind: (#map_err)(e),
353 })),
354 quote!(<::std::string::String as #async_proto_crate::Protocol>::write_sync(&<Self as ::std::string::ToString>::to_string(self), sink).map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
355 context: #async_proto_crate::ErrorContext::AsString {
356 source: Box::new(context),
357 },
358 kind,
359 })),
360 )
361 } else if let Some(proxy_ty) = via {
362 if internal && data.is_some() { return quote!(compile_error!("redundant type layout specification with #[async_proto(via = ...)]");).into() }
363 let (write_proxy, write_sync_proxy) = if clone {
364 (
365 quote!(<Self as ::core::convert::TryInto<#proxy_ty>>::try_into(<Self as ::core::clone::Clone>::clone(self)).map_err(|e| #async_proto_crate::WriteError {
366 context: #async_proto_crate::ErrorContext::TryInto,
367 kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
368 })?),
369 quote!(<Self as ::core::convert::TryInto<#proxy_ty>>::try_into(<Self as ::core::clone::Clone>::clone(self)).map_err(|e| #async_proto_crate::WriteError {
370 context: #async_proto_crate::ErrorContext::TryInto,
371 kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
372 })?),
373 )
374 } else {
375 (
376 quote!(<&'a Self as ::core::convert::TryInto<#proxy_ty>>::try_into(self).map_err(|e| #async_proto_crate::WriteError {
377 context: #async_proto_crate::ErrorContext::TryInto,
378 kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
379 })?),
380 quote!(<&Self as ::core::convert::TryInto<#proxy_ty>>::try_into(self).map_err(|e| #async_proto_crate::WriteError {
381 context: #async_proto_crate::ErrorContext::TryInto,
382 kind: ::core::convert::Into::<#async_proto_crate::WriteErrorKind>::into(e),
383 })?),
384 )
385 };
386 let map_err = map_err.unwrap_or(parse_quote!(::core::convert::Into::<#async_proto_crate::ReadErrorKind>::into));
387 (
388 quote!(<#proxy_ty as ::core::convert::TryInto<Self>>::try_into(<#proxy_ty as #async_proto_crate::Protocol>::read(stream).await.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
389 context: #async_proto_crate::ErrorContext::Via {
390 source: Box::new(context),
391 },
392 kind,
393 })?).map_err(|e| #async_proto_crate::ReadError {
394 context: #async_proto_crate::ErrorContext::TryInto,
395 kind: (#map_err)(e),
396 })),
397 quote!(<#proxy_ty as #async_proto_crate::Protocol>::write(&#write_proxy, sink).await.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
398 context: #async_proto_crate::ErrorContext::Via {
399 source: Box::new(context),
400 },
401 kind,
402 })),
403 quote!(<Self as ::core::convert::TryFrom<#proxy_ty>>::try_from(<#proxy_ty as #async_proto_crate::Protocol>::read_sync(stream).map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
404 context: #async_proto_crate::ErrorContext::Via {
405 source: Box::new(context),
406 },
407 kind,
408 })?).map_err(|e| #async_proto_crate::ReadError {
409 context: #async_proto_crate::ErrorContext::TryInto,
410 kind: (#map_err)(e),
411 })),
412 quote!(<#proxy_ty as #async_proto_crate::Protocol>::write_sync(&#write_sync_proxy, sink).map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
413 context: #async_proto_crate::ErrorContext::Via {
414 source: Box::new(context),
415 },
416 kind,
417 })),
418 )
419 } else {
420 if map_err.is_some() { return quote!(compile_error!("#[async_proto(map_err = ...)] does nothing without #[async_proto(as_string)] or #[async_proto(via = ...)]");).into() }
421 match data {
422 Some(Data::Struct(DataStruct { fields, .. })) => {
423 let fields_pat = fields_pat(&fields);
424 let read_fields_async = read_fields(internal, false, &fields);
425 let write_fields_async = write_fields(internal, false, &fields);
426 let read_fields_sync = read_fields(internal, true, &fields);
427 let write_fields_sync = write_fields(internal, true, &fields);
428 (
429 quote!(::core::result::Result::Ok(Self #read_fields_async)),
430 quote! {
431 let Self #fields_pat = self;
432 #write_fields_async
433 ::core::result::Result::Ok(())
434 },
435 quote!(::core::result::Result::Ok(Self #read_fields_sync)),
436 quote! {
437 let Self #fields_pat = self;
438 #write_fields_sync
439 ::core::result::Result::Ok(())
440 },
441 )
442 }
443 Some(Data::Enum(DataEnum { variants, .. })) => {
444 if variants.is_empty() {
445 (
446 quote!(::core::result::Result::Err(#async_proto_crate::ReadError {
447 context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
448 kind: #async_proto_crate::ReadErrorKind::ReadNever,
449 })),
450 quote!(match *self {}),
451 quote!(::core::result::Result::Err(#async_proto_crate::ReadError {
452 context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
453 kind: #async_proto_crate::ReadErrorKind::ReadNever,
454 })),
455 quote!(match *self {}),
456 )
457 } else {
458 let (discrim_ty, unknown_variant_variant, get_discrim) = match variants.len() {
459 0 => unreachable!(), 1..=256 => (quote!(u8), quote!(UnknownVariant8), (&|idx| {
461 let idx = u8::try_from(idx).expect("variant index unexpectedly high");
462 quote!(#idx)
463 }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
464 257..=65_536 => (quote!(u16), quote!(UnknownVariant16), (&|idx| {
465 let idx = u16::try_from(idx).expect("variant index unexpectedly high");
466 quote!(#idx)
467 }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
468 #[cfg(target_pointer_width = "32")]
469 _ => (quote!(u32), quote!(UnknownVariant32), (&|idx| {
470 let idx = u32::try_from(idx).expect("variant index unexpectedly high");
471 quote!(#idx)
472 }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
473 #[cfg(target_pointer_width = "64")]
474 65_537..=4_294_967_296 => (quote!(u32), quote!(UnknownVariant32), (&|idx| {
475 let idx = u32::try_from(idx).expect("variant index unexpectedly high");
476 quote!(#idx)
477 }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
478 #[cfg(target_pointer_width = "64")]
479 _ => (quote!(u64), quote!(UnknownVariant64), (&|idx| {
480 let idx = u64::try_from(idx).expect("variant index unexpectedly high");
481 quote!(#idx)
482 }) as &dyn Fn(usize) -> proc_macro2::TokenStream),
483 };
484 let read_arms = variants.iter()
485 .enumerate()
486 .map(|(idx, Variant { ident: var, fields, .. })| {
487 let idx = get_discrim(idx);
488 let read_fields = read_fields(internal, false, fields);
489 quote!(#idx => ::core::result::Result::Ok(Self::#var #read_fields))
490 })
491 .collect_vec();
492 let write_arms = variants.iter()
493 .enumerate()
494 .map(|(idx, Variant { ident: var, fields, .. })| {
495 let idx = get_discrim(idx);
496 let fields_pat = fields_pat(&fields);
497 let write_fields = write_fields(internal, false, fields);
498 quote! {
499 Self::#var #fields_pat => {
500 #idx.write(sink).await.map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
501 context: #async_proto_crate::ErrorContext::EnumDiscrim {
502 source: Box::new(context),
503 },
504 kind,
505 })?;
506 #write_fields
507 }
508 }
509 })
510 .collect_vec();
511 let read_sync_arms = variants.iter()
512 .enumerate()
513 .map(|(idx, Variant { ident: var, fields, .. })| {
514 let idx = get_discrim(idx);
515 let read_fields = read_fields(internal, true, fields);
516 quote!(#idx => ::core::result::Result::Ok(Self::#var #read_fields))
517 })
518 .collect_vec();
519 let write_sync_arms = variants.iter()
520 .enumerate()
521 .map(|(idx, Variant { ident: var, fields, .. })| {
522 let idx = get_discrim(idx);
523 let fields_pat = fields_pat(&fields);
524 let write_fields = write_fields(internal, true, fields);
525 quote! {
526 Self::#var #fields_pat => {
527 #idx.write_sync(sink).map_err(|#async_proto_crate::WriteError { context, kind }| #async_proto_crate::WriteError {
528 context: #async_proto_crate::ErrorContext::EnumDiscrim {
529 source: Box::new(context),
530 },
531 kind,
532 })?;
533 #write_fields
534 }
535 }
536 })
537 .collect_vec();
538 (
539 quote! {
540 match <#discrim_ty as #async_proto_crate::Protocol>::read(stream).await.map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
541 context: #async_proto_crate::ErrorContext::EnumDiscrim {
542 source: Box::new(context),
543 },
544 kind,
545 })? {
546 #(#read_arms,)*
547 n => ::core::result::Result::Err(#async_proto_crate::ReadError {
548 context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
549 kind: #async_proto_crate::ReadErrorKind::#unknown_variant_variant(n),
550 }),
551 }
552 },
553 quote! {
554 match self {
555 #(#write_arms,)*
556 }
557 ::core::result::Result::Ok(())
558 },
559 quote! {
560 match <#discrim_ty as #async_proto_crate::Protocol>::read_sync(stream).map_err(|#async_proto_crate::ReadError { context, kind }| #async_proto_crate::ReadError {
561 context: #async_proto_crate::ErrorContext::EnumDiscrim {
562 source: Box::new(context),
563 },
564 kind,
565 })? {
566 #(#read_sync_arms,)*
567 n => ::core::result::Result::Err(#async_proto_crate::ReadError {
568 context: #async_proto_crate::ErrorContext::Derived { for_type: #for_type },
569 kind: #async_proto_crate::ReadErrorKind::#unknown_variant_variant(n),
570 }),
571 }
572 },
573 quote! {
574 match self {
575 #(#write_sync_arms,)*
576 }
577 ::core::result::Result::Ok(())
578 },
579 )
580 }
581 }
582 Some(Data::Union(_)) => return quote!(compile_error!("unions not supported in derive(Protocol)");).into(),
583 None => return quote!(compile_error!("missing type layout specification or #[async_proto(via = ...)]");).into(),
584 }
585 };
586 let (impl_generics, ty_generics, where_clause) = impl_generics.split_for_impl();
587 quote! {
588 #(#[#impl_attrs])*
589 impl #impl_generics #async_proto_crate::Protocol for #qual_ty #ty_generics #where_clause {
590 fn read<'a, R: #async_proto_crate::tokio::io::AsyncRead + ::core::marker::Unpin + ::core::marker::Send + 'a>(stream: &'a mut R) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<Self, #async_proto_crate::ReadError>> + ::core::marker::Send + 'a>> {
591 ::std::boxed::Box::pin(async move { #impl_read })
592 }
593
594 fn write<'a, W: #async_proto_crate::tokio::io::AsyncWrite + ::core::marker::Unpin + ::core::marker::Send + 'a>(&'a self, sink: &'a mut W) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<(), #async_proto_crate::WriteError>> + ::core::marker::Send + 'a>> {
595 ::std::boxed::Box::pin(async move { #impl_write })
596 }
597
598 fn read_sync(mut stream: &mut impl ::std::io::Read) -> ::core::result::Result<Self, #async_proto_crate::ReadError> { #impl_read_sync }
599 fn write_sync(&self, mut sink: &mut impl ::std::io::Write) -> ::core::result::Result<(), #async_proto_crate::WriteError> { #impl_write_sync }
600 }
601 }
602}
603
604#[proc_macro_derive(Protocol, attributes(async_proto))]
637pub fn derive_protocol(input: TokenStream) -> TokenStream {
638 let DeriveInput { attrs, ident, generics, data, .. } = parse_macro_input!(input);
639 impl_protocol_inner(false, attrs, parse_quote!(#ident), generics, Some(data)).into()
640}
641
642struct ImplProtocolFor(Vec<(Vec<Attribute>, Path, Generics, Option<Data>)>);
643
644impl Parse for ImplProtocolFor {
645 fn parse(input: ParseStream<'_>) -> Result<Self> {
646 let mut decls = Vec::default();
647 while !input.is_empty() {
648 let attrs = Attribute::parse_outer(input)?;
649 let lookahead = input.lookahead1();
650 decls.push(if lookahead.peek(Token![enum]) {
651 let enum_token = input.parse()?;
652 let path = Path::parse_mod_style(input)?;
653 let generics = input.parse()?;
654 let content;
655 let brace_token = braced!(content in input);
656 let variants = Punctuated::parse_terminated(&content)?;
657 (attrs, path, generics, Some(Data::Enum(DataEnum { enum_token, brace_token, variants })))
658 } else if lookahead.peek(Token![struct]) {
659 let struct_token = input.parse()?;
660 let path = Path::parse_mod_style(input)?;
661 let generics = input.parse()?;
662 let lookahead = input.lookahead1();
663 let fields = if lookahead.peek(Token![;]) {
664 Fields::Unit
665 } else if lookahead.peek(Paren) {
666 let content;
667 let paren_token = parenthesized!(content in input);
668 let unnamed = Punctuated::parse_terminated_with(&content, Field::parse_unnamed)?;
669 Fields::Unnamed(FieldsUnnamed { paren_token, unnamed })
670 } else if lookahead.peek(Brace) {
671 let content;
672 let brace_token = braced!(content in input);
673 let named = Punctuated::parse_terminated_with(&content, Field::parse_named)?;
674 Fields::Named(FieldsNamed { brace_token, named })
675 } else {
676 return Err(lookahead.error())
677 };
678 let semi_token = input.peek(Token![;]).then(|| input.parse()).transpose()?;
679 (attrs, path, generics, Some(Data::Struct(DataStruct { struct_token, fields, semi_token })))
680 } else if lookahead.peek(Token![type]) {
681 let _ = input.parse::<Token![type]>()?;
682 let path = Path::parse_mod_style(input)?;
683 let mut generics = input.parse::<Generics>()?;
684 generics.where_clause = input.parse()?;
685 let _ = input.parse::<Token![;]>()?;
686 (attrs, path, generics, None)
687 } else {
688 return Err(lookahead.error())
689 });
690 }
691 Ok(ImplProtocolFor(decls))
692 }
693}
694
695#[doc(hidden)]
696#[proc_macro]
697pub fn impl_protocol_for(input: TokenStream) -> TokenStream {
698 let impls = parse_macro_input!(input as ImplProtocolFor)
699 .0.into_iter()
700 .map(|(attrs, path, generics, data)| impl_protocol_inner(true, attrs, path, generics, data));
701 TokenStream::from(quote!(#(#impls)*))
702}
703
704struct Bitflags {
705 name: Ident,
706 repr: Ident,
707}
708
709impl Parse for Bitflags {
710 fn parse(input: ParseStream<'_>) -> Result<Self> {
711 let name = input.parse()?;
712 input.parse::<Token![:]>()?;
713 let repr = input.parse()?;
714 Ok(Self { name, repr })
715 }
716}
717
718#[proc_macro]
737pub fn bitflags(input: TokenStream) -> TokenStream {
738 let Bitflags { name, repr } = parse_macro_input!(input);
739 TokenStream::from(quote! {
740 impl ::async_proto::Protocol for #name {
741 fn read<'a, R: ::async_proto::tokio::io::AsyncRead + ::core::marker::Unpin + ::core::marker::Send + 'a>(stream: &'a mut R) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<Self, ::async_proto::ReadError>> + ::core::marker::Send + 'a>> {
742 ::std::boxed::Box::pin(async move {
743 Ok(Self::from_bits_truncate(<#repr as ::async_proto::Protocol>::read(stream).await.map_err(|::async_proto::ReadError { context, kind }| ::async_proto::ReadError {
744 context: ::async_proto::ErrorContext::Bitflags {
745 source: Box::new(context),
746 },
747 kind,
748 })?))
749 })
750 }
751
752 fn write<'a, W: ::async_proto::tokio::io::AsyncWrite + ::core::marker::Unpin + ::core::marker::Send + 'a>(&'a self, sink: &'a mut W) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::core::result::Result<(), ::async_proto::WriteError>> + ::core::marker::Send + 'a>> {
753 ::std::boxed::Box::pin(async move {
754 <#repr as ::async_proto::Protocol>::write(&self.bits(), sink).await.map_err(|::async_proto::WriteError { context, kind }| ::async_proto::WriteError {
755 context: ::async_proto::ErrorContext::Bitflags {
756 source: Box::new(context),
757 },
758 kind,
759 })
760 })
761 }
762
763 fn read_sync(stream: &mut impl ::std::io::Read) -> ::core::result::Result<Self, ::async_proto::ReadError> {
764 Ok(Self::from_bits_truncate(<#repr as ::async_proto::Protocol>::read_sync(stream).map_err(|::async_proto::ReadError { context, kind }| ::async_proto::ReadError {
765 context: ::async_proto::ErrorContext::Bitflags {
766 source: Box::new(context),
767 },
768 kind,
769 })?))
770 }
771
772 fn write_sync(&self, sink: &mut impl ::std::io::Write) -> ::core::result::Result<(), ::async_proto::WriteError> {
773 <#repr as ::async_proto::Protocol>::write_sync(&self.bits(), sink).map_err(|::async_proto::WriteError { context, kind }| ::async_proto::WriteError {
774 context: ::async_proto::ErrorContext::Bitflags {
775 source: Box::new(context),
776 },
777 kind,
778 })
779 }
780 }
781 })
782}