fire_stream_api_codegen/
action.rs

1use crate::util::fire_api_crate;
2
3use proc_macro2::TokenStream;
4use syn::{
5	Result, DeriveInput, Error, Ident, Attribute, Data, Expr, ExprLit, Lit,
6	LitInt, Fields, TypePath, Variant
7};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use quote::quote;
11
12
13pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
14	let DeriveInput { attrs, ident, generics, data, .. } = input;
15
16	let d = match data {
17		Data::Enum(e) => e,
18		_ => return Err(Error::new_spanned(ident, "only enums supported"))
19	};
20
21	if !repr_as_u16(attrs)? {
22		return Err(Error::new_spanned(ident, "#[repr(u16)] required"))
23	}
24
25	if !generics.params.is_empty() {
26		return Err(Error::new_spanned(generics, "generics not supported"))
27	}
28
29	// (fieldnum, ident)
30	let variants = variants_no_fields(d.variants)?;
31
32	let fire = fire_api_crate()?;
33
34	let from_variants = variants.iter()
35		.map(|(num, id)| quote!(#num => Some(Self::#id)));
36
37	let from_u16 = quote!(
38		fn from_u16(num: u16) -> Option<Self> {
39			match num {
40				#(#from_variants),*,
41				_ => None
42			}
43		}
44	);
45
46	let to_variants = variants.iter()
47		.map(|(num, id)| quote!(Self::#id => #num));
48
49	let as_u16 = quote!(
50		fn as_u16(&self) -> u16 {
51			match self {
52				#(#to_variants),*
53			}
54		}
55	);
56
57	Ok(quote!(
58		impl #fire::message::Action for #ident {
59			#from_u16
60			#as_u16
61		}
62	))
63}
64
65fn repr_as_u16(attrs: Vec<Attribute>) -> Result<bool> {
66	let mut repr_as = None;
67
68	for attr in attrs {
69		if !attr.path().is_ident("repr") {
70			continue
71		}
72
73		let ty: TypePath = attr.parse_args()?;
74
75		repr_as = Some(ty);
76	}
77
78	match repr_as {
79		Some(path) => {
80			if !path.path.is_ident("u16") {
81				return Err(Error::new_spanned(path, "expected u16"));
82			}
83
84			Ok(true)
85		},
86		None => Ok(false)
87	}
88}
89
90fn variants_no_fields(
91	variants: Punctuated<Variant, Comma>
92) -> Result<Vec<(LitInt, Ident)>> {
93	variants.into_iter()
94		.map(|v| {
95			let fieldnum_expr = v.discriminant
96				.ok_or_else(|| Error::new_spanned(
97					&v.ident,
98					"needs to have a field number `Ident = x`"
99				))?
100				.1;
101			let fieldnum = match fieldnum_expr {
102				Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int,
103				e => return Err(Error::new_spanned(e, "expected = int"))
104			};
105
106			let fieldnum_zero = fieldnum.base10_digits() == "0";
107
108			if fieldnum_zero {
109				return Err(Error::new_spanned(
110					fieldnum_zero,
111					"zero not allowed"
112				))
113			}
114
115			if !matches!(v.fields, Fields::Unit) {
116				return Err(Error::new_spanned(v.fields, "no fields allowed"))
117			}
118
119			Ok((fieldnum, v.ident))
120		})
121		.collect()
122}