use std::{
collections::HashMap,
hash::{BuildHasher, RandomState},
num::NonZeroUsize,
sync::Arc,
time::{Duration, Instant},
};
use cosmian_kmip::{
KmipError,
kmip_2_1::kmip_objects::Object,
ttlv::{KmipFlavor, to_ttlv},
};
use cosmian_logger::{debug, trace, warn};
use lru::LruCache;
#[cfg(test)]
use tokio::sync::RwLockReadGuard;
use tokio::sync::{
RwLock,
mpsc::{self, Receiver, Sender},
oneshot,
};
use crate::{DbError, error::DbResult};
#[derive(Clone)]
pub struct CachedObject {
fingerprint: u64,
unwrapped_object: Object,
}
impl CachedObject {
#[must_use]
pub const fn new(key_signature: u64, unwrapped_object: Object) -> Self {
Self {
fingerprint: key_signature,
unwrapped_object,
}
}
#[must_use]
pub const fn fingerprint(&self) -> u64 {
self.fingerprint
}
#[must_use]
pub const fn unwrapped_object(&self) -> &Object {
&self.unwrapped_object
}
}
pub struct UnwrappedCache {
seed: RandomState,
cache: Arc<RwLock<LruCache<String, CachedObject>>>,
access_timestamps: Arc<RwLock<HashMap<String, Instant>>>,
access_sender: Sender<String>,
gc_interval: Duration,
max_age: Duration,
shutdown_sender: Option<oneshot::Sender<()>>,
}
impl UnwrappedCache {
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn new(max_age: Duration) -> Self {
#[allow(clippy::expect_used)]
let max_size = NonZeroUsize::new(100).expect("100 is not zero. This will never trigger");
let (tx, rx) = mpsc::channel(100_000);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let cache = Arc::new(RwLock::new(LruCache::new(max_size)));
let access_timestamps = Arc::new(RwLock::new(HashMap::new()));
let gc_interval = max_age + max_age / 2;
let unwrapped_cache = Self {
seed: RandomState::new(),
cache,
access_timestamps,
access_sender: tx,
gc_interval,
max_age,
shutdown_sender: Some(shutdown_tx),
};
unwrapped_cache.spawn_gc_thread(rx, shutdown_rx);
unwrapped_cache
}
fn spawn_gc_thread(&self, mut rx: Receiver<String>, mut shutdown_rx: oneshot::Receiver<()>) {
let timestamps = self.access_timestamps.clone();
let cache = self.cache.clone();
let interval = self.gc_interval;
let max_age = self.max_age;
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
tokio::select! {
shutdown = &mut shutdown_rx => {
if shutdown.is_ok() {
debug!("Cache garbage collection thread shutting down");
break;
}
}
Some(key) = rx.recv() => {
let mut timestamps_lock = timestamps.write().await;
timestamps_lock.insert(key, Instant::now());
}
_ = interval_timer.tick() => {
debug!("Running cache garbage collection");
let now = Instant::now();
let mut keys_to_remove = Vec::new();
{
let timestamps_lock = timestamps.read().await;
for (key, last_access) in timestamps_lock.iter() {
if now.duration_since(*last_access) > max_age {
keys_to_remove.push(key.clone());
}
}
}
if !keys_to_remove.is_empty() {
let mut timestamps_lock = timestamps.write().await;
let mut cache_lock = cache.write().await;
for key in &keys_to_remove {
timestamps_lock.remove(key);
cache_lock.pop(key);
}
debug!("Garbage collected {} stale cache entries", keys_to_remove.len());
}
}
}
}
debug!("Cache garbage collection thread terminated");
});
}
async fn record_access(&self, uid: &str) -> DbResult<()> {
if let Err(e) = self.access_sender.send(uid.to_owned()).await {
warn!("Failed to send cache access timestamp: {}", e);
return Err(DbError::UnwrappedCache(e.to_string()));
}
Ok(())
}
fn fingerprint(&self, object: &Object) -> DbResult<u64> {
to_ttlv(&object)
.and_then(|ttlv| ttlv.to_bytes(KmipFlavor::Kmip2))
.map_err(KmipError::from)
.map_err(DbError::from)
.map(|bytes| {
self.seed.hash_one(&bytes)
})
}
pub async fn validate_cache(&self, uid: &str, object: &Object) -> DbResult<()> {
let mut cache = self.cache.write().await;
if let Some(cached_object) = cache.peek(uid) {
if cached_object.fingerprint() != self.fingerprint(object)? {
trace!("Invalidating the cache for {}", uid);
cache.pop(uid);
}
}
Ok(())
}
pub async fn clear_cache(&self, uid: &str) {
self.cache.write().await.pop(uid);
self.access_timestamps.write().await.remove(uid);
}
pub async fn peek(&self, uid: &str, wrapped_object: &Object) -> DbResult<Option<Object>> {
let cache_read = self.cache.read();
match cache_read.await.peek(uid) {
Some(cached_object) => {
self.record_access(uid).await?;
if cached_object.fingerprint() == self.fingerprint(wrapped_object)? {
Ok(Some(cached_object.unwrapped_object().clone()))
} else {
Ok(None)
}
}
None => Ok(None),
}
}
pub async fn insert(
&self,
uid: String,
wrapped_object: &Object,
unwrapped_object: Object,
) -> DbResult<()> {
if wrapped_object == &unwrapped_object {
return Err(DbError::UnwrappedCache(
"wrapped and unwrapped objects should be different".to_owned(),
));
}
self.cache.write().await.put(
uid.clone(),
CachedObject {
fingerprint: self.fingerprint(wrapped_object)?,
unwrapped_object,
},
);
self.access_timestamps
.write()
.await
.insert(uid, Instant::now());
Ok(())
}
#[cfg(test)]
pub async fn get_cache(&self) -> RwLockReadGuard<'_, LruCache<String, CachedObject>> {
self.cache.read().await
}
}
impl Drop for UnwrappedCache {
fn drop(&mut self) {
if let Some(shutdown_tx) = self.shutdown_sender.take() {
let _ = shutdown_tx.send(());
debug!("Sent shutdown signal to cache garbage collection thread");
}
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::panic_in_result_fn,
clippy::unwrap_in_result,
clippy::assertions_on_result_states,
clippy::assertions_on_constants
)]
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use cosmian_kmip::kmip_2_1::{
extra::tagging::VENDOR_ID_COSMIAN, kmip_attributes::Attributes,
kmip_types::CryptographicAlgorithm, requests::create_symmetric_key_kmip_object,
};
use cosmian_kms_crypto::reexport::cosmian_crypto_core::{
CsRng,
reexport::rand_core::{RngCore, SeedableRng},
};
use cosmian_logger::log_init;
use tempfile::TempDir;
use uuid::Uuid;
use crate::{Database, core::main_db_params::MainDbParams, error::DbResult};
#[tokio::test]
async fn test_lru_cache() -> DbResult<()> {
log_init(option_env!("RUST_LOG"));
let dir = TempDir::new()?;
let main_db_params = MainDbParams::Sqlite(dir.path().to_owned(), None);
let database = Database::instantiate(
&main_db_params,
true,
HashMap::new(),
Duration::from_millis(100),
)
.await?;
let mut rng = CsRng::from_entropy();
let mut symmetric_key_bytes = vec![0; 32];
rng.fill_bytes(&mut symmetric_key_bytes);
let symmetric_key = create_symmetric_key_kmip_object(
VENDOR_ID_COSMIAN,
&symmetric_key_bytes,
&Attributes {
cryptographic_algorithm: Some(CryptographicAlgorithm::AES),
..Attributes::default()
},
)?;
let owner = "eyJhbGciOiJSUzI1Ni";
let uid = Uuid::new_v4().to_string();
let uid_ = database
.create(
Some(uid.clone()),
owner,
&symmetric_key,
symmetric_key.attributes()?,
&HashSet::new(),
)
.await?;
assert_eq!(&uid, &uid_);
assert!(
database
.unwrapped_cache()
.peek(&uid, &symmetric_key)
.await?
.is_none()
);
let owm = database.retrieve_object(&uid).await?;
match owm {
Some(obj) => assert_eq!(obj.id(), &uid),
None => assert!(false, "expected object to be present"),
}
{
let cache = database.unwrapped_cache.get_cache();
assert!(cache.await.peek(&uid).is_none());
};
Ok(())
}
#[tokio::test]
async fn test_garbage_collection() -> DbResult<()> {
log_init(option_env!("RUST_LOG"));
let cache = super::UnwrappedCache::new(
Duration::from_millis(100), );
let uid = "test_item".to_owned();
let unwrapped_object = create_symmetric_key_kmip_object(
VENDOR_ID_COSMIAN,
&[0; 32],
&Attributes {
cryptographic_algorithm: Some(CryptographicAlgorithm::AES),
..Attributes::default()
},
)?;
let wrapped_object = create_symmetric_key_kmip_object(
VENDOR_ID_COSMIAN,
&[0; 32],
&Attributes {
cryptographic_algorithm: Some(CryptographicAlgorithm::AES),
..Attributes::default()
},
)?;
cache
.insert(uid.clone(), &wrapped_object, unwrapped_object.clone())
.await?;
assert_eq!(
cache.peek(&uid, &wrapped_object).await?,
Some(unwrapped_object)
);
tokio::time::sleep(Duration::from_millis(350)).await;
assert!(cache.peek(&uid, &wrapped_object).await?.is_none());
Ok(())
}
#[tokio::test]
async fn test_gc_thread_shutdown() -> DbResult<()> {
log_init(option_env!("RUST_LOG"));
{
let cache = super::UnwrappedCache::new(Duration::from_millis(100));
let uid = "test_item".to_owned();
let wrapped_object = create_symmetric_key_kmip_object(
VENDOR_ID_COSMIAN,
&[0; 32],
&Attributes {
cryptographic_algorithm: Some(CryptographicAlgorithm::AES),
..Attributes::default()
},
)?;
let unwrapped_object = create_symmetric_key_kmip_object(
VENDOR_ID_COSMIAN,
&[0; 32],
&Attributes {
cryptographic_algorithm: Some(CryptographicAlgorithm::AES),
..Attributes::default()
},
)?;
cache
.insert(uid.clone(), &wrapped_object, unwrapped_object.clone())
.await?;
assert_eq!(
cache.peek(&uid, &wrapped_object).await?,
Some(unwrapped_object),
);
};
tokio::time::sleep(Duration::from_millis(50)).await;
Ok(())
}
}