Skip to main content

linux_cec_macros/
lib.rs

1/*
2 * Copyright © 2024 Valve Software
3 * SPDX-License-Identifier: LGPL-2.1-or-later
4 */
5
6use heck::AsSnakeCase;
7use proc_macro::TokenStream;
8use proc_macro2::{Punct, TokenStream as TokenStream2};
9use quote::{format_ident, quote};
10use std::collections::HashSet;
11use syn::parse::{self, Parse, ParseStream};
12use syn::{
13    parse_macro_input, parse_str, Data, DataEnum, DeriveInput, Expr, ExprArray, ExprLit, ExprPath,
14    Field, Fields, FieldsUnnamed, Ident, Lit, LitInt, Meta, Type, TypeArray,
15};
16
17macro_rules! bail {
18    ($text:literal) => {
19        return quote! {
20            compile_error!($text);
21        }
22        .into()
23    };
24    ($text:literal $(, $args:ident)*) => {{
25        let err = format!($text $(, $args)*);
26        return quote! {
27            compile_error!(#err);
28        }
29        .into()
30    }};
31    ($text:expr) => {{
32        let err = $text;
33        return quote! {
34            compile_error!(#err);
35        }
36        .into()
37    }};
38}
39
40struct MessageEnum {
41    message: Ident,
42    opcode: Ident,
43    from_bytes: Vec<TokenStream2>,
44    to_bytes: Vec<TokenStream2>,
45    len: Vec<TokenStream2>,
46    expected_len: Vec<TokenStream2>,
47    has_addressing_type: bool,
48    addressing_type: Vec<TokenStream2>,
49    tests: Vec<TokenStream2>,
50}
51
52impl MessageEnum {
53    fn add_message(
54        &mut self,
55        ident: &Ident,
56        fields: Fields,
57        addressing: &Ident,
58    ) -> Result<(), String> {
59        let message = &self.message;
60        let opcode = &self.opcode;
61        let mut from_params = Vec::new();
62        let mut names = Vec::new();
63
64        let testname: Ident = format_ident!("test_{}", AsSnakeCase(ident.to_string()).to_string());
65
66        self.addressing_type
67            .push(quote!(#opcode::#ident => AddressingType::#addressing));
68
69        match fields {
70            Fields::Named(_) => {
71                let mut sizes = Vec::new();
72                let mut params = Vec::new();
73                let mut types = Vec::new();
74                for field in fields {
75                    let Some(name) = field.ident else {
76                        return Err(format!("Variant {ident} cannot have unnamed fields"));
77                    };
78                    let typename = field.ty;
79
80                    match typename {
81                        Type::Path(ref path) if path.path.get_ident().is_none() => {
82                            sizes.push(quote!(#name.len()));
83                        }
84                        _ => sizes.push(quote!(::core::mem::size_of::<#typename>())),
85                    }
86
87                    params.push(quote! {
88                        crate::operand::OperandEncodable::to_bytes(#name, &mut out_params);
89                    });
90                    from_params.push(quote! {
91                        let #name = <#typename as crate::operand::OperandEncodable>::try_from_bytes(&bytes[offset..])
92                        .map_err(crate::Error::add_offset(offset))?;
93
94                        let offset = offset + #name.len();
95                    });
96
97                    names.push(name);
98                    types.push(typename);
99                }
100
101                self.to_bytes.push(quote! {
102                    #message::#ident { #(#names,)* } => {
103                        let mut out_params = vec![#opcode::#ident as u8];
104
105                        #(#params)*
106
107                        out_params
108                    }
109                });
110
111                self.from_bytes.push(quote! {
112                    #opcode::#ident => {
113                        let offset = 1;
114
115                        #(#from_params)*
116
117                        #message::#ident {
118                            #(#names),*
119                        }
120                    }
121                });
122
123                self.len.push(quote! {
124                    #message::#ident { #(#names,)* } => {
125                        #(let _ = #names;)*
126                        1#( + #sizes)*
127                    }
128                });
129
130                self.expected_len.push(quote! {
131                    #opcode::#ident => {
132                        [#(<#types as crate::operand::OperandEncodable>::expected_len()),*]
133                            .into_iter()
134                            .fold(crate::Range::AtLeast(1), |accum, new| {
135                                match (accum, new) {
136                                    (crate::Range::AtLeast(x), crate::Range::AtLeast(y)) =>
137                                        crate::Range::AtLeast(x + y),
138                                    (crate::Range::AtLeast(x), crate::Range::Only(ys)) =>
139                                        crate::Range::Only(ys.into_iter().map(|y| x + y).collect()),
140                                    (crate::Range::AtLeast(_), y) => todo!("Unimplemented opcode length: {y:?}"),
141                                    (x, _) => todo!("Unimplemented opcode following length: {x:?}"),
142                                }
143                            })
144                    }
145                });
146            }
147            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
148                if unnamed.len() > 1 {
149                    return Err(format!(
150                        "Variant {ident} cannot have more than one unnamed fields"
151                    ));
152                }
153                let field = unnamed.first().unwrap();
154                let typename = &field.ty;
155
156                self.to_bytes.push(quote! {
157                    #message::#ident(ref x) => {
158                        let mut out_params = vec![#opcode::#ident as u8];
159                        crate::operand::OperandEncodable::to_bytes(x, &mut out_params);
160                        out_params
161                    }
162                });
163
164                self.from_bytes.push(quote! {
165                    #opcode::#ident => {
166                        let x = <#typename as crate::operand::OperandEncodable>::try_from_bytes(&bytes[1..])
167                        .map_err(crate::Error::add_offset(1))?;
168                        #message::#ident(x)
169                    }
170                });
171
172                self.len.push(quote! {
173                    #message::#ident(ref x) => {
174                        1 + <_ as crate::operand::OperandEncodable>::len(x)
175                    }
176                });
177
178                self.expected_len.push(quote! {
179                    #opcode::#ident => {
180                        <#typename as crate::operand::OperandEncodable>::expected_len() + 1
181                    }
182                });
183            }
184            Fields::Unit => {
185                self.to_bytes
186                    .push(quote!(#message::#ident => vec![#opcode::#ident as u8]));
187
188                self.from_bytes.push(quote! {
189                    #opcode::#ident => {
190                        let offset = 1;
191
192                        #(#from_params)*
193
194                        #message::#ident {
195                            #(#names),*
196                        }
197                    }
198                });
199
200                self.len.push(quote!(#message::#ident => 1));
201
202                self.expected_len
203                    .push(quote!(#opcode::#ident => crate::Range::AtLeast(1)));
204
205                self.tests.push(quote! {
206                    #[cfg(test)]
207                    mod #testname {
208                        use super::*;
209
210                        #[test]
211                        fn test_len() {
212                            assert_eq!(#message::#ident {}.len(), 1);
213                        }
214
215                        #[test]
216                        fn test_opcode() {
217                            assert_eq!(
218                                #message::#ident {}.opcode(),
219                                #opcode::#ident
220                            );
221                        }
222
223                        #[test]
224                        fn test_encoding() {
225                            assert_eq!(
226                                &#message::#ident {}.to_bytes(),
227                                &[#opcode::#ident as u8]
228                            );
229                        }
230
231                        #[test]
232                        fn test_decoding() {
233                            assert_eq!(
234                                #message::try_from_bytes(&[#opcode::#ident as u8]),
235                                Ok(#message::#ident {})
236                            );
237                        }
238
239                        #[test]
240                        fn test_decoding_overfull() {
241                            assert_eq!(
242                                #message::try_from_bytes(&[
243                                    #opcode::#ident as u8,
244                                    0,
245                                    0,
246                                    0,
247                                    0,
248                                    0,
249                                    0,
250                                    0,
251                                    0,
252                                    0,
253                                    0,
254                                    0,
255                                    0,
256                                    0
257                                ]),
258                                Ok(#message::#ident {})
259                            );
260                        }
261                    }
262                });
263            }
264        }
265        Ok(())
266    }
267
268    fn process(mut self, data: DataEnum) -> Result<TokenStream2, String> {
269        let mut opcodes = Vec::new();
270
271        let addressing = format_ident!("addressing");
272        let broadcast = format_ident!("Broadcast");
273        let direct = format_ident!("Direct");
274        let either = format_ident!("Either");
275
276        for variant in data.variants {
277            let ident = variant.ident;
278            let Some((_, discriminant)) = variant.discriminant else {
279                return Err(format!("Variant {ident} missing discriminant"));
280            };
281            opcodes.push(quote!(#ident = #discriminant));
282
283            let mut addressing_type = &direct;
284            for attr in variant.attrs {
285                let Meta::NameValue(meta) = attr.meta else {
286                    continue;
287                };
288                if meta.path.get_ident() != Some(&addressing) {
289                    continue;
290                }
291                let Expr::Lit(ExprLit {
292                    lit: Lit::Str(lit), ..
293                }) = meta.value
294                else {
295                    return Err(format!("Invalid addressing type for {ident}"));
296                };
297
298                self.has_addressing_type = true;
299                addressing_type = match lit.value().as_str() {
300                    "direct" => &direct,
301                    "broadcast" => &broadcast,
302                    "either" => &either,
303                    _ => return Err(format!("Unknown addressing type `{}`", lit.value())),
304                };
305            }
306
307            self.add_message(&ident, variant.fields, addressing_type)?;
308        }
309
310        let message = self.message;
311        let opcode = self.opcode;
312        let from_bytes = self.from_bytes;
313        let to_bytes = self.to_bytes;
314        let len = self.len;
315        let expected_len = self.expected_len;
316        let addressing_type = self.addressing_type;
317        let tests = self.tests;
318
319        let addressing_type = if self.has_addressing_type {
320            Some(quote! {
321
322                /// Get the [`AddressingType`] that reports whether this `#opcode` can be
323                /// addressed directly to a specific logical address, broadcast to all
324                /// logical addresses, or both.
325                impl #opcode {
326                    pub fn addressing_type(&self) -> AddressingType {
327                        match self {
328                            #(#addressing_type,)*
329                        }
330                    }
331                }
332            })
333        } else {
334            None
335        };
336
337        Ok(quote! {
338            #[derive(
339                Debug, Copy, Clone, PartialEq, Eq, Hash, IntoPrimitive, TryFromPrimitive, Operand,
340            )]
341            #[repr(u8)]
342            pub enum #opcode {
343                #(#opcodes,)*
344            }
345
346            impl PartialEq<u8> for #opcode {
347                fn eq(&self, rhs: &u8) -> bool {
348                    *self as u8 == *rhs
349                }
350            }
351
352            #addressing_type
353
354            impl #message {
355                pub fn try_from_bytes(bytes: &[u8]) -> Result<#message> {
356                    if bytes.is_empty() {
357                        return Err(crate::Error::OutOfRange {
358                            expected: crate::Range::AtLeast(1),
359                            got: 0,
360                            quantity: "bytes",
361                        })
362                    }
363                    let opcode = #opcode::try_from_primitive(bytes[0])?;
364                    #message::expected_len(opcode).check(bytes.len(), "bytes")?;
365                    Ok(match opcode {
366                        #(#from_bytes)*
367                    })
368                }
369
370                pub fn to_bytes(&self) -> Vec<u8> {
371                    match self {
372                        #(#to_bytes,)*
373                    }
374                }
375
376                pub fn len(&self) -> usize {
377                    match self {
378                        #(#len,)*
379                    }
380                }
381
382                pub fn expected_len(opcode: #opcode) -> crate::Range<usize> {
383                    match opcode {
384                        #(#expected_len,)*
385                    }
386                }
387            }
388
389            #(#tests)*
390        })
391    }
392}
393
394#[proc_macro_derive(MessageEnum, attributes(addressing))]
395pub fn message_enum(input: TokenStream) -> TokenStream {
396    let DeriveInput {
397        ident: message,
398        data: Data::Enum(data),
399        ..
400    } = parse_macro_input!(input as DeriveInput)
401    else {
402        bail!("This macro only works on the Message enum");
403    };
404    let opcode: Ident = parse_str(match &message {
405        x if x == "Message" => "Opcode",
406        _ => bail!("This macro only works on the Message enum"),
407    })
408    .unwrap();
409
410    let work = MessageEnum {
411        message,
412        opcode,
413        from_bytes: Vec::new(),
414        to_bytes: Vec::new(),
415        len: Vec::new(),
416        expected_len: Vec::new(),
417        has_addressing_type: false,
418        addressing_type: Vec::new(),
419        tests: Vec::new(),
420    };
421
422    match work.process(data) {
423        Ok(tokens) => tokens.into(),
424        Err(error) => bail!(error),
425    }
426}
427
428fn bits_u8_encodable(ident: &Ident) -> TokenStream {
429    quote! {
430        impl crate::operand::OperandEncodable for #ident {
431            fn to_bytes(&self, buf: &mut impl Extend<u8>) {
432                let prim: u8 = self.bits();
433                <u8 as crate::operand::OperandEncodable>::to_bytes(&prim, buf);
434            }
435
436            fn try_from_bytes(bytes: &[u8]) -> crate::Result<Self> {
437                Self::expected_len().check(bytes.len(), "bytes")?;
438                Ok(#ident::from_bits_retain(bytes[0]))
439            }
440
441            fn len(&self) -> usize {
442                1
443            }
444
445            fn expected_len() -> crate::Range<usize> {
446                crate::Range::AtLeast(1)
447            }
448        }
449    }
450    .into()
451}
452
453fn try_into_u8_encodable(ident: &Ident) -> TokenStream {
454    quote! {
455        impl crate::operand::OperandEncodable for #ident {
456            fn to_bytes(&self, buf: &mut impl Extend<u8>) {
457                let prim = <Self as Into<u8>>::into(*self);
458                <u8 as crate::operand::OperandEncodable>::to_bytes(&prim, buf);
459            }
460
461            fn try_from_bytes(bytes: &[u8]) -> crate::Result<Self> {
462                Self::expected_len().check(bytes.len(), "bytes")?;
463                Ok(#ident::try_from(bytes[0])?)
464            }
465
466            fn len(&self) -> usize {
467                1
468            }
469
470            fn expected_len() -> crate::Range<usize> {
471                crate::Range::AtLeast(1)
472            }
473        }
474    }
475    .into()
476}
477
478fn into_u8_encodable(ident: &Ident) -> TokenStream {
479    quote! {
480        impl crate::operand::OperandEncodable for #ident {
481            fn to_bytes(&self, buf: &mut impl Extend<u8>) {
482                let prim = <Self as Into<u8>>::into(*self);
483                <u8 as crate::operand::OperandEncodable>::to_bytes(&prim, buf);
484            }
485
486            fn try_from_bytes(bytes: &[u8]) -> crate::Result<Self> {
487                if bytes.is_empty() {
488                    Err(crate::Error::OutOfRange {
489                        expected: crate::Range::AtLeast(1),
490                        got: bytes.len(),
491                        quantity: "bytes",
492                    })
493                } else {
494                    Ok(#ident::from(bytes[0]))
495                }
496            }
497
498            fn len(&self) -> usize {
499                1
500            }
501
502            fn expected_len() -> crate::Range<usize> {
503                crate::Range::AtLeast(1)
504            }
505        }
506    }
507    .into()
508}
509
510#[proc_macro_derive(Operand)]
511pub fn operand(input: TokenStream) -> TokenStream {
512    let DeriveInput { ident, data, .. } = parse_macro_input!(input as DeriveInput);
513
514    match data {
515        Data::Enum(_) => try_into_u8_encodable(&ident),
516        Data::Struct(data) => match data.fields {
517            Fields::Named(_) => {
518                let mut to = Vec::new();
519                let mut from = Vec::new();
520                let mut len = Vec::new();
521                let mut fields = Vec::new();
522                for field in data.fields {
523                    let Some(name) = field.ident else {
524                        todo!("Operand field has no name: {field:#?}");
525                    };
526                    to.push(quote! {
527                        self.#name.to_bytes(buf);
528                    });
529                    let typename = field.ty;
530                    match typename {
531                        Type::Path(_) => from.push(quote! {
532                            let #name = <#typename as OperandEncodable>::try_from_bytes(&bytes[offset..])
533                            .map_err(crate::Error::add_offset(offset))?;
534
535                            let offset = offset + #name.len();
536                        }),
537                        Type::Array(_) => (),
538                        _ => todo!("Unimplemented named operand type: {typename:#?}"),
539                    }
540                    fields.push(name);
541                    len.push(quote!(::core::mem::size_of::<#typename>()));
542                }
543                let q = quote! {
544                    impl crate::operand::OperandEncodable for #ident {
545                        fn to_bytes(&self, buf: &mut impl Extend<u8>) {
546                            #(#to)*
547                        }
548
549                        fn try_from_bytes(bytes: &[u8]) -> crate::Result<Self> {
550                            Self::expected_len().check(bytes.len(), "bytes")?;
551                            let mut offset = 0;
552                            #(#from)*
553                            Ok(Self {
554                                #(#fields),*
555                            })
556                        }
557
558                        fn len(&self) -> usize {
559                            #(#len)+*
560                        }
561
562                        fn expected_len() -> crate::Range<usize> {
563                            crate::Range::AtLeast(::core::mem::size_of::<#ident>())
564                        }
565                    }
566                };
567                q.into()
568            }
569            Fields::Unnamed(data) => match data.unnamed.first() {
570                Some(Field {
571                    ty: Type::Path(ty), ..
572                }) => {
573                    if ty.qself.is_some() {
574                        bits_u8_encodable(&ident)
575                    } else {
576                        into_u8_encodable(&ident)
577                    }
578                }
579                Some(Field {
580                    ty: Type::Array(TypeArray { elem, len, .. }),
581                    ..
582                }) => quote! {
583                    impl crate::operand::OperandEncodable for #ident {
584                        fn to_bytes(&self, buf: &mut impl Extend<u8>) {
585                            <[#elem; #len] as crate::operand::OperandEncodable>::to_bytes(&self.0, buf);
586                        }
587
588                        fn try_from_bytes(bytes: &[u8]) -> crate::Result<Self> {
589                            Self::expected_len().check(bytes.len(), "bytes")?;
590                            let buf = bytes[..#len].first_chunk::<#len>();
591                            Ok(#ident(*buf.unwrap()))
592                        }
593
594                        fn len(&self) -> usize {
595                            #len
596                        }
597
598                        fn expected_len() -> crate::Range<usize> {
599                            crate::Range::AtLeast(#len)
600                        }
601                    }
602                }
603                .into(),
604                _ => todo!("Unimplemented unnamed field operand type: {data:#?}"),
605            },
606            Fields::Unit => todo!("Unimplemented unit field operand type: {data:#?}"),
607        },
608        _ => todo!("Unimplemented operand type: {data:#?}"),
609    }
610}
611
612#[proc_macro_derive(BitfieldSpecifier, attributes(bits, default))]
613pub fn bitfield_specifier(input: TokenStream) -> TokenStream {
614    let DeriveInput {
615        attrs,
616        ident,
617        data: Data::Enum(data),
618        ..
619    } = parse_macro_input!(input as DeriveInput)
620    else {
621        bail!("This macro only works on enums");
622    };
623
624    let mut ty: Option<Type> = None;
625    let mut bits: Option<LitInt> = None;
626    let mut into_patterns = Vec::new();
627    let mut from_patterns = Vec::new();
628    let mut default = None;
629
630    // Scan enum attrs for #[repr(..)] and #[bits = ..]
631    // Reject invalid repr attributes and ignore all else
632    for attr in attrs {
633        match attr.meta {
634            Meta::List(list) => {
635                match list.path.get_ident() {
636                    Some(ident) if ident == "repr" => (),
637                    _ => continue,
638                }
639                match list.parse_args() {
640                    Ok(parsed_ty) => ty = Some(parsed_ty),
641                    Err(e) => {
642                        let e = e.to_string();
643                        bail!("Invalid repr: {}", e);
644                    }
645                }
646            }
647            Meta::NameValue(nv) => {
648                match nv.path.get_ident() {
649                    Some(ident) if ident == "bits" => (),
650                    _ => continue,
651                }
652                bits = match nv.value {
653                    Expr::Lit(ExprLit {
654                        lit: Lit::Int(lit), ..
655                    }) => Some(lit),
656                    _ => bail!("`bits` must be an integer literal"),
657                };
658            }
659            _ => continue,
660        }
661    }
662    let Some(ty) = ty else {
663        bail!("Type repr is required");
664    };
665    let Some(bits) = bits else {
666        bail!("Bits attribute is required");
667    };
668
669    for variant in &data.variants {
670        let var_ident = &variant.ident;
671        match &variant.fields {
672            Fields::Unit => (),
673            Fields::Unnamed(fields) => {
674                for attr in &variant.attrs {
675                    let Meta::Path(ref path) = attr.meta else {
676                        continue;
677                    };
678                    match fields.unnamed.first() {
679                        Some(field) if ty == field.ty => (),
680                        Some(_) => bail!("Default must have type matching repr"),
681                        _ => continue,
682                    }
683                    match path.get_ident() {
684                        Some(attr_ident) if attr_ident == "default" => default = Some(var_ident),
685                        _ => (),
686                    }
687                }
688                if fields.unnamed.len() != 1 || default.is_none() {
689                    bail!("Variant contains fields, which is unsupported");
690                }
691                continue;
692            }
693            _ => bail!("Variant contains fields, which is unsupported"),
694        }
695        let Some((_, ref expr)) = variant.discriminant else {
696            bail!("Variant has no explicit value");
697        };
698        into_patterns.push(quote!(#ident::#var_ident => #expr));
699        match expr {
700            Expr::Path(_) => from_patterns.push(quote!(#expr => #ident::#var_ident)),
701            _ => from_patterns.push(quote!(x if x == #expr => #ident::#var_ident)),
702        }
703    }
704
705    if default.is_some() {
706        quote! {
707            impl #ident {
708                pub const fn into_bits(self) -> #ty {
709                    match self {
710                        #(#into_patterns,)*
711                        #ident::#default(x) => x,
712                    }
713                }
714
715                pub const fn from_bits(bits: #ty) -> #ident {
716                    match bits & ((1 << (#bits)) - 1) {
717                        #(#from_patterns,)*
718                        x => #ident::#default(x),
719                    }
720                }
721            }
722        }
723    } else {
724        let panic = if from_patterns.len() == 1 << bits.base10_parse::<usize>().unwrap() {
725            quote!(unreachable!())
726        } else {
727            quote!(panic!("Unknown value {x}"))
728        };
729        quote! {
730            impl #ident {
731                pub const fn into_bits(self) -> #ty {
732                    match self {
733                        #(#into_patterns,)*
734                    }
735                }
736
737                pub const fn from_bits(bits: #ty) -> #ident {
738                    match bits & ((1 << (#bits)) - 1) {
739                        #(#from_patterns,)*
740                        x => #panic,
741                    }
742                }
743            }
744        }
745    }
746    .into()
747}
748
749struct CodecTest {
750    name: Option<Ident>,
751    ty: Type,
752    instance: Expr,
753    bytes: ExprArray,
754    extra: HashSet<String>,
755}
756
757impl Parse for CodecTest {
758    fn parse(input: ParseStream<'_>) -> parse::Result<CodecTest> {
759        let mut name = None;
760        let mut ty = None;
761        let mut instance = None;
762        let mut bytes = None;
763        let mut extra = HashSet::new();
764
765        let span = input.span();
766
767        while !input.is_empty() {
768            let ident: Ident = input.parse()?;
769
770            match ident {
771                x if x == "name" => {
772                    if name.is_some() {
773                        return Err(parse::Error::new(input.span(), "Duplicate field `name`"));
774                    }
775                    if input.parse::<Punct>()?.as_char() != ':' {
776                        return Err(parse::Error::new(input.span(), "Expected `:`"));
777                    }
778                    name = Some(input.parse()?);
779                }
780                x if x == "ty" => {
781                    if ty.is_some() {
782                        return Err(parse::Error::new(input.span(), "Duplicate field `ty`"));
783                    }
784                    if input.parse::<Punct>()?.as_char() != ':' {
785                        return Err(parse::Error::new(input.span(), "Expected `:`"));
786                    }
787                    ty = Some(input.parse()?);
788                }
789                x if x == "instance" => {
790                    if instance.is_some() {
791                        return Err(parse::Error::new(
792                            input.span(),
793                            "Duplicate field `instance`",
794                        ));
795                    }
796                    if input.parse::<Punct>()?.as_char() != ':' {
797                        return Err(parse::Error::new(input.span(), "Expected `:`"));
798                    }
799                    instance = Some(input.parse()?);
800                }
801                x if x == "bytes" => {
802                    if bytes.is_some() {
803                        return Err(parse::Error::new(input.span(), "Duplicate field `bytes`"));
804                    }
805                    if input.parse::<Punct>()?.as_char() != ':' {
806                        return Err(parse::Error::new(input.span(), "Expected `:`"));
807                    }
808                    bytes = Some(input.parse()?);
809                }
810                x if x == "extra" => {
811                    if !extra.is_empty() {
812                        return Err(parse::Error::new(input.span(), "Duplicate field `extra`"));
813                    }
814                    if input.parse::<Punct>()?.as_char() != ':' {
815                        return Err(parse::Error::new(input.span(), "Expected `:`"));
816                    }
817                    let extras = input.parse::<ExprArray>()?.elems;
818                    for elem in extras {
819                        match elem {
820                            Expr::Path(ExprPath { path, .. }) => {
821                                if let Some(ident) = path.get_ident() {
822                                    extra.insert(ident.to_string());
823                                } else {
824                                    return Err(parse::Error::new(
825                                        input.span(),
826                                        "Extras must be an identifier",
827                                    ));
828                                }
829                            }
830                            _ => todo!("Extras must be an identifier"),
831                        }
832                    }
833                }
834                _ => {
835                    return Err(parse::Error::new(
836                        input.span(),
837                        format!("Invalid field `{ident}`"),
838                    ))
839                }
840            }
841            if input.parse::<Punct>()?.as_char() != ',' {
842                return Err(parse::Error::new(input.span(), "Expected `:`"));
843            }
844        }
845        let Some(ty) = ty else {
846            return Err(parse::Error::new(span, "Missing field `ty`"));
847        };
848        let Some(instance) = instance else {
849            return Err(parse::Error::new(span, "Missing field `instance`"));
850        };
851        let Some(bytes) = bytes else {
852            return Err(parse::Error::new(span, "Missing field `bytes`"));
853        };
854        Ok(CodecTest {
855            name,
856            ty,
857            instance,
858            bytes,
859            extra,
860        })
861    }
862}
863
864#[proc_macro]
865pub fn opcode_test(input: TokenStream) -> TokenStream {
866    let CodecTest {
867        name,
868        ty,
869        instance,
870        bytes,
871        mut extra,
872        ..
873    } = parse_macro_input!(input as CodecTest);
874    let encode_name: Ident;
875    let decode_name: Ident;
876    let len_name: Ident;
877    let overfull_name: Ident;
878
879    if let Some(name) = name {
880        encode_name = format_ident!("test_encode{name}");
881        decode_name = format_ident!("test_decode{name}");
882        len_name = format_ident!("test_len{name}");
883        overfull_name = format_ident!("test_decode_overfull{name}");
884    } else {
885        encode_name = format_ident!("test_encode");
886        decode_name = format_ident!("test_decode");
887        len_name = format_ident!("test_len");
888        overfull_name = format_ident!("test_decode_overfull");
889    }
890
891    let test_overfull = if extra.take("Overfull").is_some() {
892        Some(quote! {
893            #[test]
894            fn #overfull_name() {
895                let mut bytes = Vec::from(&#bytes);
896                bytes.resize(14, 0);
897                assert_eq!(<#ty as OperandEncodable>::try_from_bytes(&bytes), Ok(#instance));
898            }
899        })
900    } else {
901        None
902    };
903
904    if !extra.is_empty() {
905        bail!("Unknown elements in `extra`: {:?}", extra);
906    }
907
908    quote! {
909        #[test]
910        fn #encode_name() {
911            let mut buf = Vec::new();
912            <#ty as OperandEncodable>::to_bytes(&#instance, &mut buf);
913            let bytes: &[u8] = &#bytes;
914            assert_eq!(buf, bytes);
915        }
916
917        #[test]
918        fn #decode_name() {
919            assert_eq!(<#ty as OperandEncodable>::try_from_bytes(&#bytes), Ok(#instance));
920        }
921
922        #[test]
923        fn #len_name() {
924            assert_eq!(<#ty as OperandEncodable>::len(&#instance), #bytes.len());
925        }
926
927        #test_overfull
928    }
929    .into()
930}
931
932#[proc_macro]
933pub fn message_test(input: TokenStream) -> TokenStream {
934    let CodecTest {
935        name,
936        ty,
937        instance,
938        bytes,
939        mut extra,
940        ..
941    } = parse_macro_input!(input as CodecTest);
942    let encode_name: Ident;
943    let decode_name: Ident;
944    let len_name: Ident;
945    let overfull_name: Ident;
946    let test_opcode;
947
948    if let Some(ref name) = name {
949        encode_name = format_ident!("test_encode{name}");
950        decode_name = format_ident!("test_decode{name}");
951        len_name = format_ident!("test_len{name}");
952        overfull_name = format_ident!("test_decode_overfull{name}");
953        test_opcode = None;
954    } else {
955        encode_name = format_ident!("test_encode");
956        decode_name = format_ident!("test_decode");
957        len_name = format_ident!("test_len");
958        overfull_name = format_ident!("test_decode_overfull");
959        test_opcode = Some(quote! {
960            #[test]
961            fn test_opcode() {
962                assert_eq!(#instance.opcode(), Opcode::#ty);
963            }
964        });
965    }
966
967    let test_overfull = if extra.take("Overfull").is_some() {
968        Some(quote! {
969            #[test]
970            fn #overfull_name() {
971                let mut vec = vec![Opcode::#ty as u8];
972                vec.extend(&#bytes);
973                vec.resize(14, 0);
974                assert_eq!(Message::try_from_bytes(&vec), Ok(#instance));
975            }
976        })
977    } else {
978        None
979    };
980
981    let test_empty = if extra.take("Empty").is_some() {
982        if name.is_some() {
983            bail!("Named tests cannot have `Empty` extra");
984        }
985        Some(quote! {
986            #[test]
987            fn test_decoding_missing_operands() {
988                assert_eq!(
989                    Message::try_from_bytes(&[Opcode::#ty as u8]),
990                    Err(crate::Error::OutOfRange {
991                        expected: Message::expected_len(Opcode::#ty),
992                        got: 1,
993                        quantity: "bytes",
994                    })
995                );
996            }
997        })
998    } else {
999        None
1000    };
1001
1002    if !extra.is_empty() {
1003        bail!("Unknown elements in `extra`: {:?}", extra);
1004    }
1005
1006    quote! {
1007        #test_opcode
1008
1009        #[test]
1010        fn #encode_name() {
1011            let mut vec = vec![Opcode::#ty as u8];
1012            vec.extend(&#bytes);
1013            assert_eq!(#instance.to_bytes(), vec);
1014        }
1015
1016        #[test]
1017        fn #decode_name() {
1018            let mut vec = vec![Opcode::#ty as u8];
1019            vec.extend(&#bytes);
1020            assert_eq!(Message::try_from_bytes(&vec), Ok(#instance));
1021        }
1022
1023        #[test]
1024        fn #len_name() {
1025            assert_eq!(#instance.len(), #bytes.len() + 1);
1026        }
1027
1028        #test_overfull
1029
1030        #test_empty
1031    }
1032    .into()
1033}