alloy_primitives/utils/
keccak_cache.rs

1//! A minimalistic one-way set associative cache for Keccak256 values.
2//!
3//! This cache has a fixed size to allow fast access and minimize per-call overhead.
4
5use super::{
6    hint::{likely, unlikely},
7    keccak256_impl as keccak256,
8};
9use crate::{B256, KECCAK256_EMPTY};
10use core::{
11    cell::UnsafeCell,
12    sync::atomic::{AtomicUsize, Ordering},
13};
14
15const ENABLE_STATS: bool = false || option_env!("KECCAK_CACHE_STATS").is_some();
16
17/// Number of cache entries (must be a power of 2).
18const COUNT: usize = 1 << 17; // ~131k entries
19
20const INDEX_MASK: usize = COUNT - 1;
21const HASH_MASK: usize = !INDEX_MASK;
22
23const LOCKED_BIT: usize = 0x0000_8000;
24
25/// Maximum input length that can be cached.
26pub(super) const MAX_INPUT_LEN: usize = 128 - 32 - size_of::<usize>();
27
28/// Global cache storage.
29///
30/// This is sort of an open-coded flat `HashMap<&[u8], Mutex<EntryData>>`.
31static CACHE: [Entry; COUNT] = [const { Entry::new() }; COUNT];
32
33pub(super) fn compute(input: &[u8]) -> B256 {
34    if unlikely(input.is_empty() | (input.len() > MAX_INPUT_LEN)) {
35        return if input.is_empty() {
36            stats::hit(0);
37            KECCAK256_EMPTY
38        } else {
39            stats::out_of_range(input.len());
40            keccak256(input)
41        };
42    }
43
44    let hash = hash_bytes(input);
45    let entry = &CACHE[hash & INDEX_MASK];
46
47    // Combine hash bits and length.
48    // This acts as a cache key to quickly determine if the entry is valid in the next check.
49    let combined = (hash & HASH_MASK) | input.len();
50
51    if entry.try_lock(Some(combined)) {
52        // SAFETY: We hold the lock, so we have exclusive access.
53        let EntryData { value, keccak256: result } = unsafe { *entry.data.get() };
54
55        entry.unlock(combined);
56
57        if likely(value[..input.len()] == input[..]) {
58            // Cache hit!
59            stats::hit(input.len());
60            return result;
61        }
62        // Hash collision: same `combined` value but different input.
63        // This is extremely rare, but can still happen. For correctness we must still handle it.
64        stats::collision(input, &value[..input.len()]);
65    }
66    stats::miss(input.len());
67
68    // Cache miss or contention - compute hash.
69    let result = keccak256(input);
70
71    // Try to update cache entry if not locked.
72    if entry.try_lock(None) {
73        // SAFETY: We hold the lock, so we have exclusive access.
74        unsafe {
75            let data = &mut *entry.data.get();
76            data.value[..input.len()].copy_from_slice(input);
77            data.keccak256 = result;
78        }
79
80        entry.unlock(combined);
81    }
82
83    result
84}
85
86/// A cache entry.
87#[repr(C, align(128))]
88struct Entry {
89    combined: AtomicUsize,
90    data: UnsafeCell<EntryData>,
91}
92
93#[repr(C, align(4))]
94#[derive(Clone, Copy)]
95struct EntryData {
96    value: [u8; MAX_INPUT_LEN],
97    keccak256: B256,
98}
99
100impl Entry {
101    #[inline]
102    const fn new() -> Self {
103        // SAFETY: POD.
104        unsafe { core::mem::zeroed() }
105    }
106
107    #[inline]
108    fn try_lock(&self, expected: Option<usize>) -> bool {
109        let state = self.combined.load(Ordering::Relaxed);
110        if let Some(expected) = expected {
111            if state != expected {
112                return false;
113            }
114        } else if state & LOCKED_BIT != 0 {
115            return false;
116        }
117        self.combined
118            .compare_exchange(state, state | LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
119            .is_ok()
120    }
121
122    #[inline]
123    fn unlock(&self, combined: usize) {
124        self.combined.store(combined, Ordering::Release);
125    }
126}
127
128// SAFETY: `Entry` is a specialized `Mutex<EntryData>` that never blocks.
129unsafe impl Send for Entry {}
130unsafe impl Sync for Entry {}
131
132#[inline(always)]
133#[allow(clippy::missing_const_for_fn)]
134fn hash_bytes(input: &[u8]) -> usize {
135    // This is tricky because our most common inputs are medium length: 16..=88
136    // `foldhash` and `rapidhash` have a fast-path for ..16 bytes and outline the rest,
137    // but really we want the opposite, or at least the 16.. path to be inlined.
138
139    // SAFETY: `input.len()` is checked to be within the bounds of `MAX_INPUT_LEN` by caller.
140    unsafe { core::hint::assert_unchecked(input.len() <= MAX_INPUT_LEN) };
141    if input.len() <= 16 {
142        super::hint::cold_path();
143    }
144    let hash = rapidhash::v3::rapidhash_v3_micro_inline::<false, false>(
145        input,
146        const { &rapidhash::v3::RapidSecrets::seed(0) },
147    );
148
149    if cfg!(target_pointer_width = "32") {
150        ((hash >> 32) as usize) ^ (hash as usize)
151    } else {
152        hash as usize
153    }
154}
155
156// NOT PUBLIC API.
157pub(super) mod stats {
158    use super::*;
159    use std::{collections::HashMap, sync::Mutex};
160
161    type BuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
162
163    static STATS: KeccakCacheStats = KeccakCacheStats {
164        hits: [const { AtomicUsize::new(0) }; MAX_INPUT_LEN + 1],
165        misses: [const { AtomicUsize::new(0) }; MAX_INPUT_LEN + 1],
166        out_of_range: Mutex::new(HashMap::with_hasher(BuildHasher::new())),
167        collisions: Mutex::new(Vec::new()),
168    };
169
170    struct KeccakCacheStats {
171        hits: [AtomicUsize; MAX_INPUT_LEN + 1],
172        misses: [AtomicUsize; MAX_INPUT_LEN + 1],
173        out_of_range: Mutex<HashMap<usize, usize, BuildHasher>>,
174        collisions: Mutex<Vec<(String, String)>>,
175    }
176
177    #[inline]
178    pub(super) fn hit(len: usize) {
179        if !ENABLE_STATS {
180            return;
181        }
182        STATS.hits[len].fetch_add(1, Ordering::Relaxed);
183    }
184
185    #[inline]
186    pub(super) fn miss(len: usize) {
187        if !ENABLE_STATS {
188            return;
189        }
190        STATS.misses[len].fetch_add(1, Ordering::Relaxed);
191    }
192
193    #[inline(never)]
194    pub(super) fn out_of_range(len: usize) {
195        if !ENABLE_STATS {
196            return;
197        }
198        *STATS.out_of_range.lock().unwrap().entry(len).or_insert(0) += 1;
199    }
200
201    #[inline(never)]
202    pub(super) fn collision(input: &[u8], cached: &[u8]) {
203        if !ENABLE_STATS {
204            return;
205        }
206        let input_hex = crate::hex::encode(input);
207        let cached_hex = crate::hex::encode(cached);
208        STATS.collisions.lock().unwrap().push((input_hex, cached_hex));
209    }
210
211    #[doc(hidden)]
212    pub fn format() -> String {
213        use core::fmt::Write;
214        let mut out = String::new();
215
216        if !ENABLE_STATS {
217            out.push_str("keccak cache stats: DISABLED");
218            return out;
219        }
220
221        let mut total_hits = 0usize;
222        let mut total_misses = 0usize;
223        let mut entries: Vec<(usize, usize, usize)> = Vec::new();
224        for len in 0..=MAX_INPUT_LEN {
225            let hits = STATS.hits[len].load(Ordering::Relaxed);
226            let misses = STATS.misses[len].load(Ordering::Relaxed);
227            if hits > 0 || misses > 0 {
228                entries.push((len, hits, misses));
229                total_hits += hits;
230                total_misses += misses;
231            }
232        }
233        for (&len, &misses) in STATS.out_of_range.lock().unwrap().iter() {
234            entries.push((len, 0, misses));
235            total_misses += misses;
236        }
237        entries.sort_by_key(|(len, _, _)| *len);
238
239        writeln!(out, "keccak cache stats by length:").unwrap();
240        writeln!(out, "{:>6} {:>12} {:>12} {:>8}", "len", "hits", "misses", "hit%").unwrap();
241        for (len, hits, misses) in entries {
242            let total = hits + misses;
243            let hit_rate = (hits as f64 / total as f64) * 100.0;
244            writeln!(out, "{len:>6} {hits:>12} {misses:>12} {hit_rate:>7.1}%").unwrap();
245        }
246        let total = total_hits + total_misses;
247        if total > 0 {
248            let hit_rate = (total_hits as f64 / total as f64) * 100.0;
249            writeln!(
250                out,
251                "{:>6} {:>12} {:>12} {:>7.1}%",
252                "all", total_hits, total_misses, hit_rate
253            )
254            .unwrap();
255        }
256
257        let collisions = STATS.collisions.lock().unwrap();
258        if !collisions.is_empty() {
259            writeln!(out, "\nhash collisions ({}):", collisions.len()).unwrap();
260            for (input, cached) in collisions.iter() {
261                writeln!(out, "  input:  0x{input}").unwrap();
262                writeln!(out, "  cached: 0x{cached}").unwrap();
263                writeln!(out).unwrap();
264            }
265        }
266
267        out
268    }
269}