Skip to main content

ntex_util/services/
buffer.rs

1//! Service that buffers incoming requests.
2use std::cell::{Cell, RefCell};
3use std::task::{Poll, Waker, ready};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
7
8use crate::channel::oneshot;
9
10/// Buffer - service factory for service that can buffer incoming request.
11///
12/// Default number of buffered requests is 16
13pub struct Buffer<R> {
14    buf_size: usize,
15    cancel_on_shutdown: bool,
16    _t: PhantomData<R>,
17}
18
19impl<R> Buffer<R> {
20    /// Set size of the buffer.
21    ///
22    /// Default is set to 16
23    #[must_use]
24    pub fn buf_size(mut self, size: usize) -> Self {
25        self.buf_size = size;
26        self
27    }
28
29    /// Cancel all buffered requests on shutdown.
30    ///
31    /// By default buffered requests are flushed during `poll_shutdown()`
32    #[must_use]
33    pub fn cancel_on_shutdown(mut self) -> Self {
34        self.cancel_on_shutdown = true;
35        self
36    }
37}
38
39impl<R> Default for Buffer<R> {
40    fn default() -> Self {
41        Self {
42            buf_size: 16,
43            cancel_on_shutdown: false,
44            _t: PhantomData,
45        }
46    }
47}
48
49impl<R> fmt::Debug for Buffer<R> {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        f.debug_struct("Buffer")
52            .field("buf_size", &self.buf_size)
53            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
54            .finish()
55    }
56}
57
58impl<R> Clone for Buffer<R> {
59    fn clone(&self) -> Self {
60        Self {
61            buf_size: self.buf_size,
62            cancel_on_shutdown: self.cancel_on_shutdown,
63            _t: PhantomData,
64        }
65    }
66}
67
68impl<R, S, C> Middleware<S, C> for Buffer<R>
69where
70    S: Service<R> + 'static,
71    R: 'static,
72{
73    type Service = BufferService<R, S>;
74
75    fn create(&self, service: S, _: C) -> Self::Service {
76        BufferService::new(self.buf_size, service)
77    }
78}
79
80#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81pub enum BufferServiceError<E> {
82    Service(E),
83    RequestCanceled,
84}
85
86impl<E> From<E> for BufferServiceError<E> {
87    fn from(err: E) -> Self {
88        BufferServiceError::Service(err)
89    }
90}
91
92impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
96            BufferServiceError::RequestCanceled => {
97                f.write_str("buffer service request canceled")
98            }
99        }
100    }
101}
102
103impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
104
105/// Buffer service - service that can buffer incoming requests.
106///
107/// Default number of buffered requests is 16
108pub struct BufferService<R, S: Service<R>> {
109    size: usize,
110    ready: Cell<bool>,
111    service: PipelineBinding<S, R>,
112    buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
113    next_call: RefCell<Option<oneshot::Receiver<()>>>,
114    cancel_on_shutdown: bool,
115    readiness: Cell<Option<Waker>>,
116    _t: PhantomData<R>,
117}
118
119impl<R, S> BufferService<R, S>
120where
121    S: Service<R> + 'static,
122    R: 'static,
123{
124    #[must_use]
125    pub fn new(size: usize, service: S) -> Self {
126        Self {
127            size,
128            service: Pipeline::new(service).bind(),
129            ready: Cell::new(false),
130            buf: RefCell::new(VecDeque::with_capacity(size)),
131            next_call: RefCell::default(),
132            cancel_on_shutdown: false,
133            readiness: Cell::new(None),
134            _t: PhantomData,
135        }
136    }
137
138    #[must_use]
139    pub fn cancel_on_shutdown(self) -> Self {
140        Self {
141            cancel_on_shutdown: true,
142            ..self
143        }
144    }
145}
146
147impl<R, S> Clone for BufferService<R, S>
148where
149    S: Service<R> + Clone,
150{
151    fn clone(&self) -> Self {
152        Self {
153            size: self.size,
154            ready: Cell::new(false),
155            service: self.service.clone(),
156            buf: RefCell::new(VecDeque::with_capacity(self.size)),
157            next_call: RefCell::default(),
158            cancel_on_shutdown: self.cancel_on_shutdown,
159            readiness: Cell::new(None),
160            _t: PhantomData,
161        }
162    }
163}
164
165impl<R, S> fmt::Debug for BufferService<R, S>
166where
167    S: Service<R> + fmt::Debug,
168{
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        f.debug_struct("BufferService")
171            .field("size", &self.size)
172            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
173            .field("ready", &self.ready)
174            .field("service", &self.service)
175            .field("buf", &self.buf)
176            .field("next_call", &self.next_call)
177            .finish()
178    }
179}
180
181impl<R, S> Service<R> for BufferService<R, S>
182where
183    S: Service<R> + 'static,
184    R: 'static,
185{
186    type Response = S::Response;
187    type Error = BufferServiceError<S::Error>;
188
189    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
190        // hold advancement until the last released task either makes a call or is dropped
191        let next_call = self.next_call.borrow_mut().take();
192        if let Some(next_call) = next_call {
193            let _ = next_call.recv().await;
194        }
195
196        poll_fn(|cx| {
197            let mut buffer = self.buf.borrow_mut();
198
199            // handle inner service readiness
200            if self.service.poll_ready(cx)?.is_pending() {
201                if buffer.len() < self.size {
202                    // buffer next request
203                    self.ready.set(false);
204                    Poll::Ready(Ok(()))
205                } else {
206                    log::trace!("Buffer limit exceeded");
207                    // service is not ready
208                    let _ = self.readiness.take().map(Waker::wake);
209                    Poll::Pending
210                }
211            } else {
212                while let Some(sender) = buffer.pop_front() {
213                    let (next_call_tx, next_call_rx) = oneshot::channel();
214                    if sender.send(next_call_tx).is_err()
215                        || next_call_rx.poll_recv(cx).is_ready()
216                    {
217                        // the task is gone
218                        continue;
219                    }
220                    self.next_call.borrow_mut().replace(next_call_rx);
221                    self.ready.set(false);
222                    return Poll::Ready(Ok(()));
223                }
224
225                self.ready.set(true);
226                Poll::Ready(Ok(()))
227            }
228        })
229        .await
230    }
231
232    async fn shutdown(&self) {
233        // hold advancement until the last released task either makes a call or is dropped
234        let next_call = self.next_call.borrow_mut().take();
235        if let Some(next_call) = next_call {
236            let _ = next_call.recv().await;
237        }
238
239        poll_fn(|cx| {
240            let mut buffer = self.buf.borrow_mut();
241            if self.cancel_on_shutdown {
242                buffer.clear();
243            }
244
245            if !buffer.is_empty() {
246                if ready!(self.service.poll_ready(cx)).is_err() {
247                    log::error!(
248                        "Buffered inner service failed while buffer flushing on shutdown"
249                    );
250                    return Poll::Ready(());
251                }
252
253                while let Some(sender) = buffer.pop_front() {
254                    let (next_call_tx, next_call_rx) = oneshot::channel();
255                    if sender.send(next_call_tx).is_err()
256                        || next_call_rx.poll_recv(cx).is_ready()
257                    {
258                        // the task is gone
259                        continue;
260                    }
261                    self.next_call.borrow_mut().replace(next_call_rx);
262                    if buffer.is_empty() {
263                        break;
264                    }
265                    return Poll::Pending;
266                }
267            }
268            Poll::Ready(())
269        })
270        .await;
271
272        self.service.shutdown().await;
273    }
274
275    async fn call(
276        &self,
277        req: R,
278        _: ServiceCtx<'_, Self>,
279    ) -> Result<Self::Response, Self::Error> {
280        if self.ready.get() {
281            self.ready.set(false);
282            Ok(self.service.call_nowait(req).await?)
283        } else {
284            let (tx, rx) = oneshot::channel();
285            self.buf.borrow_mut().push_back(tx);
286
287            // release
288            let _task_guard = rx.recv().await.map_err(|_| {
289                log::trace!("Buffered service request canceled");
290                BufferServiceError::RequestCanceled
291            })?;
292
293            // call service
294            Ok(self.service.call(req).await?)
295        }
296    }
297
298    ntex_service::forward_poll!(service);
299}
300
301#[cfg(test)]
302mod tests {
303    use ntex_service::{Pipeline, apply, fn_factory};
304    use std::{rc::Rc, time::Duration};
305
306    use super::*;
307    use crate::{future::lazy, task::LocalWaker};
308
309    #[derive(Debug, Clone)]
310    struct TestService(Rc<Inner>);
311
312    #[derive(Debug)]
313    struct Inner {
314        ready: Cell<bool>,
315        waker: LocalWaker,
316        count: Cell<usize>,
317    }
318
319    impl Service<()> for TestService {
320        type Response = ();
321        type Error = ();
322
323        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
324            poll_fn(|cx| {
325                self.0.waker.register(cx.waker());
326                if self.0.ready.get() {
327                    Poll::Ready(Ok(()))
328                } else {
329                    Poll::Pending
330                }
331            })
332            .await
333        }
334
335        async fn call(&self, _r: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
336            self.0.ready.set(false);
337            self.0.count.set(self.0.count.get() + 1);
338            Ok(())
339        }
340    }
341
342    #[ntex::test]
343    async fn test_service() {
344        let inner = Rc::new(Inner {
345            ready: Cell::new(false),
346            waker: LocalWaker::default(),
347            count: Cell::new(0),
348        });
349
350        let srv =
351            Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
352        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
353
354        let srv1 = srv.clone();
355        ntex::rt::spawn(async move {
356            let _ = srv1.call(()).await;
357        });
358        crate::time::sleep(Duration::from_millis(25)).await;
359        assert_eq!(inner.count.get(), 0);
360        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
361
362        let srv1 = srv.clone();
363        ntex::rt::spawn(async move {
364            let _ = srv1.call(()).await;
365        });
366        crate::time::sleep(Duration::from_millis(25)).await;
367        assert_eq!(inner.count.get(), 0);
368        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
369
370        inner.ready.set(true);
371        inner.waker.wake();
372        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
373
374        crate::time::sleep(Duration::from_millis(25)).await;
375        assert_eq!(inner.count.get(), 1);
376
377        inner.ready.set(true);
378        inner.waker.wake();
379        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
380
381        crate::time::sleep(Duration::from_millis(25)).await;
382        assert_eq!(inner.count.get(), 2);
383
384        let inner = Rc::new(Inner {
385            ready: Cell::new(true),
386            waker: LocalWaker::default(),
387            count: Cell::new(0),
388        });
389
390        let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
391        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
392
393        let _ = srv.call(()).await;
394        assert_eq!(inner.count.get(), 1);
395        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
396        assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
397
398        let err = BufferServiceError::from("test");
399        assert!(format!("{err}").contains("test"));
400        assert!(format!("{srv:?}").contains("BufferService"));
401        assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
402    }
403
404    #[ntex::test]
405    #[allow(clippy::redundant_clone)]
406    async fn test_middleware() {
407        let inner = Rc::new(Inner {
408            ready: Cell::new(false),
409            waker: LocalWaker::default(),
410            count: Cell::new(0),
411        });
412
413        let srv = apply(
414            Buffer::default().buf_size(2).clone(),
415            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
416        );
417
418        let srv = srv.pipeline(&()).await.unwrap().bind();
419        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
420
421        let srv1 = srv.clone();
422        ntex::rt::spawn(async move {
423            let _ = srv1.call(()).await;
424        });
425        crate::time::sleep(Duration::from_millis(25)).await;
426        assert_eq!(inner.count.get(), 0);
427        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
428
429        let srv1 = srv.clone();
430        ntex::rt::spawn(async move {
431            let _ = srv1.call(()).await;
432        });
433        crate::time::sleep(Duration::from_millis(25)).await;
434        assert_eq!(inner.count.get(), 0);
435        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
436
437        inner.ready.set(true);
438        inner.waker.wake();
439        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
440
441        crate::time::sleep(Duration::from_millis(25)).await;
442        assert_eq!(inner.count.get(), 1);
443
444        inner.ready.set(true);
445        inner.waker.wake();
446        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
447
448        crate::time::sleep(Duration::from_millis(25)).await;
449        assert_eq!(inner.count.get(), 2);
450    }
451
452    #[ntex::test]
453    #[allow(clippy::redundant_clone)]
454    async fn test_middleware2() {
455        let inner = Rc::new(Inner {
456            ready: Cell::new(false),
457            waker: LocalWaker::default(),
458            count: Cell::new(0),
459        });
460
461        let srv = apply(
462            Buffer::default().buf_size(2).clone(),
463            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
464        );
465
466        let srv = srv.pipeline(&()).await.unwrap().bind();
467        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
468
469        let srv1 = srv.clone();
470        ntex::rt::spawn(async move {
471            let _ = srv1.call(()).await;
472        });
473        crate::time::sleep(Duration::from_millis(25)).await;
474        assert_eq!(inner.count.get(), 0);
475        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
476
477        let srv1 = srv.clone();
478        ntex::rt::spawn(async move {
479            let _ = srv1.call(()).await;
480        });
481        crate::time::sleep(Duration::from_millis(25)).await;
482        assert_eq!(inner.count.get(), 0);
483        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
484
485        inner.ready.set(true);
486        inner.waker.wake();
487        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
488
489        crate::time::sleep(Duration::from_millis(25)).await;
490        assert_eq!(inner.count.get(), 1);
491
492        inner.ready.set(true);
493        inner.waker.wake();
494        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
495
496        crate::time::sleep(Duration::from_millis(25)).await;
497        assert_eq!(inner.count.get(), 2);
498    }
499}