jetstream_wire_format_derive/
lib.rs

1// Copyright 2018 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Derives a 9P wire format encoding for a struct by recursively calling
6//! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7//! This is only intended to be used from within the `p9` crate.
8
9#![recursion_limit = "256"]
10extern crate proc_macro2;
11
12#[macro_use]
13extern crate quote;
14
15#[macro_use]
16extern crate syn;
17
18use proc_macro2::{TokenStream, Span};
19use syn::{DeriveInput, Fields, Ident, Data, spanned::Spanned};
20
21mod protocol;
22
23/// The function that derives the actual implementation.
24#[proc_macro_derive(JetStreamWireFormat)]
25pub fn p9_wire_format(
26    input: proc_macro::TokenStream,
27) -> proc_macro::TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29    p9_wire_format_inner(input).into()
30}
31
32#[proc_macro_attribute]
33pub fn protocol(
34    _attr: proc_macro::TokenStream,
35    item: proc_macro::TokenStream,
36) -> proc_macro::TokenStream {
37    protocol::protocol_inner(_attr, item)
38}
39
40fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
41    if !input.generics.params.is_empty() {
42        return quote! {
43            compile_error!("derive(JetStreamWireFormat) does not support generic parameters");
44        };
45    }
46
47    let container = input.ident;
48
49    let byte_size_impl = byte_size_sum(&input.data);
50    let encode_impl = encode_wire_format(&input.data);
51    let decode_impl = decode_wire_format(&input.data, &container);
52
53    let scope = format!("wire_format_{}", container).to_lowercase();
54    let scope = Ident::new(&scope, Span::call_site());
55    quote! {
56        mod #scope {
57            extern crate std;
58            use self::std::io;
59            use self::std::result::Result::Ok;
60
61            use super::#container;
62            use super::WireFormat;
63
64            impl WireFormat for #container {
65                fn byte_size(&self) -> u32 {
66                    #byte_size_impl
67                }
68
69                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
70                    #encode_impl
71                }
72
73                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
74                    #decode_impl
75                }
76            }
77        }
78    }
79}
80
81// Generate code that recursively calls byte_size on every field in the struct.
82fn byte_size_sum(data: &Data) -> TokenStream {
83    if let Data::Struct(ref data) = *data {
84        if let Fields::Named(ref fields) = data.fields {
85            let fields = fields.named.iter().map(|f| {
86                let field = &f.ident;
87                let span = field.span();
88                quote_spanned! {span=>
89                    WireFormat::byte_size(&self.#field)
90                }
91            });
92
93            quote! {
94                0 #(+ #fields)*
95            }
96        } else if let Fields::Unnamed(unnamed) = &data.fields {
97            let fields = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
98                let index = syn::Index::from(i);
99                quote! {
100                    WireFormat::byte_size(&self.#index)
101                }
102            });
103
104            quote! {
105                0 #(+ #fields)*
106            }
107        } else {
108            unimplemented!("byte_size_sum for {:?}", data.struct_token.span);
109        }
110    } else {
111        unimplemented!("byte_size_sum for ");
112    }
113}
114
115// Generate code that recursively calls encode on every field in the struct.
116fn encode_wire_format(data: &Data) -> TokenStream {
117    if let Data::Struct(ref data) = *data {
118        if let Fields::Named(ref fields) = data.fields {
119            let fields = fields.named.iter().map(|f| {
120                let field = &f.ident;
121                let span = field.span();
122                quote_spanned! {span=>
123                    WireFormat::encode(&self.#field, _writer)?;
124                }
125            });
126
127            quote! {
128                #(#fields)*
129
130                Ok(())
131            }
132        } else if let Fields::Unnamed(unnamed) = &data.fields {
133            let fields = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
134                let index = syn::Index::from(i);
135                quote! {
136                    WireFormat::encode(&self.#index, _writer)?;
137                }
138            });
139
140            quote! {
141                 #(#fields)*
142
143                Ok(())
144            }
145        } else {
146            unimplemented!();
147        }
148    } else {
149        unimplemented!();
150    }
151}
152
153// Generate code that recursively calls decode on every field in the struct.
154fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
155    if let Data::Struct(ref data) = *data {
156        if let Fields::Named(ref fields) = data.fields {
157            let values = fields.named.iter().map(|f| {
158                let field = &f.ident;
159                let span = field.span();
160                quote_spanned! {span=>
161                    let #field = WireFormat::decode(_reader)?;
162                }
163            });
164
165            let members = fields.named.iter().map(|f| {
166                let field = &f.ident;
167                quote! {
168                    #field: #field,
169                }
170            });
171
172            quote! {
173                #(#values)*
174
175                Ok(#container {
176                    #(#members)*
177                })
178            }
179        } else if let Fields::Unnamed(unnamed) = &data.fields {
180            let values = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
181                let index = syn::Index::from(i);
182                // create a new ident that s __{index}
183                let ident = Ident::new(
184                    &format!("__{}", index.index),
185                    Span::call_site(),
186                );
187                quote! {
188                    let #ident = WireFormat::decode(_reader)?;
189                }
190            });
191
192            let members = unnamed.unnamed.iter().enumerate().map(|(i, _f)| {
193                let index = syn::Index::from(i);
194                let ident = Ident::new(
195                    &format!("__{}", index.index),
196                    Span::call_site(),
197                );
198                quote! {
199                    #ident
200                }
201            });
202
203            quote! {
204                #(#values)*
205
206                Ok(#container(
207                    #(#members,)*
208                ))
209            }
210        } else {
211            unimplemented!();
212        }
213    } else {
214        unimplemented!();
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    extern crate pretty_assertions;
221    use self::pretty_assertions::assert_eq;
222    use super::*;
223
224    #[test]
225    fn byte_size() {
226        let input: DeriveInput = parse_quote! {
227            struct Item {
228                ident: u32,
229                with_underscores: String,
230                other: u8,
231            }
232        };
233
234        let expected = quote! {
235            0
236                + WireFormat::byte_size(&self.ident)
237                + WireFormat::byte_size(&self.with_underscores)
238                + WireFormat::byte_size(&self.other)
239        };
240
241        assert_eq!(
242            byte_size_sum(&input.data).to_string(),
243            expected.to_string()
244        );
245    }
246
247    #[test]
248    fn encode() {
249        let input: DeriveInput = parse_quote! {
250            struct Item {
251                ident: u32,
252                with_underscores: String,
253                other: u8,
254            }
255        };
256
257        let expected = quote! {
258            WireFormat::encode(&self.ident, _writer)?;
259            WireFormat::encode(&self.with_underscores, _writer)?;
260            WireFormat::encode(&self.other, _writer)?;
261            Ok(())
262        };
263
264        assert_eq!(
265            encode_wire_format(&input.data).to_string(),
266            expected.to_string(),
267        );
268    }
269
270    #[test]
271    fn decode() {
272        let input: DeriveInput = parse_quote! {
273            struct Item {
274                ident: u32,
275                with_underscores: String,
276                other: u8,
277            }
278        };
279
280        let container = Ident::new("Item", Span::call_site());
281        let expected = quote! {
282            let ident = WireFormat::decode(_reader)?;
283            let with_underscores = WireFormat::decode(_reader)?;
284            let other = WireFormat::decode(_reader)?;
285            Ok(Item {
286                ident: ident,
287                with_underscores: with_underscores,
288                other: other,
289            })
290        };
291
292        assert_eq!(
293            decode_wire_format(&input.data, &container).to_string(),
294            expected.to_string(),
295        );
296    }
297
298    #[test]
299    fn end_to_end() {
300        let input: DeriveInput = parse_quote! {
301            struct Niijima_先輩 {
302                a: u8,
303                b: u16,
304                c: u32,
305                d: u64,
306                e: String,
307                f: Vec<String>,
308                g: Nested,
309            }
310        };
311
312        let expected = quote! {
313            mod wire_format_niijima_先輩 {
314                extern crate std;
315                use self::std::io;
316                use self::std::result::Result::Ok;
317
318                use super::Niijima_先輩;
319
320                use super::WireFormat;
321
322                impl WireFormat for Niijima_先輩 {
323                    fn byte_size(&self) -> u32 {
324                        0
325                        + WireFormat::byte_size(&self.a)
326                        + WireFormat::byte_size(&self.b)
327                        + WireFormat::byte_size(&self.c)
328                        + WireFormat::byte_size(&self.d)
329                        + WireFormat::byte_size(&self.e)
330                        + WireFormat::byte_size(&self.f)
331                        + WireFormat::byte_size(&self.g)
332                    }
333
334                    fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
335                        WireFormat::encode(&self.a, _writer)?;
336                        WireFormat::encode(&self.b, _writer)?;
337                        WireFormat::encode(&self.c, _writer)?;
338                        WireFormat::encode(&self.d, _writer)?;
339                        WireFormat::encode(&self.e, _writer)?;
340                        WireFormat::encode(&self.f, _writer)?;
341                        WireFormat::encode(&self.g, _writer)?;
342                        Ok(())
343                    }
344                    fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
345                        let a = WireFormat::decode(_reader)?;
346                        let b = WireFormat::decode(_reader)?;
347                        let c = WireFormat::decode(_reader)?;
348                        let d = WireFormat::decode(_reader)?;
349                        let e = WireFormat::decode(_reader)?;
350                        let f = WireFormat::decode(_reader)?;
351                        let g = WireFormat::decode(_reader)?;
352                        Ok(Niijima_先輩 {
353                            a: a,
354                            b: b,
355                            c: c,
356                            d: d,
357                            e: e,
358                            f: f,
359                            g: g,
360                        })
361                    }
362                }
363            }
364        };
365
366        assert_eq!(
367            p9_wire_format_inner(input).to_string(),
368            expected.to_string(),
369        );
370    }
371
372    #[test]
373    fn end_to_end_unnamed() {
374        let input: DeriveInput = parse_quote! {
375            struct Niijima_先輩(u8, u16, u32, u64, String, Vec<String>, Nested);
376        };
377
378        let expected = quote! {
379            mod wire_format_niijima_先輩 {
380                extern crate std;
381                use self::std::io;
382                use self::std::result::Result::Ok;
383
384                use super::Niijima_先輩;
385
386                use super::WireFormat;
387
388                impl WireFormat for Niijima_先輩 {
389                    fn byte_size(&self) -> u32 {
390                        0
391                        + WireFormat::byte_size(&self.0)
392                        + WireFormat::byte_size(&self.1)
393                        + WireFormat::byte_size(&self.2)
394                        + WireFormat::byte_size(&self.3)
395                        + WireFormat::byte_size(&self.4)
396                        + WireFormat::byte_size(&self.5)
397                        + WireFormat::byte_size(&self.6)
398                    }
399
400                    fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
401                        WireFormat::encode(&self.0, _writer)?;
402                        WireFormat::encode(&self.1, _writer)?;
403                        WireFormat::encode(&self.2, _writer)?;
404                        WireFormat::encode(&self.3, _writer)?;
405                        WireFormat::encode(&self.4, _writer)?;
406                        WireFormat::encode(&self.5, _writer)?;
407                        WireFormat::encode(&self.6, _writer)?;
408                        Ok(())
409                    }
410                    fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
411                        let __0 = WireFormat::decode(_reader)?;
412                        let __1= WireFormat::decode(_reader)?;
413                        let __2 = WireFormat::decode(_reader)?;
414                        let __3 = WireFormat::decode(_reader)?;
415                        let __4 = WireFormat::decode(_reader)?;
416                        let __5 = WireFormat::decode(_reader)?;
417                        let __6 = WireFormat::decode(_reader)?;
418                        Ok(Niijima_先輩(
419                            __0,
420                            __1,
421                            __2,
422                            __3,
423                            __4,
424                            __5,
425                            __6,
426                        ))
427                    }
428                }
429            }
430        };
431
432        assert_eq!(
433            p9_wire_format_inner(input).to_string(),
434            expected.to_string(),
435        );
436    }
437}