1use std::{fmt::Debug, marker::PhantomData};
2
3use cosmwasm_std::{Order, Storage};
4use cw_storage_plus::{Item, Map};
5use serde::{de::DeserializeOwned, Serialize};
6
7use crate::{Hasher, MerkleTree, MerkleTreeError};
8
9pub struct SparseMerkleTree<
11 'a,
12 L: Serialize + DeserializeOwned + Clone + Debug + PartialEq,
13 H: Hasher<L>,
14> {
15 _l: PhantomData<L>,
16 _h: PhantomData<H>,
17 pub hashes: Item<'a, (Vec<L>, Vec<L>)>,
18 pub leafs: Map<'a, u64, L>,
19 pub level: Item<'a, u8>,
20 pub root: Item<'a, L>,
21}
22
23impl<'a, L: Serialize + DeserializeOwned + Clone + Debug + PartialEq, H: Hasher<L>>
24 SparseMerkleTree<'a, L, H>
25{
26 pub const fn new(
27 hashes_ns: &'a str,
28 leafs_ns: &'a str,
29 level_ns: &'a str,
30 root_ns: &'a str,
31 ) -> Self {
32 Self {
33 _l: PhantomData,
34 _h: PhantomData,
35 hashes: Item::new(hashes_ns),
36 leafs: Map::new(leafs_ns),
37 level: Item::new(level_ns),
38 root: Item::new(root_ns),
39 }
40 }
41}
42
43impl<'a, L: Serialize + DeserializeOwned + Clone + Debug + PartialEq, H: Hasher<L>> MerkleTree<L, H>
44 for SparseMerkleTree<'a, L, H>
45{
46 fn init(
47 &self,
48 storage: &mut dyn Storage,
49 level: u8,
50 default_leaf: L,
51 hasher: &H,
52 ) -> Result<(), MerkleTreeError> {
53 self.level
54 .may_load(storage)?
55 .is_none()
56 .then_some(())
57 .ok_or(MerkleTreeError::AlreadyInit)?;
58
59 self.level.save(storage, &level)?;
60
61 let mut hashes = vec![default_leaf];
62
63 for i in 1..level as usize {
64 let latest = &hashes[i - 1];
65 hashes.push(hasher.hash_two(latest, latest)?);
66 }
67
68 self.hashes.save(storage, &(hashes.clone(), hashes))?;
69
70 Ok(())
71 }
72
73 fn is_valid_root(&self, storage: &dyn Storage, root: &L) -> Result<bool, MerkleTreeError> {
74 Ok(self.root.may_load(storage)?.as_ref() == Some(root))
75 }
76
77 fn insert(
78 &self,
79 storage: &mut dyn Storage,
80 leaf: L,
81 hasher: &H,
82 ) -> Result<(u64, L), MerkleTreeError> {
83 let level = self.level.load(storage)?;
84 let index = {
85 self.leafs
86 .keys(storage, None, None, Order::Descending)
87 .next()
88 .transpose()?
89 .map(|e| e + 1)
90 .unwrap_or_default()
91 };
92
93 (index < 2u64.pow(level as u32))
94 .then_some(())
95 .ok_or(MerkleTreeError::ExceedMaxLeaf)?;
96
97 self.leafs.save(storage, index, &leaf)?;
98
99 let (mut hashes, zeros) = self.hashes.load(storage)?;
100 let mut cur_hash = leaf;
101 let mut cur_idx = index;
102
103 for i in 0..level as usize {
104 let (left, right) = match cur_idx % 2 == 0 {
105 true => {
106 hashes[i] = cur_hash.clone();
107 (&cur_hash, &zeros[i])
108 }
109 false => (&hashes[i], &cur_hash),
110 };
111
112 cur_hash = hasher.hash_two(left, right)?;
113 cur_idx /= 2;
114 }
115
116 self.hashes.save(storage, &(hashes, zeros))?;
117 self.root.save(storage, &cur_hash)?;
118
119 Ok((index, cur_hash))
120 }
121
122 fn get_latest_root(&self, storage: &dyn Storage) -> Result<L, MerkleTreeError> {
123 Ok(self
124 .root
125 .may_load(storage)?
126 .unwrap_or(self.hashes.load(storage)?.1.last().unwrap().clone()))
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use std::{error::Error, str::FromStr};
133
134 use cosmwasm_std::{testing::MockStorage, Uint256};
135
136 use crate::{test_utils::Blake2, Hasher, MerkleTree};
137
138 use super::SparseMerkleTree;
139
140 const TREE: SparseMerkleTree<Uint256, Blake2> =
141 SparseMerkleTree::new("hashes", "leafs", "level", "zeros");
142
143 #[test]
144 fn init() -> Result<(), Box<dyn Error>> {
145 let mut storage = MockStorage::new();
146
147 TREE.init(
148 &mut storage,
149 20,
150 Blake2.hash_two(&Uint256::zero(), &Uint256::zero())?,
151 &Blake2,
152 )?;
153
154 assert_eq!(
155 TREE.get_latest_root(&storage)?,
156 Uint256::from_str(
157 "9249403463272353962338525770558810268347485650856754165003644360089862036530"
158 )?
159 );
160
161 Ok(())
162 }
163
164 #[test]
165 fn insert() -> Result<(), Box<dyn Error>> {
166 let mut storage = MockStorage::new();
167
168 TREE.init(
169 &mut storage,
170 20,
171 Blake2.hash_two(&Uint256::zero(), &Uint256::zero())?,
172 &Blake2,
173 )?;
174
175 let leaf = Blake2.hash_two(&Uint256::one(), &Uint256::one())?;
176
177 let (index, new_root) = TREE.insert(&mut storage, leaf, &Blake2)?;
178
179 assert_eq!(index, 0);
180 assert_eq!(
181 new_root,
182 Uint256::from_str(
183 "65270348628983318905821145914244198139930176154042934882987463098115489862117"
184 )?
185 );
186 assert_eq!(new_root, TREE.get_latest_root(&storage)?);
187 assert!(TREE.is_valid_root(&storage, &new_root)?);
188
189 let (index, new_root) = TREE.insert(&mut storage, leaf, &Blake2)?;
190
191 assert_eq!(index, 1);
192 assert_eq!(
193 new_root,
194 Uint256::from_str(
195 "31390868241958093005646829964058364480768696680064791450319134920411060649604"
196 )?
197 );
198 assert_eq!(new_root, TREE.get_latest_root(&storage)?);
199 assert!(TREE.is_valid_root(&storage, &new_root)?);
200
201 Ok(())
202 }
203
204 #[test]
205 fn root_history() -> Result<(), Box<dyn Error>> {
206 let mut storage = MockStorage::new();
207
208 TREE.init(
209 &mut storage,
210 20,
211 Blake2.hash_two(&Uint256::zero(), &Uint256::zero())?,
212 &Blake2,
213 )?;
214
215 let leaf = Blake2.hash_two(&Uint256::from_u128(5), &Uint256::from_u128(5))?;
216
217 let (_, old_root) = TREE.insert(&mut storage, leaf, &Blake2)?;
218 let (_, new_root) = TREE.insert(&mut storage, leaf, &Blake2)?;
219
220 assert!(!TREE.is_valid_root(&storage, &old_root)?);
221 assert!(TREE.is_valid_root(&storage, &new_root)?);
222
223 Ok(())
224 }
225}