1use std::borrow::Borrow;
2use std::num::NonZeroUsize;
3
4use chik_sha2::Sha256;
5use linked_hash_map::LinkedHashMap;
6use std::sync::Mutex;
7
8use crate::{aggregate_verify_gt, hash_to_g2};
9use crate::{GTElement, PublicKey, Signature};
10
11#[derive(Debug, Clone)]
22struct BlsCacheData {
23 items: LinkedHashMap<[u8; 32], GTElement>,
25 capacity: NonZeroUsize,
26}
27
28impl BlsCacheData {
29 pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) {
30 if self.items.len() == self.capacity.get() {
32 if let Some((oldest_key, _)) = self.items.pop_front() {
33 self.items.remove(&oldest_key);
34 }
35 }
36 self.items.insert(hash, pairing);
37 }
38}
39
40#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
41#[derive(Debug)]
42pub struct BlsCache {
43 cache: Mutex<BlsCacheData>,
44}
45
46impl Default for BlsCache {
47 fn default() -> Self {
48 Self::new(NonZeroUsize::new(50_000).unwrap())
49 }
50}
51
52impl Clone for BlsCache {
53 fn clone(&self) -> Self {
54 Self {
55 cache: Mutex::new(self.cache.lock().expect("cache").clone()),
56 }
57 }
58}
59
60impl BlsCache {
61 pub fn new(capacity: NonZeroUsize) -> Self {
62 Self {
63 cache: Mutex::new(BlsCacheData {
64 items: LinkedHashMap::new(),
65 capacity,
66 }),
67 }
68 }
69
70 pub fn len(&self) -> usize {
71 self.cache.lock().expect("cache").items.len()
72 }
73
74 pub fn is_empty(&self) -> bool {
75 self.cache.lock().expect("cache").items.is_empty()
76 }
77
78 pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
79 &self,
80 pks_msgs: impl IntoIterator<Item = (Pk, Msg)>,
81 sig: &Signature,
82 ) -> bool {
83 let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement {
84 let mut hasher = Sha256::new();
86 let mut aug_msg = pk.borrow().to_bytes().to_vec();
87 aug_msg.extend_from_slice(msg.as_ref());
88 hasher.update(&aug_msg);
89 let hash: [u8; 32] = hasher.finalize();
90
91 if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() {
93 return pairing;
94 }
95
96 let aug_hash = hash_to_g2(&aug_msg);
98
99 let pairing = aug_hash.pair(pk.borrow());
100 self.cache.lock().expect("cache").put(hash, pairing.clone());
101 pairing
102 });
103
104 aggregate_verify_gt(sig, iter)
105 }
106
107 pub fn update(&self, aug_msg: &[u8], gt: GTElement) {
108 let mut hasher = Sha256::new();
109 hasher.update(aug_msg.as_ref());
110 let hash: [u8; 32] = hasher.finalize();
111 self.cache.lock().expect("cache").put(hash, gt);
112 }
113
114 pub fn evict<Pk, Msg>(&self, pks_msgs: impl IntoIterator<Item = (Pk, Msg)>)
115 where
116 Pk: Borrow<PublicKey>,
117 Msg: AsRef<[u8]>,
118 {
119 let mut c = self.cache.lock().expect("cache");
120 for (pk, msg) in pks_msgs {
121 let mut hasher = Sha256::new();
122 let mut aug_msg = pk.borrow().to_bytes().to_vec();
123 aug_msg.extend_from_slice(msg.as_ref());
124 hasher.update(&aug_msg);
125 let hash: [u8; 32] = hasher.finalize();
126 c.items.remove(&hash);
127 }
128 }
129}
130
131#[cfg(feature = "py-bindings")]
132use pyo3::{
133 exceptions::PyValueError,
134 pybacked::PyBackedBytes,
135 types::{PyAnyMethods, PyList, PySequence},
136 Bound, PyObject, PyResult,
137};
138
139#[cfg(feature = "py-bindings")]
140#[pyo3::pymethods]
141impl BlsCache {
142 #[new]
143 #[pyo3(signature = (size=None))]
144 pub fn init(size: Option<u32>) -> PyResult<Self> {
145 let Some(size) = size else {
146 return Ok(Self::default());
147 };
148
149 let Some(size) = NonZeroUsize::new(size as usize) else {
150 return Err(PyValueError::new_err(
151 "Cannot have a cache size less than one.",
152 ));
153 };
154
155 Ok(Self::new(size))
156 }
157
158 #[pyo3(name = "aggregate_verify")]
159 pub fn py_aggregate_verify(
160 &self,
161 pks: &Bound<'_, PyList>,
162 msgs: &Bound<'_, PyList>,
163 sig: &Signature,
164 ) -> PyResult<bool> {
165 let pks = pks
166 .try_iter()?
167 .map(|item| item?.extract())
168 .collect::<PyResult<Vec<PublicKey>>>()?;
169
170 let msgs = msgs
171 .try_iter()?
172 .map(|item| item?.extract())
173 .collect::<PyResult<Vec<PyBackedBytes>>>()?;
174
175 Ok(self.aggregate_verify(pks.into_iter().zip(msgs), sig))
176 }
177
178 #[pyo3(name = "len")]
179 pub fn py_len(&self) -> PyResult<usize> {
180 Ok(self.len())
181 }
182
183 #[pyo3(name = "items")]
184 pub fn py_items(&self, py: pyo3::Python<'_>) -> PyResult<PyObject> {
185 use pyo3::prelude::*;
186 use pyo3::types::PyBytes;
187 let ret = PyList::empty(py);
188 let c = self.cache.lock().expect("cache");
189 for (key, value) in &c.items {
190 ret.append((
191 PyBytes::new(py, key),
192 value.clone().into_pyobject(py)?.into_any(),
193 ))?;
194 }
195 Ok(ret.into())
196 }
197
198 #[pyo3(name = "update")]
199 pub fn py_update(&self, other: &Bound<'_, PySequence>) -> PyResult<()> {
200 let mut c = self.cache.lock().expect("cache");
201 for item in other.borrow().try_iter()? {
202 let (key, value): (Vec<u8>, GTElement) = item?.extract()?;
203 c.put(
204 key.try_into()
205 .map_err(|_| PyValueError::new_err("invalid key"))?,
206 value,
207 );
208 }
209 Ok(())
210 }
211
212 #[pyo3(name = "evict")]
213 pub fn py_evict(&self, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>) -> PyResult<()> {
214 let pks = pks
215 .try_iter()?
216 .map(|item| item?.extract())
217 .collect::<PyResult<Vec<PublicKey>>>()?;
218 let msgs = msgs
219 .try_iter()?
220 .map(|item| item?.extract())
221 .collect::<PyResult<Vec<PyBackedBytes>>>()?;
222 self.evict(pks.into_iter().zip(msgs));
223 Ok(())
224 }
225}
226
227#[cfg(test)]
228pub mod tests {
229 use super::*;
230
231 use crate::sign;
232 use crate::SecretKey;
233
234 #[test]
235 fn test_aggregate_verify() {
236 let bls_cache = BlsCache::default();
237
238 let sk = SecretKey::from_seed(&[0; 32]);
239 let pk = sk.public_key();
240 let msg = [106; 32];
241
242 let sig = sign(&sk, msg);
243 let pks_msgs = [(pk, msg)];
244
245 assert!(bls_cache.is_empty());
247
248 assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
250 assert_eq!(bls_cache.len(), 1);
251
252 assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
254 assert_eq!(bls_cache.len(), 1);
255 }
256
257 #[test]
258 fn test_cache() {
259 let bls_cache = BlsCache::default();
260
261 let sk1 = SecretKey::from_seed(&[0; 32]);
262 let pk1 = sk1.public_key();
263 let msg1 = [106; 32];
264
265 let mut agg_sig = sign(&sk1, msg1);
266 let mut pks_msgs = vec![(pk1, msg1)];
267
268 assert!(bls_cache.is_empty());
270
271 assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
273 assert_eq!(bls_cache.len(), 1);
274
275 let sk2 = SecretKey::from_seed(&[1; 32]);
277 let pk2 = sk2.public_key();
278 let msg2 = [107; 32];
279
280 agg_sig += &sign(&sk2, msg2);
281 pks_msgs.push((pk2, msg2));
282
283 assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
284 assert_eq!(bls_cache.len(), 2);
285
286 let msg3 = [108; 32];
288
289 agg_sig += &sign(&sk2, msg3);
290 pks_msgs.push((pk2, msg3));
291
292 assert!(bls_cache.aggregate_verify(pks_msgs, &agg_sig));
294 assert_eq!(bls_cache.len(), 3);
295 }
296
297 #[test]
298 fn test_cache_limit() {
299 let bls_cache = BlsCache::new(NonZeroUsize::new(3).unwrap());
301
302 assert!(bls_cache.is_empty());
304
305 for i in 1..=5 {
307 let sk = SecretKey::from_seed(&[i; 32]);
308 let pk = sk.public_key();
309 let msg = [106; 32];
310
311 let sig = sign(&sk, msg);
312 let pks_msgs = [(pk, msg)];
313
314 assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
316 }
317
318 assert_eq!(bls_cache.len(), 3);
320
321 for i in 1..=2 {
323 let sk = SecretKey::from_seed(&[i; 32]);
324 let pk = sk.public_key();
325 let msg = [106; 32];
326 let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
327 let mut hasher = Sha256::new();
328 hasher.update(aug_msg);
329 let hash: [u8; 32] = hasher.finalize();
330 assert!(!bls_cache
331 .cache
332 .lock()
333 .expect("cache")
334 .items
335 .contains_key(&hash));
336 }
337 }
338
339 #[test]
340 fn test_empty_sig() {
341 let bls_cache = BlsCache::default();
342
343 let pks_msgs: [(&PublicKey, &[u8]); 0] = [];
344
345 assert!(bls_cache.aggregate_verify(pks_msgs, &Signature::default()));
346 }
347
348 #[test]
349 fn test_evict() {
350 let bls_cache = BlsCache::new(NonZeroUsize::new(5).unwrap());
351 let mut pks_msgs = Vec::new();
353 for i in 1..=5 {
354 let sk = SecretKey::from_seed(&[i; 32]);
355 let pk = sk.public_key();
356 let msg = [42; 32];
357 let sig = sign(&sk, msg);
358 pks_msgs.push((pk, msg));
359 assert!(bls_cache.aggregate_verify([(pk, msg)], &sig));
360 }
361 assert_eq!(bls_cache.len(), 5);
362 let pks_msgs_to_evict = vec![pks_msgs[0], pks_msgs[2]];
364 bls_cache.evict(pks_msgs_to_evict.iter().copied());
365 assert_eq!(bls_cache.len(), 3);
367 for (pk, msg) in &pks_msgs_to_evict {
369 let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
370 let mut hasher = Sha256::new();
371 hasher.update(aug_msg);
372 let hash: [u8; 32] = hasher.finalize();
373 assert!(!bls_cache
374 .cache
375 .lock()
376 .expect("cache")
377 .items
378 .contains_key(&hash));
379 }
380 for (pk, msg) in &[pks_msgs[1], pks_msgs[3], pks_msgs[4]] {
382 let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
383 let mut hasher = Sha256::new();
384 hasher.update(aug_msg);
385 let hash: [u8; 32] = hasher.finalize();
386 assert!(bls_cache
387 .cache
388 .lock()
389 .expect("cache")
390 .items
391 .contains_key(&hash));
392 }
393 }
394}