use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use redis::aio::{ConnectionManager, ConnectionManagerConfig, MultiplexedConnection, PubSub};
use redis::sentinel::{SentinelClient, SentinelClientBuilder, SentinelServerType};
use redis::{ClientTlsConfig, ConnectionAddr, TlsCertificates, TlsMode};
use sockudo_core::error::{Error, Result};
use sockudo_core::options::{RedisTlsOptions, SentinelSpec};
use tracing::warn;
enum ClientSource {
Standalone(redis::Client),
Sentinel(tokio::sync::Mutex<SentinelClient>),
}
struct Inner {
source: ClientSource,
cm_config: ConnectionManagerConfig,
connection: Mutex<Option<ConnectionManager>>,
events_connection: Mutex<Option<ConnectionManager>>,
}
#[derive(Clone)]
pub struct RedisClient {
inner: Arc<Inner>,
}
impl RedisClient {
pub async fn connect(url: &str, sentinel: Option<SentinelSpec>) -> Result<Self> {
let cm_config = ConnectionManagerConfig::new()
.set_number_of_retries(5)
.set_exponent_base(2.0)
.set_max_delay(Duration::from_millis(5000));
let source = match sentinel {
Some(spec) => {
ClientSource::Sentinel(tokio::sync::Mutex::new(build_sentinel_client(&spec).await?))
}
None => {
let client = redis::Client::open(url)
.map_err(|e| Error::Redis(format!("Failed to create Redis client: {e}")))?;
ClientSource::Standalone(client)
}
};
let client = Self {
inner: Arc::new(Inner {
source,
cm_config,
connection: Mutex::new(None),
events_connection: Mutex::new(None),
}),
};
let _ = client.command_connection().await?;
let _ = client.events_connection().await?;
Ok(client)
}
async fn master_client(&self) -> Result<redis::Client> {
match &self.inner.source {
ClientSource::Standalone(client) => Ok(client.clone()),
ClientSource::Sentinel(sentinel) => {
let mut guard = sentinel.lock().await;
guard
.async_get_client()
.await
.map_err(|e| Error::Redis(format!("Failed to resolve Sentinel master: {e}")))
}
}
}
async fn get_or_build(
&self,
slot: &Mutex<Option<ConnectionManager>>,
) -> Result<ConnectionManager> {
{
let guard = slot.lock();
if let Some(manager) = guard.as_ref() {
return Ok(manager.clone());
}
}
let client = self.master_client().await?;
let manager = client
.get_connection_manager_with_config(self.inner.cm_config.clone())
.await
.map_err(|e| Error::Redis(format!("Failed to connect to Redis: {e}")))?;
*slot.lock() = Some(manager.clone());
Ok(manager)
}
pub async fn command_connection(&self) -> Result<ConnectionManager> {
self.get_or_build(&self.inner.connection).await
}
pub async fn events_connection(&self) -> Result<ConnectionManager> {
self.get_or_build(&self.inner.events_connection).await
}
pub fn invalidate(&self) {
if matches!(self.inner.source, ClientSource::Sentinel(_)) {
*self.inner.connection.lock() = None;
*self.inner.events_connection.lock() = None;
}
}
pub async fn pubsub(&self) -> Result<PubSub> {
let client = self.master_client().await?;
client
.get_async_pubsub()
.await
.map_err(|e| Error::Redis(format!("Failed to get pubsub connection: {e}")))
}
pub async fn multiplexed(&self) -> Result<MultiplexedConnection> {
let client = self.master_client().await?;
client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::Redis(format!("Failed to acquire connection: {e}")))
}
}
fn tls_mode(tls: &RedisTlsOptions) -> TlsMode {
if tls.accept_invalid_certs {
TlsMode::Insecure
} else {
TlsMode::Secure
}
}
async fn load_tls_certificates(
tls: &RedisTlsOptions,
hop: &str,
) -> Result<Option<TlsCertificates>> {
if (tls.client_cert_path.is_some()) ^ (tls.client_key_path.is_some()) {
warn!(
"Redis {hop} TLS: both client_cert_path and client_key_path are required for mutual TLS; ignoring the partial configuration"
);
}
if tls.ca_path.is_none() && !tls.has_client_cert() {
return Ok(None);
}
let root_cert = match &tls.ca_path {
Some(path) => Some(tokio::fs::read(path).await.map_err(|e| {
Error::Redis(format!(
"Failed to read Redis {hop} TLS CA certificate {path}: {e}"
))
})?),
None => None,
};
let client_tls = match (&tls.client_cert_path, &tls.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let client_cert = tokio::fs::read(cert_path).await.map_err(|e| {
Error::Redis(format!(
"Failed to read Redis {hop} TLS client certificate {cert_path}: {e}"
))
})?;
let client_key = tokio::fs::read(key_path).await.map_err(|e| {
Error::Redis(format!(
"Failed to read Redis {hop} TLS client key {key_path}: {e}"
))
})?;
Some(ClientTlsConfig {
client_cert,
client_key,
})
}
_ => None,
};
Ok(Some(TlsCertificates {
client_tls,
root_cert,
}))
}
async fn build_sentinel_client(spec: &SentinelSpec) -> Result<SentinelClient> {
if spec.hosts.is_empty() {
return Err(Error::Redis(
"Sentinel configured but no sentinel hosts were provided".to_string(),
));
}
let addrs: Vec<ConnectionAddr> = spec
.hosts
.iter()
.map(|(host, port)| {
if spec.sentinel_tls.enabled {
ConnectionAddr::TcpTls {
host: host.clone(),
port: *port,
insecure: spec.sentinel_tls.accept_invalid_certs,
tls_params: None,
}
} else {
ConnectionAddr::Tcp(host.clone(), *port)
}
})
.collect();
let mut builder =
SentinelClientBuilder::new(addrs, spec.master_name.clone(), SentinelServerType::Master)
.map_err(|e| {
Error::Redis(format!("Failed to initialize Sentinel client builder: {e}"))
})?;
if let Some(username) = &spec.sentinel_username {
builder = builder.set_client_to_sentinel_username(username);
}
if let Some(password) = &spec.sentinel_password {
builder = builder.set_client_to_sentinel_password(password);
}
if spec.sentinel_tls.enabled
&& let Some(certs) = load_tls_certificates(&spec.sentinel_tls, "sentinel").await?
{
builder = builder.set_client_to_sentinel_certificates(certs);
}
builder = builder.set_client_to_redis_db(spec.db);
if let Some(username) = &spec.redis_username {
builder = builder.set_client_to_redis_username(username);
}
if let Some(password) = &spec.redis_password {
builder = builder.set_client_to_redis_password(password);
}
if spec.master_tls.enabled {
builder = builder.set_client_to_redis_tls_mode(tls_mode(&spec.master_tls));
if let Some(certs) = load_tls_certificates(&spec.master_tls, "master").await? {
builder = builder.set_client_to_redis_certificates(certs);
}
}
builder
.build()
.map_err(|e| Error::Redis(format!("Failed to build Sentinel client: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
fn temp_path(name: &str) -> std::path::PathBuf {
let mut path = std::env::temp_dir();
path.push(format!("sockudo-redis-tls-{}-{name}", uuid::Uuid::new_v4()));
path
}
fn base_spec() -> SentinelSpec {
SentinelSpec {
hosts: vec![("127.0.0.1".to_string(), 26379)],
master_name: "mymaster".to_string(),
db: 0,
redis_username: None,
redis_password: None,
sentinel_username: None,
sentinel_password: None,
master_tls: RedisTlsOptions::default(),
sentinel_tls: RedisTlsOptions::default(),
}
}
#[tokio::test]
async fn load_tls_certificates_returns_none_when_unconfigured() {
let tls = RedisTlsOptions {
enabled: true,
..Default::default()
};
let result = load_tls_certificates(&tls, "master").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn load_tls_certificates_reads_ca_and_client_pair() {
let ca = temp_path("ca.pem");
let cert = temp_path("client.pem");
let key = temp_path("client.key");
tokio::fs::write(&ca, b"-----CA-----").await.unwrap();
tokio::fs::write(&cert, b"-----CERT-----").await.unwrap();
tokio::fs::write(&key, b"-----KEY-----").await.unwrap();
let tls = RedisTlsOptions {
enabled: true,
ca_path: Some(ca.to_string_lossy().into_owned()),
client_cert_path: Some(cert.to_string_lossy().into_owned()),
client_key_path: Some(key.to_string_lossy().into_owned()),
..Default::default()
};
let certs = load_tls_certificates(&tls, "master")
.await
.unwrap()
.expect("certificates should be loaded");
assert_eq!(certs.root_cert.as_deref(), Some(&b"-----CA-----"[..]));
let client = certs.client_tls.expect("client tls should be present");
assert_eq!(client.client_cert, b"-----CERT-----");
assert_eq!(client.client_key, b"-----KEY-----");
let _ = tokio::fs::remove_file(&ca).await;
let _ = tokio::fs::remove_file(&cert).await;
let _ = tokio::fs::remove_file(&key).await;
}
#[tokio::test]
async fn load_tls_certificates_errors_on_missing_file() {
let tls = RedisTlsOptions {
enabled: true,
ca_path: Some(temp_path("missing-ca.pem").to_string_lossy().into_owned()),
..Default::default()
};
assert!(load_tls_certificates(&tls, "master").await.is_err());
}
#[tokio::test]
async fn build_sentinel_client_succeeds_plaintext() {
build_sentinel_client(&base_spec())
.await
.expect("plaintext sentinel client should build");
}
#[tokio::test]
async fn build_sentinel_client_succeeds_with_tls_no_certs() {
let mut spec = base_spec();
spec.sentinel_tls = RedisTlsOptions {
enabled: true,
accept_invalid_certs: true,
..Default::default()
};
spec.master_tls = RedisTlsOptions {
enabled: true,
..Default::default()
};
spec.redis_password = Some("secret".to_string());
spec.sentinel_password = Some("sentinel-secret".to_string());
build_sentinel_client(&spec)
.await
.expect("TLS sentinel client should build");
}
#[tokio::test]
async fn build_sentinel_client_errors_without_hosts() {
let mut spec = base_spec();
spec.hosts.clear();
assert!(build_sentinel_client(&spec).await.is_err());
}
}