Skip to main content

chia_bls/
bls_cache.rs

1use std::borrow::Borrow;
2use std::num::NonZeroUsize;
3
4use crate::{GTElement, PublicKey, Signature};
5use crate::{aggregate_verify_gt, hash_to_g2};
6use chia_sha2::Sha256;
7use linked_hash_map::LinkedHashMap;
8use std::sync::Mutex;
9
10/// This is a cache of pairings of public keys and their corresponding message.
11/// It accelerates aggregate verification when some public keys have already
12/// been paired, and found in the cache.
13/// We use it to cache pairings when validating transactions inserted into the
14/// mempool, as many of those transactions are likely to show up in a full block
15/// later. This makes it a lot cheaper to validate the full block.
16/// However, validating a signature where we have no cached GT elements, the
17/// aggregate_verify() primitive is faster. When long-syncing, that's
18/// preferable.
19
20#[derive(Debug, Clone)]
21struct BlsCacheData {
22    // sha256(pubkey + message) -> GTElement
23    items: LinkedHashMap<[u8; 32], GTElement>,
24    capacity: NonZeroUsize,
25}
26
27impl BlsCacheData {
28    pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) {
29        // If the cache is full, remove the oldest item.
30        if self.items.len() == self.capacity.get() {
31            if let Some((oldest_key, _)) = self.items.pop_front() {
32                self.items.remove(&oldest_key);
33            }
34        }
35        self.items.insert(hash, pairing);
36    }
37}
38
39#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
40#[derive(Debug)]
41pub struct BlsCache {
42    cache: Mutex<BlsCacheData>,
43}
44
45impl Default for BlsCache {
46    fn default() -> Self {
47        Self::new(NonZeroUsize::new(50_000).unwrap())
48    }
49}
50
51impl Clone for BlsCache {
52    fn clone(&self) -> Self {
53        Self {
54            cache: Mutex::new(self.cache.lock().expect("cache").clone()),
55        }
56    }
57}
58
59impl BlsCache {
60    pub fn new(capacity: NonZeroUsize) -> Self {
61        Self {
62            cache: Mutex::new(BlsCacheData {
63                items: LinkedHashMap::new(),
64                capacity,
65            }),
66        }
67    }
68
69    pub fn len(&self) -> usize {
70        self.cache.lock().expect("cache").items.len()
71    }
72
73    pub fn is_empty(&self) -> bool {
74        self.cache.lock().expect("cache").items.is_empty()
75    }
76
77    pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
78        &self,
79        pks_msgs: impl IntoIterator<Item = (Pk, Msg)>,
80        sig: &Signature,
81    ) -> bool {
82        let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement {
83            // Hash pubkey + message
84            let mut hasher = Sha256::new();
85            let mut aug_msg = pk.borrow().to_bytes().to_vec();
86            aug_msg.extend_from_slice(msg.as_ref());
87            hasher.update(&aug_msg);
88            let hash: [u8; 32] = hasher.finalize();
89
90            // If the pairing is in the cache, we don't need to recalculate it.
91            if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() {
92                return pairing;
93            }
94
95            // Otherwise, we need to calculate the pairing and add it to the cache.
96            let aug_hash = hash_to_g2(&aug_msg);
97
98            let pairing = aug_hash.pair(pk.borrow());
99            self.cache.lock().expect("cache").put(hash, pairing.clone());
100            pairing
101        });
102
103        aggregate_verify_gt(sig, iter)
104    }
105
106    pub fn update(&self, aug_msg: &[u8], gt: GTElement) {
107        let mut hasher = Sha256::new();
108        hasher.update(aug_msg.as_ref());
109        let hash: [u8; 32] = hasher.finalize();
110        self.cache.lock().expect("cache").put(hash, gt);
111    }
112
113    pub fn evict<Pk, Msg>(&self, pks_msgs: impl IntoIterator<Item = (Pk, Msg)>)
114    where
115        Pk: Borrow<PublicKey>,
116        Msg: AsRef<[u8]>,
117    {
118        let mut c = self.cache.lock().expect("cache");
119        for (pk, msg) in pks_msgs {
120            let mut hasher = Sha256::new();
121            let mut aug_msg = pk.borrow().to_bytes().to_vec();
122            aug_msg.extend_from_slice(msg.as_ref());
123            hasher.update(&aug_msg);
124            let hash: [u8; 32] = hasher.finalize();
125            c.items.remove(&hash);
126        }
127    }
128}
129
130#[cfg(feature = "py-bindings")]
131use pyo3::{
132    Bound, Py, PyResult,
133    exceptions::PyValueError,
134    pybacked::PyBackedBytes,
135    types::{PyAnyMethods, PyList, PySequence},
136};
137
138#[cfg(feature = "py-bindings")]
139#[pyo3::pymethods]
140impl BlsCache {
141    #[new]
142    #[pyo3(signature = (cache_size=None))]
143    pub fn init(cache_size: Option<u32>) -> PyResult<Self> {
144        let Some(size) = cache_size else {
145            return Ok(Self::default());
146        };
147
148        let Some(size) = NonZeroUsize::new(size as usize) else {
149            return Err(PyValueError::new_err(
150                "Cannot have a cache size less than one.",
151            ));
152        };
153
154        Ok(Self::new(size))
155    }
156
157    #[pyo3(name = "aggregate_verify")]
158    pub fn py_aggregate_verify(
159        &self,
160        pks: &Bound<'_, PyList>,
161        msgs: &Bound<'_, PyList>,
162        sig: &Signature,
163    ) -> PyResult<bool> {
164        let pks = pks
165            .try_iter()?
166            .map(|item| Ok(item?.extract()?))
167            .collect::<PyResult<Vec<PublicKey>>>()?;
168
169        let msgs = msgs
170            .try_iter()?
171            .map(|item| Ok(item?.extract()?))
172            .collect::<PyResult<Vec<PyBackedBytes>>>()?;
173
174        Ok(self.aggregate_verify(pks.into_iter().zip(msgs), sig))
175    }
176
177    #[pyo3(name = "len")]
178    pub fn py_len(&self) -> PyResult<usize> {
179        Ok(self.len())
180    }
181
182    #[pyo3(name = "items")]
183    pub fn py_items(&self, py: pyo3::Python<'_>) -> PyResult<Py<pyo3::PyAny>> {
184        use pyo3::prelude::*;
185        use pyo3::types::PyBytes;
186        let ret = PyList::empty(py);
187        let c = self.cache.lock().expect("cache");
188        for (key, value) in &c.items {
189            ret.append((
190                PyBytes::new(py, key),
191                value.clone().into_pyobject(py)?.into_any(),
192            ))?;
193        }
194        Ok(ret.into_any().unbind())
195    }
196
197    #[pyo3(name = "update")]
198    pub fn py_update(&self, other: &Bound<'_, PySequence>) -> PyResult<()> {
199        let mut c = self.cache.lock().expect("cache");
200        for item in other.borrow().try_iter()? {
201            let (key, value): (Vec<u8>, GTElement) = item?.extract()?;
202            c.put(
203                key.try_into()
204                    .map_err(|_| PyValueError::new_err("invalid key"))?,
205                value,
206            );
207        }
208        Ok(())
209    }
210
211    #[pyo3(name = "evict")]
212    pub fn py_evict(&self, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>) -> PyResult<()> {
213        let pks = pks
214            .try_iter()?
215            .map(|item| Ok(item?.extract()?))
216            .collect::<PyResult<Vec<PublicKey>>>()?;
217        let msgs = msgs
218            .try_iter()?
219            .map(|item| Ok(item?.extract()?))
220            .collect::<PyResult<Vec<PyBackedBytes>>>()?;
221        self.evict(pks.into_iter().zip(msgs));
222        Ok(())
223    }
224}
225
226#[cfg(test)]
227pub mod tests {
228    use super::*;
229
230    use crate::SecretKey;
231    use crate::sign;
232
233    #[test]
234    fn test_aggregate_verify() {
235        let bls_cache = BlsCache::default();
236
237        let sk = SecretKey::from_seed(&[0; 32]);
238        let pk = sk.public_key();
239        let msg = [106; 32];
240
241        let sig = sign(&sk, msg);
242        let pks_msgs = [(pk, msg)];
243
244        // Before we cache anything, it should be empty.
245        assert!(bls_cache.is_empty());
246
247        // Verify the signature and add to the cache.
248        assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
249        assert_eq!(bls_cache.len(), 1);
250
251        // Now that it's cached, it shouldn't cache it again.
252        assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
253        assert_eq!(bls_cache.len(), 1);
254    }
255
256    #[test]
257    fn test_cache() {
258        let bls_cache = BlsCache::default();
259
260        let sk1 = SecretKey::from_seed(&[0; 32]);
261        let pk1 = sk1.public_key();
262        let msg1 = [106; 32];
263
264        let mut agg_sig = sign(&sk1, msg1);
265        let mut pks_msgs = vec![(pk1, msg1)];
266
267        // Before we cache anything, it should be empty.
268        assert!(bls_cache.is_empty());
269
270        // Add the first signature to cache.
271        assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
272        assert_eq!(bls_cache.len(), 1);
273
274        // Try with the first key message pair in the cache but not the second.
275        let sk2 = SecretKey::from_seed(&[1; 32]);
276        let pk2 = sk2.public_key();
277        let msg2 = [107; 32];
278
279        agg_sig += &sign(&sk2, msg2);
280        pks_msgs.push((pk2, msg2));
281
282        assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
283        assert_eq!(bls_cache.len(), 2);
284
285        // Try reusing a public key.
286        let msg3 = [108; 32];
287
288        agg_sig += &sign(&sk2, msg3);
289        pks_msgs.push((pk2, msg3));
290
291        // Verify this signature and add to the cache as well (since it's still a different aggregate).
292        assert!(bls_cache.aggregate_verify(pks_msgs, &agg_sig));
293        assert_eq!(bls_cache.len(), 3);
294    }
295
296    #[test]
297    fn test_cache_limit() {
298        // The cache is limited to only 3 items.
299        let bls_cache = BlsCache::new(NonZeroUsize::new(3).unwrap());
300
301        // Before we cache anything, it should be empty.
302        assert!(bls_cache.is_empty());
303
304        // Create 5 pubkey message pairs.
305        for i in 1..=5 {
306            let sk = SecretKey::from_seed(&[i; 32]);
307            let pk = sk.public_key();
308            let msg = [106; 32];
309
310            let sig = sign(&sk, msg);
311            let pks_msgs = [(pk, msg)];
312
313            // Add to cache by validating them one at a time.
314            assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
315        }
316
317        // The cache should be full now.
318        assert_eq!(bls_cache.len(), 3);
319
320        // Recreate first two keys and make sure they got removed.
321        for i in 1..=2 {
322            let sk = SecretKey::from_seed(&[i; 32]);
323            let pk = sk.public_key();
324            let msg = [106; 32];
325            let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
326            let mut hasher = Sha256::new();
327            hasher.update(aug_msg);
328            let hash: [u8; 32] = hasher.finalize();
329            assert!(
330                !bls_cache
331                    .cache
332                    .lock()
333                    .expect("cache")
334                    .items
335                    .contains_key(&hash)
336            );
337        }
338    }
339
340    #[test]
341    fn test_empty_sig() {
342        let bls_cache = BlsCache::default();
343
344        let pks_msgs: [(&PublicKey, &[u8]); 0] = [];
345
346        assert!(bls_cache.aggregate_verify(pks_msgs, &Signature::default()));
347    }
348
349    #[test]
350    fn test_evict() {
351        let bls_cache = BlsCache::new(NonZeroUsize::new(5).unwrap());
352        // Create 5 pk msg pairs and add them to the cache.
353        let mut pks_msgs = Vec::new();
354        for i in 1..=5 {
355            let sk = SecretKey::from_seed(&[i; 32]);
356            let pk = sk.public_key();
357            let msg = [42; 32];
358            let sig = sign(&sk, msg);
359            pks_msgs.push((pk, msg));
360            assert!(bls_cache.aggregate_verify([(pk, msg)], &sig));
361        }
362        assert_eq!(bls_cache.len(), 5);
363        // Evict the first and third entries.
364        let pks_msgs_to_evict = vec![pks_msgs[0], pks_msgs[2]];
365        bls_cache.evict(pks_msgs_to_evict.iter().copied());
366        // The cache should have 3 items now.
367        assert_eq!(bls_cache.len(), 3);
368        // Check that the evicted entries are no longer in the cache.
369        for (pk, msg) in &pks_msgs_to_evict {
370            let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
371            let mut hasher = Sha256::new();
372            hasher.update(aug_msg);
373            let hash: [u8; 32] = hasher.finalize();
374            assert!(
375                !bls_cache
376                    .cache
377                    .lock()
378                    .expect("cache")
379                    .items
380                    .contains_key(&hash)
381            );
382        }
383        // Check that the remaining entries are still in the cache.
384        for (pk, msg) in &[pks_msgs[1], pks_msgs[3], pks_msgs[4]] {
385            let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
386            let mut hasher = Sha256::new();
387            hasher.update(aug_msg);
388            let hash: [u8; 32] = hasher.finalize();
389            assert!(
390                bls_cache
391                    .cache
392                    .lock()
393                    .expect("cache")
394                    .items
395                    .contains_key(&hash)
396            );
397        }
398    }
399}