use std::sync::Arc;
use std::time::Duration;
use redis::aio::MultiplexedConnection;
use redis::cluster::{ClusterClient, ClusterConfig};
use redis::cluster_async::ClusterConnection;
use tokio_util::sync::CancellationToken;
use crate::error::{Result, ShoveError};
use crate::retry::Backoff;
use super::constants::{BLOCK_MS, DEFAULT_GROUP};
pub enum RedisMode {
Standalone { url: String },
Cluster { urls: Vec<String> },
}
pub const DEFAULT_RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
pub struct RedisConfig {
pub mode: RedisMode,
pub group: Option<String>,
pub(crate) response_timeout: Duration,
pub(crate) connection_timeout: Duration,
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
mode: RedisMode::Standalone {
url: "redis://127.0.0.1:6379/".to_string(),
},
group: None,
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
connection_timeout: DEFAULT_CONNECTION_TIMEOUT,
}
}
}
impl RedisConfig {
pub fn new(mode: RedisMode) -> Self {
Self {
mode,
..Self::default()
}
}
pub fn with_group(mut self, group: impl Into<String>) -> Self {
self.group = Some(group.into());
self
}
pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
assert!(
timeout > Duration::from_millis(BLOCK_MS),
"response_timeout ({} ms) must exceed BLOCK_MS ({} ms)",
timeout.as_millis(),
BLOCK_MS,
);
self.response_timeout = timeout;
self
}
pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
assert!(!timeout.is_zero(), "connection_timeout must be positive");
self.connection_timeout = timeout;
self
}
pub fn resolved_group(&self) -> &str {
self.group.as_deref().unwrap_or(DEFAULT_GROUP)
}
}
pub(crate) enum RedisConnection {
Standalone(MultiplexedConnection),
Cluster(ClusterConnection),
}
impl Clone for RedisConnection {
fn clone(&self) -> Self {
match self {
RedisConnection::Standalone(c) => RedisConnection::Standalone(c.clone()),
RedisConnection::Cluster(c) => RedisConnection::Cluster(c.clone()),
}
}
}
impl RedisConnection {
pub(crate) async fn query<T: redis::FromRedisValue + Send>(
&mut self,
cmd: &mut redis::Cmd,
) -> Result<T> {
match self {
RedisConnection::Standalone(conn) => cmd
.query_async(conn)
.await
.map_err(|e| ShoveError::Connection(e.to_string())),
RedisConnection::Cluster(conn) => cmd
.query_async(conn)
.await
.map_err(|e| ShoveError::Connection(e.to_string())),
}
}
}
enum ClientInner {
Standalone(redis::Client),
Cluster(ClusterClient),
}
#[derive(Clone)]
pub struct RedisClient {
inner: Arc<ClientInner>,
pub(super) group: String,
response_timeout: Duration,
connection_timeout: Duration,
}
impl RedisClient {
pub(super) async fn connect(config: RedisConfig) -> Result<Self> {
let group = config.resolved_group().to_owned();
let response_timeout = config.response_timeout;
let connection_timeout = config.connection_timeout;
let inner = match config.mode {
RedisMode::Standalone { url } => {
let client = redis::Client::open(url.as_str())
.map_err(|e| ShoveError::Connection(e.to_string()))?;
client
.get_multiplexed_async_connection_with_config(&async_config(
response_timeout,
connection_timeout,
))
.await
.map_err(|e| ShoveError::Connection(format!("standalone ping failed: {e}")))?;
ClientInner::Standalone(client)
}
RedisMode::Cluster { ref urls } => {
if urls.is_empty() {
return Err(ShoveError::Connection(
"cluster URLs must not be empty".into(),
));
}
let nodes: Vec<&str> = urls.iter().map(String::as_str).collect();
let client =
ClusterClient::new(nodes).map_err(|e| ShoveError::Connection(e.to_string()))?;
client
.get_async_connection_with_config(cluster_config(
response_timeout,
connection_timeout,
))
.await
.map_err(|e| ShoveError::Connection(format!("cluster ping failed: {e}")))?;
ClientInner::Cluster(client)
}
};
let client = Self {
inner: Arc::new(inner),
group,
response_timeout,
connection_timeout,
};
let mut conn = client.multiplexed_conn().await?;
check_min_version(&mut conn).await?;
Ok(client)
}
pub(super) async fn multiplexed_conn(&self) -> Result<RedisConnection> {
match self.inner.as_ref() {
ClientInner::Standalone(client) => client
.get_multiplexed_async_connection_with_config(&async_config(
self.response_timeout,
self.connection_timeout,
))
.await
.map(RedisConnection::Standalone)
.map_err(|e| ShoveError::Connection(e.to_string())),
ClientInner::Cluster(client) => client
.get_async_connection_with_config(cluster_config(
self.response_timeout,
self.connection_timeout,
))
.await
.map(RedisConnection::Cluster)
.map_err(|e| ShoveError::Connection(e.to_string())),
}
}
pub(super) async fn dedicated_conn(&self) -> Result<RedisConnection> {
self.multiplexed_conn().await
}
pub(super) async fn ping(&self, timeout: Duration) -> Result<()> {
let fut = async {
let mut conn = self.multiplexed_conn().await?;
let reply: String = conn
.query::<String>(&mut redis::cmd("PING"))
.await
.map_err(|e| ShoveError::Connection(format!("redis ping failed: {e}")))?;
if reply != "PONG" {
return Err(ShoveError::Connection(format!(
"redis ping returned {reply:?}, expected PONG"
)));
}
Ok::<(), ShoveError>(())
};
tokio::time::timeout(timeout, fut).await.map_err(|_| {
ShoveError::Connection(format!("redis ping timed out after {timeout:?}"))
})?
}
pub(super) fn group(&self) -> &str {
&self.group
}
}
fn async_config(response: Duration, connection: Duration) -> redis::AsyncConnectionConfig {
redis::AsyncConnectionConfig::new()
.set_response_timeout(Some(response))
.set_connection_timeout(Some(connection))
}
fn cluster_config(response: Duration, connection: Duration) -> ClusterConfig {
ClusterConfig::new()
.set_response_timeout(response)
.set_connection_timeout(connection)
}
fn check_version_info(info: &str) -> Result<()> {
let version = info
.lines()
.find_map(|line| line.strip_prefix("redis_version:"))
.ok_or_else(|| ShoveError::Connection("could not determine Redis version".into()))?
.trim();
let mut parts = version.splitn(3, '.');
let major: u32 = parts
.next()
.and_then(|s| s.parse().ok())
.ok_or_else(|| ShoveError::Connection(format!("unparseable Redis version: {version}")))?;
let minor: u32 = parts
.next()
.and_then(|s| s.parse().ok())
.ok_or_else(|| ShoveError::Connection(format!("unparseable Redis version: {version}")))?;
if (major, minor) < (6, 2) {
return Err(ShoveError::Connection(format!(
"Redis {version} is not supported; shove requires Redis 6.2 or newer"
)));
}
Ok(())
}
async fn check_min_version(conn: &mut RedisConnection) -> Result<()> {
let reply: redis::Value = conn
.query(redis::cmd("INFO").arg("server"))
.await
.map_err(|e| ShoveError::Connection(format!("INFO server failed: {e}")))?;
for info in info_payloads(reply)? {
check_version_info(&info)?;
}
Ok(())
}
fn info_payloads(value: redis::Value) -> Result<Vec<String>> {
fn decode_string(v: redis::Value) -> Result<String> {
match v {
redis::Value::SimpleString(s) => Ok(s),
redis::Value::BulkString(b) => String::from_utf8(b)
.map_err(|e| ShoveError::Connection(format!("INFO server not UTF-8: {e}"))),
redis::Value::VerbatimString { text, .. } => Ok(text),
other => Err(ShoveError::Connection(format!(
"unexpected INFO payload: {other:?}"
))),
}
}
match value {
redis::Value::SimpleString(_)
| redis::Value::BulkString(_)
| redis::Value::VerbatimString { .. } => Ok(vec![decode_string(value)?]),
redis::Value::Map(entries) => {
if entries.is_empty() {
return Err(ShoveError::Connection("INFO server: empty map".into()));
}
entries
.into_iter()
.map(|(_node, payload)| decode_string(payload))
.collect()
}
other => Err(ShoveError::Connection(format!(
"unexpected INFO server reply: {other:?}"
))),
}
}
pub(super) async fn acquire_conn_with_retry(
client: &RedisClient,
shutdown: &CancellationToken,
task: &str,
) -> Option<RedisConnection> {
let mut backoff = Backoff::default();
loop {
match client.multiplexed_conn().await {
Ok(c) => return Some(c),
Err(e) => {
if shutdown.is_cancelled() {
return None;
}
let delay = backoff.next().expect("backoff is infinite");
tracing::warn!(
"{task}: connection failed ({e}), retrying in {:.1}s",
delay.as_secs_f64()
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return None,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_group() {
let cfg = RedisConfig {
mode: RedisMode::Standalone {
url: "redis://127.0.0.1:6379/".to_string(),
},
group: None,
..RedisConfig::default()
};
assert_eq!(cfg.resolved_group(), "shove");
}
#[test]
fn config_custom_group() {
let cfg = RedisConfig {
mode: RedisMode::Standalone {
url: "redis://127.0.0.1:6379/".to_string(),
},
group: Some("myapp".to_string()),
..RedisConfig::default()
};
assert_eq!(cfg.resolved_group(), "myapp");
}
#[test]
fn standalone_url_preserved() {
let url = "rediss://user:pass@myhost:6380/".to_string();
let config = RedisConfig {
mode: RedisMode::Standalone { url: url.clone() },
group: None,
..RedisConfig::default()
};
match config.mode {
RedisMode::Standalone { url: stored } => assert_eq!(stored, url),
_ => panic!("expected Standalone"),
}
}
#[test]
fn cluster_urls_preserved() {
let urls = vec![
"redis://node1:6379/".to_string(),
"redis://node2:6379/".to_string(),
"redis://node3:6379/".to_string(),
];
let config = RedisConfig {
mode: RedisMode::Cluster { urls: urls.clone() },
group: None,
..RedisConfig::default()
};
match config.mode {
RedisMode::Cluster { urls: stored } => assert_eq!(stored, urls),
_ => panic!("expected Cluster"),
}
}
#[test]
fn resolved_group_empty_string_preserved() {
let cfg = RedisConfig {
mode: RedisMode::Standalone {
url: "redis://127.0.0.1:6379/".to_string(),
},
group: Some(String::new()),
..RedisConfig::default()
};
assert_eq!(cfg.resolved_group(), "");
}
#[test]
fn redis_mode_standalone_variant_matches() {
let cfg = RedisConfig {
mode: RedisMode::Standalone {
url: "redis://localhost/".to_string(),
},
group: None,
..RedisConfig::default()
};
assert!(matches!(cfg.mode, RedisMode::Standalone { .. }));
}
#[test]
fn redis_mode_cluster_variant_matches() {
let cfg = RedisConfig {
mode: RedisMode::Cluster {
urls: vec!["redis://node1/".to_string()],
},
group: None,
..RedisConfig::default()
};
assert!(matches!(cfg.mode, RedisMode::Cluster { .. }));
}
#[test]
fn default_has_shove_tuned_timeouts() {
let cfg = RedisConfig::default();
assert_eq!(cfg.response_timeout, DEFAULT_RESPONSE_TIMEOUT);
assert_eq!(cfg.connection_timeout, DEFAULT_CONNECTION_TIMEOUT);
assert!(cfg.response_timeout > Duration::from_millis(BLOCK_MS));
}
#[test]
fn new_constructor_seeds_defaults() {
let cfg = RedisConfig::new(RedisMode::Standalone {
url: "redis://127.0.0.1:6379/".to_string(),
});
assert_eq!(cfg.response_timeout, DEFAULT_RESPONSE_TIMEOUT);
assert_eq!(cfg.connection_timeout, DEFAULT_CONNECTION_TIMEOUT);
assert!(cfg.group.is_none());
}
#[test]
fn with_response_timeout_round_trips() {
let cfg = RedisConfig::default().with_response_timeout(Duration::from_secs(60));
assert_eq!(cfg.response_timeout, Duration::from_secs(60));
}
#[test]
#[should_panic(expected = "must exceed BLOCK_MS")]
fn with_response_timeout_below_block_ms_panics() {
let _ = RedisConfig::default().with_response_timeout(Duration::from_millis(BLOCK_MS));
}
#[test]
fn with_response_timeout_at_block_ms_plus_one_accepted() {
let cfg = RedisConfig::default().with_response_timeout(Duration::from_millis(BLOCK_MS + 1));
assert_eq!(cfg.response_timeout, Duration::from_millis(BLOCK_MS + 1));
}
#[test]
fn with_connection_timeout_round_trips() {
let cfg = RedisConfig::default().with_connection_timeout(Duration::from_secs(5));
assert_eq!(cfg.connection_timeout, Duration::from_secs(5));
}
#[test]
#[should_panic(expected = "connection_timeout must be positive")]
fn with_connection_timeout_zero_panics() {
let _ = RedisConfig::default().with_connection_timeout(Duration::ZERO);
}
fn make_info(version: &str) -> String {
format!("# Server\r\nredis_version:{version}\r\nredis_git_sha1:00000000\r\nos:Linux\r\n")
}
#[test]
fn version_6_2_0_is_accepted() {
assert!(check_version_info(&make_info("6.2.0")).is_ok());
}
#[test]
fn version_6_2_14_is_accepted() {
assert!(check_version_info(&make_info("6.2.14")).is_ok());
}
#[test]
fn version_7_0_0_is_accepted() {
assert!(check_version_info(&make_info("7.0.0")).is_ok());
}
#[test]
fn version_8_0_0_is_accepted() {
assert!(check_version_info(&make_info("8.0.0")).is_ok());
}
#[test]
fn version_5_0_0_is_rejected() {
let err = check_version_info(&make_info("5.0.0")).unwrap_err();
assert!(err.to_string().contains("5.0.0"));
assert!(err.to_string().contains("6.2"));
}
#[test]
fn version_6_0_0_is_rejected() {
let err = check_version_info(&make_info("6.0.0")).unwrap_err();
assert!(err.to_string().contains("6.0.0"));
}
#[test]
fn version_6_1_9_is_rejected() {
let err = check_version_info(&make_info("6.1.9")).unwrap_err();
assert!(err.to_string().contains("6.1.9"));
}
#[test]
fn missing_version_line_is_an_error() {
let info = "# Server\r\nredis_git_sha1:00000000\r\nos:Linux\r\n";
let err = check_version_info(info).unwrap_err();
assert!(
err.to_string()
.contains("could not determine Redis version")
);
}
#[test]
fn malformed_version_no_dots_is_an_error() {
let err = check_version_info(&make_info("garbage")).unwrap_err();
assert!(err.to_string().contains("unparseable Redis version"));
}
#[test]
fn malformed_version_major_only_is_an_error() {
let err = check_version_info(&make_info("7")).unwrap_err();
assert!(err.to_string().contains("unparseable Redis version"));
}
#[test]
fn version_line_with_surrounding_noise_is_parsed_correctly() {
let info = "# Server\r\nredis_version:7.2.5\r\nredis_git_sha1:00000000\r\nredis_git_dirty:0\r\nredis_build_id:abc\r\nredis_mode:standalone\r\nos:Linux\r\narch_bits:64\r\n";
assert!(check_version_info(info).is_ok());
}
#[test]
fn valkey_reports_compatible_redis_version() {
let info = "# Server\r\nredis_version:7.2.4\r\nvalkey_version:8.0.0\r\n";
assert!(check_version_info(info).is_ok());
}
#[test]
fn info_payloads_handles_standalone_bulk_string() {
let value = redis::Value::BulkString(make_info("7.0.0").into_bytes());
let payloads = info_payloads(value).unwrap();
assert_eq!(payloads.len(), 1);
assert!(payloads[0].contains("redis_version:7.0.0"));
}
#[test]
fn info_payloads_handles_standalone_simple_string() {
let value = redis::Value::SimpleString(make_info("7.0.0"));
let payloads = info_payloads(value).unwrap();
assert_eq!(payloads.len(), 1);
assert!(payloads[0].contains("redis_version:7.0.0"));
}
#[test]
fn info_payloads_handles_cluster_map_with_one_entry_per_master() {
let value = redis::Value::Map(vec![
(
redis::Value::BulkString(b"127.0.0.1:7001".to_vec()),
redis::Value::BulkString(make_info("7.0.10").into_bytes()),
),
(
redis::Value::BulkString(b"127.0.0.1:7002".to_vec()),
redis::Value::BulkString(make_info("7.0.10").into_bytes()),
),
(
redis::Value::BulkString(b"127.0.0.1:7000".to_vec()),
redis::Value::BulkString(make_info("7.0.10").into_bytes()),
),
]);
let payloads = info_payloads(value).unwrap();
assert_eq!(payloads.len(), 3);
for payload in &payloads {
assert!(payload.contains("redis_version:7.0.10"));
}
}
#[test]
fn info_payloads_empty_map_is_an_error() {
let err = info_payloads(redis::Value::Map(vec![])).unwrap_err();
assert!(err.to_string().contains("empty map"));
}
#[test]
fn info_payloads_unexpected_reply_is_an_error() {
let err = info_payloads(redis::Value::Int(42)).unwrap_err();
assert!(err.to_string().contains("unexpected INFO server reply"));
}
}