Skip to main content

tower_acc/
service.rs

1use crate::Algorithm;
2use crate::controller::Controller;
3use crate::future::ResponseFuture;
4
5use tokio::sync::OwnedSemaphorePermit;
6use tokio_util::sync::PollSemaphore;
7use tower_service::Service;
8
9use std::{
10    sync::{Arc, Mutex},
11    task::{Context, Poll, ready},
12    time::Instant,
13};
14
15/// Enforces an adaptive limit on the concurrent number of requests the
16/// underlying service can handle.
17///
18/// Unlike a static concurrency limit, `ConcurrencyLimit` continuously observes
19/// request latency and adjusts the number of allowed in-flight requests using
20/// the configured [`Algorithm`].
21///
22/// Use [`ConcurrencyLimitLayer`](crate::ConcurrencyLimitLayer) to integrate
23/// with [`tower::ServiceBuilder`].
24pub struct ConcurrencyLimit<S, A> {
25    inner: S,
26    controller: Arc<Mutex<Controller<A>>>,
27    semaphore: PollSemaphore,
28    /// The currently acquired semaphore permit, if there is sufficient
29    /// concurrency to send a new request.
30    ///
31    /// The permit is acquired in `poll_ready`, and taken in `call` when sending
32    /// a new request.
33    permit: Option<OwnedSemaphorePermit>,
34}
35
36impl<S, A> ConcurrencyLimit<S, A>
37where
38    A: Algorithm,
39{
40    /// Creates a new concurrency limiter.
41    pub fn new(inner: S, algorithm: A) -> Self {
42        let controller = Controller::new(algorithm);
43        let semaphore = controller.semaphore();
44
45        Self {
46            inner,
47            controller: Arc::new(Mutex::new(controller)),
48            semaphore: PollSemaphore::new(semaphore),
49            permit: None,
50        }
51    }
52
53    /// Gets a reference to the inner service.
54    pub fn get_ref(&self) -> &S {
55        &self.inner
56    }
57
58    /// Gets a mutable reference to the inner service.
59    pub fn get_mut(&mut self) -> &mut S {
60        &mut self.inner
61    }
62
63    /// Consumes `self`, returning the inner service.
64    pub fn into_inner(self) -> S {
65        self.inner
66    }
67}
68
69impl<S, A, Request> Service<Request> for ConcurrencyLimit<S, A>
70where
71    S: Service<Request>,
72    A: Algorithm,
73{
74    type Response = S::Response;
75    type Error = S::Error;
76    type Future = ResponseFuture<S::Future, A>;
77
78    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79        if self.permit.is_none() {
80            self.permit = ready!(self.semaphore.poll_acquire(cx));
81            debug_assert!(self.permit.is_some(), "semaphore should never be closed");
82        }
83        // Once we've acquired a permit (or if we already had one), poll the
84        // inner service.
85        self.inner.poll_ready(cx)
86    }
87
88    fn call(&mut self, request: Request) -> Self::Future {
89        let start = Instant::now();
90        // Take the permit
91        let permit = self
92            .permit
93            .take()
94            .expect("`poll_ready` should be called first");
95
96        // Call the inner service
97        let future = self.inner.call(request);
98        ResponseFuture::new(future, self.controller.clone(), permit, start)
99    }
100}
101
102impl<S: Clone, A> Clone for ConcurrencyLimit<S, A> {
103    fn clone(&self) -> Self {
104        // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`.
105        // Instead, when cloning the service, create a new service with the
106        // same semaphore, but with the permit in the un-acquired state.
107        Self {
108            inner: self.inner.clone(),
109            controller: self.controller.clone(),
110            semaphore: self.semaphore.clone(),
111            permit: None,
112        }
113    }
114}
115
116impl<S: std::fmt::Debug, A> std::fmt::Debug for ConcurrencyLimit<S, A> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("ConcurrencyLimit")
119            .field("inner", &self.inner)
120            .field("permit", &self.permit)
121            .finish_non_exhaustive()
122    }
123}