ntex-mqtt 0.8.10

Client and Server framework for MQTT v5 and v3.1.1 protocols
Documentation
//! Service that limits number of in-flight async requests.
use std::{cell::Cell, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};

use ntex::{service::Service, task::LocalWaker};

pub(crate) trait SizedRequest {
    fn size(&self) -> u32;
}

pub(crate) struct InFlightService<S> {
    count: Counter,
    service: S,
}

impl<S> InFlightService<S> {
    pub fn new(max_cap: u16, max_size: usize, service: S) -> Self {
        Self { service, count: Counter::new(max_cap, max_size) }
    }
}

impl<T, R> Service<R> for InFlightService<T>
where
    T: Service<R>,
    R: SizedRequest,
{
    type Response = T::Response;
    type Error = T::Error;
    type Future = InFlightServiceResponse<T, R>;

    #[inline]
    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        if self.service.poll_ready(cx)?.is_pending() {
            Poll::Pending
        } else if !self.count.available(cx) {
            log::trace!("InFlight limit exceeded");
            Poll::Pending
        } else {
            Poll::Ready(Ok(()))
        }
    }

    #[inline]
    fn poll_shutdown(&self, cx: &mut Context<'_>, is_error: bool) -> Poll<()> {
        self.service.poll_shutdown(cx, is_error)
    }

    #[inline]
    fn call(&self, req: R) -> Self::Future {
        let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
        InFlightServiceResponse {
            _guard: self.count.get(size),
            _t: marker::PhantomData,
            fut: self.service.call(req),
        }
    }
}

pin_project_lite::pin_project! {
    #[doc(hidden)]
    pub struct InFlightServiceResponse<T: Service<R>, R> {
        #[pin]
        fut: T::Future,
        _guard: CounterGuard,
        _t: marker::PhantomData<R>
    }
}

impl<T: Service<R>, R> Future for InFlightServiceResponse<T, R> {
    type Output = Result<T::Response, T::Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.project().fut.poll(cx)
    }
}

struct Counter(Rc<CounterInner>);

struct CounterInner {
    max_cap: u16,
    cur_cap: Cell<u16>,
    max_size: usize,
    cur_size: Cell<usize>,
    task: LocalWaker,
}

impl Counter {
    fn new(max_cap: u16, max_size: usize) -> Self {
        Counter(Rc::new(CounterInner {
            max_cap,
            max_size,
            cur_cap: Cell::new(0),
            cur_size: Cell::new(0),
            task: LocalWaker::new(),
        }))
    }

    fn get(&self, size: u32) -> CounterGuard {
        CounterGuard::new(size, self.0.clone())
    }

    fn available(&self, cx: &mut Context<'_>) -> bool {
        self.0.available(cx)
    }
}

struct CounterGuard(u32, Rc<CounterInner>);

impl CounterGuard {
    fn new(size: u32, inner: Rc<CounterInner>) -> Self {
        inner.inc(size);
        CounterGuard(size, inner)
    }
}

impl Unpin for CounterGuard {}

impl Drop for CounterGuard {
    fn drop(&mut self) {
        self.1.dec(self.0);
    }
}

impl CounterInner {
    fn inc(&self, size: u32) {
        self.cur_cap.set(self.cur_cap.get() + 1);
        self.cur_size.set(self.cur_size.get() + size as usize);
    }

    fn dec(&self, size: u32) {
        let num = self.cur_cap.get();
        self.cur_cap.set(num - 1);

        let cur_size = self.cur_size.get();
        let new_size = cur_size - (size as usize);
        self.cur_size.set(new_size);

        if num == self.max_cap || (cur_size > self.max_size && new_size <= self.max_size) {
            self.task.wake();
        }
    }

    fn available(&self, cx: &mut Context<'_>) -> bool {
        if (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
            && (self.max_size == 0 || self.cur_size.get() <= self.max_size)
        {
            true
        } else {
            self.task.register(cx.waker());
            false
        }
    }
}

#[cfg(test)]
mod tests {
    use ntex::{service::Service, time::sleep, util::lazy};
    use std::{task::Context, task::Poll, time::Duration};

    use super::*;

    struct SleepService(Duration);

    impl Service<()> for SleepService {
        type Response = ();
        type Error = ();
        type Future = Pin<Box<dyn Future<Output = Result<(), ()>>>>;

        fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn call(&self, _: ()) -> Self::Future {
            let fut = sleep(self.0);
            Box::pin(async move {
                let _ = fut.await;
                Ok::<_, ()>(())
            })
        }
    }

    impl SizedRequest for () {
        fn size(&self) -> u32 {
            12
        }
    }

    #[ntex::test]
    async fn test_inflight() {
        let wait_time = Duration::from_millis(50);

        let srv = InFlightService::new(1, 0, SleepService(wait_time));
        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

        let res = srv.call(());
        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);

        let _ = res.await;
        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
        assert!(lazy(|cx| srv.poll_shutdown(cx, false)).await.is_ready());
    }

    #[ntex::test]
    async fn test_inflight2() {
        let wait_time = Duration::from_millis(50);

        let srv = InFlightService::new(0, 10, SleepService(wait_time));
        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

        let res = srv.call(());
        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);

        let _ = res.await;
        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
    }
}