1use apalis_core::{
2 backend::{BackendExt, TaskSinkError, codec::Codec},
3 error::BoxDynError,
4 task::{Task, metadata::MetadataExt, task_id::TaskId},
5};
6use futures::SinkExt;
7use futures::{FutureExt, Sink, future::BoxFuture};
8use std::{
9 collections::{HashMap, VecDeque},
10 marker::PhantomData,
11 task::{Context, Poll},
12};
13use tower::Service;
14
15use crate::{
16 SteppedService,
17 context::{StepContext, WorkflowContext},
18 id_generator::GenerateId,
19 router::{GoTo, StepResult},
20};
21
22#[derive(Debug)]
24pub struct WorkflowService<B, Input>
25where
26 B: BackendExt,
27{
28 services: HashMap<usize, SteppedService<B::Compact, B::Context, B::IdType>>,
29 not_ready: VecDeque<usize>,
30 backend: B,
31 _marker: PhantomData<Input>,
32}
33impl<B, Input> WorkflowService<B, Input>
34where
35 B: BackendExt,
36{
37 pub fn new(
39 services: HashMap<usize, SteppedService<B::Compact, B::Context, B::IdType>>,
40 backend: B,
41 ) -> Self {
42 Self {
43 services,
44 not_ready: VecDeque::new(),
45 backend,
46 _marker: PhantomData,
47 }
48 }
49}
50
51impl<B, Err, Input> Service<Task<B::Compact, B::Context, B::IdType>> for WorkflowService<B, Input>
52where
53 B::Compact: Send + 'static,
54 B: Sync,
55 B::Context: Send + Default + MetadataExt<WorkflowContext>,
56 Err: std::error::Error + Send + Sync + 'static,
57 B::IdType: GenerateId + Send + 'static,
58 <B::Context as MetadataExt<WorkflowContext>>::Error: Into<BoxDynError>,
59 B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin,
60 B: Clone + Send + Sync + 'static + BackendExt<Error = Err>,
61{
62 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
63 type Error = BoxDynError;
64 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
65
66 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67 loop {
68 if self.not_ready.is_empty() {
71 return Poll::Ready(Ok(()));
72 } else {
73 if self
74 .services
75 .get_mut(&self.not_ready[0])
76 .unwrap()
77 .poll_ready(cx)?
78 .is_pending()
79 {
80 return Poll::Pending;
81 }
82
83 self.not_ready.pop_front();
84 }
85 }
86 }
87
88 fn call(&mut self, mut req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
89 assert!(
90 self.not_ready.is_empty(),
91 "Workflow must wait for all services to be ready. Did you forget to call poll_ready()?"
92 );
93 let meta: WorkflowContext = req.parts.ctx.extract().unwrap_or_default();
94 let idx = meta.step_index;
95
96 let has_next = self.services.contains_key(&(idx + 1));
97 let ctx: StepContext<B> = StepContext::new(self.backend.clone(), idx, has_next);
98
99 let svc = self
100 .services
101 .get_mut(&idx)
102 .expect("Attempted to run a step that doesn't exist");
103
104 req.parts.data.insert(ctx);
106
107 self.not_ready.push_back(idx);
108 svc.call(req).boxed()
109 }
110}
111
112pub async fn handle_step_result<N, Compact, B, Err>(
114 ctx: &mut StepContext<B>,
115 result: GoTo<N>,
116) -> Result<GoTo<StepResult<B::Compact, B::IdType>>, TaskSinkError<Err>>
117where
118 B: Sink<Task<Compact, B::Context, B::IdType>, Error = Err>
119 + BackendExt<Error = Err, Compact = Compact>
120 + Send
121 + Unpin,
122 B::Context: MetadataExt<WorkflowContext>,
123 B::Codec: Codec<N, Compact = Compact>,
124 <B::Codec as Codec<N>>::Error: Into<BoxDynError>,
125 Compact: 'static,
126 N: 'static,
127 B::IdType: GenerateId + Send + 'static,
128{
129 match result {
130 GoTo::Next(next) if ctx.has_next => {
131 let task_id = B::IdType::generate();
132 let task_id = TaskId::new(task_id);
133 let task = Task::builder(
134 B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
135 )
136 .with_task_id(task_id.clone())
137 .meta(WorkflowContext {
138 step_index: ctx.current_step + 1,
139 })
140 .build();
141 ctx.backend.send(task).await?;
142 Ok(GoTo::Next(StepResult {
143 result: B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
144 next_task_id: Some(task_id),
145 }))
146 }
147 GoTo::DelayFor(delay, next) if ctx.has_next => {
148 let task_id = B::IdType::generate();
149 let task_id = TaskId::new(task_id);
150 let task = Task::builder(
151 B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
152 )
153 .run_after(delay)
154 .with_task_id(task_id.clone())
155 .meta(WorkflowContext {
156 step_index: ctx.current_step + 1,
157 })
158 .build();
159 ctx.backend.send(task).await?;
160 Ok(GoTo::DelayFor(
161 delay,
162 StepResult {
163 result: B::Codec::encode(&next)
164 .map_err(|e| TaskSinkError::CodecError(e.into()))?,
165 next_task_id: Some(task_id),
166 },
167 ))
168 }
169 #[allow(clippy::match_same_arms)]
170 GoTo::Done => Ok(GoTo::Done),
171 GoTo::Break(res) => Ok(GoTo::Break(StepResult {
172 result: B::Codec::encode(&res).map_err(|e| TaskSinkError::CodecError(e.into()))?,
173 next_task_id: None,
174 })),
175 _ => Ok(GoTo::Done),
176 }
177}