ntex_service/
ctx.rs

1use std::task::{Context, Poll, Waker};
2use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc};
3
4use crate::Service;
5
6pub struct ServiceCtx<'a, S: ?Sized> {
7    idx: u32,
8    waiters: &'a WaitersRef,
9    _t: marker::PhantomData<Rc<S>>,
10}
11
12#[derive(Debug)]
13pub(crate) struct WaitersRef {
14    running: cell::Cell<bool>,
15    cur: cell::Cell<u32>,
16    shutdown: cell::Cell<bool>,
17    wakers: cell::UnsafeCell<Vec<u32>>,
18    indexes: cell::UnsafeCell<slab::Slab<Option<Waker>>>,
19}
20
21impl WaitersRef {
22    pub(crate) fn new() -> (u32, Self) {
23        let mut waiters = slab::Slab::new();
24
25        (
26            waiters.insert(Default::default()) as u32,
27            WaitersRef {
28                running: cell::Cell::new(false),
29                cur: cell::Cell::new(u32::MAX),
30                shutdown: cell::Cell::new(false),
31                indexes: cell::UnsafeCell::new(waiters),
32                wakers: cell::UnsafeCell::new(Vec::default()),
33            },
34        )
35    }
36
37    #[allow(clippy::mut_from_ref)]
38    pub(crate) fn get(&self) -> &mut slab::Slab<Option<Waker>> {
39        unsafe { &mut *self.indexes.get() }
40    }
41
42    #[allow(clippy::mut_from_ref)]
43    pub(crate) fn get_wakers(&self) -> &mut Vec<u32> {
44        unsafe { &mut *self.wakers.get() }
45    }
46
47    pub(crate) fn insert(&self) -> u32 {
48        self.get().insert(None) as u32
49    }
50
51    pub(crate) fn remove(&self, idx: u32) {
52        self.get().remove(idx as usize);
53
54        if self.cur.get() == idx {
55            self.notify();
56        }
57    }
58
59    pub(crate) fn notify(&self) {
60        let wakers = self.get_wakers();
61        if !wakers.is_empty() {
62            let indexes = self.get();
63            for idx in wakers.drain(..) {
64                if let Some(item) = indexes.get_mut(idx as usize) {
65                    if let Some(waker) = item.take() {
66                        waker.wake();
67                    }
68                }
69            }
70        }
71
72        self.cur.set(u32::MAX);
73    }
74
75    pub(crate) fn run<F, R>(&self, idx: u32, cx: &mut Context<'_>, f: F) -> Poll<R>
76    where
77        F: FnOnce(&mut Context<'_>) -> Poll<R>,
78    {
79        // calculate owner for readiness check
80        let cur = self.cur.get();
81        let can_check = if cur == idx {
82            true
83        } else if cur == u32::MAX {
84            self.cur.set(idx);
85            true
86        } else {
87            false
88        };
89
90        if can_check {
91            // only one readiness check can manage waiters
92            let initial_run = !self.running.get();
93            if initial_run {
94                self.running.set(true);
95            }
96
97            let result = f(cx);
98
99            if initial_run {
100                if result.is_pending() {
101                    self.get_wakers().push(idx);
102                    self.get()[idx as usize] = Some(cx.waker().clone());
103                } else {
104                    self.notify();
105                }
106                self.running.set(false);
107            }
108            result
109        } else {
110            // other pipeline ownes readiness check process
111            self.get_wakers().push(idx);
112            self.get()[idx as usize] = Some(cx.waker().clone());
113            Poll::Pending
114        }
115    }
116
117    pub(crate) fn shutdown(&self) {
118        self.shutdown.set(true);
119    }
120
121    pub(crate) fn is_shutdown(&self) -> bool {
122        self.shutdown.get()
123    }
124}
125
126impl<'a, S> ServiceCtx<'a, S> {
127    pub(crate) fn new(idx: u32, waiters: &'a WaitersRef) -> Self {
128        Self {
129            idx,
130            waiters,
131            _t: marker::PhantomData,
132        }
133    }
134
135    pub(crate) fn inner(self) -> (u32, &'a WaitersRef) {
136        (self.idx, self.waiters)
137    }
138
139    #[inline]
140    /// Unique id for this pipeline
141    pub fn id(&self) -> u32 {
142        self.idx
143    }
144
145    /// Returns when the service is able to process requests.
146    pub async fn ready<T, R>(&self, svc: &'a T) -> Result<(), T::Error>
147    where
148        T: Service<R>,
149    {
150        // check readiness and notify waiters
151        ReadyCall {
152            completed: false,
153            fut: svc.ready(ServiceCtx {
154                idx: self.idx,
155                waiters: self.waiters,
156                _t: marker::PhantomData,
157            }),
158            ctx: *self,
159        }
160        .await
161    }
162
163    #[inline]
164    /// Wait for service readiness and then call service
165    pub async fn call<T, R>(&self, svc: &'a T, req: R) -> Result<T::Response, T::Error>
166    where
167        T: Service<R>,
168        R: 'a,
169    {
170        self.ready(svc).await?;
171
172        svc.call(
173            req,
174            ServiceCtx {
175                idx: self.idx,
176                waiters: self.waiters,
177                _t: marker::PhantomData,
178            },
179        )
180        .await
181    }
182
183    #[inline]
184    /// Call service, do not check service readiness
185    pub async fn call_nowait<T, R>(
186        &self,
187        svc: &'a T,
188        req: R,
189    ) -> Result<T::Response, T::Error>
190    where
191        T: Service<R>,
192        R: 'a,
193    {
194        svc.call(
195            req,
196            ServiceCtx {
197                idx: self.idx,
198                waiters: self.waiters,
199                _t: marker::PhantomData,
200            },
201        )
202        .await
203    }
204}
205
206impl<S> Copy for ServiceCtx<'_, S> {}
207
208impl<S> Clone for ServiceCtx<'_, S> {
209    #[inline]
210    fn clone(&self) -> Self {
211        *self
212    }
213}
214
215impl<S> fmt::Debug for ServiceCtx<'_, S> {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        f.debug_struct("ServiceCtx")
218            .field("idx", &self.idx)
219            .field("waiters", &self.waiters.get().len())
220            .finish()
221    }
222}
223
224struct ReadyCall<'a, S: ?Sized, F: Future> {
225    completed: bool,
226    fut: F,
227    ctx: ServiceCtx<'a, S>,
228}
229
230impl<S: ?Sized, F: Future> Drop for ReadyCall<'_, S, F> {
231    fn drop(&mut self) {
232        if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx {
233            self.ctx.waiters.notify();
234        }
235    }
236}
237
238impl<S: ?Sized, F: Future> Unpin for ReadyCall<'_, S, F> {}
239
240impl<S: ?Sized, F: Future> Future for ReadyCall<'_, S, F> {
241    type Output = F::Output;
242
243    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244        self.ctx.waiters.run(self.ctx.idx, cx, |cx| {
245            // SAFETY: `fut` never moves
246            let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) };
247            if result.is_ready() {
248                self.completed = true;
249            }
250            result
251        })
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::{cell::Cell, cell::RefCell, future::poll_fn};
258
259    use ntex_util::channel::{condition, oneshot};
260    use ntex_util::{future::lazy, future::select, spawn, time};
261
262    use super::*;
263    use crate::Pipeline;
264
265    struct Srv(Rc<Cell<usize>>, condition::Waiter);
266
267    impl Service<&'static str> for Srv {
268        type Response = &'static str;
269        type Error = ();
270
271        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
272            self.0.set(self.0.get() + 1);
273            self.1.ready().await;
274            Ok(())
275        }
276
277        async fn call(
278            &self,
279            req: &'static str,
280            ctx: ServiceCtx<'_, Self>,
281        ) -> Result<Self::Response, Self::Error> {
282            let _ = format!("{ctx:?}");
283            #[allow(clippy::clone_on_copy)]
284            let _ = ctx.clone();
285            Ok(req)
286        }
287    }
288
289    #[ntex::test]
290    async fn test_ready() {
291        let cnt = Rc::new(Cell::new(0));
292        let con = condition::Condition::new();
293
294        let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
295        let srv2 = srv1.clone();
296
297        let res = lazy(|cx| srv1.poll_ready(cx)).await;
298        assert_eq!(res, Poll::Pending);
299        assert_eq!(cnt.get(), 1);
300
301        let res = lazy(|cx| srv2.poll_ready(cx)).await;
302        assert_eq!(res, Poll::Pending);
303        assert_eq!(cnt.get(), 1);
304
305        con.notify();
306        let res = lazy(|cx| srv1.poll_ready(cx)).await;
307        assert_eq!(res, Poll::Ready(Ok(())));
308        assert_eq!(cnt.get(), 1);
309
310        let res = lazy(|cx| srv2.poll_ready(cx)).await;
311        assert_eq!(res, Poll::Pending);
312        assert_eq!(cnt.get(), 2);
313
314        con.notify();
315        let res = lazy(|cx| srv2.poll_ready(cx)).await;
316        assert_eq!(res, Poll::Ready(Ok(())));
317        assert_eq!(cnt.get(), 2);
318
319        let res = lazy(|cx| srv1.poll_ready(cx)).await;
320        assert_eq!(res, Poll::Pending);
321        assert_eq!(cnt.get(), 3);
322    }
323
324    #[ntex::test]
325    async fn test_ready_on_drop() {
326        let cnt = Rc::new(Cell::new(0));
327        let con = condition::Condition::new();
328        let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
329
330        let srv1 = srv.clone();
331        let srv2 = srv1.clone().bind();
332
333        let (tx, rx) = oneshot::channel();
334        spawn(async move {
335            select(rx, srv1.ready()).await;
336            time::sleep(time::Millis(25000)).await;
337            drop(srv1);
338        });
339        time::sleep(time::Millis(250)).await;
340
341        let res = lazy(|cx| srv2.poll_ready(cx)).await;
342        assert_eq!(res, Poll::Pending);
343
344        let _ = tx.send(());
345        time::sleep(time::Millis(250)).await;
346
347        let res = lazy(|cx| srv2.poll_ready(cx)).await;
348        assert_eq!(res, Poll::Pending);
349
350        con.notify();
351        let res = lazy(|cx| srv2.poll_ready(cx)).await;
352        assert_eq!(res, Poll::Ready(Ok(())));
353    }
354
355    #[ntex::test]
356    async fn test_ready_after_shutdown() {
357        let cnt = Rc::new(Cell::new(0));
358        let con = condition::Condition::new();
359        let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
360
361        let srv1 = srv.clone().bind();
362        let srv2 = srv1.clone();
363
364        let (tx, rx) = oneshot::channel();
365        spawn(async move {
366            select(rx, poll_fn(|cx| srv1.poll_ready(cx))).await;
367            poll_fn(|cx| srv1.poll_shutdown(cx)).await;
368            time::sleep(time::Millis(25000)).await;
369            drop(srv1);
370        });
371        time::sleep(time::Millis(250)).await;
372
373        let res = lazy(|cx| srv2.poll_ready(cx)).await;
374        assert_eq!(res, Poll::Pending);
375
376        let _ = tx.send(());
377        time::sleep(time::Millis(250)).await;
378
379        let res = lazy(|cx| srv2.poll_ready(cx)).await;
380        assert_eq!(res, Poll::Pending);
381
382        con.notify();
383        let res = lazy(|cx| srv2.poll_ready(cx)).await;
384        assert_eq!(res, Poll::Ready(Ok(())));
385    }
386
387    #[ntex::test]
388    #[should_panic]
389    async fn test_pipeline_binding_after_shutdown() {
390        let cnt = Rc::new(Cell::new(0));
391        let con = condition::Condition::new();
392        let srv = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
393        let _ = poll_fn(|cx| srv.poll_shutdown(cx)).await;
394        let _ = poll_fn(|cx| srv.poll_ready(cx)).await;
395    }
396
397    #[ntex::test]
398    async fn test_shared_call() {
399        let data = Rc::new(RefCell::new(Vec::new()));
400
401        let cnt = Rc::new(Cell::new(0));
402        let con = condition::Condition::new();
403
404        let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
405        let srv2 = srv1.clone();
406        let _: Pipeline<_> = srv1.pipeline();
407
408        let data1 = data.clone();
409        ntex::rt::spawn(async move {
410            let _ = poll_fn(|cx| srv1.poll_ready(cx)).await;
411            let fut = srv1.call_nowait("srv1");
412            assert!(format!("{:?}", fut).contains("PipelineCall"));
413            let i = fut.await.unwrap();
414            data1.borrow_mut().push(i);
415        });
416
417        let data2 = data.clone();
418        ntex::rt::spawn(async move {
419            let i = srv2.call("srv2").await.unwrap();
420            data2.borrow_mut().push(i);
421        });
422        time::sleep(time::Millis(50)).await;
423
424        con.notify();
425        time::sleep(time::Millis(150)).await;
426
427        assert_eq!(cnt.get(), 2);
428        assert_eq!(&*data.borrow(), &["srv1"]);
429
430        con.notify();
431        time::sleep(time::Millis(150)).await;
432
433        assert_eq!(cnt.get(), 2);
434        assert_eq!(&*data.borrow(), &["srv1", "srv2"]);
435    }
436}