1#![allow(clippy::cast_possible_truncation)]
2use std::task::{Context, Poll, Waker};
3use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc};
4
5use crate::Service;
6
7pub struct ServiceCtx<'a, S: ?Sized> {
8 idx: u32,
9 waiters: &'a WaitersRef,
10 _t: marker::PhantomData<Rc<S>>,
11}
12
13#[derive(Debug)]
14pub(crate) struct WaitersRef {
15 running: cell::Cell<bool>,
16 cur: cell::Cell<u32>,
17 shutdown: cell::Cell<bool>,
18 wakers: cell::UnsafeCell<Vec<u32>>,
19 indexes: cell::UnsafeCell<slab::Slab<Option<Waker>>>,
20}
21
22impl WaitersRef {
23 pub(crate) fn new() -> (u32, Self) {
24 let mut waiters = slab::Slab::new();
25
26 (
27 waiters.insert(None) as u32,
28 WaitersRef {
29 running: cell::Cell::new(false),
30 cur: cell::Cell::new(u32::MAX),
31 shutdown: cell::Cell::new(false),
32 indexes: cell::UnsafeCell::new(waiters),
33 wakers: cell::UnsafeCell::new(Vec::default()),
34 },
35 )
36 }
37
38 #[allow(clippy::mut_from_ref)]
39 pub(crate) fn get(&self) -> &mut slab::Slab<Option<Waker>> {
40 unsafe { &mut *self.indexes.get() }
41 }
42
43 #[allow(clippy::mut_from_ref)]
44 pub(crate) fn get_wakers(&self) -> &mut Vec<u32> {
45 unsafe { &mut *self.wakers.get() }
46 }
47
48 pub(crate) fn insert(&self) -> u32 {
49 self.get().insert(None) as u32
50 }
51
52 pub(crate) fn remove(&self, idx: u32) {
53 self.get().remove(idx as usize);
54
55 if self.cur.get() == idx {
56 self.notify();
57 }
58 }
59
60 pub(crate) fn notify(&self) {
61 let wakers = self.get_wakers();
62 if !wakers.is_empty() {
63 let indexes = self.get();
64 for idx in wakers.drain(..) {
65 if let Some(item) = indexes.get_mut(idx as usize)
66 && let Some(waker) = item.take()
67 {
68 waker.wake();
69 }
70 }
71 }
72
73 self.cur.set(u32::MAX);
74 }
75
76 pub(crate) fn run<F, R>(&self, idx: u32, cx: &mut Context<'_>, f: F) -> Poll<R>
77 where
78 F: FnOnce(&mut Context<'_>) -> Poll<R>,
79 {
80 let cur = self.cur.get();
82 let can_check = if cur == idx {
83 true
84 } else if cur == u32::MAX {
85 self.cur.set(idx);
86 true
87 } else {
88 false
89 };
90
91 if can_check {
92 let initial_run = !self.running.get();
94 if initial_run {
95 self.running.set(true);
96 }
97
98 let result = f(cx);
99
100 if initial_run {
101 if result.is_pending() {
102 self.get_wakers().push(idx);
103 self.get()[idx as usize] = Some(cx.waker().clone());
104 } else {
105 self.notify();
106 }
107 self.running.set(false);
108 }
109 result
110 } else {
111 self.get_wakers().push(idx);
113 self.get()[idx as usize] = Some(cx.waker().clone());
114 Poll::Pending
115 }
116 }
117
118 pub(crate) fn shutdown(&self) {
119 self.shutdown.set(true);
120 }
121
122 pub(crate) fn is_shutdown(&self) -> bool {
123 self.shutdown.get()
124 }
125}
126
127impl<'a, S> ServiceCtx<'a, S> {
128 pub(crate) fn new(idx: u32, waiters: &'a WaitersRef) -> Self {
129 Self {
130 idx,
131 waiters,
132 _t: marker::PhantomData,
133 }
134 }
135
136 pub(crate) fn inner(self) -> (u32, &'a WaitersRef) {
137 (self.idx, self.waiters)
138 }
139
140 #[inline]
141 pub fn id(&self) -> u32 {
143 self.idx
144 }
145
146 pub async fn ready<T, R>(&self, svc: &'a T) -> Result<(), T::Error>
148 where
149 T: Service<R>,
150 {
151 ReadyCall {
153 completed: false,
154 fut: svc.ready(ServiceCtx {
155 idx: self.idx,
156 waiters: self.waiters,
157 _t: marker::PhantomData,
158 }),
159 ctx: *self,
160 }
161 .await
162 }
163
164 #[inline]
165 pub async fn call<T, R>(&self, svc: &'a T, req: R) -> Result<T::Response, T::Error>
167 where
168 T: Service<R>,
169 R: 'a,
170 {
171 self.ready(svc).await?;
172
173 svc.call(
174 req,
175 ServiceCtx {
176 idx: self.idx,
177 waiters: self.waiters,
178 _t: marker::PhantomData,
179 },
180 )
181 .await
182 }
183
184 #[inline]
185 pub async fn call_nowait<T, R>(
187 &self,
188 svc: &'a T,
189 req: R,
190 ) -> Result<T::Response, T::Error>
191 where
192 T: Service<R>,
193 R: 'a,
194 {
195 svc.call(
196 req,
197 ServiceCtx {
198 idx: self.idx,
199 waiters: self.waiters,
200 _t: marker::PhantomData,
201 },
202 )
203 .await
204 }
205}
206
207impl<S> Copy for ServiceCtx<'_, S> {}
208
209impl<S> Clone for ServiceCtx<'_, S> {
210 #[inline]
211 fn clone(&self) -> Self {
212 *self
213 }
214}
215
216impl<S> fmt::Debug for ServiceCtx<'_, S> {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 f.debug_struct("ServiceCtx")
219 .field("idx", &self.idx)
220 .field("waiters", &self.waiters.get().len())
221 .finish()
222 }
223}
224
225struct ReadyCall<'a, S: ?Sized, F: Future> {
226 completed: bool,
227 fut: F,
228 ctx: ServiceCtx<'a, S>,
229}
230
231impl<S: ?Sized, F: Future> Drop for ReadyCall<'_, S, F> {
232 fn drop(&mut self) {
233 if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx {
234 self.ctx.waiters.notify();
235 }
236 }
237}
238
239impl<S: ?Sized, F: Future> Unpin for ReadyCall<'_, S, F> {}
240
241impl<S: ?Sized, F: Future> Future for ReadyCall<'_, S, F> {
242 type Output = F::Output;
243
244 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245 self.ctx.waiters.run(self.ctx.idx, cx, |cx| {
246 let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) };
248 if result.is_ready() {
249 self.completed = true;
250 }
251 result
252 })
253 }
254}
255
256#[cfg(test)]
257#[allow(clippy::should_panic_without_expect)]
258mod tests {
259 use std::{cell::Cell, cell::RefCell, future::poll_fn};
260
261 use ntex::channel::{condition, oneshot};
262 use ntex::{rt::spawn, time, util::lazy, util::select};
263
264 use super::*;
265 use crate::Pipeline;
266
267 struct Srv(Rc<Cell<usize>>, condition::Waiter);
268
269 impl Service<&'static str> for Srv {
270 type Response = &'static str;
271 type Error = ();
272
273 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
274 self.0.set(self.0.get() + 1);
275 self.1.ready().await;
276 Ok(())
277 }
278
279 async fn call(
280 &self,
281 req: &'static str,
282 ctx: ServiceCtx<'_, Self>,
283 ) -> Result<Self::Response, Self::Error> {
284 let _ = format!("{ctx:?}");
285 let _ = format!("{:?}", ctx.id());
286 #[allow(clippy::clone_on_copy)]
287 let _ = ctx.clone();
288 Ok(req)
289 }
290 }
291
292 #[ntex::test]
293 async fn test_ready() {
294 let cnt = Rc::new(Cell::new(0));
295 let con = condition::Condition::new();
296
297 let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
298 let srv2 = srv1.clone();
299
300 let res = lazy(|cx| srv1.poll_ready(cx)).await;
301 assert_eq!(res, Poll::Pending);
302 assert_eq!(cnt.get(), 1);
303
304 let res = lazy(|cx| srv2.poll_ready(cx)).await;
305 assert_eq!(res, Poll::Pending);
306 assert_eq!(cnt.get(), 1);
307
308 con.notify();
309 let res = lazy(|cx| srv1.poll_ready(cx)).await;
310 assert_eq!(res, Poll::Ready(Ok(())));
311 assert_eq!(cnt.get(), 1);
312
313 let res = lazy(|cx| srv2.poll_ready(cx)).await;
314 assert_eq!(res, Poll::Pending);
315 assert_eq!(cnt.get(), 2);
316
317 con.notify();
318 let res = lazy(|cx| srv2.poll_ready(cx)).await;
319 assert_eq!(res, Poll::Ready(Ok(())));
320 assert_eq!(cnt.get(), 2);
321
322 let res = lazy(|cx| srv1.poll_ready(cx)).await;
323 assert_eq!(res, Poll::Pending);
324 assert_eq!(cnt.get(), 3);
325 }
326
327 #[ntex::test]
328 async fn test_ready_on_drop() {
329 let cnt = Rc::new(Cell::new(0));
330 let con = condition::Condition::new();
331 let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
332
333 let srv1 = srv.clone();
334 let srv2 = srv1.clone().bind();
335
336 let (tx, rx) = oneshot::channel();
337 spawn(async move {
338 select(rx, srv1.ready()).await;
339 time::sleep(time::Millis(25000)).await;
340 });
341 time::sleep(time::Millis(250)).await;
342
343 let res = lazy(|cx| srv2.poll_ready(cx)).await;
344 assert_eq!(res, Poll::Pending);
345
346 let _ = tx.send(());
347 time::sleep(time::Millis(250)).await;
348
349 let res = lazy(|cx| srv2.poll_ready(cx)).await;
350 assert_eq!(res, Poll::Pending);
351
352 con.notify();
353 let res = lazy(|cx| srv2.poll_ready(cx)).await;
354 assert_eq!(res, Poll::Ready(Ok(())));
355 }
356
357 #[ntex::test]
358 async fn test_ready_after_shutdown() {
359 let cnt = Rc::new(Cell::new(0));
360 let con = condition::Condition::new();
361 let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
362
363 let srv1 = srv.clone().bind();
364 let srv2 = srv1.clone();
365
366 let (tx, rx) = oneshot::channel();
367 spawn(async move {
368 select(rx, poll_fn(|cx| srv1.poll_ready(cx))).await;
369 poll_fn(|cx| srv1.poll_shutdown(cx)).await;
370 time::sleep(time::Millis(25000)).await;
371 });
372 time::sleep(time::Millis(250)).await;
373
374 let res = lazy(|cx| srv2.poll_ready(cx)).await;
375 assert_eq!(res, Poll::Pending);
376
377 let _ = tx.send(());
378 time::sleep(time::Millis(250)).await;
379
380 let res = lazy(|cx| srv2.poll_ready(cx)).await;
381 assert_eq!(res, Poll::Pending);
382
383 con.notify();
384 let res = lazy(|cx| srv2.poll_ready(cx)).await;
385 assert_eq!(res, Poll::Ready(Ok(())));
386 }
387
388 #[ntex::test]
389 #[should_panic]
390 async fn test_pipeline_binding_after_shutdown() {
391 let cnt = Rc::new(Cell::new(0));
392 let con = condition::Condition::new();
393 let srv = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
394 poll_fn(|cx| srv.poll_shutdown(cx)).await;
395 let _ = poll_fn(|cx| srv.poll_ready(cx)).await;
396 }
397
398 #[ntex::test]
399 async fn test_shared_call() {
400 let data = Rc::new(RefCell::new(Vec::new()));
401
402 let cnt = Rc::new(Cell::new(0));
403 let con = condition::Condition::new();
404
405 let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
406 let srv2 = srv1.clone();
407 let _: Pipeline<_> = srv1.pipeline();
408
409 let data1 = data.clone();
410 ntex::rt::spawn(async move {
411 let _ = poll_fn(|cx| srv1.poll_ready(cx)).await;
412 let fut = srv1.call_nowait("srv1");
413 assert!(format!("{fut:?}").contains("PipelineCall"));
414 let i = fut.await.unwrap();
415 data1.borrow_mut().push(i);
416 });
417
418 let data2 = data.clone();
419 ntex::rt::spawn(async move {
420 let i = srv2.call("srv2").await.unwrap();
421 data2.borrow_mut().push(i);
422 });
423 time::sleep(time::Millis(50)).await;
424
425 con.notify();
426 time::sleep(time::Millis(150)).await;
427
428 assert_eq!(cnt.get(), 2);
429 assert_eq!(&*data.borrow(), &["srv1"]);
430
431 con.notify();
432 time::sleep(time::Millis(150)).await;
433
434 assert_eq!(cnt.get(), 2);
435 assert_eq!(&*data.borrow(), &["srv1", "srv2"]);
436 }
437}