ntex_service/
pipeline.rs

1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{ctx::WaitersRef, Service, ServiceCtx};
4
5#[derive(Debug)]
6/// Container for a service.
7///
8/// Container allows to call enclosed service and adds support of shared readiness.
9pub struct Pipeline<S> {
10    index: u32,
11    state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15    svc: S,
16    waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20    pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21        &self.waiters
22    }
23}
24
25impl<S> Pipeline<S> {
26    #[inline]
27    /// Construct new container instance.
28    pub fn new(svc: S) -> Self {
29        let (index, waiters) = WaitersRef::new();
30        Pipeline {
31            index,
32            state: Rc::new(PipelineState { svc, waiters }),
33        }
34    }
35
36    #[inline]
37    /// Return reference to enclosed service
38    pub fn get_ref(&self) -> &S {
39        &self.state.svc
40    }
41
42    #[inline]
43    /// Returns when the pipeline is able to process requests.
44    pub async fn ready<R>(&self) -> Result<(), S::Error>
45    where
46        S: Service<R>,
47    {
48        ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49            .ready(&self.state.svc)
50            .await
51    }
52
53    #[inline]
54    /// Wait for service readiness and then create future object
55    /// that resolves to service result.
56    pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
57    where
58        S: Service<R>,
59    {
60        ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
61            .call(&self.state.svc, req)
62            .await
63    }
64
65    #[inline]
66    /// Wait for service readiness and then create future object
67    /// that resolves to service result.
68    pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
69    where
70        S: Service<R> + 'static,
71        R: 'static,
72    {
73        let pl = self.clone();
74
75        PipelineCall {
76            fut: Box::pin(async move {
77                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
78                    .call(&pl.state.svc, req)
79                    .await
80            }),
81        }
82    }
83
84    #[inline]
85    /// Call service and create future object that resolves to service result.
86    ///
87    /// Note, this call does not check service readiness.
88    pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
89    where
90        S: Service<R> + 'static,
91        R: 'static,
92    {
93        let pl = self.clone();
94
95        PipelineCall {
96            fut: Box::pin(async move {
97                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
98                    .call_nowait(&pl.state.svc, req)
99                    .await
100            }),
101        }
102    }
103
104    #[inline]
105    /// Check if shutdown is initiated.
106    pub fn is_shutdown(&self) -> bool {
107        self.state.waiters.is_shutdown()
108    }
109
110    #[inline]
111    /// Shutdown enclosed service.
112    pub async fn shutdown<R>(&self)
113    where
114        S: Service<R>,
115    {
116        self.state.svc.shutdown().await
117    }
118
119    #[inline]
120    pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
121    where
122        S: Service<R>,
123    {
124        self.state.svc.poll(cx)
125    }
126
127    #[inline]
128    /// Get current pipeline.
129    pub fn bind<R>(self) -> PipelineBinding<S, R>
130    where
131        S: Service<R> + 'static,
132        R: 'static,
133    {
134        PipelineBinding::new(self)
135    }
136}
137
138impl<S> From<S> for Pipeline<S> {
139    #[inline]
140    fn from(svc: S) -> Self {
141        Pipeline::new(svc)
142    }
143}
144
145impl<S> Clone for Pipeline<S> {
146    fn clone(&self) -> Self {
147        Pipeline {
148            index: self.state.waiters.insert(),
149            state: self.state.clone(),
150        }
151    }
152}
153
154impl<S> Drop for Pipeline<S> {
155    #[inline]
156    fn drop(&mut self) {
157        self.state.waiters.remove(self.index);
158    }
159}
160
161impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("PipelineState")
164            .field("svc", &self.svc)
165            .field("waiters", &self.waiters.get().len())
166            .finish()
167    }
168}
169
170/// Bound container for a service.
171pub struct PipelineBinding<S, R>
172where
173    S: Service<R>,
174{
175    pl: Pipeline<S>,
176    st: cell::UnsafeCell<State<S::Error>>,
177}
178
179enum State<E> {
180    New,
181    Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
182    Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
183}
184
185impl<S, R> PipelineBinding<S, R>
186where
187    S: Service<R> + 'static,
188    R: 'static,
189{
190    fn new(pl: Pipeline<S>) -> Self {
191        PipelineBinding {
192            pl,
193            st: cell::UnsafeCell::new(State::New),
194        }
195    }
196
197    #[inline]
198    /// Return reference to enclosed service
199    pub fn get_ref(&self) -> &S {
200        &self.pl.state.svc
201    }
202
203    #[inline]
204    /// Get pipeline
205    pub fn pipeline(&self) -> Pipeline<S> {
206        self.pl.clone()
207    }
208
209    #[inline]
210    pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
211        self.pl.poll(cx)
212    }
213
214    #[inline]
215    /// Returns `Ready` when the pipeline is able to process requests.
216    ///
217    /// panics if .poll_shutdown() was called before.
218    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
219        let st = unsafe { &mut *self.st.get() };
220
221        match st {
222            State::New => {
223                // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`.
224                // Pipeline::svc is heap allocated(Rc<S>), and it is being kept alive until
225                // `self` is alive
226                let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
227                let fut = Box::pin(CheckReadiness {
228                    fut: None,
229                    f: ready,
230                    _t: marker::PhantomData,
231                    pl,
232                });
233                *st = State::Readiness(fut);
234                self.poll_ready(cx)
235            }
236            State::Readiness(ref mut fut) => Pin::new(fut).poll(cx),
237            State::Shutdown(_) => panic!("Pipeline is shutding down"),
238        }
239    }
240
241    #[inline]
242    /// Returns `Ready` when the service is properly shutdowns.
243    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
244        let st = unsafe { &mut *self.st.get() };
245
246        match st {
247            State::New | State::Readiness(_) => {
248                // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`.
249                // Pipeline::svc is heap allocated(Rc<S>), and it is being kept alive until
250                // `self` is alive
251                let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
252                *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
253                pl.state.waiters.shutdown();
254                self.poll_shutdown(cx)
255            }
256            State::Shutdown(ref mut fut) => Pin::new(fut).poll(cx),
257        }
258    }
259
260    #[inline]
261    /// Wait for service readiness and then create future object
262    /// that resolves to service result.
263    pub fn call(&self, req: R) -> PipelineCall<S, R> {
264        let pl = self.pl.clone();
265
266        PipelineCall {
267            fut: Box::pin(async move {
268                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
269                    .call(&pl.state.svc, req)
270                    .await
271            }),
272        }
273    }
274
275    #[inline]
276    /// Call service and create future object that resolves to service result.
277    ///
278    /// Note, this call does not check service readiness.
279    pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
280        let pl = self.pl.clone();
281
282        PipelineCall {
283            fut: Box::pin(async move {
284                ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
285                    .call_nowait(&pl.state.svc, req)
286                    .await
287            }),
288        }
289    }
290
291    #[inline]
292    /// Check if shutdown is initiated.
293    pub fn is_shutdown(&self) -> bool {
294        self.pl.state.waiters.is_shutdown()
295    }
296
297    #[inline]
298    /// Shutdown enclosed service.
299    pub async fn shutdown(&self) {
300        self.pl.state.svc.shutdown().await
301    }
302}
303
304impl<S, R> Drop for PipelineBinding<S, R>
305where
306    S: Service<R>,
307{
308    fn drop(&mut self) {
309        self.st = cell::UnsafeCell::new(State::New);
310    }
311}
312
313impl<S, R> Clone for PipelineBinding<S, R>
314where
315    S: Service<R>,
316{
317    #[inline]
318    fn clone(&self) -> Self {
319        Self {
320            pl: self.pl.clone(),
321            st: cell::UnsafeCell::new(State::New),
322        }
323    }
324}
325
326impl<S, R> fmt::Debug for PipelineBinding<S, R>
327where
328    S: Service<R> + fmt::Debug,
329{
330    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331        f.debug_struct("PipelineBinding")
332            .field("pipeline", &self.pl)
333            .finish()
334    }
335}
336
337#[must_use = "futures do nothing unless polled"]
338/// Pipeline call future
339pub struct PipelineCall<S, R>
340where
341    S: Service<R>,
342    R: 'static,
343{
344    fut: Call<S::Response, S::Error>,
345}
346
347type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
348
349impl<S, R> Future for PipelineCall<S, R>
350where
351    S: Service<R>,
352{
353    type Output = Result<S::Response, S::Error>;
354
355    #[inline]
356    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
357        Pin::new(&mut self.as_mut().fut).poll(cx)
358    }
359}
360
361impl<S, R> fmt::Debug for PipelineCall<S, R>
362where
363    S: Service<R>,
364{
365    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366        f.debug_struct("PipelineCall").finish()
367    }
368}
369
370fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
371where
372    S: Service<R>,
373    R: 'static,
374{
375    pl.state
376        .svc
377        .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
378}
379
380struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
381    f: F,
382    fut: Option<Fut>,
383    pl: &'static Pipeline<S>,
384    _t: marker::PhantomData<R>,
385}
386
387impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
388
389impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
390    fn drop(&mut self) {
391        // future fot dropped during polling, we must notify other waiters
392        if self.fut.is_some() {
393            self.pl.state.waiters.notify();
394        }
395    }
396}
397
398impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
399where
400    S: Service<R>,
401    F: Fn(&'static Pipeline<S>) -> Fut,
402    Fut: Future<Output = Result<(), S::Error>>,
403{
404    type Output = Result<(), S::Error>;
405
406    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
407        let mut slf = self.as_mut();
408
409        slf.pl.poll(cx)?;
410
411        if slf.pl.state.waiters.can_check(slf.pl.index, cx) {
412            if slf.fut.is_none() {
413                slf.fut = Some((slf.f)(slf.pl));
414            }
415            let fut = slf.fut.as_mut().unwrap();
416            match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
417                Poll::Pending => {
418                    slf.pl.state.waiters.register(slf.pl.index, cx);
419                    Poll::Pending
420                }
421                Poll::Ready(res) => {
422                    let _ = slf.fut.take();
423                    slf.pl.state.waiters.notify();
424                    Poll::Ready(res)
425                }
426            }
427        } else {
428            Poll::Pending
429        }
430    }
431}