apalis_workflow/
service.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    task::{Context, Poll},
4};
5
6use apalis_core::{
7    backend::{Backend, TaskSinkError, codec::Codec},
8    error::BoxDynError,
9    task::{Task, metadata::MetadataExt},
10};
11use futures::{FutureExt, Sink, TryFutureExt, future::BoxFuture};
12use serde::{Deserialize, Serialize, Serializer};
13use serde_json::Value;
14use tower::Service;
15
16use crate::{CompositeService, GenerateId, GoTo, StepContext, WorkflowRequest};
17
18#[derive(Debug, Clone, Deserialize)]
19pub struct StepResult<T>(pub T);
20
21impl Serialize for StepResult<Vec<u8>> {
22    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
23    where
24        S: Serializer,
25    {
26        // Try to deserialize the bytes as JSON
27        match serde_json::from_slice::<serde_json::Value>(&self.0) {
28            Ok(value) => value.serialize(serializer),
29            Err(e) => {
30                // If deserialization fails, serialize the error
31                use serde::ser::Error;
32                Err(S::Error::custom(e.to_string()))
33            }
34        }
35    }
36}
37
38impl Serialize for StepResult<serde_json::Value> {
39    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40    where
41        S: Serializer,
42    {
43        // Try to deserialize the bytes as JSON
44        match Value::deserialize(&self.0) {
45            Ok(value) => value.serialize(serializer),
46            Err(e) => {
47                // If deserialization fails, serialize the error
48                use serde::ser::Error;
49                Err(S::Error::custom(e.to_string()))
50            }
51        }
52    }
53}
54
55pub struct WorkFlowService<FlowSink, Encode, Compact, Context, IdType> {
56    services: HashMap<usize, CompositeService<FlowSink, Encode, Compact, Context, IdType>>,
57    not_ready: VecDeque<usize>,
58    backend: FlowSink,
59}
60impl<FlowSink, Encode, Compact, Context, IdType>
61    WorkFlowService<FlowSink, Encode, Compact, Context, IdType>
62{
63    pub(crate) fn new(
64        services: HashMap<usize, CompositeService<FlowSink, Encode, Compact, Context, IdType>>,
65        backend: FlowSink,
66    ) -> Self {
67        Self {
68            services,
69            not_ready: VecDeque::new(),
70            backend,
71        }
72    }
73}
74
75impl<FlowSink: Clone + Send + Sync + 'static + Backend<Error = Err>, Encode, Compact, Err>
76    Service<Task<Compact, FlowSink::Context, FlowSink::IdType>>
77    for WorkFlowService<FlowSink, Encode, Compact, FlowSink::Context, FlowSink::IdType>
78where
79    FlowSink::Context: MetadataExt<WorkflowRequest>,
80    Encode: Send + Sync + 'static,
81    Compact: Send + 'static + Clone,
82    FlowSink: Sync,
83    Compact: Send + Sync,
84    FlowSink::Context: Send + Default + MetadataExt<WorkflowRequest>,
85    Err: std::error::Error + Send + Sync + 'static,
86    FlowSink::IdType: GenerateId + Send + 'static,
87    Encode: Codec<Compact, Compact = Compact>,
88    <FlowSink::Context as MetadataExt<WorkflowRequest>>::Error: Into<BoxDynError>,
89    Encode::Error: Into<BoxDynError>,
90    FlowSink: Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = Err> + Unpin,
91{
92    type Response = StepResult<Compact>;
93    type Error = BoxDynError;
94    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
95
96    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97        loop {
98            // must wait for *all* services to be ready.
99            // this will cause head-of-line blocking unless the underlying services are always ready.
100            if self.not_ready.is_empty() {
101                return Poll::Ready(Ok(()));
102            } else {
103                if self
104                    .services
105                    .get_mut(&self.not_ready[0])
106                    .unwrap()
107                    .svc
108                    .poll_ready(cx)?
109                    .is_pending()
110                {
111                    return Poll::Pending;
112                }
113
114                self.not_ready.pop_front();
115            }
116        }
117    }
118
119    fn call(
120        &mut self,
121        mut req: Task<Compact, FlowSink::Context, FlowSink::IdType>,
122    ) -> Self::Future {
123        assert!(
124            self.not_ready.is_empty(),
125            "Workflow must wait for all services to be ready. Did you forget to call poll_ready()?"
126        );
127        let meta: WorkflowRequest = req.parts.ctx.extract().unwrap_or_default();
128        let idx = meta.step_index;
129
130        let has_next = self.services.get(&(idx + 1)).is_some();
131        let ctx: StepContext<FlowSink, Encode> =
132            StepContext::new(self.backend.clone(), idx, has_next);
133
134        let cl = self
135            .services
136            .get_mut(&idx)
137            .expect("Attempted to run a step that doesn't exist");
138
139        let svc = &mut cl.svc;
140
141        // Prepare the context for the next step
142        req.parts.data.insert(ctx);
143
144        self.not_ready.push_back(idx);
145        svc.call(req).map_ok(|res| StepResult(res)).boxed()
146    }
147}
148
149pub async fn handle_workflow_result<N, Compact, FlowSink, Err>(
150    ctx: &mut StepContext<FlowSink, FlowSink::Codec>,
151    result: &GoTo<N>,
152) -> Result<(), TaskSinkError<Err>>
153where
154    FlowSink: Sink<Task<Compact, FlowSink::Context, FlowSink::IdType>, Error = Err>
155        + Backend<Error = Err>
156        + Send
157        + Unpin,
158    FlowSink::Context: MetadataExt<WorkflowRequest>,
159    FlowSink::Codec: Codec<N, Compact = Compact>,
160    <FlowSink::Codec as Codec<N>>::Error: Into<BoxDynError>,
161{
162    use futures::SinkExt;
163    match result {
164        GoTo::Next(next) if ctx.has_next => {
165            let task = Task::builder(
166                FlowSink::Codec::encode(next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
167            )
168            .meta(WorkflowRequest {
169                step_index: ctx.current_step + 1,
170            })
171            .build();
172            ctx.sink.send(task).await?;
173        }
174        GoTo::DelayFor(delay, next) if ctx.has_next => {
175            let task = Task::builder(
176                FlowSink::Codec::encode(next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
177            )
178            .run_after(*delay)
179            .meta(WorkflowRequest {
180                step_index: ctx.current_step + 1,
181            })
182            .build();
183            ctx.sink.send(task).await?;
184        }
185        _ => {}
186    }
187    Ok(())
188}