noproto_derive/
lib.rs

1// The `quote!` macro requires deep recursion.
2#![recursion_limit = "4096"]
3
4extern crate alloc;
5extern crate proc_macro;
6
7use anyhow::{bail, Error};
8use field::{Kind, OneofVariant};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::punctuated::Punctuated;
14use syn::{Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Ident, Index, Variant};
15
16mod field;
17use crate::field::Field;
18
19fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
20    let input: DeriveInput = syn::parse(input)?;
21
22    let ident = input.ident;
23
24    let variant_data = match input.data {
25        Data::Struct(variant_data) => variant_data,
26        Data::Enum(..) => bail!("Message can not be derived for an enum"),
27        Data::Union(..) => bail!("Message can not be derived for a union"),
28    };
29
30    let generics = &input.generics;
31    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
32
33    let (_is_struct, fields) = match variant_data {
34        DataStruct {
35            fields: Fields::Named(FieldsNamed { named: fields, .. }),
36            ..
37        } => (true, fields.into_iter().collect()),
38        DataStruct {
39            fields: Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }),
40            ..
41        } => (false, fields.into_iter().collect()),
42        DataStruct {
43            fields: Fields::Unit, ..
44        } => (false, Vec::new()),
45    };
46
47    let mut fields = fields
48        .into_iter()
49        .enumerate()
50        .map(|(i, field)| {
51            let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
52                let index = Index {
53                    index: i as u32,
54                    span: Span::call_site(),
55                };
56                quote!(#index)
57            });
58            match Field::new(field.attrs) {
59                Ok(field) => Ok((field_ident, field)),
60                Err(err) => Err(err.context(format!("invalid message field {}.{}", ident, field_ident))),
61            }
62        })
63        .collect::<Result<Vec<_>, _>>()?;
64
65    // Sort the fields by tag number so that fields will be encoded in tag order.
66    // TODO: This encodes oneof fields in the position of their lowest tag,
67    // regardless of the currently occupied variant, is that consequential?
68    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
69    fields.sort_by_key(|&(_, ref field)| field.tags.iter().copied().min().unwrap());
70    let fields = fields;
71
72    let mut tags = fields.iter().flat_map(|(_, field)| &field.tags).collect::<Vec<_>>();
73    let num_tags = tags.len();
74    tags.sort_unstable();
75    tags.dedup();
76    if tags.len() != num_tags {
77        bail!("message {} has fields with duplicate tags", ident);
78    }
79
80    let write = fields.iter().map(|&(ref field_ident, ref field)| {
81        let tag = field.tags[0];
82        let ident = quote!(self.#field_ident);
83        match field.kind {
84            Kind::Single => quote!(w.write_field(#tag, &#ident)?;),
85            Kind::Repeated => quote!(w.write_repeated(#tag, &#ident)?;),
86            Kind::Optional => quote!(w.write_optional(#tag, &#ident)?;),
87            Kind::Oneof => quote!(w.write_oneof(&#ident)?;),
88        }
89    });
90
91    let read = fields.iter().map(|&(ref field_ident, ref field)| {
92        let ident = quote!(self.#field_ident);
93        let read = match field.kind {
94            Kind::Single => quote!(r.read(&mut #ident)?;),
95            Kind::Repeated => quote!(r.read_repeated(&mut #ident)?;),
96            Kind::Optional => quote!(r.read_optional(&mut #ident)?;),
97            Kind::Oneof => quote!(r.read_oneof(&mut #ident)?;),
98        };
99
100        let tags = field.tags.iter().map(|&tag| quote!(#tag));
101        let tags = Itertools::intersperse(tags, quote!(|));
102
103        quote!(#(#tags)* => { #read })
104    });
105
106    let expanded = quote! {
107        impl #impl_generics ::noproto::Message for #ident #ty_generics #where_clause {
108            const WIRE_TYPE: ::noproto::WireType = ::noproto::WireType::LengthDelimited;
109
110            fn write_raw(&self, w: &mut ::noproto::encoding::ByteWriter) -> Result<(), ::noproto::WriteError> {
111                #(#write)*
112                Ok(())
113            }
114
115            fn read_raw(&mut self, r: &mut ::noproto::encoding::ByteReader) -> Result<(), ::noproto::ReadError> {
116                for r in r.read_fields() {
117                    let r = r?;
118                    match r.tag() {
119                        #(#read)*
120                        _ => {}
121                    }
122                }
123                Ok(())
124            }
125        }
126    };
127
128    Ok(expanded.into())
129}
130
131#[proc_macro_derive(Message, attributes(noproto))]
132pub fn message(input: TokenStream) -> TokenStream {
133    try_message(input).unwrap()
134}
135
136fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
137    let input: DeriveInput = syn::parse(input)?;
138    let ident = input.ident;
139
140    let generics = &input.generics;
141    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
142
143    let punctuated_variants = match input.data {
144        Data::Enum(DataEnum { variants, .. }) => variants,
145        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
146        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
147    };
148
149    // Map the variants into 'fields'.
150    let mut variants: Vec<(Ident, Expr)> = Vec::new();
151    for Variant {
152        ident,
153        fields,
154        discriminant,
155        ..
156    } in punctuated_variants
157    {
158        match fields {
159            Fields::Unit => (),
160            Fields::Named(_) | Fields::Unnamed(_) => {
161                bail!("Enumeration variants may not have fields")
162            }
163        }
164
165        match discriminant {
166            Some((_, expr)) => variants.push((ident, expr)),
167            None => bail!("Enumeration variants must have a discriminant"),
168        }
169    }
170
171    if variants.is_empty() {
172        panic!("Enumeration must have at least one variant");
173    }
174
175    let _default = variants[0].0.clone();
176
177    let _is_valid = variants.iter().map(|&(_, ref value)| quote!(#value => true));
178
179    let write = variants
180        .iter()
181        .map(|(variant, value)| quote!(#ident::#variant => #value));
182
183    let read = variants
184        .iter()
185        .map(|(variant, value)| quote!(#value => #ident::#variant ));
186
187    let expanded = quote! {
188        impl #impl_generics  ::noproto::Message for #ident #ty_generics #where_clause {
189
190            const WIRE_TYPE: ::noproto::WireType = ::noproto::WireType::Varint;
191
192            fn write_raw(&self, w: &mut ::noproto::encoding::ByteWriter) -> Result<(), ::noproto::WriteError> {
193                let val = match self {
194                    #(#write,)*
195                };
196                w.write_varuint32(*self as _)
197            }
198
199            fn read_raw(&mut self, r: &mut ::noproto::encoding::ByteReader) -> Result<(), ::noproto::ReadError> {
200                *self = match r.read_varuint32()? {
201                    #(#read,)*
202                    _ => return Err(::noproto::ReadError),
203                };
204                Ok(())
205            }
206        }
207    };
208
209    Ok(expanded.into())
210}
211
212#[proc_macro_derive(Enumeration, attributes(noproto))]
213pub fn enumeration(input: TokenStream) -> TokenStream {
214    try_enumeration(input).unwrap()
215}
216
217fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
218    let input: DeriveInput = syn::parse(input)?;
219
220    let ident = input.ident;
221
222    let variants = match input.data {
223        Data::Enum(DataEnum { variants, .. }) => variants,
224        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
225        Data::Union(..) => bail!("Oneof can not be derived for a union"),
226    };
227
228    let generics = &input.generics;
229    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
230
231    // Map the variants
232    let mut oneof_variants: Vec<(Ident, OneofVariant)> = Vec::new();
233    for Variant {
234        attrs,
235        ident: variant_ident,
236        fields: variant_fields,
237        ..
238    } in variants
239    {
240        let variant_fields = match variant_fields {
241            Fields::Unit => Punctuated::new(),
242            Fields::Named(FieldsNamed { named: fields, .. })
243            | Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => fields,
244        };
245        if variant_fields.len() != 1 {
246            bail!("Oneof enum variants must have a single field");
247        }
248
249        match OneofVariant::new(attrs) {
250            Ok(variant) => oneof_variants.push((variant_ident, variant)),
251            Err(err) => bail!("invalid oneof variant {}.{}: {}", ident, variant_ident, err),
252        }
253    }
254
255    let mut tags = oneof_variants.iter().map(|(_, v)| v.tag).collect::<Vec<_>>();
256    tags.sort_unstable();
257    tags.dedup();
258    if tags.len() != oneof_variants.len() {
259        panic!("invalid oneof {}: variants have duplicate tags", ident);
260    }
261
262    let write = oneof_variants.iter().map(|(variant_ident, variant)| {
263        let tag = variant.tag;
264        quote!(#ident::#variant_ident(value) => { w.write_field(#tag, value)?; })
265    });
266
267    let read = oneof_variants.iter().map(|(variant_ident, variant)| {
268        let tag = variant.tag;
269        quote!(#tag => {
270            *self = #ident::#variant_ident(r.read_oneof_variant()?);
271        })
272    });
273
274    let read_option = oneof_variants.iter().map(|(variant_ident, variant)| {
275        let tag = variant.tag;
276        quote!(#tag => {
277            *this = Some(#ident::#variant_ident(r.read_oneof_variant()?));
278        })
279    });
280
281    let expanded = quote! {
282        impl #impl_generics ::noproto::Oneof for #ident #ty_generics #where_clause {
283            fn write_raw(&self, w: &mut ::noproto::encoding::ByteWriter) -> Result<(), ::noproto::WriteError> {
284                match self {
285                    #(#write)*
286                }
287                Ok(())
288            }
289
290            fn read_raw(&mut self, r: ::noproto::encoding::FieldReader) -> Result<(), ::noproto::ReadError> {
291                match r.tag() {
292                    #(#read)*
293                    _ => return Err(::noproto::ReadError),
294                }
295                Ok(())
296            }
297
298            fn read_raw_option(this: &mut Option<Self>, r: ::noproto::encoding::FieldReader) -> Result<(), ::noproto::ReadError> {
299                match r.tag() {
300                    #(#read_option)*
301                    _ => return Err(::noproto::ReadError),
302                }
303                Ok(())
304            }
305        }
306    };
307
308    Ok(expanded.into())
309}
310
311#[proc_macro_derive(Oneof, attributes(noproto))]
312pub fn oneof(input: TokenStream) -> TokenStream {
313    try_oneof(input).unwrap()
314}