ntex_util/services/
buffer.rs

1//! Service that buffers incomming requests.
2use std::cell::{Cell, RefCell};
3use std::task::{ready, Poll, Waker};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
7
8use crate::channel::oneshot;
9
10/// Buffer - service factory for service that can buffer incoming request.
11///
12/// Default number of buffered requests is 16
13pub struct Buffer<R> {
14    buf_size: usize,
15    cancel_on_shutdown: bool,
16    _t: PhantomData<R>,
17}
18
19impl<R> Buffer<R> {
20    pub fn buf_size(mut self, size: usize) -> Self {
21        self.buf_size = size;
22        self
23    }
24
25    /// Cancel all buffered requests on shutdown
26    ///
27    /// By default buffered requests are flushed during poll_shutdown
28    pub fn cancel_on_shutdown(mut self) -> Self {
29        self.cancel_on_shutdown = true;
30        self
31    }
32}
33
34impl<R> Default for Buffer<R> {
35    fn default() -> Self {
36        Self {
37            buf_size: 16,
38            cancel_on_shutdown: false,
39            _t: PhantomData,
40        }
41    }
42}
43
44impl<R> fmt::Debug for Buffer<R> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("Buffer")
47            .field("buf_size", &self.buf_size)
48            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
49            .finish()
50    }
51}
52
53impl<R> Clone for Buffer<R> {
54    fn clone(&self) -> Self {
55        Self {
56            buf_size: self.buf_size,
57            cancel_on_shutdown: self.cancel_on_shutdown,
58            _t: PhantomData,
59        }
60    }
61}
62
63impl<R, S> Middleware<S> for Buffer<R>
64where
65    S: Service<R> + 'static,
66    R: 'static,
67{
68    type Service = BufferService<R, S>;
69
70    fn create(&self, service: S) -> Self::Service {
71        BufferService {
72            service: Pipeline::new(service).bind(),
73            size: self.buf_size,
74            ready: Cell::new(false),
75            buf: RefCell::new(VecDeque::with_capacity(self.buf_size)),
76            next_call: RefCell::default(),
77            cancel_on_shutdown: self.cancel_on_shutdown,
78            readiness: Cell::new(None),
79            _t: PhantomData,
80        }
81    }
82}
83
84#[derive(Clone, Copy, Debug, PartialEq, Eq)]
85pub enum BufferServiceError<E> {
86    Service(E),
87    RequestCanceled,
88}
89
90impl<E> From<E> for BufferServiceError<E> {
91    fn from(err: E) -> Self {
92        BufferServiceError::Service(err)
93    }
94}
95
96impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        match self {
99            BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
100            BufferServiceError::RequestCanceled => {
101                f.write_str("buffer service request canceled")
102            }
103        }
104    }
105}
106
107impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
108
109/// Buffer service - service that can buffer incoming requests.
110///
111/// Default number of buffered requests is 16
112pub struct BufferService<R, S: Service<R>> {
113    size: usize,
114    ready: Cell<bool>,
115    service: PipelineBinding<S, R>,
116    buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
117    next_call: RefCell<Option<oneshot::Receiver<()>>>,
118    cancel_on_shutdown: bool,
119    readiness: Cell<Option<Waker>>,
120    _t: PhantomData<R>,
121}
122
123impl<R, S> BufferService<R, S>
124where
125    S: Service<R> + 'static,
126    R: 'static,
127{
128    pub fn new(size: usize, service: S) -> Self {
129        Self {
130            size,
131            service: Pipeline::new(service).bind(),
132            ready: Cell::new(false),
133            buf: RefCell::new(VecDeque::with_capacity(size)),
134            next_call: RefCell::default(),
135            cancel_on_shutdown: false,
136            readiness: Cell::new(None),
137            _t: PhantomData,
138        }
139    }
140
141    pub fn cancel_on_shutdown(self) -> Self {
142        Self {
143            cancel_on_shutdown: true,
144            ..self
145        }
146    }
147}
148
149impl<R, S> Clone for BufferService<R, S>
150where
151    S: Service<R> + Clone,
152{
153    fn clone(&self) -> Self {
154        Self {
155            size: self.size,
156            ready: Cell::new(false),
157            service: self.service.clone(),
158            buf: RefCell::new(VecDeque::with_capacity(self.size)),
159            next_call: RefCell::default(),
160            cancel_on_shutdown: self.cancel_on_shutdown,
161            readiness: Cell::new(None),
162            _t: PhantomData,
163        }
164    }
165}
166
167impl<R, S> fmt::Debug for BufferService<R, S>
168where
169    S: Service<R> + fmt::Debug,
170{
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        f.debug_struct("BufferService")
173            .field("size", &self.size)
174            .field("cancel_on_shutdown", &self.cancel_on_shutdown)
175            .field("ready", &self.ready)
176            .field("service", &self.service)
177            .field("buf", &self.buf)
178            .field("next_call", &self.next_call)
179            .finish()
180    }
181}
182
183impl<R, S> Service<R> for BufferService<R, S>
184where
185    S: Service<R> + 'static,
186    R: 'static,
187{
188    type Response = S::Response;
189    type Error = BufferServiceError<S::Error>;
190
191    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
192        // hold advancement until the last released task either makes a call or is dropped
193        let next_call = self.next_call.borrow_mut().take();
194        if let Some(next_call) = next_call {
195            let _ = next_call.recv().await;
196        }
197
198        poll_fn(|cx| {
199            let mut buffer = self.buf.borrow_mut();
200
201            // handle inner service readiness
202            if self.service.poll_ready(cx)?.is_pending() {
203                if buffer.len() < self.size {
204                    // buffer next request
205                    self.ready.set(false);
206                    Poll::Ready(Ok(()))
207                } else {
208                    log::trace!("Buffer limit exceeded");
209                    // service is not ready
210                    let _ = self.readiness.take().map(|w| w.wake());
211                    Poll::Pending
212                }
213            } else {
214                while let Some(sender) = buffer.pop_front() {
215                    let (next_call_tx, next_call_rx) = oneshot::channel();
216                    if sender.send(next_call_tx).is_err()
217                        || next_call_rx.poll_recv(cx).is_ready()
218                    {
219                        // the task is gone
220                        continue;
221                    }
222                    self.next_call.borrow_mut().replace(next_call_rx);
223                    self.ready.set(false);
224                    return Poll::Ready(Ok(()));
225                }
226
227                self.ready.set(true);
228                Poll::Ready(Ok(()))
229            }
230        })
231        .await
232    }
233
234    async fn shutdown(&self) {
235        // hold advancement until the last released task either makes a call or is dropped
236        let next_call = self.next_call.borrow_mut().take();
237        if let Some(next_call) = next_call {
238            let _ = next_call.recv().await;
239        }
240
241        poll_fn(|cx| {
242            let mut buffer = self.buf.borrow_mut();
243            if self.cancel_on_shutdown {
244                buffer.clear();
245            }
246
247            if !buffer.is_empty() {
248                if ready!(self.service.poll_ready(cx)).is_err() {
249                    log::error!(
250                        "Buffered inner service failed while buffer flushing on shutdown"
251                    );
252                    return Poll::Ready(());
253                }
254
255                while let Some(sender) = buffer.pop_front() {
256                    let (next_call_tx, next_call_rx) = oneshot::channel();
257                    if sender.send(next_call_tx).is_err()
258                        || next_call_rx.poll_recv(cx).is_ready()
259                    {
260                        // the task is gone
261                        continue;
262                    }
263                    self.next_call.borrow_mut().replace(next_call_rx);
264                    if buffer.is_empty() {
265                        break;
266                    }
267                    return Poll::Pending;
268                }
269            }
270            Poll::Ready(())
271        })
272        .await;
273
274        self.service.shutdown().await;
275    }
276
277    async fn call(
278        &self,
279        req: R,
280        _: ServiceCtx<'_, Self>,
281    ) -> Result<Self::Response, Self::Error> {
282        if self.ready.get() {
283            self.ready.set(false);
284            Ok(self.service.call_nowait(req).await?)
285        } else {
286            let (tx, rx) = oneshot::channel();
287            self.buf.borrow_mut().push_back(tx);
288
289            // release
290            let _task_guard = rx.recv().await.map_err(|_| {
291                log::trace!("Buffered service request canceled");
292                BufferServiceError::RequestCanceled
293            })?;
294
295            // call service
296            Ok(self.service.call(req).await?)
297        }
298    }
299
300    ntex_service::forward_poll!(service);
301}
302
303#[cfg(test)]
304mod tests {
305    use ntex_service::{apply, fn_factory, Pipeline, ServiceFactory};
306    use std::{rc::Rc, time::Duration};
307
308    use super::*;
309    use crate::future::lazy;
310    use crate::task::LocalWaker;
311
312    #[derive(Debug, Clone)]
313    struct TestService(Rc<Inner>);
314
315    #[derive(Debug)]
316    struct Inner {
317        ready: Cell<bool>,
318        waker: LocalWaker,
319        count: Cell<usize>,
320    }
321
322    impl Service<()> for TestService {
323        type Response = ();
324        type Error = ();
325
326        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
327            poll_fn(|cx| {
328                self.0.waker.register(cx.waker());
329                if self.0.ready.get() {
330                    Poll::Ready(Ok(()))
331                } else {
332                    Poll::Pending
333                }
334            })
335            .await
336        }
337
338        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
339            self.0.ready.set(false);
340            self.0.count.set(self.0.count.get() + 1);
341            Ok(())
342        }
343    }
344
345    #[ntex_macros::rt_test2]
346    async fn test_service() {
347        let inner = Rc::new(Inner {
348            ready: Cell::new(false),
349            waker: LocalWaker::default(),
350            count: Cell::new(0),
351        });
352
353        let srv =
354            Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
355        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
356
357        let srv1 = srv.clone();
358        ntex::rt::spawn(async move {
359            let _ = srv1.call(()).await;
360        });
361        crate::time::sleep(Duration::from_millis(25)).await;
362        assert_eq!(inner.count.get(), 0);
363        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
364
365        let srv1 = srv.clone();
366        ntex::rt::spawn(async move {
367            let _ = srv1.call(()).await;
368        });
369        crate::time::sleep(Duration::from_millis(25)).await;
370        assert_eq!(inner.count.get(), 0);
371        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
372
373        inner.ready.set(true);
374        inner.waker.wake();
375        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
376
377        crate::time::sleep(Duration::from_millis(25)).await;
378        assert_eq!(inner.count.get(), 1);
379
380        inner.ready.set(true);
381        inner.waker.wake();
382        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
383
384        crate::time::sleep(Duration::from_millis(25)).await;
385        assert_eq!(inner.count.get(), 2);
386
387        let inner = Rc::new(Inner {
388            ready: Cell::new(true),
389            waker: LocalWaker::default(),
390            count: Cell::new(0),
391        });
392
393        let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
394        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
395
396        let _ = srv.call(()).await;
397        assert_eq!(inner.count.get(), 1);
398        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
399        assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
400
401        let err = BufferServiceError::from("test");
402        assert!(format!("{err}").contains("test"));
403        assert!(format!("{srv:?}").contains("BufferService"));
404        assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
405    }
406
407    #[ntex_macros::rt_test2]
408    #[allow(clippy::redundant_clone)]
409    async fn test_middleware() {
410        let inner = Rc::new(Inner {
411            ready: Cell::new(false),
412            waker: LocalWaker::default(),
413            count: Cell::new(0),
414        });
415
416        let srv = apply(
417            Buffer::default().buf_size(2).clone(),
418            fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
419        );
420
421        let srv = srv.pipeline(&()).await.unwrap().bind();
422        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
423
424        let srv1 = srv.clone();
425        ntex::rt::spawn(async move {
426            let _ = srv1.call(()).await;
427        });
428        crate::time::sleep(Duration::from_millis(25)).await;
429        assert_eq!(inner.count.get(), 0);
430        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
431
432        let srv1 = srv.clone();
433        ntex::rt::spawn(async move {
434            let _ = srv1.call(()).await;
435        });
436        crate::time::sleep(Duration::from_millis(25)).await;
437        assert_eq!(inner.count.get(), 0);
438        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
439
440        inner.ready.set(true);
441        inner.waker.wake();
442        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
443
444        crate::time::sleep(Duration::from_millis(25)).await;
445        assert_eq!(inner.count.get(), 1);
446
447        inner.ready.set(true);
448        inner.waker.wake();
449        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
450
451        crate::time::sleep(Duration::from_millis(25)).await;
452        assert_eq!(inner.count.get(), 2);
453    }
454}