1use std::borrow::Borrow;
2use std::num::NonZeroUsize;
3
4use crate::{GTElement, PublicKey, Signature};
5use crate::{aggregate_verify_gt, hash_to_g2};
6use chia_sha2::Sha256;
7use linked_hash_map::LinkedHashMap;
8use std::sync::Mutex;
9
10#[derive(Debug, Clone)]
21struct BlsCacheData {
22 items: LinkedHashMap<[u8; 32], GTElement>,
24 capacity: NonZeroUsize,
25}
26
27impl BlsCacheData {
28 pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) {
29 if self.items.len() == self.capacity.get() {
31 if let Some((oldest_key, _)) = self.items.pop_front() {
32 self.items.remove(&oldest_key);
33 }
34 }
35 self.items.insert(hash, pairing);
36 }
37}
38
39#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
40#[derive(Debug)]
41pub struct BlsCache {
42 cache: Mutex<BlsCacheData>,
43}
44
45impl Default for BlsCache {
46 fn default() -> Self {
47 Self::new(NonZeroUsize::new(50_000).unwrap())
48 }
49}
50
51impl Clone for BlsCache {
52 fn clone(&self) -> Self {
53 Self {
54 cache: Mutex::new(self.cache.lock().expect("cache").clone()),
55 }
56 }
57}
58
59impl BlsCache {
60 pub fn new(capacity: NonZeroUsize) -> Self {
61 Self {
62 cache: Mutex::new(BlsCacheData {
63 items: LinkedHashMap::new(),
64 capacity,
65 }),
66 }
67 }
68
69 pub fn len(&self) -> usize {
70 self.cache.lock().expect("cache").items.len()
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.cache.lock().expect("cache").items.is_empty()
75 }
76
77 pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
78 &self,
79 pks_msgs: impl IntoIterator<Item = (Pk, Msg)>,
80 sig: &Signature,
81 ) -> bool {
82 let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement {
83 let mut hasher = Sha256::new();
85 let mut aug_msg = pk.borrow().to_bytes().to_vec();
86 aug_msg.extend_from_slice(msg.as_ref());
87 hasher.update(&aug_msg);
88 let hash: [u8; 32] = hasher.finalize();
89
90 if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() {
92 return pairing;
93 }
94
95 let aug_hash = hash_to_g2(&aug_msg);
97
98 let pairing = aug_hash.pair(pk.borrow());
99 self.cache.lock().expect("cache").put(hash, pairing.clone());
100 pairing
101 });
102
103 aggregate_verify_gt(sig, iter)
104 }
105
106 pub fn update(&self, aug_msg: &[u8], gt: GTElement) {
107 let mut hasher = Sha256::new();
108 hasher.update(aug_msg.as_ref());
109 let hash: [u8; 32] = hasher.finalize();
110 self.cache.lock().expect("cache").put(hash, gt);
111 }
112
113 pub fn evict<Pk, Msg>(&self, pks_msgs: impl IntoIterator<Item = (Pk, Msg)>)
114 where
115 Pk: Borrow<PublicKey>,
116 Msg: AsRef<[u8]>,
117 {
118 let mut c = self.cache.lock().expect("cache");
119 for (pk, msg) in pks_msgs {
120 let mut hasher = Sha256::new();
121 let mut aug_msg = pk.borrow().to_bytes().to_vec();
122 aug_msg.extend_from_slice(msg.as_ref());
123 hasher.update(&aug_msg);
124 let hash: [u8; 32] = hasher.finalize();
125 c.items.remove(&hash);
126 }
127 }
128}
129
130#[cfg(feature = "py-bindings")]
131use pyo3::{
132 Bound, Py, PyResult,
133 exceptions::PyValueError,
134 pybacked::PyBackedBytes,
135 types::{PyAnyMethods, PyList, PySequence},
136};
137
138#[cfg(feature = "py-bindings")]
139#[pyo3::pymethods]
140impl BlsCache {
141 #[new]
142 #[pyo3(signature = (cache_size=None))]
143 pub fn init(cache_size: Option<u32>) -> PyResult<Self> {
144 let Some(size) = cache_size else {
145 return Ok(Self::default());
146 };
147
148 let Some(size) = NonZeroUsize::new(size as usize) else {
149 return Err(PyValueError::new_err(
150 "Cannot have a cache size less than one.",
151 ));
152 };
153
154 Ok(Self::new(size))
155 }
156
157 #[pyo3(name = "aggregate_verify")]
158 pub fn py_aggregate_verify(
159 &self,
160 pks: &Bound<'_, PyList>,
161 msgs: &Bound<'_, PyList>,
162 sig: &Signature,
163 ) -> PyResult<bool> {
164 let pks = pks
165 .try_iter()?
166 .map(|item| Ok(item?.extract()?))
167 .collect::<PyResult<Vec<PublicKey>>>()?;
168
169 let msgs = msgs
170 .try_iter()?
171 .map(|item| Ok(item?.extract()?))
172 .collect::<PyResult<Vec<PyBackedBytes>>>()?;
173
174 Ok(self.aggregate_verify(pks.into_iter().zip(msgs), sig))
175 }
176
177 #[pyo3(name = "len")]
178 pub fn py_len(&self) -> PyResult<usize> {
179 Ok(self.len())
180 }
181
182 #[pyo3(name = "items")]
183 pub fn py_items(&self, py: pyo3::Python<'_>) -> PyResult<Py<pyo3::PyAny>> {
184 use pyo3::prelude::*;
185 use pyo3::types::PyBytes;
186 let ret = PyList::empty(py);
187 let c = self.cache.lock().expect("cache");
188 for (key, value) in &c.items {
189 ret.append((
190 PyBytes::new(py, key),
191 value.clone().into_pyobject(py)?.into_any(),
192 ))?;
193 }
194 Ok(ret.into_any().unbind())
195 }
196
197 #[pyo3(name = "update")]
198 pub fn py_update(&self, other: &Bound<'_, PySequence>) -> PyResult<()> {
199 let mut c = self.cache.lock().expect("cache");
200 for item in other.borrow().try_iter()? {
201 let (key, value): (Vec<u8>, GTElement) = item?.extract()?;
202 c.put(
203 key.try_into()
204 .map_err(|_| PyValueError::new_err("invalid key"))?,
205 value,
206 );
207 }
208 Ok(())
209 }
210
211 #[pyo3(name = "evict")]
212 pub fn py_evict(&self, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>) -> PyResult<()> {
213 let pks = pks
214 .try_iter()?
215 .map(|item| Ok(item?.extract()?))
216 .collect::<PyResult<Vec<PublicKey>>>()?;
217 let msgs = msgs
218 .try_iter()?
219 .map(|item| Ok(item?.extract()?))
220 .collect::<PyResult<Vec<PyBackedBytes>>>()?;
221 self.evict(pks.into_iter().zip(msgs));
222 Ok(())
223 }
224}
225
226#[cfg(test)]
227pub mod tests {
228 use super::*;
229
230 use crate::SecretKey;
231 use crate::sign;
232
233 #[test]
234 fn test_aggregate_verify() {
235 let bls_cache = BlsCache::default();
236
237 let sk = SecretKey::from_seed(&[0; 32]);
238 let pk = sk.public_key();
239 let msg = [106; 32];
240
241 let sig = sign(&sk, msg);
242 let pks_msgs = [(pk, msg)];
243
244 assert!(bls_cache.is_empty());
246
247 assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
249 assert_eq!(bls_cache.len(), 1);
250
251 assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
253 assert_eq!(bls_cache.len(), 1);
254 }
255
256 #[test]
257 fn test_cache() {
258 let bls_cache = BlsCache::default();
259
260 let sk1 = SecretKey::from_seed(&[0; 32]);
261 let pk1 = sk1.public_key();
262 let msg1 = [106; 32];
263
264 let mut agg_sig = sign(&sk1, msg1);
265 let mut pks_msgs = vec![(pk1, msg1)];
266
267 assert!(bls_cache.is_empty());
269
270 assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
272 assert_eq!(bls_cache.len(), 1);
273
274 let sk2 = SecretKey::from_seed(&[1; 32]);
276 let pk2 = sk2.public_key();
277 let msg2 = [107; 32];
278
279 agg_sig += &sign(&sk2, msg2);
280 pks_msgs.push((pk2, msg2));
281
282 assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
283 assert_eq!(bls_cache.len(), 2);
284
285 let msg3 = [108; 32];
287
288 agg_sig += &sign(&sk2, msg3);
289 pks_msgs.push((pk2, msg3));
290
291 assert!(bls_cache.aggregate_verify(pks_msgs, &agg_sig));
293 assert_eq!(bls_cache.len(), 3);
294 }
295
296 #[test]
297 fn test_cache_limit() {
298 let bls_cache = BlsCache::new(NonZeroUsize::new(3).unwrap());
300
301 assert!(bls_cache.is_empty());
303
304 for i in 1..=5 {
306 let sk = SecretKey::from_seed(&[i; 32]);
307 let pk = sk.public_key();
308 let msg = [106; 32];
309
310 let sig = sign(&sk, msg);
311 let pks_msgs = [(pk, msg)];
312
313 assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
315 }
316
317 assert_eq!(bls_cache.len(), 3);
319
320 for i in 1..=2 {
322 let sk = SecretKey::from_seed(&[i; 32]);
323 let pk = sk.public_key();
324 let msg = [106; 32];
325 let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
326 let mut hasher = Sha256::new();
327 hasher.update(aug_msg);
328 let hash: [u8; 32] = hasher.finalize();
329 assert!(
330 !bls_cache
331 .cache
332 .lock()
333 .expect("cache")
334 .items
335 .contains_key(&hash)
336 );
337 }
338 }
339
340 #[test]
341 fn test_empty_sig() {
342 let bls_cache = BlsCache::default();
343
344 let pks_msgs: [(&PublicKey, &[u8]); 0] = [];
345
346 assert!(bls_cache.aggregate_verify(pks_msgs, &Signature::default()));
347 }
348
349 #[test]
350 fn test_evict() {
351 let bls_cache = BlsCache::new(NonZeroUsize::new(5).unwrap());
352 let mut pks_msgs = Vec::new();
354 for i in 1..=5 {
355 let sk = SecretKey::from_seed(&[i; 32]);
356 let pk = sk.public_key();
357 let msg = [42; 32];
358 let sig = sign(&sk, msg);
359 pks_msgs.push((pk, msg));
360 assert!(bls_cache.aggregate_verify([(pk, msg)], &sig));
361 }
362 assert_eq!(bls_cache.len(), 5);
363 let pks_msgs_to_evict = vec![pks_msgs[0], pks_msgs[2]];
365 bls_cache.evict(pks_msgs_to_evict.iter().copied());
366 assert_eq!(bls_cache.len(), 3);
368 for (pk, msg) in &pks_msgs_to_evict {
370 let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
371 let mut hasher = Sha256::new();
372 hasher.update(aug_msg);
373 let hash: [u8; 32] = hasher.finalize();
374 assert!(
375 !bls_cache
376 .cache
377 .lock()
378 .expect("cache")
379 .items
380 .contains_key(&hash)
381 );
382 }
383 for (pk, msg) in &[pks_msgs[1], pks_msgs[3], pks_msgs[4]] {
385 let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
386 let mut hasher = Sha256::new();
387 hasher.update(aug_msg);
388 let hash: [u8; 32] = hasher.finalize();
389 assert!(
390 bls_cache
391 .cache
392 .lock()
393 .expect("cache")
394 .items
395 .contains_key(&hash)
396 );
397 }
398 }
399}