ntex_service/
pipeline.rs

1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{IntoService, Service, ServiceCtx, ctx::WaitersRef};
4
5#[derive(Debug)]
6/// Container for a service.
7///
8/// Container allows to call enclosed service and adds support of shared readiness.
9pub struct Pipeline<S> {
10    index: u32,
11    state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15    svc: S,
16    waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20    pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21        &self.waiters
22    }
23}
24
25impl<S> Pipeline<S> {
26    #[inline]
27    /// Construct new container instance.
28    pub fn new(svc: S) -> Self {
29        let (index, waiters) = WaitersRef::new();
30        Pipeline {
31            index,
32            state: Rc::new(PipelineState { svc, waiters }),
33        }
34    }
35
36    #[inline]
37    /// Return reference to enclosed service
38    pub fn get_ref(&self) -> &S {
39        &self.state.svc
40    }
41
42    #[inline]
43    /// Returns when the pipeline is able to process requests.
44    pub async fn ready<R>(&self) -> Result<(), S::Error>
45    where
46        S: Service<R>,
47    {
48        ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49            .ready(&self.state.svc)
50            .await
51    }
52
53    #[inline]
54    /// Wait for service readiness and then create future object
55    /// that resolves to service result.
56    pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
57    where
58        S: Service<R>,
59    {
60        ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
61            .call(&self.state.svc, req)
62            .await
63    }
64
65    #[inline]
66    /// Wait for service readiness and then create future object
67    /// that resolves to service result.
68    pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
69    where
70        S: Service<R> + 'static,
71        R: 'static,
72    {
73        let pl = self.clone();
74
75        PipelineCall {
76            fut: Box::pin(async move {
77                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
78                    .call(&pl.state.svc, req)
79                    .await
80            }),
81        }
82    }
83
84    #[inline]
85    /// Call service and create future object that resolves to service result.
86    ///
87    /// Note, this call does not check service readiness.
88    pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
89    where
90        S: Service<R> + 'static,
91        R: 'static,
92    {
93        let pl = self.clone();
94
95        PipelineCall {
96            fut: Box::pin(async move {
97                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
98                    .call_nowait(&pl.state.svc, req)
99                    .await
100            }),
101        }
102    }
103
104    #[inline]
105    /// Check if shutdown is initiated.
106    pub fn is_shutdown(&self) -> bool {
107        self.state.waiters.is_shutdown()
108    }
109
110    #[inline]
111    /// Shutdown enclosed service.
112    pub async fn shutdown<R>(&self)
113    where
114        S: Service<R>,
115    {
116        self.state.svc.shutdown().await
117    }
118
119    #[inline]
120    pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
121    where
122        S: Service<R>,
123    {
124        self.state.svc.poll(cx)
125    }
126
127    #[inline]
128    /// Get current pipeline.
129    pub fn bind<R>(self) -> PipelineBinding<S, R>
130    where
131        S: Service<R> + 'static,
132        R: 'static,
133    {
134        PipelineBinding::new(self)
135    }
136}
137
138impl<S> From<S> for Pipeline<S> {
139    #[inline]
140    fn from(svc: S) -> Self {
141        Pipeline::new(svc)
142    }
143}
144
145impl<S> Clone for Pipeline<S> {
146    fn clone(&self) -> Self {
147        Pipeline {
148            index: self.state.waiters.insert(),
149            state: self.state.clone(),
150        }
151    }
152}
153
154impl<S> Drop for Pipeline<S> {
155    #[inline]
156    fn drop(&mut self) {
157        self.state.waiters.remove(self.index);
158    }
159}
160
161impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("PipelineState")
164            .field("svc", &self.svc)
165            .field("waiters", &self.waiters.get().len())
166            .finish()
167    }
168}
169
170#[derive(Debug)]
171/// Service wrapper for Pipeline
172pub struct PipelineSvc<S> {
173    inner: Pipeline<S>,
174}
175
176impl<S> PipelineSvc<S> {
177    #[inline]
178    /// Construct new PipelineSvc
179    pub fn new(inner: Pipeline<S>) -> Self {
180        Self { inner }
181    }
182}
183
184impl<S, Req> Service<Req> for PipelineSvc<S>
185where
186    S: Service<Req>,
187{
188    type Response = S::Response;
189    type Error = S::Error;
190
191    #[inline]
192    async fn call(
193        &self,
194        req: Req,
195        _: ServiceCtx<'_, Self>,
196    ) -> Result<Self::Response, Self::Error> {
197        self.inner.call(req).await
198    }
199
200    #[inline]
201    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
202        self.inner.ready().await
203    }
204
205    #[inline]
206    async fn shutdown(&self) {
207        self.inner.shutdown().await
208    }
209
210    #[inline]
211    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
212        self.inner.poll(cx)
213    }
214}
215
216impl<S> From<S> for PipelineSvc<S> {
217    #[inline]
218    fn from(svc: S) -> Self {
219        PipelineSvc {
220            inner: Pipeline::new(svc),
221        }
222    }
223}
224
225impl<S> Clone for PipelineSvc<S> {
226    fn clone(&self) -> Self {
227        PipelineSvc {
228            inner: self.inner.clone(),
229        }
230    }
231}
232
233impl<S, R> IntoService<PipelineSvc<S>, R> for Pipeline<S>
234where
235    S: Service<R>,
236{
237    #[inline]
238    fn into_service(self) -> PipelineSvc<S> {
239        PipelineSvc::new(self)
240    }
241}
242
243/// Bound container for a service.
244pub struct PipelineBinding<S, R>
245where
246    S: Service<R>,
247{
248    pl: Pipeline<S>,
249    st: cell::UnsafeCell<State<S::Error>>,
250}
251
252enum State<E> {
253    New,
254    Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
255    Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
256}
257
258impl<S, R> PipelineBinding<S, R>
259where
260    S: Service<R> + 'static,
261    R: 'static,
262{
263    fn new(pl: Pipeline<S>) -> Self {
264        PipelineBinding {
265            pl,
266            st: cell::UnsafeCell::new(State::New),
267        }
268    }
269
270    #[inline]
271    /// Return reference to enclosed service
272    pub fn get_ref(&self) -> &S {
273        &self.pl.state.svc
274    }
275
276    #[inline]
277    /// Get pipeline
278    pub fn pipeline(&self) -> Pipeline<S> {
279        self.pl.clone()
280    }
281
282    #[inline]
283    pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
284        self.pl.poll(cx)
285    }
286
287    #[inline]
288    /// Returns `Ready` when the pipeline is able to process requests.
289    ///
290    /// panics if .poll_shutdown() was called before.
291    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
292        let st = unsafe { &mut *self.st.get() };
293
294        match st {
295            State::New => {
296                // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`.
297                // Pipeline::svc is heap allocated(Rc<S>), and it is being kept alive until
298                // `self` is alive
299                let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
300                let fut = Box::pin(CheckReadiness {
301                    fut: None,
302                    f: ready,
303                    _t: marker::PhantomData,
304                    pl,
305                });
306                *st = State::Readiness(fut);
307                self.poll_ready(cx)
308            }
309            State::Readiness(fut) => Pin::new(fut).poll(cx),
310            State::Shutdown(_) => panic!("Pipeline is shutding down"),
311        }
312    }
313
314    #[inline]
315    /// Returns `Ready` when the service is properly shutdowns.
316    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
317        let st = unsafe { &mut *self.st.get() };
318
319        match st {
320            State::New | State::Readiness(_) => {
321                // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`.
322                // Pipeline::svc is heap allocated(Rc<S>), and it is being kept alive until
323                // `self` is alive
324                let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
325                *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
326                pl.state.waiters.shutdown();
327                self.poll_shutdown(cx)
328            }
329            State::Shutdown(fut) => Pin::new(fut).poll(cx),
330        }
331    }
332
333    #[inline]
334    /// Wait for service readiness and then create future object
335    /// that resolves to service result.
336    pub fn call(&self, req: R) -> PipelineCall<S, R> {
337        let pl = self.pl.clone();
338
339        PipelineCall {
340            fut: Box::pin(async move {
341                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
342                    .call(&pl.state.svc, req)
343                    .await
344            }),
345        }
346    }
347
348    #[inline]
349    /// Call service and create future object that resolves to service result.
350    ///
351    /// Note, this call does not check service readiness.
352    pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
353        let pl = self.pl.clone();
354
355        PipelineCall {
356            fut: Box::pin(async move {
357                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
358                    .call_nowait(&pl.state.svc, req)
359                    .await
360            }),
361        }
362    }
363
364    #[inline]
365    /// Check if shutdown is initiated.
366    pub fn is_shutdown(&self) -> bool {
367        self.pl.state.waiters.is_shutdown()
368    }
369
370    #[inline]
371    /// Shutdown enclosed service.
372    pub async fn shutdown(&self) {
373        self.pl.state.svc.shutdown().await
374    }
375}
376
377impl<S, R> Drop for PipelineBinding<S, R>
378where
379    S: Service<R>,
380{
381    fn drop(&mut self) {
382        self.st = cell::UnsafeCell::new(State::New);
383    }
384}
385
386impl<S, R> Clone for PipelineBinding<S, R>
387where
388    S: Service<R>,
389{
390    #[inline]
391    fn clone(&self) -> Self {
392        Self {
393            pl: self.pl.clone(),
394            st: cell::UnsafeCell::new(State::New),
395        }
396    }
397}
398
399impl<S, R> fmt::Debug for PipelineBinding<S, R>
400where
401    S: Service<R> + fmt::Debug,
402{
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        f.debug_struct("PipelineBinding")
405            .field("pipeline", &self.pl)
406            .finish()
407    }
408}
409
410#[must_use = "futures do nothing unless polled"]
411/// Pipeline call future
412pub struct PipelineCall<S, R>
413where
414    S: Service<R>,
415    R: 'static,
416{
417    fut: Call<S::Response, S::Error>,
418}
419
420type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
421
422impl<S, R> Future for PipelineCall<S, R>
423where
424    S: Service<R>,
425{
426    type Output = Result<S::Response, S::Error>;
427
428    #[inline]
429    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
430        Pin::new(&mut self.as_mut().fut).poll(cx)
431    }
432}
433
434impl<S, R> fmt::Debug for PipelineCall<S, R>
435where
436    S: Service<R>,
437{
438    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439        f.debug_struct("PipelineCall").finish()
440    }
441}
442
443fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
444where
445    S: Service<R>,
446    R: 'static,
447{
448    pl.state
449        .svc
450        .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
451}
452
453struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
454    f: F,
455    fut: Option<Fut>,
456    pl: &'static Pipeline<S>,
457    _t: marker::PhantomData<R>,
458}
459
460impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
461
462impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
463    fn drop(&mut self) {
464        // future got dropped during polling, we must notify other waiters
465        if self.fut.is_some() {
466            self.pl.state.waiters.notify();
467        }
468    }
469}
470
471impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
472where
473    S: Service<R>,
474    F: Fn(&'static Pipeline<S>) -> Fut,
475    Fut: Future<Output = Result<(), S::Error>>,
476{
477    type Output = Result<(), S::Error>;
478
479    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
480        let mut slf = self.as_mut();
481
482        slf.pl.poll(cx)?;
483
484        slf.pl.state.waiters.run(slf.pl.index, cx, |cx| {
485            if slf.fut.is_none() {
486                slf.fut = Some((slf.f)(slf.pl));
487            }
488            let fut = slf.fut.as_mut().unwrap();
489            let result = unsafe { Pin::new_unchecked(fut) }.poll(cx);
490            if result.is_ready() {
491                let _ = slf.fut.take();
492            }
493            result
494        })
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use std::{cell::Cell, rc::Rc};
501
502    use super::*;
503
504    #[derive(Debug, Default, Clone)]
505    struct Srv(Rc<Cell<usize>>);
506
507    impl Service<()> for Srv {
508        type Response = ();
509        type Error = ();
510
511        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
512            Ok(())
513        }
514
515        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
516            Ok(())
517        }
518
519        async fn shutdown(&self) {
520            self.0.set(self.0.get() + 1);
521        }
522    }
523
524    #[ntex::test]
525    async fn pipeline_service() {
526        let cnt_sht = Rc::new(Cell::new(0));
527        let srv = Pipeline::new(
528            Pipeline::new(Srv(cnt_sht.clone()).map(|_| "ok"))
529                .into_service()
530                .clone(),
531        );
532        let res = srv.call(()).await;
533        assert!(res.is_ok());
534        assert_eq!(res.unwrap(), "ok");
535
536        let res = srv.ready().await;
537        assert_eq!(res, Ok(()));
538
539        srv.shutdown().await;
540        assert_eq!(cnt_sht.get(), 1);
541        let _ = format!("{srv:?}");
542
543        let cnt_sht = Rc::new(Cell::new(0));
544        let svc = Srv(cnt_sht.clone()).map(|_| "ok");
545        let srv = Pipeline::new(PipelineSvc::from(&svc));
546        let res = srv.call(()).await;
547        assert!(res.is_ok());
548        assert_eq!(res.unwrap(), "ok");
549
550        let res = srv.ready().await;
551        assert_eq!(res, Ok(()));
552
553        srv.shutdown().await;
554        assert_eq!(cnt_sht.get(), 1);
555        let _ = format!("{srv:?}");
556    }
557}