ntex_service/
pipeline.rs

1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{ctx::WaitersRef, Service, ServiceCtx};
4
5#[derive(Debug)]
6/// Container for a service.
7///
8/// Container allows to call enclosed service and adds support of shared readiness.
9pub struct Pipeline<S> {
10    index: u32,
11    state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15    svc: S,
16    waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20    pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21        &self.waiters
22    }
23}
24
25impl<S> Pipeline<S> {
26    #[inline]
27    /// Construct new container instance.
28    pub fn new(svc: S) -> Self {
29        let (index, waiters) = WaitersRef::new();
30        Pipeline {
31            index,
32            state: Rc::new(PipelineState { svc, waiters }),
33        }
34    }
35
36    #[inline]
37    /// Return reference to enclosed service
38    pub fn get_ref(&self) -> &S {
39        &self.state.svc
40    }
41
42    #[inline]
43    /// Returns when the pipeline is able to process requests.
44    pub async fn ready<R>(&self) -> Result<(), S::Error>
45    where
46        S: Service<R>,
47    {
48        ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49            .ready(&self.state.svc)
50            .await
51    }
52
53    #[doc(hidden)]
54    #[deprecated]
55    /// Returns when the pipeline is not able to process requests.
56    pub async fn not_ready<R>(&self)
57    where
58        S: Service<R>,
59    {
60        std::future::pending().await
61    }
62
63    #[inline]
64    /// Wait for service readiness and then create future object
65    /// that resolves to service result.
66    pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
67    where
68        S: Service<R>,
69    {
70        ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
71            .call(&self.state.svc, req)
72            .await
73    }
74
75    #[inline]
76    /// Wait for service readiness and then create future object
77    /// that resolves to service result.
78    pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
79    where
80        S: Service<R> + 'static,
81        R: 'static,
82    {
83        let pl = self.clone();
84
85        PipelineCall {
86            fut: Box::pin(async move {
87                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
88                    .call(&pl.state.svc, req)
89                    .await
90            }),
91        }
92    }
93
94    #[inline]
95    /// Call service and create future object that resolves to service result.
96    ///
97    /// Note, this call does not check service readiness.
98    pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
99    where
100        S: Service<R> + 'static,
101        R: 'static,
102    {
103        let pl = self.clone();
104
105        PipelineCall {
106            fut: Box::pin(async move {
107                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
108                    .call_nowait(&pl.state.svc, req)
109                    .await
110            }),
111        }
112    }
113
114    #[inline]
115    /// Check if shutdown is initiated.
116    pub fn is_shutdown(&self) -> bool {
117        self.state.waiters.is_shutdown()
118    }
119
120    #[inline]
121    /// Shutdown enclosed service.
122    pub async fn shutdown<R>(&self)
123    where
124        S: Service<R>,
125    {
126        self.state.svc.shutdown().await
127    }
128
129    #[inline]
130    pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
131    where
132        S: Service<R>,
133    {
134        self.state.svc.poll(cx)
135    }
136
137    #[inline]
138    /// Get current pipeline.
139    pub fn bind<R>(self) -> PipelineBinding<S, R>
140    where
141        S: Service<R> + 'static,
142        R: 'static,
143    {
144        PipelineBinding::new(self)
145    }
146}
147
148impl<S> From<S> for Pipeline<S> {
149    #[inline]
150    fn from(svc: S) -> Self {
151        Pipeline::new(svc)
152    }
153}
154
155impl<S> Clone for Pipeline<S> {
156    fn clone(&self) -> Self {
157        Pipeline {
158            index: self.state.waiters.insert(),
159            state: self.state.clone(),
160        }
161    }
162}
163
164impl<S> Drop for Pipeline<S> {
165    #[inline]
166    fn drop(&mut self) {
167        self.state.waiters.remove(self.index);
168    }
169}
170
171impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        f.debug_struct("PipelineState")
174            .field("svc", &self.svc)
175            .field("waiters", &self.waiters.get().len())
176            .finish()
177    }
178}
179
180/// Bound container for a service.
181pub struct PipelineBinding<S, R>
182where
183    S: Service<R>,
184{
185    pl: Pipeline<S>,
186    st: cell::UnsafeCell<State<S::Error>>,
187}
188
189enum State<E> {
190    New,
191    Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
192    Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
193}
194
195impl<S, R> PipelineBinding<S, R>
196where
197    S: Service<R> + 'static,
198    R: 'static,
199{
200    fn new(pl: Pipeline<S>) -> Self {
201        PipelineBinding {
202            pl,
203            st: cell::UnsafeCell::new(State::New),
204        }
205    }
206
207    #[inline]
208    /// Return reference to enclosed service
209    pub fn get_ref(&self) -> &S {
210        &self.pl.state.svc
211    }
212
213    #[inline]
214    /// Get pipeline
215    pub fn pipeline(&self) -> Pipeline<S> {
216        self.pl.clone()
217    }
218
219    #[inline]
220    pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
221        self.pl.poll(cx)
222    }
223
224    #[inline]
225    /// Returns `Ready` when the pipeline is able to process requests.
226    ///
227    /// panics if .poll_shutdown() was called before.
228    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
229        let st = unsafe { &mut *self.st.get() };
230
231        match st {
232            State::New => {
233                // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`.
234                // Pipeline::svc is heap allocated(Rc<S>), and it is being kept alive until
235                // `self` is alive
236                let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
237                let fut = Box::pin(CheckReadiness {
238                    fut: None,
239                    f: ready,
240                    _t: marker::PhantomData,
241                    pl,
242                });
243                *st = State::Readiness(fut);
244                self.poll_ready(cx)
245            }
246            State::Readiness(ref mut fut) => Pin::new(fut).poll(cx),
247            State::Shutdown(_) => panic!("Pipeline is shutding down"),
248        }
249    }
250
251    #[doc(hidden)]
252    #[deprecated]
253    #[inline]
254    /// Returns when the pipeline is not able to process requests.
255    pub fn poll_not_ready(&self, _: &mut Context<'_>) -> Poll<()> {
256        Poll::Pending
257    }
258
259    #[inline]
260    /// Returns `Ready` when the service is properly shutdowns.
261    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
262        let st = unsafe { &mut *self.st.get() };
263
264        match st {
265            State::New | State::Readiness(_) => {
266                // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`.
267                // Pipeline::svc is heap allocated(Rc<S>), and it is being kept alive until
268                // `self` is alive
269                let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
270                *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
271                pl.state.waiters.shutdown();
272                self.poll_shutdown(cx)
273            }
274            State::Shutdown(ref mut fut) => Pin::new(fut).poll(cx),
275        }
276    }
277
278    #[inline]
279    /// Wait for service readiness and then create future object
280    /// that resolves to service result.
281    pub fn call(&self, req: R) -> PipelineCall<S, R> {
282        let pl = self.pl.clone();
283
284        PipelineCall {
285            fut: Box::pin(async move {
286                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
287                    .call(&pl.state.svc, req)
288                    .await
289            }),
290        }
291    }
292
293    #[inline]
294    /// Call service and create future object that resolves to service result.
295    ///
296    /// Note, this call does not check service readiness.
297    pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
298        let pl = self.pl.clone();
299
300        PipelineCall {
301            fut: Box::pin(async move {
302                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
303                    .call_nowait(&pl.state.svc, req)
304                    .await
305            }),
306        }
307    }
308
309    #[inline]
310    /// Check if shutdown is initiated.
311    pub fn is_shutdown(&self) -> bool {
312        self.pl.state.waiters.is_shutdown()
313    }
314
315    #[inline]
316    /// Shutdown enclosed service.
317    pub async fn shutdown(&self) {
318        self.pl.state.svc.shutdown().await
319    }
320}
321
322impl<S, R> Drop for PipelineBinding<S, R>
323where
324    S: Service<R>,
325{
326    fn drop(&mut self) {
327        self.st = cell::UnsafeCell::new(State::New);
328    }
329}
330
331impl<S, R> Clone for PipelineBinding<S, R>
332where
333    S: Service<R>,
334{
335    #[inline]
336    fn clone(&self) -> Self {
337        Self {
338            pl: self.pl.clone(),
339            st: cell::UnsafeCell::new(State::New),
340        }
341    }
342}
343
344impl<S, R> fmt::Debug for PipelineBinding<S, R>
345where
346    S: Service<R> + fmt::Debug,
347{
348    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349        f.debug_struct("PipelineBinding")
350            .field("pipeline", &self.pl)
351            .finish()
352    }
353}
354
355#[must_use = "futures do nothing unless polled"]
356/// Pipeline call future
357pub struct PipelineCall<S, R>
358where
359    S: Service<R>,
360    R: 'static,
361{
362    fut: Call<S::Response, S::Error>,
363}
364
365type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
366
367impl<S, R> Future for PipelineCall<S, R>
368where
369    S: Service<R>,
370{
371    type Output = Result<S::Response, S::Error>;
372
373    #[inline]
374    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
375        Pin::new(&mut self.as_mut().fut).poll(cx)
376    }
377}
378
379impl<S, R> fmt::Debug for PipelineCall<S, R>
380where
381    S: Service<R>,
382{
383    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384        f.debug_struct("PipelineCall").finish()
385    }
386}
387
388fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
389where
390    S: Service<R>,
391    R: 'static,
392{
393    pl.state
394        .svc
395        .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
396}
397
398struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
399    f: F,
400    fut: Option<Fut>,
401    pl: &'static Pipeline<S>,
402    _t: marker::PhantomData<R>,
403}
404
405impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
406
407impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
408    fn drop(&mut self) {
409        // future fot dropped during polling, we must notify other waiters
410        if self.fut.is_some() {
411            self.pl.state.waiters.notify();
412        }
413    }
414}
415
416impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
417where
418    S: Service<R>,
419    F: Fn(&'static Pipeline<S>) -> Fut,
420    Fut: Future<Output = Result<(), S::Error>>,
421{
422    type Output = Result<(), S::Error>;
423
424    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
425        let mut slf = self.as_mut();
426
427        slf.pl.poll(cx)?;
428
429        if slf.pl.state.waiters.can_check(slf.pl.index, cx) {
430            if slf.fut.is_none() {
431                slf.fut = Some((slf.f)(slf.pl));
432            }
433            let fut = slf.fut.as_mut().unwrap();
434            match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
435                Poll::Pending => {
436                    slf.pl.state.waiters.register(slf.pl.index, cx);
437                    Poll::Pending
438                }
439                Poll::Ready(res) => {
440                    let _ = slf.fut.take();
441                    slf.pl.state.waiters.notify();
442                    Poll::Ready(res)
443                }
444            }
445        } else {
446            Poll::Pending
447        }
448    }
449}