1use alloc::vec::Vec;
2
3use bitcoin::consensus::Encodable;
4use bitcoin::hashes::Hash;
5use bitcoin::hash_types::TxMerkleNode;
6
7pub struct IncrementalHasher {
9 stack: Vec<Option<TxMerkleNode>>,
10}
11
12fn hash(left: TxMerkleNode, right: TxMerkleNode) -> TxMerkleNode {
13 let mut encoder = TxMerkleNode::engine();
14 left.consensus_encode(&mut encoder)
15 .expect("in-memory writers don't error");
16 right
17 .consensus_encode(&mut encoder)
18 .expect("in-memory writers don't error");
19 TxMerkleNode::from_engine(encoder)
20}
21
22impl IncrementalHasher {
23 pub fn new() -> Self {
25 Self { stack: Vec::new() }
26 }
27
28 pub fn add(&mut self, mut node: TxMerkleNode) {
30 for height in 0.. {
31 if self.stack.len() <= height {
32 self.stack.push(Some(node));
33 break;
34 }
35 let left = self.stack[height].take();
36 if let Some(left) = left {
37 node = hash(left, node);
38 } else {
39 self.stack[height] = Some(node);
40 break;
41 }
42 }
43 }
44
45 pub fn finish(self) -> TxMerkleNode {
48 let height = self.stack.len();
49
50 let mut node = None;
51 for (h, left) in self.stack.into_iter().enumerate() {
52 let is_last = h == height - 1;
53 if let Some(left) = left {
54 if let Some(right) = node {
55 node = Some(hash(left, right));
56 } else {
57 node = Some(if is_last { left } else { hash(left, left) });
58 }
59 } else {
60 if let Some(single) = node {
61 node = Some(if is_last {
62 single
63 } else {
64 hash(single, single)
65 });
66 }
67 }
68 }
69 node.expect("empty merkle tree")
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use bitcoin::hashes::Hash;
77 use bitcoin::hash_types::TxMerkleNode;
78
79 #[test]
80 fn test_merkle() {
81 for len in 1..63 {
82 run_one(len);
83 }
84 }
85
86 fn run_one(len: usize) {
87 let nodes = (0..len)
88 .map(|i| TxMerkleNode::from_slice(&[i as u8; 32]).unwrap())
89 .collect::<Vec<_>>();
90 let root = bitcoin::merkle_tree::calculate_root(nodes.clone().into_iter());
91 let mut incremental = IncrementalHasher::new();
92 for node in nodes.iter() {
93 incremental.add(*node);
94 }
95 let incremental_root = incremental.finish();
96 assert_eq!(root, Some(incremental_root), "mismatch for len={}", len);
97 }
98}