Skip to main content

stygian_graph/adapters/
distributed.rs

1//! Distributed execution adapters
2//!
3//! Provides [`LocalWorkQueue`](distributed::LocalWorkQueue) (in-process, for single-node and testing) and
4//! [`DistributedDagExecutor`](distributed::DistributedDagExecutor) (wraps any [`WorkQueuePort`](crate::ports::work_queue::WorkQueuePort) to distribute DAG
5//! waves across workers).
6//!
7//! # Design
8//!
9//! ```text
10//! DistributedDagExecutor
11//!    │
12//!    ├─ resolve wave N (topological sort already done by DagExecutor)
13//!    ├─ enqueue every node in the wave as a WorkTask
14//!    ├─ spawn worker tasks that call try_dequeue + service.execute
15//!    └─ collect_results when all tasks in wave are Completed
16//! ```
17
18use crate::domain::error::{Result, ServiceError, StygianError};
19use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
20use crate::ports::{ScrapingService, ServiceInput};
21use async_trait::async_trait;
22use dashmap::DashMap;
23use std::collections::VecDeque;
24use std::sync::Arc;
25use tokio::sync::Mutex;
26use tracing::{debug, error, info, warn};
27
28// ─────────────────────────────────────────────────────────────────────────────
29// LocalWorkQueue
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// In-memory work queue for single-node deployments and unit tests.
33///
34/// All state is stored in `Arc`-wrapped structures so the queue can be cheaply
35/// cloned and shared across worker tasks.
36///
37/// # Example
38///
39/// ```
40/// use stygian_graph::adapters::distributed::LocalWorkQueue;
41/// use stygian_graph::ports::work_queue::{WorkQueuePort, WorkTask};
42/// use serde_json::json;
43///
44/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
45/// let queue = LocalWorkQueue::new();
46/// assert_eq!(queue.pending_count().await.unwrap(), 0);
47///
48/// let task = WorkTask {
49///     id: "t-1".to_string(),
50///     pipeline_id: "p-1".to_string(),
51///     node_name: "fetch".to_string(),
52///     input: json!({"url": "https://example.com"}),
53///     wave: 0,
54///     attempt: 0,
55///     idempotency_key: "ik-t1".to_string(),
56/// };
57/// queue.enqueue(task).await.unwrap();
58/// assert_eq!(queue.pending_count().await.unwrap(), 1);
59///
60/// let dequeued = queue.try_dequeue().await.unwrap().unwrap();
61/// assert_eq!(dequeued.node_name, "fetch");
62/// # });
63/// ```
64#[derive(Clone)]
65pub struct LocalWorkQueue {
66    pending: Arc<Mutex<VecDeque<WorkTask>>>,
67    state: Arc<DashMap<String, TaskStatus>>,
68    /// Max retries before a task moves to the dead-letter state
69    max_retries: u32,
70}
71
72impl LocalWorkQueue {
73    /// Create a new `LocalWorkQueue` with default settings (`max_retries = 3`).
74    #[must_use]
75    pub fn new() -> Self {
76        Self {
77            pending: Arc::new(Mutex::new(VecDeque::new())),
78            state: Arc::new(DashMap::new()),
79            max_retries: 3,
80        }
81    }
82
83    /// Create a `LocalWorkQueue` with a custom retry limit.
84    ///
85    /// # Example
86    ///
87    /// ```
88    /// use stygian_graph::adapters::distributed::LocalWorkQueue;
89    ///
90    /// let queue = LocalWorkQueue::with_max_retries(5);
91    /// ```
92    #[must_use]
93    pub fn with_max_retries(max_retries: u32) -> Self {
94        Self {
95            pending: Arc::new(Mutex::new(VecDeque::new())),
96            state: Arc::new(DashMap::new()),
97            max_retries,
98        }
99    }
100}
101
102impl Default for LocalWorkQueue {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108#[async_trait]
109impl WorkQueuePort for LocalWorkQueue {
110    async fn enqueue(&self, task: WorkTask) -> Result<()> {
111        debug!(task_id = %task.id, node = %task.node_name, "enqueuing task");
112        self.state.insert(task.id.clone(), TaskStatus::Pending);
113        self.pending.lock().await.push_back(task);
114        Ok(())
115    }
116
117    async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
118        let task = self.pending.lock().await.pop_front();
119        if let Some(ref t) = task {
120            debug!(task_id = %t.id, "dequeued task");
121            self.state.insert(
122                t.id.clone(),
123                TaskStatus::InProgress {
124                    worker_id: "local".to_string(),
125                },
126            );
127        }
128        Ok(task)
129    }
130
131    async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
132        info!(task_id = %task_id, "task acknowledged (completed)");
133        self.state
134            .insert(task_id.to_string(), TaskStatus::Completed { output });
135        Ok(())
136    }
137
138    async fn fail(&self, task_id: &str, error: &str) -> Result<()> {
139        let attempt = self
140            .state
141            .get(task_id)
142            .map_or(0, |status| match status.value() {
143                TaskStatus::Failed { attempt, .. } => *attempt,
144                _ => 0,
145            });
146
147        if attempt >= self.max_retries {
148            warn!(task_id = %task_id, %error, "task dead-lettered after max retries");
149            self.state.insert(
150                task_id.to_string(),
151                TaskStatus::DeadLetter {
152                    error: error.to_string(),
153                },
154            );
155        } else {
156            error!(task_id = %task_id, attempt, %error, "task failed, will retry");
157            self.state.insert(
158                task_id.to_string(),
159                TaskStatus::Failed {
160                    error: error.to_string(),
161                    attempt: attempt + 1,
162                },
163            );
164        }
165        Ok(())
166    }
167
168    async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
169        Ok(self.state.get(task_id).map(|s| s.value().clone()))
170    }
171
172    async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
173        // We need to find tasks by pipeline_id — the state map is keyed by
174        // task_id so we collect all Completed entries whose pipeline_id matches.
175        // LocalWorkQueue stores the task in the pending queue; once dequeued
176        // we lose the pipeline_id mapping. We use a secondary index maintained
177        // in the pipeline_tasks map instead.
178        //
179        // For simplicity in the local adapter, we scan all state entries and
180        // match on pipeline_id encoded in the task_id prefix convention
181        // "pipeline_id::node_name::task_id".
182        let mut results = Vec::new();
183        for entry in self.state.iter() {
184            let key = entry.key();
185            // Convention: task_id == "{pipeline_id}::{node_name}::{ulid}"
186            if !key.starts_with(pipeline_id) {
187                continue;
188            }
189            if let TaskStatus::Completed { ref output } = *entry.value() {
190                // Extract node_name from the middle segment
191                let node_name = key.split("::").nth(1).unwrap_or(key).to_string();
192                results.push((node_name, output.clone()));
193            }
194        }
195        Ok(results)
196    }
197
198    async fn pending_count(&self) -> Result<usize> {
199        Ok(self.pending.lock().await.len())
200    }
201}
202
203// ─────────────────────────────────────────────────────────────────────────────
204// DistributedDagExecutor
205// ─────────────────────────────────────────────────────────────────────────────
206
207/// Executes a DAG wave using a [`WorkQueuePort`] to distribute node-level tasks
208/// across workers.
209///
210/// Workers are spawned as Tokio tasks that pull from the queue, call the
211/// appropriate service, and acknowledge results.  For local development the
212/// [`LocalWorkQueue`] is used; in production any queue backend can be plugged
213/// in without changing this executor.
214///
215/// # Example
216///
217/// ```
218/// use stygian_graph::adapters::distributed::{DistributedDagExecutor, LocalWorkQueue};
219/// use stygian_graph::ports::work_queue::WorkTask;
220///
221/// use stygian_graph::adapters::noop::NoopService;
222/// use serde_json::json;
223/// use std::sync::Arc;
224/// use std::collections::HashMap;
225///
226/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
227/// let queue = Arc::new(LocalWorkQueue::new());
228/// let executor = DistributedDagExecutor::new(queue, 4);
229///
230/// let mut services: HashMap<String, Arc<dyn stygian_graph::ports::ScrapingService>> =
231///     HashMap::new();
232/// services.insert("noop".to_string(), Arc::new(NoopService));
233///
234/// let tasks = vec![WorkTask {
235///     id: "p1::fetch::01".to_string(),
236///     pipeline_id: "p1".to_string(),
237///     node_name: "fetch".to_string(),
238///     input: json!({"url": "https://example.com"}),
239///     wave: 0,
240///     attempt: 0,
241///     idempotency_key: "ik-01".to_string(),
242/// }];
243///
244/// let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
245/// assert!(!results.is_empty() || results.is_empty()); // noop returns empty
246/// # });
247/// ```
248pub struct DistributedDagExecutor<Q: WorkQueuePort> {
249    queue: Arc<Q>,
250    worker_concurrency: usize,
251}
252
253impl<Q: WorkQueuePort + 'static> DistributedDagExecutor<Q> {
254    /// Create a new executor with the given work queue and worker concurrency.
255    ///
256    /// `worker_concurrency` controls how many parallel worker tasks drain the
257    /// queue.
258    pub fn new(queue: Arc<Q>, worker_concurrency: usize) -> Self {
259        Self {
260            queue,
261            worker_concurrency: worker_concurrency.max(1),
262        }
263    }
264
265    /// Execute a single wave of tasks, distributing them across workers.
266    ///
267    /// Returns `(node_name, output)` pairs for all tasks in the wave.
268    ///
269    /// # Panics
270    ///
271    /// Panics if an internal `Mutex` is poisoned (i.e. another thread panicked
272    /// while holding the lock). Treat this as unrecoverable.
273    ///
274    /// # Errors
275    ///
276    /// Returns [`StygianError`] when a service reports a failure, the executor
277    /// is shut down, or a worker task cannot be enqueued.
278    pub async fn execute_wave(
279        &self,
280        pipeline_id: &str,
281        tasks: Vec<WorkTask>,
282        services: &std::collections::HashMap<String, Arc<dyn ScrapingService>>,
283    ) -> Result<Vec<(String, serde_json::Value)>> {
284        let expected = tasks.len();
285        if expected == 0 {
286            return Ok(Vec::new());
287        }
288
289        // Enqueue all tasks in this wave
290        for task in tasks {
291            self.queue.enqueue(task).await?;
292        }
293
294        // Spawn workers to drain the queue
295        let queue = Arc::clone(&self.queue);
296        let services: Arc<std::collections::HashMap<String, Arc<dyn ScrapingService>>> =
297            Arc::new(services.clone());
298
299        let concurrency = self.worker_concurrency.min(expected);
300        let mut handles = tokio::task::JoinSet::new();
301
302        for _ in 0..concurrency {
303            let q = Arc::clone(&queue);
304            let svcs = Arc::clone(&services);
305            handles.spawn(async move {
306                // Each worker drains the queue until it finds nothing
307                let mut worked = 0usize;
308                loop {
309                    match q.try_dequeue().await {
310                        Ok(Some(task)) => {
311                            let service_input = ServiceInput {
312                                url: task
313                                    .input
314                                    .get("url")
315                                    .and_then(serde_json::Value::as_str)
316                                    .unwrap_or("")
317                                    .to_string(),
318                                params: task.input.clone(),
319                            };
320                            let output = match svcs.get(&task.node_name) {
321                                Some(svc) => svc.execute(service_input.clone()).await,
322                                None => {
323                                    // Fallback: look for a service named "default"
324                                    match svcs.get("default") {
325                                        Some(svc) => svc.execute(service_input).await,
326                                        None => Err(StygianError::Service(
327                                            ServiceError::Unavailable(format!(
328                                                "service '{}' not registered",
329                                                task.node_name
330                                            )),
331                                        )),
332                                    }
333                                }
334                            };
335                            match output {
336                                Ok(out) => {
337                                    // codeql[rust/unused-variable] - `out` is consumed by the `json!` macro below.
338                                    let val = serde_json::json!({
339                                        "data": out.data,
340                                        "metadata": out.metadata,
341                                    });
342                                    let _ = q.acknowledge(&task.id, val).await;
343                                }
344                                Err(e) => {
345                                    let _ = q.fail(&task.id, &e.to_string()).await;
346                                }
347                            }
348                            worked += 1;
349                        }
350                        Ok(None) => break, // queue empty
351                        Err(e) => {
352                            // codeql[rust/unused-variable] - `e` is used via the structured field below.
353                            error!(error = %e, "worker dequeue error");
354                            break;
355                        }
356                    }
357                }
358                worked
359            });
360        }
361
362        // Wait for all workers
363        while handles.join_next().await.is_some() {}
364
365        // Collect results
366        self.queue.collect_results(pipeline_id).await
367    }
368}
369
370// ─────────────────────────────────────────────────────────────────────────────
371// Tests
372// ─────────────────────────────────────────────────────────────────────────────
373
374#[cfg(test)]
375#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
376mod tests {
377    use super::*;
378    use serde_json::json;
379
380    fn make_task(pipeline_id: &str, node_name: &str, seq: u32) -> WorkTask {
381        WorkTask {
382            id: format!("{pipeline_id}::{node_name}::{seq:04}"),
383            pipeline_id: pipeline_id.to_string(),
384            node_name: node_name.to_string(),
385            input: json!({"url": "https://example.com"}),
386            wave: 0,
387            attempt: 0,
388            idempotency_key: format!("ik-{seq}"),
389        }
390    }
391
392    #[tokio::test]
393    async fn enqueue_dequeue_roundtrip() {
394        let queue = LocalWorkQueue::new();
395        assert_eq!(queue.pending_count().await.unwrap(), 0);
396
397        queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
398        queue.enqueue(make_task("p1", "parse", 2)).await.unwrap();
399        assert_eq!(queue.pending_count().await.unwrap(), 2);
400
401        let t1 = queue.try_dequeue().await.unwrap().unwrap();
402        assert_eq!(t1.node_name, "fetch");
403        assert_eq!(queue.pending_count().await.unwrap(), 1);
404
405        let t2 = queue.try_dequeue().await.unwrap().unwrap();
406        assert_eq!(t2.node_name, "parse");
407        assert_eq!(queue.pending_count().await.unwrap(), 0);
408
409        // Queue empty — returns None
410        let empty = queue.try_dequeue().await.unwrap();
411        assert!(empty.is_none());
412    }
413
414    #[tokio::test]
415    async fn acknowledge_records_completed_status() {
416        let queue = LocalWorkQueue::new();
417        queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
418        let task = queue.try_dequeue().await.unwrap().unwrap();
419        queue
420            .acknowledge(&task.id, json!({"data": "hello", "status": 200}))
421            .await
422            .unwrap();
423
424        let status = queue.status(&task.id).await.unwrap().unwrap();
425        assert!(matches!(status, TaskStatus::Completed { .. }));
426    }
427
428    #[tokio::test]
429    async fn fail_dead_letters_after_max_retries() {
430        let queue = LocalWorkQueue::with_max_retries(2);
431        queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
432        let task = queue.try_dequeue().await.unwrap().unwrap();
433
434        queue.fail(&task.id, "err 1").await.unwrap();
435        queue.fail(&task.id, "err 2").await.unwrap();
436        // attempt 2 == max_retries → dead-letter
437        queue.fail(&task.id, "err 3").await.unwrap();
438
439        let status = queue.status(&task.id).await.unwrap().unwrap();
440        assert!(matches!(status, TaskStatus::DeadLetter { .. }));
441    }
442
443    #[tokio::test]
444    async fn collect_results_filters_by_pipeline_id() {
445        let queue = LocalWorkQueue::new();
446
447        // Two pipelines, one task each
448        let t1 = make_task("pipeline-A", "node1", 1);
449        let t2 = make_task("pipeline-B", "node1", 2);
450
451        queue.enqueue(t1.clone()).await.unwrap();
452        queue.enqueue(t2.clone()).await.unwrap();
453
454        // Both dequeued and acknowledged
455        let deq1 = queue.try_dequeue().await.unwrap().unwrap();
456        let deq2 = queue.try_dequeue().await.unwrap().unwrap();
457
458        queue
459            .acknowledge(&deq1.id, json!({"data": "A-result"}))
460            .await
461            .unwrap();
462        queue
463            .acknowledge(&deq2.id, json!({"data": "B-result"}))
464            .await
465            .unwrap();
466
467        let results_a = queue.collect_results("pipeline-A").await.unwrap();
468        assert_eq!(results_a.len(), 1);
469        assert_eq!(results_a[0].1["data"], "A-result");
470
471        let results_b = queue.collect_results("pipeline-B").await.unwrap();
472        assert_eq!(results_b.len(), 1);
473        assert_eq!(results_b[0].1["data"], "B-result");
474    }
475
476    #[tokio::test]
477    async fn distributed_executor_runs_tasks() {
478        use crate::adapters::noop::NoopService;
479        use std::collections::HashMap;
480
481        let queue = Arc::new(LocalWorkQueue::new());
482        let executor = DistributedDagExecutor::new(Arc::clone(&queue), 2);
483
484        let mut services: HashMap<String, Arc<dyn ScrapingService>> = HashMap::new();
485        services.insert("noop".to_string(), Arc::new(NoopService));
486
487        let tasks = vec![
488            make_task("p1", "noop", 1),
489            make_task("p1", "noop", 2),
490            make_task("p1", "noop", 3),
491        ];
492
493        // Execute wave — NoopService returns empty data, so results may be empty
494        // but the call must succeed without panic/error
495        let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
496        // 3 tasks were acknowledged; results will contain completed ones
497        assert!(results.len() <= 3);
498    }
499}