1use convert_case::{Case, Casing};
35use quote::quote;
36use syn::parse::{Parse, ParseStream, Result};
37use syn::{Error, Ident, LitInt, Token};
38
39struct MMRInput {
40 mmr_type: String,
41 num_of_peaks: usize,
42 arity: usize,
43 hash_type: String,
44 max_input_len: usize,
45}
46
47impl MMRInput {
48 const TYPE_IDENT: &'static str = "Type";
49 const BRANCH_FACTOR_IDENT: &'static str = "BranchFactor";
50 const NUM_OF_PEAKS: &'static str = "Peaks";
51 const HASH_TYPE_IDENT: &'static str = "Hash";
52 const MAX_INPUT_LEN_IDENT: &'static str = "MaxInputWordLength";
53}
54
55impl Parse for MMRInput {
56 fn parse(input: ParseStream) -> Result<Self> {
57 let mut with_type = false;
58 let maybe_type_ident = input.parse::<Ident>()?;
59 let mmr_type: String;
60
61 if maybe_type_ident == Self::TYPE_IDENT {
62 with_type = true;
63 input.parse::<Token![=]>()?;
64 mmr_type = input.parse::<Ident>()?.to_string();
65 input.parse::<Token![,]>()?;
66 } else {
67 mmr_type = "MerkleMountainRange".to_owned();
68 }
69
70 let err_msg = "error while parsing 'BranchFactor = <power of 2 number>' section";
71 let branch_factor_ident = if with_type {
72 input.parse::<Ident>().expect(err_msg)
73 } else {
74 maybe_type_ident
75 };
76
77 if branch_factor_ident != Self::BRANCH_FACTOR_IDENT {
78 return Err(Error::new(
79 branch_factor_ident.span(),
80 format!("expected {}", Self::BRANCH_FACTOR_IDENT),
81 ));
82 }
83 input.parse::<Token![=]>()?;
84 let arity: LitInt = input.parse().expect(err_msg);
85
86 let err_msg = "error while parsing 'Peaks = <peak number>' section";
87 input.parse::<Token![,]>().expect(err_msg);
88
89 let num_of_peaks_ident = input.parse::<Ident>().expect(err_msg);
90 if num_of_peaks_ident != Self::NUM_OF_PEAKS {
91 return Err(Error::new(
92 num_of_peaks_ident.span(),
93 format!("expected {}", Self::NUM_OF_PEAKS),
94 ));
95 }
96 input.parse::<Token![=]>().expect(err_msg);
97 let num_of_peaks: LitInt = input.parse().expect(err_msg);
98
99 let err_msg = "error while parsing 'Hash = <hash impl>' section";
100 input.parse::<Token![,]>().expect(err_msg);
101
102 let hash_type_ident = input.parse::<Ident>().expect(err_msg);
103 if hash_type_ident != Self::HASH_TYPE_IDENT {
104 return Err(Error::new(
105 hash_type_ident.span(),
106 format!("{}, expected {}", err_msg, Self::HASH_TYPE_IDENT),
107 ));
108 }
109 input.parse::<Token![=]>().expect(err_msg);
110 let hash_type = input.parse::<Ident>().expect(err_msg).to_string();
111
112 let err_msg = "error while parsing 'MaxInputWordLength = <usize>' section";
113 input.parse::<Token![,]>().expect(err_msg);
114
115 let max_input_len_ident = input.parse::<Ident>().expect(err_msg);
116 if max_input_len_ident != Self::MAX_INPUT_LEN_IDENT {
117 return Err(Error::new(
118 max_input_len_ident.span(),
119 format!("{}, expected {}", err_msg, Self::MAX_INPUT_LEN_IDENT),
120 ));
121 }
122 input.parse::<Token![=]>().expect(err_msg);
123 let max_input_len: LitInt = input.parse().expect(err_msg);
124 let max_input_len = max_input_len.base10_parse::<usize>()?;
125
126 Ok(Self {
127 mmr_type,
128 num_of_peaks: num_of_peaks.base10_parse::<usize>()?,
129 arity: arity.base10_parse::<usize>()?,
130 hash_type,
131 max_input_len,
132 })
133 }
134}
135
136#[proc_macro]
137pub fn mmr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
138 let input = syn::parse_macro_input!(input as MMRInput);
139
140 if input.num_of_peaks < 2 {
141 panic!("Number of peaks must be greater than 1");
142 }
143
144 let peak_height = input.num_of_peaks;
145 let summit_height = (8 * core::mem::size_of::<usize>() as u32
146 - input.num_of_peaks.leading_zeros()) as usize
147 + 1;
148 let total_height = summit_height + peak_height;
149
150 let total_height = LitInt::new(&total_height.to_string(), proc_macro2::Span::call_site());
151 let summit_height = LitInt::new(&summit_height.to_string(), proc_macro2::Span::call_site());
152 let num_of_peaks = LitInt::new(
153 &input.num_of_peaks.to_string(),
154 proc_macro2::Span::call_site(),
155 );
156 let mmr_type = syn::Ident::new(&input.mmr_type, proc_macro2::Span::call_site());
157 let mmr_peak_type = syn::Ident::new(
158 &format!("{}Peak", input.mmr_type),
159 proc_macro2::Span::call_site(),
160 );
161 let mmr_peak_proof_type = syn::Ident::new(
162 &format!("{}PeakProof", input.mmr_type),
163 proc_macro2::Span::call_site(),
164 );
165 let mmr_proof_type = syn::Ident::new(
166 &format!("{}MMRProof", input.mmr_type),
167 proc_macro2::Span::call_site(),
168 );
169 let hash_type = syn::Ident::new(&input.hash_type, proc_macro2::Span::call_site());
170 let max_input_len = LitInt::new(
171 &input.max_input_len.to_string(),
172 proc_macro2::Span::call_site(),
173 );
174
175 let mod_ident = syn::Ident::new(
176 &input.mmr_type.to_case(Case::Snake),
177 proc_macro2::Span::call_site(),
178 );
179
180 let peak_variant_def_idents = (0..input.num_of_peaks)
181 .map(|i| {
182 (
183 syn::Ident::new(&format!("Peak{i}"), proc_macro2::Span::call_site()),
184 LitInt::new(
185 &(input.num_of_peaks - i).to_string(),
186 proc_macro2::Span::call_site(),
187 ),
188 )
189 })
190 .collect::<Vec<(syn::Ident, LitInt)>>();
191
192 let arity = LitInt::new(&input.arity.to_string(), proc_macro2::Span::call_site());
193
194 let peak_variant_def_tokens = peak_variant_def_idents.iter()
195 .map(|(peak_lit, peak_height)| {
196
197 quote! {
198 #peak_lit(AugmentableTree<#arity, #peak_height, #hash_type, #max_input_len, #mmr_peak_proof_type>)
199 }
200 })
201 .collect::<Vec<_>>();
202
203 let clone_impl_variant_def_tokens = peak_variant_def_idents
204 .iter()
205 .map(|(peak_lit, _)| {
206 quote! {
207 #peak_lit(tree) => #peak_lit(tree.clone())
208 }
209 })
210 .collect::<Vec<_>>();
211
212 let default_variant_def_token = peak_variant_def_idents
213 .iter()
214 .last()
215 .map(|(peak_lit, _)| {
216 quote! {
217 #peak_lit(AugmentableTree::default())
218 }
219 })
220 .expect("variant list is not empty. qed");
221
222 let mut it1 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
223 let it2 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
224 it1.next();
225 let augment_and_merge_variant_def_tokens = it1.zip(it2)
226 .map(|(peak_lit1, peak_lit2)| {
227 quote! {
228 (#peak_lit1(this), #peak_lit1(other)) => Ok(#peak_lit2(this.augment_and_merge(other)))
229 }
230 })
231 .collect::<Vec<_>>();
232
233 let mut it1 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
234 let it2 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
235 it1.next();
236 let augment_variant_def_tokens = it1
237 .zip(it2)
238 .map(|(peak_lit1, peak_lit2)| {
239 quote! {
240 #peak_lit1(this) => Ok(#peak_lit2(this.augment()))
241 }
242 })
243 .collect::<Vec<_>>();
244
245 let as_dyn_tree_variant_def_token = peak_variant_def_idents.iter()
246 .map(|(peak_lit, _)| {
247 quote! {
248 #peak_lit(tree) => tree as &dyn merkle_heapless::traits::StaticTreeTrait<#arity, #hash_type, #mmr_peak_proof_type>
249 }
250 })
251 .collect::<Vec<_>>();
252
253 let as_mut_dyn_tree_variant_def_token = peak_variant_def_idents.iter()
254 .map(|(peak_lit, _)| {
255 quote! {
256 #peak_lit(tree) => tree as &mut dyn merkle_heapless::traits::StaticTreeTrait<#arity, #hash_type, #mmr_peak_proof_type>
257 }
258 })
259 .collect::<Vec<_>>();
260
261 let as_append_only_variant_def_token = peak_variant_def_idents
262 .iter()
263 .map(|(peak_lit, _)| {
264 quote! {
265 #peak_lit(tree) => tree as &dyn merkle_heapless::traits::AppendOnly
266 }
267 })
268 .collect::<Vec<_>>();
269
270 let as_mut_append_only_variant_def_token = peak_variant_def_idents
271 .iter()
272 .map(|(peak_lit, _)| {
273 quote! {
274 #peak_lit(tree) => tree as &mut dyn merkle_heapless::traits::AppendOnly
275 }
276 })
277 .collect::<Vec<_>>();
278
279 let impl_method_body_token = quote! {
280 use #mmr_peak_type::*;
281 match self {
282 #(#as_dyn_tree_variant_def_token),*
283 }
284 };
285 let impl_mut_method_body_token = quote! {
286 use #mmr_peak_type::*;
287 match self {
288 #(#as_mut_dyn_tree_variant_def_token),*
289 }
290 };
291
292 let impl_append_only_method_body_token = quote! {
293 use #mmr_peak_type::*;
294 match self {
295 #(#as_append_only_variant_def_token),*
296 }
297 };
298
299 let impl_mut_append_only_method_body_token = quote! {
300 use #mmr_peak_type::*;
301 match self {
302 #(#as_mut_append_only_variant_def_token),*
303 }
304 };
305
306 let output = quote! {
307 mod #mod_ident {
308 use merkle_heapless::{StaticTree, Error};
309 use merkle_heapless::augmentable::{AugmentableTree};
310 use merkle_heapless::traits::{HashT, StaticTreeTrait, AppendOnly};
311 use merkle_heapless::proof::{Proof, chain_proofs};
312 use merkle_heapless::prefixed::{Prefixed};
313 use super::#hash_type;
314
315 type #mmr_peak_proof_type = Proof<#arity, #peak_height, #hash_type, #max_input_len>;
316 type #mmr_proof_type = Proof<#arity, #total_height, #hash_type, #max_input_len>;
317
318 #[derive(Debug)]
319 pub enum #mmr_peak_type {
320 #(#peak_variant_def_tokens),*
321 }
322
323 impl Clone for #mmr_peak_type {
324 fn clone(&self) -> Self {
325 use #mmr_peak_type::*;
326 match self {
327 #(#clone_impl_variant_def_tokens),*
328 }
329 }
330 }
331
332 impl Default for #mmr_peak_type {
333 fn default() -> Self {
334 use #mmr_peak_type::*;
335 #default_variant_def_token
336 }
337 }
338
339 impl Copy for #mmr_peak_type {}
340
341 impl #mmr_peak_type {
342 pub fn try_augment_and_merge(self, other: Self) -> Result<Self, Error> {
343 use #mmr_peak_type::{*};
344 match (self, other) {
345 #(#augment_and_merge_variant_def_tokens),*,
346 _ => Err(Error::Merge),
347 }
348 }
349 pub fn try_augment(self) -> Result<Self, Error> {
350 use #mmr_peak_type::{*};
351 match self {
352 #(#augment_variant_def_tokens),*,
353 _ => Err(Error::Merge),
354 }
355 }
356 }
357
358 impl StaticTreeTrait<#arity, #hash_type, #mmr_peak_proof_type> for #mmr_peak_type {
359 fn generate_proof(&self, index: usize) -> #mmr_peak_proof_type {
360 #impl_method_body_token.generate_proof(index)
361 }
363 fn replace(&mut self, index: usize, input: &[u8]) {
364 #impl_mut_method_body_token.replace(index, input)
365 }
366 fn replace_leaf(&mut self, index: usize, leaf: <#hash_type as HashT>::Output) {
367 #impl_mut_method_body_token.replace_leaf(index, leaf)
368 }
369 fn root(&self) -> <#hash_type as HashT>::Output {
370 #impl_method_body_token.root()
371 }
372 fn leaves(&self) -> &[Prefixed<#arity, #hash_type>] {
373 #impl_method_body_token.leaves()
374 }
375 fn base_layer_size(&self) -> usize {
376 #impl_method_body_token.base_layer_size()
377 }
378 fn height(&self) -> usize {
382 #impl_method_body_token.height()
383 }
384 }
385
386 impl AppendOnly for #mmr_peak_type {
387 fn try_append(&mut self, input: &[u8]) -> Result<(), Error> {
388 #impl_mut_append_only_method_body_token.try_append(input)
389 }
390 fn num_of_leaves(&self) -> usize {
391 #impl_append_only_method_body_token.num_of_leaves()
392 }
393 }
394
395 pub struct #mmr_type
396 where
397 [(); #num_of_peaks]: Sized,
398 {
399 summit_tree: StaticTree<#arity, #summit_height, #hash_type, #max_input_len>,
401 peaks: [#mmr_peak_type; #num_of_peaks],
402 curr_peak_index: usize,
403 num_of_leaves: usize,
404 }
405
406 impl #mmr_type
407 where
408 [(); #num_of_peaks]: Sized,
409 {
410 pub fn from_peak(peak: #mmr_peak_type) -> Self {
411 let mut this = Self {
412 summit_tree: StaticTree::<#arity, #summit_height, #hash_type, #max_input_len>::default(),
413 peaks: [#mmr_peak_type::default(); #num_of_peaks],
414 curr_peak_index: 0,
415 num_of_leaves: 0,
416 };
417 this.peaks[0] = peak;
418 this
419 }
420
421 fn merge_collapse(&mut self) {
422 let mut i = self.curr_peak_index;
423 while i > 0
426 && self.peaks[i].num_of_leaves() == self.peaks[i - 1].num_of_leaves() {
428
429 match self.peaks[i - 1].try_augment_and_merge(self.peaks[i]) {
430 Ok(peak) => { self.peaks[i - 1] = peak; },
431 Err(_) => break,
432 }
433 self.peaks[i] = Default::default();
434 i -= 1;
435 }
436 self.curr_peak_index = i;
437 }
438
439 pub fn try_append(&mut self, input: &[u8]) -> Result<(), Error> {
440 self.peaks[self.curr_peak_index]
441 .try_append(input)
443 .or_else(|_| {
445 if self.curr_peak_index < #num_of_peaks - 1 {
447 self.curr_peak_index += 1;
449 } else {
450 self.peaks[self.curr_peak_index] = self.peaks[self.curr_peak_index].try_augment()?;
452 }
453 self.peaks[self.curr_peak_index].try_append(input)
455 })
456 .map(|_| {
457 self.merge_collapse();
459
460 let root = self.peaks[self.curr_peak_index].root();
461 self.summit_tree.replace_leaf(self.curr_peak_index, root);
462
463 self.num_of_leaves += 1;
464 })
465 }
466
467 pub fn generate_proof(&self, index: usize) -> #mmr_proof_type {
469 let mut accrue_len = 0;
470 let mut i = 0;
471 while accrue_len + self.peaks[i].num_of_leaves() <= index {
473 accrue_len += self.peaks[i].num_of_leaves();
474 i += 1;
475 }
476 chain_proofs(
479 self.peaks[i].generate_proof(index - accrue_len),
480 self.summit_tree.generate_proof(i)
481 )
482 }
483
484 pub fn base_layer_size(&self) -> usize {
485 self.peaks.iter().map(|peak| peak.base_layer_size()).sum()
486 }
487 pub fn num_of_leaves(&self) -> usize {
488 self.num_of_leaves
489 }
490 pub fn curr_peak_index(&self) -> usize {
491 self.curr_peak_index
492 }
493 pub fn peaks(&self) -> &[#mmr_peak_type] {
494 &self.peaks[..]
495 }
496 }
497
498 impl Default for #mmr_type
499 where
500 [(); #num_of_peaks]: Sized,
501 {
502 fn default() -> Self {
503 Self::from_peak(Default::default())
504 }
505 }
506 }
507
508 use #mod_ident::#mmr_type as #mmr_type;
509 use #mod_ident::#mmr_peak_type as #mmr_peak_type;
510 };
511
512 proc_macro::TokenStream::from(output)
513}