Skip to main content

alloy_primitives/utils/
keccak_cache.rs

1//! A minimalistic one-way set associative cache for Keccak256 values.
2//!
3//! This cache has a fixed size to allow fast access and minimize per-call overhead.
4
5use super::{hint::unlikely, keccak256_impl as keccak256};
6use crate::{B256, KECCAK256_EMPTY};
7use std::mem::MaybeUninit;
8
9/// Maximum input length that can be cached.
10pub(super) const MAX_INPUT_LEN: usize =
11    128 - size_of::<B256>() - size_of::<u8>() - size_of::<usize>();
12
13const COUNT: usize = 1 << 17; // ~131k entries * 128 bytes = 16MiB
14static CACHE: fixed_cache::Cache<Key, B256, BuildHasher> =
15    fixed_cache::static_cache!(Key, B256, COUNT, BuildHasher::new());
16
17pub(super) fn compute(input: &[u8], imp: impl FnOnce(&[u8]) -> B256) -> B256 {
18    if unlikely(input.is_empty() | (input.len() > MAX_INPUT_LEN)) {
19        return if input.is_empty() { KECCAK256_EMPTY } else { keccak256(input) };
20    }
21
22    CACHE.get_or_insert_with_ref(input, imp, |input| {
23        let mut data = [MaybeUninit::uninit(); MAX_INPUT_LEN];
24        unsafe {
25            std::ptr::copy_nonoverlapping(input.as_ptr(), data.as_mut_ptr().cast(), input.len())
26        };
27        Key { len: input.len() as u8, data }
28    })
29}
30
31type BuildHasher = std::hash::BuildHasherDefault<Hasher>;
32#[derive(Default)]
33struct Hasher(u64);
34
35impl std::hash::Hasher for Hasher {
36    #[inline]
37    fn finish(&self) -> u64 {
38        self.0
39    }
40
41    #[inline]
42    fn write(&mut self, bytes: &[u8]) {
43        // This is tricky because our most common inputs are medium length: 16..=88
44        // `foldhash` and `rapidhash` have a fast-path for ..16 bytes and outline the rest,
45        // but really we want the opposite, or at least the 16.. path to be inlined.
46
47        // SAFETY: `bytes.len()` is checked to be within the bounds of `MAX_INPUT_LEN` by caller.
48        unsafe { core::hint::assert_unchecked(bytes.len() <= MAX_INPUT_LEN) };
49        if bytes.len() <= 16 {
50            super::hint::cold_path();
51        }
52        self.0 = rapidhash::v3::rapidhash_v3_micro_inline::<false, false>(
53            bytes,
54            const { &rapidhash::v3::RapidSecrets::seed(0) },
55        );
56    }
57
58    // We can just skip hashing the length prefix entirely since we know it's always
59    // `<=MAX_INPUT_LEN`, and the hash is good enough.
60
61    // `write_length_prefix` calls `write_usize` by default.
62    #[inline]
63    fn write_usize(&mut self, i: usize) {
64        debug_assert!(i <= MAX_INPUT_LEN, "{i} > {MAX_INPUT_LEN}")
65    }
66
67    #[cfg(feature = "nightly")]
68    #[inline]
69    fn write_length_prefix(&mut self, len: usize) {
70        debug_assert!(len <= MAX_INPUT_LEN, "{len} > {MAX_INPUT_LEN}")
71    }
72}
73
74#[derive(Clone, Copy)]
75struct Key {
76    len: u8,
77    data: [MaybeUninit<u8>; MAX_INPUT_LEN],
78}
79
80impl PartialEq for Key {
81    #[inline]
82    fn eq(&self, other: &Self) -> bool {
83        self.get() == other.get()
84    }
85}
86impl Eq for Key {}
87
88impl std::borrow::Borrow<[u8]> for Key {
89    #[inline]
90    fn borrow(&self) -> &[u8] {
91        self.get()
92    }
93}
94
95impl std::hash::Hash for Key {
96    #[inline]
97    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
98        state.write(self.get());
99    }
100}
101
102impl Key {
103    #[inline]
104    const fn get(&self) -> &[u8] {
105        unsafe { std::slice::from_raw_parts(self.data.as_ptr().cast(), self.len as usize) }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn sizes() {
115        assert_eq!(size_of::<Key>(), MAX_INPUT_LEN + 1);
116        assert_eq!(size_of::<fixed_cache::Bucket<(Key, B256)>>(), 128);
117    }
118
119    #[test]
120    fn caching() {
121        let mut count: usize = 0;
122        let mut compute = |input| {
123            compute(input, |x| {
124                count += 1;
125                keccak256(x)
126            })
127        };
128
129        let input = b"Hello World!";
130        let input2 = b"Hello World! 2";
131
132        let a = compute(input);
133        let b = compute(input);
134        let c = compute(input);
135        assert_eq!(a, b);
136        assert_eq!(a, c);
137
138        let d = compute(input2);
139        let e = compute(input2);
140        assert_ne!(a, d);
141        assert_eq!(d, e);
142
143        assert_eq!(count, 2);
144    }
145}