chik_bls/
bls_cache.rs

1use std::borrow::Borrow;
2use std::num::NonZeroUsize;
3
4use chik_sha2::Sha256;
5use linked_hash_map::LinkedHashMap;
6use std::sync::Mutex;
7
8use crate::{aggregate_verify_gt, hash_to_g2};
9use crate::{GTElement, PublicKey, Signature};
10
11/// This is a cache of pairings of public keys and their corresponding message.
12/// It accelerates aggregate verification when some public keys have already
13/// been paired, and found in the cache.
14/// We use it to cache pairings when validating transactions inserted into the
15/// mempool, as many of those transactions are likely to show up in a full block
16/// later. This makes it a lot cheaper to validate the full block.
17/// However, validating a signature where we have no cached GT elements, the
18/// aggregate_verify() primitive is faster. When long-syncing, that's
19/// preferable.
20
21#[derive(Debug, Clone)]
22struct BlsCacheData {
23    // sha256(pubkey + message) -> GTElement
24    items: LinkedHashMap<[u8; 32], GTElement>,
25    capacity: NonZeroUsize,
26}
27
28impl BlsCacheData {
29    pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) {
30        // If the cache is full, remove the oldest item.
31        if self.items.len() == self.capacity.get() {
32            if let Some((oldest_key, _)) = self.items.pop_front() {
33                self.items.remove(&oldest_key);
34            }
35        }
36        self.items.insert(hash, pairing);
37    }
38}
39
40#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
41#[derive(Debug)]
42pub struct BlsCache {
43    cache: Mutex<BlsCacheData>,
44}
45
46impl Default for BlsCache {
47    fn default() -> Self {
48        Self::new(NonZeroUsize::new(50_000).unwrap())
49    }
50}
51
52impl Clone for BlsCache {
53    fn clone(&self) -> Self {
54        Self {
55            cache: Mutex::new(self.cache.lock().expect("cache").clone()),
56        }
57    }
58}
59
60impl BlsCache {
61    pub fn new(capacity: NonZeroUsize) -> Self {
62        Self {
63            cache: Mutex::new(BlsCacheData {
64                items: LinkedHashMap::new(),
65                capacity,
66            }),
67        }
68    }
69
70    pub fn len(&self) -> usize {
71        self.cache.lock().expect("cache").items.len()
72    }
73
74    pub fn is_empty(&self) -> bool {
75        self.cache.lock().expect("cache").items.is_empty()
76    }
77
78    pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
79        &self,
80        pks_msgs: impl IntoIterator<Item = (Pk, Msg)>,
81        sig: &Signature,
82    ) -> bool {
83        let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement {
84            // Hash pubkey + message
85            let mut hasher = Sha256::new();
86            let mut aug_msg = pk.borrow().to_bytes().to_vec();
87            aug_msg.extend_from_slice(msg.as_ref());
88            hasher.update(&aug_msg);
89            let hash: [u8; 32] = hasher.finalize();
90
91            // If the pairing is in the cache, we don't need to recalculate it.
92            if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() {
93                return pairing;
94            }
95
96            // Otherwise, we need to calculate the pairing and add it to the cache.
97            let aug_hash = hash_to_g2(&aug_msg);
98
99            let pairing = aug_hash.pair(pk.borrow());
100            self.cache.lock().expect("cache").put(hash, pairing.clone());
101            pairing
102        });
103
104        aggregate_verify_gt(sig, iter)
105    }
106
107    pub fn update(&self, aug_msg: &[u8], gt: GTElement) {
108        let mut hasher = Sha256::new();
109        hasher.update(aug_msg.as_ref());
110        let hash: [u8; 32] = hasher.finalize();
111        self.cache.lock().expect("cache").put(hash, gt);
112    }
113
114    pub fn evict<Pk, Msg>(&self, pks_msgs: impl IntoIterator<Item = (Pk, Msg)>)
115    where
116        Pk: Borrow<PublicKey>,
117        Msg: AsRef<[u8]>,
118    {
119        let mut c = self.cache.lock().expect("cache");
120        for (pk, msg) in pks_msgs {
121            let mut hasher = Sha256::new();
122            let mut aug_msg = pk.borrow().to_bytes().to_vec();
123            aug_msg.extend_from_slice(msg.as_ref());
124            hasher.update(&aug_msg);
125            let hash: [u8; 32] = hasher.finalize();
126            c.items.remove(&hash);
127        }
128    }
129}
130
131#[cfg(feature = "py-bindings")]
132use pyo3::{
133    exceptions::PyValueError,
134    pybacked::PyBackedBytes,
135    types::{PyAnyMethods, PyList, PySequence},
136    Bound, PyObject, PyResult,
137};
138
139#[cfg(feature = "py-bindings")]
140#[pyo3::pymethods]
141impl BlsCache {
142    #[new]
143    #[pyo3(signature = (size=None))]
144    pub fn init(size: Option<u32>) -> PyResult<Self> {
145        let Some(size) = size else {
146            return Ok(Self::default());
147        };
148
149        let Some(size) = NonZeroUsize::new(size as usize) else {
150            return Err(PyValueError::new_err(
151                "Cannot have a cache size less than one.",
152            ));
153        };
154
155        Ok(Self::new(size))
156    }
157
158    #[pyo3(name = "aggregate_verify")]
159    pub fn py_aggregate_verify(
160        &self,
161        pks: &Bound<'_, PyList>,
162        msgs: &Bound<'_, PyList>,
163        sig: &Signature,
164    ) -> PyResult<bool> {
165        let pks = pks
166            .try_iter()?
167            .map(|item| item?.extract())
168            .collect::<PyResult<Vec<PublicKey>>>()?;
169
170        let msgs = msgs
171            .try_iter()?
172            .map(|item| item?.extract())
173            .collect::<PyResult<Vec<PyBackedBytes>>>()?;
174
175        Ok(self.aggregate_verify(pks.into_iter().zip(msgs), sig))
176    }
177
178    #[pyo3(name = "len")]
179    pub fn py_len(&self) -> PyResult<usize> {
180        Ok(self.len())
181    }
182
183    #[pyo3(name = "items")]
184    pub fn py_items(&self, py: pyo3::Python<'_>) -> PyResult<PyObject> {
185        use pyo3::prelude::*;
186        use pyo3::types::PyBytes;
187        let ret = PyList::empty(py);
188        let c = self.cache.lock().expect("cache");
189        for (key, value) in &c.items {
190            ret.append((
191                PyBytes::new(py, key),
192                value.clone().into_pyobject(py)?.into_any(),
193            ))?;
194        }
195        Ok(ret.into())
196    }
197
198    #[pyo3(name = "update")]
199    pub fn py_update(&self, other: &Bound<'_, PySequence>) -> PyResult<()> {
200        let mut c = self.cache.lock().expect("cache");
201        for item in other.borrow().try_iter()? {
202            let (key, value): (Vec<u8>, GTElement) = item?.extract()?;
203            c.put(
204                key.try_into()
205                    .map_err(|_| PyValueError::new_err("invalid key"))?,
206                value,
207            );
208        }
209        Ok(())
210    }
211
212    #[pyo3(name = "evict")]
213    pub fn py_evict(&self, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>) -> PyResult<()> {
214        let pks = pks
215            .try_iter()?
216            .map(|item| item?.extract())
217            .collect::<PyResult<Vec<PublicKey>>>()?;
218        let msgs = msgs
219            .try_iter()?
220            .map(|item| item?.extract())
221            .collect::<PyResult<Vec<PyBackedBytes>>>()?;
222        self.evict(pks.into_iter().zip(msgs));
223        Ok(())
224    }
225}
226
227#[cfg(test)]
228pub mod tests {
229    use super::*;
230
231    use crate::sign;
232    use crate::SecretKey;
233
234    #[test]
235    fn test_aggregate_verify() {
236        let bls_cache = BlsCache::default();
237
238        let sk = SecretKey::from_seed(&[0; 32]);
239        let pk = sk.public_key();
240        let msg = [106; 32];
241
242        let sig = sign(&sk, msg);
243        let pks_msgs = [(pk, msg)];
244
245        // Before we cache anything, it should be empty.
246        assert!(bls_cache.is_empty());
247
248        // Verify the signature and add to the cache.
249        assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
250        assert_eq!(bls_cache.len(), 1);
251
252        // Now that it's cached, it shouldn't cache it again.
253        assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
254        assert_eq!(bls_cache.len(), 1);
255    }
256
257    #[test]
258    fn test_cache() {
259        let bls_cache = BlsCache::default();
260
261        let sk1 = SecretKey::from_seed(&[0; 32]);
262        let pk1 = sk1.public_key();
263        let msg1 = [106; 32];
264
265        let mut agg_sig = sign(&sk1, msg1);
266        let mut pks_msgs = vec![(pk1, msg1)];
267
268        // Before we cache anything, it should be empty.
269        assert!(bls_cache.is_empty());
270
271        // Add the first signature to cache.
272        assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
273        assert_eq!(bls_cache.len(), 1);
274
275        // Try with the first key message pair in the cache but not the second.
276        let sk2 = SecretKey::from_seed(&[1; 32]);
277        let pk2 = sk2.public_key();
278        let msg2 = [107; 32];
279
280        agg_sig += &sign(&sk2, msg2);
281        pks_msgs.push((pk2, msg2));
282
283        assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
284        assert_eq!(bls_cache.len(), 2);
285
286        // Try reusing a public key.
287        let msg3 = [108; 32];
288
289        agg_sig += &sign(&sk2, msg3);
290        pks_msgs.push((pk2, msg3));
291
292        // Verify this signature and add to the cache as well (since it's still a different aggregate).
293        assert!(bls_cache.aggregate_verify(pks_msgs, &agg_sig));
294        assert_eq!(bls_cache.len(), 3);
295    }
296
297    #[test]
298    fn test_cache_limit() {
299        // The cache is limited to only 3 items.
300        let bls_cache = BlsCache::new(NonZeroUsize::new(3).unwrap());
301
302        // Before we cache anything, it should be empty.
303        assert!(bls_cache.is_empty());
304
305        // Create 5 pubkey message pairs.
306        for i in 1..=5 {
307            let sk = SecretKey::from_seed(&[i; 32]);
308            let pk = sk.public_key();
309            let msg = [106; 32];
310
311            let sig = sign(&sk, msg);
312            let pks_msgs = [(pk, msg)];
313
314            // Add to cache by validating them one at a time.
315            assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
316        }
317
318        // The cache should be full now.
319        assert_eq!(bls_cache.len(), 3);
320
321        // Recreate first two keys and make sure they got removed.
322        for i in 1..=2 {
323            let sk = SecretKey::from_seed(&[i; 32]);
324            let pk = sk.public_key();
325            let msg = [106; 32];
326            let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
327            let mut hasher = Sha256::new();
328            hasher.update(aug_msg);
329            let hash: [u8; 32] = hasher.finalize();
330            assert!(!bls_cache
331                .cache
332                .lock()
333                .expect("cache")
334                .items
335                .contains_key(&hash));
336        }
337    }
338
339    #[test]
340    fn test_empty_sig() {
341        let bls_cache = BlsCache::default();
342
343        let pks_msgs: [(&PublicKey, &[u8]); 0] = [];
344
345        assert!(bls_cache.aggregate_verify(pks_msgs, &Signature::default()));
346    }
347
348    #[test]
349    fn test_evict() {
350        let bls_cache = BlsCache::new(NonZeroUsize::new(5).unwrap());
351        // Create 5 pk msg pairs and add them to the cache.
352        let mut pks_msgs = Vec::new();
353        for i in 1..=5 {
354            let sk = SecretKey::from_seed(&[i; 32]);
355            let pk = sk.public_key();
356            let msg = [42; 32];
357            let sig = sign(&sk, msg);
358            pks_msgs.push((pk, msg));
359            assert!(bls_cache.aggregate_verify([(pk, msg)], &sig));
360        }
361        assert_eq!(bls_cache.len(), 5);
362        // Evict the first and third entries.
363        let pks_msgs_to_evict = vec![pks_msgs[0], pks_msgs[2]];
364        bls_cache.evict(pks_msgs_to_evict.iter().copied());
365        // The cache should have 3 items now.
366        assert_eq!(bls_cache.len(), 3);
367        // Check that the evicted entries are no longer in the cache.
368        for (pk, msg) in &pks_msgs_to_evict {
369            let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
370            let mut hasher = Sha256::new();
371            hasher.update(aug_msg);
372            let hash: [u8; 32] = hasher.finalize();
373            assert!(!bls_cache
374                .cache
375                .lock()
376                .expect("cache")
377                .items
378                .contains_key(&hash));
379        }
380        // Check that the remaining entries are still in the cache.
381        for (pk, msg) in &[pks_msgs[1], pks_msgs[3], pks_msgs[4]] {
382            let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
383            let mut hasher = Sha256::new();
384            hasher.update(aug_msg);
385            let hash: [u8; 32] = hasher.finalize();
386            assert!(bls_cache
387                .cache
388                .lock()
389                .expect("cache")
390                .items
391                .contains_key(&hash));
392        }
393    }
394}