apalis_workflow/dag/
executor.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    fmt::Debug,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use apalis_core::{
9    backend::{BackendExt, codec::RawDataBackend},
10    error::BoxDynError,
11    task::{
12        Task,
13        metadata::{Meta, MetadataExt},
14    },
15    worker::builder::{IntoWorkerService, WorkerService},
16};
17use petgraph::graph::{DiGraph, NodeIndex};
18use tower::Service;
19
20use crate::{
21    DagFlow, DagService,
22    dag::{DagFlowContext, RootDagService, error::DagFlowError},
23    id_generator::GenerateId,
24};
25
26/// Executor for DAG workflows
27#[derive(Debug)]
28pub struct DagExecutor<B>
29where
30    B: BackendExt,
31{
32    pub(super) graph: DiGraph<DagService<B::Compact, B::Context, B::IdType>, ()>,
33    pub(super) node_mapping: HashMap<String, NodeIndex>,
34    pub(super) topological_order: Vec<NodeIndex>,
35    pub(super) start_nodes: Vec<NodeIndex>,
36    pub(super) end_nodes: Vec<NodeIndex>,
37    pub(super) not_ready: VecDeque<NodeIndex>,
38}
39
40impl<B> Clone for DagExecutor<B>
41where
42    B: BackendExt,
43{
44    fn clone(&self) -> Self {
45        Self {
46            graph: self.graph.clone(),
47            node_mapping: self.node_mapping.clone(),
48            topological_order: self.topological_order.clone(),
49            start_nodes: self.start_nodes.clone(),
50            end_nodes: self.end_nodes.clone(),
51            not_ready: self.not_ready.clone(),
52        }
53    }
54}
55
56impl<B> DagExecutor<B>
57where
58    B: BackendExt,
59{
60    /// Get a node by name
61    pub fn get_node_by_name_mut(
62        &mut self,
63        name: &str,
64    ) -> Option<&mut DagService<B::Compact, B::Context, B::IdType>> {
65        self.node_mapping
66            .get(name)
67            .and_then(|&idx| self.graph.node_weight_mut(idx))
68    }
69}
70
71impl<B, MetaError> Service<Task<B::Compact, B::Context, B::IdType>> for DagExecutor<B>
72where
73    B: BackendExt,
74    B::Context:
75        Send + Sync + 'static + MetadataExt<DagFlowContext<B::IdType>, Error = MetaError> + Default,
76    B::IdType: Clone + Send + Sync + 'static + GenerateId + Debug,
77    B::Compact: Send + Sync + 'static,
78    MetaError: Into<BoxDynError>,
79{
80    type Response = B::Compact;
81    type Error = DagFlowError;
82    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
83
84    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85        loop {
86            // must wait for *all* services to be ready.
87            // this will cause head-of-line blocking unless the underlying services are always ready.
88            if self.not_ready.is_empty() {
89                return Poll::Ready(Ok(()));
90            } else {
91                if self
92                    .graph
93                    .node_weight_mut(self.not_ready[0])
94                    .unwrap()
95                    .poll_ready(cx)
96                    .map_err(DagFlowError::Service)?
97                    .is_pending()
98                {
99                    return Poll::Pending;
100                }
101
102                self.not_ready.pop_front();
103            }
104        }
105    }
106
107    fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
108        let mut graph = self.graph.clone();
109
110        Box::pin(async move {
111            let context = req
112                .extract::<Meta<DagFlowContext<B::IdType>>>()
113                .await
114                .map_err(|e| DagFlowError::Metadata(e.into()))?
115                .0;
116
117            // Get the service for this node
118            let service = graph
119                .node_weight_mut(context.current_node)
120                .ok_or_else(|| DagFlowError::MissingService(context.current_node))?;
121
122            let result = service.call(req).await.map_err(DagFlowError::Node)?;
123
124            Ok(result)
125        })
126    }
127}
128
129impl<B, Compact, Err> IntoWorkerService<B, RootDagService<B>, B::Compact, B::Context> for DagFlow<B>
130where
131    B: BackendExt<Compact = Compact, Args = Compact, Error = Err> + Clone,
132    Err: std::error::Error + Send + Sync + 'static,
133    B::Context: MetadataExt<DagFlowContext<B::IdType>> + Send + Sync + 'static,
134    B::IdType: Send + Sync + 'static + Default + GenerateId + PartialEq + Debug,
135    B::Compact: Send + Sync + 'static + Clone,
136    RootDagService<B>: Service<Task<Compact, B::Context, B::IdType>>,
137{
138    type Backend = RawDataBackend<B>;
139    fn into_service(self, b: B) -> WorkerService<RawDataBackend<B>, RootDagService<B>> {
140        let executor = self.build().expect("Execution should be valid");
141        WorkerService {
142            backend: RawDataBackend::new(b.clone()),
143            service: RootDagService::new(executor, b),
144        }
145    }
146}