winter_crypto/merkle/
proofs.rs1use alloc::{collections::BTreeMap, vec::Vec};
7
8use utils::{ByteReader, Deserializable, DeserializationError, Serializable};
9
10use super::MerkleTreeOpening;
11use crate::{errors::MerkleTreeError, Hasher};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct BatchMerkleProof<H: Hasher> {
24 pub nodes: Vec<Vec<H::Digest>>,
26 pub depth: u8,
28}
29
30impl<H: Hasher> BatchMerkleProof<H> {
31 pub fn from_single_proofs(
39 proofs: &[MerkleTreeOpening<H>],
40 indexes: &[usize],
41 ) -> BatchMerkleProof<H> {
42 assert!(!proofs.is_empty(), "at least one proof must be provided");
44 assert_eq!(proofs.len(), indexes.len(), "number of proofs must equal number of indexes");
45
46 let depth = proofs[0].1.len();
47
48 let mut proof_map = BTreeMap::new();
50 for (&index, proof) in indexes.iter().zip(proofs.iter().cloned()) {
51 assert_eq!(depth, proof.1.len(), "not all proofs have the same length");
52 proof_map.insert(index, proof);
53 }
54 let indexes = proof_map.keys().cloned().collect::<Vec<_>>();
55 let proofs = proof_map.values().cloned().collect::<Vec<_>>();
56 proof_map.clear();
57
58 let mut leaves = vec![H::Digest::default(); indexes.len()];
59 let mut nodes: Vec<Vec<H::Digest>> = Vec::with_capacity(indexes.len());
60
61 let mut i = 0;
63 while i < indexes.len() {
64 leaves[i] = proofs[i].0;
65
66 if indexes.len() > i + 1 && are_siblings(indexes[i], indexes[i + 1]) {
67 leaves[i + 1] = proofs[i].1[0];
68 nodes.push(vec![]);
69 i += 1;
70 } else {
71 nodes.push(vec![proofs[i].1[0]]);
72 }
73 proof_map.insert(indexes[i] >> 1, proofs[i].clone());
74 i += 1;
75 }
76
77 for d in 1..depth {
79 let indexes = proof_map.keys().cloned().collect::<Vec<_>>();
80 let mut next_proof_map = BTreeMap::new();
81
82 let mut i = 0;
83 while i < indexes.len() {
84 let index = indexes[i];
85 let proof = proof_map.get(&index).unwrap();
86 if indexes.len() > i + 1 && are_siblings(index, indexes[i + 1]) {
87 i += 1;
88 } else {
89 nodes[i].push(proof.1[d]);
90 }
91 next_proof_map.insert(index >> 1, proof.clone());
92 i += 1;
93 }
94
95 core::mem::swap(&mut proof_map, &mut next_proof_map);
96 }
97
98 BatchMerkleProof { nodes, depth: (depth) as u8 }
99 }
100
101 pub fn get_root(
111 &self,
112 indexes: &[usize],
113 leaves: &[H::Digest],
114 ) -> Result<H::Digest, MerkleTreeError> {
115 if indexes.is_empty() {
116 return Err(MerkleTreeError::TooFewLeafIndexes);
117 }
118
119 let mut buf = [H::Digest::default(); 2];
120 let mut v = BTreeMap::new();
121
122 let index_map = super::map_indexes(indexes, self.depth as usize)?;
124 let indexes = super::normalize_indexes(indexes);
125 if indexes.len() != self.nodes.len() {
126 return Err(MerkleTreeError::InvalidProof);
127 }
128
129 let offset = 2usize.pow(self.depth as u32);
131 let mut next_indexes: Vec<usize> = Vec::new();
132 let mut proof_pointers: Vec<usize> = Vec::with_capacity(indexes.len());
133 for (i, index) in indexes.into_iter().enumerate() {
134 match index_map.get(&index) {
136 Some(&index1) => {
137 if leaves.len() <= index1 {
138 return Err(MerkleTreeError::InvalidProof);
139 }
140 buf[0] = leaves[index1];
141 match index_map.get(&(index + 1)) {
142 Some(&index2) => {
143 if leaves.len() <= index2 {
144 return Err(MerkleTreeError::InvalidProof);
145 }
146 buf[1] = leaves[index2];
147 proof_pointers.push(0);
148 },
149 None => {
150 if self.nodes[i].is_empty() {
151 return Err(MerkleTreeError::InvalidProof);
152 }
153 buf[1] = self.nodes[i][0];
154 proof_pointers.push(1);
155 },
156 }
157 },
158 None => {
159 if self.nodes[i].is_empty() {
160 return Err(MerkleTreeError::InvalidProof);
161 }
162 buf[0] = self.nodes[i][0];
163 match index_map.get(&(index + 1)) {
164 Some(&index2) => {
165 if leaves.len() <= index2 {
166 return Err(MerkleTreeError::InvalidProof);
167 }
168 buf[1] = leaves[index2];
169 },
170 None => return Err(MerkleTreeError::InvalidProof),
171 }
172 proof_pointers.push(1);
173 },
174 }
175
176 let parent = H::merge(&buf);
178
179 let parent_index = (offset + index) >> 1;
180 v.insert(parent_index, parent);
181 next_indexes.push(parent_index);
182 }
183
184 for _ in 1..self.depth {
186 let indexes = next_indexes.clone();
187 next_indexes.truncate(0);
188
189 let mut i = 0;
190 while i < indexes.len() {
191 let node_index = indexes[i];
192 let sibling_index = node_index ^ 1;
193
194 let sibling: H::Digest;
196 if i + 1 < indexes.len() && indexes[i + 1] == sibling_index {
197 sibling = match v.get(&sibling_index) {
198 Some(sibling) => *sibling,
199 None => return Err(MerkleTreeError::InvalidProof),
200 };
201 i += 1;
202 } else {
203 let pointer = proof_pointers[i];
204 if self.nodes[i].len() <= pointer {
205 return Err(MerkleTreeError::InvalidProof);
206 }
207 sibling = self.nodes[i][pointer];
208 proof_pointers[i] += 1;
209 }
210
211 let node = match v.get(&node_index) {
213 Some(node) => node,
214 None => return Err(MerkleTreeError::InvalidProof),
215 };
216
217 if node_index & 1 != 0 {
219 buf[0] = sibling;
220 buf[1] = *node;
221 } else {
222 buf[0] = *node;
223 buf[1] = sibling;
224 }
225 let parent = H::merge(&buf);
226
227 let parent_index = node_index >> 1;
229 v.insert(parent_index, parent);
230 next_indexes.push(parent_index);
231
232 i += 1;
233 }
234 }
235 v.remove(&1).ok_or(MerkleTreeError::InvalidProof)
236 }
237
238 pub fn into_openings(
245 self,
246 leaves: &[H::Digest],
247 indexes: &[usize],
248 ) -> Result<Vec<MerkleTreeOpening<H>>, MerkleTreeError> {
249 if indexes.is_empty() {
250 return Err(MerkleTreeError::TooFewLeafIndexes);
251 }
252 if indexes.len() != leaves.len() {
253 return Err(MerkleTreeError::InvalidProof);
254 }
255
256 let mut partial_tree_map = BTreeMap::new();
257
258 for (&i, leaf) in indexes.iter().zip(leaves.iter()) {
259 partial_tree_map.insert(i + (1 << (self.depth)), *leaf);
260 }
261
262 let mut buf = [H::Digest::default(); 2];
263 let mut v = BTreeMap::new();
264
265 let original_indexes = indexes;
267 let index_map = super::map_indexes(indexes, self.depth as usize)?;
268 let indexes = super::normalize_indexes(indexes);
269 if indexes.len() != self.nodes.len() {
270 return Err(MerkleTreeError::InvalidProof);
271 }
272
273 let offset = 2usize.pow(self.depth as u32);
275 let mut next_indexes: Vec<usize> = Vec::new();
276 let mut proof_pointers: Vec<usize> = Vec::with_capacity(indexes.len());
277 for (i, index) in indexes.into_iter().enumerate() {
278 match index_map.get(&index) {
280 Some(&index1) => {
281 if leaves.len() <= index1 {
282 return Err(MerkleTreeError::InvalidProof);
283 }
284 buf[0] = leaves[index1];
285 match index_map.get(&(index + 1)) {
286 Some(&index2) => {
287 if leaves.len() <= index2 {
288 return Err(MerkleTreeError::InvalidProof);
289 }
290 buf[1] = leaves[index2];
291 proof_pointers.push(0);
292 },
293 None => {
294 if self.nodes[i].is_empty() {
295 return Err(MerkleTreeError::InvalidProof);
296 }
297 buf[1] = self.nodes[i][0];
298 proof_pointers.push(1);
299 },
300 }
301 },
302 None => {
303 if self.nodes[i].is_empty() {
304 return Err(MerkleTreeError::InvalidProof);
305 }
306 buf[0] = self.nodes[i][0];
307 match index_map.get(&(index + 1)) {
308 Some(&index2) => {
309 if leaves.len() <= index2 {
310 return Err(MerkleTreeError::InvalidProof);
311 }
312 buf[1] = leaves[index2];
313 },
314 None => return Err(MerkleTreeError::InvalidProof),
315 }
316 proof_pointers.push(1);
317 },
318 }
319
320 let parent = H::merge(&buf);
322 partial_tree_map.insert(offset + index, buf[0]);
323 partial_tree_map.insert((offset + index) ^ 1, buf[1]);
324 let parent_index = (offset + index) >> 1;
325 v.insert(parent_index, parent);
326 next_indexes.push(parent_index);
327 partial_tree_map.insert(parent_index, parent);
328 }
329
330 for _ in 1..self.depth {
332 let indexes = next_indexes.clone();
333 next_indexes.clear();
334
335 let mut i = 0;
336 while i < indexes.len() {
337 let node_index = indexes[i];
338 let sibling_index = node_index ^ 1;
339
340 let sibling = if i + 1 < indexes.len() && indexes[i + 1] == sibling_index {
342 i += 1;
343 match v.get(&sibling_index) {
344 Some(sibling) => *sibling,
345 None => return Err(MerkleTreeError::InvalidProof),
346 }
347 } else {
348 let pointer = proof_pointers[i];
349 if self.nodes[i].len() <= pointer {
350 return Err(MerkleTreeError::InvalidProof);
351 }
352 proof_pointers[i] += 1;
353 self.nodes[i][pointer]
354 };
355
356 let node = match v.get(&node_index) {
358 Some(node) => node,
359 None => return Err(MerkleTreeError::InvalidProof),
360 };
361
362 partial_tree_map.insert(node_index ^ 1, sibling);
364 let parent = if node_index & 1 != 0 {
365 H::merge(&[sibling, *node])
366 } else {
367 H::merge(&[*node, sibling])
368 };
369
370 let parent_index = node_index >> 1;
372 v.insert(parent_index, parent);
373 next_indexes.push(parent_index);
374 partial_tree_map.insert(parent_index, parent);
375
376 i += 1;
377 }
378 }
379
380 original_indexes
381 .iter()
382 .map(|&i| get_proof::<H>(i, &partial_tree_map, self.depth as usize))
383 .collect()
384 }
385}
386
387impl<H: Hasher> Serializable for BatchMerkleProof<H> {
391 fn write_into<W: utils::ByteWriter>(&self, target: &mut W) {
393 target.write_u8(self.depth);
394 target.write_usize(self.nodes.len());
395
396 for nodes in self.nodes.iter() {
397 nodes.write_into(target);
399 }
400 }
401}
402
403impl<H: Hasher> Deserializable for BatchMerkleProof<H> {
404 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
411 let depth = source.read_u8()?;
412 let num_node_vectors = source.read_usize()?;
413
414 let mut nodes = Vec::with_capacity(num_node_vectors);
415 for _ in 0..num_node_vectors {
416 let digests = Vec::<_>::read_from(source)?;
418 nodes.push(digests);
419 }
420
421 Ok(BatchMerkleProof { nodes, depth })
422 }
423}
424
425fn are_siblings(left: usize, right: usize) -> bool {
431 left & 1 == 0 && right - 1 == left
432}
433
434pub fn get_proof<H: Hasher>(
436 index: usize,
437 tree: &BTreeMap<usize, <H as Hasher>::Digest>,
438 depth: usize,
439) -> Result<MerkleTreeOpening<H>, MerkleTreeError> {
440 let mut index = index + (1 << depth);
441 let leaf = if let Some(leaf) = tree.get(&index) {
442 *leaf
443 } else {
444 return Err(MerkleTreeError::InvalidProof);
445 };
446
447 let mut proof = vec![];
448 while index > 1 {
449 let leaf = if let Some(leaf) = tree.get(&(index ^ 1)) {
450 *leaf
451 } else {
452 return Err(MerkleTreeError::InvalidProof);
453 };
454
455 proof.push(leaf);
456 index >>= 1;
457 }
458
459 Ok((leaf, proof))
460}