apalis_workflow/
service.rs

1use apalis_core::{
2    backend::{BackendExt, TaskSinkError, codec::Codec},
3    error::BoxDynError,
4    task::{Task, metadata::MetadataExt, task_id::TaskId},
5};
6use futures::SinkExt;
7use futures::{FutureExt, Sink, future::BoxFuture};
8use std::{
9    collections::{HashMap, VecDeque},
10    marker::PhantomData,
11    task::{Context, Poll},
12};
13use tower::Service;
14
15use crate::{
16    SteppedService,
17    context::{StepContext, WorkflowContext},
18    id_generator::GenerateId,
19    router::{GoTo, StepResult},
20};
21
22/// The main workflow service that orchestrates the execution of workflow steps.
23#[derive(Debug)]
24pub struct WorkflowService<B, Input>
25where
26    B: BackendExt,
27{
28    services: HashMap<usize, SteppedService<B::Compact, B::Context, B::IdType>>,
29    not_ready: VecDeque<usize>,
30    backend: B,
31    _marker: PhantomData<Input>,
32}
33impl<B, Input> WorkflowService<B, Input>
34where
35    B: BackendExt,
36{
37    /// Creates a new `WorkflowService` with the given services and backend.
38    pub fn new(
39        services: HashMap<usize, SteppedService<B::Compact, B::Context, B::IdType>>,
40        backend: B,
41    ) -> Self {
42        Self {
43            services,
44            not_ready: VecDeque::new(),
45            backend,
46            _marker: PhantomData,
47        }
48    }
49}
50
51impl<B, Err, Input> Service<Task<B::Compact, B::Context, B::IdType>> for WorkflowService<B, Input>
52where
53    B::Compact: Send + 'static,
54    B: Sync,
55    B::Context: Send + Default + MetadataExt<WorkflowContext>,
56    Err: std::error::Error + Send + Sync + 'static,
57    B::IdType: GenerateId + Send + 'static,
58    <B::Context as MetadataExt<WorkflowContext>>::Error: Into<BoxDynError>,
59    B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin,
60    B: Clone + Send + Sync + 'static + BackendExt<Error = Err>,
61{
62    type Response = GoTo<StepResult<B::Compact, B::IdType>>;
63    type Error = BoxDynError;
64    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
65
66    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67        loop {
68            // must wait for *all* services to be ready.
69            // this will cause head-of-line blocking unless the underlying services are always ready.
70            if self.not_ready.is_empty() {
71                return Poll::Ready(Ok(()));
72            } else {
73                if self
74                    .services
75                    .get_mut(&self.not_ready[0])
76                    .unwrap()
77                    .poll_ready(cx)?
78                    .is_pending()
79                {
80                    return Poll::Pending;
81                }
82
83                self.not_ready.pop_front();
84            }
85        }
86    }
87
88    fn call(&mut self, mut req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
89        assert!(
90            self.not_ready.is_empty(),
91            "Workflow must wait for all services to be ready. Did you forget to call poll_ready()?"
92        );
93        let meta: WorkflowContext = req.parts.ctx.extract().unwrap_or_default();
94        let idx = meta.step_index;
95
96        let has_next = self.services.contains_key(&(idx + 1));
97        let ctx: StepContext<B> = StepContext::new(self.backend.clone(), idx, has_next);
98
99        let svc = self
100            .services
101            .get_mut(&idx)
102            .expect("Attempted to run a step that doesn't exist");
103
104        // Prepare the context for the next step
105        req.parts.data.insert(ctx);
106
107        self.not_ready.push_back(idx);
108        svc.call(req).boxed()
109    }
110}
111
112/// Handle the result of a workflow step, scheduling the next step if necessary
113pub async fn handle_step_result<N, Compact, B, Err>(
114    ctx: &mut StepContext<B>,
115    result: GoTo<N>,
116) -> Result<GoTo<StepResult<B::Compact, B::IdType>>, TaskSinkError<Err>>
117where
118    B: Sink<Task<Compact, B::Context, B::IdType>, Error = Err>
119        + BackendExt<Error = Err, Compact = Compact>
120        + Send
121        + Unpin,
122    B::Context: MetadataExt<WorkflowContext>,
123    B::Codec: Codec<N, Compact = Compact>,
124    <B::Codec as Codec<N>>::Error: Into<BoxDynError>,
125    Compact: 'static,
126    N: 'static,
127    B::IdType: GenerateId + Send + 'static,
128{
129    match result {
130        GoTo::Next(next) if ctx.has_next => {
131            let task_id = B::IdType::generate();
132            let task_id = TaskId::new(task_id);
133            let task = Task::builder(
134                B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
135            )
136            .with_task_id(task_id.clone())
137            .meta(WorkflowContext {
138                step_index: ctx.current_step + 1,
139            })
140            .build();
141            ctx.backend.send(task).await?;
142            Ok(GoTo::Next(StepResult {
143                result: B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
144                next_task_id: Some(task_id),
145            }))
146        }
147        GoTo::DelayFor(delay, next) if ctx.has_next => {
148            let task_id = B::IdType::generate();
149            let task_id = TaskId::new(task_id);
150            let task = Task::builder(
151                B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?,
152            )
153            .run_after(delay)
154            .with_task_id(task_id.clone())
155            .meta(WorkflowContext {
156                step_index: ctx.current_step + 1,
157            })
158            .build();
159            ctx.backend.send(task).await?;
160            Ok(GoTo::DelayFor(
161                delay,
162                StepResult {
163                    result: B::Codec::encode(&next)
164                        .map_err(|e| TaskSinkError::CodecError(e.into()))?,
165                    next_task_id: Some(task_id),
166                },
167            ))
168        }
169        #[allow(clippy::match_same_arms)]
170        GoTo::Done => Ok(GoTo::Done),
171        GoTo::Break(res) => Ok(GoTo::Break(StepResult {
172            result: B::Codec::encode(&res).map_err(|e| TaskSinkError::CodecError(e.into()))?,
173            next_task_id: None,
174        })),
175        _ => Ok(GoTo::Done),
176    }
177}