1use std::cell::{Cell, RefCell};
3use std::task::{Poll, Waker, ready};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
7
8use crate::channel::oneshot;
9
10pub struct Buffer<R> {
14 buf_size: usize,
15 cancel_on_shutdown: bool,
16 _t: PhantomData<R>,
17}
18
19impl<R> Buffer<R> {
20 #[must_use]
24 pub fn buf_size(mut self, size: usize) -> Self {
25 self.buf_size = size;
26 self
27 }
28
29 #[must_use]
33 pub fn cancel_on_shutdown(mut self) -> Self {
34 self.cancel_on_shutdown = true;
35 self
36 }
37}
38
39impl<R> Default for Buffer<R> {
40 fn default() -> Self {
41 Self {
42 buf_size: 16,
43 cancel_on_shutdown: false,
44 _t: PhantomData,
45 }
46 }
47}
48
49impl<R> fmt::Debug for Buffer<R> {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 f.debug_struct("Buffer")
52 .field("buf_size", &self.buf_size)
53 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
54 .finish()
55 }
56}
57
58impl<R> Clone for Buffer<R> {
59 fn clone(&self) -> Self {
60 Self {
61 buf_size: self.buf_size,
62 cancel_on_shutdown: self.cancel_on_shutdown,
63 _t: PhantomData,
64 }
65 }
66}
67
68impl<R, S, C> Middleware<S, C> for Buffer<R>
69where
70 S: Service<R> + 'static,
71 R: 'static,
72{
73 type Service = BufferService<R, S>;
74
75 fn create(&self, service: S, _: C) -> Self::Service {
76 BufferService::new(self.buf_size, service)
77 }
78}
79
80#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81pub enum BufferServiceError<E> {
82 Service(E),
83 RequestCanceled,
84}
85
86impl<E> From<E> for BufferServiceError<E> {
87 fn from(err: E) -> Self {
88 BufferServiceError::Service(err)
89 }
90}
91
92impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
96 BufferServiceError::RequestCanceled => {
97 f.write_str("buffer service request canceled")
98 }
99 }
100 }
101}
102
103impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
104
105pub struct BufferService<R, S: Service<R>> {
109 size: usize,
110 ready: Cell<bool>,
111 service: PipelineBinding<S, R>,
112 buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
113 next_call: RefCell<Option<oneshot::Receiver<()>>>,
114 cancel_on_shutdown: bool,
115 readiness: Cell<Option<Waker>>,
116 _t: PhantomData<R>,
117}
118
119impl<R, S> BufferService<R, S>
120where
121 S: Service<R> + 'static,
122 R: 'static,
123{
124 #[must_use]
125 pub fn new(size: usize, service: S) -> Self {
126 Self {
127 size,
128 service: Pipeline::new(service).bind(),
129 ready: Cell::new(false),
130 buf: RefCell::new(VecDeque::with_capacity(size)),
131 next_call: RefCell::default(),
132 cancel_on_shutdown: false,
133 readiness: Cell::new(None),
134 _t: PhantomData,
135 }
136 }
137
138 #[must_use]
139 pub fn cancel_on_shutdown(self) -> Self {
140 Self {
141 cancel_on_shutdown: true,
142 ..self
143 }
144 }
145}
146
147impl<R, S> Clone for BufferService<R, S>
148where
149 S: Service<R> + Clone,
150{
151 fn clone(&self) -> Self {
152 Self {
153 size: self.size,
154 ready: Cell::new(false),
155 service: self.service.clone(),
156 buf: RefCell::new(VecDeque::with_capacity(self.size)),
157 next_call: RefCell::default(),
158 cancel_on_shutdown: self.cancel_on_shutdown,
159 readiness: Cell::new(None),
160 _t: PhantomData,
161 }
162 }
163}
164
165impl<R, S> fmt::Debug for BufferService<R, S>
166where
167 S: Service<R> + fmt::Debug,
168{
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 f.debug_struct("BufferService")
171 .field("size", &self.size)
172 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
173 .field("ready", &self.ready)
174 .field("service", &self.service)
175 .field("buf", &self.buf)
176 .field("next_call", &self.next_call)
177 .finish()
178 }
179}
180
181impl<R, S> Service<R> for BufferService<R, S>
182where
183 S: Service<R> + 'static,
184 R: 'static,
185{
186 type Response = S::Response;
187 type Error = BufferServiceError<S::Error>;
188
189 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
190 let next_call = self.next_call.borrow_mut().take();
192 if let Some(next_call) = next_call {
193 let _ = next_call.recv().await;
194 }
195
196 poll_fn(|cx| {
197 let mut buffer = self.buf.borrow_mut();
198
199 if self.service.poll_ready(cx)?.is_pending() {
201 if buffer.len() < self.size {
202 self.ready.set(false);
204 Poll::Ready(Ok(()))
205 } else {
206 log::trace!("Buffer limit exceeded");
207 let _ = self.readiness.take().map(Waker::wake);
209 Poll::Pending
210 }
211 } else {
212 while let Some(sender) = buffer.pop_front() {
213 let (next_call_tx, next_call_rx) = oneshot::channel();
214 if sender.send(next_call_tx).is_err()
215 || next_call_rx.poll_recv(cx).is_ready()
216 {
217 continue;
219 }
220 self.next_call.borrow_mut().replace(next_call_rx);
221 self.ready.set(false);
222 return Poll::Ready(Ok(()));
223 }
224
225 self.ready.set(true);
226 Poll::Ready(Ok(()))
227 }
228 })
229 .await
230 }
231
232 async fn shutdown(&self) {
233 let next_call = self.next_call.borrow_mut().take();
235 if let Some(next_call) = next_call {
236 let _ = next_call.recv().await;
237 }
238
239 poll_fn(|cx| {
240 let mut buffer = self.buf.borrow_mut();
241 if self.cancel_on_shutdown {
242 buffer.clear();
243 }
244
245 if !buffer.is_empty() {
246 if ready!(self.service.poll_ready(cx)).is_err() {
247 log::error!(
248 "Buffered inner service failed while buffer flushing on shutdown"
249 );
250 return Poll::Ready(());
251 }
252
253 while let Some(sender) = buffer.pop_front() {
254 let (next_call_tx, next_call_rx) = oneshot::channel();
255 if sender.send(next_call_tx).is_err()
256 || next_call_rx.poll_recv(cx).is_ready()
257 {
258 continue;
260 }
261 self.next_call.borrow_mut().replace(next_call_rx);
262 if buffer.is_empty() {
263 break;
264 }
265 return Poll::Pending;
266 }
267 }
268 Poll::Ready(())
269 })
270 .await;
271
272 self.service.shutdown().await;
273 }
274
275 async fn call(
276 &self,
277 req: R,
278 _: ServiceCtx<'_, Self>,
279 ) -> Result<Self::Response, Self::Error> {
280 if self.ready.get() {
281 self.ready.set(false);
282 Ok(self.service.call_nowait(req).await?)
283 } else {
284 let (tx, rx) = oneshot::channel();
285 self.buf.borrow_mut().push_back(tx);
286
287 let _task_guard = rx.recv().await.map_err(|_| {
289 log::trace!("Buffered service request canceled");
290 BufferServiceError::RequestCanceled
291 })?;
292
293 Ok(self.service.call(req).await?)
295 }
296 }
297
298 ntex_service::forward_poll!(service);
299}
300
301#[cfg(test)]
302mod tests {
303 use ntex_service::{Pipeline, apply, fn_factory};
304 use std::{rc::Rc, time::Duration};
305
306 use super::*;
307 use crate::{future::lazy, task::LocalWaker};
308
309 #[derive(Debug, Clone)]
310 struct TestService(Rc<Inner>);
311
312 #[derive(Debug)]
313 struct Inner {
314 ready: Cell<bool>,
315 waker: LocalWaker,
316 count: Cell<usize>,
317 }
318
319 impl Service<()> for TestService {
320 type Response = ();
321 type Error = ();
322
323 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
324 poll_fn(|cx| {
325 self.0.waker.register(cx.waker());
326 if self.0.ready.get() {
327 Poll::Ready(Ok(()))
328 } else {
329 Poll::Pending
330 }
331 })
332 .await
333 }
334
335 async fn call(&self, _r: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
336 self.0.ready.set(false);
337 self.0.count.set(self.0.count.get() + 1);
338 Ok(())
339 }
340 }
341
342 #[ntex::test]
343 async fn test_service() {
344 let inner = Rc::new(Inner {
345 ready: Cell::new(false),
346 waker: LocalWaker::default(),
347 count: Cell::new(0),
348 });
349
350 let srv =
351 Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
352 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
353
354 let srv1 = srv.clone();
355 ntex::rt::spawn(async move {
356 let _ = srv1.call(()).await;
357 });
358 crate::time::sleep(Duration::from_millis(25)).await;
359 assert_eq!(inner.count.get(), 0);
360 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
361
362 let srv1 = srv.clone();
363 ntex::rt::spawn(async move {
364 let _ = srv1.call(()).await;
365 });
366 crate::time::sleep(Duration::from_millis(25)).await;
367 assert_eq!(inner.count.get(), 0);
368 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
369
370 inner.ready.set(true);
371 inner.waker.wake();
372 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
373
374 crate::time::sleep(Duration::from_millis(25)).await;
375 assert_eq!(inner.count.get(), 1);
376
377 inner.ready.set(true);
378 inner.waker.wake();
379 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
380
381 crate::time::sleep(Duration::from_millis(25)).await;
382 assert_eq!(inner.count.get(), 2);
383
384 let inner = Rc::new(Inner {
385 ready: Cell::new(true),
386 waker: LocalWaker::default(),
387 count: Cell::new(0),
388 });
389
390 let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
391 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
392
393 let _ = srv.call(()).await;
394 assert_eq!(inner.count.get(), 1);
395 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
396 assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
397
398 let err = BufferServiceError::from("test");
399 assert!(format!("{err}").contains("test"));
400 assert!(format!("{srv:?}").contains("BufferService"));
401 assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
402 }
403
404 #[ntex::test]
405 #[allow(clippy::redundant_clone)]
406 async fn test_middleware() {
407 let inner = Rc::new(Inner {
408 ready: Cell::new(false),
409 waker: LocalWaker::default(),
410 count: Cell::new(0),
411 });
412
413 let srv = apply(
414 Buffer::default().buf_size(2).clone(),
415 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
416 );
417
418 let srv = srv.pipeline(&()).await.unwrap().bind();
419 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
420
421 let srv1 = srv.clone();
422 ntex::rt::spawn(async move {
423 let _ = srv1.call(()).await;
424 });
425 crate::time::sleep(Duration::from_millis(25)).await;
426 assert_eq!(inner.count.get(), 0);
427 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
428
429 let srv1 = srv.clone();
430 ntex::rt::spawn(async move {
431 let _ = srv1.call(()).await;
432 });
433 crate::time::sleep(Duration::from_millis(25)).await;
434 assert_eq!(inner.count.get(), 0);
435 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
436
437 inner.ready.set(true);
438 inner.waker.wake();
439 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
440
441 crate::time::sleep(Duration::from_millis(25)).await;
442 assert_eq!(inner.count.get(), 1);
443
444 inner.ready.set(true);
445 inner.waker.wake();
446 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
447
448 crate::time::sleep(Duration::from_millis(25)).await;
449 assert_eq!(inner.count.get(), 2);
450 }
451
452 #[ntex::test]
453 #[allow(clippy::redundant_clone)]
454 async fn test_middleware2() {
455 let inner = Rc::new(Inner {
456 ready: Cell::new(false),
457 waker: LocalWaker::default(),
458 count: Cell::new(0),
459 });
460
461 let srv = apply(
462 Buffer::default().buf_size(2).clone(),
463 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
464 );
465
466 let srv = srv.pipeline(&()).await.unwrap().bind();
467 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
468
469 let srv1 = srv.clone();
470 ntex::rt::spawn(async move {
471 let _ = srv1.call(()).await;
472 });
473 crate::time::sleep(Duration::from_millis(25)).await;
474 assert_eq!(inner.count.get(), 0);
475 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
476
477 let srv1 = srv.clone();
478 ntex::rt::spawn(async move {
479 let _ = srv1.call(()).await;
480 });
481 crate::time::sleep(Duration::from_millis(25)).await;
482 assert_eq!(inner.count.get(), 0);
483 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
484
485 inner.ready.set(true);
486 inner.waker.wake();
487 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
488
489 crate::time::sleep(Duration::from_millis(25)).await;
490 assert_eq!(inner.count.get(), 1);
491
492 inner.ready.set(true);
493 inner.waker.wake();
494 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
495
496 crate::time::sleep(Duration::from_millis(25)).await;
497 assert_eq!(inner.count.get(), 2);
498 }
499}