scorex_crypto_avltree/
batch_avl_prover.rs

1use crate::authenticated_tree_ops::*;
2use crate::batch_node::*;
3use crate::operation::*;
4use anyhow::Result;
5use bytes::{BufMut, Bytes, BytesMut};
6use rand::prelude::*;
7use rand::RngCore;
8use std::cmp::Ordering;
9
10///
11/// Implements the batch AVL prover from https://eprint.iacr.org/2016/994
12/// Not thread safe if you use with ThreadUnsafeHash
13///
14/// @param keyLength           - length of keys in tree
15/// @param valueLengthOpt      - length of values in tree. None if it is not fixed
16/// @param oldRootAndHeight    - option root node and height of old tree. Tree should contain new nodes only
17///                            WARNING if you pass it, all isNew and visited flags should be set correctly and height should be correct
18/// @param collectChangedNodes - changed nodes will be collected to a separate buffer during tree modifications if `true`
19/// @param hf                  - hash function
20///
21pub struct BatchAVLProver {
22    pub base: AuthenticatedTreeOpsBase,
23
24    // Directions are just a bit string representing booleans
25    directions: Vec<u8>,
26    directions_bit_length: usize,
27
28    // Keeps track of where we are when replaying directions
29    // a second time; needed for deletions
30    replay_index: usize,
31
32    // Keeps track of the last time we took a right step
33    // when going down the tree; needed for deletions
34    last_right_step: usize,
35
36    old_top_node: Option<NodeId>,
37
38    // operation has already been found in the tree
39    // (if so, we know how to get to the leaf without
40    //  any further comparisons)
41    found: bool, // keeps track of whether the key for the current
42}
43
44impl BatchAVLProver {
45    pub fn new(tree: AVLTree, collect_changed_nodes: bool) -> BatchAVLProver {
46        let mut prover = BatchAVLProver {
47            base: AuthenticatedTreeOpsBase::new(tree, collect_changed_nodes),
48            directions: Vec::new(),
49            directions_bit_length: 0,
50            replay_index: 0,
51            last_right_step: 0,
52            old_top_node: None,
53            found: false,
54        };
55        if prover.base.tree.root.is_none() {
56            let t = LeafNode::new(
57                &prover.base.tree.negative_infinity_key(),
58                &Bytes::from(vec![0u8; prover.base.tree.value_length.unwrap_or(0)]),
59                &prover.base.tree.positive_infinity_key(),
60            );
61            prover.base.tree.root = Some(t);
62            prover.base.tree.height = 1;
63            prover.base.tree.reset();
64        }
65        prover.old_top_node = prover.base.tree.root.clone();
66        prover
67    }
68
69    ///
70    /// If operation.key exists in the tree and the operation succeeds,
71    /// returns Success(Some(v)), where v is the value associated with operation.key
72    /// before the operation.
73    /// If operation.key does not exists in the tree and the operation succeeds, returns Success(None).
74    /// Returns Failure if the operation fails.
75    /// Does not modify the tree or the proof in case return is Failure.
76    ///
77    /// @param operation
78    /// @return - Success(Some(old value)), Success(None), or Failure
79    ////
80    pub fn perform_one_operation(&mut self, operation: &Operation) -> Result<Option<ADValue>> {
81        self.replay_index = self.directions_bit_length;
82        let res = self.return_result_of_one_operation(operation, &self.top_node());
83        if res.is_err() {
84            // take the bit length before fail and divide by 8 with rounding up
85            let old_directions_byte_length = (self.replay_index + 7) / 8;
86            // undo the changes to the directions array
87            self.directions.truncate(old_directions_byte_length);
88            self.directions_bit_length = self.replay_index;
89            if (self.directions_bit_length & 7) > 0 {
90                // 0 out the bits of the last element of the directions array
91                // that are above directionsBitLength
92                let mask = (1u8 << (self.directions_bit_length & 7)) - 1;
93                *self.directions.last_mut().unwrap() &= mask;
94            }
95        }
96        res
97    }
98
99    ///
100    /// @return nodes, that where presented in old tree (starting form oldTopNode, but are not presented in new tree
101    ///
102    pub fn removed_nodes(&mut self) -> Vec<NodeId> {
103        for cn in &self.base.changed_nodes_buffer_to_check {
104            if !self.contains(cn) {
105                self.base.changed_nodes_buffer.push(cn.clone())
106            }
107        }
108        self.base.changed_nodes_buffer.clone()
109    }
110
111    ///
112    /// Generates the proof for all the operations in the list.
113    /// Does NOT modify the tree
114    ////
115    pub fn generate_proof_for_operations(
116        &self,
117        operations: &Vec<Operation>,
118    ) -> Result<(SerializedAdProof, ADDigest)> {
119        let mut new_prover = BatchAVLProver::new(self.base.tree.clone(), false);
120        for op in operations.iter() {
121            new_prover.perform_one_operation(op)?;
122        }
123        Ok((new_prover.generate_proof(), new_prover.digest().unwrap()))
124    }
125
126    /* TODO Possible optimizations:
127     * - Don't put in the key if it's in the modification stream somewhere
128     *   (savings ~32 bytes per proof for transactions with existing key; 0 for insert)
129     *   (problem is that then verifier logic has to change --
130     *   can't verify tree immediately)
131     * - Condense a sequence of balances and other non-full-byte info using
132     *   bit-level stuff and maybe even "changing base without losing space"
133     *   by Dodis-Patrascu-Thorup STOC 2010 (expected savings: 5-15 bytes
134     *   per proof for depth 20, based on experiments with gzipping the array
135     *   that contains only this info)
136     * - Condense the sequence of values if they are mostly not randomly distributed
137     */
138    fn pack_tree(
139        &self,
140        r_node: &NodeId,
141        packaged_tree: &mut BytesMut,
142        previous_leaf_available: &mut bool,
143    ) {
144        // Post order traversal to pack up the tree
145        if !self.base.tree.visited(r_node) {
146            packaged_tree.put_u8(LABEL_IN_PACKAGED_PROOF);
147            let label = self.base.tree.label(r_node);
148            packaged_tree.extend_from_slice(&label);
149            assert!(label.len() == DIGEST_LENGTH);
150            *previous_leaf_available = false;
151        } else {
152            self.base.tree.mark_visited(r_node, false);
153            match self.base.tree.copy(r_node) {
154                Node::Leaf(leaf) => {
155                    packaged_tree.put_u8(LEAF_IN_PACKAGED_PROOF);
156                    if !*previous_leaf_available {
157                        packaged_tree.extend_from_slice(&leaf.hdr.key.unwrap());
158                    }
159                    packaged_tree.extend_from_slice(&leaf.next_node_key);
160                    if self.base.tree.value_length.is_none() {
161                        packaged_tree.put_u32(leaf.value.len() as u32);
162                    }
163                    packaged_tree.extend_from_slice(&leaf.value);
164                    *previous_leaf_available = true;
165                }
166                Node::Internal(node) => {
167                    self.pack_tree(&node.left, packaged_tree, previous_leaf_available);
168                    self.pack_tree(&node.right, packaged_tree, previous_leaf_available);
169                    packaged_tree.put_u8(node.balance as u8);
170                }
171                _ => {
172                    panic!("Node is not resolved");
173                }
174            }
175        }
176    }
177
178    ///
179    /// Generates the proof for all the operations performed (except the ones that failed)
180    /// since the last generateProof call
181    ///
182    /// @return - the proof
183    ///
184    pub fn generate_proof(&mut self) -> SerializedAdProof {
185        self.base.changed_nodes_buffer.clear();
186        self.base.changed_nodes_buffer_to_check.clear();
187        let mut packaged_tree = BytesMut::new();
188        let mut previous_leaf_available = false;
189        self.pack_tree(
190            &self.old_top_node.as_ref().unwrap().clone(),
191            &mut packaged_tree,
192            &mut previous_leaf_available,
193        );
194        packaged_tree.put_u8(END_OF_TREE_IN_PACKAGED_PROOF);
195        packaged_tree.extend_from_slice(&self.directions);
196
197        // prepare for the next time proof
198        self.base.tree.reset();
199        self.directions = Vec::new();
200        self.directions_bit_length = 0;
201        self.old_top_node = self.base.tree.root.clone();
202
203        packaged_tree.freeze()
204    }
205
206    fn walk<IR, LR>(
207        &self,
208        r_node: &NodeId,
209        ir: IR,
210        internal_node_fn: &mut dyn FnMut(&InternalNode, IR) -> (NodeId, IR),
211        leaf_fn: &mut dyn FnMut(&LeafNode, IR) -> LR,
212    ) -> LR {
213        match self.base.tree.copy(r_node) {
214            Node::Leaf(leaf) => leaf_fn(&leaf, ir),
215            Node::Internal(r) => {
216                let i = internal_node_fn(&r, ir);
217                self.walk(&i.0, i.1, internal_node_fn, leaf_fn)
218            }
219            _ => {
220                panic!("Node is not resolved");
221            }
222        }
223    }
224
225    ///
226    /// Walk from tree to a leaf.
227    ///
228    /// @param internalNodeFn - function applied to internal nodes. Takes current internal node and current IR, returns
229    ///                       new internal nod and new IR
230    /// @param leafFn         - function applied to leafss. Takes current leaf and current IR, returns result of walk LR
231    /// @param initial        - initial value of IR
232    /// @tparam IR - result of applying internalNodeFn to internal node. E.g. some accumutalor of previous results
233    /// @tparam LR - result of applying leafFn to a leaf. Result of all walk application
234    /// @return
235    ///
236    pub fn tree_walk<IR, LR>(
237        &self,
238        internal_node_fn: &mut dyn FnMut(&InternalNode, IR) -> (NodeId, IR),
239        leaf_fn: &mut dyn FnMut(&LeafNode, IR) -> LR,
240        initial: IR,
241    ) -> LR {
242        self.walk(&self.top_node(), initial, internal_node_fn, leaf_fn)
243    }
244
245    ///
246    ///
247    /// @param rand - source of randomness
248    /// @return Random leaf from the tree that is not positive or negative infinity
249    ////
250    pub fn random_walk(&self, rand: &mut dyn RngCore) -> Option<KeyValue> {
251        let mut internal_node_fn = |r: &InternalNode, _dummy: ()| -> (NodeId, ()) {
252            if rand.gen::<bool>() {
253                (r.right.clone(), ())
254            } else {
255                (r.left.clone(), ())
256            }
257        };
258        let mut leaf_fn = |leaf: &LeafNode, _dummy: ()| -> Option<KeyValue> {
259            let key = leaf.hdr.key.as_ref().unwrap().clone();
260            if key == self.base.tree.positive_infinity_key() {
261                None
262            } else if key == self.base.tree.negative_infinity_key() {
263                None
264            } else {
265                let value = leaf.value.clone();
266                Some(KeyValue { key, value })
267            }
268        };
269
270        self.tree_walk(&mut internal_node_fn, &mut leaf_fn, ())
271    }
272
273    ///
274    /// A simple non-modifying non-proof-generating lookup.
275    /// Does not mutate the data structure
276    ///
277    /// @return Some(value) for value associated with the given key if key is in the tree, and None otherwise
278    ///
279    pub fn unauthenticated_lookup(&self, key: &ADKey) -> Option<ADValue> {
280        let mut internal_node_fn = |r: &InternalNode, found: bool| {
281            if found {
282                // left all the way to the leaf
283                (r.left.clone(), true)
284            } else {
285                match (*key).cmp(r.hdr.key.as_ref().unwrap()) {
286                    Ordering::Equal =>
287                    // found in the tree -- go one step right, then left to the leaf
288                    {
289                        (r.right.clone(), true)
290                    }
291                    Ordering::Less =>
292                    // going left, not yet found
293                    {
294                        (r.left.clone(), false)
295                    }
296                    Ordering::Greater =>
297                    // going right, not yet found
298                    {
299                        (r.right.clone(), false)
300                    }
301                }
302            }
303        };
304
305        let mut leaf_fn = |leaf: &LeafNode, found: bool| -> Option<ADValue> {
306            if found {
307                Some(leaf.value.clone())
308            } else {
309                None
310            }
311        };
312
313        self.tree_walk(&mut internal_node_fn, &mut leaf_fn, false)
314    }
315
316    fn check_tree_helper(&self, r_node: &NodeId, post_proof: bool) -> (NodeId, NodeId, usize) {
317        let node = self.base.tree.copy(r_node);
318        assert!(!post_proof || (!node.visited() && !node.is_new()));
319        match node {
320            Node::Internal(r) => {
321                let key = r.hdr.key.unwrap();
322                if let Node::Internal(rl) = &*r.left.borrow() {
323                    assert!(*rl.hdr.key.as_ref().unwrap() < key);
324                }
325                if let Node::Internal(rr) = &*r.right.borrow() {
326                    assert!(*rr.hdr.key.as_ref().unwrap() > key);
327                }
328                let (min_left, max_left, left_height) = self.check_tree_helper(&r.left, post_proof);
329                let (min_right, max_right, right_height) =
330                    self.check_tree_helper(&r.right, post_proof);
331                assert_eq!(max_left.borrow().next_node_key(), min_right.borrow().key());
332                assert_eq!(min_right.borrow().key(), key);
333                assert!(
334                    r.balance >= -1
335                        && r.balance <= 1
336                        && r.balance == (right_height as i8 - left_height as i8)
337                );
338                let height = std::cmp::max(left_height, right_height) + 1;
339                (min_left, max_right, height)
340            }
341            _ => (r_node.clone(), r_node.clone(), 1),
342        }
343    }
344
345    ///
346    /// Is for debug only
347    ///
348    /// Checks the BST order, AVL balance, correctness of leaf positions, correctness of first and last
349    /// leaf, correctness of nextLeafKey fields
350    /// If postProof, then also checks for visited and isNew fields being false
351    /// Warning: slow -- takes linear time in tree size
352    /// Throws exception if something is wrong
353    ///
354    pub fn check_tree(&self, post_proof: bool) {
355        let (min_tree, max_tree, tree_height) =
356            self.check_tree_helper(&self.top_node(), post_proof);
357        assert_eq!(
358            min_tree.borrow().key(),
359            self.base.tree.negative_infinity_key()
360        );
361        assert_eq!(
362            max_tree.borrow().next_node_key(),
363            self.base.tree.positive_infinity_key()
364        );
365        assert_eq!(tree_height, self.base.tree.height);
366    }
367}
368
369impl AuthenticatedTreeOps for BatchAVLProver {
370    fn get_state<'a>(&'a self) -> &'a AuthenticatedTreeOpsBase {
371        return &self.base;
372    }
373
374    fn state<'a>(&'a mut self) -> &'a mut AuthenticatedTreeOpsBase {
375        return &mut self.base;
376    }
377
378    ///
379    /// Figures out whether to go left or right when from node r when searching for the key,
380    /// using the appropriate bit in the directions bit string from the proof
381    ///
382    /// @param key
383    /// @param r
384    /// @return - true if to go left, false if to go right in the search
385    ///
386    fn next_direction_is_left(&mut self, key: &ADKey, r: &InternalNode) -> bool {
387        let ret = if self.found {
388            true
389        } else {
390            match (*key).cmp(r.hdr.key.as_ref().unwrap()) {
391                Ordering::Equal => {
392                    // found in the tree -- go one step right, then left to the leaf
393                    self.found = true;
394                    self.last_right_step = self.directions_bit_length;
395                    false
396                }
397                Ordering::Less =>
398                // going left
399                {
400                    true
401                }
402                Ordering::Greater =>
403                // going right
404                {
405                    false
406                }
407            }
408        };
409
410        // encode Booleans as bits
411        if (self.directions_bit_length & 7) == 0 {
412            // new byte needed
413            self.directions.push(if ret { 1u8 } else { 0u8 });
414        } else {
415            if ret {
416                let i = self.directions_bit_length >> 3;
417                self.directions[i] |= 1 << (self.directions_bit_length & 7);
418                // change last byte
419            }
420        }
421        self.directions_bit_length += 1;
422        ret
423    }
424
425    ///
426    /// Determines if the leaf r contains the key
427    ///
428    /// @param key
429    /// @param r
430    /// @return
431    ////
432    fn key_matches_leaf(&mut self, _key: &ADKey, _leaf: &LeafNode) -> Result<bool> {
433        // The prover doesn't actually need to look at the leaf key,
434        // because the prover would have already seen this key on the way
435        // down the to leaf if and only if the leaf matches the key that is being sought
436        let ret = self.found;
437        self.found = false; // reset for next time
438        Ok(ret)
439    }
440
441    ///
442    /// Deletions go down the tree twice -- once to find the leaf and realize
443    /// that it needs to be deleted, and the second time to actually perform the deletion.
444    /// This method will re-create comparison results using directions array and lastRightStep
445    /// variable. Each time it's called, it will give the next comparison result of
446    /// key and node.key, where node starts at the root and progresses down the tree
447    /// according to the comparison results.
448    ///
449    /// @return - result of previous comparison of key and relevant node's key
450    ///
451    fn replay_comparison(&mut self) -> i32 {
452        let ret = if self.replay_index == self.last_right_step {
453            0
454        } else if (self.directions[self.replay_index >> 3] & (1 << (self.replay_index & 7))) == 0 {
455            1
456        } else {
457            -1
458        };
459        self.replay_index += 1;
460        ret
461    }
462}