use crate::horizontal_adapter::{BroadcastMessage, RequestBody, ResponseBody};
use crate::horizontal_transport::{HorizontalTransport, TransportConfig, TransportHandlers};
use async_trait::async_trait;
use futures::StreamExt;
use redis::AsyncCommands;
use sockudo_core::error::{Error, Result};
use sockudo_core::metrics::MetricsInterface;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::Notify;
use tracing::{debug, error, warn};
#[derive(Debug, Clone)]
pub struct RedisAdapterConfig {
pub url: String,
pub prefix: String,
pub request_timeout_ms: u64,
pub cluster_mode: bool,
}
impl Default for RedisAdapterConfig {
fn default() -> Self {
Self {
url: "redis://127.0.0.1:6379/".to_string(),
prefix: "sockudo".to_string(),
request_timeout_ms: 5000,
cluster_mode: false,
}
}
}
impl TransportConfig for RedisAdapterConfig {
fn request_timeout_ms(&self) -> u64 {
self.request_timeout_ms
}
fn prefix(&self) -> &str {
&self.prefix
}
}
pub struct RedisTransport {
client: redis::Client,
connection: redis::aio::ConnectionManager,
events_connection: redis::aio::ConnectionManager,
broadcast_channel: String,
request_channel: String,
response_channel: String,
metrics: Arc<OnceLock<Arc<dyn MetricsInterface + Send + Sync>>>,
shutdown: Arc<Notify>,
is_running: Arc<AtomicBool>,
owner_count: Arc<AtomicUsize>,
}
#[async_trait]
impl HorizontalTransport for RedisTransport {
type Config = RedisAdapterConfig;
async fn new(config: Self::Config) -> Result<Self> {
let client = redis::Client::open(&*config.url)
.map_err(|e| Error::Redis(format!("Failed to create Redis client: {e}")))?;
let connection_manager_config = redis::aio::ConnectionManagerConfig::new()
.set_number_of_retries(5)
.set_exponent_base(2.0)
.set_max_delay(std::time::Duration::from_millis(5000));
let connection = client
.get_connection_manager_with_config(connection_manager_config.clone())
.await
.map_err(|e| Error::Redis(format!("Failed to connect to Redis: {e}")))?;
let events_connection = client
.get_connection_manager_with_config(connection_manager_config)
.await
.map_err(|e| {
Error::Redis(format!(
"Failed to create Redis connection manager for events: {e}"
))
})?;
let broadcast_channel = format!("{}:#broadcast", config.prefix);
let request_channel = format!("{}:#requests", config.prefix);
let response_channel = format!("{}:#responses", config.prefix);
Ok(Self {
client,
connection,
events_connection,
broadcast_channel,
request_channel,
response_channel,
metrics: Arc::new(OnceLock::new()),
shutdown: Arc::new(Notify::new()),
is_running: Arc::new(AtomicBool::new(true)),
owner_count: Arc::new(AtomicUsize::new(1)),
})
}
async fn publish_broadcast(&self, message: &BroadcastMessage) -> Result<()> {
let broadcast_json = sonic_rs::to_string(message)?;
let mut retry_delay = 100u64; const MAX_RETRIES: u32 = 3;
const MAX_RETRY_DELAY: u64 = 1000;
for attempt in 0..=MAX_RETRIES {
let mut conn = self.events_connection.clone();
match conn
.publish::<_, _, i32>(&self.broadcast_channel, &broadcast_json)
.await
{
Ok(_subscriber_count) => {
if attempt > 0 {
debug!("Broadcast succeeded on retry attempt {}", attempt);
}
return Ok(());
}
Err(e) => {
if attempt == MAX_RETRIES {
return Err(Error::Redis(format!(
"Failed to publish broadcast after {} attempts: {}",
MAX_RETRIES + 1,
e
)));
}
warn!(
"Broadcast attempt {} failed: {}, retrying in {}ms",
attempt + 1,
e,
retry_delay
);
tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay)).await;
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
}
}
}
Err(Error::Redis(
"All retry attempts failed unexpectedly".to_string(),
))
}
async fn publish_request(&self, request: &RequestBody) -> Result<()> {
let request_json = sonic_rs::to_string(request)
.map_err(|e| Error::Other(format!("Failed to serialize request: {e}")))?;
let mut conn = self.connection.clone();
let subscriber_count: i32 = conn
.publish(&self.request_channel, &request_json)
.await
.map_err(|e| Error::Redis(format!("Failed to publish request: {e}")))?;
debug!(
"Broadcasted request {} to {} subscribers",
request.request_id, subscriber_count
);
Ok(())
}
async fn publish_response(&self, response: &ResponseBody) -> Result<()> {
let response_json = sonic_rs::to_string(response)
.map_err(|e| Error::Other(format!("Failed to serialize response: {e}")))?;
let mut conn = self.connection.clone();
let _: () = conn
.publish(&self.response_channel, response_json)
.await
.map_err(|e| Error::Redis(format!("Failed to publish response: {e}")))?;
Ok(())
}
async fn start_listeners(&self, handlers: TransportHandlers) -> Result<()> {
let sub_client = self.client.clone();
let pub_connection = self.connection.clone();
let broadcast_channel = self.broadcast_channel.clone();
let request_channel = self.request_channel.clone();
let response_channel = self.response_channel.clone();
let metrics = self.metrics.clone();
let shutdown = self.shutdown.clone();
let is_running = self.is_running.clone();
tokio::spawn(async move {
let mut retry_delay = 500u64; const MAX_RETRY_DELAY: u64 = 10_000;
loop {
if !is_running.load(Ordering::Relaxed) {
break;
}
debug!("Attempting to establish pub/sub connection...");
let mut pubsub = match sub_client.get_async_pubsub().await {
Ok(pubsub) => {
retry_delay = 500; debug!("Pub/sub connection established successfully");
pubsub
}
Err(e) => {
error!(
"Failed to get pubsub connection: {}, retrying in {}ms",
e, retry_delay
);
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
continue;
}
};
if let Err(e) = pubsub
.subscribe(&[&broadcast_channel, &request_channel, &response_channel])
.await
{
error!(
"Failed to subscribe to channels: {}, retrying in {}ms",
e, retry_delay
);
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
continue;
}
debug!(
"Redis transport listening on channels: {}, {}, {}",
broadcast_channel, request_channel, response_channel
);
let mut message_stream = pubsub.on_message();
let mut connection_broken = false;
loop {
if !is_running.load(Ordering::Relaxed) {
break;
}
let next_msg = tokio::select! {
_ = shutdown.notified() => break,
msg = message_stream.next() => msg,
};
let Some(msg) = next_msg else {
break;
};
let channel = msg.get_channel_name();
let payload_result: redis::RedisResult<String> = msg.get_payload();
if let Ok(payload) = payload_result {
if channel == broadcast_channel.as_str() {
let broadcast_handler = handlers.on_broadcast.clone();
let metrics_clone = metrics.clone();
tokio::spawn(async move {
if let Ok(broadcast) =
sonic_rs::from_str::<BroadcastMessage>(&payload)
{
broadcast_handler(broadcast).await;
} else if let Some(metrics) = metrics_clone.get() {
metrics.mark_horizontal_transport_message_dropped("redis");
}
});
} else if channel == request_channel.as_str() {
let request_handler = handlers.on_request.clone();
let pub_connection_clone = pub_connection.clone();
let response_channel_clone = response_channel.clone();
let metrics_clone = metrics.clone();
tokio::spawn(async move {
if let Ok(request) = sonic_rs::from_str::<RequestBody>(&payload) {
let response_result = request_handler(request).await;
if let Ok(response) = response_result
&& let Ok(response_json) = sonic_rs::to_string(&response)
{
let mut conn = pub_connection_clone.clone();
let _ = conn
.publish::<_, _, ()>(
&response_channel_clone,
response_json,
)
.await;
}
} else if let Some(metrics) = metrics_clone.get() {
metrics.mark_horizontal_transport_message_dropped("redis");
}
});
} else if channel == response_channel.as_str() {
let response_handler = handlers.on_response.clone();
let metrics_clone = metrics.clone();
tokio::spawn(async move {
if let Ok(response) = sonic_rs::from_str::<ResponseBody>(&payload) {
response_handler(response).await;
} else {
if let Some(metrics) = metrics_clone.get() {
metrics.mark_horizontal_transport_message_dropped("redis");
}
warn!("Failed to parse response message: {}", payload);
}
});
}
} else {
warn!("Error getting message payload: {:?}", payload_result);
connection_broken = true;
break;
}
}
if connection_broken {
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_reconnection("redis");
}
warn!(
"Pub/sub connection broken, reconnecting in {}ms...",
retry_delay
);
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay)) => {}
}
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
} else {
if let Some(metrics) = metrics.get() {
metrics.mark_horizontal_transport_reconnection("redis");
}
warn!("Pub/sub message stream ended unexpectedly, reconnecting...");
}
}
});
Ok(())
}
async fn get_node_count(&self) -> Result<usize> {
let mut conn = self.connection.clone();
let result: redis::RedisResult<Vec<redis::Value>> = redis::cmd("PUBSUB")
.arg("NUMSUB")
.arg(&self.request_channel)
.query_async(&mut conn)
.await;
match result {
Ok(values) => {
if values.len() >= 2 {
if let redis::Value::Int(count) = values[1] {
let node_count = (count as usize).max(1);
debug!(
"Detected {} application instances via PUBSUB NUMSUB",
node_count
);
Ok(node_count)
} else {
warn!("PUBSUB NUMSUB returned non-integer count: {:?}", values[1]);
Ok(1)
}
} else {
warn!("PUBSUB NUMSUB returned unexpected format: {:?}", values);
Ok(1)
}
}
Err(e) => {
error!("Failed to execute PUBSUB NUMSUB: {}", e);
Ok(1)
}
}
}
fn node_count_is_real_time(&self) -> bool {
true
}
async fn check_health(&self) -> Result<()> {
let mut conn = self
.client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::Redis(format!("Failed to acquire health check connection: {e}")))?;
let response = redis::cmd("PING")
.query_async::<String>(&mut conn)
.await
.map_err(|e| Error::Redis(format!("Health check PING failed: {e}")))?;
if response == "PONG" {
Ok(())
} else {
Err(Error::Redis(format!(
"PING returned unexpected response: {response}"
)))
}
}
fn set_metrics(&self, metrics: Arc<dyn MetricsInterface + Send + Sync>) {
let _ = self.metrics.set(metrics);
}
}
impl Drop for RedisTransport {
fn drop(&mut self) {
if self.owner_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.is_running.store(false, Ordering::Relaxed);
self.shutdown.notify_waiters();
}
}
}
impl Clone for RedisTransport {
fn clone(&self) -> Self {
self.owner_count.fetch_add(1, Ordering::Relaxed);
Self {
client: self.client.clone(),
connection: self.connection.clone(),
events_connection: self.events_connection.clone(),
broadcast_channel: self.broadcast_channel.clone(),
request_channel: self.request_channel.clone(),
response_channel: self.response_channel.clone(),
metrics: self.metrics.clone(),
shutdown: self.shutdown.clone(),
is_running: self.is_running.clone(),
owner_count: self.owner_count.clone(),
}
}
}