1use std::cell::{Cell, RefCell};
3use std::task::{Poll, Waker, ready};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
7
8use crate::channel::oneshot;
9
10pub struct Buffer<R> {
14 buf_size: usize,
15 cancel_on_shutdown: bool,
16 _t: PhantomData<R>,
17}
18
19impl<R> Buffer<R> {
20 pub fn buf_size(mut self, size: usize) -> Self {
21 self.buf_size = size;
22 self
23 }
24
25 pub fn cancel_on_shutdown(mut self) -> Self {
29 self.cancel_on_shutdown = true;
30 self
31 }
32}
33
34impl<R> Default for Buffer<R> {
35 fn default() -> Self {
36 Self {
37 buf_size: 16,
38 cancel_on_shutdown: false,
39 _t: PhantomData,
40 }
41 }
42}
43
44impl<R> fmt::Debug for Buffer<R> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("Buffer")
47 .field("buf_size", &self.buf_size)
48 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
49 .finish()
50 }
51}
52
53impl<R> Clone for Buffer<R> {
54 fn clone(&self) -> Self {
55 Self {
56 buf_size: self.buf_size,
57 cancel_on_shutdown: self.cancel_on_shutdown,
58 _t: PhantomData,
59 }
60 }
61}
62
63impl<R, S, C> Middleware<S, C> for Buffer<R>
64where
65 S: Service<R> + 'static,
66 R: 'static,
67{
68 type Service = BufferService<R, S>;
69
70 fn create(&self, service: S, _: C) -> Self::Service {
71 BufferService::new(self.buf_size, service)
72 }
73}
74
75#[derive(Clone, Copy, Debug, PartialEq, Eq)]
76pub enum BufferServiceError<E> {
77 Service(E),
78 RequestCanceled,
79}
80
81impl<E> From<E> for BufferServiceError<E> {
82 fn from(err: E) -> Self {
83 BufferServiceError::Service(err)
84 }
85}
86
87impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
91 BufferServiceError::RequestCanceled => {
92 f.write_str("buffer service request canceled")
93 }
94 }
95 }
96}
97
98impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
99
100pub struct BufferService<R, S: Service<R>> {
104 size: usize,
105 ready: Cell<bool>,
106 service: PipelineBinding<S, R>,
107 buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
108 next_call: RefCell<Option<oneshot::Receiver<()>>>,
109 cancel_on_shutdown: bool,
110 readiness: Cell<Option<Waker>>,
111 _t: PhantomData<R>,
112}
113
114impl<R, S> BufferService<R, S>
115where
116 S: Service<R> + 'static,
117 R: 'static,
118{
119 pub fn new(size: usize, service: S) -> Self {
120 Self {
121 size,
122 service: Pipeline::new(service).bind(),
123 ready: Cell::new(false),
124 buf: RefCell::new(VecDeque::with_capacity(size)),
125 next_call: RefCell::default(),
126 cancel_on_shutdown: false,
127 readiness: Cell::new(None),
128 _t: PhantomData,
129 }
130 }
131
132 pub fn cancel_on_shutdown(self) -> Self {
133 Self {
134 cancel_on_shutdown: true,
135 ..self
136 }
137 }
138}
139
140impl<R, S> Clone for BufferService<R, S>
141where
142 S: Service<R> + Clone,
143{
144 fn clone(&self) -> Self {
145 Self {
146 size: self.size,
147 ready: Cell::new(false),
148 service: self.service.clone(),
149 buf: RefCell::new(VecDeque::with_capacity(self.size)),
150 next_call: RefCell::default(),
151 cancel_on_shutdown: self.cancel_on_shutdown,
152 readiness: Cell::new(None),
153 _t: PhantomData,
154 }
155 }
156}
157
158impl<R, S> fmt::Debug for BufferService<R, S>
159where
160 S: Service<R> + fmt::Debug,
161{
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("BufferService")
164 .field("size", &self.size)
165 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
166 .field("ready", &self.ready)
167 .field("service", &self.service)
168 .field("buf", &self.buf)
169 .field("next_call", &self.next_call)
170 .finish()
171 }
172}
173
174impl<R, S> Service<R> for BufferService<R, S>
175where
176 S: Service<R> + 'static,
177 R: 'static,
178{
179 type Response = S::Response;
180 type Error = BufferServiceError<S::Error>;
181
182 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
183 let next_call = self.next_call.borrow_mut().take();
185 if let Some(next_call) = next_call {
186 let _ = next_call.recv().await;
187 }
188
189 poll_fn(|cx| {
190 let mut buffer = self.buf.borrow_mut();
191
192 if self.service.poll_ready(cx)?.is_pending() {
194 if buffer.len() < self.size {
195 self.ready.set(false);
197 Poll::Ready(Ok(()))
198 } else {
199 log::trace!("Buffer limit exceeded");
200 let _ = self.readiness.take().map(|w| w.wake());
202 Poll::Pending
203 }
204 } else {
205 while let Some(sender) = buffer.pop_front() {
206 let (next_call_tx, next_call_rx) = oneshot::channel();
207 if sender.send(next_call_tx).is_err()
208 || next_call_rx.poll_recv(cx).is_ready()
209 {
210 continue;
212 }
213 self.next_call.borrow_mut().replace(next_call_rx);
214 self.ready.set(false);
215 return Poll::Ready(Ok(()));
216 }
217
218 self.ready.set(true);
219 Poll::Ready(Ok(()))
220 }
221 })
222 .await
223 }
224
225 async fn shutdown(&self) {
226 let next_call = self.next_call.borrow_mut().take();
228 if let Some(next_call) = next_call {
229 let _ = next_call.recv().await;
230 }
231
232 poll_fn(|cx| {
233 let mut buffer = self.buf.borrow_mut();
234 if self.cancel_on_shutdown {
235 buffer.clear();
236 }
237
238 if !buffer.is_empty() {
239 if ready!(self.service.poll_ready(cx)).is_err() {
240 log::error!(
241 "Buffered inner service failed while buffer flushing on shutdown"
242 );
243 return Poll::Ready(());
244 }
245
246 while let Some(sender) = buffer.pop_front() {
247 let (next_call_tx, next_call_rx) = oneshot::channel();
248 if sender.send(next_call_tx).is_err()
249 || next_call_rx.poll_recv(cx).is_ready()
250 {
251 continue;
253 }
254 self.next_call.borrow_mut().replace(next_call_rx);
255 if buffer.is_empty() {
256 break;
257 }
258 return Poll::Pending;
259 }
260 }
261 Poll::Ready(())
262 })
263 .await;
264
265 self.service.shutdown().await;
266 }
267
268 async fn call(
269 &self,
270 req: R,
271 _: ServiceCtx<'_, Self>,
272 ) -> Result<Self::Response, Self::Error> {
273 if self.ready.get() {
274 self.ready.set(false);
275 Ok(self.service.call_nowait(req).await?)
276 } else {
277 let (tx, rx) = oneshot::channel();
278 self.buf.borrow_mut().push_back(tx);
279
280 let _task_guard = rx.recv().await.map_err(|_| {
282 log::trace!("Buffered service request canceled");
283 BufferServiceError::RequestCanceled
284 })?;
285
286 Ok(self.service.call(req).await?)
288 }
289 }
290
291 ntex_service::forward_poll!(service);
292}
293
294#[cfg(test)]
295mod tests {
296 use ntex_service::{Pipeline, ServiceFactory, apply, fn_factory};
297 use std::{rc::Rc, time::Duration};
298
299 use super::*;
300 use crate::{future::lazy, task::LocalWaker};
301
302 #[derive(Debug, Clone)]
303 struct TestService(Rc<Inner>);
304
305 #[derive(Debug)]
306 struct Inner {
307 ready: Cell<bool>,
308 waker: LocalWaker,
309 count: Cell<usize>,
310 }
311
312 impl Service<()> for TestService {
313 type Response = ();
314 type Error = ();
315
316 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
317 poll_fn(|cx| {
318 self.0.waker.register(cx.waker());
319 if self.0.ready.get() {
320 Poll::Ready(Ok(()))
321 } else {
322 Poll::Pending
323 }
324 })
325 .await
326 }
327
328 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
329 self.0.ready.set(false);
330 self.0.count.set(self.0.count.get() + 1);
331 Ok(())
332 }
333 }
334
335 #[ntex::test]
336 async fn test_service() {
337 let inner = Rc::new(Inner {
338 ready: Cell::new(false),
339 waker: LocalWaker::default(),
340 count: Cell::new(0),
341 });
342
343 let srv =
344 Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
345 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
346
347 let srv1 = srv.clone();
348 ntex::rt::spawn(async move {
349 let _ = srv1.call(()).await;
350 });
351 crate::time::sleep(Duration::from_millis(25)).await;
352 assert_eq!(inner.count.get(), 0);
353 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
354
355 let srv1 = srv.clone();
356 ntex::rt::spawn(async move {
357 let _ = srv1.call(()).await;
358 });
359 crate::time::sleep(Duration::from_millis(25)).await;
360 assert_eq!(inner.count.get(), 0);
361 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
362
363 inner.ready.set(true);
364 inner.waker.wake();
365 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
366
367 crate::time::sleep(Duration::from_millis(25)).await;
368 assert_eq!(inner.count.get(), 1);
369
370 inner.ready.set(true);
371 inner.waker.wake();
372 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
373
374 crate::time::sleep(Duration::from_millis(25)).await;
375 assert_eq!(inner.count.get(), 2);
376
377 let inner = Rc::new(Inner {
378 ready: Cell::new(true),
379 waker: LocalWaker::default(),
380 count: Cell::new(0),
381 });
382
383 let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
384 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
385
386 let _ = srv.call(()).await;
387 assert_eq!(inner.count.get(), 1);
388 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
389 assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
390
391 let err = BufferServiceError::from("test");
392 assert!(format!("{err}").contains("test"));
393 assert!(format!("{srv:?}").contains("BufferService"));
394 assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
395 }
396
397 #[ntex::test]
398 #[allow(clippy::redundant_clone)]
399 async fn test_middleware() {
400 let inner = Rc::new(Inner {
401 ready: Cell::new(false),
402 waker: LocalWaker::default(),
403 count: Cell::new(0),
404 });
405
406 let srv = apply(
407 Buffer::default().buf_size(2).clone(),
408 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
409 );
410
411 let srv = srv.pipeline(&()).await.unwrap().bind();
412 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
413
414 let srv1 = srv.clone();
415 ntex::rt::spawn(async move {
416 let _ = srv1.call(()).await;
417 });
418 crate::time::sleep(Duration::from_millis(25)).await;
419 assert_eq!(inner.count.get(), 0);
420 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
421
422 let srv1 = srv.clone();
423 ntex::rt::spawn(async move {
424 let _ = srv1.call(()).await;
425 });
426 crate::time::sleep(Duration::from_millis(25)).await;
427 assert_eq!(inner.count.get(), 0);
428 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
429
430 inner.ready.set(true);
431 inner.waker.wake();
432 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
433
434 crate::time::sleep(Duration::from_millis(25)).await;
435 assert_eq!(inner.count.get(), 1);
436
437 inner.ready.set(true);
438 inner.waker.wake();
439 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
440
441 crate::time::sleep(Duration::from_millis(25)).await;
442 assert_eq!(inner.count.get(), 2);
443 }
444
445 #[ntex::test]
446 #[allow(clippy::redundant_clone)]
447 async fn test_middleware2() {
448 let inner = Rc::new(Inner {
449 ready: Cell::new(false),
450 waker: LocalWaker::default(),
451 count: Cell::new(0),
452 });
453
454 let srv = apply(
455 Buffer::default().buf_size(2).clone(),
456 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
457 );
458
459 let srv = srv.pipeline(&()).await.unwrap().bind();
460 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
461
462 let srv1 = srv.clone();
463 ntex::rt::spawn(async move {
464 let _ = srv1.call(()).await;
465 });
466 crate::time::sleep(Duration::from_millis(25)).await;
467 assert_eq!(inner.count.get(), 0);
468 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
469
470 let srv1 = srv.clone();
471 ntex::rt::spawn(async move {
472 let _ = srv1.call(()).await;
473 });
474 crate::time::sleep(Duration::from_millis(25)).await;
475 assert_eq!(inner.count.get(), 0);
476 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
477
478 inner.ready.set(true);
479 inner.waker.wake();
480 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
481
482 crate::time::sleep(Duration::from_millis(25)).await;
483 assert_eq!(inner.count.get(), 1);
484
485 inner.ready.set(true);
486 inner.waker.wake();
487 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
488
489 crate::time::sleep(Duration::from_millis(25)).await;
490 assert_eq!(inner.count.get(), 2);
491 }
492}