ntex_util/services/
onerequest.rs

1//! Service that limits number of in-flight async requests to 1.
2use std::{cell::Cell, future::poll_fn, task::Poll};
3
4use ntex_service::{Middleware, Middleware2, Service, ServiceCtx};
5
6use crate::task::LocalWaker;
7
8/// OneRequest - service factory for service that can limit number of in-flight
9/// async requests to 1.
10#[derive(Copy, Clone, Default, Debug)]
11pub struct OneRequest;
12
13impl<S> Middleware<S> for OneRequest {
14    type Service = OneRequestService<S>;
15
16    fn create(&self, service: S) -> Self::Service {
17        OneRequestService {
18            service,
19            ready: Cell::new(true),
20            waker: LocalWaker::new(),
21        }
22    }
23}
24
25impl<S, C> Middleware2<S, C> for OneRequest {
26    type Service = OneRequestService<S>;
27
28    fn create(&self, service: S, _: C) -> Self::Service {
29        OneRequestService {
30            service,
31            ready: Cell::new(true),
32            waker: LocalWaker::new(),
33        }
34    }
35}
36
37#[derive(Clone, Debug)]
38pub struct OneRequestService<S> {
39    waker: LocalWaker,
40    service: S,
41    ready: Cell<bool>,
42}
43
44impl<S> OneRequestService<S> {
45    pub fn new<R>(service: S) -> Self
46    where
47        S: Service<R>,
48    {
49        Self {
50            service,
51            ready: Cell::new(true),
52            waker: LocalWaker::new(),
53        }
54    }
55}
56
57impl<T, R> Service<R> for OneRequestService<T>
58where
59    T: Service<R>,
60{
61    type Response = T::Response;
62    type Error = T::Error;
63
64    #[inline]
65    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
66        if !self.ready.get() {
67            poll_fn(|cx| {
68                self.waker.register(cx.waker());
69                if self.ready.get() {
70                    Poll::Ready(())
71                } else {
72                    Poll::Pending
73                }
74            })
75            .await
76        }
77        ctx.ready(&self.service).await
78    }
79
80    #[inline]
81    async fn call(
82        &self,
83        req: R,
84        ctx: ServiceCtx<'_, Self>,
85    ) -> Result<Self::Response, Self::Error> {
86        self.ready.set(false);
87
88        let result = ctx.call(&self.service, req).await;
89        self.ready.set(true);
90        self.waker.wake();
91        result
92    }
93
94    ntex_service::forward_poll!(service);
95    ntex_service::forward_shutdown!(service);
96}
97
98#[cfg(test)]
99mod tests {
100    use ntex_service::{Pipeline, ServiceFactory, apply, apply2, fn_factory};
101    use std::{cell::RefCell, time::Duration};
102
103    use super::*;
104    use crate::{channel::oneshot, future::lazy};
105
106    struct SleepService(oneshot::Receiver<()>);
107
108    impl Service<()> for SleepService {
109        type Response = ();
110        type Error = ();
111
112        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
113            let _ = self.0.recv().await;
114            Ok::<_, ()>(())
115        }
116    }
117
118    #[ntex::test]
119    async fn test_oneshot() {
120        let (tx, rx) = oneshot::channel();
121
122        let srv = Pipeline::new(OneRequestService::new(SleepService(rx))).bind();
123        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
124
125        let srv2 = srv.clone();
126        ntex::rt::spawn(async move {
127            let _ = srv2.call(()).await;
128        });
129        crate::time::sleep(Duration::from_millis(25)).await;
130        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
131
132        let _ = tx.send(());
133        crate::time::sleep(Duration::from_millis(25)).await;
134        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
135        srv.shutdown().await;
136    }
137
138    #[ntex::test]
139    async fn test_middleware() {
140        assert_eq!(format!("{OneRequest:?}"), "OneRequest");
141
142        let (tx, rx) = oneshot::channel();
143        let rx = RefCell::new(Some(rx));
144        let srv = apply(
145            OneRequest,
146            fn_factory(move || {
147                let rx = rx.borrow_mut().take().unwrap();
148                async move { Ok::<_, ()>(SleepService(rx)) }
149            }),
150        );
151
152        let srv = srv.pipeline(&()).await.unwrap().bind();
153        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
154
155        let srv2 = srv.clone();
156        ntex::rt::spawn(async move {
157            let _ = srv2.call(()).await;
158        });
159        crate::time::sleep(Duration::from_millis(25)).await;
160        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
161
162        let _ = tx.send(());
163        crate::time::sleep(Duration::from_millis(25)).await;
164        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
165    }
166
167    #[ntex::test]
168    async fn test_middleware2() {
169        assert_eq!(format!("{OneRequest:?}"), "OneRequest");
170
171        let (tx, rx) = oneshot::channel();
172        let rx = RefCell::new(Some(rx));
173        let srv = apply2(
174            OneRequest,
175            fn_factory(move || {
176                let rx = rx.borrow_mut().take().unwrap();
177                async move { Ok::<_, ()>(SleepService(rx)) }
178            }),
179        );
180
181        let srv = srv.pipeline(&()).await.unwrap().bind();
182        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
183
184        let srv2 = srv.clone();
185        ntex::rt::spawn(async move {
186            let _ = srv2.call(()).await;
187        });
188        crate::time::sleep(Duration::from_millis(25)).await;
189        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
190
191        let _ = tx.send(());
192        crate::time::sleep(Duration::from_millis(25)).await;
193        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
194    }
195}