use std::collections::HashSet;
use std::sync::{
RwLock,
atomic::{AtomicU64, Ordering},
};
use rand::{rngs::OsRng, RngCore};
#[must_use]
pub fn fresh_nonce() -> [u8; 16] {
let mut nonce = [0u8; 16];
OsRng.fill_bytes(&mut nonce);
nonce
}
pub trait RevocationStore: Send + Sync {
fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, String>;
fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), String>;
}
pub trait NonceStore: Send + Sync {
fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, String>;
fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), String>;
}
pub struct MemoryRevocationStore {
bloom: Vec<AtomicU64>,
shards: Vec<RwLock<HashSet<[u8; 32]>>>,
}
impl MemoryRevocationStore {
#[must_use]
pub fn new() -> Self {
let bloom = (0..64).map(|_| AtomicU64::new(0)).collect();
let shards = (0..64).map(|_| RwLock::new(HashSet::new())).collect();
Self { bloom, shards }
}
#[inline(always)]
fn bloom_indices(fp: &[u8; 32]) -> (usize, u64) {
let mut entropy = 0u64;
for chunk in fp.chunks_exact(8) {
entropy ^= u64::from_le_bytes(chunk.try_into().unwrap());
}
let word = (entropy % 64) as usize;
let bit = 1u64 << (entropy.rotate_right(13) % 64);
(word, bit)
}
#[inline(always)]
fn shard_index(fp: &[u8; 32]) -> usize {
(u64::from_le_bytes(fp[24..32].try_into().unwrap()) % 64) as usize
}
}
impl Default for MemoryRevocationStore {
fn default() -> Self { Self::new() }
}
impl RevocationStore for MemoryRevocationStore {
fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, String> {
let (word, bit) = Self::bloom_indices(fingerprint);
if (self.bloom[word].load(Ordering::Relaxed) & bit) == 0 {
return Ok(false);
}
let shard = self.shards[Self::shard_index(fingerprint)]
.read()
.map_err(|_| "revocation shard lock poisoned".to_owned())?;
Ok(shard.contains(fingerprint))
}
fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), String> {
let (word, bit) = Self::bloom_indices(fingerprint);
self.bloom[word].fetch_or(bit, Ordering::SeqCst);
self.shards[Self::shard_index(fingerprint)]
.write()
.map_err(|_| "revocation shard lock poisoned".to_owned())?
.insert(*fingerprint);
Ok(())
}
}
pub struct MemoryNonceStore {
bloom: Vec<AtomicU64>,
shards: Vec<RwLock<HashSet<[u8; 16]>>>,
}
impl MemoryNonceStore {
const BLOOM_WORDS: usize = 1024;
const SHARD_COUNT: usize = 256;
#[must_use]
pub fn new() -> Self {
let bloom = (0..Self::BLOOM_WORDS).map(|_| AtomicU64::new(0)).collect();
let shards = (0..Self::SHARD_COUNT).map(|_| RwLock::new(HashSet::new())).collect();
Self { bloom, shards }
}
#[inline(always)]
fn indices(nonce: &[u8; 16]) -> (usize, u64, usize) {
let e1 = u64::from_le_bytes(nonce[0..8].try_into().unwrap());
let e2 = u64::from_le_bytes(nonce[8..16].try_into().unwrap());
let h = e1.wrapping_mul(0x9E3779B185EBCA87).wrapping_add(e2.rotate_left(23));
let word = (h as usize) % Self::BLOOM_WORDS;
let bit = 1u64 << (h.rotate_right(11) % 64);
let shard = (h.rotate_right(31) as usize) % Self::SHARD_COUNT;
(word, bit, shard)
}
}
impl Default for MemoryNonceStore {
fn default() -> Self { Self::new() }
}
impl NonceStore for MemoryNonceStore {
fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, String> {
let (word, bit, shard) = Self::indices(nonce);
if (self.bloom[word].load(Ordering::Acquire) & bit) == 0 {
return Ok(false);
}
let s = self.shards[shard]
.read()
.map_err(|_| "nonce shard lock poisoned".to_owned())?;
Ok(s.contains(nonce))
}
fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), String> {
let (word, bit, shard) = Self::indices(nonce);
self.bloom[word].fetch_or(bit, Ordering::Release);
self.shards[shard]
.write()
.map_err(|_| "nonce shard lock poisoned".to_owned())?
.insert(*nonce);
Ok(())
}
}