1#![recursion_limit = "4096"]
3
4extern crate alloc;
5extern crate proc_macro;
6
7use anyhow::{bail, Error};
8use field::{Kind, OneofVariant};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::punctuated::Punctuated;
14use syn::{Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Ident, Index, Variant};
15
16mod field;
17use crate::field::Field;
18
19fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
20 let input: DeriveInput = syn::parse(input)?;
21
22 let ident = input.ident;
23
24 let variant_data = match input.data {
25 Data::Struct(variant_data) => variant_data,
26 Data::Enum(..) => bail!("Message can not be derived for an enum"),
27 Data::Union(..) => bail!("Message can not be derived for a union"),
28 };
29
30 let generics = &input.generics;
31 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
32
33 let (_is_struct, fields) = match variant_data {
34 DataStruct {
35 fields: Fields::Named(FieldsNamed { named: fields, .. }),
36 ..
37 } => (true, fields.into_iter().collect()),
38 DataStruct {
39 fields: Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }),
40 ..
41 } => (false, fields.into_iter().collect()),
42 DataStruct {
43 fields: Fields::Unit, ..
44 } => (false, Vec::new()),
45 };
46
47 let mut fields = fields
48 .into_iter()
49 .enumerate()
50 .map(|(i, field)| {
51 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
52 let index = Index {
53 index: i as u32,
54 span: Span::call_site(),
55 };
56 quote!(#index)
57 });
58 match Field::new(field.attrs) {
59 Ok(field) => Ok((field_ident, field)),
60 Err(err) => Err(err.context(format!("invalid message field {}.{}", ident, field_ident))),
61 }
62 })
63 .collect::<Result<Vec<_>, _>>()?;
64
65 fields.sort_by_key(|&(_, ref field)| field.tags.iter().copied().min().unwrap());
70 let fields = fields;
71
72 let mut tags = fields.iter().flat_map(|(_, field)| &field.tags).collect::<Vec<_>>();
73 let num_tags = tags.len();
74 tags.sort_unstable();
75 tags.dedup();
76 if tags.len() != num_tags {
77 bail!("message {} has fields with duplicate tags", ident);
78 }
79
80 let write = fields.iter().map(|&(ref field_ident, ref field)| {
81 let tag = field.tags[0];
82 let ident = quote!(self.#field_ident);
83 match field.kind {
84 Kind::Single => quote!(w.write_field(#tag, &#ident)?;),
85 Kind::Repeated => quote!(w.write_repeated(#tag, &#ident)?;),
86 Kind::Optional => quote!(w.write_optional(#tag, &#ident)?;),
87 Kind::Oneof => quote!(w.write_oneof(&#ident)?;),
88 }
89 });
90
91 let read = fields.iter().map(|&(ref field_ident, ref field)| {
92 let ident = quote!(self.#field_ident);
93 let read = match field.kind {
94 Kind::Single => quote!(r.read(&mut #ident)?;),
95 Kind::Repeated => quote!(r.read_repeated(&mut #ident)?;),
96 Kind::Optional => quote!(r.read_optional(&mut #ident)?;),
97 Kind::Oneof => quote!(r.read_oneof(&mut #ident)?;),
98 };
99
100 let tags = field.tags.iter().map(|&tag| quote!(#tag));
101 let tags = Itertools::intersperse(tags, quote!(|));
102
103 quote!(#(#tags)* => { #read })
104 });
105
106 let expanded = quote! {
107 impl #impl_generics ::noproto::Message for #ident #ty_generics #where_clause {
108 const WIRE_TYPE: ::noproto::WireType = ::noproto::WireType::LengthDelimited;
109
110 fn write_raw(&self, w: &mut ::noproto::encoding::ByteWriter) -> Result<(), ::noproto::WriteError> {
111 #(#write)*
112 Ok(())
113 }
114
115 fn read_raw(&mut self, r: &mut ::noproto::encoding::ByteReader) -> Result<(), ::noproto::ReadError> {
116 for r in r.read_fields() {
117 let r = r?;
118 match r.tag() {
119 #(#read)*
120 _ => {}
121 }
122 }
123 Ok(())
124 }
125 }
126 };
127
128 Ok(expanded.into())
129}
130
131#[proc_macro_derive(Message, attributes(noproto))]
132pub fn message(input: TokenStream) -> TokenStream {
133 try_message(input).unwrap()
134}
135
136fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
137 let input: DeriveInput = syn::parse(input)?;
138 let ident = input.ident;
139
140 let generics = &input.generics;
141 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
142
143 let punctuated_variants = match input.data {
144 Data::Enum(DataEnum { variants, .. }) => variants,
145 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
146 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
147 };
148
149 let mut variants: Vec<(Ident, Expr)> = Vec::new();
151 for Variant {
152 ident,
153 fields,
154 discriminant,
155 ..
156 } in punctuated_variants
157 {
158 match fields {
159 Fields::Unit => (),
160 Fields::Named(_) | Fields::Unnamed(_) => {
161 bail!("Enumeration variants may not have fields")
162 }
163 }
164
165 match discriminant {
166 Some((_, expr)) => variants.push((ident, expr)),
167 None => bail!("Enumeration variants must have a discriminant"),
168 }
169 }
170
171 if variants.is_empty() {
172 panic!("Enumeration must have at least one variant");
173 }
174
175 let _default = variants[0].0.clone();
176
177 let _is_valid = variants.iter().map(|&(_, ref value)| quote!(#value => true));
178
179 let write = variants
180 .iter()
181 .map(|(variant, value)| quote!(#ident::#variant => #value));
182
183 let read = variants
184 .iter()
185 .map(|(variant, value)| quote!(#value => #ident::#variant ));
186
187 let expanded = quote! {
188 impl #impl_generics ::noproto::Message for #ident #ty_generics #where_clause {
189
190 const WIRE_TYPE: ::noproto::WireType = ::noproto::WireType::Varint;
191
192 fn write_raw(&self, w: &mut ::noproto::encoding::ByteWriter) -> Result<(), ::noproto::WriteError> {
193 let val = match self {
194 #(#write,)*
195 };
196 w.write_varuint32(*self as _)
197 }
198
199 fn read_raw(&mut self, r: &mut ::noproto::encoding::ByteReader) -> Result<(), ::noproto::ReadError> {
200 *self = match r.read_varuint32()? {
201 #(#read,)*
202 _ => return Err(::noproto::ReadError),
203 };
204 Ok(())
205 }
206 }
207 };
208
209 Ok(expanded.into())
210}
211
212#[proc_macro_derive(Enumeration, attributes(noproto))]
213pub fn enumeration(input: TokenStream) -> TokenStream {
214 try_enumeration(input).unwrap()
215}
216
217fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
218 let input: DeriveInput = syn::parse(input)?;
219
220 let ident = input.ident;
221
222 let variants = match input.data {
223 Data::Enum(DataEnum { variants, .. }) => variants,
224 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
225 Data::Union(..) => bail!("Oneof can not be derived for a union"),
226 };
227
228 let generics = &input.generics;
229 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
230
231 let mut oneof_variants: Vec<(Ident, OneofVariant)> = Vec::new();
233 for Variant {
234 attrs,
235 ident: variant_ident,
236 fields: variant_fields,
237 ..
238 } in variants
239 {
240 let variant_fields = match variant_fields {
241 Fields::Unit => Punctuated::new(),
242 Fields::Named(FieldsNamed { named: fields, .. })
243 | Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => fields,
244 };
245 if variant_fields.len() != 1 {
246 bail!("Oneof enum variants must have a single field");
247 }
248
249 match OneofVariant::new(attrs) {
250 Ok(variant) => oneof_variants.push((variant_ident, variant)),
251 Err(err) => bail!("invalid oneof variant {}.{}: {}", ident, variant_ident, err),
252 }
253 }
254
255 let mut tags = oneof_variants.iter().map(|(_, v)| v.tag).collect::<Vec<_>>();
256 tags.sort_unstable();
257 tags.dedup();
258 if tags.len() != oneof_variants.len() {
259 panic!("invalid oneof {}: variants have duplicate tags", ident);
260 }
261
262 let write = oneof_variants.iter().map(|(variant_ident, variant)| {
263 let tag = variant.tag;
264 quote!(#ident::#variant_ident(value) => { w.write_field(#tag, value)?; })
265 });
266
267 let read = oneof_variants.iter().map(|(variant_ident, variant)| {
268 let tag = variant.tag;
269 quote!(#tag => {
270 *self = #ident::#variant_ident(r.read_oneof_variant()?);
271 })
272 });
273
274 let read_option = oneof_variants.iter().map(|(variant_ident, variant)| {
275 let tag = variant.tag;
276 quote!(#tag => {
277 *this = Some(#ident::#variant_ident(r.read_oneof_variant()?));
278 })
279 });
280
281 let expanded = quote! {
282 impl #impl_generics ::noproto::Oneof for #ident #ty_generics #where_clause {
283 fn write_raw(&self, w: &mut ::noproto::encoding::ByteWriter) -> Result<(), ::noproto::WriteError> {
284 match self {
285 #(#write)*
286 }
287 Ok(())
288 }
289
290 fn read_raw(&mut self, r: ::noproto::encoding::FieldReader) -> Result<(), ::noproto::ReadError> {
291 match r.tag() {
292 #(#read)*
293 _ => return Err(::noproto::ReadError),
294 }
295 Ok(())
296 }
297
298 fn read_raw_option(this: &mut Option<Self>, r: ::noproto::encoding::FieldReader) -> Result<(), ::noproto::ReadError> {
299 match r.tag() {
300 #(#read_option)*
301 _ => return Err(::noproto::ReadError),
302 }
303 Ok(())
304 }
305 }
306 };
307
308 Ok(expanded.into())
309}
310
311#[proc_macro_derive(Oneof, attributes(noproto))]
312pub fn oneof(input: TokenStream) -> TokenStream {
313 try_oneof(input).unwrap()
314}