apalis_workflow/dag/
executor.rs1use std::{
2 collections::{HashMap, VecDeque},
3 fmt::Debug,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use apalis_core::{
9 backend::{BackendExt, codec::RawDataBackend},
10 error::BoxDynError,
11 task::{
12 Task,
13 metadata::{Meta, MetadataExt},
14 },
15 worker::builder::{IntoWorkerService, WorkerService},
16};
17use petgraph::graph::{DiGraph, NodeIndex};
18use tower::Service;
19
20use crate::{
21 DagFlow, DagService,
22 dag::{DagFlowContext, RootDagService, error::DagFlowError},
23 id_generator::GenerateId,
24};
25
26#[derive(Debug)]
28pub struct DagExecutor<B>
29where
30 B: BackendExt,
31{
32 pub(super) graph: DiGraph<DagService<B::Compact, B::Context, B::IdType>, ()>,
33 pub(super) node_mapping: HashMap<String, NodeIndex>,
34 pub(super) topological_order: Vec<NodeIndex>,
35 pub(super) start_nodes: Vec<NodeIndex>,
36 pub(super) end_nodes: Vec<NodeIndex>,
37 pub(super) not_ready: VecDeque<NodeIndex>,
38}
39
40impl<B> Clone for DagExecutor<B>
41where
42 B: BackendExt,
43{
44 fn clone(&self) -> Self {
45 Self {
46 graph: self.graph.clone(),
47 node_mapping: self.node_mapping.clone(),
48 topological_order: self.topological_order.clone(),
49 start_nodes: self.start_nodes.clone(),
50 end_nodes: self.end_nodes.clone(),
51 not_ready: self.not_ready.clone(),
52 }
53 }
54}
55
56impl<B> DagExecutor<B>
57where
58 B: BackendExt,
59{
60 pub fn get_node_by_name_mut(
62 &mut self,
63 name: &str,
64 ) -> Option<&mut DagService<B::Compact, B::Context, B::IdType>> {
65 self.node_mapping
66 .get(name)
67 .and_then(|&idx| self.graph.node_weight_mut(idx))
68 }
69}
70
71impl<B, MetaError> Service<Task<B::Compact, B::Context, B::IdType>> for DagExecutor<B>
72where
73 B: BackendExt,
74 B::Context:
75 Send + Sync + 'static + MetadataExt<DagFlowContext<B::IdType>, Error = MetaError> + Default,
76 B::IdType: Clone + Send + Sync + 'static + GenerateId + Debug,
77 B::Compact: Send + Sync + 'static,
78 MetaError: Into<BoxDynError>,
79{
80 type Response = B::Compact;
81 type Error = DagFlowError;
82 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
83
84 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85 loop {
86 if self.not_ready.is_empty() {
89 return Poll::Ready(Ok(()));
90 } else {
91 if self
92 .graph
93 .node_weight_mut(self.not_ready[0])
94 .unwrap()
95 .poll_ready(cx)
96 .map_err(DagFlowError::Service)?
97 .is_pending()
98 {
99 return Poll::Pending;
100 }
101
102 self.not_ready.pop_front();
103 }
104 }
105 }
106
107 fn call(&mut self, req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
108 let mut graph = self.graph.clone();
109
110 Box::pin(async move {
111 let context = req
112 .extract::<Meta<DagFlowContext<B::IdType>>>()
113 .await
114 .map_err(|e| DagFlowError::Metadata(e.into()))?
115 .0;
116
117 let service = graph
119 .node_weight_mut(context.current_node)
120 .ok_or_else(|| DagFlowError::MissingService(context.current_node))?;
121
122 let result = service.call(req).await.map_err(DagFlowError::Node)?;
123
124 Ok(result)
125 })
126 }
127}
128
129impl<B, Compact, Err> IntoWorkerService<B, RootDagService<B>, B::Compact, B::Context> for DagFlow<B>
130where
131 B: BackendExt<Compact = Compact, Args = Compact, Error = Err> + Clone,
132 Err: std::error::Error + Send + Sync + 'static,
133 B::Context: MetadataExt<DagFlowContext<B::IdType>> + Send + Sync + 'static,
134 B::IdType: Send + Sync + 'static + Default + GenerateId + PartialEq + Debug,
135 B::Compact: Send + Sync + 'static + Clone,
136 RootDagService<B>: Service<Task<Compact, B::Context, B::IdType>>,
137{
138 type Backend = RawDataBackend<B>;
139 fn into_service(self, b: B) -> WorkerService<RawDataBackend<B>, RootDagService<B>> {
140 let executor = self.build().expect("Execution should be valid");
141 WorkerService {
142 backend: RawDataBackend::new(b.clone()),
143 service: RootDagService::new(executor, b),
144 }
145 }
146}