use async_trait::async_trait;
use bytes::Bytes;
use redis::aio::ConnectionManager;
use redis::{Client, RedisError, Value};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::adapter::{Adapter, ShardPollResult};
use crate::config::RedisAdapterConfig;
use crate::error::AdapterError;
use crate::event::{Batch, InternalEvent, StoredEvent};
pub struct RedisAdapter {
client: Client,
conn: Option<ConnectionManager>,
config: RedisAdapterConfig,
initialized: AtomicBool,
stream_keys: parking_lot::RwLock<HashMap<u16, Arc<str>>>,
}
impl RedisAdapter {
pub fn new(config: RedisAdapterConfig) -> Result<Self, AdapterError> {
let client = Client::open(config.url.as_str())
.map_err(|e| AdapterError::Connection(e.to_string()))?;
Ok(Self {
client,
conn: None,
config,
initialized: AtomicBool::new(false),
stream_keys: parking_lot::RwLock::new(HashMap::new()),
})
}
#[inline]
fn stream_key(&self, shard_id: u16) -> Arc<str> {
if let Some(k) = self.stream_keys.read().get(&shard_id) {
return k.clone();
}
let mut cache = self.stream_keys.write();
cache
.entry(shard_id)
.or_insert_with(|| {
Arc::from(format!("{}:shard:{}", self.config.prefix, shard_id).as_str())
})
.clone()
}
fn serialize_event(event: &InternalEvent) -> Result<Vec<u8>, AdapterError> {
use std::io::Write as _;
let mut buf = Vec::with_capacity(event.raw.len() + 32);
buf.extend_from_slice(b"{\"r\":");
buf.extend_from_slice(&event.raw); buf.extend_from_slice(b",\"t\":");
write!(&mut buf, "{}", event.insertion_ts)
.map_err(|e| AdapterError::Fatal(format!("serialize_event: {e}")))?;
buf.extend_from_slice(b",\"s\":");
write!(&mut buf, "{}", event.shard_id)
.map_err(|e| AdapterError::Fatal(format!("serialize_event: {e}")))?;
buf.push(b'}');
Ok(buf)
}
fn deserialize_event(id: &str, data: &[u8]) -> Result<StoredEvent, AdapterError> {
#[derive(serde::Deserialize)]
struct StoredFormat<'a> {
#[serde(borrow)]
r: &'a serde_json::value::RawValue,
#[serde(default)]
t: u64,
#[serde(default)]
s: u16,
}
let parsed: StoredFormat = serde_json::from_slice(data)?;
let raw_bytes = Bytes::copy_from_slice(parsed.r.get().as_bytes());
Ok(StoredEvent::new(
id.to_string(),
raw_bytes,
parsed.t,
parsed.s,
))
}
async fn get_conn(&self) -> Result<ConnectionManager, AdapterError> {
if !self.initialized.load(Ordering::Acquire) {
return Err(AdapterError::Shutdown);
}
self.conn
.clone()
.ok_or_else(|| AdapterError::Transient("redis: re-init race; conn not yet set".into()))
}
fn parse_xrange_response(results: Value, limit: usize, stream_key: &str) -> ShardPollResult {
let entries = match results {
Value::Array(entries) => entries,
_ => return ShardPollResult::empty(),
};
let mut events = Vec::with_capacity(limit);
let mut last_seen_idx: Option<usize> = None;
for (idx, entry) in entries.iter().enumerate().take(limit) {
let Value::Array(parts) = entry else { continue };
if parts.len() < 2 {
continue;
}
let id: std::borrow::Cow<str> = match &parts[0] {
Value::BulkString(bytes) => String::from_utf8_lossy(bytes),
Value::SimpleString(s) => std::borrow::Cow::Borrowed(s.as_str()),
_ => continue,
};
last_seen_idx = Some(idx);
let Value::Array(fields) = &parts[1] else {
continue;
};
let mut data_bytes: Option<&[u8]> = None;
let mut dedup_id: Option<String> = None;
let mut i = 0;
while i + 1 < fields.len() {
match &fields[i] {
Value::BulkString(bytes) if bytes.as_slice() == b"d" => {
if let Value::BulkString(data) = &fields[i + 1] {
data_bytes = Some(data.as_slice());
}
}
Value::SimpleString(s) if s == "d" => {
if let Value::BulkString(data) = &fields[i + 1] {
data_bytes = Some(data.as_slice());
}
}
Value::BulkString(bytes) if bytes.as_slice() == b"dedup_id" => {
if let Value::BulkString(v) = &fields[i + 1] {
dedup_id = Some(String::from_utf8_lossy(v).into_owned());
} else if let Value::SimpleString(v) = &fields[i + 1] {
dedup_id = Some(v.clone());
}
}
Value::SimpleString(s) if s == "dedup_id" => {
if let Value::BulkString(v) = &fields[i + 1] {
dedup_id = Some(String::from_utf8_lossy(v).into_owned());
} else if let Value::SimpleString(v) = &fields[i + 1] {
dedup_id = Some(v.clone());
}
}
_ => {}
}
i += 2;
}
if let Some(data) = data_bytes {
match Self::deserialize_event(&id, data) {
Ok(event) => events.push(event.with_dedup_id(dedup_id)),
Err(e) => {
tracing::warn!(
stream = %stream_key,
id = %id,
error = %e,
"Failed to deserialize event, skipping"
);
}
}
}
}
let has_more = entries.len() > limit;
let last_seen_id: Option<String> = last_seen_idx.and_then(|i| {
let Value::Array(parts) = &entries[i] else {
return None;
};
if parts.is_empty() {
return None;
}
match &parts[0] {
Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).into_owned()),
Value::SimpleString(s) => Some(s.clone()),
_ => None,
}
});
let next_id = last_seen_id.or_else(|| events.last().map(|e| e.id.clone()));
ShardPollResult {
events,
next_id,
has_more,
}
}
}
impl std::fmt::Debug for RedisAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisAdapter")
.field("url", &crate::adapter::redact_url(&self.config.url))
.field("prefix", &self.config.prefix)
.field("initialized", &self.initialized.load(Ordering::Relaxed))
.finish()
}
}
#[async_trait]
impl Adapter for RedisAdapter {
async fn init(&mut self) -> Result<(), AdapterError> {
if self.initialized.load(Ordering::Acquire) {
tracing::warn!(
adapter = "redis",
"Redis adapter::init called twice; ignoring"
);
return Ok(());
}
let conn = self
.client
.get_connection_manager()
.await
.map_err(|e| AdapterError::Connection(e.to_string()))?;
let mut test_conn = conn.clone();
redis::cmd("PING")
.query_async::<String>(&mut test_conn)
.await
.map_err(|e| AdapterError::Connection(e.to_string()))?;
self.conn = Some(conn);
self.initialized.store(true, Ordering::Release);
tracing::info!(
adapter = "redis",
url = %crate::adapter::redact_url(&self.config.url),
prefix = %self.config.prefix,
"Redis adapter initialized"
);
static DEDUP_WARN_FIRED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
if !DEDUP_WARN_FIRED.swap(true, Ordering::AcqRel) {
tracing::warn!(
adapter = "redis",
helper = "net_sdk::RedisStreamDedup",
"Redis adapter emits a `dedup_id` field on every XADD that is \
stable across producer retries. Consumers MUST filter on this \
field (e.g. via `net_sdk::RedisStreamDedup`) or accept \
at-least-once delivery with retry duplicates. See the \
`adapter::redis` module docs for the full contract."
);
}
Ok(())
}
async fn on_batch(&self, batch: std::sync::Arc<Batch>) -> Result<(), AdapterError> {
if batch.is_empty() {
return Ok(());
}
let mut conn = self.get_conn().await?;
let stream_key = self.stream_key(batch.shard_id);
let serialized: Vec<Vec<u8>> = batch
.events
.iter()
.map(Self::serialize_event)
.collect::<Result<Vec<_>, _>>()?;
let mut pipe = redis::pipe();
pipe.atomic();
let mut dedup_id_buf = String::new();
use std::fmt::Write as _;
let _ = write!(
dedup_id_buf,
"{:x}:{}:{}",
batch.process_nonce, batch.shard_id, batch.sequence_start
);
let prefix_len = dedup_id_buf.len();
for (i, data) in serialized.iter().enumerate() {
let mut cmd = redis::cmd("XADD");
cmd.arg(&*stream_key);
if let Some(max_len) = self.config.max_stream_len {
cmd.arg("MAXLEN").arg("~").arg(max_len);
}
cmd.arg("*"); cmd.arg("d").arg(data.as_slice());
dedup_id_buf.truncate(prefix_len);
let _ = write!(dedup_id_buf, ":{i}");
cmd.arg("dedup_id").arg(dedup_id_buf.as_str());
pipe.add_command(cmd);
}
let fut = pipe.query_async::<()>(&mut conn);
tokio::time::timeout(self.config.command_timeout, fut)
.await
.map_err(|_| AdapterError::Transient("Redis command timeout".into()))?
.map_err(|e: RedisError| {
if is_transient_error(&e) {
AdapterError::Transient(e.to_string())
} else {
AdapterError::Fatal(e.to_string())
}
})?;
tracing::trace!(
shard_id = batch.shard_id,
event_count = batch.events.len(),
"Batch written to Redis"
);
Ok(())
}
async fn flush(&self) -> Result<(), AdapterError> {
Ok(())
}
async fn shutdown(&self) -> Result<(), AdapterError> {
self.initialized.store(false, Ordering::Release);
tracing::info!(adapter = "redis", "Redis adapter shut down");
Ok(())
}
async fn poll_shard(
&self,
shard_id: u16,
from_id: Option<&str>,
limit: usize,
) -> Result<ShardPollResult, AdapterError> {
let mut conn = self.get_conn().await?;
let stream_key = self.stream_key(shard_id);
let start = from_id
.map(|id| format!("({}", id)) .unwrap_or_else(|| "-".to_string());
let fetch_limit = limit.saturating_add(1);
let mut cmd = redis::cmd("XRANGE");
cmd.arg(&*stream_key)
.arg(&start)
.arg("+") .arg("COUNT")
.arg(fetch_limit);
let fut = cmd.query_async::<Value>(&mut conn);
let results = tokio::time::timeout(self.config.command_timeout, fut)
.await
.map_err(|_| AdapterError::Transient("Redis XRANGE timeout".into()))?
.map_err(|e| AdapterError::Transient(e.to_string()))?;
Ok(Self::parse_xrange_response(results, limit, &stream_key))
}
fn name(&self) -> &'static str {
"redis"
}
async fn is_healthy(&self) -> bool {
if !self.initialized.load(Ordering::Acquire) {
return false;
}
let Some(conn) = &self.conn else { return false };
let mut conn = conn.clone();
let cmd = redis::cmd("PING");
let fut = cmd.query_async::<String>(&mut conn);
matches!(
tokio::time::timeout(self.config.command_timeout, fut).await,
Ok(Ok(_))
)
}
}
fn is_transient_error(e: &RedisError) -> bool {
use redis::{ErrorKind, ServerErrorKind};
match e.kind() {
ErrorKind::Io => true,
ErrorKind::ClusterConnectionNotFound => true,
ErrorKind::Server(ServerErrorKind::BusyLoading)
| ErrorKind::Server(ServerErrorKind::Moved)
| ErrorKind::Server(ServerErrorKind::Ask)
| ErrorKind::Server(ServerErrorKind::TryAgain)
| ErrorKind::Server(ServerErrorKind::ClusterDown)
| ErrorKind::Server(ServerErrorKind::MasterDown)
| ErrorKind::Server(ServerErrorKind::ReadOnly) => true,
ErrorKind::Server(_) | ErrorKind::Extension => {
const TRANSIENT_PREFIXES: &[&str] = &[
"LOADING",
"BUSY",
"TRYAGAIN",
"MASTERDOWN",
"MOVED",
"ASK",
"READONLY",
"CLUSTERDOWN",
"NOREPLICAS",
];
e.detail()
.map(|d| TRANSIENT_PREFIXES.iter().any(|p| d.starts_with(p)))
.unwrap_or(false)
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_serialize_event() {
let event =
InternalEvent::from_value(json!({"token": "hello", "index": 42}), 1702123456789, 3);
let buffer = RedisAdapter::serialize_event(&event).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&buffer).unwrap();
assert_eq!(parsed["t"], 1702123456789u64);
assert_eq!(parsed["s"], 3);
assert_eq!(parsed["r"]["token"], "hello");
assert_eq!(parsed["r"]["index"], 42);
}
#[test]
fn test_deserialize_event() {
let data = br#"{"r":{"token":"world"},"t":9999,"s":7}"#;
let event = RedisAdapter::deserialize_event("1702123456789-0", data).unwrap();
assert_eq!(event.id, "1702123456789-0");
assert_eq!(event.insertion_ts, 9999);
assert_eq!(event.shard_id, 7);
let raw: serde_json::Value = serde_json::from_slice(&event.raw).unwrap();
assert_eq!(raw["token"], "world");
}
#[test]
fn test_stream_key() {
let config = RedisAdapterConfig::new("redis://localhost:6379").with_prefix("myapp");
let adapter = RedisAdapter::new(config).unwrap();
assert_eq!(&*adapter.stream_key(0), "myapp:shard:0");
assert_eq!(&*adapter.stream_key(15), "myapp:shard:15");
assert_eq!(&*adapter.stream_key(0), "myapp:shard:0");
}
#[test]
fn test_stream_key_sparse_shard_ids() {
let config = RedisAdapterConfig::new("redis://localhost:6379").with_prefix("myapp");
let adapter = RedisAdapter::new(config).unwrap();
assert_eq!(&*adapter.stream_key(65535), "myapp:shard:65535");
assert_eq!(&*adapter.stream_key(7), "myapp:shard:7");
assert_eq!(adapter.stream_keys.read().len(), 2);
}
fn xrange_entry(id: &str, payload: &[u8]) -> Value {
Value::Array(vec![
Value::BulkString(id.as_bytes().to_vec()),
Value::Array(vec![
Value::BulkString(b"d".to_vec()),
Value::BulkString(payload.to_vec()),
]),
])
}
#[test]
fn test_poll_shard_advances_cursor_on_all_corrupt_entries() {
let response = Value::Array(vec![
xrange_entry("1-0", b"not json"),
xrange_entry("2-0", b"{also not"),
xrange_entry("3-0", b"][broken"),
]);
let result = RedisAdapter::parse_xrange_response(response, 10, "myapp:shard:0");
assert!(
result.events.is_empty(),
"all corrupt entries should be skipped"
);
assert_eq!(
result.next_id.as_deref(),
Some("3-0"),
"next_id must advance to the last raw entry id, not None"
);
}
#[test]
fn test_poll_shard_advances_past_trailing_corrupt_entries() {
let good = br#"{"r":{"k":"v"},"t":1,"s":0}"#;
let response = Value::Array(vec![
xrange_entry("1-0", good),
xrange_entry("2-0", b"corrupt"),
xrange_entry("3-0", b"also corrupt"),
]);
let result = RedisAdapter::parse_xrange_response(response, 10, "myapp:shard:0");
assert_eq!(result.events.len(), 1);
assert_eq!(result.events[0].id, "1-0");
assert_eq!(result.next_id.as_deref(), Some("3-0"));
}
#[test]
fn test_poll_shard_empty_response_has_no_cursor() {
let result = RedisAdapter::parse_xrange_response(Value::Array(vec![]), 10, "myapp:shard:0");
assert!(result.events.is_empty());
assert!(result.next_id.is_none());
assert!(!result.has_more);
}
fn xrange_entry_with_dedup(id: &str, payload: &[u8], dedup_id: &str) -> Value {
Value::Array(vec![
Value::BulkString(id.as_bytes().to_vec()),
Value::Array(vec![
Value::BulkString(b"d".to_vec()),
Value::BulkString(payload.to_vec()),
Value::BulkString(b"dedup_id".to_vec()),
Value::BulkString(dedup_id.as_bytes().to_vec()),
]),
])
}
#[test]
fn parse_xrange_response_surfaces_dedup_id() {
let good = br#"{"r":{"k":"v"},"t":1,"s":0}"#;
let response = Value::Array(vec![
xrange_entry_with_dedup("1-0", good, "nonce123:0:42:0"),
xrange_entry("2-0", good),
]);
let result = RedisAdapter::parse_xrange_response(response, 10, "myapp:shard:0");
assert_eq!(result.events.len(), 2);
assert_eq!(
result.events[0].dedup_id.as_deref(),
Some("nonce123:0:42:0"),
"dedup_id must round-trip from wire field through StoredEvent"
);
assert!(
result.events[1].dedup_id.is_none(),
"missing dedup_id field must surface as None, not panic or fabricate"
);
}
#[test]
fn parse_xrange_response_dedup_id_order_independent() {
let good = br#"{"r":{"k":"v"},"t":1,"s":0}"#;
let entry = Value::Array(vec![
Value::BulkString(b"5-0".to_vec()),
Value::Array(vec![
Value::BulkString(b"dedup_id".to_vec()),
Value::BulkString(b"abc:0:1:0".to_vec()),
Value::BulkString(b"d".to_vec()),
Value::BulkString(good.to_vec()),
]),
]);
let result =
RedisAdapter::parse_xrange_response(Value::Array(vec![entry]), 10, "myapp:shard:0");
assert_eq!(result.events.len(), 1);
assert_eq!(result.events[0].dedup_id.as_deref(), Some("abc:0:1:0"));
}
#[test]
fn is_transient_error_recognizes_cluster_recoverables() {
use redis::{ErrorKind, ServerErrorKind};
let typed_transient: &[(ErrorKind, &str)] = &[
(ErrorKind::Server(ServerErrorKind::Moved), "MOVED redirect"),
(ErrorKind::Server(ServerErrorKind::Ask), "ASK redirect"),
(
ErrorKind::Server(ServerErrorKind::ClusterDown),
"cluster down",
),
(
ErrorKind::Server(ServerErrorKind::MasterDown),
"master down",
),
(
ErrorKind::Server(ServerErrorKind::ReadOnly),
"read-only replica",
),
(ErrorKind::Server(ServerErrorKind::BusyLoading), "loading"),
(ErrorKind::Server(ServerErrorKind::TryAgain), "try again"),
(ErrorKind::Io, "I/O error"),
(
ErrorKind::ClusterConnectionNotFound,
"no cluster connection",
),
];
for (kind, label) in typed_transient {
let err = RedisError::from((*kind, "test"));
assert!(
is_transient_error(&err),
"{} ({:?}) must classify as transient",
label,
kind,
);
}
let extension_transient: &[&str] = &["NOREPLICAS Not enough good replicas to write"];
for msg in extension_transient {
let err = RedisError::from((ErrorKind::Extension, "test", msg.to_string()));
assert!(
is_transient_error(&err),
"extension `{}` must classify as transient",
msg,
);
}
let fatal: &[ErrorKind] = &[
ErrorKind::AuthenticationFailed,
ErrorKind::UnexpectedReturnType,
ErrorKind::InvalidClientConfig,
ErrorKind::Client,
ErrorKind::Server(ServerErrorKind::ExecAbort),
ErrorKind::Server(ServerErrorKind::NoScript),
ErrorKind::Server(ServerErrorKind::CrossSlot),
ErrorKind::Server(ServerErrorKind::NoPerm),
];
for kind in fatal {
let err = RedisError::from((*kind, "test"));
assert!(
!is_transient_error(&err),
"{:?} must classify as fatal (non-transient)",
kind,
);
}
}
}