1use std::{
2 collections::{HashMap, VecDeque},
3 task::{Context, Poll},
4};
5
6use apalis_core::{
7 backend::{Backend, TaskSinkError, codec::Codec},
8 error::BoxDynError,
9 task::{Task, metadata::MetadataExt},
10};
11use futures::{FutureExt, Sink, TryFutureExt, future::BoxFuture};
12use serde::{Deserialize, Serialize, Serializer};
13use serde_json::Value;
14use tower::Service;
15
16use crate::{CompositeService, GenerateId, GoTo, StepContext, WorkflowRequest};
17
18#[derive(Debug, Clone, Deserialize)]
19pub struct StepResult<T>(pub T);
20
21impl Serialize for StepResult<Vec<u8>> {
22 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
23 where
24 S: Serializer,
25 {
26 match serde_json::from_slice::<serde_json::Value>(&self.0) {
28 Ok(value) => value.serialize(serializer),
29 Err(e) => {
30 use serde::ser::Error;
32 Err(S::Error::custom(e.to_string()))
33 }
34 }
35 }
36}
37
38impl Serialize for StepResult<serde_json::Value> {
39 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40 where
41 S: Serializer,
42 {
43 match Value::deserialize(&self.0) {
45 Ok(value) => value.serialize(serializer),
46 Err(e) => {
47 use serde::ser::Error;
49 Err(S::Error::custom(e.to_string()))
50 }
51 }
52 }
53}
54
55pub struct WorkFlowService<FlowSink, Encode, Compact, Context, IdType> {
56 services: HashMap<usize, CompositeService<FlowSink, Encode, Compact, Context, IdType>>,
57 not_ready: VecDeque<usize>,
58 backend: FlowSink,
59}
60impl<FlowSink, Encode, Compact, Context, IdType>
61 WorkFlowService<FlowSink, Encode, Compact, Context, IdType>
62{
63 pub(crate) fn new(
64 services: HashMap<usize, CompositeService<FlowSink, Encode, Compact, Context, IdType>>,
65 backend: FlowSink,
66 ) -> Self {
67 Self {
68 services,
69 not_ready: VecDeque::new(),
70 backend,
71 }
72 }
73}
74
75impl<FlowSink: Clone + Send + Sync + 'static + Backend<Error = Err>, Encode, Compact, Err>
76 Service<Task<Compact, FlowSink::Context, FlowSink::IdType>>
77 for WorkFlowService<FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType>
78where
79 FlowSink::Context: MetadataExt<WorkflowRequest>,
80 Encode: Send + Sync + 'static,
81 Compact: Send + 'static + Clone,
82 FlowSink: Sync,
83 Compact: Send + Sync,
84 FlowSink::Context: Send + Default + MetadataExt<WorkflowRequest>,
85 Err: std::error::Error + Send + Sync + 'static,
86 FlowSink::IdType: GenerateId + Send + 'static,
87 Encode: Codec<Compact, Compact = Compact>,
88 <FlowSink::Context as MetadataExt<WorkflowRequest>>::Error: Into<BoxDynError>,
89 Encode::Error: Into<BoxDynError>,
90 FlowSink: Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = Err> + Unpin,
91{
92 type Response = StepResult<Compact>;
93 type Error = BoxDynError;
94 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
95
96 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97 loop {
98 if self.not_ready.is_empty() {
101 return Poll::Ready(Ok(()));
102 } else {
103 if self
104 .services
105 .get_mut(&self.not_ready[0])
106 .unwrap()
107 .svc
108 .poll_ready(cx)?
109 .is_pending()
110 {
111 return Poll::Pending;
112 }
113
114 self.not_ready.pop_front();
115 }
116 }
117 }
118
119 fn call(
120 &mut self,
121 mut req: Task<Compact, FlowSink::Context, FlowSink::IdType>,
122 ) -> Self::Future {
123 assert!(
124 self.not_ready.is_empty(),
125 "Workflow must wait for all services to be ready. Did you forget to call poll_ready()?"
126 );
127 let meta: WorkflowRequest = req.parts.ctx.extract().unwrap_or_default();
128 let idx = meta.step_index;
129
130 let has_next = self.services.get(&(idx + 1)).is_some();
131 let ctx: StepContext<FlowSink, Encode> =
132 StepContext::new(self.backend.clone(), idx, has_next);
133
134 let cl = self
135 .services
136 .get_mut(&idx)
137 .expect("Attempted to run a step that doesn't exist");
138
139 let svc = &mut cl.svc;
140
141 req.parts.data.insert(ctx);
143
144 self.not_ready.push_back(idx);
145 svc.call(req).map_ok(|res| StepResult(res)).boxed()
146 }
147}
148
149pub async fn handle_workflow_result<N, Compact, FlowSink, Err>(
150 ctx: &mut StepContext<FlowSink, FlowSink::Codec>,
151 result: &GoTo<N>,
152) -> Result<(), TaskSinkError<Err>>
153where
154 FlowSink: Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = Err>
155 + Backend<Error = Err>
156 + Send
157 + Unpin,
158 FlowSink::Context: MetadataExt<WorkflowRequest>,
159 FlowSink::Codec: Codec<N, Compact = Compact>,
160 <FlowSink::Codec as Codec<N>>::Error: Into<BoxDynError>,
161{
162 use futures::SinkExt;
163 match result {
164 GoTo::Next(next) if ctx.has_next => {
165 let task = Task::builder(
166 FlowSink::Codec::encode(next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
167 )
168 .meta(WorkflowRequest {
169 step_index: ctx.current_step + 1,
170 })
171 .build();
172 ctx.sink.send(task).await?;
173 }
174 GoTo::DelayFor(delay, next) if ctx.has_next => {
175 let task = Task::builder(
176 FlowSink::Codec::encode(next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
177 )
178 .run_after(*delay)
179 .meta(WorkflowRequest {
180 step_index: ctx.current_step + 1,
181 })
182 .build();
183 ctx.sink.send(task).await?;
184 }
185 _ => {}
186 }
187 Ok(())
188}