Skip to main content

alloy_primitives/utils/
keccak_cache.rs

1//! A minimalistic one-way set associative cache for Keccak256 values.
2//!
3//! This cache has a fixed size to allow fast access and minimize per-call overhead.
4
5use super::{hint::unlikely, keccak256_impl as keccak256};
6use crate::{B256, KECCAK256_EMPTY};
7use std::mem::MaybeUninit;
8
9/// Maximum input length that can be cached.
10pub(super) const MAX_INPUT_LEN: usize =
11    128 - size_of::<B256>() - size_of::<u8>() - size_of::<usize>();
12
13const COUNT: usize = 1 << 17; // ~131k entries * 128 bytes = 16MiB
14static CACHE: fixed_cache::Cache<Key, B256, BuildHasher, CacheConfig> =
15    fixed_cache::static_cache!(Key, B256, COUNT, BuildHasher::new());
16
17struct CacheConfig {}
18impl fixed_cache::CacheConfig for CacheConfig {
19    const STATS: bool = false;
20    const EPOCHS: bool = false;
21}
22
23pub(super) fn compute(input: &[u8], imp: impl FnOnce(&[u8]) -> B256) -> B256 {
24    if unlikely(input.is_empty() | (input.len() > MAX_INPUT_LEN)) {
25        return if input.is_empty() { KECCAK256_EMPTY } else { keccak256(input) };
26    }
27
28    CACHE.get_or_insert_with_ref(input, imp, |input| {
29        let mut data = [MaybeUninit::uninit(); MAX_INPUT_LEN];
30        unsafe {
31            std::ptr::copy_nonoverlapping(input.as_ptr(), data.as_mut_ptr().cast(), input.len())
32        };
33        Key { len: input.len() as u8, data }
34    })
35}
36
37type BuildHasher = std::hash::BuildHasherDefault<Hasher>;
38#[derive(Default)]
39struct Hasher(u64);
40
41impl std::hash::Hasher for Hasher {
42    #[inline]
43    fn finish(&self) -> u64 {
44        self.0
45    }
46
47    #[inline]
48    fn write(&mut self, bytes: &[u8]) {
49        // This is tricky because our most common inputs are medium length: 16..=88
50        // `foldhash` and `rapidhash` have a fast-path for ..16 bytes and outline the rest,
51        // but really we want the opposite, or at least the 16.. path to be inlined.
52
53        // SAFETY: `bytes.len()` is checked to be within the bounds of `MAX_INPUT_LEN` by caller.
54        unsafe { core::hint::assert_unchecked(bytes.len() <= MAX_INPUT_LEN) };
55        if bytes.len() <= 16 {
56            super::hint::cold_path();
57        }
58        self.0 = rapidhash::v3::rapidhash_v3_micro_inline::<false, false>(
59            bytes,
60            const { &rapidhash::v3::RapidSecrets::seed(0) },
61        );
62    }
63
64    // We can just skip hashing the length prefix entirely since we know it's always
65    // `<=MAX_INPUT_LEN`, and the hash is good enough.
66
67    // `write_length_prefix` calls `write_usize` by default.
68    #[inline]
69    fn write_usize(&mut self, i: usize) {
70        debug_assert!(i <= MAX_INPUT_LEN, "{i} > {MAX_INPUT_LEN}")
71    }
72
73    #[cfg(feature = "nightly")]
74    #[inline]
75    fn write_length_prefix(&mut self, len: usize) {
76        debug_assert!(len <= MAX_INPUT_LEN, "{len} > {MAX_INPUT_LEN}")
77    }
78}
79
80#[derive(Clone, Copy)]
81struct Key {
82    len: u8,
83    data: [MaybeUninit<u8>; MAX_INPUT_LEN],
84}
85
86impl PartialEq for Key {
87    #[inline]
88    fn eq(&self, other: &Self) -> bool {
89        self.get() == other.get()
90    }
91}
92impl Eq for Key {}
93
94impl std::borrow::Borrow<[u8]> for Key {
95    #[inline]
96    fn borrow(&self) -> &[u8] {
97        self.get()
98    }
99}
100
101impl std::hash::Hash for Key {
102    #[inline]
103    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
104        state.write(self.get());
105    }
106}
107
108impl Key {
109    #[inline]
110    const fn get(&self) -> &[u8] {
111        unsafe { std::slice::from_raw_parts(self.data.as_ptr().cast(), self.len as usize) }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn sizes() {
121        assert_eq!(size_of::<Key>(), MAX_INPUT_LEN + 1);
122        assert_eq!(size_of::<fixed_cache::Bucket<(Key, B256)>>(), 128);
123    }
124
125    #[test]
126    fn caching() {
127        let mut count: usize = 0;
128        let mut compute = |input| {
129            compute(input, |x| {
130                count += 1;
131                keccak256(x)
132            })
133        };
134
135        let input = b"Hello World!";
136        let input2 = b"Hello World! 2";
137
138        let a = compute(input);
139        let b = compute(input);
140        let c = compute(input);
141        assert_eq!(a, b);
142        assert_eq!(a, c);
143
144        let d = compute(input2);
145        let e = compute(input2);
146        assert_ne!(a, d);
147        assert_eq!(d, e);
148
149        assert_eq!(count, 2);
150    }
151}