ntex_util/services/
buffer.rs

1//! Service that buffers incomming requests.
2use std::cell::{Cell, RefCell};
3use std::task::{Poll, Waker, ready};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
7
8use crate::channel::oneshot;
9
10/// Buffer - service factory for service that can buffer incoming request.
11///
12/// Default number of buffered requests is 16
13pub struct Buffer<R> {
14    buf_size: usize,
15    cancel_on_shutdown: bool,
16    _t: PhantomData<R>,
17}
18
19impl<R> Buffer<R> {
20    pub fn buf_size(mut self, size: usize) -> Self {
21        self.buf_size = size;
22        self
23    }
24
25    /// Cancel all buffered requests on shutdown
26    ///
27    /// By default buffered requests are flushed during poll_shutdown
28    pub fn cancel_on_shutdown(mut self) -> Self {
29        self.cancel_on_shutdown = true;
30        self
31    }
32}
33
34impl<R> Default for Buffer<R> {
35    fn default() -> Self {
36        Self {
37            buf_size: 16,
38            cancel_on_shutdown: false,
39            _t: PhantomData,
40        }
41    }
42}
43
44impl<R> fmt::Debug for Buffer<R> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("Buffer")
47            .field("buf_size", &self.buf_size)
48            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
49            .finish()
50    }
51}
52
53impl<R> Clone for Buffer<R> {
54    fn clone(&self) -> Self {
55        Self {
56            buf_size: self.buf_size,
57            cancel_on_shutdown: self.cancel_on_shutdown,
58            _t: PhantomData,
59        }
60    }
61}
62
63impl<R, S, C> Middleware<S, C> for Buffer<R>
64where
65    S: Service<R> + 'static,
66    R: 'static,
67{
68    type Service = BufferService<R, S>;
69
70    fn create(&self, service: S, _: C) -> Self::Service {
71        BufferService::new(self.buf_size, service)
72    }
73}
74
75#[derive(Clone, Copy, Debug, PartialEq, Eq)]
76pub enum BufferServiceError<E> {
77    Service(E),
78    RequestCanceled,
79}
80
81impl<E> From<E> for BufferServiceError<E> {
82    fn from(err: E) -> Self {
83        BufferServiceError::Service(err)
84    }
85}
86
87impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
91            BufferServiceError::RequestCanceled => {
92                f.write_str("buffer service request canceled")
93            }
94        }
95    }
96}
97
98impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
99
100/// Buffer service - service that can buffer incoming requests.
101///
102/// Default number of buffered requests is 16
103pub struct BufferService<R, S: Service<R>> {
104    size: usize,
105    ready: Cell<bool>,
106    service: PipelineBinding<S, R>,
107    buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
108    next_call: RefCell<Option<oneshot::Receiver<()>>>,
109    cancel_on_shutdown: bool,
110    readiness: Cell<Option<Waker>>,
111    _t: PhantomData<R>,
112}
113
114impl<R, S> BufferService<R, S>
115where
116    S: Service<R> + 'static,
117    R: 'static,
118{
119    pub fn new(size: usize, service: S) -> Self {
120        Self {
121            size,
122            service: Pipeline::new(service).bind(),
123            ready: Cell::new(false),
124            buf: RefCell::new(VecDeque::with_capacity(size)),
125            next_call: RefCell::default(),
126            cancel_on_shutdown: false,
127            readiness: Cell::new(None),
128            _t: PhantomData,
129        }
130    }
131
132    pub fn cancel_on_shutdown(self) -> Self {
133        Self {
134            cancel_on_shutdown: true,
135            ..self
136        }
137    }
138}
139
140impl<R, S> Clone for BufferService<R, S>
141where
142    S: Service<R> + Clone,
143{
144    fn clone(&self) -> Self {
145        Self {
146            size: self.size,
147            ready: Cell::new(false),
148            service: self.service.clone(),
149            buf: RefCell::new(VecDeque::with_capacity(self.size)),
150            next_call: RefCell::default(),
151            cancel_on_shutdown: self.cancel_on_shutdown,
152            readiness: Cell::new(None),
153            _t: PhantomData,
154        }
155    }
156}
157
158impl<R, S> fmt::Debug for BufferService<R, S>
159where
160    S: Service<R> + fmt::Debug,
161{
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("BufferService")
164            .field("size", &self.size)
165            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
166            .field("ready", &self.ready)
167            .field("service", &self.service)
168            .field("buf", &self.buf)
169            .field("next_call", &self.next_call)
170            .finish()
171    }
172}
173
174impl<R, S> Service<R> for BufferService<R, S>
175where
176    S: Service<R> + 'static,
177    R: 'static,
178{
179    type Response = S::Response;
180    type Error = BufferServiceError<S::Error>;
181
182    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
183        // hold advancement until the last released task either makes a call or is dropped
184        let next_call = self.next_call.borrow_mut().take();
185        if let Some(next_call) = next_call {
186            let _ = next_call.recv().await;
187        }
188
189        poll_fn(|cx| {
190            let mut buffer = self.buf.borrow_mut();
191
192            // handle inner service readiness
193            if self.service.poll_ready(cx)?.is_pending() {
194                if buffer.len() < self.size {
195                    // buffer next request
196                    self.ready.set(false);
197                    Poll::Ready(Ok(()))
198                } else {
199                    log::trace!("Buffer limit exceeded");
200                    // service is not ready
201                    let _ = self.readiness.take().map(|w| w.wake());
202                    Poll::Pending
203                }
204            } else {
205                while let Some(sender) = buffer.pop_front() {
206                    let (next_call_tx, next_call_rx) = oneshot::channel();
207                    if sender.send(next_call_tx).is_err()
208                        || next_call_rx.poll_recv(cx).is_ready()
209                    {
210                        // the task is gone
211                        continue;
212                    }
213                    self.next_call.borrow_mut().replace(next_call_rx);
214                    self.ready.set(false);
215                    return Poll::Ready(Ok(()));
216                }
217
218                self.ready.set(true);
219                Poll::Ready(Ok(()))
220            }
221        })
222        .await
223    }
224
225    async fn shutdown(&self) {
226        // hold advancement until the last released task either makes a call or is dropped
227        let next_call = self.next_call.borrow_mut().take();
228        if let Some(next_call) = next_call {
229            let _ = next_call.recv().await;
230        }
231
232        poll_fn(|cx| {
233            let mut buffer = self.buf.borrow_mut();
234            if self.cancel_on_shutdown {
235                buffer.clear();
236            }
237
238            if !buffer.is_empty() {
239                if ready!(self.service.poll_ready(cx)).is_err() {
240                    log::error!(
241                        "Buffered inner service failed while buffer flushing on shutdown"
242                    );
243                    return Poll::Ready(());
244                }
245
246                while let Some(sender) = buffer.pop_front() {
247                    let (next_call_tx, next_call_rx) = oneshot::channel();
248                    if sender.send(next_call_tx).is_err()
249                        || next_call_rx.poll_recv(cx).is_ready()
250                    {
251                        // the task is gone
252                        continue;
253                    }
254                    self.next_call.borrow_mut().replace(next_call_rx);
255                    if buffer.is_empty() {
256                        break;
257                    }
258                    return Poll::Pending;
259                }
260            }
261            Poll::Ready(())
262        })
263        .await;
264
265        self.service.shutdown().await;
266    }
267
268    async fn call(
269        &self,
270        req: R,
271        _: ServiceCtx<'_, Self>,
272    ) -> Result<Self::Response, Self::Error> {
273        if self.ready.get() {
274            self.ready.set(false);
275            Ok(self.service.call_nowait(req).await?)
276        } else {
277            let (tx, rx) = oneshot::channel();
278            self.buf.borrow_mut().push_back(tx);
279
280            // release
281            let _task_guard = rx.recv().await.map_err(|_| {
282                log::trace!("Buffered service request canceled");
283                BufferServiceError::RequestCanceled
284            })?;
285
286            // call service
287            Ok(self.service.call(req).await?)
288        }
289    }
290
291    ntex_service::forward_poll!(service);
292}
293
294#[cfg(test)]
295mod tests {
296    use ntex_service::{Pipeline, ServiceFactory, apply, fn_factory};
297    use std::{rc::Rc, time::Duration};
298
299    use super::*;
300    use crate::{future::lazy, task::LocalWaker};
301
302    #[derive(Debug, Clone)]
303    struct TestService(Rc<Inner>);
304
305    #[derive(Debug)]
306    struct Inner {
307        ready: Cell<bool>,
308        waker: LocalWaker,
309        count: Cell<usize>,
310    }
311
312    impl Service<()> for TestService {
313        type Response = ();
314        type Error = ();
315
316        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
317            poll_fn(|cx| {
318                self.0.waker.register(cx.waker());
319                if self.0.ready.get() {
320                    Poll::Ready(Ok(()))
321                } else {
322                    Poll::Pending
323                }
324            })
325            .await
326        }
327
328        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
329            self.0.ready.set(false);
330            self.0.count.set(self.0.count.get() + 1);
331            Ok(())
332        }
333    }
334
335    #[ntex::test]
336    async fn test_service() {
337        let inner = Rc::new(Inner {
338            ready: Cell::new(false),
339            waker: LocalWaker::default(),
340            count: Cell::new(0),
341        });
342
343        let srv =
344            Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
345        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
346
347        let srv1 = srv.clone();
348        ntex::rt::spawn(async move {
349            let _ = srv1.call(()).await;
350        });
351        crate::time::sleep(Duration::from_millis(25)).await;
352        assert_eq!(inner.count.get(), 0);
353        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
354
355        let srv1 = srv.clone();
356        ntex::rt::spawn(async move {
357            let _ = srv1.call(()).await;
358        });
359        crate::time::sleep(Duration::from_millis(25)).await;
360        assert_eq!(inner.count.get(), 0);
361        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
362
363        inner.ready.set(true);
364        inner.waker.wake();
365        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
366
367        crate::time::sleep(Duration::from_millis(25)).await;
368        assert_eq!(inner.count.get(), 1);
369
370        inner.ready.set(true);
371        inner.waker.wake();
372        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
373
374        crate::time::sleep(Duration::from_millis(25)).await;
375        assert_eq!(inner.count.get(), 2);
376
377        let inner = Rc::new(Inner {
378            ready: Cell::new(true),
379            waker: LocalWaker::default(),
380            count: Cell::new(0),
381        });
382
383        let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
384        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
385
386        let _ = srv.call(()).await;
387        assert_eq!(inner.count.get(), 1);
388        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
389        assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
390
391        let err = BufferServiceError::from("test");
392        assert!(format!("{err}").contains("test"));
393        assert!(format!("{srv:?}").contains("BufferService"));
394        assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
395    }
396
397    #[ntex::test]
398    #[allow(clippy::redundant_clone)]
399    async fn test_middleware() {
400        let inner = Rc::new(Inner {
401            ready: Cell::new(false),
402            waker: LocalWaker::default(),
403            count: Cell::new(0),
404        });
405
406        let srv = apply(
407            Buffer::default().buf_size(2).clone(),
408            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
409        );
410
411        let srv = srv.pipeline(&()).await.unwrap().bind();
412        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
413
414        let srv1 = srv.clone();
415        ntex::rt::spawn(async move {
416            let _ = srv1.call(()).await;
417        });
418        crate::time::sleep(Duration::from_millis(25)).await;
419        assert_eq!(inner.count.get(), 0);
420        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
421
422        let srv1 = srv.clone();
423        ntex::rt::spawn(async move {
424            let _ = srv1.call(()).await;
425        });
426        crate::time::sleep(Duration::from_millis(25)).await;
427        assert_eq!(inner.count.get(), 0);
428        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
429
430        inner.ready.set(true);
431        inner.waker.wake();
432        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
433
434        crate::time::sleep(Duration::from_millis(25)).await;
435        assert_eq!(inner.count.get(), 1);
436
437        inner.ready.set(true);
438        inner.waker.wake();
439        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
440
441        crate::time::sleep(Duration::from_millis(25)).await;
442        assert_eq!(inner.count.get(), 2);
443    }
444
445    #[ntex::test]
446    #[allow(clippy::redundant_clone)]
447    async fn test_middleware2() {
448        let inner = Rc::new(Inner {
449            ready: Cell::new(false),
450            waker: LocalWaker::default(),
451            count: Cell::new(0),
452        });
453
454        let srv = apply(
455            Buffer::default().buf_size(2).clone(),
456            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
457        );
458
459        let srv = srv.pipeline(&()).await.unwrap().bind();
460        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
461
462        let srv1 = srv.clone();
463        ntex::rt::spawn(async move {
464            let _ = srv1.call(()).await;
465        });
466        crate::time::sleep(Duration::from_millis(25)).await;
467        assert_eq!(inner.count.get(), 0);
468        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
469
470        let srv1 = srv.clone();
471        ntex::rt::spawn(async move {
472            let _ = srv1.call(()).await;
473        });
474        crate::time::sleep(Duration::from_millis(25)).await;
475        assert_eq!(inner.count.get(), 0);
476        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
477
478        inner.ready.set(true);
479        inner.waker.wake();
480        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
481
482        crate::time::sleep(Duration::from_millis(25)).await;
483        assert_eq!(inner.count.get(), 1);
484
485        inner.ready.set(true);
486        inner.waker.wake();
487        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
488
489        crate::time::sleep(Duration::from_millis(25)).await;
490        assert_eq!(inner.count.get(), 2);
491    }
492}