apalis_core/
layers.rs

1use crate::codec::Codec;
2use crate::error::{BoxDynError, Error};
3use crate::request::Request;
4use crate::response::Response;
5use futures::channel::mpsc::{SendError, Sender};
6use futures::SinkExt;
7use futures::{future::BoxFuture, Future, FutureExt};
8use serde::Serialize;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11use std::{fmt, sync::Arc};
12pub use tower::{
13    layer::layer_fn, layer::util::Identity, util::BoxCloneService, Layer, Service, ServiceBuilder,
14};
15
16/// A generic layer that has been stripped off types.
17/// This is returned by a [crate::Backend] and can be used to customize the middleware of the service consuming tasks
18pub struct CommonLayer<In, T, U, E> {
19    boxed: Arc<dyn Layer<In, Service = BoxCloneService<T, U, E>>>,
20}
21
22impl<In, T, U, E> CommonLayer<In, T, U, E> {
23    /// Create a new [`CommonLayer`].
24    pub fn new<L>(inner_layer: L) -> Self
25    where
26        L: Layer<In> + 'static,
27        L::Service: Service<T, Response = U, Error = E> + Send + 'static + Clone,
28        <L::Service as Service<T>>::Future: Send + 'static,
29        E: std::error::Error,
30    {
31        let layer = layer_fn(move |inner: In| {
32            let out = inner_layer.layer(inner);
33            BoxCloneService::new(out)
34        });
35
36        Self {
37            boxed: Arc::new(layer),
38        }
39    }
40}
41
42impl<In, T, U, E> Layer<In> for CommonLayer<In, T, U, E> {
43    type Service = BoxCloneService<T, U, E>;
44
45    fn layer(&self, inner: In) -> Self::Service {
46        self.boxed.layer(inner)
47    }
48}
49
50impl<In, T, U, E> Clone for CommonLayer<In, T, U, E> {
51    fn clone(&self) -> Self {
52        Self {
53            boxed: Arc::clone(&self.boxed),
54        }
55    }
56}
57
58impl<In, T, U, E> fmt::Debug for CommonLayer<In, T, U, E> {
59    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
60        fmt.debug_struct("CommonLayer").finish()
61    }
62}
63
64/// Extension data for tasks.
65pub mod extensions {
66    use std::{
67        ops::Deref,
68        task::{Context, Poll},
69    };
70    use tower::Service;
71
72    use crate::request::Request;
73
74    /// Extension data for tasks.
75    /// This is commonly used to share state across tasks. or across layers within the same tasks
76    ///
77    /// ```rust
78    /// # use std::sync::Arc;
79    /// # struct Email;
80    /// # use apalis_core::layers::extensions::Data;
81    /// # use apalis_core::service_fn::service_fn;
82    /// # use crate::apalis_core::builder::WorkerFactory;
83    /// # use apalis_core::builder::WorkerBuilder;
84    /// # use apalis_core::memory::MemoryStorage;
85    /// // Some shared state used throughout our application
86    /// struct State {
87    ///     // ...
88    /// }
89    ///
90    /// async fn email_service(email: Email, state: Data<Arc<State>>) {
91    ///     
92    /// }
93    ///
94    /// let state = Arc::new(State { /* ... */ });
95    ///
96    /// let worker = WorkerBuilder::new("tasty-avocado")
97    ///     .data(state)
98    ///     .backend(MemoryStorage::new())
99    ///     .build(service_fn(email_service));
100    /// ```
101
102    #[derive(Debug, Clone, Copy)]
103    pub struct Data<T>(T);
104    impl<T> Data<T> {
105        /// Build a new data entry
106        pub fn new(inner: T) -> Data<T> {
107            Data(inner)
108        }
109    }
110
111    impl<T> Deref for Data<T> {
112        type Target = T;
113        fn deref(&self) -> &Self::Target {
114            &self.0
115        }
116    }
117
118    impl<S, T> tower::Layer<S> for Data<T>
119    where
120        T: Clone + Send + Sync + 'static,
121    {
122        type Service = AddExtension<S, T>;
123
124        fn layer(&self, inner: S) -> Self::Service {
125            AddExtension {
126                inner,
127                value: self.0.clone(),
128            }
129        }
130    }
131
132    /// Middleware for adding some shareable value to [request data].
133    #[derive(Clone, Copy, Debug)]
134    pub struct AddExtension<S, T> {
135        inner: S,
136        value: T,
137    }
138
139    impl<S, T, Req, Ctx> Service<Request<Req, Ctx>> for AddExtension<S, T>
140    where
141        S: Service<Request<Req, Ctx>>,
142        T: Clone + Send + Sync + 'static,
143    {
144        type Response = S::Response;
145        type Error = S::Error;
146        type Future = S::Future;
147
148        #[inline]
149        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150            self.inner.poll_ready(cx)
151        }
152
153        fn call(&mut self, mut req: Request<Req, Ctx>) -> Self::Future {
154            req.parts.data.insert(self.value.clone());
155            self.inner.call(req)
156        }
157    }
158}
159
160/// A trait for acknowledging successful processing
161/// This trait is called even when a task fails.
162/// This is a way of a [`Backend`] to save the result of a job or message
163pub trait Ack<Task, Res, Codec> {
164    /// The data to fetch from context to allow acknowledgement
165    type Context;
166    /// The error returned by the ack
167    type AckError: std::error::Error;
168
169    /// Acknowledges successful processing of the given request
170    fn ack(
171        &mut self,
172        ctx: &Self::Context,
173        response: &Response<Res>,
174    ) -> impl Future<Output = Result<(), Self::AckError>> + Send;
175}
176
177impl<T, Res: Clone + Send + Sync + Serialize, Ctx: Clone + Send + Sync, Cdc: Codec> Ack<T, Res, Cdc>
178    for Sender<(Ctx, Response<Cdc::Compact>)>
179where
180    Cdc::Error: Debug,
181    Cdc::Compact: Send,
182{
183    type AckError = SendError;
184    type Context = Ctx;
185    async fn ack(
186        &mut self,
187        ctx: &Self::Context,
188        result: &Response<Res>,
189    ) -> Result<(), Self::AckError> {
190        let ctx = ctx.clone();
191        let res = result.map(|res| Cdc::encode(res).unwrap());
192        self.send((ctx, res)).await.unwrap();
193        Ok(())
194    }
195}
196
197/// A layer that acknowledges a job completed successfully
198#[derive(Debug)]
199pub struct AckLayer<A, Req, Ctx, Cdc> {
200    ack: A,
201    job_type: PhantomData<Request<Req, Ctx>>,
202    codec: PhantomData<Cdc>,
203}
204
205impl<A, Req, Ctx, Cdc> AckLayer<A, Req, Ctx, Cdc> {
206    /// Build a new [AckLayer] for a job
207    pub fn new(ack: A) -> Self {
208        Self {
209            ack,
210            job_type: PhantomData,
211            codec: PhantomData,
212        }
213    }
214}
215
216impl<A, Req, Ctx, S, Cdc> Layer<S> for AckLayer<A, Req, Ctx, Cdc>
217where
218    S: Service<Request<Req, Ctx>> + Send + 'static,
219    S::Error: std::error::Error + Send + Sync + 'static,
220    S::Future: Send + 'static,
221    A: Ack<Req, S::Response, Cdc> + Clone + Send + Sync + 'static,
222{
223    type Service = AckService<S, A, Req, Ctx, Cdc>;
224
225    fn layer(&self, service: S) -> Self::Service {
226        AckService {
227            service,
228            ack: self.ack.clone(),
229            job_type: PhantomData,
230            codec: PhantomData,
231        }
232    }
233}
234
235/// The underlying service for an [AckLayer]
236#[derive(Debug)]
237pub struct AckService<SV, A, Req, Ctx, Cdc> {
238    service: SV,
239    ack: A,
240    job_type: PhantomData<Request<Req, Ctx>>,
241    codec: PhantomData<Cdc>,
242}
243
244impl<Sv: Clone, A: Clone, Req, Ctx, Cdc> Clone for AckService<Sv, A, Req, Ctx, Cdc> {
245    fn clone(&self) -> Self {
246        Self {
247            ack: self.ack.clone(),
248            job_type: PhantomData,
249            service: self.service.clone(),
250            codec: PhantomData,
251        }
252    }
253}
254
255impl<SV, A, Req, Ctx, Cdc> Service<Request<Req, Ctx>> for AckService<SV, A, Req, Ctx, Cdc>
256where
257    SV: Service<Request<Req, Ctx>> + Send + 'static,
258    SV::Error: Into<BoxDynError> + Send + 'static,
259    SV::Future: Send + 'static,
260    A: Ack<Req, SV::Response, Cdc, Context = Ctx> + Send + 'static + Clone,
261    Req: 'static + Send,
262    SV::Response: std::marker::Send + Serialize,
263    <A as Ack<Req, SV::Response, Cdc>>::Context: Send + Clone,
264    <A as Ack<Req, SV::Response, Cdc>>::Context: 'static,
265    Ctx: Clone,
266{
267    type Response = SV::Response;
268    type Error = Error;
269    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
270
271    fn poll_ready(
272        &mut self,
273        cx: &mut std::task::Context<'_>,
274    ) -> std::task::Poll<Result<(), Self::Error>> {
275        self.service
276            .poll_ready(cx)
277            .map_err(|e| Error::Failed(Arc::new(e.into())))
278    }
279
280    fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
281        let mut ack = self.ack.clone();
282        let ctx = request.parts.context.clone();
283        let attempt = request.parts.attempt.clone();
284        let task_id = request.parts.task_id.clone();
285        let fut = self.service.call(request);
286        let fut_with_ack = async move {
287            let res = fut.await.map_err(|err| {
288                let e: BoxDynError = err.into();
289                // Try to downcast the error to see if it is already of type `Error`
290                if let Some(custom_error) = e.downcast_ref::<Error>() {
291                    return custom_error.clone();
292                }
293                Error::Failed(Arc::new(e))
294            });
295            let response = Response {
296                attempt,
297                inner: res,
298                task_id,
299                _priv: (),
300            };
301            if let Err(_e) = ack.ack(&ctx, &response).await {
302                // TODO: Implement tracing in apalis core
303                // tracing::error!("Acknowledgement Failed: {}", e);
304            }
305            response.inner
306        };
307        fut_with_ack.boxed()
308    }
309}