use super::cache_trait::Cache;
use async_trait::async_trait;
use reinhardt_core::exception::Result;
#[async_trait]
pub trait CacheWarmer<C: Cache>: Send + Sync {
async fn warm(&self, cache: std::sync::Arc<C>) -> Result<()>;
}
pub struct FunctionWarmer<C, F>
where
C: Cache,
F: Fn(
std::sync::Arc<C>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
{
func: F,
_phantom: std::marker::PhantomData<C>,
}
impl<C, F> FunctionWarmer<C, F>
where
C: Cache,
F: Fn(
std::sync::Arc<C>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
{
pub fn new(func: F) -> Self {
Self {
func,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<C, F> CacheWarmer<C> for FunctionWarmer<C, F>
where
C: Cache,
F: Fn(
std::sync::Arc<C>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
{
async fn warm(&self, cache: std::sync::Arc<C>) -> Result<()> {
(self.func)(cache).await
}
}
pub struct BatchWarmer<C: Cache> {
warmers: Vec<Box<dyn CacheWarmer<C>>>,
}
impl<C: Cache> BatchWarmer<C> {
pub fn new() -> Self {
Self {
warmers: Vec::new(),
}
}
pub fn with_warmer(mut self, warmer: Box<dyn CacheWarmer<C>>) -> Self {
self.warmers.push(warmer);
self
}
}
impl<C: Cache> Default for BatchWarmer<C> {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<C: Cache> CacheWarmer<C> for BatchWarmer<C> {
async fn warm(&self, cache: std::sync::Arc<C>) -> Result<()> {
for warmer in &self.warmers {
warmer.warm(cache.clone()).await?;
}
Ok(())
}
}
pub struct ParallelWarmer<C: Cache> {
warmers: Vec<Box<dyn CacheWarmer<C>>>,
}
impl<C: Cache> ParallelWarmer<C> {
pub fn new() -> Self {
Self {
warmers: Vec::new(),
}
}
pub fn with_warmer(mut self, warmer: Box<dyn CacheWarmer<C>>) -> Self {
self.warmers.push(warmer);
self
}
}
impl<C: Cache> Default for ParallelWarmer<C> {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<C: Cache> CacheWarmer<C> for ParallelWarmer<C> {
async fn warm(&self, cache: std::sync::Arc<C>) -> Result<()> {
let tasks: Vec<_> = self
.warmers
.iter()
.map(|warmer| {
let cache = cache.clone();
async move { warmer.warm(cache).await }
})
.collect();
for result in futures::future::join_all(tasks).await {
result?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::InMemoryCache;
use std::sync::Arc;
struct TestWarmer {
key: String,
value: String,
}
#[async_trait]
impl CacheWarmer<InMemoryCache> for TestWarmer {
async fn warm(&self, cache: Arc<InMemoryCache>) -> Result<()> {
cache.set(&self.key, &self.value, None).await
}
}
#[tokio::test]
async fn test_basic_warmer() {
let cache = Arc::new(InMemoryCache::new());
let warmer = TestWarmer {
key: "test_key".to_string(),
value: "test_value".to_string(),
};
warmer.warm(cache.clone()).await.unwrap();
let value: Option<String> = cache.get("test_key").await.unwrap();
assert_eq!(value, Some("test_value".to_string()));
}
#[tokio::test]
async fn test_function_warmer() {
let cache = Arc::new(InMemoryCache::new());
let warmer = FunctionWarmer::new(|cache: Arc<InMemoryCache>| {
Box::pin(async move {
cache.set("func_key", &"func_value", None).await?;
Ok(())
})
});
warmer.warm(cache.clone()).await.unwrap();
let value: Option<String> = cache.get("func_key").await.unwrap();
assert_eq!(value, Some("func_value".to_string()));
}
#[tokio::test]
async fn test_batch_warmer() {
let cache = Arc::new(InMemoryCache::new());
let warmer1 = TestWarmer {
key: "key1".to_string(),
value: "value1".to_string(),
};
let warmer2 = TestWarmer {
key: "key2".to_string(),
value: "value2".to_string(),
};
let batch = BatchWarmer::new()
.with_warmer(Box::new(warmer1))
.with_warmer(Box::new(warmer2));
batch.warm(cache.clone()).await.unwrap();
let value1: Option<String> = cache.get("key1").await.unwrap();
let value2: Option<String> = cache.get("key2").await.unwrap();
assert_eq!(value1, Some("value1".to_string()));
assert_eq!(value2, Some("value2".to_string()));
}
#[tokio::test]
async fn test_parallel_warmer() {
let cache = Arc::new(InMemoryCache::new());
let warmer1 = TestWarmer {
key: "parallel1".to_string(),
value: "value1".to_string(),
};
let warmer2 = TestWarmer {
key: "parallel2".to_string(),
value: "value2".to_string(),
};
let parallel = ParallelWarmer::new()
.with_warmer(Box::new(warmer1))
.with_warmer(Box::new(warmer2));
parallel.warm(cache.clone()).await.unwrap();
let value1: Option<String> = cache.get("parallel1").await.unwrap();
let value2: Option<String> = cache.get("parallel2").await.unwrap();
assert_eq!(value1, Some("value1".to_string()));
assert_eq!(value2, Some("value2".to_string()));
}
}