1use std::task::{Context, Poll, Waker};
2use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc};
3
4use crate::Service;
5
6pub struct ServiceCtx<'a, S: ?Sized> {
7 idx: u32,
8 waiters: &'a WaitersRef,
9 _t: marker::PhantomData<Rc<S>>,
10}
11
12#[derive(Debug)]
13pub(crate) struct WaitersRef {
14 cur: cell::Cell<u32>,
15 shutdown: cell::Cell<bool>,
16 wakers: cell::UnsafeCell<Vec<u32>>,
17 indexes: cell::UnsafeCell<slab::Slab<Option<Waker>>>,
18}
19
20impl WaitersRef {
21 pub(crate) fn new() -> (u32, Self) {
22 let mut waiters = slab::Slab::new();
23
24 (
25 waiters.insert(Default::default()) as u32,
26 WaitersRef {
27 cur: cell::Cell::new(u32::MAX),
28 shutdown: cell::Cell::new(false),
29 indexes: cell::UnsafeCell::new(waiters),
30 wakers: cell::UnsafeCell::new(Vec::default()),
31 },
32 )
33 }
34
35 #[allow(clippy::mut_from_ref)]
36 pub(crate) fn get(&self) -> &mut slab::Slab<Option<Waker>> {
37 unsafe { &mut *self.indexes.get() }
38 }
39
40 #[allow(clippy::mut_from_ref)]
41 pub(crate) fn get_wakers(&self) -> &mut Vec<u32> {
42 unsafe { &mut *self.wakers.get() }
43 }
44
45 pub(crate) fn insert(&self) -> u32 {
46 self.get().insert(None) as u32
47 }
48
49 pub(crate) fn remove(&self, idx: u32) {
50 self.get().remove(idx as usize);
51
52 if self.cur.get() == idx {
53 self.notify();
54 }
55 }
56
57 pub(crate) fn register(&self, idx: u32, cx: &mut Context<'_>) {
58 let wakers = self.get_wakers();
59 if let Some(last) = wakers.last() {
60 if idx == *last {
61 return;
62 }
63 }
64 wakers.push(idx);
65 self.get()[idx as usize] = Some(cx.waker().clone());
66 }
67
68 pub(crate) fn notify(&self) {
69 let wakers = self.get_wakers();
70 if !wakers.is_empty() {
71 let indexes = self.get();
72 for idx in wakers.drain(..) {
73 if let Some(item) = indexes.get_mut(idx as usize) {
74 if let Some(waker) = item.take() {
75 waker.wake();
76 }
77 }
78 }
79 }
80
81 self.cur.set(u32::MAX);
82 }
83
84 pub(crate) fn can_check(&self, idx: u32, cx: &mut Context<'_>) -> bool {
85 let cur = self.cur.get();
86 if cur == idx {
87 true
88 } else if cur == u32::MAX {
89 self.cur.set(idx);
90 true
91 } else {
92 self.register(idx, cx);
93 false
94 }
95 }
96
97 pub(crate) fn shutdown(&self) {
98 self.shutdown.set(true);
99 }
100
101 pub(crate) fn is_shutdown(&self) -> bool {
102 self.shutdown.get()
103 }
104}
105
106impl<'a, S> ServiceCtx<'a, S> {
107 pub(crate) fn new(idx: u32, waiters: &'a WaitersRef) -> Self {
108 Self {
109 idx,
110 waiters,
111 _t: marker::PhantomData,
112 }
113 }
114
115 pub(crate) fn inner(self) -> (u32, &'a WaitersRef) {
116 (self.idx, self.waiters)
117 }
118
119 pub async fn ready<T, R>(&self, svc: &'a T) -> Result<(), T::Error>
121 where
122 T: Service<R>,
123 {
124 ReadyCall {
126 completed: false,
127 fut: svc.ready(ServiceCtx {
128 idx: self.idx,
129 waiters: self.waiters,
130 _t: marker::PhantomData,
131 }),
132 ctx: *self,
133 }
134 .await
135 }
136
137 #[inline]
138 pub async fn call<T, R>(&self, svc: &'a T, req: R) -> Result<T::Response, T::Error>
140 where
141 T: Service<R>,
142 R: 'a,
143 {
144 self.ready(svc).await?;
145
146 svc.call(
147 req,
148 ServiceCtx {
149 idx: self.idx,
150 waiters: self.waiters,
151 _t: marker::PhantomData,
152 },
153 )
154 .await
155 }
156
157 #[inline]
158 pub async fn call_nowait<T, R>(
160 &self,
161 svc: &'a T,
162 req: R,
163 ) -> Result<T::Response, T::Error>
164 where
165 T: Service<R>,
166 R: 'a,
167 {
168 svc.call(
169 req,
170 ServiceCtx {
171 idx: self.idx,
172 waiters: self.waiters,
173 _t: marker::PhantomData,
174 },
175 )
176 .await
177 }
178}
179
180impl<S> Copy for ServiceCtx<'_, S> {}
181
182impl<S> Clone for ServiceCtx<'_, S> {
183 #[inline]
184 fn clone(&self) -> Self {
185 *self
186 }
187}
188
189impl<S> fmt::Debug for ServiceCtx<'_, S> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_struct("ServiceCtx")
192 .field("idx", &self.idx)
193 .field("waiters", &self.waiters.get().len())
194 .finish()
195 }
196}
197
198struct ReadyCall<'a, S: ?Sized, F: Future> {
199 completed: bool,
200 fut: F,
201 ctx: ServiceCtx<'a, S>,
202}
203
204impl<S: ?Sized, F: Future> Drop for ReadyCall<'_, S, F> {
205 fn drop(&mut self) {
206 if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx {
207 self.ctx.waiters.notify();
208 }
209 }
210}
211
212impl<S: ?Sized, F: Future> Unpin for ReadyCall<'_, S, F> {}
213
214impl<S: ?Sized, F: Future> Future for ReadyCall<'_, S, F> {
215 type Output = F::Output;
216
217 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
218 if self.ctx.waiters.can_check(self.ctx.idx, cx) {
219 let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) };
221 match result {
222 Poll::Pending => {
223 self.ctx.waiters.register(self.ctx.idx, cx);
224 Poll::Pending
225 }
226 Poll::Ready(res) => {
227 self.completed = true;
228 self.ctx.waiters.notify();
229 Poll::Ready(res)
230 }
231 }
232 } else {
233 Poll::Pending
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use std::{cell::Cell, cell::RefCell, future::poll_fn};
241
242 use ntex_util::channel::{condition, oneshot};
243 use ntex_util::{future::lazy, future::select, spawn, time};
244
245 use super::*;
246 use crate::Pipeline;
247
248 struct Srv(Rc<Cell<usize>>, condition::Waiter);
249
250 impl Service<&'static str> for Srv {
251 type Response = &'static str;
252 type Error = ();
253
254 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
255 self.0.set(self.0.get() + 1);
256 self.1.ready().await;
257 Ok(())
258 }
259
260 async fn call(
261 &self,
262 req: &'static str,
263 ctx: ServiceCtx<'_, Self>,
264 ) -> Result<Self::Response, Self::Error> {
265 let _ = format!("{:?}", ctx);
266 #[allow(clippy::clone_on_copy)]
267 let _ = ctx.clone();
268 Ok(req)
269 }
270 }
271
272 #[ntex::test]
273 async fn test_ready() {
274 let cnt = Rc::new(Cell::new(0));
275 let con = condition::Condition::new();
276
277 let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
278 let srv2 = srv1.clone();
279
280 let res = lazy(|cx| srv1.poll_ready(cx)).await;
281 assert_eq!(res, Poll::Pending);
282 assert_eq!(cnt.get(), 1);
283
284 let res = lazy(|cx| srv2.poll_ready(cx)).await;
285 assert_eq!(res, Poll::Pending);
286 assert_eq!(cnt.get(), 1);
287
288 con.notify();
289 let res = lazy(|cx| srv1.poll_ready(cx)).await;
290 assert_eq!(res, Poll::Ready(Ok(())));
291 assert_eq!(cnt.get(), 1);
292
293 let res = lazy(|cx| srv2.poll_ready(cx)).await;
294 assert_eq!(res, Poll::Pending);
295 assert_eq!(cnt.get(), 2);
296
297 con.notify();
298 let res = lazy(|cx| srv2.poll_ready(cx)).await;
299 assert_eq!(res, Poll::Ready(Ok(())));
300 assert_eq!(cnt.get(), 2);
301
302 let res = lazy(|cx| srv1.poll_ready(cx)).await;
303 assert_eq!(res, Poll::Pending);
304 assert_eq!(cnt.get(), 3);
305 }
306
307 #[ntex::test]
308 async fn test_ready_on_drop() {
309 let cnt = Rc::new(Cell::new(0));
310 let con = condition::Condition::new();
311 let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
312
313 let srv1 = srv.clone();
314 let srv2 = srv1.clone().bind();
315
316 let (tx, rx) = oneshot::channel();
317 spawn(async move {
318 select(rx, srv1.ready()).await;
319 time::sleep(time::Millis(25000)).await;
320 drop(srv1);
321 });
322 time::sleep(time::Millis(250)).await;
323
324 let res = lazy(|cx| srv2.poll_ready(cx)).await;
325 assert_eq!(res, Poll::Pending);
326
327 let _ = tx.send(());
328 time::sleep(time::Millis(250)).await;
329
330 let res = lazy(|cx| srv2.poll_ready(cx)).await;
331 assert_eq!(res, Poll::Pending);
332
333 con.notify();
334 let res = lazy(|cx| srv2.poll_ready(cx)).await;
335 assert_eq!(res, Poll::Ready(Ok(())));
336 }
337
338 #[ntex::test]
339 async fn test_ready_after_shutdown() {
340 let cnt = Rc::new(Cell::new(0));
341 let con = condition::Condition::new();
342 let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
343
344 let srv1 = srv.clone().bind();
345 let srv2 = srv1.clone();
346
347 let (tx, rx) = oneshot::channel();
348 spawn(async move {
349 select(rx, poll_fn(|cx| srv1.poll_ready(cx))).await;
350 poll_fn(|cx| srv1.poll_shutdown(cx)).await;
351 time::sleep(time::Millis(25000)).await;
352 drop(srv1);
353 });
354 time::sleep(time::Millis(250)).await;
355
356 let res = lazy(|cx| srv2.poll_ready(cx)).await;
357 assert_eq!(res, Poll::Pending);
358
359 let _ = tx.send(());
360 time::sleep(time::Millis(250)).await;
361
362 let res = lazy(|cx| srv2.poll_ready(cx)).await;
363 assert_eq!(res, Poll::Pending);
364
365 con.notify();
366 let res = lazy(|cx| srv2.poll_ready(cx)).await;
367 assert_eq!(res, Poll::Ready(Ok(())));
368 }
369
370 #[ntex::test]
371 async fn test_shared_call() {
372 let data = Rc::new(RefCell::new(Vec::new()));
373
374 let cnt = Rc::new(Cell::new(0));
375 let con = condition::Condition::new();
376
377 let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
378 let srv2 = srv1.clone();
379
380 let data1 = data.clone();
381 ntex::rt::spawn(async move {
382 let _ = poll_fn(|cx| srv1.poll_ready(cx)).await;
383 let i = srv1.call_nowait("srv1").await.unwrap();
384 data1.borrow_mut().push(i);
385 });
386
387 let data2 = data.clone();
388 ntex::rt::spawn(async move {
389 let i = srv2.call("srv2").await.unwrap();
390 data2.borrow_mut().push(i);
391 });
392 time::sleep(time::Millis(50)).await;
393
394 con.notify();
395 time::sleep(time::Millis(150)).await;
396
397 assert_eq!(cnt.get(), 2);
398 assert_eq!(&*data.borrow(), &["srv1"]);
399
400 con.notify();
401 time::sleep(time::Millis(150)).await;
402
403 assert_eq!(cnt.get(), 2);
404 assert_eq!(&*data.borrow(), &["srv1", "srv2"]);
405 }
406}