Skip to main content

nectar_primitives/bmt/
hasher.rs

1//! Binary Merkle Tree hasher implementation
2//!
3//! This module provides an implementation of a BMT hasher that uses Keccak256
4//! for computing content-addressed hashes of arbitrary data.
5
6use alloy_primitives::{B256, Keccak256};
7use bytes::Bytes;
8use digest::{FixedOutput, FixedOutputReset, OutputSizeUser, Reset, Update};
9use hybrid_array::{Array, sizes::U32};
10use std::io::{self, Write};
11use std::sync::LazyLock;
12
13// Use rayon for parallel processing on non-WASM platforms
14#[cfg(not(target_arch = "wasm32"))]
15use rayon;
16
17use super::constants::*;
18
19/// Number of zero tree levels for the default body size.
20const ZERO_TREE_LEVELS: usize = zero_tree_levels(DEFAULT_BODY_SIZE);
21
22/// Pre-computed zero hashes for the default body size tree.
23static ZERO_HASHES: LazyLock<[B256; ZERO_TREE_LEVELS]> = LazyLock::new(|| {
24    let mut hashes = [B256::ZERO; ZERO_TREE_LEVELS];
25
26    // Level 0: hash of 64 zero bytes (one segment pair)
27    let mut hasher = Keccak256::new();
28    hasher.update([0u8; SEGMENT_PAIR_LENGTH]);
29    hashes[0] = B256::from_slice(hasher.finalize().as_slice());
30
31    // Each subsequent level: hash of two copies of previous level's hash
32    for i in 1..ZERO_TREE_LEVELS {
33        let mut hasher = Keccak256::new();
34        hasher.update(hashes[i - 1].as_slice());
35        hasher.update(hashes[i - 1].as_slice());
36        hashes[i] = B256::from_slice(hasher.finalize().as_slice());
37    }
38
39    hashes
40});
41
42/// BMT hasher with configurable body size.
43#[derive(Debug, Clone)]
44pub struct Hasher<const BODY_SIZE: usize = DEFAULT_BODY_SIZE> {
45    span: u64,
46    prefix: Option<Vec<u8>>,
47    buffer: [u8; BODY_SIZE],
48    cursor: usize,
49}
50
51impl<const BODY_SIZE: usize> Default for Hasher<BODY_SIZE> {
52    #[inline]
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl<const BODY_SIZE: usize> Hasher<BODY_SIZE> {
59    /// Create a new BMT hasher.
60    #[inline]
61    pub const fn new() -> Self {
62        Self {
63            span: 0,
64            prefix: None,
65            buffer: [0u8; BODY_SIZE],
66            cursor: 0,
67        }
68    }
69
70    /// Set the span of data to be hashed
71    #[inline]
72    pub const fn set_span(&mut self, span: u64) {
73        self.span = span;
74    }
75
76    /// Get the current span
77    #[inline(always)]
78    pub const fn span(&self) -> u64 {
79        self.span
80    }
81
82    /// Add a prefix to the hash calculation
83    #[inline]
84    pub fn prefix_with(&mut self, prefix: &[u8]) {
85        self.prefix = Some(prefix.to_vec());
86    }
87
88    /// Get the current prefix
89    #[inline(always)]
90    pub fn prefix(&self) -> &[u8] {
91        self.prefix.as_deref().unwrap_or(&[])
92    }
93
94    /// Get the current cursor position
95    #[inline(always)]
96    pub const fn position(&self) -> usize {
97        self.cursor
98    }
99
100    /// Get the amount of data currently in the buffer
101    #[inline(always)]
102    pub const fn len(&self) -> usize {
103        self.cursor
104    }
105
106    /// Check if the buffer is empty
107    #[inline(always)]
108    pub const fn is_empty(&self) -> bool {
109        self.cursor == 0
110    }
111
112    /// Update the hasher with more data (non-destructive)
113    #[inline]
114    pub fn update(&mut self, data: &[u8]) {
115        if data.is_empty() {
116            return;
117        }
118
119        // Calculate how much data we can actually copy
120        let available_space = BODY_SIZE - self.cursor;
121        let bytes_to_copy = data.len().min(available_space);
122
123        if bytes_to_copy > 0 {
124            // Copy data at cursor position
125            self.buffer[self.cursor..self.cursor + bytes_to_copy]
126                .copy_from_slice(&data[..bytes_to_copy]);
127
128            // Update cursor position
129            self.cursor += bytes_to_copy;
130        }
131    }
132
133    /// Compute the BMT hash and write to output buffer.
134    #[allow(clippy::should_implement_trait)] // BMT hash, not std::hash::Hash
135    #[inline]
136    pub fn hash(&self, out: &mut [u8]) {
137        let hash = self.sum();
138        out.copy_from_slice(hash.as_slice());
139    }
140
141    /// Compute the BMT hash and return the result (non-destructive)
142    #[inline]
143    #[must_use]
144    pub fn sum(&self) -> B256 {
145        self.finalize_with_prefix(self.hash_internal())
146    }
147
148    /// Check if a byte slice is all zeros.
149    /// Uses chunk-based iteration which LLVM optimizes to SIMD on supported platforms.
150    #[inline(always)]
151    fn is_all_zeros(data: &[u8]) -> bool {
152        // Fold with bitwise OR - any non-zero byte makes the result non-zero
153        // LLVM vectorizes this pattern into efficient SIMD code
154        data.iter().fold(0u8, |acc, &b| acc | b) == 0
155    }
156
157    /// Hash data using a binary merkle tree (internal implementation)
158    ///
159    /// This uses an optimized algorithm that:
160    /// 1. Finds the smallest power-of-2 subtree containing all data
161    /// 2. Hashes only that subtree
162    /// 3. Iteratively combines with pre-computed zero hashes to reach the root
163    #[inline(always)]
164    fn hash_internal(&self) -> B256 {
165        // Special case: no data means entire tree is zeros
166        if self.cursor == 0 {
167            return ZERO_HASHES[ZERO_TREE_LEVELS - 1];
168        }
169
170        // Fast path: if all data is zeros, return pre-computed zero tree root
171        // This avoids hashing entirely when the input is all zeros
172        if Self::is_all_zeros(&self.buffer[..self.cursor]) {
173            return ZERO_HASHES[ZERO_TREE_LEVELS - 1];
174        }
175
176        // Find the smallest power-of-2 subtree that contains all data
177        let effective_size = self
178            .cursor
179            .next_power_of_two()
180            .max(SEGMENT_PAIR_LENGTH)
181            .min(BODY_SIZE);
182
183        // Hash only the effective subtree (which contains all actual data)
184        #[cfg(not(target_arch = "wasm32"))]
185        let mut result = self.hash_subtree_parallel(&self.buffer[..effective_size], effective_size);
186
187        #[cfg(target_arch = "wasm32")]
188        let mut result =
189            self.hash_subtree_sequential(&self.buffer[..effective_size], effective_size);
190
191        // Roll up with zero hashes until we reach the full tree size
192        let mut current_size = effective_size;
193        while current_size < BODY_SIZE {
194            // The current result is a left child, combine with zero hash for right sibling
195            let sibling_level = Self::zero_tree_level(current_size);
196            let mut hasher = Keccak256::new();
197            hasher.update(result.as_slice());
198            hasher.update(ZERO_HASHES[sibling_level].as_slice());
199            result = B256::from_slice(hasher.finalize().as_slice());
200            current_size *= 2;
201        }
202
203        result
204    }
205
206    /// Hash a subtree of exactly `length` bytes (must be power of 2, >= 64)
207    ///
208    /// For sizes < BODY_SIZE: uses sequential hashing (no rayon overhead).
209    /// For BODY_SIZE (4096): uses recursive parallel hashing for maximum throughput.
210    #[cfg(not(target_arch = "wasm32"))]
211    #[inline(always)]
212    fn hash_subtree_parallel(&self, data: &[u8], length: usize) -> B256 {
213        debug_assert!(length.is_power_of_two());
214        debug_assert!(length >= SEGMENT_PAIR_LENGTH);
215
216        // For sizes < BODY_SIZE, use sequential (avoids rayon overhead for small/medium sizes)
217        if length < BODY_SIZE {
218            return self.hash_subtree_sequential(data, length);
219        }
220
221        // For BODY_SIZE (4096): use recursive parallel hashing
222        // Pass cursor as parameter to avoid self indirection in hot loop
223        Self::hash_subtree_recursive_parallel_inner(data, length, self.cursor)
224    }
225
226    /// Recursively hash a subtree using rayon for parallelism.
227    /// Only called for full BODY_SIZE chunks where parallelism pays off.
228    /// Takes cursor as parameter to avoid self indirection in recursive calls.
229    #[cfg(not(target_arch = "wasm32"))]
230    #[inline(always)]
231    fn hash_subtree_recursive_parallel_inner(data: &[u8], length: usize, cursor: usize) -> B256 {
232        debug_assert!(length.is_power_of_two());
233        debug_assert!(length >= SEGMENT_PAIR_LENGTH);
234
235        // Base case: 64 bytes (one segment pair)
236        if length == SEGMENT_PAIR_LENGTH {
237            let mut hasher = Keccak256::new();
238            hasher.update(data);
239            return B256::from_slice(hasher.finalize().as_slice());
240        }
241
242        let half = length / 2;
243        let (left, right) = data.split_at(half);
244
245        // Check if right half is entirely beyond cursor (all zeros in buffer)
246        // cursor is relative to the start of this subtree
247        let (left_hash, right_hash) = if half >= cursor {
248            // Right side is all zeros - compute left only, use precomputed right
249            let left_hash = Self::hash_subtree_recursive_parallel_inner(left, half, cursor);
250            let right_hash = ZERO_HASHES[Self::zero_tree_level(half)];
251            (left_hash, right_hash)
252        } else {
253            // Both sides have data, use parallel execution
254            // Left cursor is capped at half (can't exceed subtree size)
255            // Right cursor is adjusted by half (relative to right subtree start)
256            rayon::join(
257                || Self::hash_subtree_recursive_parallel_inner(left, half, half),
258                || Self::hash_subtree_recursive_parallel_inner(right, half, cursor - half),
259            )
260        };
261
262        let mut hasher = Keccak256::new();
263        hasher.update(left_hash.as_slice());
264        hasher.update(right_hash.as_slice());
265        B256::from_slice(hasher.finalize().as_slice())
266    }
267
268    /// Hash a subtree of exactly `length` bytes (must be power of 2, >= 64) - sequential version
269    #[inline(always)]
270    fn hash_subtree_sequential(&self, data: &[u8], length: usize) -> B256 {
271        debug_assert!(length.is_power_of_two());
272        debug_assert!(length >= SEGMENT_PAIR_LENGTH);
273
274        if length == SEGMENT_PAIR_LENGTH {
275            let mut hasher = Keccak256::new();
276            hasher.update(data);
277            return B256::from_slice(hasher.finalize().as_slice());
278        }
279
280        let half = length / 2;
281        let (left, right) = data.split_at(half);
282
283        // Check if right half is entirely beyond cursor (all zeros in buffer)
284        let (left_hash, right_hash) = if half >= self.cursor {
285            // Right side is all zeros
286            let left_hash = self.hash_subtree_sequential(left, half);
287            let right_hash = ZERO_HASHES[Self::zero_tree_level(half)];
288            (left_hash, right_hash)
289        } else {
290            let left_hash = self.hash_subtree_sequential(left, half);
291            let right_hash = self.hash_subtree_sequential(right, half);
292            (left_hash, right_hash)
293        };
294
295        let mut hasher = Keccak256::new();
296        hasher.update(left_hash.as_slice());
297        hasher.update(right_hash.as_slice());
298        B256::from_slice(hasher.finalize().as_slice())
299    }
300
301    /// Calculate the zero-tree level for a given subtree length.
302    /// Length must be a power of 2 between 64 and 4096.
303    #[inline(always)]
304    const fn zero_tree_level(length: usize) -> usize {
305        // length = 64 * 2^level, so level = log2(length) - log2(64) = log2(length) - 6
306        length.trailing_zeros() as usize - 6
307    }
308
309    /// Finalize with span and optional prefix
310    #[inline(always)]
311    fn finalize_with_prefix(&self, intermediate_hash: B256) -> B256 {
312        let mut hasher = Keccak256::new();
313
314        // Add prefix if present
315        if let Some(prefix) = &self.prefix {
316            hasher.update(prefix);
317        }
318
319        // Add span as little-endian bytes
320        hasher.update(self.span.to_le_bytes());
321
322        // Add the intermediate hash
323        hasher.update(intermediate_hash.as_slice());
324
325        // Finalize to get the result
326        B256::from_slice(hasher.finalize().as_slice())
327    }
328
329    /// Reset the hasher's internal state
330    #[inline(always)]
331    const fn reset_internal(&mut self) {
332        // Simply reset cursor - no need to clear the buffer as it will be overwritten
333        self.cursor = 0;
334        self.span = 0;
335        // Don't reset prefix, as it's considered a configuration parameter
336    }
337
338    /// Get the current data as Bytes (immutable reference)
339    #[inline]
340    #[must_use]
341    pub fn data(&self) -> Bytes {
342        if self.cursor == 0 {
343            return Bytes::new();
344        }
345
346        // Create Bytes from slice
347        Bytes::copy_from_slice(&self.buffer[..self.cursor])
348    }
349
350    /// Get segments for the current level of data
351    #[inline]
352    pub fn get_level_segments(&self, data: &[u8]) -> Vec<B256> {
353        let branches = branches_for_body_size(BODY_SIZE);
354
355        #[cfg(not(target_arch = "wasm32"))]
356        {
357            use rayon::prelude::*;
358            (0..branches)
359                .into_par_iter()
360                .map(|i| self.compute_segment_hash(data, i))
361                .collect()
362        }
363
364        #[cfg(target_arch = "wasm32")]
365        {
366            (0..branches)
367                .map(|i| self.compute_segment_hash(data, i))
368                .collect()
369        }
370    }
371
372    /// Compute the hash for a single segment at given index
373    #[inline(always)]
374    fn compute_segment_hash(&self, data: &[u8], i: usize) -> B256 {
375        let start = i << SEGMENT_SIZE_LOG2; // Equivalent to i * SEGMENT_SIZE
376        let mut hasher = Keccak256::new();
377
378        if start < data.len() {
379            let end = (start + SEGMENT_SIZE).min(data.len());
380            let segment_data = &data[start..end];
381
382            // Update with segment data
383            hasher.update(segment_data);
384
385            // If segment is shorter than SEGMENT_SIZE, the remaining bytes are zeros
386            if segment_data.len() < SEGMENT_SIZE {
387                hasher.update(&[0u8; SEGMENT_SIZE][..(SEGMENT_SIZE - segment_data.len())]);
388            }
389        } else {
390            // Empty segment (all zeros)
391            hasher.update([0u8; SEGMENT_SIZE]);
392        }
393
394        B256::from_slice(hasher.finalize().as_slice())
395    }
396}
397
398impl<const BODY_SIZE: usize> Write for Hasher<BODY_SIZE> {
399    #[inline]
400    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
401        let available = BODY_SIZE - self.cursor;
402        let to_write = buf.len().min(available);
403        if to_write > 0 {
404            self.buffer[self.cursor..self.cursor + to_write].copy_from_slice(&buf[..to_write]);
405            self.cursor += to_write;
406        }
407        Ok(to_write)
408    }
409
410    #[inline]
411    fn flush(&mut self) -> io::Result<()> {
412        Ok(())
413    }
414}
415
416impl<const BODY_SIZE: usize> OutputSizeUser for Hasher<BODY_SIZE> {
417    type OutputSize = U32;
418}
419
420impl<const BODY_SIZE: usize> Update for Hasher<BODY_SIZE> {
421    #[inline]
422    fn update(&mut self, data: &[u8]) {
423        self.update(data);
424    }
425}
426
427impl<const BODY_SIZE: usize> Reset for Hasher<BODY_SIZE> {
428    #[inline]
429    fn reset(&mut self) {
430        self.reset_internal();
431    }
432}
433
434impl<const BODY_SIZE: usize> FixedOutput for Hasher<BODY_SIZE> {
435    #[inline]
436    fn finalize_into(self, out: &mut Array<u8, Self::OutputSize>) {
437        let b256 = self.sum();
438        out.copy_from_slice(b256.as_slice());
439    }
440}
441
442impl<const BODY_SIZE: usize> FixedOutputReset for Hasher<BODY_SIZE> {
443    #[inline]
444    fn finalize_into_reset(&mut self, out: &mut Array<u8, Self::OutputSize>) {
445        let b256 = self.sum();
446        out.copy_from_slice(b256.as_slice());
447        self.reset_internal();
448    }
449}
450
451impl<const BODY_SIZE: usize> digest::HashMarker for Hasher<BODY_SIZE> {}
452
453/// Factory for creating BMT hashers.
454#[derive(Debug, Default, Clone)]
455pub struct HasherFactory<const BODY_SIZE: usize = DEFAULT_BODY_SIZE>;
456
457impl<const BODY_SIZE: usize> HasherFactory<BODY_SIZE> {
458    /// Create a new factory.
459    #[inline]
460    pub const fn new() -> Self {
461        Self
462    }
463
464    /// Create a new BMT hasher.
465    #[inline]
466    pub const fn create_hasher(&self) -> Hasher<BODY_SIZE> {
467        Hasher::new()
468    }
469}