use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use serde::de::DeserializeOwned;
use serde::Serialize;
use sha2::{Digest, Sha256};
mod backend;
mod config;
pub use self::backend::*;
pub use self::config::Config;
#[async_trait::async_trait]
pub trait HttpCacheStorage {
fn set_expiration_times(&mut self, cache_ttl: Duration, cache_tti: Duration);
async fn get(&self, key: &HttpCacheKey) -> Option<Vec<u8>>;
async fn set(&self, key: HttpCacheKey, value: Vec<u8>);
}
pub struct HttpCache {
pub ttl: Duration,
pub tti: Duration,
storage: Arc<Box<dyn HttpCacheStorage + Send + Sync>>,
}
impl Default for HttpCache {
fn default() -> Self {
Self::new(
Duration::from_secs(DEFAULT_TTL_SECS),
Duration::from_secs(DEFAULT_TTI_SECS),
None,
)
}
}
const MAX_PAYLOAD_SIZE: usize = 10 * 1024 * 1024;
const DEFAULT_TTL_SECS: u64 = 60;
const DEFAULT_TTI_SECS: u64 = 60;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct HttpCacheKey([u8; 32]);
impl Deref for HttpCacheKey {
type Target = [u8; 32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<config::Config> for HttpCache {
fn from(config: config::Config) -> Self {
match config.backend {
config::Backend::Memory => Self::new(
Duration::from_secs(config.ttl.unwrap_or(DEFAULT_TTL_SECS)),
Duration::from_secs(config.tti.unwrap_or(DEFAULT_TTI_SECS)),
None,
),
#[cfg(feature = "redis")]
config::Backend::Redis(redis_config) => {
let client = redis::Client::open(redis_config.connection_string)
.expect("Failed to create Redis client");
let storage = HttpCacheRedis::new(client).set_prefix(
redis_config
.key_prefix
.unwrap_or_default()
.as_bytes()
.to_vec(),
);
Self::new(
Duration::from_secs(config.ttl.unwrap_or(DEFAULT_TTL_SECS)),
Duration::from_secs(config.tti.unwrap_or(DEFAULT_TTI_SECS)),
Some(Box::new(storage)),
)
}
}
}
}
impl HttpCache {
pub fn new(
ttl: Duration,
tti: Duration,
storage: Option<Box<dyn HttpCacheStorage + Send + Sync + 'static>>,
) -> Self {
let mut storage = storage.unwrap_or_else(|| Box::<InMemoryHttpCache>::default());
storage.set_expiration_times(ttl, tti);
Self {
ttl,
tti,
storage: Arc::new(storage),
}
}
pub fn calculate_key<K>(&self, key: &K) -> Option<HttpCacheKey>
where
K: Serialize,
{
let json_value = match serde_json::to_vec(key) {
Ok(value) => value,
Err(err) => {
tracing::warn!("Failed to serialize key: {:?}", err);
return None;
}
};
if json_value.len() > MAX_PAYLOAD_SIZE {
tracing::warn!("Key size is too large: {}", json_value.len());
return None;
}
let first_hash = Sha256::digest(json_value);
let second_hash = Sha256::digest(first_hash);
Some(HttpCacheKey(second_hash.into()))
}
pub async fn get<V>(self: &Arc<Self>, key: &HttpCacheKey) -> Option<V>
where
V: DeserializeOwned,
{
self.storage.get(key).await.and_then(|value| {
serde_json::from_slice(&value)
.map_err(|e| {
tracing::warn!("Failed to deserialize value: {:?}", e);
e
})
.ok()
})
}
pub async fn set<V: Serialize>(self: &Arc<Self>, key: HttpCacheKey, value: &V) {
if let Ok(bytes) = serde_json::to_vec(value).map_err(|e| {
tracing::warn!("Failed to serialize value: {:?}", e);
e
}) {
self.storage.set(key, bytes).await;
}
}
}