use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use hitbox::{BackendLabel, CacheKey, CacheValue, Raw};
use hitbox_backend::{
Backend, BackendError, BackendResult, CacheKeyFormat, Compressor, DeleteStatus,
PassthroughCompressor,
format::{BincodeFormat, Format},
};
use redis::Client;
use redis::aio::ConnectionManager;
#[cfg(feature = "cluster")]
use redis::cluster_async::ClusterConnection;
use tokio::sync::OnceCell;
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct SingleConfig {
pub(crate) url: String,
pub(crate) exponent_base: f32,
}
impl SingleConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
exponent_base: 2.0,
}
}
}
#[cfg(feature = "cluster")]
#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))]
#[derive(Debug, Clone)]
pub struct ClusterConfig {
pub(crate) nodes: Vec<String>,
pub(crate) read_from_replicas: bool,
}
#[cfg(feature = "cluster")]
impl ClusterConfig {
pub fn new<I, S>(nodes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
nodes: nodes.into_iter().map(Into::into).collect(),
read_from_replicas: false,
}
}
}
#[derive(Debug, Clone)]
pub enum ConnectionMode {
Single(SingleConfig),
#[cfg(feature = "cluster")]
#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))]
Cluster(ClusterConfig),
}
impl ConnectionMode {
pub fn single(url: impl Into<String>) -> Self {
Self::Single(SingleConfig::new(url))
}
#[cfg(feature = "cluster")]
#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))]
pub fn cluster<I, S>(nodes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self::Cluster(ClusterConfig::new(nodes))
}
#[allow(irrefutable_let_patterns)]
pub fn exponent_base(mut self, base: f32) -> Self {
if let Self::Single(ref mut config) = self {
config.exponent_base = base;
}
self
}
#[cfg(feature = "cluster")]
#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))]
pub fn read_from_replicas(mut self) -> Self {
if let Self::Cluster(ref mut config) = self {
config.read_from_replicas = true;
}
self
}
}
#[derive(Clone)]
enum RedisConnection {
Single(ConnectionManager),
#[cfg(feature = "cluster")]
Cluster(ClusterConnection),
}
impl RedisConnection {
async fn query_pipeline<T: redis::FromRedisValue>(
&mut self,
pipe: &redis::Pipeline,
) -> Result<T, redis::RedisError> {
match self {
Self::Single(conn) => pipe.query_async(conn).await,
#[cfg(feature = "cluster")]
Self::Cluster(conn) => pipe.query_async(conn).await,
}
}
async fn query_cmd<T: redis::FromRedisValue>(
&mut self,
cmd: &mut redis::Cmd,
) -> Result<T, redis::RedisError> {
match self {
Self::Single(conn) => cmd.query_async(conn).await,
#[cfg(feature = "cluster")]
Self::Cluster(conn) => cmd.query_async(conn).await,
}
}
}
#[derive(Clone)]
pub struct RedisBackend<S = BincodeFormat, C = PassthroughCompressor>
where
S: Format,
C: Compressor,
{
mode: ConnectionMode,
connection_timeout: Option<Duration>,
response_timeout: Option<Duration>,
number_of_retries: Option<usize>,
username: Option<String>,
password: Option<String>,
connection: OnceCell<RedisConnection>,
serializer: S,
key_format: CacheKeyFormat,
compressor: C,
label: BackendLabel,
}
impl RedisBackend<BincodeFormat, PassthroughCompressor> {
#[must_use]
pub fn builder() -> RedisBackendBuilder<BincodeFormat, PassthroughCompressor> {
RedisBackendBuilder::default()
}
}
impl<S, C> RedisBackend<S, C>
where
S: Format,
C: Compressor,
{
async fn get_connection(&self) -> Result<&RedisConnection, Error> {
self.connection
.get_or_try_init(|| async {
match &self.mode {
ConnectionMode::Single(config) => {
let mut conn_info: redis::ConnectionInfo = config.url.as_str().parse()?;
let mut redis_info = conn_info.redis_settings().clone();
if let Some(ref username) = self.username {
redis_info = redis_info.set_username(username);
}
if let Some(ref password) = self.password {
redis_info = redis_info.set_password(password);
}
conn_info = conn_info.set_redis_settings(redis_info);
let client = Client::open(conn_info)?;
let mut manager_config = redis::aio::ConnectionManagerConfig::new()
.set_exponent_base(config.exponent_base);
if let Some(timeout) = self.connection_timeout {
manager_config = manager_config.set_connection_timeout(Some(timeout));
}
if let Some(timeout) = self.response_timeout {
manager_config = manager_config.set_response_timeout(Some(timeout));
}
if let Some(retries) = self.number_of_retries {
manager_config = manager_config.set_number_of_retries(retries);
}
let conn = client
.get_connection_manager_with_config(manager_config)
.await?;
Ok(RedisConnection::Single(conn))
}
#[cfg(feature = "cluster")]
ConnectionMode::Cluster(config) => {
let mut builder = redis::cluster::ClusterClientBuilder::new(
config.nodes.iter().map(|s| s.as_str()),
);
if config.read_from_replicas {
builder = builder.read_from_replicas();
}
if let Some(ref username) = self.username {
builder = builder.username(username.clone());
}
if let Some(ref password) = self.password {
builder = builder.password(password.clone());
}
if let Some(timeout) = self.connection_timeout {
builder = builder.connection_timeout(timeout);
}
if let Some(timeout) = self.response_timeout {
builder = builder.response_timeout(timeout);
}
if let Some(retries) = self.number_of_retries {
builder = builder.retries(retries as u32);
}
let client = builder.build()?;
let conn = client.get_async_connection().await?;
Ok(RedisConnection::Cluster(conn))
}
}
})
.await
}
}
pub struct RedisBackendBuilder<S = BincodeFormat, C = PassthroughCompressor>
where
S: Format,
C: Compressor,
{
mode: Option<ConnectionMode>,
serializer: S,
key_format: CacheKeyFormat,
compressor: C,
label: BackendLabel,
connection_timeout: Option<Duration>,
response_timeout: Option<Duration>,
number_of_retries: Option<usize>,
username: Option<String>,
password: Option<String>,
}
impl Default for RedisBackendBuilder<BincodeFormat, PassthroughCompressor> {
fn default() -> Self {
Self {
mode: None,
serializer: BincodeFormat,
key_format: CacheKeyFormat::default(),
compressor: PassthroughCompressor,
label: BackendLabel::new_static("redis"),
connection_timeout: None,
response_timeout: None,
number_of_retries: None,
username: None,
password: None,
}
}
}
impl<S, C> RedisBackendBuilder<S, C>
where
S: Format,
C: Compressor,
{
pub fn connection(mut self, mode: ConnectionMode) -> Self {
self.mode = Some(mode);
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = Some(timeout);
self
}
pub fn response_timeout(mut self, timeout: Duration) -> Self {
self.response_timeout = Some(timeout);
self
}
pub fn retries(mut self, count: usize) -> Self {
self.number_of_retries = Some(count);
self
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password = Some(password.into());
self
}
pub fn value_format<NewS>(self, serializer: NewS) -> RedisBackendBuilder<NewS, C>
where
NewS: Format,
{
RedisBackendBuilder {
mode: self.mode,
serializer,
key_format: self.key_format,
compressor: self.compressor,
label: self.label,
connection_timeout: self.connection_timeout,
response_timeout: self.response_timeout,
number_of_retries: self.number_of_retries,
username: self.username,
password: self.password,
}
}
pub fn key_format(mut self, key_format: CacheKeyFormat) -> Self {
self.key_format = key_format;
self
}
pub fn label(mut self, label: impl Into<BackendLabel>) -> Self {
self.label = label.into();
self
}
pub fn compressor<NewC>(self, compressor: NewC) -> RedisBackendBuilder<S, NewC>
where
NewC: Compressor,
{
RedisBackendBuilder {
mode: self.mode,
serializer: self.serializer,
key_format: self.key_format,
compressor,
label: self.label,
connection_timeout: self.connection_timeout,
response_timeout: self.response_timeout,
number_of_retries: self.number_of_retries,
username: self.username,
password: self.password,
}
}
pub fn build(self) -> Result<RedisBackend<S, C>, Error> {
let mode = self.mode.ok_or(Error::MissingConnectionMode)?;
Ok(RedisBackend {
mode,
connection_timeout: self.connection_timeout,
response_timeout: self.response_timeout,
number_of_retries: self.number_of_retries,
username: self.username,
password: self.password,
connection: OnceCell::new(),
serializer: self.serializer,
key_format: self.key_format,
compressor: self.compressor,
label: self.label,
})
}
}
#[async_trait]
impl<S, C> Backend for RedisBackend<S, C>
where
S: Format + Send + Sync,
C: Compressor + Send + Sync,
{
async fn read(&self, key: &CacheKey) -> BackendResult<Option<CacheValue<Raw>>> {
let mut con = self.get_connection().await?.clone();
let cache_key = self.key_format.serialize(key)?;
let ((data, stale_ms), pttl): ((Option<Vec<u8>>, Option<i64>), i64) = con
.query_pipeline(
redis::pipe()
.cmd("HMGET")
.arg(&cache_key)
.arg("d")
.arg("s")
.cmd("PTTL")
.arg(&cache_key),
)
.await
.map_err(Error::from)?;
let data = match data {
Some(data) => Bytes::from(data),
None => return Ok(None),
};
let stale = stale_ms.and_then(DateTime::from_timestamp_millis);
let expire = (pttl > 0).then(|| Utc::now() + chrono::Duration::milliseconds(pttl));
Ok(Some(CacheValue::new(data, expire, stale)))
}
async fn write(&self, key: &CacheKey, value: CacheValue<Raw>) -> BackendResult<()> {
let mut con = self.get_connection().await?.clone();
let cache_key = self.key_format.serialize(key)?;
let mut cmd = redis::cmd("HSET");
cmd.arg(&cache_key).arg("d").arg(value.data().as_ref());
if let Some(stale) = value.stale() {
cmd.arg("s").arg(stale.timestamp_millis());
}
let mut pipe = redis::pipe();
pipe.add_command(cmd).ignore();
if let Some(ttl_duration) = value.ttl() {
pipe.cmd("PEXPIRE")
.arg(&cache_key)
.arg(u64::try_from(ttl_duration.as_millis()).map_err(|_| {
BackendError::InternalError(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"TTL overflow: {}ms exceeds u64 range",
ttl_duration.as_millis()
),
)))
})?)
.ignore();
}
con.query_pipeline::<()>(&pipe).await.map_err(Error::from)?;
Ok(())
}
async fn remove(&self, key: &CacheKey) -> BackendResult<DeleteStatus> {
let mut con = self.get_connection().await?.clone();
let cache_key = self.key_format.serialize(key)?;
let deleted: i32 = con
.query_cmd(redis::cmd("DEL").arg(cache_key))
.await
.map_err(Error::from)?;
if deleted > 0 {
Ok(DeleteStatus::Deleted(deleted as u32))
} else {
Ok(DeleteStatus::Missing)
}
}
fn label(&self) -> BackendLabel {
self.label.clone()
}
fn value_format(&self) -> &dyn Format {
&self.serializer
}
fn key_format(&self) -> &CacheKeyFormat {
&self.key_format
}
fn compressor(&self) -> &dyn Compressor {
&self.compressor
}
}
impl<S, C> hitbox_backend::CacheBackend for RedisBackend<S, C>
where
S: Format + Send + Sync,
C: Compressor + Send + Sync,
{
}