use std::cmp;
use std::collections::BTreeMap;
use std::collections::btree_map;
use std::sync::OnceLock;
use std::sync::RwLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use crate::HashAlgorithm;
use crate::Result;
use crate::packet::Key;
use crate::packet::Signature;
use crate::packet::key;
const TRACE: bool = false;
static SIGNATURE_VERIFICATION_CACHE: SignatureVerificationCache
= SignatureVerificationCache::empty();
const HASH_ALGO: HashAlgorithm = HashAlgorithm::SHA512;
const HASH_BYTES_UNTRUNCATED: usize = 512 / 8;
const HASH_BYTES_TRUNCATED: usize = HASH_BYTES_UNTRUNCATED / 2;
const VALUE_BYTES: usize = HASH_BYTES_TRUNCATED;
type Value = [u8; VALUE_BYTES];
const VALUE_NULL: Value = [0u8; VALUE_BYTES];
#[derive(Debug)]
pub struct Metadata {
inserted: bool,
accessed: AtomicBool,
}
impl Clone for Metadata {
fn clone(&self) -> Metadata {
Self {
inserted: self.inserted,
accessed: AtomicBool::from(self.accessed.load(Ordering::Relaxed)),
}
}
}
impl Metadata {
fn new(inserted: bool) -> Self {
Metadata {
inserted,
accessed: false.into(),
}
}
pub fn inserted(&self) -> bool {
self.inserted
}
pub fn accessed(&self) -> bool {
self.accessed.load(Ordering::Relaxed)
}
}
#[derive(Clone)]
pub struct Entry {
value: Value,
metadata: Metadata,
}
impl PartialOrd for Entry {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Entry {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.value.cmp(&other.value)
}
}
impl PartialEq for Entry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == cmp::Ordering::Equal
}
}
impl Eq for Entry {}
impl Entry {
pub(super) fn new(sig: &Signature,
computed_digest: &[u8],
key: &Key<key::PublicParts, key::UnspecifiedRole>)
-> Result<Self>
{
use crate::serialize::Marshal;
use crate::serialize::MarshalInto;
let mut context = HASH_ALGO.context()?.for_digest();
context.update(&[ 0u8 ]);
let mpis_len = sig.mpis.serialized_len();
context.update(&[
(mpis_len & 0xFF) as u8,
((mpis_len >> 8) & 0xFF) as u8,
((mpis_len >> 16) & 0xFF) as u8,
((mpis_len >> 24) & 0xFF) as u8,
]);
sig.mpis.export(&mut context)?;
context.update(&[
u8::from(sig.hash_algo())
]);
context.update(computed_digest);
key.mpis().export(&mut context)?;
let context_hash = context.into_digest()?;
let mut value = VALUE_NULL;
value.copy_from_slice(&context_hash[..VALUE_BYTES]);
Ok(Entry {
value,
metadata: Metadata::new(true),
})
}
pub fn value(&self) -> &[u8] {
&self.value
}
pub(super) fn present(&self) -> bool {
SIGNATURE_VERIFICATION_CACHE.present(&self.value)
}
pub(super) fn insert(self, verified: bool) {
if verified {
SIGNATURE_VERIFICATION_CACHE.insert(self.value);
}
}
pub fn inserted(&self) -> bool {
self.metadata.inserted
}
pub fn accessed(&self) -> bool {
self.metadata.accessed.load(Ordering::Relaxed)
}
}
const BUCKETS_BITS: usize = 4;
const BUCKETS: usize = 1 << BUCKETS_BITS;
const BUCKETS_SHIFT: usize = 8 - BUCKETS_BITS;
pub struct SignatureVerificationCache {
list: OnceLock<Vec<Entry>>,
buckets: [
RwLock<BTreeMap<Value, Metadata>>;
BUCKETS
],
hits: AtomicUsize,
misses: AtomicUsize,
preloads: AtomicUsize,
insertions: AtomicUsize,
}
impl SignatureVerificationCache {
const fn empty() -> Self {
SignatureVerificationCache {
list: OnceLock::new(),
buckets: [
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
RwLock::new(BTreeMap::new()),
],
hits: AtomicUsize::new(0),
misses: AtomicUsize::new(0),
preloads: AtomicUsize::new(0),
insertions: AtomicUsize::new(0),
}
}
fn bucket(value: &[u8]) -> usize {
(value[0] >> BUCKETS_SHIFT) as usize
}
fn present(&self, value: &[u8]) -> bool {
assert_eq!(value.len(), HASH_BYTES_TRUNCATED);
if let Some(list) = self.list.get() {
if let Ok(i) = list.binary_search_by(|e| e.value[..].cmp(value)) {
list[i].metadata.accessed.store(true, Ordering::Relaxed);
self.hits.fetch_add(1, Ordering::Relaxed);
return true;
}
}
let i = Self::bucket(value);
let entries = self.buckets[i].read().unwrap();
if let Some(metadata) = entries.get(value) {
metadata.accessed.store(true, Ordering::Relaxed);
self.hits.fetch_add(1, Ordering::Relaxed);
true
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
false
}
}
fn insert(&self, value: [u8; HASH_BYTES_TRUNCATED]) {
let i = Self::bucket(&value);
let mut entries = self.buckets[i].write().unwrap();
match entries.entry(value) {
btree_map::Entry::Vacant(e) => {
self.insertions.fetch_add(1, Ordering::Relaxed);
e.insert(Metadata::new(true));
}
btree_map::Entry::Occupied(_e) => {
}
}
}
pub fn cache_hits() -> usize {
SIGNATURE_VERIFICATION_CACHE.hits.load(Ordering::Relaxed)
}
pub fn cache_misses() -> usize {
SIGNATURE_VERIFICATION_CACHE.misses.load(Ordering::Relaxed)
}
pub fn insertions() -> usize {
SIGNATURE_VERIFICATION_CACHE.insertions.load(Ordering::Relaxed)
}
pub fn clear_insertions() {
SIGNATURE_VERIFICATION_CACHE.insertions.store(0, Ordering::Relaxed);
}
pub fn restore<'a, F>(
entries: impl Iterator<Item=Vec<u8>> + Send + Sync + 'static,
finished: F)
where F: FnOnce() + Send + Sync + 'static
{
tracer!(TRACE, "SignatureVerificationCache::restore");
assert_eq!(HASH_ALGO.context().expect("have SHA-512")
.for_digest().digest_size(),
HASH_BYTES_UNTRUNCATED);
assert!(HASH_BYTES_TRUNCATED <= HASH_BYTES_UNTRUNCATED);
assert!(BUCKETS_BITS <= 8);
assert_eq!(BUCKETS, 1 << BUCKETS_BITS);
std::thread::spawn(move || {
let mut items: Vec<Entry> = Vec::with_capacity(32 * 1024);
let mut bad = 0;
let mut count = 0;
for entry in entries {
count += 1;
if entry.len() != VALUE_BYTES {
bad += 1;
continue;
}
let mut value = VALUE_NULL;
value.copy_from_slice(&entry[..VALUE_BYTES]);
items.push(Entry {
value,
metadata: Metadata::new(false),
});
}
if bad > 0 {
t!("Warning: {} of {} cache entries could not be read",
bad, count);
}
t!("Restored {} entries", count);
SIGNATURE_VERIFICATION_CACHE.preloads
.fetch_add(items.len(), Ordering::Relaxed);
items.sort();
if let Err(items) = SIGNATURE_VERIFICATION_CACHE.list.set(items) {
let mut bucket_i = 0;
let mut bucket = SIGNATURE_VERIFICATION_CACHE
.buckets[bucket_i].write().unwrap();
for item in items.into_iter() {
let i = Self::bucket(&item.value);
if i != bucket_i {
assert!(i > bucket_i);
bucket = SIGNATURE_VERIFICATION_CACHE
.buckets[i].write().unwrap();
bucket_i = i;
}
bucket.insert(item.value, item.metadata);
}
}
finished();
});
}
pub fn dump<'a>() -> impl IntoIterator<Item=Entry> {
tracer!(TRACE, "SignatureVerificationCache::dump");
if TRACE {
let preloads = SIGNATURE_VERIFICATION_CACHE
.preloads.load(Ordering::Relaxed);
let insertions = SIGNATURE_VERIFICATION_CACHE
.insertions.load(Ordering::Relaxed);
t!("{} entries: {} restored, {} inserted",
preloads + insertions,
preloads, insertions);
let hits = SIGNATURE_VERIFICATION_CACHE
.hits.load(Ordering::Relaxed);
let misses = SIGNATURE_VERIFICATION_CACHE
.misses.load(Ordering::Relaxed);
let lookups = hits + misses;
if lookups > 0 {
t!("{} cache lookups, {} hits ({}%), {} misses ({}%)",
lookups,
hits, (100 * hits) / lookups,
misses, (100 * misses) / lookups);
} else {
t!("0 cache lookups");
}
}
DumpIter {
bucket: 0,
iter: None,
list: SIGNATURE_VERIFICATION_CACHE.list.get()
.map(|list| list.clone())
.unwrap_or(Vec::new()),
}
}
}
struct DumpIter {
iter: Option<std::vec::IntoIter<Entry>>,
bucket: usize,
list: Vec<Entry>,
}
impl Iterator for DumpIter {
type Item = Entry;
fn next(&mut self) -> Option<Self::Item> {
tracer!(TRACE, "DumpIter::next");
loop {
if let Some(ref mut iter) = self.iter {
if let Some(item) = iter.next() {
return Some(item);
}
}
if self.bucket == BUCKETS {
if self.list.is_empty() {
return None;
}
let list = std::mem::take(&mut self.list);
t!("Dumping {} restored entries", list.len());
self.iter = Some(list.into_iter());
} else {
let bucket = &SIGNATURE_VERIFICATION_CACHE.buckets[self.bucket];
self.bucket += 1;
let bucket = bucket.read().unwrap();
t!("Dumping {} entries from bucket {}",
bucket.len(), self.bucket - 1);
self.iter = Some(
bucket.iter()
.filter_map(|(v, m)| {
if let Ok(_) = self.list.binary_search_by(|e| {
e.value[..].cmp(v)
})
{
None
} else {
Some(Entry {
value: v.clone(),
metadata: m.clone(),
})
}
})
.collect::<Vec<_>>()
.into_iter())
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn bucket() {
let mut bucket = 0;
let mut bucket_count = vec![0; BUCKETS];
for i in 0..=u8::MAX {
let mut value = VALUE_NULL;
value[0] = i;
let b = SignatureVerificationCache::bucket(&value);
if b != bucket {
assert_eq!(b, bucket + 1);
bucket = bucket + 1;
}
bucket_count[b] += 1;
}
for (i, c) in bucket_count.iter().enumerate() {
eprintln!("{}: {}", i, c);
}
assert!(bucket_count.iter().all(|c| *c == bucket_count[0]));
assert_eq!(bucket_count.iter().map(|c| *c as usize).sum::<usize>(),
u8::MAX as usize + 1);
}
}