use super::error::{TransportError, TransportResult};
use super::traits::{CommitToken, TransportBase, TransportReceiver, TransportSender};
use super::types::{Message, PayloadFormat, SendResult};
use redis::AsyncCommands;
use redis::streams::{StreamMaxlen, StreamReadOptions, StreamReadReply};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RedisToken {
pub stream: Arc<str>,
pub entry_id: String,
}
impl CommitToken for RedisToken {}
impl std::fmt::Display for RedisToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "redis:{}:{}", self.stream, self.entry_id)
}
}
fn default_url() -> String {
"redis://127.0.0.1:6379".into()
}
fn default_group() -> String {
"dfe".into()
}
fn default_consumer() -> String {
"consumer-1".into()
}
fn default_block_ms() -> usize {
5000
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedisTransportConfig {
#[serde(default = "default_url")]
pub url: String,
#[serde(default)]
pub stream: Option<String>,
#[serde(default = "default_group")]
pub group: String,
#[serde(default = "default_consumer")]
pub consumer: String,
#[serde(default)]
pub max_stream_len: Option<usize>,
#[serde(default = "default_block_ms")]
pub block_ms: usize,
#[serde(default)]
pub filters_in: Vec<super::filter::FilterRule>,
#[serde(default)]
pub filters_out: Vec<super::filter::FilterRule>,
}
impl Default for RedisTransportConfig {
fn default() -> Self {
Self {
url: default_url(),
stream: None,
group: default_group(),
consumer: default_consumer(),
max_stream_len: None,
block_ms: default_block_ms(),
filters_in: Vec::new(),
filters_out: Vec::new(),
}
}
}
impl RedisTransportConfig {
#[must_use]
pub fn from_cascade() -> Self {
#[cfg(feature = "config")]
{
if let Some(cfg) = crate::config::try_get()
&& let Ok(tc) = cfg.unmarshal_key_registered::<Self>("transport.redis")
{
return tc;
}
}
Self::default()
}
}
pub struct RedisTransport {
conn: Mutex<redis::aio::MultiplexedConnection>,
config: RedisTransportConfig,
closed: Arc<AtomicBool>,
group_created: Mutex<std::collections::HashSet<String>>,
filter_engine: super::filter::TransportFilterEngine,
filtered_dlq_buffer: parking_lot::Mutex<Vec<super::filter::FilteredDlqEntry>>,
}
impl RedisTransport {
pub async fn new(config: &RedisTransportConfig) -> TransportResult<Self> {
let client = redis::Client::open(config.url.as_str()).map_err(|e| {
TransportError::Config(format!("invalid Redis URL '{}': {e}", config.url))
})?;
let conn = client
.get_multiplexed_async_connection()
.await
.map_err(|e| {
TransportError::Connection(format!(
"failed to connect to Redis at '{}': {e}",
config.url
))
})?;
#[cfg(feature = "logger")]
tracing::info!(
url = %config.url,
stream = ?config.stream,
group = %config.group,
"Redis transport opened"
);
let filter_engine = super::filter::TransportFilterEngine::new(
&config.filters_in,
&config.filters_out,
&crate::transport::filter::TransportFilterTierConfig::default(),
)?;
let closed = Arc::new(AtomicBool::new(false));
#[cfg(feature = "health")]
{
let h = Arc::clone(&closed);
crate::health::HealthRegistry::register("transport:redis", move || {
if h.load(Ordering::Relaxed) {
crate::health::HealthStatus::Unhealthy
} else {
crate::health::HealthStatus::Healthy
}
});
}
Ok(Self {
conn: Mutex::new(conn),
config: config.clone(),
closed,
group_created: Mutex::new(std::collections::HashSet::new()),
filter_engine,
filtered_dlq_buffer: parking_lot::Mutex::new(Vec::new()),
})
}
fn resolve_stream<'a>(&'a self, key: &'a str) -> Result<&'a str, TransportError> {
if !key.is_empty() {
return Ok(key);
}
self.config.stream.as_deref().ok_or_else(|| {
TransportError::Config(
"no stream name: key is empty and config.stream is not set".into(),
)
})
}
async fn ensure_group(&self, stream: &str) -> TransportResult<()> {
{
let created = self.group_created.lock().await;
if created.contains(stream) {
return Ok(());
}
}
let mut conn = self.conn.lock().await;
let result: redis::RedisResult<()> = conn
.xgroup_create_mkstream(stream, &self.config.group, "0")
.await;
match result {
Ok(()) => {}
Err(e) => {
let msg = e.to_string();
if !msg.contains("BUSYGROUP") {
return Err(TransportError::Connection(format!(
"failed to create consumer group '{}' on stream '{stream}': {e}",
self.config.group
)));
}
}
}
self.group_created.lock().await.insert(stream.to_string());
Ok(())
}
}
impl TransportBase for RedisTransport {
async fn close(&self) -> TransportResult<()> {
self.closed.store(true, Ordering::Relaxed);
Ok(())
}
fn is_healthy(&self) -> bool {
!self.closed.load(Ordering::Relaxed)
}
fn name(&self) -> &'static str {
"redis"
}
}
impl TransportSender for RedisTransport {
async fn send(&self, key: &str, payload: &[u8]) -> SendResult {
if self.closed.load(Ordering::Relaxed) {
return SendResult::Fatal(TransportError::Closed);
}
if self.filter_engine.has_outbound_filters() {
match self.filter_engine.apply_outbound(payload) {
super::filter::FilterDisposition::Pass => {}
super::filter::FilterDisposition::Drop => return SendResult::Ok,
super::filter::FilterDisposition::Dlq => return SendResult::FilteredDlq,
}
}
let stream = match self.resolve_stream(key) {
Ok(s) => s.to_string(),
Err(e) => return SendResult::Fatal(e),
};
let mut conn = self.conn.lock().await;
let result: redis::RedisResult<String> = if let Some(max_len) = self.config.max_stream_len {
conn.xadd_maxlen(
&stream,
StreamMaxlen::Approx(max_len),
"*",
&[("payload", payload)],
)
.await
} else {
conn.xadd(&stream, "*", &[("payload", payload)]).await
};
match result {
Ok(_entry_id) => {
#[cfg(feature = "logger")]
tracing::debug!(stream = %stream, "Redis transport: XADD sent");
#[cfg(feature = "metrics")]
metrics::counter!("dfe_transport_sent_total", "transport" => "redis").increment(1);
SendResult::Ok
}
Err(e) => {
#[cfg(feature = "logger")]
tracing::warn!(error = %e, stream = %stream, "Redis transport: XADD error");
SendResult::Fatal(TransportError::Send(format!(
"XADD to stream '{stream}' failed: {e}"
)))
}
}
}
}
impl TransportReceiver for RedisTransport {
type Token = RedisToken;
async fn recv(&self, max: usize) -> TransportResult<Vec<Message<Self::Token>>> {
if self.closed.load(Ordering::Relaxed) {
return Err(TransportError::Closed);
}
let stream_name = self
.config
.stream
.as_deref()
.ok_or_else(|| TransportError::Config("config.stream must be set for recv()".into()))?
.to_string();
self.ensure_group(&stream_name).await?;
let opts = StreamReadOptions::default()
.group(&self.config.group, &self.config.consumer)
.count(max)
.block(self.config.block_ms);
let mut conn = self.conn.lock().await;
let reply: StreamReadReply = conn
.xread_options(&[&stream_name], &[">"], &opts)
.await
.map_err(|e| {
#[cfg(feature = "logger")]
tracing::warn!(error = %e, stream = %stream_name, "Redis transport: XREADGROUP error");
TransportError::Recv(format!("XREADGROUP on stream '{stream_name}' failed: {e}"))
})?;
let stream_arc: Arc<str> = Arc::from(stream_name.as_str());
let mut messages = Vec::new();
for stream_key in &reply.keys {
for stream_id in &stream_key.ids {
let payload_bytes: Option<Vec<u8>> = stream_id
.map
.get("payload")
.and_then(|v| redis::from_redis_value(v.clone()).ok());
let payload = payload_bytes.unwrap_or_default();
let format = PayloadFormat::detect(&payload);
let timestamp_ms = parse_entry_timestamp(&stream_id.id);
messages.push(Message {
key: Some(Arc::clone(&stream_arc)),
payload,
token: RedisToken {
stream: Arc::clone(&stream_arc),
entry_id: stream_id.id.clone(),
},
timestamp_ms,
format,
});
}
}
if self.filter_engine.has_inbound_filters() {
let mut staged_dlq: Vec<super::filter::FilteredDlqEntry> = Vec::new();
messages.retain(|msg| match self.filter_engine.apply_inbound(&msg.payload) {
super::filter::FilterDisposition::Pass => true,
super::filter::FilterDisposition::Drop => false,
super::filter::FilterDisposition::Dlq => {
staged_dlq.push(super::filter::FilteredDlqEntry {
payload: msg.payload.clone(),
key: msg.key.clone(),
reason: "transport filter".to_string(),
});
false
}
});
if !staged_dlq.is_empty() {
self.filtered_dlq_buffer.lock().extend(staged_dlq);
}
}
#[cfg(feature = "logger")]
if !messages.is_empty() {
tracing::debug!(
messages = messages.len(),
"Redis transport: XREADGROUP received"
);
}
#[cfg(feature = "metrics")]
if !messages.is_empty() {
metrics::counter!("dfe_transport_received_total", "transport" => "redis")
.increment(messages.len() as u64);
}
Ok(messages)
}
fn take_filtered_dlq_entries(&self) -> Vec<super::filter::FilteredDlqEntry> {
std::mem::take(&mut *self.filtered_dlq_buffer.lock())
}
async fn commit(&self, tokens: &[Self::Token]) -> TransportResult<()> {
if tokens.is_empty() {
return Ok(());
}
let mut by_stream: std::collections::HashMap<&str, Vec<&str>> =
std::collections::HashMap::new();
for token in tokens {
by_stream
.entry(&token.stream)
.or_default()
.push(&token.entry_id);
}
let mut conn = self.conn.lock().await;
for (stream, ids) in &by_stream {
let id_refs: &[&str] = ids;
let _acked: i32 = conn
.xack(*stream, &self.config.group, id_refs)
.await
.map_err(|e| {
#[cfg(feature = "logger")]
tracing::warn!(error = %e, stream = %stream, "Redis transport: XACK error");
TransportError::Commit(format!("XACK on stream '{stream}' failed: {e}"))
})?;
}
#[cfg(feature = "logger")]
tracing::debug!(count = tokens.len(), "Redis transport: XACK committed");
Ok(())
}
}
fn parse_entry_timestamp(entry_id: &str) -> Option<i64> {
entry_id
.split_once('-')
.and_then(|(ms_str, _)| ms_str.parse::<i64>().ok())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_display() {
let token = RedisToken {
stream: Arc::from("my_stream"),
entry_id: "1711432800000-0".into(),
};
assert_eq!(format!("{token}"), "redis:my_stream:1711432800000-0");
}
#[test]
fn token_clone() {
let token = RedisToken {
stream: Arc::from("s1"),
entry_id: "100-0".into(),
};
let cloned = token.clone();
assert_eq!(token, cloned);
}
#[test]
fn config_defaults() {
let config = RedisTransportConfig::default();
assert_eq!(config.url, "redis://127.0.0.1:6379");
assert!(config.stream.is_none());
assert_eq!(config.group, "dfe");
assert!(config.max_stream_len.is_none());
assert_eq!(config.block_ms, 5000);
}
#[test]
fn config_deserialise_minimal() {
let yaml = r"
url: redis://myhost:6380
stream: events
";
let config: RedisTransportConfig = serde_yaml_ng::from_str(yaml).unwrap();
assert_eq!(config.url, "redis://myhost:6380");
assert_eq!(config.stream.as_deref(), Some("events"));
assert_eq!(config.group, "dfe");
assert_eq!(config.block_ms, 5000);
}
#[test]
fn config_deserialise_full() {
let yaml = r"
url: rediss://secure.redis.io:6380
stream: audit_log
group: my_group
consumer: worker-3
max_stream_len: 100000
block_ms: 2000
";
let config: RedisTransportConfig = serde_yaml_ng::from_str(yaml).unwrap();
assert_eq!(config.url, "rediss://secure.redis.io:6380");
assert_eq!(config.stream.as_deref(), Some("audit_log"));
assert_eq!(config.group, "my_group");
assert_eq!(config.consumer, "worker-3");
assert_eq!(config.max_stream_len, Some(100_000));
assert_eq!(config.block_ms, 2000);
}
#[test]
fn parse_entry_timestamp_valid() {
assert_eq!(
parse_entry_timestamp("1711432800000-0"),
Some(1_711_432_800_000)
);
assert_eq!(parse_entry_timestamp("0-0"), Some(0));
}
#[test]
fn parse_entry_timestamp_invalid() {
assert_eq!(parse_entry_timestamp("not-a-number"), None);
assert_eq!(parse_entry_timestamp(""), None);
}
#[test]
fn resolve_stream_uses_key_when_non_empty() {
let config = RedisTransportConfig {
stream: Some("default_stream".into()),
..Default::default()
};
let key = "override_stream";
let resolved = if key.is_empty() {
config.stream.as_deref().unwrap_or("")
} else {
key
};
assert_eq!(resolved, "override_stream");
}
#[test]
fn resolve_stream_falls_back_to_config() {
let config = RedisTransportConfig {
stream: Some("default_stream".into()),
..Default::default()
};
let key = "";
let resolved = if key.is_empty() {
config.stream.as_deref().unwrap_or("")
} else {
key
};
assert_eq!(resolved, "default_stream");
}
#[tokio::test]
async fn redis_integration_xadd_xreadgroup_xack() {
let Ok(url) = std::env::var("REDIS_URL") else {
eprintln!("Skipping: REDIS_URL not set");
return;
};
let stream = format!("test_stream_{}", chrono::Utc::now().timestamp_millis());
let group = "test_group";
let consumer = "test_consumer";
let config = RedisTransportConfig {
url: url.clone(),
stream: Some(stream.clone()),
group: group.into(),
consumer: consumer.into(),
max_stream_len: Some(1000),
block_ms: 1000,
..Default::default()
};
let transport = RedisTransport::new(&config).await.unwrap();
let r1 = transport.send("", b"{\"n\":1}").await;
assert!(r1.is_ok(), "first send should succeed");
let r2 = transport.send("", b"{\"n\":2}").await;
assert!(r2.is_ok(), "second send should succeed");
let messages = transport.recv(10).await.unwrap();
assert_eq!(messages.len(), 2, "should receive 2 messages");
assert_eq!(messages[0].payload, b"{\"n\":1}");
assert_eq!(messages[1].payload, b"{\"n\":2}");
let tokens: Vec<_> = messages.iter().map(|m| m.token.clone()).collect();
transport.commit(&tokens).await.unwrap();
let more = transport.recv(10).await.unwrap();
assert!(more.is_empty(), "no more messages after commit");
let mut conn = transport.conn.lock().await;
let _: redis::RedisResult<()> =
redis::cmd("DEL").arg(&stream).query_async(&mut *conn).await;
transport.close().await.unwrap();
assert!(!transport.is_healthy());
}
}