apalis_core/
layers.rs

1use crate::error::{BoxDynError, Error};
2use crate::request::Request;
3use crate::response::Response;
4use futures::channel::mpsc::{SendError, Sender};
5use futures::SinkExt;
6use futures::{future::BoxFuture, Future, FutureExt};
7use serde::Serialize;
8use std::marker::PhantomData;
9use std::{fmt, sync::Arc};
10pub use tower::{
11    layer::layer_fn, layer::util::Identity, util::BoxCloneService, Layer, Service, ServiceBuilder,
12};
13
14/// A generic layer that has been stripped off types.
15/// This is returned by a [crate::Backend] and can be used to customize the middleware of the service consuming tasks
16pub struct CommonLayer<In, T, U, E> {
17    boxed: Arc<dyn Layer<In, Service = BoxCloneService<T, U, E>>>,
18}
19
20impl<In, T, U, E> CommonLayer<In, T, U, E> {
21    /// Create a new [`CommonLayer`].
22    pub fn new<L>(inner_layer: L) -> Self
23    where
24        L: Layer<In> + 'static,
25        L::Service: Service<T, Response = U, Error = E> + Send + 'static + Clone,
26        <L::Service as Service<T>>::Future: Send + 'static,
27        E: std::error::Error,
28    {
29        let layer = layer_fn(move |inner: In| {
30            let out = inner_layer.layer(inner);
31            BoxCloneService::new(out)
32        });
33
34        Self {
35            boxed: Arc::new(layer),
36        }
37    }
38}
39
40impl<In, T, U, E> Layer<In> for CommonLayer<In, T, U, E> {
41    type Service = BoxCloneService<T, U, E>;
42
43    fn layer(&self, inner: In) -> Self::Service {
44        self.boxed.layer(inner)
45    }
46}
47
48impl<In, T, U, E> Clone for CommonLayer<In, T, U, E> {
49    fn clone(&self) -> Self {
50        Self {
51            boxed: Arc::clone(&self.boxed),
52        }
53    }
54}
55
56impl<In, T, U, E> fmt::Debug for CommonLayer<In, T, U, E> {
57    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
58        fmt.debug_struct("CommonLayer").finish()
59    }
60}
61
62/// Extension data for tasks.
63pub mod extensions {
64    use std::{
65        ops::Deref,
66        task::{Context, Poll},
67    };
68    use tower::Service;
69
70    use crate::request::Request;
71
72    /// Extension data for tasks.
73    /// This is commonly used to share state across tasks. or across layers within the same tasks
74    ///
75    /// ```rust
76    /// # use std::sync::Arc;
77    /// # struct Email;
78    /// # use apalis_core::layers::extensions::Data;
79    /// # use apalis_core::service_fn::service_fn;
80    /// # use crate::apalis_core::builder::WorkerFactory;
81    /// # use apalis_core::builder::WorkerBuilder;
82    /// # use apalis_core::memory::MemoryStorage;
83    /// // Some shared state used throughout our application
84    /// struct State {
85    ///     // ...
86    /// }
87    ///
88    /// async fn email_service(email: Email, state: Data<Arc<State>>) {
89    ///     
90    /// }
91    ///
92    /// let state = Arc::new(State { /* ... */ });
93    ///
94    /// let worker = WorkerBuilder::new("tasty-avocado")
95    ///     .data(state)
96    ///     .backend(MemoryStorage::new())
97    ///     .build(service_fn(email_service));
98    /// ```
99
100    #[derive(Debug, Clone, Copy)]
101    pub struct Data<T>(T);
102    impl<T> Data<T> {
103        /// Build a new data entry
104        pub fn new(inner: T) -> Data<T> {
105            Data(inner)
106        }
107    }
108
109    impl<T> Deref for Data<T> {
110        type Target = T;
111        fn deref(&self) -> &Self::Target {
112            &self.0
113        }
114    }
115
116    impl<S, T> tower::Layer<S> for Data<T>
117    where
118        T: Clone + Send + Sync + 'static,
119    {
120        type Service = AddExtension<S, T>;
121
122        fn layer(&self, inner: S) -> Self::Service {
123            AddExtension {
124                inner,
125                value: self.0.clone(),
126            }
127        }
128    }
129
130    /// Middleware for adding some shareable value to [request data].
131    #[derive(Clone, Copy, Debug)]
132    pub struct AddExtension<S, T> {
133        inner: S,
134        value: T,
135    }
136
137    impl<S, T, Req, Ctx> Service<Request<Req, Ctx>> for AddExtension<S, T>
138    where
139        S: Service<Request<Req, Ctx>>,
140        T: Clone + Send + Sync + 'static,
141    {
142        type Response = S::Response;
143        type Error = S::Error;
144        type Future = S::Future;
145
146        #[inline]
147        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148            self.inner.poll_ready(cx)
149        }
150
151        fn call(&mut self, mut req: Request<Req, Ctx>) -> Self::Future {
152            req.parts.data.insert(self.value.clone());
153            self.inner.call(req)
154        }
155    }
156}
157
158/// A trait for acknowledging successful processing
159/// This trait is called even when a task fails.
160/// This is a way of a [`Backend`] to save the result of a job or message
161pub trait Ack<Task, Res> {
162    /// The data to fetch from context to allow acknowledgement
163    type Context;
164    /// The error returned by the ack
165    type AckError: std::error::Error;
166
167    /// Acknowledges successful processing of the given request
168    fn ack(
169        &mut self,
170        ctx: &Self::Context,
171        response: &Response<Res>,
172    ) -> impl Future<Output = Result<(), Self::AckError>> + Send;
173}
174
175impl<T, Res: Clone + Send + Sync, Ctx: Clone + Send + Sync> Ack<T, Res>
176    for Sender<(Ctx, Response<Res>)>
177{
178    type AckError = SendError;
179    type Context = Ctx;
180    async fn ack(
181        &mut self,
182        ctx: &Self::Context,
183        result: &Response<Res>,
184    ) -> Result<(), Self::AckError> {
185        let ctx = ctx.clone();
186        self.send((ctx, result.clone())).await.unwrap();
187        Ok(())
188    }
189}
190
191/// A layer that acknowledges a job completed successfully
192#[derive(Debug)]
193pub struct AckLayer<A, Req, Ctx, Res> {
194    ack: A,
195    job_type: PhantomData<Request<Req, Ctx>>,
196    res: PhantomData<Res>,
197}
198
199impl<A, Req, Ctx, Res> AckLayer<A, Req, Ctx, Res> {
200    /// Build a new [AckLayer] for a job
201    pub fn new(ack: A) -> Self {
202        Self {
203            ack,
204            job_type: PhantomData,
205            res: PhantomData,
206        }
207    }
208}
209
210impl<A, Req, Ctx, S, Res> Layer<S> for AckLayer<A, Req, Ctx, Res>
211where
212    S: Service<Request<Req, Ctx>> + Send + 'static,
213    S::Error: std::error::Error + Send + Sync + 'static,
214    S::Future: Send + 'static,
215    A: Ack<Req, S::Response> + Clone + Send + Sync + 'static,
216{
217    type Service = AckService<S, A, Req, Ctx, S::Response>;
218
219    fn layer(&self, service: S) -> Self::Service {
220        AckService {
221            service,
222            ack: self.ack.clone(),
223            job_type: PhantomData,
224            res: PhantomData,
225        }
226    }
227}
228
229/// The underlying service for an [AckLayer]
230#[derive(Debug)]
231pub struct AckService<SV, A, Req, Ctx, Res> {
232    service: SV,
233    ack: A,
234    job_type: PhantomData<Request<Req, Ctx>>,
235    res: PhantomData<Res>,
236}
237
238impl<Sv: Clone, A: Clone, Req, Ctx, Res> Clone for AckService<Sv, A, Req, Ctx, Res> {
239    fn clone(&self) -> Self {
240        Self {
241            ack: self.ack.clone(),
242            job_type: PhantomData,
243            service: self.service.clone(),
244            res: PhantomData,
245        }
246    }
247}
248
249impl<SV, A, Req, Res, Ctx> Service<Request<Req, Ctx>> for AckService<SV, A, Req, Ctx, Res>
250where
251    SV: Service<Request<Req, Ctx>> + Send + Sync + 'static,
252    <SV as Service<Request<Req, Ctx>>>::Error: Into<BoxDynError> + Send + Sync + 'static,
253    <SV as Service<Request<Req, Ctx>>>::Future: std::marker::Send + 'static,
254    A: Ack<Req, <SV as Service<Request<Req, Ctx>>>::Response, Context = Ctx>
255        + Send
256        + 'static
257        + Clone
258        + Send
259        + Sync,
260    Req: 'static + Send,
261    <SV as Service<Request<Req, Ctx>>>::Response: std::marker::Send + fmt::Debug + Sync + Serialize,
262    <A as Ack<Req, SV::Response>>::Context: Sync + Send + Clone,
263    <A as Ack<Req, <SV as Service<Request<Req, Ctx>>>::Response>>::Context: 'static,
264    Ctx: Clone,
265{
266    type Response = SV::Response;
267    type Error = Error;
268    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
269
270    fn poll_ready(
271        &mut self,
272        cx: &mut std::task::Context<'_>,
273    ) -> std::task::Poll<Result<(), Self::Error>> {
274        self.service
275            .poll_ready(cx)
276            .map_err(|e| Error::Failed(Arc::new(e.into())))
277    }
278
279    fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
280        let mut ack = self.ack.clone();
281        let ctx = request.parts.context.clone();
282        let attempt = request.parts.attempt.clone();
283        let task_id = request.parts.task_id.clone();
284        let fut = self.service.call(request);
285        let fut_with_ack = async move {
286            let res = fut.await.map_err(|err| {
287                let e: BoxDynError = err.into();
288                // Try to downcast the error to see if it is already of type `Error`
289                if let Some(custom_error) = e.downcast_ref::<Error>() {
290                    return custom_error.clone();
291                }
292                Error::Failed(Arc::new(e))
293            });
294            let response = Response {
295                attempt,
296                inner: res,
297                task_id,
298                _priv: (),
299            };
300            if let Err(_e) = ack.ack(&ctx, &response).await {
301                // TODO: Implement tracing in apalis core
302                // tracing::error!("Acknowledgement Failed: {}", e);
303            }
304            response.inner
305        };
306        fut_with_ack.boxed()
307    }
308}