apalis_workflow/dag/
executor.rs1use std::{
2 collections::{HashMap, VecDeque},
3 fmt::Debug,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use apalis_core::{
9 backend::{BackendExt, codec::RawDataBackend},
10 error::BoxDynError,
11 task::{
12 Task,
13 metadata::{Meta, MetadataExt},
14 },
15 worker::builder::{IntoWorkerService, WorkerService},
16};
17use petgraph::graph::{DiGraph, NodeIndex};
18use tower::Service;
19
20use crate::{
21 DagFlow, DagService,
22 dag::{
23 DagFlowContext, RootDagService,
24 error::{DagFlowError, DagServiceError},
25 },
26 id_generator::GenerateId,
27};
28
29#[derive(Debug)]
31pub struct DagExecutor<B>
32where
33 B: BackendExt,
34{
35 pub(super) graph: DiGraph<DagService<B::Compact, B::Context, B::IdType>, ()>,
36 pub(super) node_mapping: HashMap<String, NodeIndex>,
37 pub(super) topological_order: Vec<NodeIndex>,
38 pub(super) start_nodes: Vec<NodeIndex>,
39 pub(super) end_nodes: Vec<NodeIndex>,
40 pub(super) not_ready: VecDeque<NodeIndex>,
41}
42
43impl<B> Clone for DagExecutor<B>
44where
45 B: BackendExt,
46{
47 fn clone(&self) -> Self {
48 Self {
49 graph: self.graph.clone(),
50 node_mapping: self.node_mapping.clone(),
51 topological_order: self.topological_order.clone(),
52 start_nodes: self.start_nodes.clone(),
53 end_nodes: self.end_nodes.clone(),
54 not_ready: self.not_ready.clone(),
55 }
56 }
57}
58
59impl<B> DagExecutor<B>
60where
61 B: BackendExt,
62{
63 pub fn get_node_by_name_mut(
65 &mut self,
66 name: &str,
67 ) -> Option<&mut DagService<B::Compact, B::Context, B::IdType>> {
68 self.node_mapping
69 .get(name)
70 .and_then(|&idx| self.graph.node_weight_mut(idx))
71 }
72}
73
74impl<B, MetaError> Service<Task<B::Compact, B::Context, B::IdType>> for DagExecutor<B>
75where
76 B: BackendExt,
77 B::Context:
78 Send + Sync + 'static + MetadataExt<DagFlowContext<B::IdType>, Error = MetaError> + Default,
79 B::IdType: Clone + Send + Sync + 'static + GenerateId + Debug,
80 B::Compact: Send + Sync + 'static,
81 MetaError: Into<BoxDynError>,
82{
83 type Response = B::Compact;
84 type Error = DagFlowError;
85 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
86
87 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88 loop {
89 if self.not_ready.is_empty() {
92 return Poll::Ready(Ok(()));
93 } else {
94 if self
95 .graph
96 .node_weight_mut(self.not_ready[0])
97 .ok_or(DagFlowError::MissingService(self.not_ready[0]))?
98 .poll_ready(cx)
99 .map_err(DagServiceError::PollError)
100 .map_err(DagFlowError::Service)?
101 .is_pending()
102 {
103 return Poll::Pending;
104 }
105
106 self.not_ready.pop_front();
107 }
108 }
109 }
110
111 fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
112 let mut graph = self.graph.clone();
113
114 Box::pin(async move {
115 let context = req
116 .extract::<Meta<DagFlowContext<B::IdType>>>()
117 .await
118 .map_err(|e| DagFlowError::Metadata(e.into()))?
119 .0;
120 let service = graph
122 .node_weight_mut(context.current_node)
123 .ok_or_else(|| DagFlowError::MissingService(context.current_node))?;
124
125 let result = service
126 .call(req)
127 .await
128 .map_err(DagFlowError::NodeExecutionError)?;
129
130 Ok(result)
131 })
132 }
133}
134
135impl<B, Compact, Err> IntoWorkerService<B, RootDagService<B>, B::Compact, B::Context> for DagFlow<B>
136where
137 B: BackendExt<Compact = Compact, Args = Compact, Error = Err> + Clone,
138 Err: std::error::Error + Send + Sync + 'static,
139 B::Context: MetadataExt<DagFlowContext<B::IdType>> + Send + Sync + 'static,
140 B::IdType: Send + Sync + 'static + Default + GenerateId + PartialEq + Debug,
141 B::Compact: Send + Sync + 'static + Clone,
142 RootDagService<B>: Service<Task<Compact, B::Context, B::IdType>>,
143{
144 type Backend = RawDataBackend<B>;
145 fn into_service(self, b: B) -> WorkerService<RawDataBackend<B>, RootDagService<B>> {
146 let executor = self.build().expect("Execution should be valid");
147 WorkerService {
148 backend: RawDataBackend::new(b.clone()),
149 service: RootDagService::new(executor, b),
150 }
151 }
152}