1use std::task::{Context, Poll, Waker};
2use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc};
3
4use crate::Service;
5
6pub struct ServiceCtx<'a, S: ?Sized> {
7 idx: u32,
8 waiters: &'a WaitersRef,
9 _t: marker::PhantomData<Rc<S>>,
10}
11
12#[derive(Debug)]
13pub(crate) struct WaitersRef {
14 running: cell::Cell<bool>,
15 cur: cell::Cell<u32>,
16 shutdown: cell::Cell<bool>,
17 wakers: cell::UnsafeCell<Vec<u32>>,
18 indexes: cell::UnsafeCell<slab::Slab<Option<Waker>>>,
19}
20
21impl WaitersRef {
22 pub(crate) fn new() -> (u32, Self) {
23 let mut waiters = slab::Slab::new();
24
25 (
26 waiters.insert(Default::default()) as u32,
27 WaitersRef {
28 running: cell::Cell::new(false),
29 cur: cell::Cell::new(u32::MAX),
30 shutdown: cell::Cell::new(false),
31 indexes: cell::UnsafeCell::new(waiters),
32 wakers: cell::UnsafeCell::new(Vec::default()),
33 },
34 )
35 }
36
37 #[allow(clippy::mut_from_ref)]
38 pub(crate) fn get(&self) -> &mut slab::Slab<Option<Waker>> {
39 unsafe { &mut *self.indexes.get() }
40 }
41
42 #[allow(clippy::mut_from_ref)]
43 pub(crate) fn get_wakers(&self) -> &mut Vec<u32> {
44 unsafe { &mut *self.wakers.get() }
45 }
46
47 pub(crate) fn insert(&self) -> u32 {
48 self.get().insert(None) as u32
49 }
50
51 pub(crate) fn remove(&self, idx: u32) {
52 self.get().remove(idx as usize);
53
54 if self.cur.get() == idx {
55 self.notify();
56 }
57 }
58
59 pub(crate) fn notify(&self) {
60 let wakers = self.get_wakers();
61 if !wakers.is_empty() {
62 let indexes = self.get();
63 for idx in wakers.drain(..) {
64 if let Some(item) = indexes.get_mut(idx as usize) {
65 if let Some(waker) = item.take() {
66 waker.wake();
67 }
68 }
69 }
70 }
71
72 self.cur.set(u32::MAX);
73 }
74
75 pub(crate) fn run<F, R>(&self, idx: u32, cx: &mut Context<'_>, f: F) -> Poll<R>
76 where
77 F: FnOnce(&mut Context<'_>) -> Poll<R>,
78 {
79 let cur = self.cur.get();
81 let can_check = if cur == idx {
82 true
83 } else if cur == u32::MAX {
84 self.cur.set(idx);
85 true
86 } else {
87 false
88 };
89
90 if can_check {
91 let initial_run = !self.running.get();
93 if initial_run {
94 self.running.set(true);
95 }
96
97 let result = f(cx);
98
99 if initial_run {
100 if result.is_pending() {
101 self.get_wakers().push(idx);
102 self.get()[idx as usize] = Some(cx.waker().clone());
103 } else {
104 self.notify();
105 }
106 self.running.set(false);
107 }
108 result
109 } else {
110 self.get_wakers().push(idx);
112 self.get()[idx as usize] = Some(cx.waker().clone());
113 Poll::Pending
114 }
115 }
116
117 pub(crate) fn shutdown(&self) {
118 self.shutdown.set(true);
119 }
120
121 pub(crate) fn is_shutdown(&self) -> bool {
122 self.shutdown.get()
123 }
124}
125
126impl<'a, S> ServiceCtx<'a, S> {
127 pub(crate) fn new(idx: u32, waiters: &'a WaitersRef) -> Self {
128 Self {
129 idx,
130 waiters,
131 _t: marker::PhantomData,
132 }
133 }
134
135 pub(crate) fn inner(self) -> (u32, &'a WaitersRef) {
136 (self.idx, self.waiters)
137 }
138
139 #[inline]
140 pub fn id(&self) -> u32 {
142 self.idx
143 }
144
145 pub async fn ready<T, R>(&self, svc: &'a T) -> Result<(), T::Error>
147 where
148 T: Service<R>,
149 {
150 ReadyCall {
152 completed: false,
153 fut: svc.ready(ServiceCtx {
154 idx: self.idx,
155 waiters: self.waiters,
156 _t: marker::PhantomData,
157 }),
158 ctx: *self,
159 }
160 .await
161 }
162
163 #[inline]
164 pub async fn call<T, R>(&self, svc: &'a T, req: R) -> Result<T::Response, T::Error>
166 where
167 T: Service<R>,
168 R: 'a,
169 {
170 self.ready(svc).await?;
171
172 svc.call(
173 req,
174 ServiceCtx {
175 idx: self.idx,
176 waiters: self.waiters,
177 _t: marker::PhantomData,
178 },
179 )
180 .await
181 }
182
183 #[inline]
184 pub async fn call_nowait<T, R>(
186 &self,
187 svc: &'a T,
188 req: R,
189 ) -> Result<T::Response, T::Error>
190 where
191 T: Service<R>,
192 R: 'a,
193 {
194 svc.call(
195 req,
196 ServiceCtx {
197 idx: self.idx,
198 waiters: self.waiters,
199 _t: marker::PhantomData,
200 },
201 )
202 .await
203 }
204}
205
206impl<S> Copy for ServiceCtx<'_, S> {}
207
208impl<S> Clone for ServiceCtx<'_, S> {
209 #[inline]
210 fn clone(&self) -> Self {
211 *self
212 }
213}
214
215impl<S> fmt::Debug for ServiceCtx<'_, S> {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 f.debug_struct("ServiceCtx")
218 .field("idx", &self.idx)
219 .field("waiters", &self.waiters.get().len())
220 .finish()
221 }
222}
223
224struct ReadyCall<'a, S: ?Sized, F: Future> {
225 completed: bool,
226 fut: F,
227 ctx: ServiceCtx<'a, S>,
228}
229
230impl<S: ?Sized, F: Future> Drop for ReadyCall<'_, S, F> {
231 fn drop(&mut self) {
232 if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx {
233 self.ctx.waiters.notify();
234 }
235 }
236}
237
238impl<S: ?Sized, F: Future> Unpin for ReadyCall<'_, S, F> {}
239
240impl<S: ?Sized, F: Future> Future for ReadyCall<'_, S, F> {
241 type Output = F::Output;
242
243 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244 self.ctx.waiters.run(self.ctx.idx, cx, |cx| {
245 let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) };
247 if result.is_ready() {
248 self.completed = true;
249 }
250 result
251 })
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use std::{cell::Cell, cell::RefCell, future::poll_fn};
258
259 use ntex_util::channel::{condition, oneshot};
260 use ntex_util::{future::lazy, future::select, spawn, time};
261
262 use super::*;
263 use crate::Pipeline;
264
265 struct Srv(Rc<Cell<usize>>, condition::Waiter);
266
267 impl Service<&'static str> for Srv {
268 type Response = &'static str;
269 type Error = ();
270
271 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
272 self.0.set(self.0.get() + 1);
273 self.1.ready().await;
274 Ok(())
275 }
276
277 async fn call(
278 &self,
279 req: &'static str,
280 ctx: ServiceCtx<'_, Self>,
281 ) -> Result<Self::Response, Self::Error> {
282 let _ = format!("{ctx:?}");
283 #[allow(clippy::clone_on_copy)]
284 let _ = ctx.clone();
285 Ok(req)
286 }
287 }
288
289 #[ntex::test]
290 async fn test_ready() {
291 let cnt = Rc::new(Cell::new(0));
292 let con = condition::Condition::new();
293
294 let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
295 let srv2 = srv1.clone();
296
297 let res = lazy(|cx| srv1.poll_ready(cx)).await;
298 assert_eq!(res, Poll::Pending);
299 assert_eq!(cnt.get(), 1);
300
301 let res = lazy(|cx| srv2.poll_ready(cx)).await;
302 assert_eq!(res, Poll::Pending);
303 assert_eq!(cnt.get(), 1);
304
305 con.notify();
306 let res = lazy(|cx| srv1.poll_ready(cx)).await;
307 assert_eq!(res, Poll::Ready(Ok(())));
308 assert_eq!(cnt.get(), 1);
309
310 let res = lazy(|cx| srv2.poll_ready(cx)).await;
311 assert_eq!(res, Poll::Pending);
312 assert_eq!(cnt.get(), 2);
313
314 con.notify();
315 let res = lazy(|cx| srv2.poll_ready(cx)).await;
316 assert_eq!(res, Poll::Ready(Ok(())));
317 assert_eq!(cnt.get(), 2);
318
319 let res = lazy(|cx| srv1.poll_ready(cx)).await;
320 assert_eq!(res, Poll::Pending);
321 assert_eq!(cnt.get(), 3);
322 }
323
324 #[ntex::test]
325 async fn test_ready_on_drop() {
326 let cnt = Rc::new(Cell::new(0));
327 let con = condition::Condition::new();
328 let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
329
330 let srv1 = srv.clone();
331 let srv2 = srv1.clone().bind();
332
333 let (tx, rx) = oneshot::channel();
334 spawn(async move {
335 select(rx, srv1.ready()).await;
336 time::sleep(time::Millis(25000)).await;
337 drop(srv1);
338 });
339 time::sleep(time::Millis(250)).await;
340
341 let res = lazy(|cx| srv2.poll_ready(cx)).await;
342 assert_eq!(res, Poll::Pending);
343
344 let _ = tx.send(());
345 time::sleep(time::Millis(250)).await;
346
347 let res = lazy(|cx| srv2.poll_ready(cx)).await;
348 assert_eq!(res, Poll::Pending);
349
350 con.notify();
351 let res = lazy(|cx| srv2.poll_ready(cx)).await;
352 assert_eq!(res, Poll::Ready(Ok(())));
353 }
354
355 #[ntex::test]
356 async fn test_ready_after_shutdown() {
357 let cnt = Rc::new(Cell::new(0));
358 let con = condition::Condition::new();
359 let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
360
361 let srv1 = srv.clone().bind();
362 let srv2 = srv1.clone();
363
364 let (tx, rx) = oneshot::channel();
365 spawn(async move {
366 select(rx, poll_fn(|cx| srv1.poll_ready(cx))).await;
367 poll_fn(|cx| srv1.poll_shutdown(cx)).await;
368 time::sleep(time::Millis(25000)).await;
369 drop(srv1);
370 });
371 time::sleep(time::Millis(250)).await;
372
373 let res = lazy(|cx| srv2.poll_ready(cx)).await;
374 assert_eq!(res, Poll::Pending);
375
376 let _ = tx.send(());
377 time::sleep(time::Millis(250)).await;
378
379 let res = lazy(|cx| srv2.poll_ready(cx)).await;
380 assert_eq!(res, Poll::Pending);
381
382 con.notify();
383 let res = lazy(|cx| srv2.poll_ready(cx)).await;
384 assert_eq!(res, Poll::Ready(Ok(())));
385 }
386
387 #[ntex::test]
388 #[should_panic]
389 async fn test_pipeline_binding_after_shutdown() {
390 let cnt = Rc::new(Cell::new(0));
391 let con = condition::Condition::new();
392 let srv = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
393 let _ = poll_fn(|cx| srv.poll_shutdown(cx)).await;
394 let _ = poll_fn(|cx| srv.poll_ready(cx)).await;
395 }
396
397 #[ntex::test]
398 async fn test_shared_call() {
399 let data = Rc::new(RefCell::new(Vec::new()));
400
401 let cnt = Rc::new(Cell::new(0));
402 let con = condition::Condition::new();
403
404 let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
405 let srv2 = srv1.clone();
406 let _: Pipeline<_> = srv1.pipeline();
407
408 let data1 = data.clone();
409 ntex::rt::spawn(async move {
410 let _ = poll_fn(|cx| srv1.poll_ready(cx)).await;
411 let fut = srv1.call_nowait("srv1");
412 assert!(format!("{:?}", fut).contains("PipelineCall"));
413 let i = fut.await.unwrap();
414 data1.borrow_mut().push(i);
415 });
416
417 let data2 = data.clone();
418 ntex::rt::spawn(async move {
419 let i = srv2.call("srv2").await.unwrap();
420 data2.borrow_mut().push(i);
421 });
422 time::sleep(time::Millis(50)).await;
423
424 con.notify();
425 time::sleep(time::Millis(150)).await;
426
427 assert_eq!(cnt.get(), 2);
428 assert_eq!(&*data.borrow(), &["srv1"]);
429
430 con.notify();
431 time::sleep(time::Millis(150)).await;
432
433 assert_eq!(cnt.get(), 2);
434 assert_eq!(&*data.borrow(), &["srv1", "srv2"]);
435 }
436}