mls_rs_codec_derive/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use std::str::FromStr;
6
7use darling::{
8    ast::{self, Fields},
9    FromDeriveInput, FromField, FromVariant,
10};
11use proc_macro2::{Literal, TokenStream};
12use quote::quote;
13use syn::{
14    parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
15};
16
17enum Operation {
18    Size,
19    Encode,
20    Decode,
21}
22
23impl Operation {
24    fn path(&self) -> Path {
25        match self {
26            Operation::Size => parse_quote! { mls_rs_codec::MlsSize },
27            Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode },
28            Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode },
29        }
30    }
31
32    fn call(&self) -> TokenStream {
33        match self {
34            Operation::Size => quote! { mls_encoded_len },
35            Operation::Encode => quote! { mls_encode },
36            Operation::Decode => quote! { mls_decode },
37        }
38    }
39
40    fn extras(&self) -> TokenStream {
41        match self {
42            Operation::Size => quote! {},
43            Operation::Encode => quote! { , writer },
44            Operation::Decode => quote! { reader },
45        }
46    }
47
48    fn is_result(&self) -> bool {
49        match self {
50            Operation::Size => false,
51            Operation::Encode => true,
52            Operation::Decode => true,
53        }
54    }
55}
56
57#[derive(Debug, FromField)]
58#[darling(attributes(mls_codec))]
59struct MlsFieldReceiver {
60    ident: Option<Ident>,
61    with: Option<Path>,
62}
63
64impl MlsFieldReceiver {
65    pub fn call_tokens(&self, index: Index) -> TokenStream {
66        if let Some(ref ident) = self.ident {
67            quote! { &self.#ident }
68        } else {
69            quote! { &self.#index }
70        }
71    }
72
73    pub fn name(&self, index: Index) -> TokenStream {
74        if let Some(ref ident) = self.ident {
75            quote! {#ident: }
76        } else {
77            quote! { #index: }
78        }
79    }
80}
81
82#[derive(Debug, FromVariant)]
83#[darling(attributes(mls_codec))]
84struct MlsVariantReceiver {
85    ident: Ident,
86    discriminant: Option<Expr>,
87    fields: ast::Fields<MlsFieldReceiver>,
88}
89
90#[derive(FromDeriveInput)]
91#[darling(attributes(mls_codec), forward_attrs(repr))]
92struct MlsInputReceiver {
93    attrs: Vec<Attribute>,
94    ident: Ident,
95    generics: Generics,
96    data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
97}
98
99impl MlsInputReceiver {
100    fn handle_input(&self, operation: Operation) -> TokenStream {
101        match self.data {
102            ast::Data::Struct(ref s) => struct_impl(s, operation),
103            ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
104        }
105    }
106}
107
108fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
109    let repr_path = attrs
110        .iter()
111        .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
112        .find(|attr| attr.path().is_ident("repr"))
113        .map(|repr| repr.parse_args())
114        .transpose()
115        .ok()
116        .flatten();
117
118    let Some(Expr::Path(path)) = repr_path else {
119        return None;
120    };
121
122    path.path
123        .segments
124        .iter()
125        .find(|s| s.ident != "C")
126        .map(|path| path.ident.clone())
127}
128
129/// Provides the discriminant for a given variant. If the variant does not specify a suffix
130/// and a `repr_ident` is provided, it will be appended to number.
131fn discriminant_for_variant(
132    variant: &MlsVariantReceiver,
133    repr_ident: &Option<Ident>,
134) -> TokenStream {
135    let discriminant = variant
136        .discriminant
137        .clone()
138        .expect("Enum discriminants must be explicitly defined");
139
140    let Expr::Lit(lit_expr) = &discriminant else {
141        return quote! {#discriminant};
142    };
143
144    let Lit::Int(lit_int) = &lit_expr.lit else {
145        return quote! {#discriminant};
146    };
147
148    if lit_int.suffix().is_empty() {
149        // This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
150        // proc macros to pull it off...
151        // TODO: Add proper support for correctly ignoring transparent, packed and modifiers
152        let str = format!(
153            "{}{}",
154            lit_int.base10_digits(),
155            &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
156        );
157        Literal::from_str(&str)
158            .map(|l| quote! {#l})
159            .ok()
160            .unwrap_or_else(|| quote! {#discriminant})
161    } else {
162        quote! {#discriminant}
163    }
164}
165
166fn enum_impl(
167    ident: &Ident,
168    attrs: &[Attribute],
169    variants: &[MlsVariantReceiver],
170    operation: Operation,
171) -> TokenStream {
172    let handle_error = operation.is_result().then_some(quote! { ? });
173    let path = operation.path();
174    let call = operation.call();
175    let extras = operation.extras();
176    let enum_name = &ident;
177    let repr_ident = repr_ident(attrs);
178    if matches!(operation, Operation::Decode) {
179        let cases = variants.iter().map(|variant| {
180            let variant_name = &variant.ident;
181
182            let discriminant = discriminant_for_variant(variant, &repr_ident);
183
184            // TODO: Support more than 1 field
185            match variant.fields.len() {
186                0 => quote! { #discriminant => Ok(#enum_name::#variant_name), },
187                1 =>{
188                    let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
189                    quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), }
190                },
191                _ => panic!("Enum discriminants with more than 1 field are not currently supported")
192            }
193        });
194
195        return quote! {
196            let discriminant = #path::#call(#extras)#handle_error;
197
198            match discriminant {
199                #(#cases)*
200                _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
201            }
202        };
203    }
204
205    let cases = variants.iter().map(|variant| {
206        let variant_name = &variant.ident;
207
208        let discriminant = discriminant_for_variant(variant, &repr_ident);
209
210        let (parameter, field) = if variant.fields.is_empty() {
211            (None, None)
212        } else {
213            let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
214
215            let start = match operation {
216                Operation::Size => Some(quote! { + }),
217                Operation::Encode => Some(quote! {;}),
218                Operation::Decode => None,
219            };
220
221            (
222                Some(quote! {(ref val)}),
223                Some(quote! { #start #path::#call (val #extras) #handle_error }),
224            )
225        };
226
227        let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error };
228
229        quote! { #enum_name::#variant_name #parameter => { #discrim #field }}
230    });
231
232    let enum_impl = quote! {
233        match self {
234            #(#cases)*
235        }
236    };
237
238    if operation.is_result() {
239        quote! {
240            Ok(#enum_impl)
241        }
242    } else {
243        enum_impl
244    }
245}
246
247fn struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream {
248    let recurse = s.fields.iter().enumerate().map(|(index, field)| {
249        let (call_tokens, field_name) = match operation {
250            Operation::Size | Operation::Encode => {
251                (field.call_tokens(Index::from(index)), quote! {})
252            }
253            Operation::Decode => (quote! {}, field.name(Index::from(index))),
254        };
255
256        let handle_error = operation.is_result().then_some(quote! { ? });
257        let path = field.with.clone().unwrap_or(operation.path());
258        let call = operation.call();
259        let extras = operation.extras();
260
261        quote! {
262           #field_name #path::#call (#call_tokens #extras) #handle_error
263        }
264    });
265
266    match operation {
267        Operation::Size => quote! { 0 #(+ #recurse)* },
268        Operation::Encode => quote! { #(#recurse;)* Ok(()) },
269        Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) },
270    }
271}
272
273fn derive_impl<F>(
274    input: proc_macro::TokenStream,
275    trait_name: TokenStream,
276    function_def: TokenStream,
277    internals: F,
278) -> proc_macro::TokenStream
279where
280    F: FnOnce(&MlsInputReceiver) -> TokenStream,
281{
282    let input = parse_macro_input!(input as DeriveInput);
283
284    let input = MlsInputReceiver::from_derive_input(&input).unwrap();
285
286    let name = &input.ident;
287
288    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
289
290    // Generate an expression to sum up the heap size of each field.
291    let function_impl = internals(&input);
292
293    let expanded = quote! {
294        // The generated impl.
295        impl #impl_generics #trait_name for #name #ty_generics #where_clause {
296            #function_def {
297                #function_impl
298            }
299        }
300    };
301
302    // Hand the output tokens back to the compiler.
303    proc_macro::TokenStream::from(expanded)
304}
305
306#[proc_macro_derive(MlsSize, attributes(mls_codec))]
307pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308    let trait_name = quote! { mls_rs_codec::MlsSize };
309    let function_def = quote! {fn mls_encoded_len(&self) -> usize };
310
311    derive_impl(input, trait_name, function_def, |input| {
312        input.handle_input(Operation::Size)
313    })
314}
315
316#[proc_macro_derive(MlsEncode, attributes(mls_codec))]
317pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
318    let trait_name = quote! { mls_rs_codec::MlsEncode };
319
320    let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec<u8>) -> Result<(), mls_rs_codec::Error> };
321
322    derive_impl(input, trait_name, function_def, |input| {
323        input.handle_input(Operation::Encode)
324    })
325}
326
327#[proc_macro_derive(MlsDecode, attributes(mls_codec))]
328pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
329    let trait_name = quote! { mls_rs_codec::MlsDecode };
330
331    let function_def =
332        quote! { fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> };
333
334    derive_impl(input, trait_name, function_def, |input| {
335        input.handle_input(Operation::Decode)
336    })
337}