use crate::domain::error::{CacheError, Result, StygianError};
use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
use async_trait::async_trait;
use deadpool_redis::{Config as PoolConfig, Pool, Runtime};
use redis::AsyncCommands;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct RedisWorkQueueConfig {
pub url: String,
pub stream_name: String,
pub group_name: String,
pub consumer_name: String,
pub pool_size: usize,
pub max_retries: u32,
pub block_timeout_ms: usize,
pub idle_threshold_ms: usize,
}
impl Default for RedisWorkQueueConfig {
fn default() -> Self {
let host = std::env::var("HOSTNAME").unwrap_or_else(|_| "local".to_string());
let consumer_name = format!("{}:{}", host, std::process::id());
Self {
url: "redis://127.0.0.1:6379".into(),
stream_name: "stygian:tasks".into(),
group_name: "stygian-workers".into(),
consumer_name,
pool_size: 8,
max_retries: 3,
block_timeout_ms: 1000,
idle_threshold_ms: 30_000,
}
}
}
pub struct RedisWorkQueue {
pool: Pool,
config: RedisWorkQueueConfig,
}
impl RedisWorkQueue {
pub async fn new(config: RedisWorkQueueConfig) -> Result<Self> {
let pool_cfg = PoolConfig::from_url(&config.url);
let pool = pool_cfg
.builder()
.map(|b| b.max_size(config.pool_size))
.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!(
"failed to build Redis pool: {e}"
)))
})?
.runtime(Runtime::Tokio1)
.build()
.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!(
"failed to build Redis pool: {e}"
)))
})?;
let queue = Self { pool, config };
queue.ensure_consumer_group().await?;
Ok(queue)
}
pub async fn from_pool(pool: Pool, config: RedisWorkQueueConfig) -> Result<Self> {
let queue = Self { pool, config };
queue.ensure_consumer_group().await?;
Ok(queue)
}
async fn ensure_consumer_group(&self) -> Result<()> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
})?;
let result: redis::RedisResult<String> = redis::cmd("XGROUP")
.arg("CREATE")
.arg(&self.config.stream_name)
.arg(&self.config.group_name)
.arg("$")
.arg("MKSTREAM")
.query_async(&mut *conn)
.await;
match result {
Ok(_) => {
debug!(
stream = %self.config.stream_name,
group = %self.config.group_name,
"created consumer group"
);
}
Err(e) if e.to_string().contains("BUSYGROUP") => {
debug!(
stream = %self.config.stream_name,
group = %self.config.group_name,
"consumer group already exists"
);
}
Err(e) => {
return Err(StygianError::Cache(CacheError::WriteFailed(format!(
"XGROUP CREATE failed: {e}"
))));
}
}
Ok(())
}
fn task_meta_key(&self, task_id: &str) -> String {
format!("{}:tasks:{}", self.config.stream_name, task_id)
}
fn result_key(&self, task_id: &str) -> String {
format!("{}:results:{}", self.config.stream_name, task_id)
}
fn dlq_stream(&self) -> String {
format!("{}:dlq", self.config.stream_name)
}
pub async fn reclaim_stuck_tasks(&self) -> Result<Vec<WorkTask>> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
})?;
let pending: Vec<Vec<redis::Value>> = redis::cmd("XPENDING")
.arg(&self.config.stream_name)
.arg(&self.config.group_name)
.arg("-")
.arg("+")
.arg(100_i64)
.query_async(&mut *conn)
.await
.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("XPENDING failed: {e}")))
})?;
let mut reclaimed = Vec::new();
for entry in &pending {
if entry.len() < 3 {
continue;
}
let Some(raw_msg_id) = entry.first() else {
continue;
};
let redis::Value::BulkString(b) = raw_msg_id else {
continue;
};
let msg_id = String::from_utf8_lossy(b.as_slice()).to_string();
let idle_ms: usize = match entry.get(2) {
Some(redis::Value::Int(n)) => match usize::try_from(*n) {
Ok(v) => v,
Err(_) => continue,
},
_ => continue,
};
if idle_ms < self.config.idle_threshold_ms {
continue;
}
let claimed: redis::RedisResult<Vec<redis::Value>> = redis::cmd("XCLAIM")
.arg(&self.config.stream_name)
.arg(&self.config.group_name)
.arg(&self.config.consumer_name)
.arg(self.config.idle_threshold_ms)
.arg(&msg_id)
.query_async(&mut *conn)
.await;
if let Ok(messages) = claimed {
for msg in &messages {
if let Some(task) = Self::parse_stream_message(msg) {
info!(task_id = %task.id, idle_ms, "reclaimed stuck task");
reclaimed.push(task);
}
}
}
}
Ok(reclaimed)
}
fn parse_stream_message(msg: &redis::Value) -> Option<WorkTask> {
let redis::Value::Array(arr) = msg else {
return None;
};
if arr.len() < 2 {
return None;
}
let Some(redis::Value::Array(fields)) = arr.get(1) else {
return None;
};
let mut payload: Option<&[u8]> = None;
let mut i = 0;
while i + 1 < fields.len() {
if let Some(redis::Value::BulkString(key)) = fields.get(i)
&& key == b"payload"
&& let Some(redis::Value::BulkString(val)) = fields.get(i + 1)
{
payload = Some(val);
}
i += 2;
}
let payload = payload?;
serde_json::from_slice(payload).ok()
}
}
#[async_trait]
impl WorkQueuePort for RedisWorkQueue {
async fn enqueue(&self, task: WorkTask) -> Result<()> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
})?;
let payload = serde_json::to_string(&task).map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!(
"task serialisation failed: {e}"
)))
})?;
let _msg_id: String = redis::cmd("XADD")
.arg(&self.config.stream_name)
.arg("*")
.arg("payload")
.arg(&payload)
.query_async(&mut *conn)
.await
.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!("XADD failed: {e}")))
})?;
let meta_key = self.task_meta_key(&task.id);
let meta = serde_json::json!({
"pipeline_id": task.pipeline_id,
"node_name": task.node_name,
"attempt": task.attempt,
"status": "pending",
});
conn.set::<_, _, ()>(&meta_key, meta.to_string())
.await
.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!(
"SET task meta failed: {e}"
)))
})?;
debug!(task_id = %task.id, node = %task.node_name, "enqueued task to Redis stream");
Ok(())
}
async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
})?;
let result: redis::RedisResult<redis::Value> = redis::cmd("XREADGROUP")
.arg("GROUP")
.arg(&self.config.group_name)
.arg(&self.config.consumer_name)
.arg("COUNT")
.arg(1_i64)
.arg("BLOCK")
.arg(self.config.block_timeout_ms)
.arg("STREAMS")
.arg(&self.config.stream_name)
.arg(">")
.query_async(&mut *conn)
.await;
let value = match result {
Ok(v) => v,
Err(e) => {
if e.to_string().contains("nil") {
return Ok(None);
}
return Err(StygianError::Cache(CacheError::ReadFailed(format!(
"XREADGROUP failed: {e}"
))));
}
};
let streams = match &value {
redis::Value::Array(s) if !s.is_empty() => s,
_ => return Ok(None),
};
let stream_data = match streams.first() {
Some(redis::Value::Array(s)) if s.len() >= 2 => s,
_ => return Ok(None),
};
let messages = match stream_data.get(1) {
Some(redis::Value::Array(m)) if !m.is_empty() => m,
_ => return Ok(None),
};
if let Some(first_message) = messages.first()
&& let Some(task) = Self::parse_stream_message(first_message)
{
let meta_key = self.task_meta_key(&task.id);
let meta = serde_json::json!({
"pipeline_id": task.pipeline_id,
"node_name": task.node_name,
"attempt": task.attempt,
"status": "in_progress",
"worker_id": self.config.consumer_name,
});
let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
debug!(task_id = %task.id, consumer = %self.config.consumer_name, "dequeued task");
return Ok(Some(task));
}
Ok(None)
}
async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
})?;
let result_key = self.result_key(task_id);
let output_str = output.to_string();
conn.set::<_, _, ()>(&result_key, &output_str)
.await
.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!("SET result failed: {e}")))
})?;
let meta_key = self.task_meta_key(task_id);
let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
if let Some(raw) = meta_raw
&& let Ok(mut meta) = serde_json::from_str::<serde_json::Value>(&raw)
{
if let Some(obj) = meta.as_object_mut() {
obj.insert("status".to_string(), serde_json::json!("completed"));
}
let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
}
info!(task_id = %task_id, "task acknowledged (completed)");
Ok(())
}
async fn fail(&self, task_id: &str, error_msg: &str) -> Result<()> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
})?;
let meta_key = self.task_meta_key(task_id);
let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
let attempt = meta_raw
.as_ref()
.and_then(|raw| serde_json::from_str::<serde_json::Value>(raw).ok())
.and_then(|m| m.get("attempt").and_then(serde_json::Value::as_u64))
.and_then(|n| u32::try_from(n).ok())
.unwrap_or(0);
if attempt >= self.config.max_retries {
let dlq = self.dlq_stream();
let dlq_payload = serde_json::json!({
"task_id": task_id,
"error": error_msg,
"attempt": attempt,
});
let _: redis::RedisResult<String> = redis::cmd("XADD")
.arg(&dlq)
.arg("*")
.arg("payload")
.arg(dlq_payload.to_string())
.query_async(&mut *conn)
.await;
let meta = serde_json::json!({
"status": "dead_letter",
"error": error_msg,
"attempt": attempt,
});
let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
warn!(task_id = %task_id, %error_msg, attempt, "task dead-lettered after max retries");
} else {
let meta = serde_json::json!({
"status": "failed",
"error": error_msg,
"attempt": attempt + 1,
});
let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
error!(task_id = %task_id, attempt = attempt + 1, %error_msg, "task failed, will retry");
}
Ok(())
}
async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
})?;
let meta_key = self.task_meta_key(task_id);
let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
let Some(raw) = meta_raw else {
return Ok(None);
};
let meta: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!(
"task meta deserialise failed: {e}"
)))
})?;
let status_str = meta
.get("status")
.and_then(serde_json::Value::as_str)
.unwrap_or("pending");
let status = match status_str {
"in_progress" => TaskStatus::InProgress {
worker_id: meta
.get("worker_id")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string(),
},
"completed" => {
let result_key = self.result_key(task_id);
let output_raw: Option<String> = conn.get(&result_key).await.unwrap_or(None);
let output = output_raw
.and_then(|r| serde_json::from_str(&r).ok())
.unwrap_or(serde_json::Value::Null);
TaskStatus::Completed { output }
}
"failed" => TaskStatus::Failed {
error: meta
.get("error")
.and_then(serde_json::Value::as_str)
.unwrap_or("")
.to_string(),
attempt: meta
.get("attempt")
.and_then(serde_json::Value::as_u64)
.and_then(|n| u32::try_from(n).ok())
.unwrap_or(0),
},
"dead_letter" => TaskStatus::DeadLetter {
error: meta
.get("error")
.and_then(serde_json::Value::as_str)
.unwrap_or("")
.to_string(),
},
_ => TaskStatus::Pending,
};
Ok(Some(status))
}
async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
})?;
let pattern = format!("{}:tasks:*", self.config.stream_name);
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern)
.query_async(&mut *conn)
.await
.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("KEYS scan failed: {e}")))
})?;
let mut results = Vec::new();
for key in &keys {
let meta_raw: Option<String> = conn.get(key).await.unwrap_or(None);
let Some(raw) = meta_raw else { continue };
let Ok(meta) = serde_json::from_str::<serde_json::Value>(&raw) else {
continue;
};
if meta.get("pipeline_id").and_then(serde_json::Value::as_str) != Some(pipeline_id) {
continue;
}
if meta.get("status").and_then(serde_json::Value::as_str) != Some("completed") {
continue;
}
let node_name = meta
.get("node_name")
.and_then(serde_json::Value::as_str)
.unwrap_or("")
.to_string();
let task_id = key.rsplit(':').next().unwrap_or("");
let result_key = self.result_key(task_id);
let output_raw: Option<String> = conn.get(&result_key).await.unwrap_or(None);
let output = output_raw
.and_then(|r| serde_json::from_str(&r).ok())
.unwrap_or(serde_json::Value::Null);
results.push((node_name, output));
}
Ok(results)
}
async fn pending_count(&self) -> Result<usize> {
let mut conn = self.pool.get().await.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
})?;
let len: usize = redis::cmd("XLEN")
.arg(&self.config.stream_name)
.query_async(&mut *conn)
.await
.map_err(|e| {
StygianError::Cache(CacheError::ReadFailed(format!("XLEN failed: {e}")))
})?;
Ok(len)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_task_serialisation_roundtrip() -> std::result::Result<(), Box<dyn std::error::Error>> {
let task = WorkTask {
id: "t-1".to_string(),
pipeline_id: "p-1".to_string(),
node_name: "fetch".to_string(),
input: json!({"url": "https://example.com"}),
wave: 0,
attempt: 0,
idempotency_key: "ik-1".to_string(),
};
let serialised = serde_json::to_string(&task)?;
let deserialised: WorkTask = serde_json::from_str(&serialised)?;
assert_eq!(deserialised.id, task.id);
assert_eq!(deserialised.pipeline_id, task.pipeline_id);
assert_eq!(deserialised.node_name, task.node_name);
assert_eq!(deserialised.input, task.input);
assert_eq!(deserialised.wave, task.wave);
assert_eq!(deserialised.attempt, task.attempt);
assert_eq!(deserialised.idempotency_key, task.idempotency_key);
Ok(())
}
#[test]
fn test_default_config() {
let cfg = RedisWorkQueueConfig::default();
assert_eq!(cfg.url, "redis://127.0.0.1:6379");
assert_eq!(cfg.stream_name, "stygian:tasks");
assert_eq!(cfg.group_name, "stygian-workers");
assert_eq!(cfg.max_retries, 3);
assert_eq!(cfg.block_timeout_ms, 1000);
assert_eq!(cfg.idle_threshold_ms, 30_000);
assert!(!cfg.consumer_name.is_empty());
}
#[test]
fn test_key_generation() {
let stream_name = "stygian:tasks";
let task_id = "abc-123";
assert_eq!(
format!("{stream_name}:tasks:{task_id}"),
"stygian:tasks:tasks:abc-123"
);
assert_eq!(
format!("{stream_name}:results:{task_id}"),
"stygian:tasks:results:abc-123"
);
assert_eq!(format!("{stream_name}:dlq"), "stygian:tasks:dlq");
}
#[test]
fn test_parse_stream_message_empty() {
let msg = redis::Value::Nil;
assert!(RedisWorkQueue::parse_stream_message(&msg).is_none());
}
#[test]
fn test_parse_stream_message_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
let task = WorkTask {
id: "t-1".to_string(),
pipeline_id: "p-1".to_string(),
node_name: "fetch".to_string(),
input: json!({"url": "https://example.com"}),
wave: 0,
attempt: 0,
idempotency_key: "ik-1".to_string(),
};
let payload = serde_json::to_vec(&task)?;
let msg = redis::Value::Array(vec![
redis::Value::BulkString(b"1234-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(payload),
]),
]);
let parsed = RedisWorkQueue::parse_stream_message(&msg)
.ok_or_else(|| std::io::Error::other("expected parse_stream_message to return task"))?;
assert_eq!(parsed.id, "t-1");
assert_eq!(parsed.node_name, "fetch");
Ok(())
}
#[test]
fn test_consumer_name_is_unique() {
let cfg1 = RedisWorkQueueConfig::default();
assert!(cfg1.consumer_name.contains(&std::process::id().to_string()));
}
}