Skip to main content

apalis_workflow/dag/
executor.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    fmt::Debug,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use apalis_core::{
9    backend::{BackendExt, codec::RawDataBackend},
10    error::BoxDynError,
11    task::{
12        Task,
13        metadata::{Meta, MetadataExt},
14    },
15    worker::builder::{IntoWorkerService, WorkerService},
16};
17use petgraph::graph::{DiGraph, NodeIndex};
18use tower::Service;
19
20use crate::{
21    DagFlow, DagService,
22    dag::{
23        DagFlowContext, RootDagService,
24        error::{DagFlowError, DagServiceError},
25    },
26    id_generator::GenerateId,
27};
28
29/// Executor for DAG workflows
30#[derive(Debug)]
31pub struct DagExecutor<B>
32where
33    B: BackendExt,
34{
35    pub(super) graph: DiGraph<DagService<B::Compact, B::Context, B::IdType>, ()>,
36    pub(super) node_mapping: HashMap<String, NodeIndex>,
37    pub(super) topological_order: Vec<NodeIndex>,
38    pub(super) start_nodes: Vec<NodeIndex>,
39    pub(super) end_nodes: Vec<NodeIndex>,
40    pub(super) not_ready: VecDeque<NodeIndex>,
41}
42
43impl<B> Clone for DagExecutor<B>
44where
45    B: BackendExt,
46{
47    fn clone(&self) -> Self {
48        Self {
49            graph: self.graph.clone(),
50            node_mapping: self.node_mapping.clone(),
51            topological_order: self.topological_order.clone(),
52            start_nodes: self.start_nodes.clone(),
53            end_nodes: self.end_nodes.clone(),
54            not_ready: self.not_ready.clone(),
55        }
56    }
57}
58
59impl<B> DagExecutor<B>
60where
61    B: BackendExt,
62{
63    /// Get a node by name
64    pub fn get_node_by_name_mut(
65        &mut self,
66        name: &str,
67    ) -> Option<&mut DagService<B::Compact, B::Context, B::IdType>> {
68        self.node_mapping
69            .get(name)
70            .and_then(|&idx| self.graph.node_weight_mut(idx))
71    }
72}
73
74impl<B, MetaError> Service<Task<B::Compact, B::Context, B::IdType>> for DagExecutor<B>
75where
76    B: BackendExt,
77    B::Context:
78        Send + Sync + 'static + MetadataExt<DagFlowContext<B::IdType>, Error = MetaError> + Default,
79    B::IdType: Clone + Send + Sync + 'static + GenerateId + Debug,
80    B::Compact: Send + Sync + 'static,
81    MetaError: Into<BoxDynError>,
82{
83    type Response = B::Compact;
84    type Error = DagFlowError;
85    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
86
87    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88        loop {
89            // must wait for *all* services to be ready.
90            // this will cause head-of-line blocking unless the underlying services are always ready.
91            if self.not_ready.is_empty() {
92                return Poll::Ready(Ok(()));
93            } else {
94                if self
95                    .graph
96                    .node_weight_mut(self.not_ready[0])
97                    .ok_or(DagFlowError::MissingService(self.not_ready[0]))?
98                    .poll_ready(cx)
99                    .map_err(DagServiceError::PollError)
100                    .map_err(DagFlowError::Service)?
101                    .is_pending()
102                {
103                    return Poll::Pending;
104                }
105
106                self.not_ready.pop_front();
107            }
108        }
109    }
110
111    fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
112        let mut graph = self.graph.clone();
113
114        Box::pin(async move {
115            let context = req
116                .extract::<Meta<DagFlowContext<B::IdType>>>()
117                .await
118                .map_err(|e| DagFlowError::Metadata(e.into()))?
119                .0;
120            // Get the service for this node
121            let service = graph
122                .node_weight_mut(context.current_node)
123                .ok_or_else(|| DagFlowError::MissingService(context.current_node))?;
124
125            let result = service
126                .call(req)
127                .await
128                .map_err(DagFlowError::NodeExecutionError)?;
129
130            Ok(result)
131        })
132    }
133}
134
135impl<B, Compact, Err> IntoWorkerService<B, RootDagService<B>, B::Compact, B::Context> for DagFlow<B>
136where
137    B: BackendExt<Compact = Compact, Args = Compact, Error = Err> + Clone,
138    Err: std::error::Error + Send + Sync + 'static,
139    B::Context: MetadataExt<DagFlowContext<B::IdType>> + Send + Sync + 'static,
140    B::IdType: Send + Sync + 'static + Default + GenerateId + PartialEq + Debug,
141    B::Compact: Send + Sync + 'static + Clone,
142    RootDagService<B>: Service<Task<Compact, B::Context, B::IdType>>,
143{
144    type Backend = RawDataBackend<B>;
145    fn into_service(self, b: B) -> WorkerService<RawDataBackend<B>, RootDagService<B>> {
146        let executor = self.build().expect("Execution should be valid");
147        WorkerService {
148            backend: RawDataBackend::new(b.clone()),
149            service: RootDagService::new(executor, b),
150        }
151    }
152}