mmr_macro/
lib.rs

1//! # Merkle Mountain Range macro
2//! Include ["mmr_macro"] feature in merkle-heapless dependency
3//! ### Necessary compiler features
4//! ```rust
5//! // compulsory at the beginning of the .rs file in order the macro to compile
6//! #![allow(incomplete_features)]
7//! #![feature(generic_const_exprs)]
8//! // snip
9//! ### Declaration and instantiation
10//! use merkle_heapless::{mmr_macro};
11//! // declaration with expicit type name for your MMR
12//! mmr_macro::mmr!(Type = FooMMR, BranchFactor = 2, Peaks = 3, Hash = StdHash, MaxInputWordLength = 10);
13//! let mmr = FooMMR::default();
14//! // implicitly creates MerkleMountainRange type
15//! mmr_macro::mmr!(BranchFactor = 2, Peaks = 5, Hash = StdHash, MaxInputWordLength = 10);
16//! // create with default current peak of height 0
17//! let mmr = MerkleMountainRange::default();
18//! // or create with current peak of height 2
19//! let mmr = MerkleMountainRange::from_peak(MerkleMountainRangePeak::Peak3(Default::default()));
20//! assert_eq!(mmr.peaks()[0].height(), 5 - 3);
21//! ```
22//! ### Functionality
23//! The functionality of Mountain Range is similar to that of the Merkle tree.   
24//! ```rust
25//! mmr.try_append(b"apple").unwrap();
26//! // peak leaf numbers: [1, 0, 0, 0, 0]
27//! assert_eq!(mmr.peaks()[0].height(), 0);
28//! assert_eq!(mmr.peaks()[0].num_of_leaves(), 1);
29//! assert_eq!(mmr.peaks()[1].num_of_leaves(), 0);
30//! let proof = mmr.generate_proof(0);
31//! assert!(proof.validate(b"apple"));
32//! ```
33
34use convert_case::{Case, Casing};
35use quote::quote;
36use syn::parse::{Parse, ParseStream, Result};
37use syn::{Error, Ident, LitInt, Token};
38
39struct MMRInput {
40    mmr_type: String,
41    num_of_peaks: usize,
42    arity: usize,
43    hash_type: String,
44    max_input_len: usize,
45}
46
47impl MMRInput {
48    const TYPE_IDENT: &'static str = "Type";
49    const BRANCH_FACTOR_IDENT: &'static str = "BranchFactor";
50    const NUM_OF_PEAKS: &'static str = "Peaks";
51    const HASH_TYPE_IDENT: &'static str = "Hash";
52    const MAX_INPUT_LEN_IDENT: &'static str = "MaxInputWordLength";
53}
54
55impl Parse for MMRInput {
56    fn parse(input: ParseStream) -> Result<Self> {
57        let mut with_type = false;
58        let maybe_type_ident = input.parse::<Ident>()?;
59        let mmr_type: String;
60
61        if maybe_type_ident == Self::TYPE_IDENT {
62            with_type = true;
63            input.parse::<Token![=]>()?;
64            mmr_type = input.parse::<Ident>()?.to_string();
65            input.parse::<Token![,]>()?;
66        } else {
67            mmr_type = "MerkleMountainRange".to_owned();
68        }
69
70        let err_msg = "error while parsing 'BranchFactor = <power of 2 number>' section";
71        let branch_factor_ident = if with_type {
72            input.parse::<Ident>().expect(err_msg)
73        } else {
74            maybe_type_ident
75        };
76
77        if branch_factor_ident != Self::BRANCH_FACTOR_IDENT {
78            return Err(Error::new(
79                branch_factor_ident.span(),
80                format!("expected {}", Self::BRANCH_FACTOR_IDENT),
81            ));
82        }
83        input.parse::<Token![=]>()?;
84        let arity: LitInt = input.parse().expect(err_msg);
85
86        let err_msg = "error while parsing 'Peaks = <peak number>' section";
87        input.parse::<Token![,]>().expect(err_msg);
88
89        let num_of_peaks_ident = input.parse::<Ident>().expect(err_msg);
90        if num_of_peaks_ident != Self::NUM_OF_PEAKS {
91            return Err(Error::new(
92                num_of_peaks_ident.span(),
93                format!("expected {}", Self::NUM_OF_PEAKS),
94            ));
95        }
96        input.parse::<Token![=]>().expect(err_msg);
97        let num_of_peaks: LitInt = input.parse().expect(err_msg);
98
99        let err_msg = "error while parsing 'Hash = <hash impl>' section";
100        input.parse::<Token![,]>().expect(err_msg);
101
102        let hash_type_ident = input.parse::<Ident>().expect(err_msg);
103        if hash_type_ident != Self::HASH_TYPE_IDENT {
104            return Err(Error::new(
105                hash_type_ident.span(),
106                format!("{}, expected {}", err_msg, Self::HASH_TYPE_IDENT),
107            ));
108        }
109        input.parse::<Token![=]>().expect(err_msg);
110        let hash_type = input.parse::<Ident>().expect(err_msg).to_string();
111
112        let err_msg = "error while parsing 'MaxInputWordLength = <usize>' section";
113        input.parse::<Token![,]>().expect(err_msg);
114
115        let max_input_len_ident = input.parse::<Ident>().expect(err_msg);
116        if max_input_len_ident != Self::MAX_INPUT_LEN_IDENT {
117            return Err(Error::new(
118                max_input_len_ident.span(),
119                format!("{}, expected {}", err_msg, Self::MAX_INPUT_LEN_IDENT),
120            ));
121        }
122        input.parse::<Token![=]>().expect(err_msg);
123        let max_input_len: LitInt = input.parse().expect(err_msg);
124        let max_input_len = max_input_len.base10_parse::<usize>()?;
125
126        Ok(Self {
127            mmr_type,
128            num_of_peaks: num_of_peaks.base10_parse::<usize>()?,
129            arity: arity.base10_parse::<usize>()?,
130            hash_type,
131            max_input_len,
132        })
133    }
134}
135
136#[proc_macro]
137pub fn mmr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
138    let input = syn::parse_macro_input!(input as MMRInput);
139
140    if input.num_of_peaks < 2 {
141        panic!("Number of peaks must be greater than 1");
142    }
143
144    let peak_height = input.num_of_peaks;
145    let summit_height = (8 * core::mem::size_of::<usize>() as u32
146        - input.num_of_peaks.leading_zeros()) as usize
147        + 1;
148    let total_height = summit_height + peak_height;
149
150    let total_height = LitInt::new(&total_height.to_string(), proc_macro2::Span::call_site());
151    let summit_height = LitInt::new(&summit_height.to_string(), proc_macro2::Span::call_site());
152    let num_of_peaks = LitInt::new(
153        &input.num_of_peaks.to_string(),
154        proc_macro2::Span::call_site(),
155    );
156    let mmr_type = syn::Ident::new(&input.mmr_type, proc_macro2::Span::call_site());
157    let mmr_peak_type = syn::Ident::new(
158        &format!("{}Peak", input.mmr_type),
159        proc_macro2::Span::call_site(),
160    );
161    let mmr_peak_proof_type = syn::Ident::new(
162        &format!("{}PeakProof", input.mmr_type),
163        proc_macro2::Span::call_site(),
164    );
165    let mmr_proof_type = syn::Ident::new(
166        &format!("{}MMRProof", input.mmr_type),
167        proc_macro2::Span::call_site(),
168    );
169    let hash_type = syn::Ident::new(&input.hash_type, proc_macro2::Span::call_site());
170    let max_input_len = LitInt::new(
171        &input.max_input_len.to_string(),
172        proc_macro2::Span::call_site(),
173    );
174
175    let mod_ident = syn::Ident::new(
176        &input.mmr_type.to_case(Case::Snake),
177        proc_macro2::Span::call_site(),
178    );
179
180    let peak_variant_def_idents = (0..input.num_of_peaks)
181        .map(|i| {
182            (
183                syn::Ident::new(&format!("Peak{i}"), proc_macro2::Span::call_site()),
184                LitInt::new(
185                    &(input.num_of_peaks - i).to_string(),
186                    proc_macro2::Span::call_site(),
187                ),
188            )
189        })
190        .collect::<Vec<(syn::Ident, LitInt)>>();
191
192    let arity = LitInt::new(&input.arity.to_string(), proc_macro2::Span::call_site());
193
194    let peak_variant_def_tokens = peak_variant_def_idents.iter()
195        .map(|(peak_lit, peak_height)| {
196
197            quote! {
198                #peak_lit(AugmentableTree<#arity, #peak_height, #hash_type, #max_input_len, #mmr_peak_proof_type>)
199            }
200        })
201        .collect::<Vec<_>>();
202
203    let clone_impl_variant_def_tokens = peak_variant_def_idents
204        .iter()
205        .map(|(peak_lit, _)| {
206            quote! {
207                #peak_lit(tree) => #peak_lit(tree.clone())
208            }
209        })
210        .collect::<Vec<_>>();
211
212    let default_variant_def_token = peak_variant_def_idents
213        .iter()
214        .last()
215        .map(|(peak_lit, _)| {
216            quote! {
217                #peak_lit(AugmentableTree::default())
218            }
219        })
220        .expect("variant list is not empty. qed");
221
222    let mut it1 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
223    let it2 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
224    it1.next();
225    let augment_and_merge_variant_def_tokens = it1.zip(it2)
226        .map(|(peak_lit1, peak_lit2)| {
227            quote! {
228                (#peak_lit1(this), #peak_lit1(other)) => Ok(#peak_lit2(this.augment_and_merge(other)))
229            }
230        })
231        .collect::<Vec<_>>();
232
233    let mut it1 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
234    let it2 = peak_variant_def_idents.iter().map(|(peak_lit, _)| peak_lit);
235    it1.next();
236    let augment_variant_def_tokens = it1
237        .zip(it2)
238        .map(|(peak_lit1, peak_lit2)| {
239            quote! {
240                #peak_lit1(this) => Ok(#peak_lit2(this.augment()))
241            }
242        })
243        .collect::<Vec<_>>();
244
245    let as_dyn_tree_variant_def_token = peak_variant_def_idents.iter()
246        .map(|(peak_lit, _)| {
247            quote! {
248                #peak_lit(tree) => tree as &dyn merkle_heapless::traits::StaticTreeTrait<#arity, #hash_type, #mmr_peak_proof_type>
249            }
250        })
251        .collect::<Vec<_>>();
252
253    let as_mut_dyn_tree_variant_def_token = peak_variant_def_idents.iter()
254        .map(|(peak_lit, _)| {
255            quote! {
256                #peak_lit(tree) => tree as &mut dyn merkle_heapless::traits::StaticTreeTrait<#arity, #hash_type, #mmr_peak_proof_type>
257            }
258        })
259        .collect::<Vec<_>>();
260
261    let as_append_only_variant_def_token = peak_variant_def_idents
262        .iter()
263        .map(|(peak_lit, _)| {
264            quote! {
265                #peak_lit(tree) => tree as &dyn merkle_heapless::traits::AppendOnly
266            }
267        })
268        .collect::<Vec<_>>();
269
270    let as_mut_append_only_variant_def_token = peak_variant_def_idents
271        .iter()
272        .map(|(peak_lit, _)| {
273            quote! {
274                #peak_lit(tree) => tree as &mut dyn merkle_heapless::traits::AppendOnly
275            }
276        })
277        .collect::<Vec<_>>();
278
279    let impl_method_body_token = quote! {
280        use #mmr_peak_type::*;
281        match self {
282            #(#as_dyn_tree_variant_def_token),*
283        }
284    };
285    let impl_mut_method_body_token = quote! {
286        use #mmr_peak_type::*;
287        match self {
288            #(#as_mut_dyn_tree_variant_def_token),*
289        }
290    };
291
292    let impl_append_only_method_body_token = quote! {
293        use #mmr_peak_type::*;
294        match self {
295            #(#as_append_only_variant_def_token),*
296        }
297    };
298
299    let impl_mut_append_only_method_body_token = quote! {
300        use #mmr_peak_type::*;
301        match self {
302            #(#as_mut_append_only_variant_def_token),*
303        }
304    };
305
306    let output = quote! {
307            mod #mod_ident {
308                use merkle_heapless::{StaticTree, Error};
309                use merkle_heapless::augmentable::{AugmentableTree};
310                use merkle_heapless::traits::{HashT, StaticTreeTrait, AppendOnly};
311                use merkle_heapless::proof::{Proof, chain_proofs};
312                use merkle_heapless::prefixed::{Prefixed};
313                use super::#hash_type;
314
315                type #mmr_peak_proof_type = Proof<#arity, #peak_height, #hash_type, #max_input_len>;
316                type #mmr_proof_type = Proof<#arity, #total_height, #hash_type, #max_input_len>;
317
318                #[derive(Debug)]
319                pub enum #mmr_peak_type {
320                    #(#peak_variant_def_tokens),*
321                }
322
323                impl Clone for #mmr_peak_type {
324                    fn clone(&self) -> Self {
325                        use #mmr_peak_type::*;
326                        match self {
327                            #(#clone_impl_variant_def_tokens),*
328                        }
329                    }
330                }
331
332                impl Default for #mmr_peak_type {
333                    fn default() -> Self {
334                        use #mmr_peak_type::*;
335                        #default_variant_def_token
336                    }
337                }
338
339                impl Copy for #mmr_peak_type {}
340
341                impl #mmr_peak_type {
342                    pub fn try_augment_and_merge(self, other: Self) -> Result<Self, Error> {
343                        use #mmr_peak_type::{*};
344                        match (self, other) {
345                            #(#augment_and_merge_variant_def_tokens),*,
346                            _ => Err(Error::Merge),
347                        }
348                    }
349                    pub fn try_augment(self) -> Result<Self, Error> {
350                        use #mmr_peak_type::{*};
351                        match self {
352                            #(#augment_variant_def_tokens),*,
353                            _ => Err(Error::Merge),
354                        }
355                    }
356                }
357
358                impl StaticTreeTrait<#arity, #hash_type, #mmr_peak_proof_type> for #mmr_peak_type {
359                    fn generate_proof(&self, index: usize) -> #mmr_peak_proof_type {
360                        #impl_method_body_token.generate_proof(index)
361                        // #impl_mut_method_body_token.generate_proof(index)
362                    }
363                    fn replace(&mut self, index: usize, input: &[u8]) {
364                        #impl_mut_method_body_token.replace(index, input)
365                    }
366                    fn replace_leaf(&mut self, index: usize, leaf: <#hash_type as HashT>::Output) {
367                        #impl_mut_method_body_token.replace_leaf(index, leaf)
368                    }
369                    fn root(&self) -> <#hash_type as HashT>::Output {
370                        #impl_method_body_token.root()
371                    }
372                    fn leaves(&self) -> &[Prefixed<#arity, #hash_type>] {
373                        #impl_method_body_token.leaves()
374                    }
375                    fn base_layer_size(&self) -> usize {
376                        #impl_method_body_token.base_layer_size()
377                    }
378                    // fn arity(&self) -> usize {
379                    //     #impl_method_body_token.arity()
380                    // }
381                    fn height(&self) -> usize {
382                        #impl_method_body_token.height()
383                    }
384                }
385
386                impl AppendOnly for #mmr_peak_type {
387                    fn try_append(&mut self, input: &[u8]) -> Result<(), Error> {
388                        #impl_mut_append_only_method_body_token.try_append(input)
389                    }
390                    fn num_of_leaves(&self) -> usize {
391                        #impl_append_only_method_body_token.num_of_leaves()
392                    }
393                }
394
395                pub struct #mmr_type
396                where
397                    [(); #num_of_peaks]: Sized,
398                {
399                    // the tree that generates the entire proof by chaining a peak's proof
400                    summit_tree: StaticTree<#arity, #summit_height, #hash_type, #max_input_len>,
401                    peaks: [#mmr_peak_type; #num_of_peaks],
402                    curr_peak_index: usize,
403                    num_of_leaves: usize,
404                }
405
406                impl #mmr_type
407                where
408                    [(); #num_of_peaks]: Sized,
409                {
410                    pub fn from_peak(peak: #mmr_peak_type) -> Self {
411                        let mut this = Self {
412                            summit_tree: StaticTree::<#arity, #summit_height, #hash_type, #max_input_len>::default(),
413                            peaks: [#mmr_peak_type::default(); #num_of_peaks],
414                            curr_peak_index: 0,
415                            num_of_leaves: 0,
416                        };
417                        this.peaks[0] = peak;
418                        this
419                    }
420
421                    fn merge_collapse(&mut self) {
422                        let mut i = self.curr_peak_index;
423                        // back propagate and merge peaks while possible
424                        // the indicator that two peaks can merge is that they have the same rank (can check height or num_of_leaves)
425                        while i > 0
426    //                        && self.peaks[i].height() == self.peaks[i - 1].height()
427                            && self.peaks[i].num_of_leaves() == self.peaks[i - 1].num_of_leaves() {
428
429                            match self.peaks[i - 1].try_augment_and_merge(self.peaks[i]) {
430                                Ok(peak) => { self.peaks[i - 1] = peak; },
431                                Err(_) => break,
432                            }
433                            self.peaks[i] = Default::default();
434                            i -= 1;
435                        }
436                        self.curr_peak_index = i;
437                    }
438
439                    pub fn try_append(&mut self, input: &[u8]) -> Result<(), Error> {
440                        self.peaks[self.curr_peak_index]
441                            // try to append item to the current peak
442                            .try_append(input)
443                            // if couldn't append, it's because the underlying tree is full
444                            .or_else(|_| {
445                                // so if the current peak is not last...
446                                if self.curr_peak_index < #num_of_peaks - 1 {
447                                    // move to the next peak and set it the new current one
448                                    self.curr_peak_index += 1;
449                                } else {
450                                    // try to augment the last peak
451                                    self.peaks[self.curr_peak_index] = self.peaks[self.curr_peak_index].try_augment()?;
452                                }
453                                // now try append the item to the new peak
454                                self.peaks[self.curr_peak_index].try_append(input)
455                            })
456                            .map(|_| {
457                                // now back propagate the peaks and merge them if necessary
458                                self.merge_collapse();
459
460                                let root = self.peaks[self.curr_peak_index].root();
461                                self.summit_tree.replace_leaf(self.curr_peak_index, root);
462
463                                self.num_of_leaves += 1;
464                            })
465                    }
466
467                    // panics if the index is out of bounds
468                    pub fn generate_proof(&self, index: usize) -> #mmr_proof_type {
469                        let mut accrue_len = 0;
470                        let mut i = 0;
471                        // find the peak corresponding to the index
472                        while accrue_len + self.peaks[i].num_of_leaves() <= index {
473                            accrue_len += self.peaks[i].num_of_leaves();
474                            i += 1;
475                        }
476                        // make thy entire proof by chaining the peak proof
477                        // to the upper tree proof
478                        chain_proofs(
479                            self.peaks[i].generate_proof(index - accrue_len),
480                            self.summit_tree.generate_proof(i)
481                        )
482                    }
483
484                    pub fn base_layer_size(&self) -> usize {
485                        self.peaks.iter().map(|peak| peak.base_layer_size()).sum()
486                    }
487                    pub fn num_of_leaves(&self) -> usize {
488                        self.num_of_leaves
489                    }
490                    pub fn curr_peak_index(&self) -> usize {
491                        self.curr_peak_index
492                    }
493                    pub fn peaks(&self) -> &[#mmr_peak_type] {
494                        &self.peaks[..]
495                    }
496                }
497
498                impl Default for #mmr_type
499                where
500                    [(); #num_of_peaks]: Sized,
501                    {
502                        fn default() -> Self {
503                            Self::from_peak(Default::default())
504                        }
505                    }
506            }
507
508            use #mod_ident::#mmr_type as #mmr_type;
509            use #mod_ident::#mmr_peak_type as #mmr_peak_type;
510        };
511
512    proc_macro::TokenStream::from(output)
513}