1use std::marker::PhantomData;
2
3use apalis_core::{
4 backend::{BackendExt, codec::Codec},
5 error::BoxDynError,
6 task::{Task, metadata::MetadataExt},
7 task_fn::{TaskFn, task_fn},
8};
9use futures::{
10 FutureExt, Sink,
11 future::{BoxFuture, ready},
12};
13use tower::{Service, ServiceBuilder, layer::layer_fn};
14
15use crate::{
16 SteppedService,
17 id_generator::GenerateId,
18 sequential::context::{StepContext, WorkflowContext},
19 sequential::router::{GoTo, StepResult, WorkflowRouter},
20 sequential::service::handle_step_result,
21 sequential::step::{Layer, Stack, Step},
22 sequential::workflow::Workflow,
23};
24
25#[derive(Clone, Debug)]
27pub struct AndThen<F> {
28 then_fn: F,
29}
30
31impl<F> AndThen<F> {
32 pub fn new(then_fn: F) -> Self {
34 Self { then_fn }
35 }
36}
37
38#[derive(Clone, Debug)]
40pub struct AndThenStep<F, S> {
41 then_fn: F,
42 step: S,
43}
44
45impl<S, F> Layer<S> for AndThen<F>
46where
47 F: Clone,
48{
49 type Step = AndThenStep<F, S>;
50
51 fn layer(&self, step: S) -> Self::Step {
52 AndThenStep {
53 then_fn: self.then_fn.clone(),
54 step,
55 }
56 }
57}
58
59impl<F, Input, S, B, CodecError, SinkError> Step<Input, B> for AndThenStep<F, S>
60where
61 B: BackendExt<Error = SinkError>
62 + Send
63 + Sync
64 + 'static
65 + Clone
66 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = SinkError>
67 + Unpin,
68 F: Service<Task<Input, B::Context, B::IdType>, Error = BoxDynError>
69 + Send
70 + Sync
71 + 'static
72 + Clone,
73 S: Step<F::Response, B>,
74 Input: Send + Sync + 'static,
75 F::Future: Send + 'static,
76 F::Error: Into<BoxDynError> + Send + 'static,
77 B::Codec: Codec<F::Response, Error = CodecError, Compact = B::Compact>
78 + Codec<Input, Error = CodecError, Compact = B::Compact>
79 + Codec<S::Response, Error = CodecError, Compact = B::Compact>
80 + 'static,
81 CodecError: std::error::Error + Send + Sync + 'static,
82 B::IdType: GenerateId + Send + 'static,
83 S::Response: Send + 'static,
84 B::Compact: Send + 'static,
85 B::Context: Send + MetadataExt<WorkflowContext> + 'static,
86 SinkError: std::error::Error + Send + Sync + 'static,
87 F::Response: Send + 'static,
88{
89 type Response = F::Response;
90 type Error = F::Error;
91 fn register(&mut self, ctx: &mut WorkflowRouter<B>) -> Result<(), BoxDynError> {
92 let svc = ServiceBuilder::new()
93 .layer(layer_fn(|s| AndThenService {
94 service: s,
95 _marker: PhantomData::<(B, Input)>,
96 }))
97 .map_response(|res: F::Response| GoTo::Next(res))
98 .service(self.then_fn.clone());
99 let svc = SteppedService::<B::Compact, B::Context, B::IdType>::new(svc);
100 let count = ctx.steps.len();
101 ctx.steps.insert(count, svc);
102 self.step.register(ctx)
103 }
104}
105
106#[derive(Debug)]
108pub struct AndThenService<Svc, Backend, Cur> {
109 service: Svc,
110 _marker: PhantomData<(Backend, Cur)>,
111}
112
113impl<Svc: Clone, Backend, Cur> Clone for AndThenService<Svc, Backend, Cur> {
114 fn clone(&self) -> Self {
115 Self {
116 service: self.service.clone(),
117 _marker: PhantomData,
118 }
119 }
120}
121
122impl<Svc, Backend, Cur> AndThenService<Svc, Backend, Cur> {
123 pub fn new(service: Svc) -> Self {
125 Self {
126 service,
127 _marker: PhantomData,
128 }
129 }
130}
131
132impl<S, B, Cur, Res, CodecErr, SinkError> Service<Task<B::Compact, B::Context, B::IdType>>
133 for AndThenService<S, B, Cur>
134where
135 S: Service<Task<Cur, B::Context, B::IdType>, Response = GoTo<Res>>,
136 S::Future: Send + 'static,
137 B: BackendExt<Error = SinkError>
138 + Sync
139 + Send
140 + 'static
141 + Clone
142 + Sink<Task<B::Compact, B::Context, B::IdType>, Error = SinkError>
143 + Unpin,
144 B::Codec: Codec<Cur, Compact = B::Compact, Error = CodecErr>
145 + Codec<Res, Compact = B::Compact, Error = CodecErr>,
146 S::Error: Into<BoxDynError> + Send + 'static,
147 CodecErr: Into<BoxDynError> + Send + 'static,
148 Cur: Send + 'static,
149 B::IdType: GenerateId + Send + 'static,
150 SinkError: std::error::Error + Send + Sync + 'static,
151 Res: Send + 'static,
152 B::Compact: Send + 'static,
153 B::Context: Send + MetadataExt<WorkflowContext> + 'static,
154{
155 type Response = GoTo<StepResult<B::Compact, B::IdType>>;
156 type Error = BoxDynError;
157 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
158
159 fn poll_ready(
160 &mut self,
161 cx: &mut std::task::Context<'_>,
162 ) -> std::task::Poll<Result<(), Self::Error>> {
163 self.service.poll_ready(cx).map_err(|e| e.into())
164 }
165
166 fn call(&mut self, request: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
167 let mut ctx = request.parts.data.get::<StepContext<B>>().cloned().unwrap();
168 let compacted = request.try_map(|t| B::Codec::decode(&t));
169 match compacted {
170 Ok(task) => {
171 let fut = self.service.call(task);
172 async move {
173 let res = fut.await.map_err(|e| e.into())?;
174 Ok(handle_step_result(&mut ctx, res).await?)
175 }
176 .boxed()
177 }
178 Err(e) => ready(Err(e.into())).boxed(),
179 }
180 }
181}
182
183impl<Start, Cur, B, L> Workflow<Start, Cur, B, L>
184where
185 B: BackendExt,
186{
187 pub fn and_then<F, O, FnArgs>(
201 self,
202 and_then: F,
203 ) -> Workflow<Start, O, B, Stack<AndThen<TaskFn<F, Cur, B::Context, FnArgs>>, L>>
204 where
205 TaskFn<F, Cur, B::Context, FnArgs>: Service<Task<Cur, B::Context, B::IdType>, Response = O>,
206 {
207 self.add_step(AndThen {
208 then_fn: task_fn(and_then),
209 })
210 }
211}