1use std::cell::{Cell, RefCell};
3use std::task::{Poll, Waker, ready};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{
7 Middleware, Middleware2, Pipeline, PipelineBinding, Service, ServiceCtx,
8};
9
10use crate::channel::oneshot;
11
12pub struct Buffer<R> {
16 buf_size: usize,
17 cancel_on_shutdown: bool,
18 _t: PhantomData<R>,
19}
20
21impl<R> Buffer<R> {
22 pub fn buf_size(mut self, size: usize) -> Self {
23 self.buf_size = size;
24 self
25 }
26
27 pub fn cancel_on_shutdown(mut self) -> Self {
31 self.cancel_on_shutdown = true;
32 self
33 }
34}
35
36impl<R> Default for Buffer<R> {
37 fn default() -> Self {
38 Self {
39 buf_size: 16,
40 cancel_on_shutdown: false,
41 _t: PhantomData,
42 }
43 }
44}
45
46impl<R> fmt::Debug for Buffer<R> {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_struct("Buffer")
49 .field("buf_size", &self.buf_size)
50 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
51 .finish()
52 }
53}
54
55impl<R> Clone for Buffer<R> {
56 fn clone(&self) -> Self {
57 Self {
58 buf_size: self.buf_size,
59 cancel_on_shutdown: self.cancel_on_shutdown,
60 _t: PhantomData,
61 }
62 }
63}
64
65impl<R, S> Middleware<S> for Buffer<R>
66where
67 S: Service<R> + 'static,
68 R: 'static,
69{
70 type Service = BufferService<R, S>;
71
72 fn create(&self, service: S) -> Self::Service {
73 BufferService::new(self.buf_size, service)
74 }
75}
76
77impl<R, S, C> Middleware2<S, C> for Buffer<R>
78where
79 S: Service<R> + 'static,
80 R: 'static,
81{
82 type Service = BufferService<R, S>;
83
84 fn create(&self, service: S, _: C) -> Self::Service {
85 BufferService::new(self.buf_size, service)
86 }
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum BufferServiceError<E> {
91 Service(E),
92 RequestCanceled,
93}
94
95impl<E> From<E> for BufferServiceError<E> {
96 fn from(err: E) -> Self {
97 BufferServiceError::Service(err)
98 }
99}
100
101impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
105 BufferServiceError::RequestCanceled => {
106 f.write_str("buffer service request canceled")
107 }
108 }
109 }
110}
111
112impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
113
114pub struct BufferService<R, S: Service<R>> {
118 size: usize,
119 ready: Cell<bool>,
120 service: PipelineBinding<S, R>,
121 buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
122 next_call: RefCell<Option<oneshot::Receiver<()>>>,
123 cancel_on_shutdown: bool,
124 readiness: Cell<Option<Waker>>,
125 _t: PhantomData<R>,
126}
127
128impl<R, S> BufferService<R, S>
129where
130 S: Service<R> + 'static,
131 R: 'static,
132{
133 pub fn new(size: usize, service: S) -> Self {
134 Self {
135 size,
136 service: Pipeline::new(service).bind(),
137 ready: Cell::new(false),
138 buf: RefCell::new(VecDeque::with_capacity(size)),
139 next_call: RefCell::default(),
140 cancel_on_shutdown: false,
141 readiness: Cell::new(None),
142 _t: PhantomData,
143 }
144 }
145
146 pub fn cancel_on_shutdown(self) -> Self {
147 Self {
148 cancel_on_shutdown: true,
149 ..self
150 }
151 }
152}
153
154impl<R, S> Clone for BufferService<R, S>
155where
156 S: Service<R> + Clone,
157{
158 fn clone(&self) -> Self {
159 Self {
160 size: self.size,
161 ready: Cell::new(false),
162 service: self.service.clone(),
163 buf: RefCell::new(VecDeque::with_capacity(self.size)),
164 next_call: RefCell::default(),
165 cancel_on_shutdown: self.cancel_on_shutdown,
166 readiness: Cell::new(None),
167 _t: PhantomData,
168 }
169 }
170}
171
172impl<R, S> fmt::Debug for BufferService<R, S>
173where
174 S: Service<R> + fmt::Debug,
175{
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.debug_struct("BufferService")
178 .field("size", &self.size)
179 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
180 .field("ready", &self.ready)
181 .field("service", &self.service)
182 .field("buf", &self.buf)
183 .field("next_call", &self.next_call)
184 .finish()
185 }
186}
187
188impl<R, S> Service<R> for BufferService<R, S>
189where
190 S: Service<R> + 'static,
191 R: 'static,
192{
193 type Response = S::Response;
194 type Error = BufferServiceError<S::Error>;
195
196 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
197 let next_call = self.next_call.borrow_mut().take();
199 if let Some(next_call) = next_call {
200 let _ = next_call.recv().await;
201 }
202
203 poll_fn(|cx| {
204 let mut buffer = self.buf.borrow_mut();
205
206 if self.service.poll_ready(cx)?.is_pending() {
208 if buffer.len() < self.size {
209 self.ready.set(false);
211 Poll::Ready(Ok(()))
212 } else {
213 log::trace!("Buffer limit exceeded");
214 let _ = self.readiness.take().map(|w| w.wake());
216 Poll::Pending
217 }
218 } else {
219 while let Some(sender) = buffer.pop_front() {
220 let (next_call_tx, next_call_rx) = oneshot::channel();
221 if sender.send(next_call_tx).is_err()
222 || next_call_rx.poll_recv(cx).is_ready()
223 {
224 continue;
226 }
227 self.next_call.borrow_mut().replace(next_call_rx);
228 self.ready.set(false);
229 return Poll::Ready(Ok(()));
230 }
231
232 self.ready.set(true);
233 Poll::Ready(Ok(()))
234 }
235 })
236 .await
237 }
238
239 async fn shutdown(&self) {
240 let next_call = self.next_call.borrow_mut().take();
242 if let Some(next_call) = next_call {
243 let _ = next_call.recv().await;
244 }
245
246 poll_fn(|cx| {
247 let mut buffer = self.buf.borrow_mut();
248 if self.cancel_on_shutdown {
249 buffer.clear();
250 }
251
252 if !buffer.is_empty() {
253 if ready!(self.service.poll_ready(cx)).is_err() {
254 log::error!(
255 "Buffered inner service failed while buffer flushing on shutdown"
256 );
257 return Poll::Ready(());
258 }
259
260 while let Some(sender) = buffer.pop_front() {
261 let (next_call_tx, next_call_rx) = oneshot::channel();
262 if sender.send(next_call_tx).is_err()
263 || next_call_rx.poll_recv(cx).is_ready()
264 {
265 continue;
267 }
268 self.next_call.borrow_mut().replace(next_call_rx);
269 if buffer.is_empty() {
270 break;
271 }
272 return Poll::Pending;
273 }
274 }
275 Poll::Ready(())
276 })
277 .await;
278
279 self.service.shutdown().await;
280 }
281
282 async fn call(
283 &self,
284 req: R,
285 _: ServiceCtx<'_, Self>,
286 ) -> Result<Self::Response, Self::Error> {
287 if self.ready.get() {
288 self.ready.set(false);
289 Ok(self.service.call_nowait(req).await?)
290 } else {
291 let (tx, rx) = oneshot::channel();
292 self.buf.borrow_mut().push_back(tx);
293
294 let _task_guard = rx.recv().await.map_err(|_| {
296 log::trace!("Buffered service request canceled");
297 BufferServiceError::RequestCanceled
298 })?;
299
300 Ok(self.service.call(req).await?)
302 }
303 }
304
305 ntex_service::forward_poll!(service);
306}
307
308#[cfg(test)]
309mod tests {
310 use ntex_service::{Pipeline, ServiceFactory, apply, apply2, fn_factory};
311 use std::{rc::Rc, time::Duration};
312
313 use super::*;
314 use crate::{future::lazy, task::LocalWaker};
315
316 #[derive(Debug, Clone)]
317 struct TestService(Rc<Inner>);
318
319 #[derive(Debug)]
320 struct Inner {
321 ready: Cell<bool>,
322 waker: LocalWaker,
323 count: Cell<usize>,
324 }
325
326 impl Service<()> for TestService {
327 type Response = ();
328 type Error = ();
329
330 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
331 poll_fn(|cx| {
332 self.0.waker.register(cx.waker());
333 if self.0.ready.get() {
334 Poll::Ready(Ok(()))
335 } else {
336 Poll::Pending
337 }
338 })
339 .await
340 }
341
342 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
343 self.0.ready.set(false);
344 self.0.count.set(self.0.count.get() + 1);
345 Ok(())
346 }
347 }
348
349 #[ntex::test]
350 async fn test_service() {
351 let inner = Rc::new(Inner {
352 ready: Cell::new(false),
353 waker: LocalWaker::default(),
354 count: Cell::new(0),
355 });
356
357 let srv =
358 Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
359 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
360
361 let srv1 = srv.clone();
362 ntex::rt::spawn(async move {
363 let _ = srv1.call(()).await;
364 });
365 crate::time::sleep(Duration::from_millis(25)).await;
366 assert_eq!(inner.count.get(), 0);
367 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
368
369 let srv1 = srv.clone();
370 ntex::rt::spawn(async move {
371 let _ = srv1.call(()).await;
372 });
373 crate::time::sleep(Duration::from_millis(25)).await;
374 assert_eq!(inner.count.get(), 0);
375 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
376
377 inner.ready.set(true);
378 inner.waker.wake();
379 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
380
381 crate::time::sleep(Duration::from_millis(25)).await;
382 assert_eq!(inner.count.get(), 1);
383
384 inner.ready.set(true);
385 inner.waker.wake();
386 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
387
388 crate::time::sleep(Duration::from_millis(25)).await;
389 assert_eq!(inner.count.get(), 2);
390
391 let inner = Rc::new(Inner {
392 ready: Cell::new(true),
393 waker: LocalWaker::default(),
394 count: Cell::new(0),
395 });
396
397 let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
398 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
399
400 let _ = srv.call(()).await;
401 assert_eq!(inner.count.get(), 1);
402 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
403 assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
404
405 let err = BufferServiceError::from("test");
406 assert!(format!("{err}").contains("test"));
407 assert!(format!("{srv:?}").contains("BufferService"));
408 assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
409 }
410
411 #[ntex::test]
412 #[allow(clippy::redundant_clone)]
413 async fn test_middleware() {
414 let inner = Rc::new(Inner {
415 ready: Cell::new(false),
416 waker: LocalWaker::default(),
417 count: Cell::new(0),
418 });
419
420 let srv = apply(
421 Buffer::default().buf_size(2).clone(),
422 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
423 );
424
425 let srv = srv.pipeline(&()).await.unwrap().bind();
426 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
427
428 let srv1 = srv.clone();
429 ntex::rt::spawn(async move {
430 let _ = srv1.call(()).await;
431 });
432 crate::time::sleep(Duration::from_millis(25)).await;
433 assert_eq!(inner.count.get(), 0);
434 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
435
436 let srv1 = srv.clone();
437 ntex::rt::spawn(async move {
438 let _ = srv1.call(()).await;
439 });
440 crate::time::sleep(Duration::from_millis(25)).await;
441 assert_eq!(inner.count.get(), 0);
442 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
443
444 inner.ready.set(true);
445 inner.waker.wake();
446 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
447
448 crate::time::sleep(Duration::from_millis(25)).await;
449 assert_eq!(inner.count.get(), 1);
450
451 inner.ready.set(true);
452 inner.waker.wake();
453 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
454
455 crate::time::sleep(Duration::from_millis(25)).await;
456 assert_eq!(inner.count.get(), 2);
457 }
458
459 #[ntex::test]
460 #[allow(clippy::redundant_clone)]
461 async fn test_middleware2() {
462 let inner = Rc::new(Inner {
463 ready: Cell::new(false),
464 waker: LocalWaker::default(),
465 count: Cell::new(0),
466 });
467
468 let srv = apply2(
469 Buffer::default().buf_size(2).clone(),
470 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
471 );
472
473 let srv = srv.pipeline(&()).await.unwrap().bind();
474 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
475
476 let srv1 = srv.clone();
477 ntex::rt::spawn(async move {
478 let _ = srv1.call(()).await;
479 });
480 crate::time::sleep(Duration::from_millis(25)).await;
481 assert_eq!(inner.count.get(), 0);
482 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
483
484 let srv1 = srv.clone();
485 ntex::rt::spawn(async move {
486 let _ = srv1.call(()).await;
487 });
488 crate::time::sleep(Duration::from_millis(25)).await;
489 assert_eq!(inner.count.get(), 0);
490 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
491
492 inner.ready.set(true);
493 inner.waker.wake();
494 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
495
496 crate::time::sleep(Duration::from_millis(25)).await;
497 assert_eq!(inner.count.get(), 1);
498
499 inner.ready.set(true);
500 inner.waker.wake();
501 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
502
503 crate::time::sleep(Duration::from_millis(25)).await;
504 assert_eq!(inner.count.get(), 2);
505 }
506}