1use proc_macro::{Delimiter, TokenStream, TokenTree};
4use quote::{format_ident, quote};
5use syn::{parse::Parser, parse_macro_input, ItemMod, LitInt, Token};
6
7fn skip_comma<T: Iterator<Item = TokenTree>>(ts: &mut T) {
8 match ts.next() {
9 Some(TokenTree::Punct(p)) => assert_eq!(p.as_char(), ','),
10 _ => panic!("Expected comma"),
11 }
12}
13
14fn accept_token<T: Iterator<Item = TokenTree>>(ts: &mut T) -> TokenTree {
15 match ts.next() {
16 Some(t) => t,
17 _ => panic!("early end"),
18 }
19}
20
21fn brace(ts: TokenStream) -> TokenTree {
22 TokenTree::Group(proc_macro::Group::new(Delimiter::Brace, ts))
23}
24
25#[proc_macro]
26pub fn unroll_for(ts: TokenStream) -> TokenStream {
27 let mut i = ts.into_iter();
28 let n_loops = accept_token(&mut i).to_string().parse::<u32>().unwrap();
29 skip_comma(&mut i);
30 let var = accept_token(&mut i).to_string();
31 let var = &var[1..var.len() - 1];
32 skip_comma(&mut i);
33 let start = accept_token(&mut i).to_string();
34 skip_comma(&mut i);
35 let increment = accept_token(&mut i).to_string();
36 skip_comma(&mut i);
37 let grouped_body = brace(TokenStream::from_iter(i));
38 let chunks = (0..n_loops).map(|i| {
39 let chunks = [
40 format!("const {}: u32 = {} + {} * {};", var, start, i, increment)
41 .parse()
42 .unwrap(),
43 TokenStream::from(grouped_body.clone()),
44 ";".parse().unwrap(),
45 ];
46 TokenStream::from(brace(TokenStream::from_iter(chunks)))
47 });
48 TokenStream::from(brace(TokenStream::from_iter(chunks.into_iter().flatten())))
49 }
51
52#[proc_macro_attribute]
62pub fn ml_dsa_parameter_sets(args: TokenStream, item: TokenStream) -> TokenStream {
63 let ItemMod {
64 attrs,
65 vis,
66 content,
67 semi,
68 ..
69 } = parse_macro_input!(item as ItemMod);
70
71 let variants_vec = syn::punctuated::Punctuated::<LitInt, Token![,]>::parse_terminated
72 .parse(args)
73 .unwrap();
74 let mut expanded = quote! {};
75
76 for parameter_set in variants_vec {
77 let parameter_set_string = quote! {#parameter_set}.to_string();
78 let feature_name = format!("mldsa{}", parameter_set_string);
79 let modpath = format_ident!("ml_dsa_{}", parameter_set_string);
80
81 let sk_ident = format_ident!("MLDSA{}SigningKey", parameter_set_string);
82 let vk_ident = format_ident!("MLDSA{}VerificationKey", parameter_set_string);
83 let keypair_ident = format_ident!("MLDSA{}KeyPair", parameter_set_string);
84 let sig_ident = format_ident!("MLDSA{}Signature", parameter_set_string);
85
86 if let Some((_, ref content)) = content {
88 let this_content = content.clone();
89 let fun = quote! {
90 #(#attrs)*
91 #[cfg(feature = #feature_name)]
92 #vis mod #modpath {
93 use crate::constants::#modpath::*;
94
95 pub type #sk_ident = MLDSASigningKey<SIGNING_KEY_SIZE>;
96 pub type #vk_ident = MLDSAVerificationKey<VERIFICATION_KEY_SIZE>;
97 pub type #keypair_ident = MLDSAKeyPair<VERIFICATION_KEY_SIZE, SIGNING_KEY_SIZE>;
98 pub type #sig_ident = MLDSASignature<SIGNATURE_SIZE>;
99
100 #(#this_content)*
101 } #semi
102 };
103 expanded.extend(fun);
104 }
105 }
106 expanded.into()
107}