fire_stream_api_codegen/
action.rs1use crate::util::fire_api_crate;
2
3use proc_macro2::TokenStream;
4use syn::{
5 Result, DeriveInput, Error, Ident, Attribute, Data, Expr, ExprLit, Lit,
6 LitInt, Fields, TypePath, Variant
7};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use quote::quote;
11
12
13pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
14 let DeriveInput { attrs, ident, generics, data, .. } = input;
15
16 let d = match data {
17 Data::Enum(e) => e,
18 _ => return Err(Error::new_spanned(ident, "only enums supported"))
19 };
20
21 if !repr_as_u16(attrs)? {
22 return Err(Error::new_spanned(ident, "#[repr(u16)] required"))
23 }
24
25 if !generics.params.is_empty() {
26 return Err(Error::new_spanned(generics, "generics not supported"))
27 }
28
29 let variants = variants_no_fields(d.variants)?;
31
32 let fire = fire_api_crate()?;
33
34 let from_variants = variants.iter()
35 .map(|(num, id)| quote!(#num => Some(Self::#id)));
36
37 let from_u16 = quote!(
38 fn from_u16(num: u16) -> Option<Self> {
39 match num {
40 #(#from_variants),*,
41 _ => None
42 }
43 }
44 );
45
46 let to_variants = variants.iter()
47 .map(|(num, id)| quote!(Self::#id => #num));
48
49 let as_u16 = quote!(
50 fn as_u16(&self) -> u16 {
51 match self {
52 #(#to_variants),*
53 }
54 }
55 );
56
57 Ok(quote!(
58 impl #fire::message::Action for #ident {
59 #from_u16
60 #as_u16
61 }
62 ))
63}
64
65fn repr_as_u16(attrs: Vec<Attribute>) -> Result<bool> {
66 let mut repr_as = None;
67
68 for attr in attrs {
69 if !attr.path().is_ident("repr") {
70 continue
71 }
72
73 let ty: TypePath = attr.parse_args()?;
74
75 repr_as = Some(ty);
76 }
77
78 match repr_as {
79 Some(path) => {
80 if !path.path.is_ident("u16") {
81 return Err(Error::new_spanned(path, "expected u16"));
82 }
83
84 Ok(true)
85 },
86 None => Ok(false)
87 }
88}
89
90fn variants_no_fields(
91 variants: Punctuated<Variant, Comma>
92) -> Result<Vec<(LitInt, Ident)>> {
93 variants.into_iter()
94 .map(|v| {
95 let fieldnum_expr = v.discriminant
96 .ok_or_else(|| Error::new_spanned(
97 &v.ident,
98 "needs to have a field number `Ident = x`"
99 ))?
100 .1;
101 let fieldnum = match fieldnum_expr {
102 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int,
103 e => return Err(Error::new_spanned(e, "expected = int"))
104 };
105
106 let fieldnum_zero = fieldnum.base10_digits() == "0";
107
108 if fieldnum_zero {
109 return Err(Error::new_spanned(
110 fieldnum_zero,
111 "zero not allowed"
112 ))
113 }
114
115 if !matches!(v.fields, Fields::Unit) {
116 return Err(Error::new_spanned(v.fields, "no fields allowed"))
117 }
118
119 Ok((fieldnum, v.ident))
120 })
121 .collect()
122}