commonware_conformance_macros/
lib.rs

1//! Augment the development of [`commonware-conformance`](https://docs.rs/commonware-conformance) with procedural macros.
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10    Ident, Token, Type,
11};
12
13/// A single conformance test entry: `Type` or `Type => n_cases`
14struct ConformanceEntry {
15    ty: Type,
16    n_cases: Option<syn::Expr>,
17}
18
19impl Parse for ConformanceEntry {
20    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
21        let ty: Type = input.parse()?;
22
23        let n_cases = if input.peek(Token![=>]) {
24            input.parse::<Token![=>]>()?;
25            Some(input.parse()?)
26        } else {
27            None
28        };
29
30        Ok(Self { ty, n_cases })
31    }
32}
33
34/// The full input to conformance_tests!
35struct ConformanceInput {
36    entries: Punctuated<ConformanceEntry, Token![,]>,
37}
38
39impl Parse for ConformanceInput {
40    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
41        let entries = Punctuated::parse_terminated(input)?;
42        Ok(Self { entries })
43    }
44}
45
46/// Convert a type to a valid snake_case function name suffix.
47///
48/// Inserts underscores at PascalCase boundaries and replaces punctuation
49/// with underscores. Consecutive separators are collapsed.
50fn type_to_ident(ty: &Type) -> String {
51    let type_str = quote!(#ty).to_string();
52
53    let mut result = String::with_capacity(type_str.len());
54    let mut prev_was_separator = true;
55
56    for c in type_str.chars() {
57        match c {
58            'A'..='Z' => {
59                if !prev_was_separator && !result.is_empty() {
60                    result.push('_');
61                }
62                result.push(c.to_ascii_lowercase());
63                prev_was_separator = false;
64            }
65            'a'..='z' | '0'..='9' => {
66                result.push(c);
67                prev_was_separator = false;
68            }
69            '_' => {
70                if !prev_was_separator && !result.is_empty() {
71                    result.push('_');
72                }
73                prev_was_separator = true;
74            }
75            '<' | '>' | ',' | ' ' | ':' => {
76                if !prev_was_separator && !result.is_empty() {
77                    result.push('_');
78                }
79                prev_was_separator = true;
80            }
81            // Skip other characters
82            _ => {}
83        }
84    }
85
86    result.trim_end_matches("_").to_string()
87}
88
89/// Define tests for types implementing the
90/// [`Conformance`](https://docs.rs/commonware-conformance/latest/commonware_conformance/trait.Conformance.html) trait.
91///
92/// Generates test functions that verify implementations match expected digest
93/// values stored in `conformance.toml`.
94///
95/// # Usage
96///
97/// ```ignore
98/// conformance_tests! {
99///     Vec<u8>,                       // Uses default (65536 cases)
100///     Vec<u16> => 100,               // Explicit case count
101///     BTreeMap<u32, String> => 100,
102/// }
103/// ```
104///
105/// This generates test functions named after the type:
106/// - `test_vec_u8`
107/// - `test_vec_u16`
108/// - `test_b_tree_map_u32_string`
109///
110/// The type name is used as the key in the TOML file.
111#[proc_macro]
112pub fn conformance_tests(input: TokenStream) -> TokenStream {
113    let input = parse_macro_input!(input as ConformanceInput);
114
115    let tests = input.entries.iter().map(|entry| {
116        let ty = &entry.ty;
117        let n_cases = entry
118            .n_cases
119            .as_ref()
120            .map(|e| quote!(#e))
121            .unwrap_or_else(|| quote!(::commonware_conformance::DEFAULT_CASES));
122
123        let type_name_str = quote!(#ty).to_string().replace(' ', "");
124        let fn_name_suffix = type_to_ident(ty);
125        let fn_name = Ident::new(&format!("test_{fn_name_suffix}"), Span::call_site());
126
127        quote! {
128            #[::commonware_conformance::commonware_macros::test_group("conformance")]
129            #[test]
130            fn #fn_name() {
131                ::commonware_conformance::futures::executor::block_on(
132                    ::commonware_conformance::run_conformance_test::<#ty>(
133                        concat!(module_path!(), "::", #type_name_str),
134                        #n_cases,
135                        ::std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/conformance.toml")),
136                    )
137                );
138            }
139        }
140    });
141
142    let expanded = quote! {
143        #(#tests)*
144    };
145
146    expanded.into()
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn ident_for(type_str: &str) -> String {
154        let ty: Type = syn::parse_str(type_str).unwrap();
155        type_to_ident(&ty)
156    }
157
158    #[test]
159    fn test_simple_types() {
160        assert_eq!(ident_for("u8"), "u8");
161        assert_eq!(ident_for("u32"), "u32");
162        assert_eq!(ident_for("String"), "string");
163    }
164
165    #[test]
166    fn test_generic_types() {
167        assert_eq!(ident_for("Vec<u8>"), "vec_u8");
168        assert_eq!(ident_for("Option<u32>"), "option_u32");
169        assert_eq!(ident_for("Option<Vec<u8>>"), "option_vec_u8");
170    }
171
172    #[test]
173    fn test_pascal_case_splitting() {
174        assert_eq!(ident_for("BTreeMap<u32, String>"), "b_tree_map_u32_string");
175        assert_eq!(ident_for("HashMap<u32, u32>"), "hash_map_u32_u32");
176    }
177
178    #[test]
179    fn test_wrapper_types() {
180        assert_eq!(
181            ident_for("CodecConformance<Vec<u8>>"),
182            "codec_conformance_vec_u8"
183        );
184        assert_eq!(
185            ident_for("CodecConformance<BTreeMap<u32, u32>>"),
186            "codec_conformance_b_tree_map_u32_u32"
187        );
188    }
189
190    #[test]
191    fn test_paths() {
192        assert_eq!(ident_for("std::vec::Vec<u8>"), "std_vec_vec_u8");
193        assert_eq!(ident_for("crate::Foo"), "crate_foo");
194    }
195
196    #[test]
197    fn test_tuples() {
198        assert_eq!(ident_for("(u32, u32)"), "u32_u32");
199        assert_eq!(ident_for("(u32, u32, u32)"), "u32_u32_u32");
200    }
201
202    #[test]
203    fn test_arrays() {
204        assert_eq!(ident_for("[u8; 32]"), "u8_32");
205    }
206
207    #[test]
208    fn test_underscores_in_names() {
209        assert_eq!(ident_for("my_type"), "my_type");
210        assert_eq!(ident_for("My_Type"), "my_type");
211    }
212}