Skip to main content

ntex_service/
ctx.rs

1#![allow(clippy::cast_possible_truncation)]
2use std::task::{Context, Poll, Waker};
3use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc};
4
5use crate::Service;
6
7pub struct ServiceCtx<'a, S: ?Sized> {
8    idx: u32,
9    waiters: &'a WaitersRef,
10    _t: marker::PhantomData<Rc<S>>,
11}
12
13#[derive(Debug)]
14pub(crate) struct WaitersRef {
15    running: cell::Cell<bool>,
16    cur: cell::Cell<u32>,
17    shutdown: cell::Cell<bool>,
18    wakers: cell::UnsafeCell<Vec<u32>>,
19    indexes: cell::UnsafeCell<slab::Slab<Option<Waker>>>,
20}
21
22impl WaitersRef {
23    pub(crate) fn new() -> (u32, Self) {
24        let mut waiters = slab::Slab::new();
25
26        (
27            waiters.insert(None) as u32,
28            WaitersRef {
29                running: cell::Cell::new(false),
30                cur: cell::Cell::new(u32::MAX),
31                shutdown: cell::Cell::new(false),
32                indexes: cell::UnsafeCell::new(waiters),
33                wakers: cell::UnsafeCell::new(Vec::default()),
34            },
35        )
36    }
37
38    #[allow(clippy::mut_from_ref)]
39    pub(crate) fn get(&self) -> &mut slab::Slab<Option<Waker>> {
40        unsafe { &mut *self.indexes.get() }
41    }
42
43    #[allow(clippy::mut_from_ref)]
44    pub(crate) fn get_wakers(&self) -> &mut Vec<u32> {
45        unsafe { &mut *self.wakers.get() }
46    }
47
48    pub(crate) fn insert(&self) -> u32 {
49        self.get().insert(None) as u32
50    }
51
52    pub(crate) fn remove(&self, idx: u32) {
53        self.get().remove(idx as usize);
54
55        if self.cur.get() == idx {
56            self.notify();
57        }
58    }
59
60    pub(crate) fn notify(&self) {
61        let wakers = self.get_wakers();
62        if !wakers.is_empty() {
63            let indexes = self.get();
64            for idx in wakers.drain(..) {
65                if let Some(item) = indexes.get_mut(idx as usize)
66                    && let Some(waker) = item.take()
67                {
68                    waker.wake();
69                }
70            }
71        }
72
73        self.cur.set(u32::MAX);
74    }
75
76    pub(crate) fn run<F, R>(&self, idx: u32, cx: &mut Context<'_>, f: F) -> Poll<R>
77    where
78        F: FnOnce(&mut Context<'_>) -> Poll<R>,
79    {
80        // calculate owner for readiness check
81        let cur = self.cur.get();
82        let can_check = if cur == idx {
83            true
84        } else if cur == u32::MAX {
85            self.cur.set(idx);
86            true
87        } else {
88            false
89        };
90
91        if can_check {
92            // only one readiness check can manage waiters
93            let initial_run = !self.running.get();
94            if initial_run {
95                self.running.set(true);
96            }
97
98            let result = f(cx);
99
100            if initial_run {
101                if result.is_pending() {
102                    self.get_wakers().push(idx);
103                    self.get()[idx as usize] = Some(cx.waker().clone());
104                } else {
105                    self.notify();
106                }
107                self.running.set(false);
108            }
109            result
110        } else {
111            // other pipeline ownes readiness check process
112            self.get_wakers().push(idx);
113            self.get()[idx as usize] = Some(cx.waker().clone());
114            Poll::Pending
115        }
116    }
117
118    pub(crate) fn shutdown(&self) {
119        self.shutdown.set(true);
120    }
121
122    pub(crate) fn is_shutdown(&self) -> bool {
123        self.shutdown.get()
124    }
125}
126
127impl<'a, S> ServiceCtx<'a, S> {
128    pub(crate) fn new(idx: u32, waiters: &'a WaitersRef) -> Self {
129        Self {
130            idx,
131            waiters,
132            _t: marker::PhantomData,
133        }
134    }
135
136    pub(crate) fn inner(self) -> (u32, &'a WaitersRef) {
137        (self.idx, self.waiters)
138    }
139
140    #[inline]
141    /// Unique id for this pipeline
142    pub fn id(&self) -> u32 {
143        self.idx
144    }
145
146    /// Returns when the service is able to process requests.
147    pub async fn ready<T, R>(&self, svc: &'a T) -> Result<(), T::Error>
148    where
149        T: Service<R>,
150    {
151        // check readiness and notify waiters
152        ReadyCall {
153            completed: false,
154            fut: svc.ready(ServiceCtx {
155                idx: self.idx,
156                waiters: self.waiters,
157                _t: marker::PhantomData,
158            }),
159            ctx: *self,
160        }
161        .await
162    }
163
164    #[inline]
165    /// Wait for service readiness and then call service
166    pub async fn call<T, R>(&self, svc: &'a T, req: R) -> Result<T::Response, T::Error>
167    where
168        T: Service<R>,
169        R: 'a,
170    {
171        self.ready(svc).await?;
172
173        svc.call(
174            req,
175            ServiceCtx {
176                idx: self.idx,
177                waiters: self.waiters,
178                _t: marker::PhantomData,
179            },
180        )
181        .await
182    }
183
184    #[inline]
185    /// Call service, do not check service readiness
186    pub async fn call_nowait<T, R>(
187        &self,
188        svc: &'a T,
189        req: R,
190    ) -> Result<T::Response, T::Error>
191    where
192        T: Service<R>,
193        R: 'a,
194    {
195        svc.call(
196            req,
197            ServiceCtx {
198                idx: self.idx,
199                waiters: self.waiters,
200                _t: marker::PhantomData,
201            },
202        )
203        .await
204    }
205}
206
207impl<S> Copy for ServiceCtx<'_, S> {}
208
209impl<S> Clone for ServiceCtx<'_, S> {
210    #[inline]
211    fn clone(&self) -> Self {
212        *self
213    }
214}
215
216impl<S> fmt::Debug for ServiceCtx<'_, S> {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        f.debug_struct("ServiceCtx")
219            .field("idx", &self.idx)
220            .field("waiters", &self.waiters.get().len())
221            .finish()
222    }
223}
224
225struct ReadyCall<'a, S: ?Sized, F: Future> {
226    completed: bool,
227    fut: F,
228    ctx: ServiceCtx<'a, S>,
229}
230
231impl<S: ?Sized, F: Future> Drop for ReadyCall<'_, S, F> {
232    fn drop(&mut self) {
233        if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx {
234            self.ctx.waiters.notify();
235        }
236    }
237}
238
239impl<S: ?Sized, F: Future> Unpin for ReadyCall<'_, S, F> {}
240
241impl<S: ?Sized, F: Future> Future for ReadyCall<'_, S, F> {
242    type Output = F::Output;
243
244    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245        self.ctx.waiters.run(self.ctx.idx, cx, |cx| {
246            // SAFETY: `fut` never moves
247            let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) };
248            if result.is_ready() {
249                self.completed = true;
250            }
251            result
252        })
253    }
254}
255
256#[cfg(test)]
257#[allow(clippy::should_panic_without_expect)]
258mod tests {
259    use std::{cell::Cell, cell::RefCell, future::poll_fn};
260
261    use ntex::channel::{condition, oneshot};
262    use ntex::{rt::spawn, time, util::lazy, util::select};
263
264    use super::*;
265    use crate::Pipeline;
266
267    struct Srv(Rc<Cell<usize>>, condition::Waiter);
268
269    impl Service<&'static str> for Srv {
270        type Response = &'static str;
271        type Error = ();
272
273        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
274            self.0.set(self.0.get() + 1);
275            self.1.ready().await;
276            Ok(())
277        }
278
279        async fn call(
280            &self,
281            req: &'static str,
282            ctx: ServiceCtx<'_, Self>,
283        ) -> Result<Self::Response, Self::Error> {
284            let _ = format!("{ctx:?}");
285            let _ = format!("{:?}", ctx.id());
286            #[allow(clippy::clone_on_copy)]
287            let _ = ctx.clone();
288            Ok(req)
289        }
290    }
291
292    #[ntex::test]
293    async fn test_ready() {
294        let cnt = Rc::new(Cell::new(0));
295        let con = condition::Condition::new();
296
297        let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
298        let srv2 = srv1.clone();
299
300        let res = lazy(|cx| srv1.poll_ready(cx)).await;
301        assert_eq!(res, Poll::Pending);
302        assert_eq!(cnt.get(), 1);
303
304        let res = lazy(|cx| srv2.poll_ready(cx)).await;
305        assert_eq!(res, Poll::Pending);
306        assert_eq!(cnt.get(), 1);
307
308        con.notify();
309        let res = lazy(|cx| srv1.poll_ready(cx)).await;
310        assert_eq!(res, Poll::Ready(Ok(())));
311        assert_eq!(cnt.get(), 1);
312
313        let res = lazy(|cx| srv2.poll_ready(cx)).await;
314        assert_eq!(res, Poll::Pending);
315        assert_eq!(cnt.get(), 2);
316
317        con.notify();
318        let res = lazy(|cx| srv2.poll_ready(cx)).await;
319        assert_eq!(res, Poll::Ready(Ok(())));
320        assert_eq!(cnt.get(), 2);
321
322        let res = lazy(|cx| srv1.poll_ready(cx)).await;
323        assert_eq!(res, Poll::Pending);
324        assert_eq!(cnt.get(), 3);
325    }
326
327    #[ntex::test]
328    async fn test_ready_on_drop() {
329        let cnt = Rc::new(Cell::new(0));
330        let con = condition::Condition::new();
331        let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
332
333        let srv1 = srv.clone();
334        let srv2 = srv1.clone().bind();
335
336        let (tx, rx) = oneshot::channel();
337        spawn(async move {
338            select(rx, srv1.ready()).await;
339            time::sleep(time::Millis(25000)).await;
340        });
341        time::sleep(time::Millis(250)).await;
342
343        let res = lazy(|cx| srv2.poll_ready(cx)).await;
344        assert_eq!(res, Poll::Pending);
345
346        let _ = tx.send(());
347        time::sleep(time::Millis(250)).await;
348
349        let res = lazy(|cx| srv2.poll_ready(cx)).await;
350        assert_eq!(res, Poll::Pending);
351
352        con.notify();
353        let res = lazy(|cx| srv2.poll_ready(cx)).await;
354        assert_eq!(res, Poll::Ready(Ok(())));
355    }
356
357    #[ntex::test]
358    async fn test_ready_after_shutdown() {
359        let cnt = Rc::new(Cell::new(0));
360        let con = condition::Condition::new();
361        let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
362
363        let srv1 = srv.clone().bind();
364        let srv2 = srv1.clone();
365
366        let (tx, rx) = oneshot::channel();
367        spawn(async move {
368            select(rx, poll_fn(|cx| srv1.poll_ready(cx))).await;
369            poll_fn(|cx| srv1.poll_shutdown(cx)).await;
370            time::sleep(time::Millis(25000)).await;
371        });
372        time::sleep(time::Millis(250)).await;
373
374        let res = lazy(|cx| srv2.poll_ready(cx)).await;
375        assert_eq!(res, Poll::Pending);
376
377        let _ = tx.send(());
378        time::sleep(time::Millis(250)).await;
379
380        let res = lazy(|cx| srv2.poll_ready(cx)).await;
381        assert_eq!(res, Poll::Pending);
382
383        con.notify();
384        let res = lazy(|cx| srv2.poll_ready(cx)).await;
385        assert_eq!(res, Poll::Ready(Ok(())));
386    }
387
388    #[ntex::test]
389    #[should_panic]
390    async fn test_pipeline_binding_after_shutdown() {
391        let cnt = Rc::new(Cell::new(0));
392        let con = condition::Condition::new();
393        let srv = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
394        poll_fn(|cx| srv.poll_shutdown(cx)).await;
395        let _ = poll_fn(|cx| srv.poll_ready(cx)).await;
396    }
397
398    #[ntex::test]
399    async fn test_shared_call() {
400        let data = Rc::new(RefCell::new(Vec::new()));
401
402        let cnt = Rc::new(Cell::new(0));
403        let con = condition::Condition::new();
404
405        let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
406        let srv2 = srv1.clone();
407        let _: Pipeline<_> = srv1.pipeline();
408
409        let data1 = data.clone();
410        ntex::rt::spawn(async move {
411            let _ = poll_fn(|cx| srv1.poll_ready(cx)).await;
412            let fut = srv1.call_nowait("srv1");
413            assert!(format!("{fut:?}").contains("PipelineCall"));
414            let i = fut.await.unwrap();
415            data1.borrow_mut().push(i);
416        });
417
418        let data2 = data.clone();
419        ntex::rt::spawn(async move {
420            let i = srv2.call("srv2").await.unwrap();
421            data2.borrow_mut().push(i);
422        });
423        time::sleep(time::Millis(50)).await;
424
425        con.notify();
426        time::sleep(time::Millis(150)).await;
427
428        assert_eq!(cnt.get(), 2);
429        assert_eq!(&*data.borrow(), &["srv1"]);
430
431        con.notify();
432        time::sleep(time::Millis(150)).await;
433
434        assert_eq!(cnt.get(), 2);
435        assert_eq!(&*data.borrow(), &["srv1", "srv2"]);
436    }
437}