libcrux_macros/
lib.rs

1//! This is a collection of libcrux internal proc macros.
2
3use proc_macro::{Delimiter, TokenStream, TokenTree};
4use quote::{format_ident, quote, ToTokens};
5use syn::{parse::Parser, parse_macro_input, ItemFn, ItemMod, LitInt, Token};
6
7fn skip_comma<T: Iterator<Item = TokenTree>>(ts: &mut T) {
8    match ts.next() {
9        Some(TokenTree::Punct(p)) => assert_eq!(p.as_char(), ','),
10        _ => panic!("Expected comma"),
11    }
12}
13
14fn accept_token<T: Iterator<Item = TokenTree>>(ts: &mut T) -> TokenTree {
15    match ts.next() {
16        Some(t) => t,
17        _ => panic!("early end"),
18    }
19}
20
21fn brace(ts: TokenStream) -> TokenTree {
22    TokenTree::Group(proc_macro::Group::new(Delimiter::Brace, ts))
23}
24
25#[proc_macro]
26pub fn unroll_for(ts: TokenStream) -> TokenStream {
27    let mut i = ts.into_iter();
28    let n_loops = accept_token(&mut i).to_string().parse::<u32>().unwrap();
29    skip_comma(&mut i);
30    let var = accept_token(&mut i).to_string();
31    let var = &var[1..var.len() - 1];
32    skip_comma(&mut i);
33    let start = accept_token(&mut i).to_string();
34    skip_comma(&mut i);
35    let increment = accept_token(&mut i).to_string();
36    skip_comma(&mut i);
37    let grouped_body = brace(TokenStream::from_iter(i));
38    let chunks = (0..n_loops).map(|i| {
39        let chunks = [
40            format!("const {}: u32 = {} + {} * {};", var, start, i, increment)
41                .parse()
42                .unwrap(),
43            TokenStream::from(grouped_body.clone()),
44            ";".parse().unwrap(),
45        ];
46        TokenStream::from(brace(TokenStream::from_iter(chunks)))
47    });
48    TokenStream::from(brace(TokenStream::from_iter(chunks.into_iter().flatten())))
49    // "{ let i = 0; println!(\"FROM MACRO{}\", i); }".parse().unwrap()
50}
51
52/// Annotation for a generic ML-DSA implementation, which pulls in
53/// parameter-set specific constants.
54///
55/// Given a list of parameter set identifiers, i.e. `44,65,87`, for
56/// each identifier $id a feature-gated module `ml_dsa_$id` is generated, which
57/// pulls in the parameter specific constants, assumed to be specified
58/// in `crate::constants::ml_dsa_$id`. Further, type aliases for for
59/// signing, and verification keys, whole keypairs and signatures are
60/// created.
61#[proc_macro_attribute]
62pub fn ml_dsa_parameter_sets(args: TokenStream, item: TokenStream) -> TokenStream {
63    let ItemMod {
64        attrs,
65        vis,
66        content,
67        semi,
68        ..
69    } = parse_macro_input!(item as ItemMod);
70
71    let variants_vec = syn::punctuated::Punctuated::<LitInt, Token![,]>::parse_terminated
72        .parse(args)
73        .unwrap();
74    let mut expanded = quote! {};
75
76    for parameter_set in variants_vec {
77        let parameter_set_string = quote! {#parameter_set}.to_string();
78        let feature_name = format!("mldsa{}", parameter_set_string);
79        let modpath = format_ident!("ml_dsa_{}", parameter_set_string);
80
81        let sk_ident = format_ident!("MLDSA{}SigningKey", parameter_set_string);
82        let vk_ident = format_ident!("MLDSA{}VerificationKey", parameter_set_string);
83        let keypair_ident = format_ident!("MLDSA{}KeyPair", parameter_set_string);
84        let sig_ident = format_ident!("MLDSA{}Signature", parameter_set_string);
85
86        // add the variant at the end of the function name
87        if let Some((_, ref content)) = content {
88            let this_content = content.clone();
89            let fun = quote! {
90                #(#attrs)*
91                #[cfg(feature = #feature_name)]
92                #vis mod #modpath {
93                    use crate::constants::#modpath::*;
94
95                    pub type #sk_ident = MLDSASigningKey<SIGNING_KEY_SIZE>;
96                    pub type #vk_ident = MLDSAVerificationKey<VERIFICATION_KEY_SIZE>;
97                    pub type #keypair_ident = MLDSAKeyPair<VERIFICATION_KEY_SIZE, SIGNING_KEY_SIZE>;
98                    pub type #sig_ident = MLDSASignature<SIGNATURE_SIZE>;
99
100                    #(#this_content)*
101                } #semi
102            };
103            expanded.extend(fun);
104        }
105    }
106    expanded.into()
107}
108
109/// Emits span events (of types `EventType::SpanOpen` and `EventType::SpanClose`) with the
110/// provided label into the provided trace. Requires that the caller depends on the
111/// libcrux-test-utils crate.
112#[proc_macro_attribute]
113pub fn trace_span(args: TokenStream, item: TokenStream) -> TokenStream {
114    let args = syn::punctuated::Punctuated::<syn::Expr, Token![,]>::parse_terminated
115        .parse(args)
116        .unwrap();
117
118    let label = args[0].to_token_stream();
119    let trace = args[1].to_token_stream();
120
121    let use_stmt_ts = quote! { use ::libcrux_test_utils::tracing::Trace as _; }.into();
122    let use_stmt = parse_macro_input!(use_stmt_ts as syn::Stmt);
123
124    let assign_stmt_ts =
125        quote! { let __libcrux_trace_macro_span_handle = #trace .emit_span( #label ); }.into();
126    let assign_stmt = parse_macro_input!(assign_stmt_ts as syn::Stmt);
127
128    let mut item_fn = parse_macro_input!(item as ItemFn);
129    match item_fn.block.as_mut() {
130        syn::Block { stmts, .. } => {
131            let mut new_stmts = Vec::with_capacity(stmts.len() + 2);
132            new_stmts.push(use_stmt);
133            new_stmts.push(assign_stmt);
134            new_stmts.append(stmts);
135
136            *stmts = new_stmts
137        }
138    }
139
140    item_fn.to_token_stream().into()
141}