use std::collections::HashMap;
use std::hash::Hash;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use chrono::{DateTime, Duration, Utc};
use flexbuffers::{FlexbufferSerializer, Reader};
use serde::Serialize;
pub use crate::result::Error;
pub struct CacheEntry<V> {
pub value: Arc<V>,
pub expiration: AtomicU64,
}
pub struct Cache<K, V>
where
K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
V: Send + Clone + for<'de> serde::Deserialize<'de> + serde::Serialize,
{
pub(crate) in_memory: Arc<RwLock<HashMap<K, CacheEntry<V>>>>,
pub(crate) content: sled::Tree,
pub(crate) expiry: sled::Tree,
pub(crate) memory_duration: Duration,
pub(crate) disk_duration: Duration,
}
impl<K, V> Cache<K, V>
where
K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
V: Send + Clone + for<'de> serde::Deserialize<'de> + serde::Serialize,
{
pub async fn get_or_insert_infallible<F>(&self, key: &K, thunk: F) -> Result<Arc<V>, Error<()>>
where
F: std::future::Future<Output = V>,
{
self.get_or_insert::<_, ()>(key, async { Ok(thunk.await) })
.await
}
pub async fn get_or_insert<F, E>(&self, key: &K, thunk: F) -> Result<Arc<V>, Error<E>>
where
F: std::future::Future<Output = Result<V, E>>,
{
let memory_expiration = Utc::now() + self.memory_duration;
let mut key_serializer = FlexbufferSerializer::new();
key.serialize(&mut key_serializer).unwrap(); let key_bin = key_serializer.take_buffer();
match self.get_at(key, &key_bin, memory_expiration) {
Ok(Some(found)) => return Ok(found),
Ok(None) => {}
Err(Error::Database(err)) => return Err(Error::Database(err)),
Err(e) => panic!("We shouldn't have any other error here {:?}", e),
}
let data = thunk.await.map_err(Error::Client)?;
let result = Arc::new(data);
self.store_in_memory_cache(key, &result, memory_expiration);
let disk_expiration = Utc::now() + self.disk_duration;
self.store_in_disk_cache(&key_bin, &result, disk_expiration)
.map_err(Error::Database)?;
Ok(result)
}
pub fn get(&self, key: &K) -> Result<Option<Arc<V>>, Error<()>> {
let mut key_serializer = FlexbufferSerializer::new();
key.serialize(&mut key_serializer).unwrap(); let key_bin = key_serializer.take_buffer();
self.get_at(key, &key_bin, Utc::now() + self.memory_duration)
}
fn get_at(
&self,
key: &K,
key_bin: &[u8],
memory_expiration: DateTime<Utc>,
) -> Result<Option<Arc<V>>, Error<()>> {
{
let read_lock = self.in_memory.read().unwrap();
if let Some(found) = read_lock.get(key) {
found
.expiration
.store(memory_expiration.timestamp() as u64, Ordering::Relaxed);
return Ok(Some(found.value.clone()));
}
}
debug!(target: "disk-cache", "Value not found in memory");
{
if let Some(value_bin) = self.content.get(&key_bin).map_err(Error::Database)? {
debug!(target: "disk-cache", "Value was in disk cache");
let reader = Reader::get_root(&value_bin).unwrap();
if let Ok(value) = V::deserialize(reader) {
debug!(target: "disk-cache", "Value deserialized");
let result = Arc::new(value);
self.store_in_memory_cache(key, &result, memory_expiration);
return Ok(Some(result));
}
}
}
debug!(target: "disk-cache", "Value not found on disk");
Ok(None)
}
fn store_in_memory_cache(&self, key: &K, value: &Arc<V>, expiration: DateTime<Utc>) {
debug!(target: "disk-cache", "Adding value to memory cache");
let mut write_lock = self.in_memory.write().unwrap();
let entry = CacheEntry {
value: value.clone(),
expiration: AtomicU64::new(expiration.timestamp() as u64),
};
write_lock.insert(key.clone(), entry);
}
fn store_in_disk_cache(
&self,
key: &[u8],
value: &Arc<V>,
expiration: DateTime<Utc>,
) -> Result<(), sled::Error> {
debug!(target: "disk-", "Adding value to disk cache");
let mut value_serializer = FlexbufferSerializer::new();
value.serialize(&mut value_serializer).unwrap();
let entry_bin = value_serializer.take_buffer();
self.content.insert(key, entry_bin)?;
self.expiry
.insert(u64_to_bytes(expiration.timestamp() as u64), key)?;
Ok(())
}
pub fn cleanup_expired_from_memory_cache(&self) {
cleanup_memory_cache(&self.in_memory)
}
pub fn cleanup_expired_disk_cache(&self) {
cleanup_disk_cache::<K, V>(&self.expiry, &self.content)
}
}
pub fn cleanup_memory_cache<K, V>(memory_cache: &Arc<RwLock<HashMap<K, CacheEntry<V>>>>)
where
K: Eq + Hash + Clone,
{
let now = Utc::now().timestamp() as u64;
{
let mut write_lock = memory_cache.write().unwrap();
write_lock.retain(|_, v| v.expiration.load(Ordering::Relaxed) > now)
}
}
pub fn cleanup_disk_cache<K, V>(expiry: &sled::Tree, content: &sled::Tree)
where
K: Send + Clone + Hash + Eq + for<'de> serde::Deserialize<'de> + serde::Serialize,
{
let now = Utc::now();
let mut batch = sled::Batch::default();
for cursor in expiry.range(u64_to_bytes(0)..u64_to_bytes(now.timestamp() as u64)) {
let (ts, k) = cursor.unwrap(); debug_assert!(bytes_to_u64(&ts) <= now.timestamp() as u64);
batch.remove(k);
}
content.apply_batch(batch).unwrap(); }
fn bytes_to_u64(bytes: &[u8]) -> u64 {
((bytes[0] as u64) << 56)
+ ((bytes[1] as u64) << 48)
+ ((bytes[2] as u64) << 40)
+ ((bytes[3] as u64) << 32)
+ ((bytes[4] as u64) << 24)
+ ((bytes[5] as u64) << 16)
+ ((bytes[6] as u64) << 8)
+ bytes[7] as u64
}
fn u64_to_bytes(value: u64) -> [u8; 8] {
[
((value >> 56) & 0b11111111) as u8,
((value >> 48) & 0b11111111) as u8,
((value >> 40) & 0b11111111) as u8,
((value >> 32) & 0b11111111) as u8,
((value >> 24) & 0b11111111) as u8,
((value >> 16) & 0b11111111) as u8,
((value >> 8) & 0b11111111) as u8,
(value % 256) as u8,
]
}
#[test]
fn test_bytes_to_u64() {
let mut i: u128 = 0;
while i <= std::u64::MAX as u128 {
let bytes = u64_to_bytes(i as u64);
let num = bytes_to_u64(&bytes);
assert_eq!(num, i as u64);
i = (i + 1) * 7;
}
}