use std::{
collections::HashMap,
hash::Hash,
ops::Deref,
sync::Arc,
time::{Duration, Instant},
};
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::Client;
type ValueWithLastUpdate<T> = (Arc<T>, Instant);
pub(crate) struct CachedMap<T: MapCacheable> {
cached_value: RwLock<HashMap<T::Key, ValueWithLastUpdate<T>>>,
duration: RwLock<Duration>,
}
#[async_trait]
pub(crate) trait MapCacheable: Clone {
type Key: Eq + Hash + Clone;
type Error;
async fn fetch_uncached(
client: &Arc<Client>,
key: &Self::Key,
) -> Result<Option<Self>, Self::Error>;
}
impl<T: MapCacheable> CachedMap<T> {
pub(crate) fn new(duration: Duration) -> Self {
Self {
cached_value: RwLock::new(HashMap::new()),
duration: RwLock::new(duration),
}
}
pub(crate) async fn set_cache_time(&self, duration: Duration) {
*self.duration.write().await = duration;
}
pub(crate) async fn get(
&self,
client: &Arc<Client>,
key: &T::Key,
) -> Result<Option<Arc<T>>, T::Error> {
let map = self.cached_value.read().await;
if let Some((cached, last_update)) = map.get(key) {
if last_update.elapsed() < *self.duration.read().await {
return Ok(Some(cached.clone()));
}
}
drop(map);
self.refresh(client, key).await
}
pub(crate) async fn make_dirty(&self, key: &T::Key) {
let mut map = self.cached_value.write().await;
map.remove(key);
}
pub(crate) async fn refresh(
&self,
client: &Arc<Client>,
key: &T::Key,
) -> Result<Option<Arc<T>>, T::Error> {
if let Some(value) = T::fetch_uncached(client, key).await? {
let value = Arc::new(value);
let mut map = self.cached_value.write().await;
map.insert(key.clone(), (value.clone(), Instant::now()));
Ok(Some(value))
} else {
Ok(None)
}
}
}
#[async_trait]
pub(crate) trait BatchCacheable: MapCacheable {
async fn fetch_uncached_batch(
client: &Arc<Client>,
keys: &[Self::Key],
) -> Result<Vec<(Self::Key, Self)>, Self::Error>;
}
impl<T: BatchCacheable> CachedMap<T> {
pub(crate) async fn get_batch(
&self,
client: &Arc<Client>,
keys: &[T::Key],
) -> Result<HashMap<T::Key, Arc<T>>, T::Error> {
let map = self.cached_value.read().await;
let mut result = HashMap::new();
let mut uncached_keys = Vec::new();
for key in keys {
if let Some((cached, last_update)) = map.get(key) {
if last_update.elapsed() < *self.duration.read().await {
result.insert(key.clone(), cached.clone());
continue;
}
}
uncached_keys.push(key.clone());
}
if !uncached_keys.is_empty() {
drop(map);
result.extend(self.refresh_batch(client, &uncached_keys).await?);
}
Ok(result)
}
pub(crate) async fn make_dirty_batch(&self, keys: &[T::Key]) {
let mut map = self.cached_value.write().await;
for key in keys {
map.remove(key);
}
}
pub(crate) async fn refresh_batch(
&self,
client: &Arc<Client>,
keys: &[T::Key],
) -> Result<HashMap<T::Key, Arc<T>>, T::Error> {
let result = T::fetch_uncached_batch(client, keys).await?;
let result: HashMap<_, _> = result
.into_iter()
.map(|(key, value)| {
let value = Arc::new(value);
(key, value)
})
.collect();
let mut map = self.cached_value.write().await;
map.extend(
result
.iter()
.map(|(key, value)| (key.clone(), (value.clone(), Instant::now()))),
);
Ok(result)
}
}
#[async_trait]
pub(crate) trait AllCacheable: MapCacheable {
async fn fetch_uncached_all(
client: &Arc<Client>,
) -> Result<Vec<(Self::Key, Self)>, Self::Error>;
}
impl<T: AllCacheable> CachedMap<T> {
pub(crate) async fn refresh_all(
&self,
client: &Arc<Client>,
) -> Result<HashMap<T::Key, Arc<T>>, T::Error> {
let result = T::fetch_uncached_all(client).await?;
let result: HashMap<_, _> = result
.into_iter()
.map(|(key, value)| {
let value = Arc::new(value);
(key, value)
})
.collect();
let mut map = self.cached_value.write().await;
*map = result
.iter()
.map(|(key, value)| (key.clone(), (value.clone(), Instant::now())))
.collect();
Ok(result)
}
pub(crate) async fn make_dirty_all(&self) {
let mut map = self.cached_value.write().await;
*map = HashMap::new();
}
}
pub(crate) struct Cached<T: Cacheable> {
cached_value: RwLock<Option<ValueWithLastUpdate<T>>>,
duration: RwLock<Duration>,
}
#[async_trait]
pub(crate) trait Cacheable: Clone {
type Error;
async fn fetch_uncached(client: &Arc<Client>) -> Result<Self, Self::Error>;
}
impl<T: Cacheable> Cached<T> {
pub(crate) fn new(duration: Duration) -> Self {
Self {
cached_value: RwLock::new(None),
duration: RwLock::new(duration),
}
}
pub(crate) async fn set_cache_time(&self, duration: Duration) {
*self.duration.write().await = duration;
}
pub(crate) async fn get(&self, client: &Arc<Client>) -> Result<Arc<T>, T::Error> {
let locked = self.cached_value.read().await;
if let Some((cached, last_update)) = locked.deref() {
if last_update.elapsed() < *self.duration.read().await {
return Ok(cached.clone());
}
}
drop(locked);
self.refresh(client).await
}
pub(crate) async fn make_dirty(&self) {
let mut locked = self.cached_value.write().await;
*locked = None;
}
pub(crate) async fn refresh(&self, client: &Arc<Client>) -> Result<Arc<T>, T::Error> {
let value = Arc::new(T::fetch_uncached(client).await?);
let mut locked = self.cached_value.write().await;
*locked = Some((value.clone(), Instant::now()));
Ok(value)
}
}