nectar_primitives/bmt/
proof.rs1use alloy_primitives::{B256, Keccak256};
7
8use crate::bmt::{Hasher, constants::*, error::BmtError};
9use crate::error::Result;
10
11#[derive(Clone, Debug)]
13pub struct Proof {
14 pub segment_index: usize,
16 pub segment: B256,
18 pub proof_segments: Vec<B256>,
20 pub span: u64,
22 pub prefix: Option<Vec<u8>>,
24}
25
26impl Proof {
27 pub const fn new(
29 segment_index: usize,
30 segment: B256,
31 proof_segments: Vec<B256>,
32 span: u64,
33 prefix: Option<Vec<u8>>,
34 ) -> Self {
35 Self {
36 segment_index,
37 segment,
38 proof_segments,
39 span,
40 prefix,
41 }
42 }
43
44 pub fn verify(&self, root_hash: &[u8]) -> Result<bool> {
46 if self.proof_segments.len() != PROOF_LENGTH {
47 return Err(
48 BmtError::invalid_proof_length(PROOF_LENGTH, self.proof_segments.len()).into(),
49 );
50 }
51
52 let mut current_hash = self.segment;
54 let mut current_index = self.segment_index;
55
56 for proof_segment in &self.proof_segments {
58 let mut hasher = Keccak256::new();
59
60 if current_index.is_multiple_of(2) {
62 hasher.update(current_hash.as_slice());
63 hasher.update(proof_segment.as_slice());
64 } else {
65 hasher.update(proof_segment.as_slice());
66 hasher.update(current_hash.as_slice());
67 }
68
69 current_hash = B256::from_slice(hasher.finalize().as_slice());
71 current_index /= 2;
72 }
73
74 let mut hasher = Keccak256::new();
76
77 if let Some(prefix) = &self.prefix {
79 hasher.update(prefix);
80 }
81
82 hasher.update(self.span.to_le_bytes());
84
85 hasher.update(current_hash.as_slice());
87
88 let computed_root = B256::from_slice(hasher.finalize().as_slice());
89
90 Ok(computed_root.as_slice() == root_hash)
92 }
93}
94
95pub trait Prover {
97 fn generate_proof(&self, data: &[u8], segment_index: usize) -> Result<Proof>;
99
100 fn verify_proof(proof: &Proof, root_hash: &[u8]) -> Result<bool>;
102}
103
104impl Prover for Hasher {
105 fn generate_proof(&self, data: &[u8], segment_index: usize) -> Result<Proof> {
106 if segment_index >= BRANCHES {
107 return Err(self::BmtError::invalid_input_size(format!(
108 "Segment index {segment_index} out of bounds for BRANCHES"
109 ))
110 .into());
111 }
112
113 let data_len = data.len();
115
116 #[cfg(not(target_arch = "wasm32"))]
118 let segments = {
119 use rayon::prelude::*;
120 (0..BRANCHES)
121 .into_par_iter()
122 .map(|i| {
123 let start = i * SEGMENT_SIZE;
124 let mut segment = [0u8; SEGMENT_SIZE];
125
126 if start < data_len {
127 let end = (start + SEGMENT_SIZE).min(data_len);
128 let copy_len = end - start;
129 segment[..copy_len].copy_from_slice(&data[start..end]);
130 }
131
132 B256::from_slice(&segment)
133 })
134 .collect::<Vec<_>>()
135 };
136
137 #[cfg(target_arch = "wasm32")]
138 let segments = {
139 let mut segs = Vec::with_capacity(BRANCHES);
140 for i in 0..BRANCHES {
141 let start = i * SEGMENT_SIZE;
142 let mut segment = [0u8; SEGMENT_SIZE];
143
144 if start < data_len {
145 let end = (start + SEGMENT_SIZE).min(data_len);
146 let copy_len = end - start;
147 segment[..copy_len].copy_from_slice(&data[start..end]);
148 }
149
150 segs.push(B256::from_slice(&segment));
151 }
152
153 segs
154 };
155
156 let segment = segments[segment_index];
158
159 let mut proof_segments = Vec::with_capacity(PROOF_LENGTH);
161
162 let mut current_level = segments;
164 let mut current_index = segment_index;
165
166 while proof_segments.len() < PROOF_LENGTH {
168 let sibling_index = if current_index.is_multiple_of(2) {
170 current_index + 1
171 } else {
172 current_index - 1
173 };
174
175 if sibling_index < current_level.len() {
177 proof_segments.push(current_level[sibling_index]);
178 } else {
179 proof_segments.push(B256::ZERO);
180 }
181
182 let mut next_level = Vec::with_capacity(current_level.len().div_ceil(2));
184
185 for i in (0..current_level.len()).step_by(2) {
186 let left = ¤t_level[i];
187 let right = if i + 1 < current_level.len() {
188 ¤t_level[i + 1]
189 } else {
190 &B256::ZERO
191 };
192
193 let mut hasher = Keccak256::new();
195 hasher.update(left.as_slice());
196 hasher.update(right.as_slice());
197
198 let parent = B256::from_slice(hasher.finalize().as_slice());
199 next_level.push(parent);
200 }
201
202 current_level = next_level;
204 current_index /= 2;
205
206 if current_level.len() <= 1 {
208 break;
209 }
210 }
211
212 while proof_segments.len() < PROOF_LENGTH {
214 proof_segments.push(B256::ZERO);
215 }
216
217 let prefix = if !self.prefix().is_empty() {
219 Some(self.prefix().to_vec())
220 } else {
221 None
222 };
223
224 Ok(Proof::new(
225 segment_index,
226 segment,
227 proof_segments,
228 self.span(),
229 prefix,
230 ))
231 }
232
233 fn verify_proof(proof: &Proof, root_hash: &[u8]) -> Result<bool> {
234 proof.verify(root_hash)
235 }
236}