jetstream_p9_wire_format_derive/
lib.rs

1// Copyright 2018 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Derives a 9P wire format encoding for a struct by recursively calling
6//! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7//! This is only intended to be used from within the `p9` crate.
8
9#![recursion_limit = "256"]
10
11extern crate proc_macro;
12extern crate proc_macro2;
13
14#[macro_use]
15extern crate quote;
16
17#[macro_use]
18extern crate syn;
19
20use proc_macro2::Span;
21use proc_macro2::TokenStream;
22use syn::spanned::Spanned;
23use syn::Data;
24use syn::DeriveInput;
25use syn::Fields;
26use syn::Ident;
27
28/// The function that derives the actual implementation.
29#[proc_macro_derive(P9WireFormat)]
30pub fn p9_wire_format(
31    input: proc_macro::TokenStream,
32) -> proc_macro::TokenStream {
33    let input = parse_macro_input!(input as DeriveInput);
34    p9_wire_format_inner(input).into()
35}
36
37fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
38    if !input.generics.params.is_empty() {
39        return quote! {
40            compile_error!("derive(P9WireFormat) does not support generic parameters");
41        };
42    }
43
44    let container = input.ident;
45
46    let byte_size_impl = byte_size_sum(&input.data);
47    let encode_impl = encode_wire_format(&input.data);
48    let decode_impl = decode_wire_format(&input.data, &container);
49
50    let scope = format!("wire_format_{}", container).to_lowercase();
51    let scope = Ident::new(&scope, Span::call_site());
52    quote! {
53        mod #scope {
54            extern crate std;
55            use self::std::io;
56            use self::std::result::Result::Ok;
57
58            use super::#container;
59
60            use protocol::WireFormat;
61
62            impl WireFormat for #container {
63                fn byte_size(&self) -> u32 {
64                    #byte_size_impl
65                }
66
67                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
68                    #encode_impl
69                }
70
71                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
72                    #decode_impl
73                }
74            }
75        }
76    }
77}
78
79// Generate code that recursively calls byte_size on every field in the struct.
80fn byte_size_sum(data: &Data) -> TokenStream {
81    if let Data::Struct(ref data) = *data {
82        if let Fields::Named(ref fields) = data.fields {
83            let fields = fields.named.iter().map(|f| {
84                let field = &f.ident;
85                let span = field.span();
86                quote_spanned! {span=>
87                    WireFormat::byte_size(&self.#field)
88                }
89            });
90
91            quote! {
92                0 #(+ #fields)*
93            }
94        } else if let Fields::Unnamed(unnamed) = &data.fields {
95            let fields = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
96                let index = syn::Index::from(i);
97                quote! {
98                    WireFormat::byte_size(&self.#index)
99                }
100            });
101
102            quote! {
103                0 #(+ #fields)*
104            }
105        } else {
106            unimplemented!();
107        }
108    } else {
109        unimplemented!();
110    }
111}
112
113// Generate code that recursively calls encode on every field in the struct.
114fn encode_wire_format(data: &Data) -> TokenStream {
115    if let Data::Struct(ref data) = *data {
116        if let Fields::Named(ref fields) = data.fields {
117            let fields = fields.named.iter().map(|f| {
118                let field = &f.ident;
119                let span = field.span();
120                quote_spanned! {span=>
121                    WireFormat::encode(&self.#field, _writer)?;
122                }
123            });
124
125            quote! {
126                #(#fields)*
127
128                Ok(())
129            }
130        } else if let Fields::Unnamed(unnamed) = &data.fields {
131            let fields = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
132                let index = syn::Index::from(i);
133                quote! {
134                    WireFormat::encode(&self.#index, _writer)?;
135                }
136            });
137
138            quote! {
139                 #(#fields)*
140
141                Ok(())
142            }
143        } else {
144            unimplemented!();
145        }
146    } else {
147        unimplemented!();
148    }
149}
150
151// Generate code that recursively calls decode on every field in the struct.
152fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
153    if let Data::Struct(ref data) = *data {
154        if let Fields::Named(ref fields) = data.fields {
155            let values = fields.named.iter().map(|f| {
156                let field = &f.ident;
157                let span = field.span();
158                quote_spanned! {span=>
159                    let #field = WireFormat::decode(_reader)?;
160                }
161            });
162
163            let members = fields.named.iter().map(|f| {
164                let field = &f.ident;
165                quote! {
166                    #field: #field,
167                }
168            });
169            
170
171            quote! {
172                #(#values)*
173
174                Ok(#container {
175                    #(#members)*
176                })
177            }
178        } else if let Fields::Unnamed(unnamed) = &data.fields {
179            let values = unnamed.unnamed.iter().enumerate().map(|(i, f)| {
180                let index = syn::Index::from(i);
181                // create a new ident that s __{index}
182                let ident = Ident::new(
183                    &format!("__{}", index.index),
184                    Span::call_site(),
185                );
186                quote! {
187                    let #ident: #f = WireFormat::decode(_reader)?;
188                }
189            });
190
191            let members = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
192                let index = syn::Index::from(i);
193                let ident = Ident::new(
194                    &format!("__{}", index.index),
195                    Span::call_site(),
196                );
197                quote! {
198                    #ident
199                }
200            });
201
202            quote! {
203                #(#values)*
204                Ok(#container(
205                    #(#members)*
206                ))
207            }
208        } else {
209            unimplemented!();
210        }
211    } else {
212        unimplemented!();
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    extern crate pretty_assertions;
219    use self::pretty_assertions::assert_eq;
220    use super::*;
221
222    #[test]
223    fn byte_size() {
224        let input: DeriveInput = parse_quote! {
225            struct Item {
226                ident: u32,
227                with_underscores: String,
228                other: u8,
229            }
230        };
231
232        let expected = quote! {
233            0
234                + WireFormat::byte_size(&self.ident)
235                + WireFormat::byte_size(&self.with_underscores)
236                + WireFormat::byte_size(&self.other)
237        };
238
239        assert_eq!(
240            byte_size_sum(&input.data).to_string(),
241            expected.to_string()
242        );
243    }
244
245    #[test]
246    fn encode() {
247        let input: DeriveInput = parse_quote! {
248            struct Item {
249                ident: u32,
250                with_underscores: String,
251                other: u8,
252            }
253        };
254
255        let expected = quote! {
256            WireFormat::encode(&self.ident, _writer)?;
257            WireFormat::encode(&self.with_underscores, _writer)?;
258            WireFormat::encode(&self.other, _writer)?;
259            Ok(())
260        };
261
262        assert_eq!(
263            encode_wire_format(&input.data).to_string(),
264            expected.to_string(),
265        );
266    }
267
268    #[test]
269    fn decode() {
270        let input: DeriveInput = parse_quote! {
271            struct Item {
272                ident: u32,
273                with_underscores: String,
274                other: u8,
275            }
276        };
277
278        let container = Ident::new("Item", Span::call_site());
279        let expected = quote! {
280            let ident = WireFormat::decode(_reader)?;
281            let with_underscores = WireFormat::decode(_reader)?;
282            let other = WireFormat::decode(_reader)?;
283            Ok(Item {
284                ident: ident,
285                with_underscores: with_underscores,
286                other: other,
287            })
288        };
289
290        assert_eq!(
291            decode_wire_format(&input.data, &container).to_string(),
292            expected.to_string(),
293        );
294    }
295
296    #[test]
297    fn end_to_end() {
298        let input: DeriveInput = parse_quote! {
299            struct Niijima_先輩 {
300                a: u8,
301                b: u16,
302                c: u32,
303                d: u64,
304                e: String,
305                f: Vec<String>,
306                g: Nested,
307            }
308        };
309
310        let expected = quote! {
311            mod wire_format_niijima_先輩 {
312                extern crate std;
313                use self::std::io;
314                use self::std::result::Result::Ok;
315
316                use super::Niijima_先輩;
317
318                use protocol::WireFormat;
319
320                impl WireFormat for Niijima_先輩 {
321                    fn byte_size(&self) -> u32 {
322                        0
323                        + WireFormat::byte_size(&self.a)
324                        + WireFormat::byte_size(&self.b)
325                        + WireFormat::byte_size(&self.c)
326                        + WireFormat::byte_size(&self.d)
327                        + WireFormat::byte_size(&self.e)
328                        + WireFormat::byte_size(&self.f)
329                        + WireFormat::byte_size(&self.g)
330                    }
331
332                    fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
333                        WireFormat::encode(&self.a, _writer)?;
334                        WireFormat::encode(&self.b, _writer)?;
335                        WireFormat::encode(&self.c, _writer)?;
336                        WireFormat::encode(&self.d, _writer)?;
337                        WireFormat::encode(&self.e, _writer)?;
338                        WireFormat::encode(&self.f, _writer)?;
339                        WireFormat::encode(&self.g, _writer)?;
340                        Ok(())
341                    }
342                    fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
343                        let a = WireFormat::decode(_reader)?;
344                        let b = WireFormat::decode(_reader)?;
345                        let c = WireFormat::decode(_reader)?;
346                        let d = WireFormat::decode(_reader)?;
347                        let e = WireFormat::decode(_reader)?;
348                        let f = WireFormat::decode(_reader)?;
349                        let g = WireFormat::decode(_reader)?;
350                        Ok(Niijima_先輩 {
351                            a: a,
352                            b: b,
353                            c: c,
354                            d: d,
355                            e: e,
356                            f: f,
357                            g: g,
358                        })
359                    }
360                }
361            }
362        };
363
364        assert_eq!(
365            p9_wire_format_inner(input).to_string(),
366            expected.to_string(),
367        );
368    }
369}