chrony_candm_derive/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: GPL-2.0-only
3
4//! All the derive macros defined in this crate are for traits internal to
5//! the `chrony-candm` crate. There is no need for any other crate to depend
6//! directly on this one.
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{parse_macro_input, parse_quote, DeriveInput};
11
12#[doc(hidden)]
13#[proc_macro_derive(ChronySerialize, attributes(pad))]
14pub fn derive_chrony_serialize(item: TokenStream) -> TokenStream {
15    let input = parse_macro_input!(item as DeriveInput);
16    let name = input.ident;
17
18    match input.data {
19        syn::Data::Struct(ds) => derive_chrony_serialize_struct(name, ds),
20        syn::Data::Enum(ds) => derive_chrony_serialize_enum(name, input.attrs.as_slice(), ds),
21        _ => panic!("Cannot derive ChronySerialize for a non-struct"),
22    }
23}
24
25fn derive_chrony_serialize_struct(name: proc_macro2::Ident, ds: syn::DataStruct) -> TokenStream {
26    let mut length = quote!(0usize);
27    let mut serialize = quote!();
28    let mut deserialize = quote!();
29
30    for field in ds.fields.iter() {
31        let ty = &field.ty;
32        let ident = field
33            .ident
34            .as_ref()
35            .expect("Deriving ChronySerialize for tuple structs is not supported.");
36        let pad = parse_pad_attr(field.attrs.as_ref());
37        length
38            .extend(quote! { + <#ty as ::chrony_candm::common::ChronySerialize>::length() + #pad });
39        serialize.extend(quote! {
40            ::chrony_candm::common::ChronySerialize::serialize(&self.#ident, buf);
41            if #pad != 0 {
42                buf.put_bytes(0, #pad)
43            }
44        });
45        deserialize.extend(quote! {
46            #ident: {
47                let field = <#ty as ::chrony_candm::common::ChronySerialize>::deserialize_unchecked(buf)?;
48                if #pad != 0 {
49                    buf.advance(#pad)
50                }
51                field
52            },
53        })
54    }
55
56    let expanded = quote! {
57        impl ::chrony_candm::common::ChronySerialize for #name {
58            fn length() -> usize {
59                #length
60            }
61
62            fn serialize<B: ::bytes::BufMut>(&self, buf: &mut B) {
63                #serialize
64            }
65
66            fn deserialize_unchecked<B: ::bytes::Buf>(buf: &mut B) -> ::std::result::Result<Self, ::chrony_candm::common::DeserializationError> {
67                ::std::result::Result::Ok(#name {
68                    #deserialize
69                })
70            }
71
72        }
73    };
74
75    TokenStream::from(expanded)
76}
77
78fn derive_chrony_serialize_enum(
79    name: proc_macro2::Ident,
80    attrs: &[syn::Attribute],
81    _ds: syn::DataEnum,
82) -> TokenStream {
83    let repr = attrs
84        .iter()
85        .find_map(|attr| {
86            if let syn::Meta::List(meta_list) = attr.parse_meta().ok()? {
87                if meta_list.path.get_ident()? == "repr" {
88                    let repr = meta_list.nested.iter().next()?;
89                    let repr: syn::Ident = parse_quote! { #repr };
90                    Some(repr)
91                } else {
92                    None
93                }
94            } else {
95                None
96            }
97        })
98        .expect("Must specify a #[repr] attribute to derive ChronySerialize for an enum.");
99
100    let expanded = quote! {
101        impl ::chrony_candm::common::ChronySerialize for #name {
102            fn length() -> usize {
103                ::std::mem::size_of::<#repr>()
104            }
105
106            fn serialize<B: ::bytes::BufMut>(&self, buf: &mut B) {
107                buf.put_slice((&<#repr>::from(*self).to_be_bytes()) as &[u8]);
108            }
109
110            fn deserialize_unchecked<B: ::bytes::Buf>(buf: &mut B) -> ::std::result::Result<Self, ::chrony_candm::common::DeserializationError> {
111                let mut dst = [0u8; ::std::mem::size_of::<#repr>()];
112                buf.copy_to_slice(&mut dst);
113                <Self as ::std::convert::TryFrom<#repr>>::try_from(<#repr>::from_be_bytes(dst)).map_err(|_| ::chrony_candm::common::DeserializationError::new("value outside of enum range"))
114            }
115        }
116    };
117
118    TokenStream::from(expanded)
119}
120
121#[doc(hidden)]
122#[proc_macro_derive(ChronyMessage, attributes(pad, cmd))]
123pub fn derive_chrony_message(item: TokenStream) -> TokenStream {
124    let input = parse_macro_input!(item as DeriveInput);
125    let name = input.ident;
126
127    let ds = match input.data {
128        syn::Data::Enum(ds) => ds,
129        _ => panic!("Cannot derive ChronyMessage for a non-enum"),
130    };
131
132    let mut length = quote!();
133    let mut cmd = quote!();
134    let mut serialize = quote!();
135    let mut deserialize = quote!();
136
137    let mut index = 0u16;
138    for variant in ds.variants.iter() {
139        let ident = &variant.ident;
140        let mut iter = variant.fields.iter();
141        let arg = iter.next();
142        if arg.is_some() {
143            if iter.next().is_some() {
144                panic!("ChronyMessage variants must have at most a single field.")
145            }
146        }
147
148        let pad = parse_pad_attr(variant.attrs.as_ref());
149        if let Some(cmd) = parse_cmd_attr(variant.attrs.as_ref()) {
150            if index > cmd {
151                panic!("Command numbers must be strictly increasing.")
152            }
153            index = cmd;
154        }
155
156        match arg {
157            Some(field) => {
158                let ty = &field.ty;
159                length.extend(quote! { Self::#ident(_) => <#ty as ::chrony_candm::common::ChronySerialize>::length() + #pad, });
160                cmd.extend(quote! { Self::#ident(_) => #index, });
161                serialize.extend(quote! {
162                    Self::#ident(x) => {
163                        if #pad != 0 {
164                            buf.put_bytes(0, #pad)
165                        }
166                        ::chrony_candm::common::ChronySerialize::serialize(x, buf);
167                    },
168                });
169                deserialize.extend(quote! {
170                    #index => {
171                        if #pad != 0 {
172                            buf.advance(#pad)
173                        }
174                        let body = <#ty as ::chrony_candm::common::ChronySerialize>::deserialize(buf)?;
175                        Ok(Self::#ident(body))
176                    },
177                });
178            }
179            None => {
180                length.extend(quote! { Self::#ident => #pad, });
181                cmd.extend(quote! { Self::#ident => #index, });
182                serialize.extend(quote! {
183                    Self::#ident => {
184                        if #pad != 0 {
185                            buf.put_bytes(0, #pad)
186                        }
187                    },
188                });
189                deserialize.extend(quote! {
190                    #index => {
191                        if #pad != 0 {
192                            buf.advance(#pad)
193                        }
194                        Ok(Self::#ident)
195                    },
196                });
197            }
198        };
199
200        index += 1;
201    }
202
203    let expanded = quote! {
204        impl ::chrony_candm::common::ChronyMessage for #name {
205            fn body_length(&self) -> usize {
206                match self {
207                    #length
208                }
209            }
210
211            fn cmd(&self) -> u16 {
212                match self {
213                    #cmd
214                }
215            }
216
217            fn serialize_body<B: ::bytes::BufMut>(&self, buf: &mut B) {
218                match self {
219                    #serialize
220                }
221            }
222
223            fn deserialize_body<B: ::bytes::Buf>(cmd: u16, buf: &mut B) -> ::std::result::Result<Self, ::chrony_candm::common::DeserializationError> {
224                match cmd {
225                    #deserialize
226                    _ => ::std::result::Result::Err(::chrony_candm::common::DeserializationError::new("unsupported command number"))
227                }
228            }
229
230        }
231    };
232
233    TokenStream::from(expanded)
234}
235
236fn parse_pad_attr(attrs: &[syn::Attribute]) -> usize {
237    for attr in attrs.iter() {
238        if let Ok(syn::Meta::NameValue(meta_namevalue)) = attr.parse_meta() {
239            if meta_namevalue.path.is_ident("pad") {
240                if let syn::Lit::Int(i) = meta_namevalue.lit {
241                    match i.base10_parse() {
242                        Ok(size) => return size,
243                        Err(e) => panic!("{}", e),
244                    }
245                } else {
246                    panic!("Argument to pad attribute must be an integer literal")
247                }
248            }
249        }
250    }
251
252    0
253}
254
255fn parse_cmd_attr(attrs: &[syn::Attribute]) -> Option<u16> {
256    for attr in attrs.iter() {
257        if let Ok(syn::Meta::NameValue(meta_namevalue)) = attr.parse_meta() {
258            if meta_namevalue.path.is_ident("cmd") {
259                if let syn::Lit::Int(i) = meta_namevalue.lit {
260                    match i.base10_parse() {
261                        Ok(cmd) => return Some(cmd),
262                        Err(e) => panic!("{}", e),
263                    }
264                } else {
265                    panic!("Argument to cmd attribute must be an integer literal")
266                }
267            }
268        }
269    }
270
271    None
272}