Skip to main content

commonware_conformance_macros/
lib.rs

1//! Augment the development of [`commonware-conformance`](https://docs.rs/commonware-conformance) with procedural macros.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::quote;
11use syn::{
12    parse::{Parse, ParseStream},
13    parse_macro_input,
14    punctuated::Punctuated,
15    Ident, Token, Type,
16};
17
18/// A single conformance test entry: `Type` or `Type => n_cases`
19struct ConformanceEntry {
20    ty: Type,
21    n_cases: Option<syn::Expr>,
22}
23
24impl Parse for ConformanceEntry {
25    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
26        let ty: Type = input.parse()?;
27
28        let n_cases = if input.peek(Token![=>]) {
29            input.parse::<Token![=>]>()?;
30            Some(input.parse()?)
31        } else {
32            None
33        };
34
35        Ok(Self { ty, n_cases })
36    }
37}
38
39/// The full input to conformance_tests!
40struct ConformanceInput {
41    entries: Punctuated<ConformanceEntry, Token![,]>,
42}
43
44impl Parse for ConformanceInput {
45    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
46        let entries = Punctuated::parse_terminated(input)?;
47        Ok(Self { entries })
48    }
49}
50
51/// Convert a type to a valid snake_case function name suffix.
52///
53/// Inserts underscores at PascalCase boundaries and replaces punctuation
54/// with underscores. Consecutive separators are collapsed.
55fn type_to_ident(ty: &Type) -> String {
56    let type_str = quote!(#ty).to_string();
57
58    let mut result = String::with_capacity(type_str.len());
59    let mut prev_was_separator = true;
60
61    for c in type_str.chars() {
62        match c {
63            'A'..='Z' => {
64                if !prev_was_separator && !result.is_empty() {
65                    result.push('_');
66                }
67                result.push(c.to_ascii_lowercase());
68                prev_was_separator = false;
69            }
70            'a'..='z' | '0'..='9' => {
71                result.push(c);
72                prev_was_separator = false;
73            }
74            '_' => {
75                if !prev_was_separator && !result.is_empty() {
76                    result.push('_');
77                }
78                prev_was_separator = true;
79            }
80            '<' | '>' | ',' | ' ' | ':' => {
81                if !prev_was_separator && !result.is_empty() {
82                    result.push('_');
83                }
84                prev_was_separator = true;
85            }
86            // Skip other characters
87            _ => {}
88        }
89    }
90
91    result.trim_end_matches("_").to_string()
92}
93
94/// Define tests for types implementing the
95/// [`Conformance`](https://docs.rs/commonware-conformance/latest/commonware_conformance/trait.Conformance.html) trait.
96///
97/// Generates test functions that verify implementations match expected digest
98/// values stored in `conformance.toml`.
99///
100/// # Usage
101///
102/// ```ignore
103/// conformance_tests! {
104///     Vec<u8>,                       // Uses default (65536 cases)
105///     Vec<u16> => 100,               // Explicit case count
106///     BTreeMap<u32, String> => 100,
107/// }
108/// ```
109///
110/// This generates test functions named after the type:
111/// - `test_vec_u8`
112/// - `test_vec_u16`
113/// - `test_b_tree_map_u32_string`
114///
115/// The type name is used as the key in the TOML file.
116#[proc_macro]
117pub fn conformance_tests(input: TokenStream) -> TokenStream {
118    let input = parse_macro_input!(input as ConformanceInput);
119
120    let tests = input.entries.iter().map(|entry| {
121        let ty = &entry.ty;
122        let n_cases = entry
123            .n_cases
124            .as_ref()
125            .map(|e| quote!(#e))
126            .unwrap_or_else(|| quote!(::commonware_conformance::DEFAULT_CASES));
127
128        let type_name_str = quote!(#ty).to_string().replace(' ', "");
129        let fn_name_suffix = type_to_ident(ty);
130        let fn_name = Ident::new(&format!("test_{fn_name_suffix}"), Span::call_site());
131
132        quote! {
133            #[::commonware_conformance::commonware_macros::test_group("conformance")]
134            #[test]
135            fn #fn_name() {
136                ::commonware_conformance::futures::executor::block_on(
137                    ::commonware_conformance::run_conformance_test::<#ty>(
138                        concat!(module_path!(), "::", #type_name_str),
139                        #n_cases,
140                        ::std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/conformance.toml")),
141                    )
142                );
143            }
144        }
145    });
146
147    let expanded = quote! {
148        #(#tests)*
149    };
150
151    expanded.into()
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn ident_for(type_str: &str) -> String {
159        let ty: Type = syn::parse_str(type_str).unwrap();
160        type_to_ident(&ty)
161    }
162
163    #[test]
164    fn test_simple_types() {
165        assert_eq!(ident_for("u8"), "u8");
166        assert_eq!(ident_for("u32"), "u32");
167        assert_eq!(ident_for("String"), "string");
168    }
169
170    #[test]
171    fn test_generic_types() {
172        assert_eq!(ident_for("Vec<u8>"), "vec_u8");
173        assert_eq!(ident_for("Option<u32>"), "option_u32");
174        assert_eq!(ident_for("Option<Vec<u8>>"), "option_vec_u8");
175    }
176
177    #[test]
178    fn test_pascal_case_splitting() {
179        assert_eq!(ident_for("BTreeMap<u32, String>"), "b_tree_map_u32_string");
180        assert_eq!(ident_for("HashMap<u32, u32>"), "hash_map_u32_u32");
181    }
182
183    #[test]
184    fn test_wrapper_types() {
185        assert_eq!(
186            ident_for("CodecConformance<Vec<u8>>"),
187            "codec_conformance_vec_u8"
188        );
189        assert_eq!(
190            ident_for("CodecConformance<BTreeMap<u32, u32>>"),
191            "codec_conformance_b_tree_map_u32_u32"
192        );
193    }
194
195    #[test]
196    fn test_paths() {
197        assert_eq!(ident_for("std::vec::Vec<u8>"), "std_vec_vec_u8");
198        assert_eq!(ident_for("crate::Foo"), "crate_foo");
199    }
200
201    #[test]
202    fn test_tuples() {
203        assert_eq!(ident_for("(u32, u32)"), "u32_u32");
204        assert_eq!(ident_for("(u32, u32, u32)"), "u32_u32_u32");
205    }
206
207    #[test]
208    fn test_arrays() {
209        assert_eq!(ident_for("[u8; 32]"), "u8_32");
210    }
211
212    #[test]
213    fn test_underscores_in_names() {
214        assert_eq!(ident_for("my_type"), "my_type");
215        assert_eq!(ident_for("My_Type"), "my_type");
216    }
217}