1use crate::Algorithm;
2use crate::classifier::{Classifier, DefaultClassifier};
3use crate::controller::Controller;
4use crate::future::ResponseFuture;
5
6use tokio::sync::OwnedSemaphorePermit;
7use tokio_util::sync::PollSemaphore;
8use tower_service::Service;
9
10use std::{
11 sync::{Arc, Mutex},
12 task::{Context, Poll, ready},
13 time::Instant,
14};
15
16pub struct ConcurrencyLimit<S, A, C = DefaultClassifier> {
26 inner: S,
27 classifier: C,
28 controller: Arc<Mutex<Controller<A>>>,
29 semaphore: PollSemaphore,
30 permit: Option<OwnedSemaphorePermit>,
36}
37
38impl<S, A> ConcurrencyLimit<S, A>
39where
40 A: Algorithm,
41{
42 pub fn new(inner: S, algorithm: A) -> Self {
44 Self::with_classifier(inner, algorithm, DefaultClassifier)
45 }
46}
47
48impl<S, A, C> ConcurrencyLimit<S, A, C>
49where
50 A: Algorithm,
51{
52 pub fn with_classifier(inner: S, algorithm: A, classifier: C) -> Self {
54 let controller = Controller::new(algorithm);
55 let semaphore = controller.semaphore();
56
57 Self {
58 inner,
59 classifier,
60 controller: Arc::new(Mutex::new(controller)),
61 semaphore: PollSemaphore::new(semaphore),
62 permit: None,
63 }
64 }
65
66 pub fn get_ref(&self) -> &S {
68 &self.inner
69 }
70
71 pub fn get_mut(&mut self) -> &mut S {
73 &mut self.inner
74 }
75
76 pub fn into_inner(self) -> S {
78 self.inner
79 }
80}
81
82impl<S, A, C, Request> Service<Request> for ConcurrencyLimit<S, A, C>
83where
84 S: Service<Request>,
85 A: Algorithm,
86 C: Classifier<S::Response, S::Error> + Clone,
87{
88 type Response = S::Response;
89 type Error = S::Error;
90 type Future = ResponseFuture<S::Future, A, C>;
91
92 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93 if self.permit.is_none() {
94 self.permit = ready!(self.semaphore.poll_acquire(cx));
95 debug_assert!(self.permit.is_some(), "semaphore should never be closed");
96 }
97 self.inner.poll_ready(cx)
100 }
101
102 fn call(&mut self, request: Request) -> Self::Future {
103 let start = Instant::now();
104 let permit = self
106 .permit
107 .take()
108 .expect("`poll_ready` should be called first");
109
110 let future = self.inner.call(request);
112 ResponseFuture::new(
113 future,
114 self.controller.clone(),
115 permit,
116 start,
117 self.classifier.clone(),
118 )
119 }
120}
121
122impl<S: Clone, A, C: Clone> Clone for ConcurrencyLimit<S, A, C> {
123 fn clone(&self) -> Self {
124 Self {
128 inner: self.inner.clone(),
129 classifier: self.classifier.clone(),
130 controller: self.controller.clone(),
131 semaphore: self.semaphore.clone(),
132 permit: None,
133 }
134 }
135}
136
137impl<S: std::fmt::Debug, A, C> std::fmt::Debug for ConcurrencyLimit<S, A, C> {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 f.debug_struct("ConcurrencyLimit")
140 .field("inner", &self.inner)
141 .field("permit", &self.permit)
142 .finish_non_exhaustive()
143 }
144}