use crate::horizontal_adapter::{BroadcastMessage, RequestBody, ResponseBody};
use crate::horizontal_transport::{HorizontalTransport, TransportConfig, TransportHandlers};
use async_trait::async_trait;
use crossfire::mpsc;
use redis::cluster_read_routing::RandomReplicaStrategy;
use sockudo_core::error::{Error, Result};
use sockudo_core::metrics::MetricsInterface;
use sockudo_core::options::RedisClusterAdapterConfig;
use redis::AsyncCommands;
use redis::cluster::{ClusterClient, ClusterClientBuilder};
use redis::cluster_async::ClusterConnection;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
type RedisPushChannelFlavor = mpsc::List<redis::PushInfo>;
type RedisPushReceiver = crossfire::AsyncRx<RedisPushChannelFlavor>;
fn value_to_string(v: &redis::Value) -> Option<String> {
match v {
redis::Value::BulkString(bytes) => String::from_utf8(bytes.clone()).ok(),
redis::Value::SimpleString(s) => Some(s.clone()),
redis::Value::VerbatimString { format: _, text } => Some(text.clone()),
_ => None,
}
}
impl TransportConfig for RedisClusterAdapterConfig {
fn request_timeout_ms(&self) -> u64 {
self.request_timeout_ms
}
fn prefix(&self) -> &str {
&self.prefix
}
}
pub struct RedisClusterTransport {
#[allow(dead_code)]
client: ClusterClient,
publish_connection: ClusterConnection,
health_check_connection: ClusterConnection,
broadcast_channel: String,
request_channel: String,
response_channel: String,
config: RedisClusterAdapterConfig,
use_sharded_pubsub: bool,
metrics: Arc<OnceLock<Arc<dyn MetricsInterface + Send + Sync>>>,
shutdown: Arc<Notify>,
is_running: Arc<AtomicBool>,
owner_count: Arc<AtomicUsize>,
}
#[async_trait]
impl HorizontalTransport for RedisClusterTransport {
type Config = RedisClusterAdapterConfig;
async fn new(config: Self::Config) -> Result<Self> {
let client = ClusterClientBuilder::new(config.nodes.clone())
.retries(3)
.read_routing_strategy(RandomReplicaStrategy)
.build()
.map_err(|e| Error::Redis(format!("Failed to create Redis Cluster client: {e}")))?;
let publish_connection = client.get_async_connection().await.map_err(|e| {
Error::Redis(format!(
"Failed to create Redis Cluster publish connection: {e}"
))
})?;
let health_check_connection = client.get_async_connection().await.map_err(|e| {
Error::Redis(format!(
"Failed to create Redis Cluster health check connection: {e}"
))
})?;
let broadcast_channel = format!("{}:#broadcast", config.prefix);
let request_channel = format!("{}:#requests", config.prefix);
let response_channel = format!("{}:#responses", config.prefix);
let use_sharded_pubsub = config.use_sharded_pubsub;
if use_sharded_pubsub {
info!(
"Redis Cluster using sharded pub/sub (SSUBSCRIBE/SPUBLISH) for optimal performance"
);
} else {
debug!("Redis Cluster using standard pub/sub (SUBSCRIBE/PUBLISH)");
}
info!(
"Redis Cluster transport initialized with dedicated publish and health check connections"
);
Ok(Self {
client,
publish_connection,
health_check_connection,
broadcast_channel,
request_channel,
response_channel,
config,
use_sharded_pubsub,
metrics: Arc::new(OnceLock::new()),
shutdown: Arc::new(Notify::new()),
is_running: Arc::new(AtomicBool::new(true)),
owner_count: Arc::new(AtomicUsize::new(1)),
})
}
async fn publish_broadcast(&self, message: &BroadcastMessage) -> Result<()> {
let broadcast_json = sonic_rs::to_string(message)?;
let mut retry_delay = 100u64; const MAX_RETRIES: u32 = 3;
const MAX_RETRY_DELAY: u64 = 1000;
for attempt in 0..=MAX_RETRIES {
let mut conn = self.publish_connection.clone();
let publish_result: redis::RedisResult<()> = if self.use_sharded_pubsub {
redis::cmd("SPUBLISH")
.arg(&self.broadcast_channel)
.arg(&broadcast_json)
.query_async(&mut conn)
.await
} else {
conn.publish(&self.broadcast_channel, &broadcast_json).await
};
match publish_result {
Ok(_) => {
if attempt > 0 {
debug!("Broadcast succeeded on retry attempt {}", attempt);
}
return Ok(());
}
Err(e) => {
if attempt == MAX_RETRIES {
return Err(Error::Redis(format!(
"Failed to publish broadcast after {} retries: {e}",
MAX_RETRIES
)));
}
warn!(
"Failed to publish broadcast (attempt {}): {}",
attempt + 1,
e
);
tokio::time::sleep(Duration::from_millis(retry_delay)).await;
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
}
}
}
unreachable!("Retry loop should have returned");
}
async fn publish_request(&self, request: &RequestBody) -> Result<()> {
let request_json = sonic_rs::to_string(request)
.map_err(|e| Error::Other(format!("Failed to serialize request: {e}")))?;
let mut retry_delay = 100u64;
const MAX_RETRIES: u32 = 3;
const MAX_RETRY_DELAY: u64 = 1000;
for attempt in 0..=MAX_RETRIES {
let mut conn = self.publish_connection.clone();
let publish_result: redis::RedisResult<i32> = if self.use_sharded_pubsub {
redis::cmd("SPUBLISH")
.arg(&self.request_channel)
.arg(&request_json)
.query_async(&mut conn)
.await
} else {
conn.publish(&self.request_channel, &request_json).await
};
match publish_result {
Ok(subscriber_count) => {
if attempt > 0 {
debug!("Request publish succeeded on retry attempt {}", attempt);
}
debug!(
"Broadcasted request {} to {} subscribers",
request.request_id, subscriber_count
);
return Ok(());
}
Err(e) => {
if attempt == MAX_RETRIES {
return Err(Error::Redis(format!(
"Failed to publish request after {} retries: {e}",
MAX_RETRIES
)));
}
warn!("Failed to publish request (attempt {}): {}", attempt + 1, e);
tokio::time::sleep(Duration::from_millis(retry_delay)).await;
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
}
}
}
unreachable!("Retry loop should have returned");
}
async fn publish_response(&self, response: &ResponseBody) -> Result<()> {
let response_json = sonic_rs::to_string(response)
.map_err(|e| Error::Other(format!("Failed to serialize response: {e}")))?;
let mut retry_delay = 100u64;
const MAX_RETRIES: u32 = 3;
const MAX_RETRY_DELAY: u64 = 1000;
for attempt in 0..=MAX_RETRIES {
let mut conn = self.publish_connection.clone();
let publish_result: redis::RedisResult<()> = if self.use_sharded_pubsub {
redis::cmd("SPUBLISH")
.arg(&self.response_channel)
.arg(&response_json)
.query_async(&mut conn)
.await
} else {
conn.publish(&self.response_channel, &response_json).await
};
match publish_result {
Ok(_) => {
if attempt > 0 {
debug!("Response publish succeeded on retry attempt {}", attempt);
}
return Ok(());
}
Err(e) => {
if attempt == MAX_RETRIES {
return Err(Error::Redis(format!(
"Failed to publish response after {} retries: {e}",
MAX_RETRIES
)));
}
warn!(
"Failed to publish response (attempt {}): {}",
attempt + 1,
e
);
tokio::time::sleep(Duration::from_millis(retry_delay)).await;
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
}
}
}
unreachable!("Retry loop should have returned");
}
async fn start_listeners(&self, handlers: TransportHandlers) -> Result<()> {
let broadcast_channel = self.broadcast_channel.clone();
let request_channel = self.request_channel.clone();
let response_channel = self.response_channel.clone();
let nodes = self.config.nodes.clone();
let use_sharded_pubsub = self.use_sharded_pubsub;
let publish_connection = self.publish_connection.clone();
let metrics = self.metrics.clone();
let shutdown = self.shutdown.clone();
let is_running = self.is_running.clone();
tokio::spawn(async move {
let mut retry_delay = 500u64; const MAX_RETRY_DELAY: u64 = 10_000; let mut reconnection_count = 0u64;
loop {
if !is_running.load(Ordering::Relaxed) {
break;
}
let (tx, rx): (crossfire::MTx<RedisPushChannelFlavor>, RedisPushReceiver) =
mpsc::unbounded_async();
let push_sender = move |msg| tx.send(msg).map_err(|_| redis::aio::SendError);
let sub_client = match ClusterClientBuilder::new(nodes.clone())
.use_protocol(redis::ProtocolVersion::RESP3)
.push_sender(push_sender)
.build()
{
Ok(client) => client,
Err(e) => {
error!(
"Failed to create PubSub client: {}, retrying in {}ms",
e, retry_delay
);
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
continue;
}
};
let mut pubsub = match sub_client.get_async_connection().await {
Ok(conn) => {
retry_delay = 500; if reconnection_count > 0 {
debug!(
"Redis Cluster PubSub reconnected successfully after {} attempts",
reconnection_count
);
}
conn
}
Err(e) => {
reconnection_count += 1;
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_reconnection("redis_cluster");
}
error!(
"Failed to get pubsub connection: {}, retrying in {}ms (attempt {})",
e, retry_delay, reconnection_count
);
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
continue;
}
};
let subscribe_result: redis::RedisResult<()> = if use_sharded_pubsub {
redis::cmd("SSUBSCRIBE")
.arg(&broadcast_channel)
.arg(&request_channel)
.arg(&response_channel)
.query_async(&mut pubsub)
.await
} else {
pubsub
.subscribe(&[&broadcast_channel, &request_channel, &response_channel])
.await
};
if let Err(e) = subscribe_result {
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_reconnection("redis_cluster");
}
error!(
"Failed to subscribe to channels: {}, retrying in {}ms",
e, retry_delay
);
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
continue;
}
debug!(
"Redis Cluster transport listening on channels: {}, {}, {}",
broadcast_channel, request_channel, response_channel
);
reconnection_count = 0;
loop {
if !is_running.load(Ordering::Relaxed) {
break;
}
let recv_result = tokio::select! {
_ = shutdown.notified() => break,
result = tokio::time::timeout(Duration::from_millis(100), rx.recv()) => result,
};
let Ok(Ok(push_info)) = recv_result else {
if matches!(recv_result, Ok(Err(_))) {
break;
}
continue;
};
let is_message = matches!(
push_info.kind,
redis::PushKind::Message | redis::PushKind::SMessage
);
if !is_message {
continue; }
if push_info.data.len() < 2 {
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_message_dropped("redis_cluster");
}
error!("Invalid push message format: {:?}", push_info);
continue;
}
let channel = match value_to_string(&push_info.data[0]) {
Some(s) => s,
None => {
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_message_dropped("redis_cluster");
}
error!("Failed to parse channel name: {:?}", push_info.data[0]);
continue;
}
};
let payload = match value_to_string(&push_info.data[1]) {
Some(s) => s,
None => {
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_message_dropped("redis_cluster");
}
error!("Failed to parse payload: {:?}", push_info.data[1]);
continue;
}
};
let broadcast_handler = handlers.on_broadcast.clone();
let request_handler = handlers.on_request.clone();
let response_handler = handlers.on_response.clone();
let publish_conn = publish_connection.clone(); let metrics_clone = metrics.clone();
let broadcast_channel_clone = broadcast_channel.clone();
let request_channel_clone = request_channel.clone();
let response_channel_clone = response_channel.clone();
tokio::spawn(async move {
if channel == broadcast_channel_clone {
if let Ok(broadcast) = sonic_rs::from_str::<BroadcastMessage>(&payload)
{
broadcast_handler(broadcast).await;
} else if let Some(metrics) = metrics_clone.get() {
metrics.mark_horizontal_transport_message_dropped("redis_cluster");
}
} else if channel == request_channel_clone {
if let Ok(request) = sonic_rs::from_str::<RequestBody>(&payload) {
let response_result = request_handler(request).await;
if let Ok(response) = response_result
&& let Ok(response_json) = sonic_rs::to_string(&response)
{
let mut conn = publish_conn.clone();
let _ = conn
.publish::<_, _, ()>(&response_channel_clone, response_json)
.await;
}
} else if let Some(metrics) = metrics_clone.get() {
metrics.mark_horizontal_transport_message_dropped("redis_cluster");
}
} else if channel == response_channel_clone {
if let Ok(response) = sonic_rs::from_str::<ResponseBody>(&payload) {
response_handler(response).await;
} else if let Some(metrics) = metrics_clone.get() {
metrics.mark_horizontal_transport_message_dropped("redis_cluster");
}
}
});
}
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_reconnection("redis_cluster");
}
warn!("Redis Cluster PubSub connection ended, reconnecting...");
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
}
});
Ok(())
}
async fn get_node_count(&self) -> Result<usize> {
let mut conn = self.publish_connection.clone();
let result: redis::RedisResult<Vec<redis::Value>> = redis::cmd("PUBSUB")
.arg("NUMSUB")
.arg(&self.request_channel)
.query_async(&mut conn)
.await;
match result {
Ok(values) => {
if values.len() >= 2 {
if let redis::Value::Int(count) = values[1] {
Ok((count as usize).max(1))
} else {
Ok(1)
}
} else {
Ok(1)
}
}
Err(e) => {
error!("Failed to execute PUBSUB NUMSUB: {}", e);
Ok(1)
}
}
}
async fn check_health(&self) -> Result<()> {
let mut conn = self.health_check_connection.clone();
let response = redis::cmd("PING")
.query_async::<String>(&mut conn)
.await
.map_err(|e| Error::Redis(format!("Cluster health check PING failed: {e}")))?;
if response == "PONG" {
Ok(())
} else {
Err(Error::Redis(format!(
"Cluster PING returned unexpected response: {response}"
)))
}
}
fn set_metrics(&self, metrics: Arc<dyn MetricsInterface + Send + Sync>) {
let _ = self.metrics.set(metrics);
}
}
impl Drop for RedisClusterTransport {
fn drop(&mut self) {
if self.owner_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.is_running.store(false, Ordering::Relaxed);
self.shutdown.notify_waiters();
}
}
}
impl Clone for RedisClusterTransport {
fn clone(&self) -> Self {
self.owner_count.fetch_add(1, Ordering::Relaxed);
Self {
client: self.client.clone(),
publish_connection: self.publish_connection.clone(),
health_check_connection: self.health_check_connection.clone(),
broadcast_channel: self.broadcast_channel.clone(),
request_channel: self.request_channel.clone(),
response_channel: self.response_channel.clone(),
config: self.config.clone(),
use_sharded_pubsub: self.use_sharded_pubsub,
metrics: self.metrics.clone(),
shutdown: self.shutdown.clone(),
is_running: self.is_running.clone(),
owner_count: self.owner_count.clone(),
}
}
}