use crate::domain::error::{Result, ServiceError, StygianError};
use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
use crate::ports::{ScrapingService, ServiceInput};
use async_trait::async_trait;
use dashmap::DashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
#[derive(Clone)]
pub struct LocalWorkQueue {
pending: Arc<Mutex<VecDeque<WorkTask>>>,
state: Arc<DashMap<String, TaskStatus>>,
max_retries: u32,
}
impl LocalWorkQueue {
pub fn new() -> Self {
Self {
pending: Arc::new(Mutex::new(VecDeque::new())),
state: Arc::new(DashMap::new()),
max_retries: 3,
}
}
pub fn with_max_retries(max_retries: u32) -> Self {
Self {
pending: Arc::new(Mutex::new(VecDeque::new())),
state: Arc::new(DashMap::new()),
max_retries,
}
}
}
impl Default for LocalWorkQueue {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl WorkQueuePort for LocalWorkQueue {
async fn enqueue(&self, task: WorkTask) -> Result<()> {
debug!(task_id = %task.id, node = %task.node_name, "enqueuing task");
self.state.insert(task.id.clone(), TaskStatus::Pending);
self.pending.lock().await.push_back(task);
Ok(())
}
async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
let task = self.pending.lock().await.pop_front();
if let Some(ref t) = task {
debug!(task_id = %t.id, "dequeued task");
self.state.insert(
t.id.clone(),
TaskStatus::InProgress {
worker_id: "local".to_string(),
},
);
}
Ok(task)
}
async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
info!(task_id = %task_id, "task acknowledged (completed)");
self.state
.insert(task_id.to_string(), TaskStatus::Completed { output });
Ok(())
}
async fn fail(&self, task_id: &str, error: &str) -> Result<()> {
let attempt = self
.state
.get(task_id)
.map_or(0, |status| match status.value() {
TaskStatus::Failed { attempt, .. } => *attempt,
_ => 0,
});
if attempt >= self.max_retries {
warn!(task_id = %task_id, %error, "task dead-lettered after max retries");
self.state.insert(
task_id.to_string(),
TaskStatus::DeadLetter {
error: error.to_string(),
},
);
} else {
error!(task_id = %task_id, attempt, %error, "task failed, will retry");
self.state.insert(
task_id.to_string(),
TaskStatus::Failed {
error: error.to_string(),
attempt: attempt + 1,
},
);
}
Ok(())
}
async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
Ok(self.state.get(task_id).map(|s| s.value().clone()))
}
async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
let mut results = Vec::new();
for entry in self.state.iter() {
let key = entry.key();
if !key.starts_with(pipeline_id) {
continue;
}
if let TaskStatus::Completed { ref output } = *entry.value() {
let node_name = key.split("::").nth(1).unwrap_or(key).to_string();
results.push((node_name, output.clone()));
}
}
Ok(results)
}
async fn pending_count(&self) -> Result<usize> {
Ok(self.pending.lock().await.len())
}
}
pub struct DistributedDagExecutor<Q: WorkQueuePort> {
queue: Arc<Q>,
worker_concurrency: usize,
}
impl<Q: WorkQueuePort + 'static> DistributedDagExecutor<Q> {
pub fn new(queue: Arc<Q>, worker_concurrency: usize) -> Self {
Self {
queue,
worker_concurrency: worker_concurrency.max(1),
}
}
pub async fn execute_wave(
&self,
pipeline_id: &str,
tasks: Vec<WorkTask>,
services: &std::collections::HashMap<String, Arc<dyn ScrapingService>>,
) -> Result<Vec<(String, serde_json::Value)>> {
let expected = tasks.len();
if expected == 0 {
return Ok(Vec::new());
}
for task in tasks {
self.queue.enqueue(task).await?;
}
let queue = Arc::clone(&self.queue);
let services: Arc<std::collections::HashMap<String, Arc<dyn ScrapingService>>> =
Arc::new(services.clone());
let concurrency = self.worker_concurrency.min(expected);
let mut handles = tokio::task::JoinSet::new();
for _ in 0..concurrency {
let q = Arc::clone(&queue);
let svcs = Arc::clone(&services);
handles.spawn(async move {
let mut worked = 0usize;
loop {
match q.try_dequeue().await {
Ok(Some(task)) => {
let service_input = ServiceInput {
url: task
.input
.get("url")
.and_then(serde_json::Value::as_str)
.unwrap_or("")
.to_string(),
params: task.input.clone(),
};
let output = match svcs.get(&task.node_name) {
Some(svc) => svc.execute(service_input.clone()).await,
None => {
match svcs.get("default") {
Some(svc) => svc.execute(service_input).await,
None => Err(StygianError::Service(
ServiceError::Unavailable(format!(
"service '{}' not registered",
task.node_name
)),
)),
}
}
};
match output {
Ok(out) => {
let val = serde_json::json!({
"data": out.data,
"metadata": out.metadata,
});
let _ = q.acknowledge(&task.id, val).await;
}
Err(e) => {
let _ = q.fail(&task.id, &e.to_string()).await;
}
}
worked += 1;
}
Ok(None) => break, Err(e) => {
error!(error = %e, "worker dequeue error");
break;
}
}
}
worked
});
}
while handles.join_next().await.is_some() {}
self.queue.collect_results(pipeline_id).await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
use serde_json::json;
fn make_task(pipeline_id: &str, node_name: &str, seq: u32) -> WorkTask {
WorkTask {
id: format!("{pipeline_id}::{node_name}::{seq:04}"),
pipeline_id: pipeline_id.to_string(),
node_name: node_name.to_string(),
input: json!({"url": "https://example.com"}),
wave: 0,
attempt: 0,
idempotency_key: format!("ik-{seq}"),
}
}
#[tokio::test]
async fn enqueue_dequeue_roundtrip() {
let queue = LocalWorkQueue::new();
assert_eq!(queue.pending_count().await.unwrap(), 0);
queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
queue.enqueue(make_task("p1", "parse", 2)).await.unwrap();
assert_eq!(queue.pending_count().await.unwrap(), 2);
let t1 = queue.try_dequeue().await.unwrap().unwrap();
assert_eq!(t1.node_name, "fetch");
assert_eq!(queue.pending_count().await.unwrap(), 1);
let t2 = queue.try_dequeue().await.unwrap().unwrap();
assert_eq!(t2.node_name, "parse");
assert_eq!(queue.pending_count().await.unwrap(), 0);
let empty = queue.try_dequeue().await.unwrap();
assert!(empty.is_none());
}
#[tokio::test]
async fn acknowledge_records_completed_status() {
let queue = LocalWorkQueue::new();
queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
let task = queue.try_dequeue().await.unwrap().unwrap();
queue
.acknowledge(&task.id, json!({"data": "hello", "status": 200}))
.await
.unwrap();
let status = queue.status(&task.id).await.unwrap().unwrap();
assert!(matches!(status, TaskStatus::Completed { .. }));
}
#[tokio::test]
async fn fail_dead_letters_after_max_retries() {
let queue = LocalWorkQueue::with_max_retries(2);
queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
let task = queue.try_dequeue().await.unwrap().unwrap();
queue.fail(&task.id, "err 1").await.unwrap();
queue.fail(&task.id, "err 2").await.unwrap();
queue.fail(&task.id, "err 3").await.unwrap();
let status = queue.status(&task.id).await.unwrap().unwrap();
assert!(matches!(status, TaskStatus::DeadLetter { .. }));
}
#[tokio::test]
async fn collect_results_filters_by_pipeline_id() {
let queue = LocalWorkQueue::new();
let t1 = make_task("pipeline-A", "node1", 1);
let t2 = make_task("pipeline-B", "node1", 2);
queue.enqueue(t1.clone()).await.unwrap();
queue.enqueue(t2.clone()).await.unwrap();
let deq1 = queue.try_dequeue().await.unwrap().unwrap();
let deq2 = queue.try_dequeue().await.unwrap().unwrap();
queue
.acknowledge(&deq1.id, json!({"data": "A-result"}))
.await
.unwrap();
queue
.acknowledge(&deq2.id, json!({"data": "B-result"}))
.await
.unwrap();
let results_a = queue.collect_results("pipeline-A").await.unwrap();
assert_eq!(results_a.len(), 1);
assert_eq!(results_a[0].1["data"], "A-result");
let results_b = queue.collect_results("pipeline-B").await.unwrap();
assert_eq!(results_b.len(), 1);
assert_eq!(results_b[0].1["data"], "B-result");
}
#[tokio::test]
async fn distributed_executor_runs_tasks() {
use crate::adapters::noop::NoopService;
use std::collections::HashMap;
let queue = Arc::new(LocalWorkQueue::new());
let executor = DistributedDagExecutor::new(Arc::clone(&queue), 2);
let mut services: HashMap<String, Arc<dyn ScrapingService>> = HashMap::new();
services.insert("noop".to_string(), Arc::new(NoopService));
let tasks = vec![
make_task("p1", "noop", 1),
make_task("p1", "noop", 2),
make_task("p1", "noop", 3),
];
let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
assert!(results.len() <= 3);
}
}