ntex_util/services/
buffer.rs

1//! Service that buffers incomming requests.
2use std::cell::{Cell, RefCell};
3use std::task::{Poll, Waker, ready};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{
7    Middleware, Middleware2, Pipeline, PipelineBinding, Service, ServiceCtx,
8};
9
10use crate::channel::oneshot;
11
12/// Buffer - service factory for service that can buffer incoming request.
13///
14/// Default number of buffered requests is 16
15pub struct Buffer<R> {
16    buf_size: usize,
17    cancel_on_shutdown: bool,
18    _t: PhantomData<R>,
19}
20
21impl<R> Buffer<R> {
22    pub fn buf_size(mut self, size: usize) -> Self {
23        self.buf_size = size;
24        self
25    }
26
27    /// Cancel all buffered requests on shutdown
28    ///
29    /// By default buffered requests are flushed during poll_shutdown
30    pub fn cancel_on_shutdown(mut self) -> Self {
31        self.cancel_on_shutdown = true;
32        self
33    }
34}
35
36impl<R> Default for Buffer<R> {
37    fn default() -> Self {
38        Self {
39            buf_size: 16,
40            cancel_on_shutdown: false,
41            _t: PhantomData,
42        }
43    }
44}
45
46impl<R> fmt::Debug for Buffer<R> {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("Buffer")
49            .field("buf_size", &self.buf_size)
50            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
51            .finish()
52    }
53}
54
55impl<R> Clone for Buffer<R> {
56    fn clone(&self) -> Self {
57        Self {
58            buf_size: self.buf_size,
59            cancel_on_shutdown: self.cancel_on_shutdown,
60            _t: PhantomData,
61        }
62    }
63}
64
65impl<R, S> Middleware<S> for Buffer<R>
66where
67    S: Service<R> + 'static,
68    R: 'static,
69{
70    type Service = BufferService<R, S>;
71
72    fn create(&self, service: S) -> Self::Service {
73        BufferService::new(self.buf_size, service)
74    }
75}
76
77impl<R, S, C> Middleware2<S, C> for Buffer<R>
78where
79    S: Service<R> + 'static,
80    R: 'static,
81{
82    type Service = BufferService<R, S>;
83
84    fn create(&self, service: S, _: C) -> Self::Service {
85        BufferService::new(self.buf_size, service)
86    }
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum BufferServiceError<E> {
91    Service(E),
92    RequestCanceled,
93}
94
95impl<E> From<E> for BufferServiceError<E> {
96    fn from(err: E) -> Self {
97        BufferServiceError::Service(err)
98    }
99}
100
101impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
105            BufferServiceError::RequestCanceled => {
106                f.write_str("buffer service request canceled")
107            }
108        }
109    }
110}
111
112impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
113
114/// Buffer service - service that can buffer incoming requests.
115///
116/// Default number of buffered requests is 16
117pub struct BufferService<R, S: Service<R>> {
118    size: usize,
119    ready: Cell<bool>,
120    service: PipelineBinding<S, R>,
121    buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
122    next_call: RefCell<Option<oneshot::Receiver<()>>>,
123    cancel_on_shutdown: bool,
124    readiness: Cell<Option<Waker>>,
125    _t: PhantomData<R>,
126}
127
128impl<R, S> BufferService<R, S>
129where
130    S: Service<R> + 'static,
131    R: 'static,
132{
133    pub fn new(size: usize, service: S) -> Self {
134        Self {
135            size,
136            service: Pipeline::new(service).bind(),
137            ready: Cell::new(false),
138            buf: RefCell::new(VecDeque::with_capacity(size)),
139            next_call: RefCell::default(),
140            cancel_on_shutdown: false,
141            readiness: Cell::new(None),
142            _t: PhantomData,
143        }
144    }
145
146    pub fn cancel_on_shutdown(self) -> Self {
147        Self {
148            cancel_on_shutdown: true,
149            ..self
150        }
151    }
152}
153
154impl<R, S> Clone for BufferService<R, S>
155where
156    S: Service<R> + Clone,
157{
158    fn clone(&self) -> Self {
159        Self {
160            size: self.size,
161            ready: Cell::new(false),
162            service: self.service.clone(),
163            buf: RefCell::new(VecDeque::with_capacity(self.size)),
164            next_call: RefCell::default(),
165            cancel_on_shutdown: self.cancel_on_shutdown,
166            readiness: Cell::new(None),
167            _t: PhantomData,
168        }
169    }
170}
171
172impl<R, S> fmt::Debug for BufferService<R, S>
173where
174    S: Service<R> + fmt::Debug,
175{
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        f.debug_struct("BufferService")
178            .field("size", &self.size)
179            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
180            .field("ready", &self.ready)
181            .field("service", &self.service)
182            .field("buf", &self.buf)
183            .field("next_call", &self.next_call)
184            .finish()
185    }
186}
187
188impl<R, S> Service<R> for BufferService<R, S>
189where
190    S: Service<R> + 'static,
191    R: 'static,
192{
193    type Response = S::Response;
194    type Error = BufferServiceError<S::Error>;
195
196    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
197        // hold advancement until the last released task either makes a call or is dropped
198        let next_call = self.next_call.borrow_mut().take();
199        if let Some(next_call) = next_call {
200            let _ = next_call.recv().await;
201        }
202
203        poll_fn(|cx| {
204            let mut buffer = self.buf.borrow_mut();
205
206            // handle inner service readiness
207            if self.service.poll_ready(cx)?.is_pending() {
208                if buffer.len() < self.size {
209                    // buffer next request
210                    self.ready.set(false);
211                    Poll::Ready(Ok(()))
212                } else {
213                    log::trace!("Buffer limit exceeded");
214                    // service is not ready
215                    let _ = self.readiness.take().map(|w| w.wake());
216                    Poll::Pending
217                }
218            } else {
219                while let Some(sender) = buffer.pop_front() {
220                    let (next_call_tx, next_call_rx) = oneshot::channel();
221                    if sender.send(next_call_tx).is_err()
222                        || next_call_rx.poll_recv(cx).is_ready()
223                    {
224                        // the task is gone
225                        continue;
226                    }
227                    self.next_call.borrow_mut().replace(next_call_rx);
228                    self.ready.set(false);
229                    return Poll::Ready(Ok(()));
230                }
231
232                self.ready.set(true);
233                Poll::Ready(Ok(()))
234            }
235        })
236        .await
237    }
238
239    async fn shutdown(&self) {
240        // hold advancement until the last released task either makes a call or is dropped
241        let next_call = self.next_call.borrow_mut().take();
242        if let Some(next_call) = next_call {
243            let _ = next_call.recv().await;
244        }
245
246        poll_fn(|cx| {
247            let mut buffer = self.buf.borrow_mut();
248            if self.cancel_on_shutdown {
249                buffer.clear();
250            }
251
252            if !buffer.is_empty() {
253                if ready!(self.service.poll_ready(cx)).is_err() {
254                    log::error!(
255                        "Buffered inner service failed while buffer flushing on shutdown"
256                    );
257                    return Poll::Ready(());
258                }
259
260                while let Some(sender) = buffer.pop_front() {
261                    let (next_call_tx, next_call_rx) = oneshot::channel();
262                    if sender.send(next_call_tx).is_err()
263                        || next_call_rx.poll_recv(cx).is_ready()
264                    {
265                        // the task is gone
266                        continue;
267                    }
268                    self.next_call.borrow_mut().replace(next_call_rx);
269                    if buffer.is_empty() {
270                        break;
271                    }
272                    return Poll::Pending;
273                }
274            }
275            Poll::Ready(())
276        })
277        .await;
278
279        self.service.shutdown().await;
280    }
281
282    async fn call(
283        &self,
284        req: R,
285        _: ServiceCtx<'_, Self>,
286    ) -> Result<Self::Response, Self::Error> {
287        if self.ready.get() {
288            self.ready.set(false);
289            Ok(self.service.call_nowait(req).await?)
290        } else {
291            let (tx, rx) = oneshot::channel();
292            self.buf.borrow_mut().push_back(tx);
293
294            // release
295            let _task_guard = rx.recv().await.map_err(|_| {
296                log::trace!("Buffered service request canceled");
297                BufferServiceError::RequestCanceled
298            })?;
299
300            // call service
301            Ok(self.service.call(req).await?)
302        }
303    }
304
305    ntex_service::forward_poll!(service);
306}
307
308#[cfg(test)]
309mod tests {
310    use ntex_service::{Pipeline, ServiceFactory, apply, apply2, fn_factory};
311    use std::{rc::Rc, time::Duration};
312
313    use super::*;
314    use crate::{future::lazy, task::LocalWaker};
315
316    #[derive(Debug, Clone)]
317    struct TestService(Rc<Inner>);
318
319    #[derive(Debug)]
320    struct Inner {
321        ready: Cell<bool>,
322        waker: LocalWaker,
323        count: Cell<usize>,
324    }
325
326    impl Service<()> for TestService {
327        type Response = ();
328        type Error = ();
329
330        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
331            poll_fn(|cx| {
332                self.0.waker.register(cx.waker());
333                if self.0.ready.get() {
334                    Poll::Ready(Ok(()))
335                } else {
336                    Poll::Pending
337                }
338            })
339            .await
340        }
341
342        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
343            self.0.ready.set(false);
344            self.0.count.set(self.0.count.get() + 1);
345            Ok(())
346        }
347    }
348
349    #[ntex::test]
350    async fn test_service() {
351        let inner = Rc::new(Inner {
352            ready: Cell::new(false),
353            waker: LocalWaker::default(),
354            count: Cell::new(0),
355        });
356
357        let srv =
358            Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
359        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
360
361        let srv1 = srv.clone();
362        ntex::rt::spawn(async move {
363            let _ = srv1.call(()).await;
364        });
365        crate::time::sleep(Duration::from_millis(25)).await;
366        assert_eq!(inner.count.get(), 0);
367        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
368
369        let srv1 = srv.clone();
370        ntex::rt::spawn(async move {
371            let _ = srv1.call(()).await;
372        });
373        crate::time::sleep(Duration::from_millis(25)).await;
374        assert_eq!(inner.count.get(), 0);
375        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
376
377        inner.ready.set(true);
378        inner.waker.wake();
379        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
380
381        crate::time::sleep(Duration::from_millis(25)).await;
382        assert_eq!(inner.count.get(), 1);
383
384        inner.ready.set(true);
385        inner.waker.wake();
386        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
387
388        crate::time::sleep(Duration::from_millis(25)).await;
389        assert_eq!(inner.count.get(), 2);
390
391        let inner = Rc::new(Inner {
392            ready: Cell::new(true),
393            waker: LocalWaker::default(),
394            count: Cell::new(0),
395        });
396
397        let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
398        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
399
400        let _ = srv.call(()).await;
401        assert_eq!(inner.count.get(), 1);
402        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
403        assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
404
405        let err = BufferServiceError::from("test");
406        assert!(format!("{err}").contains("test"));
407        assert!(format!("{srv:?}").contains("BufferService"));
408        assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
409    }
410
411    #[ntex::test]
412    #[allow(clippy::redundant_clone)]
413    async fn test_middleware() {
414        let inner = Rc::new(Inner {
415            ready: Cell::new(false),
416            waker: LocalWaker::default(),
417            count: Cell::new(0),
418        });
419
420        let srv = apply(
421            Buffer::default().buf_size(2).clone(),
422            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
423        );
424
425        let srv = srv.pipeline(&()).await.unwrap().bind();
426        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
427
428        let srv1 = srv.clone();
429        ntex::rt::spawn(async move {
430            let _ = srv1.call(()).await;
431        });
432        crate::time::sleep(Duration::from_millis(25)).await;
433        assert_eq!(inner.count.get(), 0);
434        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
435
436        let srv1 = srv.clone();
437        ntex::rt::spawn(async move {
438            let _ = srv1.call(()).await;
439        });
440        crate::time::sleep(Duration::from_millis(25)).await;
441        assert_eq!(inner.count.get(), 0);
442        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
443
444        inner.ready.set(true);
445        inner.waker.wake();
446        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
447
448        crate::time::sleep(Duration::from_millis(25)).await;
449        assert_eq!(inner.count.get(), 1);
450
451        inner.ready.set(true);
452        inner.waker.wake();
453        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
454
455        crate::time::sleep(Duration::from_millis(25)).await;
456        assert_eq!(inner.count.get(), 2);
457    }
458
459    #[ntex::test]
460    #[allow(clippy::redundant_clone)]
461    async fn test_middleware2() {
462        let inner = Rc::new(Inner {
463            ready: Cell::new(false),
464            waker: LocalWaker::default(),
465            count: Cell::new(0),
466        });
467
468        let srv = apply2(
469            Buffer::default().buf_size(2).clone(),
470            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
471        );
472
473        let srv = srv.pipeline(&()).await.unwrap().bind();
474        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
475
476        let srv1 = srv.clone();
477        ntex::rt::spawn(async move {
478            let _ = srv1.call(()).await;
479        });
480        crate::time::sleep(Duration::from_millis(25)).await;
481        assert_eq!(inner.count.get(), 0);
482        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
483
484        let srv1 = srv.clone();
485        ntex::rt::spawn(async move {
486            let _ = srv1.call(()).await;
487        });
488        crate::time::sleep(Duration::from_millis(25)).await;
489        assert_eq!(inner.count.get(), 0);
490        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
491
492        inner.ready.set(true);
493        inner.waker.wake();
494        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
495
496        crate::time::sleep(Duration::from_millis(25)).await;
497        assert_eq!(inner.count.get(), 1);
498
499        inner.ready.set(true);
500        inner.waker.wake();
501        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
502
503        crate::time::sleep(Duration::from_millis(25)).await;
504        assert_eq!(inner.count.get(), 2);
505    }
506}