Skip to main content

nectar_primitives/bmt/
proof.rs

1//! Proof-related traits and structures for the Binary Merkle Tree
2//!
3//! This module provides functionality for generating and verifying inclusion proofs
4//! for specific segments within a binary merkle tree.
5
6use alloy_primitives::{B256, Keccak256};
7
8use crate::bmt::{Hasher, constants::*, error::BmtError};
9use crate::error::Result;
10
11/// Represents a proof for a specific segment in a Binary Merkle Tree
12#[derive(Clone, Debug)]
13pub struct Proof {
14    /// The segment index this proof is for
15    pub segment_index: usize,
16    /// The segment data being proven
17    pub segment: B256,
18    /// The proof segments (sibling hashes in the path to the root)
19    pub proof_segments: Vec<B256>,
20    /// The span of the data
21    pub span: u64,
22    /// Optional prefix (used during verification)
23    pub prefix: Option<Vec<u8>>,
24}
25
26impl Proof {
27    /// Create a new BMT proof
28    pub const fn new(
29        segment_index: usize,
30        segment: B256,
31        proof_segments: Vec<B256>,
32        span: u64,
33        prefix: Option<Vec<u8>>,
34    ) -> Self {
35        Self {
36            segment_index,
37            segment,
38            proof_segments,
39            span,
40            prefix,
41        }
42    }
43
44    /// Verify this proof against a root hash
45    pub fn verify(&self, root_hash: &[u8]) -> Result<bool> {
46        if self.proof_segments.len() != PROOF_LENGTH {
47            return Err(
48                BmtError::invalid_proof_length(PROOF_LENGTH, self.proof_segments.len()).into(),
49            );
50        }
51
52        // Start with the segment being proven
53        let mut current_hash = self.segment;
54        let mut current_index = self.segment_index;
55
56        // Apply each proof segment to compute the root
57        for proof_segment in &self.proof_segments {
58            let mut hasher = Keccak256::new();
59
60            // Order matters - left then right
61            if current_index.is_multiple_of(2) {
62                hasher.update(current_hash.as_slice());
63                hasher.update(proof_segment.as_slice());
64            } else {
65                hasher.update(proof_segment.as_slice());
66                hasher.update(current_hash.as_slice());
67            }
68
69            // Get hash for next level
70            current_hash = B256::from_slice(hasher.finalize().as_slice());
71            current_index /= 2;
72        }
73
74        // Final step: add prefix (if any) and span to compute the root hash
75        let mut hasher = Keccak256::new();
76
77        // Add prefix if present
78        if let Some(prefix) = &self.prefix {
79            hasher.update(prefix);
80        }
81
82        // Add span as little-endian bytes
83        hasher.update(self.span.to_le_bytes());
84
85        // Add the intermediate hash
86        hasher.update(current_hash.as_slice());
87
88        let computed_root = B256::from_slice(hasher.finalize().as_slice());
89
90        // Compare with provided root hash
91        Ok(computed_root.as_slice() == root_hash)
92    }
93}
94
95/// Extension trait to add proof-related functionality to BMTHasher
96pub trait Prover {
97    /// Generate a proof for a specific segment
98    fn generate_proof(&self, data: &[u8], segment_index: usize) -> Result<Proof>;
99
100    /// Verify a proof against a root hash
101    fn verify_proof(proof: &Proof, root_hash: &[u8]) -> Result<bool>;
102}
103
104impl Prover for Hasher {
105    fn generate_proof(&self, data: &[u8], segment_index: usize) -> Result<Proof> {
106        if segment_index >= BRANCHES {
107            return Err(self::BmtError::invalid_input_size(format!(
108                "Segment index {segment_index} out of bounds for BRANCHES"
109            ))
110            .into());
111        }
112
113        // Create segments from data, padding with zeros if needed
114        let data_len = data.len();
115
116        // Use platform-specific optimizations for segment generation
117        #[cfg(not(target_arch = "wasm32"))]
118        let segments = {
119            use rayon::prelude::*;
120            (0..BRANCHES)
121                .into_par_iter()
122                .map(|i| {
123                    let start = i * SEGMENT_SIZE;
124                    let mut segment = [0u8; SEGMENT_SIZE];
125
126                    if start < data_len {
127                        let end = (start + SEGMENT_SIZE).min(data_len);
128                        let copy_len = end - start;
129                        segment[..copy_len].copy_from_slice(&data[start..end]);
130                    }
131
132                    B256::from_slice(&segment)
133                })
134                .collect::<Vec<_>>()
135        };
136
137        #[cfg(target_arch = "wasm32")]
138        let segments = {
139            let mut segs = Vec::with_capacity(BRANCHES);
140            for i in 0..BRANCHES {
141                let start = i * SEGMENT_SIZE;
142                let mut segment = [0u8; SEGMENT_SIZE];
143
144                if start < data_len {
145                    let end = (start + SEGMENT_SIZE).min(data_len);
146                    let copy_len = end - start;
147                    segment[..copy_len].copy_from_slice(&data[start..end]);
148                }
149
150                segs.push(B256::from_slice(&segment));
151            }
152
153            segs
154        };
155
156        // Get the segment being proven
157        let segment = segments[segment_index];
158
159        // Generate proof segments
160        let mut proof_segments = Vec::with_capacity(PROOF_LENGTH);
161
162        // Build the Merkle tree level by level
163        let mut current_level = segments;
164        let mut current_index = segment_index;
165
166        // Continue until we reach the root (or until we have BMT_PROOF_LENGTH segments)
167        while proof_segments.len() < PROOF_LENGTH {
168            // Get sibling's index
169            let sibling_index = if current_index.is_multiple_of(2) {
170                current_index + 1
171            } else {
172                current_index - 1
173            };
174
175            // Add sibling to proof
176            if sibling_index < current_level.len() {
177                proof_segments.push(current_level[sibling_index]);
178            } else {
179                proof_segments.push(B256::ZERO);
180            }
181
182            // Compute the next level up in the tree
183            let mut next_level = Vec::with_capacity(current_level.len().div_ceil(2));
184
185            for i in (0..current_level.len()).step_by(2) {
186                let left = &current_level[i];
187                let right = if i + 1 < current_level.len() {
188                    &current_level[i + 1]
189                } else {
190                    &B256::ZERO
191                };
192
193                // Hash the pair to create the parent node
194                let mut hasher = Keccak256::new();
195                hasher.update(left.as_slice());
196                hasher.update(right.as_slice());
197
198                let parent = B256::from_slice(hasher.finalize().as_slice());
199                next_level.push(parent);
200            }
201
202            // Move up to the next level
203            current_level = next_level;
204            current_index /= 2;
205
206            // If we've reached the root or have only one node, break
207            if current_level.len() <= 1 {
208                break;
209            }
210        }
211
212        // Ensure we have exactly BMT_PROOF_LENGTH segments in our proof
213        while proof_segments.len() < PROOF_LENGTH {
214            proof_segments.push(B256::ZERO);
215        }
216
217        // Include the prefix in the proof if there is one
218        let prefix = if !self.prefix().is_empty() {
219            Some(self.prefix().to_vec())
220        } else {
221            None
222        };
223
224        Ok(Proof::new(
225            segment_index,
226            segment,
227            proof_segments,
228            self.span(),
229            prefix,
230        ))
231    }
232
233    fn verify_proof(proof: &Proof, root_hash: &[u8]) -> Result<bool> {
234        proof.verify(root_hash)
235    }
236}