alloy_primitives/utils/
keccak_cache.rs1use super::{
6 hint::{likely, unlikely},
7 keccak256_impl as keccak256,
8};
9use crate::{B256, KECCAK256_EMPTY};
10use core::{
11 cell::UnsafeCell,
12 sync::atomic::{AtomicUsize, Ordering},
13};
14
15const ENABLE_STATS: bool = false || option_env!("KECCAK_CACHE_STATS").is_some();
16
17const COUNT: usize = 1 << 17; const INDEX_MASK: usize = COUNT - 1;
21const HASH_MASK: usize = !INDEX_MASK;
22
23const LOCKED_BIT: usize = 0x0000_8000;
24
25pub(super) const MAX_INPUT_LEN: usize = 128 - 32 - size_of::<usize>();
27
28static CACHE: [Entry; COUNT] = [const { Entry::new() }; COUNT];
32
33pub(super) fn compute(input: &[u8]) -> B256 {
34 if unlikely(input.is_empty() | (input.len() > MAX_INPUT_LEN)) {
35 return if input.is_empty() {
36 stats::hit(0);
37 KECCAK256_EMPTY
38 } else {
39 stats::out_of_range(input.len());
40 keccak256(input)
41 };
42 }
43
44 let hash = hash_bytes(input);
45 let entry = &CACHE[hash & INDEX_MASK];
46
47 let combined = (hash & HASH_MASK) | input.len();
50
51 if entry.try_lock(Some(combined)) {
52 let EntryData { value, keccak256: result } = unsafe { *entry.data.get() };
54
55 entry.unlock(combined);
56
57 if likely(value[..input.len()] == input[..]) {
58 stats::hit(input.len());
60 return result;
61 }
62 stats::collision(input, &value[..input.len()]);
65 }
66 stats::miss(input.len());
67
68 let result = keccak256(input);
70
71 if entry.try_lock(None) {
73 unsafe {
75 let data = &mut *entry.data.get();
76 data.value[..input.len()].copy_from_slice(input);
77 data.keccak256 = result;
78 }
79
80 entry.unlock(combined);
81 }
82
83 result
84}
85
86#[repr(C, align(128))]
88struct Entry {
89 combined: AtomicUsize,
90 data: UnsafeCell<EntryData>,
91}
92
93#[repr(C, align(4))]
94#[derive(Clone, Copy)]
95struct EntryData {
96 value: [u8; MAX_INPUT_LEN],
97 keccak256: B256,
98}
99
100impl Entry {
101 #[inline]
102 const fn new() -> Self {
103 unsafe { core::mem::zeroed() }
105 }
106
107 #[inline]
108 fn try_lock(&self, expected: Option<usize>) -> bool {
109 let state = self.combined.load(Ordering::Relaxed);
110 if let Some(expected) = expected {
111 if state != expected {
112 return false;
113 }
114 } else if state & LOCKED_BIT != 0 {
115 return false;
116 }
117 self.combined
118 .compare_exchange(state, state | LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
119 .is_ok()
120 }
121
122 #[inline]
123 fn unlock(&self, combined: usize) {
124 self.combined.store(combined, Ordering::Release);
125 }
126}
127
128unsafe impl Send for Entry {}
130unsafe impl Sync for Entry {}
131
132#[inline(always)]
133#[allow(clippy::missing_const_for_fn)]
134fn hash_bytes(input: &[u8]) -> usize {
135 unsafe { core::hint::assert_unchecked(input.len() <= MAX_INPUT_LEN) };
141 if input.len() <= 16 {
142 super::hint::cold_path();
143 }
144 let hash = rapidhash::v3::rapidhash_v3_micro_inline::<false, false>(
145 input,
146 const { &rapidhash::v3::RapidSecrets::seed(0) },
147 );
148
149 if cfg!(target_pointer_width = "32") {
150 ((hash >> 32) as usize) ^ (hash as usize)
151 } else {
152 hash as usize
153 }
154}
155
156pub(super) mod stats {
158 use super::*;
159 use std::{collections::HashMap, sync::Mutex};
160
161 type BuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
162
163 static STATS: KeccakCacheStats = KeccakCacheStats {
164 hits: [const { AtomicUsize::new(0) }; MAX_INPUT_LEN + 1],
165 misses: [const { AtomicUsize::new(0) }; MAX_INPUT_LEN + 1],
166 out_of_range: Mutex::new(HashMap::with_hasher(BuildHasher::new())),
167 collisions: Mutex::new(Vec::new()),
168 };
169
170 struct KeccakCacheStats {
171 hits: [AtomicUsize; MAX_INPUT_LEN + 1],
172 misses: [AtomicUsize; MAX_INPUT_LEN + 1],
173 out_of_range: Mutex<HashMap<usize, usize, BuildHasher>>,
174 collisions: Mutex<Vec<(String, String)>>,
175 }
176
177 #[inline]
178 pub(super) fn hit(len: usize) {
179 if !ENABLE_STATS {
180 return;
181 }
182 STATS.hits[len].fetch_add(1, Ordering::Relaxed);
183 }
184
185 #[inline]
186 pub(super) fn miss(len: usize) {
187 if !ENABLE_STATS {
188 return;
189 }
190 STATS.misses[len].fetch_add(1, Ordering::Relaxed);
191 }
192
193 #[inline(never)]
194 pub(super) fn out_of_range(len: usize) {
195 if !ENABLE_STATS {
196 return;
197 }
198 *STATS.out_of_range.lock().unwrap().entry(len).or_insert(0) += 1;
199 }
200
201 #[inline(never)]
202 pub(super) fn collision(input: &[u8], cached: &[u8]) {
203 if !ENABLE_STATS {
204 return;
205 }
206 let input_hex = crate::hex::encode(input);
207 let cached_hex = crate::hex::encode(cached);
208 STATS.collisions.lock().unwrap().push((input_hex, cached_hex));
209 }
210
211 #[doc(hidden)]
212 pub fn format() -> String {
213 use core::fmt::Write;
214 let mut out = String::new();
215
216 if !ENABLE_STATS {
217 out.push_str("keccak cache stats: DISABLED");
218 return out;
219 }
220
221 let mut total_hits = 0usize;
222 let mut total_misses = 0usize;
223 let mut entries: Vec<(usize, usize, usize)> = Vec::new();
224 for len in 0..=MAX_INPUT_LEN {
225 let hits = STATS.hits[len].load(Ordering::Relaxed);
226 let misses = STATS.misses[len].load(Ordering::Relaxed);
227 if hits > 0 || misses > 0 {
228 entries.push((len, hits, misses));
229 total_hits += hits;
230 total_misses += misses;
231 }
232 }
233 for (&len, &misses) in STATS.out_of_range.lock().unwrap().iter() {
234 entries.push((len, 0, misses));
235 total_misses += misses;
236 }
237 entries.sort_by_key(|(len, _, _)| *len);
238
239 writeln!(out, "keccak cache stats by length:").unwrap();
240 writeln!(out, "{:>6} {:>12} {:>12} {:>8}", "len", "hits", "misses", "hit%").unwrap();
241 for (len, hits, misses) in entries {
242 let total = hits + misses;
243 let hit_rate = (hits as f64 / total as f64) * 100.0;
244 writeln!(out, "{len:>6} {hits:>12} {misses:>12} {hit_rate:>7.1}%").unwrap();
245 }
246 let total = total_hits + total_misses;
247 if total > 0 {
248 let hit_rate = (total_hits as f64 / total as f64) * 100.0;
249 writeln!(
250 out,
251 "{:>6} {:>12} {:>12} {:>7.1}%",
252 "all", total_hits, total_misses, hit_rate
253 )
254 .unwrap();
255 }
256
257 let collisions = STATS.collisions.lock().unwrap();
258 if !collisions.is_empty() {
259 writeln!(out, "\nhash collisions ({}):", collisions.len()).unwrap();
260 for (input, cached) in collisions.iter() {
261 writeln!(out, " input: 0x{input}").unwrap();
262 writeln!(out, " cached: 0x{cached}").unwrap();
263 writeln!(out).unwrap();
264 }
265 }
266
267 out
268 }
269}