ntex_service/
ctx.rs

1use std::task::{Context, Poll, Waker};
2use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc};
3
4use crate::Service;
5
6pub struct ServiceCtx<'a, S: ?Sized> {
7    idx: u32,
8    waiters: &'a WaitersRef,
9    _t: marker::PhantomData<Rc<S>>,
10}
11
12#[derive(Debug)]
13pub(crate) struct WaitersRef {
14    cur: cell::Cell<u32>,
15    shutdown: cell::Cell<bool>,
16    wakers: cell::UnsafeCell<Vec<u32>>,
17    indexes: cell::UnsafeCell<slab::Slab<Option<Waker>>>,
18}
19
20impl WaitersRef {
21    pub(crate) fn new() -> (u32, Self) {
22        let mut waiters = slab::Slab::new();
23
24        (
25            waiters.insert(Default::default()) as u32,
26            WaitersRef {
27                cur: cell::Cell::new(u32::MAX),
28                shutdown: cell::Cell::new(false),
29                indexes: cell::UnsafeCell::new(waiters),
30                wakers: cell::UnsafeCell::new(Vec::default()),
31            },
32        )
33    }
34
35    #[allow(clippy::mut_from_ref)]
36    pub(crate) fn get(&self) -> &mut slab::Slab<Option<Waker>> {
37        unsafe { &mut *self.indexes.get() }
38    }
39
40    #[allow(clippy::mut_from_ref)]
41    pub(crate) fn get_wakers(&self) -> &mut Vec<u32> {
42        unsafe { &mut *self.wakers.get() }
43    }
44
45    pub(crate) fn insert(&self) -> u32 {
46        self.get().insert(None) as u32
47    }
48
49    pub(crate) fn remove(&self, idx: u32) {
50        self.get().remove(idx as usize);
51
52        if self.cur.get() == idx {
53            self.notify();
54        }
55    }
56
57    pub(crate) fn register(&self, idx: u32, cx: &mut Context<'_>) {
58        let wakers = self.get_wakers();
59        if let Some(last) = wakers.last() {
60            if idx == *last {
61                return;
62            }
63        }
64        wakers.push(idx);
65        self.get()[idx as usize] = Some(cx.waker().clone());
66    }
67
68    pub(crate) fn notify(&self) {
69        let wakers = self.get_wakers();
70        if !wakers.is_empty() {
71            let indexes = self.get();
72            for idx in wakers.drain(..) {
73                if let Some(item) = indexes.get_mut(idx as usize) {
74                    if let Some(waker) = item.take() {
75                        waker.wake();
76                    }
77                }
78            }
79        }
80
81        self.cur.set(u32::MAX);
82    }
83
84    pub(crate) fn can_check(&self, idx: u32, cx: &mut Context<'_>) -> bool {
85        let cur = self.cur.get();
86        if cur == idx {
87            true
88        } else if cur == u32::MAX {
89            self.cur.set(idx);
90            true
91        } else {
92            self.register(idx, cx);
93            false
94        }
95    }
96
97    pub(crate) fn shutdown(&self) {
98        self.shutdown.set(true);
99    }
100
101    pub(crate) fn is_shutdown(&self) -> bool {
102        self.shutdown.get()
103    }
104}
105
106impl<'a, S> ServiceCtx<'a, S> {
107    pub(crate) fn new(idx: u32, waiters: &'a WaitersRef) -> Self {
108        Self {
109            idx,
110            waiters,
111            _t: marker::PhantomData,
112        }
113    }
114
115    pub(crate) fn inner(self) -> (u32, &'a WaitersRef) {
116        (self.idx, self.waiters)
117    }
118
119    /// Returns when the service is able to process requests.
120    pub async fn ready<T, R>(&self, svc: &'a T) -> Result<(), T::Error>
121    where
122        T: Service<R>,
123    {
124        // check readiness and notify waiters
125        ReadyCall {
126            completed: false,
127            fut: svc.ready(ServiceCtx {
128                idx: self.idx,
129                waiters: self.waiters,
130                _t: marker::PhantomData,
131            }),
132            ctx: *self,
133        }
134        .await
135    }
136
137    #[inline]
138    /// Wait for service readiness and then call service
139    pub async fn call<T, R>(&self, svc: &'a T, req: R) -> Result<T::Response, T::Error>
140    where
141        T: Service<R>,
142        R: 'a,
143    {
144        self.ready(svc).await?;
145
146        svc.call(
147            req,
148            ServiceCtx {
149                idx: self.idx,
150                waiters: self.waiters,
151                _t: marker::PhantomData,
152            },
153        )
154        .await
155    }
156
157    #[inline]
158    /// Call service, do not check service readiness
159    pub async fn call_nowait<T, R>(
160        &self,
161        svc: &'a T,
162        req: R,
163    ) -> Result<T::Response, T::Error>
164    where
165        T: Service<R>,
166        R: 'a,
167    {
168        svc.call(
169            req,
170            ServiceCtx {
171                idx: self.idx,
172                waiters: self.waiters,
173                _t: marker::PhantomData,
174            },
175        )
176        .await
177    }
178}
179
180impl<S> Copy for ServiceCtx<'_, S> {}
181
182impl<S> Clone for ServiceCtx<'_, S> {
183    #[inline]
184    fn clone(&self) -> Self {
185        *self
186    }
187}
188
189impl<S> fmt::Debug for ServiceCtx<'_, S> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        f.debug_struct("ServiceCtx")
192            .field("idx", &self.idx)
193            .field("waiters", &self.waiters.get().len())
194            .finish()
195    }
196}
197
198struct ReadyCall<'a, S: ?Sized, F: Future> {
199    completed: bool,
200    fut: F,
201    ctx: ServiceCtx<'a, S>,
202}
203
204impl<S: ?Sized, F: Future> Drop for ReadyCall<'_, S, F> {
205    fn drop(&mut self) {
206        if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx {
207            self.ctx.waiters.notify();
208        }
209    }
210}
211
212impl<S: ?Sized, F: Future> Unpin for ReadyCall<'_, S, F> {}
213
214impl<S: ?Sized, F: Future> Future for ReadyCall<'_, S, F> {
215    type Output = F::Output;
216
217    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
218        if self.ctx.waiters.can_check(self.ctx.idx, cx) {
219            // SAFETY: `fut` never moves
220            let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) };
221            match result {
222                Poll::Pending => {
223                    self.ctx.waiters.register(self.ctx.idx, cx);
224                    Poll::Pending
225                }
226                Poll::Ready(res) => {
227                    self.completed = true;
228                    self.ctx.waiters.notify();
229                    Poll::Ready(res)
230                }
231            }
232        } else {
233            Poll::Pending
234        }
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use std::{cell::Cell, cell::RefCell, future::poll_fn};
241
242    use ntex_util::channel::{condition, oneshot};
243    use ntex_util::{future::lazy, future::select, spawn, time};
244
245    use super::*;
246    use crate::Pipeline;
247
248    struct Srv(Rc<Cell<usize>>, condition::Waiter);
249
250    impl Service<&'static str> for Srv {
251        type Response = &'static str;
252        type Error = ();
253
254        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
255            self.0.set(self.0.get() + 1);
256            self.1.ready().await;
257            Ok(())
258        }
259
260        async fn call(
261            &self,
262            req: &'static str,
263            ctx: ServiceCtx<'_, Self>,
264        ) -> Result<Self::Response, Self::Error> {
265            let _ = format!("{:?}", ctx);
266            #[allow(clippy::clone_on_copy)]
267            let _ = ctx.clone();
268            Ok(req)
269        }
270    }
271
272    #[ntex::test]
273    async fn test_ready() {
274        let cnt = Rc::new(Cell::new(0));
275        let con = condition::Condition::new();
276
277        let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
278        let srv2 = srv1.clone();
279
280        let res = lazy(|cx| srv1.poll_ready(cx)).await;
281        assert_eq!(res, Poll::Pending);
282        assert_eq!(cnt.get(), 1);
283
284        let res = lazy(|cx| srv2.poll_ready(cx)).await;
285        assert_eq!(res, Poll::Pending);
286        assert_eq!(cnt.get(), 1);
287
288        con.notify();
289        let res = lazy(|cx| srv1.poll_ready(cx)).await;
290        assert_eq!(res, Poll::Ready(Ok(())));
291        assert_eq!(cnt.get(), 1);
292
293        let res = lazy(|cx| srv2.poll_ready(cx)).await;
294        assert_eq!(res, Poll::Pending);
295        assert_eq!(cnt.get(), 2);
296
297        con.notify();
298        let res = lazy(|cx| srv2.poll_ready(cx)).await;
299        assert_eq!(res, Poll::Ready(Ok(())));
300        assert_eq!(cnt.get(), 2);
301
302        let res = lazy(|cx| srv1.poll_ready(cx)).await;
303        assert_eq!(res, Poll::Pending);
304        assert_eq!(cnt.get(), 3);
305    }
306
307    #[ntex::test]
308    async fn test_ready_on_drop() {
309        let cnt = Rc::new(Cell::new(0));
310        let con = condition::Condition::new();
311        let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
312
313        let srv1 = srv.clone();
314        let srv2 = srv1.clone().bind();
315
316        let (tx, rx) = oneshot::channel();
317        spawn(async move {
318            select(rx, srv1.ready()).await;
319            time::sleep(time::Millis(25000)).await;
320            drop(srv1);
321        });
322        time::sleep(time::Millis(250)).await;
323
324        let res = lazy(|cx| srv2.poll_ready(cx)).await;
325        assert_eq!(res, Poll::Pending);
326
327        let _ = tx.send(());
328        time::sleep(time::Millis(250)).await;
329
330        let res = lazy(|cx| srv2.poll_ready(cx)).await;
331        assert_eq!(res, Poll::Pending);
332
333        con.notify();
334        let res = lazy(|cx| srv2.poll_ready(cx)).await;
335        assert_eq!(res, Poll::Ready(Ok(())));
336    }
337
338    #[ntex::test]
339    async fn test_ready_after_shutdown() {
340        let cnt = Rc::new(Cell::new(0));
341        let con = condition::Condition::new();
342        let srv = Pipeline::from(Srv(cnt.clone(), con.wait()));
343
344        let srv1 = srv.clone().bind();
345        let srv2 = srv1.clone();
346
347        let (tx, rx) = oneshot::channel();
348        spawn(async move {
349            select(rx, poll_fn(|cx| srv1.poll_ready(cx))).await;
350            poll_fn(|cx| srv1.poll_shutdown(cx)).await;
351            time::sleep(time::Millis(25000)).await;
352            drop(srv1);
353        });
354        time::sleep(time::Millis(250)).await;
355
356        let res = lazy(|cx| srv2.poll_ready(cx)).await;
357        assert_eq!(res, Poll::Pending);
358
359        let _ = tx.send(());
360        time::sleep(time::Millis(250)).await;
361
362        let res = lazy(|cx| srv2.poll_ready(cx)).await;
363        assert_eq!(res, Poll::Pending);
364
365        con.notify();
366        let res = lazy(|cx| srv2.poll_ready(cx)).await;
367        assert_eq!(res, Poll::Ready(Ok(())));
368    }
369
370    #[ntex::test]
371    async fn test_shared_call() {
372        let data = Rc::new(RefCell::new(Vec::new()));
373
374        let cnt = Rc::new(Cell::new(0));
375        let con = condition::Condition::new();
376
377        let srv1 = Pipeline::from(Srv(cnt.clone(), con.wait())).bind();
378        let srv2 = srv1.clone();
379
380        let data1 = data.clone();
381        ntex::rt::spawn(async move {
382            let _ = poll_fn(|cx| srv1.poll_ready(cx)).await;
383            let i = srv1.call_nowait("srv1").await.unwrap();
384            data1.borrow_mut().push(i);
385        });
386
387        let data2 = data.clone();
388        ntex::rt::spawn(async move {
389            let i = srv2.call("srv2").await.unwrap();
390            data2.borrow_mut().push(i);
391        });
392        time::sleep(time::Millis(50)).await;
393
394        con.notify();
395        time::sleep(time::Millis(150)).await;
396
397        assert_eq!(cnt.get(), 2);
398        assert_eq!(&*data.borrow(), &["srv1"]);
399
400        con.notify();
401        time::sleep(time::Millis(150)).await;
402
403        assert_eq!(cnt.get(), 2);
404        assert_eq!(&*data.borrow(), &["srv1", "srv2"]);
405    }
406}