1use std::str::FromStr;
6
7use darling::{
8 ast::{self, Fields},
9 FromDeriveInput, FromField, FromVariant,
10};
11use proc_macro2::{Literal, TokenStream};
12use quote::quote;
13use syn::{
14 parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
15};
16
17enum Operation {
18 Size,
19 Encode,
20 Decode,
21}
22
23impl Operation {
24 fn path(&self) -> Path {
25 match self {
26 Operation::Size => parse_quote! { mls_rs_codec::MlsSize },
27 Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode },
28 Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode },
29 }
30 }
31
32 fn call(&self) -> TokenStream {
33 match self {
34 Operation::Size => quote! { mls_encoded_len },
35 Operation::Encode => quote! { mls_encode },
36 Operation::Decode => quote! { mls_decode },
37 }
38 }
39
40 fn extras(&self) -> TokenStream {
41 match self {
42 Operation::Size => quote! {},
43 Operation::Encode => quote! { , writer },
44 Operation::Decode => quote! { reader },
45 }
46 }
47
48 fn is_result(&self) -> bool {
49 match self {
50 Operation::Size => false,
51 Operation::Encode => true,
52 Operation::Decode => true,
53 }
54 }
55}
56
57#[derive(Debug, FromField)]
58#[darling(attributes(mls_codec))]
59struct MlsFieldReceiver {
60 ident: Option<Ident>,
61 with: Option<Path>,
62}
63
64impl MlsFieldReceiver {
65 pub fn call_tokens(&self, index: Index) -> TokenStream {
66 if let Some(ref ident) = self.ident {
67 quote! { &self.#ident }
68 } else {
69 quote! { &self.#index }
70 }
71 }
72
73 pub fn name(&self, index: Index) -> TokenStream {
74 if let Some(ref ident) = self.ident {
75 quote! {#ident: }
76 } else {
77 quote! { #index: }
78 }
79 }
80}
81
82#[derive(Debug, FromVariant)]
83#[darling(attributes(mls_codec))]
84struct MlsVariantReceiver {
85 ident: Ident,
86 discriminant: Option<Expr>,
87 fields: ast::Fields<MlsFieldReceiver>,
88}
89
90#[derive(FromDeriveInput)]
91#[darling(attributes(mls_codec), forward_attrs(repr))]
92struct MlsInputReceiver {
93 attrs: Vec<Attribute>,
94 ident: Ident,
95 generics: Generics,
96 data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
97}
98
99impl MlsInputReceiver {
100 fn handle_input(&self, operation: Operation) -> TokenStream {
101 match self.data {
102 ast::Data::Struct(ref s) => struct_impl(s, operation),
103 ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
104 }
105 }
106}
107
108fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
109 let repr_path = attrs
110 .iter()
111 .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
112 .find(|attr| attr.path().is_ident("repr"))
113 .map(|repr| repr.parse_args())
114 .transpose()
115 .ok()
116 .flatten();
117
118 let Some(Expr::Path(path)) = repr_path else {
119 return None;
120 };
121
122 path.path
123 .segments
124 .iter()
125 .find(|s| s.ident != "C")
126 .map(|path| path.ident.clone())
127}
128
129fn discriminant_for_variant(
132 variant: &MlsVariantReceiver,
133 repr_ident: &Option<Ident>,
134) -> TokenStream {
135 let discriminant = variant
136 .discriminant
137 .clone()
138 .expect("Enum discriminants must be explicitly defined");
139
140 let Expr::Lit(lit_expr) = &discriminant else {
141 return quote! {#discriminant};
142 };
143
144 let Lit::Int(lit_int) = &lit_expr.lit else {
145 return quote! {#discriminant};
146 };
147
148 if lit_int.suffix().is_empty() {
149 let str = format!(
153 "{}{}",
154 lit_int.base10_digits(),
155 &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
156 );
157 Literal::from_str(&str)
158 .map(|l| quote! {#l})
159 .ok()
160 .unwrap_or_else(|| quote! {#discriminant})
161 } else {
162 quote! {#discriminant}
163 }
164}
165
166fn enum_impl(
167 ident: &Ident,
168 attrs: &[Attribute],
169 variants: &[MlsVariantReceiver],
170 operation: Operation,
171) -> TokenStream {
172 let handle_error = operation.is_result().then_some(quote! { ? });
173 let path = operation.path();
174 let call = operation.call();
175 let extras = operation.extras();
176 let enum_name = &ident;
177 let repr_ident = repr_ident(attrs);
178 if matches!(operation, Operation::Decode) {
179 let cases = variants.iter().map(|variant| {
180 let variant_name = &variant.ident;
181
182 let discriminant = discriminant_for_variant(variant, &repr_ident);
183
184 match variant.fields.len() {
186 0 => quote! { #discriminant => Ok(#enum_name::#variant_name), },
187 1 =>{
188 let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
189 quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), }
190 },
191 _ => panic!("Enum discriminants with more than 1 field are not currently supported")
192 }
193 });
194
195 return quote! {
196 let discriminant = #path::#call(#extras)#handle_error;
197
198 match discriminant {
199 #(#cases)*
200 _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
201 }
202 };
203 }
204
205 let cases = variants.iter().map(|variant| {
206 let variant_name = &variant.ident;
207
208 let discriminant = discriminant_for_variant(variant, &repr_ident);
209
210 let (parameter, field) = if variant.fields.is_empty() {
211 (None, None)
212 } else {
213 let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
214
215 let start = match operation {
216 Operation::Size => Some(quote! { + }),
217 Operation::Encode => Some(quote! {;}),
218 Operation::Decode => None,
219 };
220
221 (
222 Some(quote! {(ref val)}),
223 Some(quote! { #start #path::#call (val #extras) #handle_error }),
224 )
225 };
226
227 let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error };
228
229 quote! { #enum_name::#variant_name #parameter => { #discrim #field }}
230 });
231
232 let enum_impl = quote! {
233 match self {
234 #(#cases)*
235 }
236 };
237
238 if operation.is_result() {
239 quote! {
240 Ok(#enum_impl)
241 }
242 } else {
243 enum_impl
244 }
245}
246
247fn struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream {
248 let recurse = s.fields.iter().enumerate().map(|(index, field)| {
249 let (call_tokens, field_name) = match operation {
250 Operation::Size | Operation::Encode => {
251 (field.call_tokens(Index::from(index)), quote! {})
252 }
253 Operation::Decode => (quote! {}, field.name(Index::from(index))),
254 };
255
256 let handle_error = operation.is_result().then_some(quote! { ? });
257 let path = field.with.clone().unwrap_or(operation.path());
258 let call = operation.call();
259 let extras = operation.extras();
260
261 quote! {
262 #field_name #path::#call (#call_tokens #extras) #handle_error
263 }
264 });
265
266 match operation {
267 Operation::Size => quote! { 0 #(+ #recurse)* },
268 Operation::Encode => quote! { #(#recurse;)* Ok(()) },
269 Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) },
270 }
271}
272
273fn derive_impl<F>(
274 input: proc_macro::TokenStream,
275 trait_name: TokenStream,
276 function_def: TokenStream,
277 internals: F,
278) -> proc_macro::TokenStream
279where
280 F: FnOnce(&MlsInputReceiver) -> TokenStream,
281{
282 let input = parse_macro_input!(input as DeriveInput);
283
284 let input = MlsInputReceiver::from_derive_input(&input).unwrap();
285
286 let name = &input.ident;
287
288 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
289
290 let function_impl = internals(&input);
292
293 let expanded = quote! {
294 impl #impl_generics #trait_name for #name #ty_generics #where_clause {
296 #function_def {
297 #function_impl
298 }
299 }
300 };
301
302 proc_macro::TokenStream::from(expanded)
304}
305
306#[proc_macro_derive(MlsSize, attributes(mls_codec))]
307pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308 let trait_name = quote! { mls_rs_codec::MlsSize };
309 let function_def = quote! {fn mls_encoded_len(&self) -> usize };
310
311 derive_impl(input, trait_name, function_def, |input| {
312 input.handle_input(Operation::Size)
313 })
314}
315
316#[proc_macro_derive(MlsEncode, attributes(mls_codec))]
317pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
318 let trait_name = quote! { mls_rs_codec::MlsEncode };
319
320 let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec<u8>) -> Result<(), mls_rs_codec::Error> };
321
322 derive_impl(input, trait_name, function_def, |input| {
323 input.handle_input(Operation::Encode)
324 })
325}
326
327#[proc_macro_derive(MlsDecode, attributes(mls_codec))]
328pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
329 let trait_name = quote! { mls_rs_codec::MlsDecode };
330
331 let function_def =
332 quote! { fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> };
333
334 derive_impl(input, trait_name, function_def, |input| {
335 input.handle_input(Operation::Decode)
336 })
337}