1use std::cell::{Cell, RefCell};
3use std::task::{ready, Poll, Waker};
4use std::{collections::VecDeque, fmt, future::poll_fn, marker::PhantomData};
5
6use ntex_service::{Middleware, Pipeline, PipelineBinding, Service, ServiceCtx};
7
8use crate::channel::oneshot;
9
10pub struct Buffer<R> {
14 buf_size: usize,
15 cancel_on_shutdown: bool,
16 _t: PhantomData<R>,
17}
18
19impl<R> Buffer<R> {
20 pub fn buf_size(mut self, size: usize) -> Self {
21 self.buf_size = size;
22 self
23 }
24
25 pub fn cancel_on_shutdown(mut self) -> Self {
29 self.cancel_on_shutdown = true;
30 self
31 }
32}
33
34impl<R> Default for Buffer<R> {
35 fn default() -> Self {
36 Self {
37 buf_size: 16,
38 cancel_on_shutdown: false,
39 _t: PhantomData,
40 }
41 }
42}
43
44impl<R> fmt::Debug for Buffer<R> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("Buffer")
47 .field("buf_size", &self.buf_size)
48 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
49 .finish()
50 }
51}
52
53impl<R> Clone for Buffer<R> {
54 fn clone(&self) -> Self {
55 Self {
56 buf_size: self.buf_size,
57 cancel_on_shutdown: self.cancel_on_shutdown,
58 _t: PhantomData,
59 }
60 }
61}
62
63impl<R, S> Middleware<S> for Buffer<R>
64where
65 S: Service<R> + 'static,
66 R: 'static,
67{
68 type Service = BufferService<R, S>;
69
70 fn create(&self, service: S) -> Self::Service {
71 BufferService {
72 service: Pipeline::new(service).bind(),
73 size: self.buf_size,
74 ready: Cell::new(false),
75 buf: RefCell::new(VecDeque::with_capacity(self.buf_size)),
76 next_call: RefCell::default(),
77 cancel_on_shutdown: self.cancel_on_shutdown,
78 readiness: Cell::new(None),
79 _t: PhantomData,
80 }
81 }
82}
83
84#[derive(Clone, Copy, Debug, PartialEq, Eq)]
85pub enum BufferServiceError<E> {
86 Service(E),
87 RequestCanceled,
88}
89
90impl<E> From<E> for BufferServiceError<E> {
91 fn from(err: E) -> Self {
92 BufferServiceError::Service(err)
93 }
94}
95
96impl<E: std::fmt::Display> std::fmt::Display for BufferServiceError<E> {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f),
100 BufferServiceError::RequestCanceled => {
101 f.write_str("buffer service request canceled")
102 }
103 }
104 }
105}
106
107impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for BufferServiceError<E> {}
108
109pub struct BufferService<R, S: Service<R>> {
113 size: usize,
114 ready: Cell<bool>,
115 service: PipelineBinding<S, R>,
116 buf: RefCell<VecDeque<oneshot::Sender<oneshot::Sender<()>>>>,
117 next_call: RefCell<Option<oneshot::Receiver<()>>>,
118 cancel_on_shutdown: bool,
119 readiness: Cell<Option<Waker>>,
120 _t: PhantomData<R>,
121}
122
123impl<R, S> BufferService<R, S>
124where
125 S: Service<R> + 'static,
126 R: 'static,
127{
128 pub fn new(size: usize, service: S) -> Self {
129 Self {
130 size,
131 service: Pipeline::new(service).bind(),
132 ready: Cell::new(false),
133 buf: RefCell::new(VecDeque::with_capacity(size)),
134 next_call: RefCell::default(),
135 cancel_on_shutdown: false,
136 readiness: Cell::new(None),
137 _t: PhantomData,
138 }
139 }
140
141 pub fn cancel_on_shutdown(self) -> Self {
142 Self {
143 cancel_on_shutdown: true,
144 ..self
145 }
146 }
147}
148
149impl<R, S> Clone for BufferService<R, S>
150where
151 S: Service<R> + Clone,
152{
153 fn clone(&self) -> Self {
154 Self {
155 size: self.size,
156 ready: Cell::new(false),
157 service: self.service.clone(),
158 buf: RefCell::new(VecDeque::with_capacity(self.size)),
159 next_call: RefCell::default(),
160 cancel_on_shutdown: self.cancel_on_shutdown,
161 readiness: Cell::new(None),
162 _t: PhantomData,
163 }
164 }
165}
166
167impl<R, S> fmt::Debug for BufferService<R, S>
168where
169 S: Service<R> + fmt::Debug,
170{
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("BufferService")
173 .field("size", &self.size)
174 .field("cancel_on_shutdown", &self.cancel_on_shutdown)
175 .field("ready", &self.ready)
176 .field("service", &self.service)
177 .field("buf", &self.buf)
178 .field("next_call", &self.next_call)
179 .finish()
180 }
181}
182
183impl<R, S> Service<R> for BufferService<R, S>
184where
185 S: Service<R> + 'static,
186 R: 'static,
187{
188 type Response = S::Response;
189 type Error = BufferServiceError<S::Error>;
190
191 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
192 let next_call = self.next_call.borrow_mut().take();
194 if let Some(next_call) = next_call {
195 let _ = next_call.recv().await;
196 }
197
198 poll_fn(|cx| {
199 let mut buffer = self.buf.borrow_mut();
200
201 if self.service.poll_ready(cx)?.is_pending() {
203 if buffer.len() < self.size {
204 self.ready.set(false);
206 Poll::Ready(Ok(()))
207 } else {
208 log::trace!("Buffer limit exceeded");
209 let _ = self.readiness.take().map(|w| w.wake());
211 Poll::Pending
212 }
213 } else {
214 while let Some(sender) = buffer.pop_front() {
215 let (next_call_tx, next_call_rx) = oneshot::channel();
216 if sender.send(next_call_tx).is_err()
217 || next_call_rx.poll_recv(cx).is_ready()
218 {
219 continue;
221 }
222 self.next_call.borrow_mut().replace(next_call_rx);
223 self.ready.set(false);
224 return Poll::Ready(Ok(()));
225 }
226
227 self.ready.set(true);
228 Poll::Ready(Ok(()))
229 }
230 })
231 .await
232 }
233
234 async fn shutdown(&self) {
235 let next_call = self.next_call.borrow_mut().take();
237 if let Some(next_call) = next_call {
238 let _ = next_call.recv().await;
239 }
240
241 poll_fn(|cx| {
242 let mut buffer = self.buf.borrow_mut();
243 if self.cancel_on_shutdown {
244 buffer.clear();
245 }
246
247 if !buffer.is_empty() {
248 if ready!(self.service.poll_ready(cx)).is_err() {
249 log::error!(
250 "Buffered inner service failed while buffer flushing on shutdown"
251 );
252 return Poll::Ready(());
253 }
254
255 while let Some(sender) = buffer.pop_front() {
256 let (next_call_tx, next_call_rx) = oneshot::channel();
257 if sender.send(next_call_tx).is_err()
258 || next_call_rx.poll_recv(cx).is_ready()
259 {
260 continue;
262 }
263 self.next_call.borrow_mut().replace(next_call_rx);
264 if buffer.is_empty() {
265 break;
266 }
267 return Poll::Pending;
268 }
269 }
270 Poll::Ready(())
271 })
272 .await;
273
274 self.service.shutdown().await;
275 }
276
277 async fn call(
278 &self,
279 req: R,
280 _: ServiceCtx<'_, Self>,
281 ) -> Result<Self::Response, Self::Error> {
282 if self.ready.get() {
283 self.ready.set(false);
284 Ok(self.service.call_nowait(req).await?)
285 } else {
286 let (tx, rx) = oneshot::channel();
287 self.buf.borrow_mut().push_back(tx);
288
289 let _task_guard = rx.recv().await.map_err(|_| {
291 log::trace!("Buffered service request canceled");
292 BufferServiceError::RequestCanceled
293 })?;
294
295 Ok(self.service.call(req).await?)
297 }
298 }
299
300 ntex_service::forward_poll!(service);
301}
302
303#[cfg(test)]
304mod tests {
305 use ntex_service::{apply, fn_factory, Pipeline, ServiceFactory};
306 use std::{rc::Rc, time::Duration};
307
308 use super::*;
309 use crate::future::lazy;
310 use crate::task::LocalWaker;
311
312 #[derive(Debug, Clone)]
313 struct TestService(Rc<Inner>);
314
315 #[derive(Debug)]
316 struct Inner {
317 ready: Cell<bool>,
318 waker: LocalWaker,
319 count: Cell<usize>,
320 }
321
322 impl Service<()> for TestService {
323 type Response = ();
324 type Error = ();
325
326 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
327 poll_fn(|cx| {
328 self.0.waker.register(cx.waker());
329 if self.0.ready.get() {
330 Poll::Ready(Ok(()))
331 } else {
332 Poll::Pending
333 }
334 })
335 .await
336 }
337
338 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
339 self.0.ready.set(false);
340 self.0.count.set(self.0.count.get() + 1);
341 Ok(())
342 }
343 }
344
345 #[ntex_macros::rt_test2]
346 async fn test_service() {
347 let inner = Rc::new(Inner {
348 ready: Cell::new(false),
349 waker: LocalWaker::default(),
350 count: Cell::new(0),
351 });
352
353 let srv =
354 Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind();
355 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
356
357 let srv1 = srv.clone();
358 ntex::rt::spawn(async move {
359 let _ = srv1.call(()).await;
360 });
361 crate::time::sleep(Duration::from_millis(25)).await;
362 assert_eq!(inner.count.get(), 0);
363 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
364
365 let srv1 = srv.clone();
366 ntex::rt::spawn(async move {
367 let _ = srv1.call(()).await;
368 });
369 crate::time::sleep(Duration::from_millis(25)).await;
370 assert_eq!(inner.count.get(), 0);
371 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
372
373 inner.ready.set(true);
374 inner.waker.wake();
375 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
376
377 crate::time::sleep(Duration::from_millis(25)).await;
378 assert_eq!(inner.count.get(), 1);
379
380 inner.ready.set(true);
381 inner.waker.wake();
382 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
383
384 crate::time::sleep(Duration::from_millis(25)).await;
385 assert_eq!(inner.count.get(), 2);
386
387 let inner = Rc::new(Inner {
388 ready: Cell::new(true),
389 waker: LocalWaker::default(),
390 count: Cell::new(0),
391 });
392
393 let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind();
394 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
395
396 let _ = srv.call(()).await;
397 assert_eq!(inner.count.get(), 1);
398 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
399 assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
400
401 let err = BufferServiceError::from("test");
402 assert!(format!("{err}").contains("test"));
403 assert!(format!("{srv:?}").contains("BufferService"));
404 assert!(format!("{:?}", Buffer::<TestService>::default()).contains("Buffer"));
405 }
406
407 #[ntex_macros::rt_test2]
408 #[allow(clippy::redundant_clone)]
409 async fn test_middleware() {
410 let inner = Rc::new(Inner {
411 ready: Cell::new(false),
412 waker: LocalWaker::default(),
413 count: Cell::new(0),
414 });
415
416 let srv = apply(
417 Buffer::default().buf_size(2).clone(),
418 fn_factory(|| async { Ok::<_, ()>(TestService(inner.clone())) }),
419 );
420
421 let srv = srv.pipeline(&()).await.unwrap().bind();
422 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
423
424 let srv1 = srv.clone();
425 ntex::rt::spawn(async move {
426 let _ = srv1.call(()).await;
427 });
428 crate::time::sleep(Duration::from_millis(25)).await;
429 assert_eq!(inner.count.get(), 0);
430 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
431
432 let srv1 = srv.clone();
433 ntex::rt::spawn(async move {
434 let _ = srv1.call(()).await;
435 });
436 crate::time::sleep(Duration::from_millis(25)).await;
437 assert_eq!(inner.count.get(), 0);
438 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
439
440 inner.ready.set(true);
441 inner.waker.wake();
442 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
443
444 crate::time::sleep(Duration::from_millis(25)).await;
445 assert_eq!(inner.count.get(), 1);
446
447 inner.ready.set(true);
448 inner.waker.wake();
449 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
450
451 crate::time::sleep(Duration::from_millis(25)).await;
452 assert_eq!(inner.count.get(), 2);
453 }
454}