forked_tarpc/server/limits/
requests_per_channel.rs1use crate::{
8 server::{Channel, Config},
9 Response, ServerError,
10};
11use futures::{prelude::*, ready, task::*};
12use pin_project::pin_project;
13use std::{io, pin::Pin};
14
15#[pin_project]
21#[derive(Debug)]
22pub struct MaxRequests<C> {
23 max_in_flight_requests: usize,
24 #[pin]
25 inner: C,
26}
27
28impl<C> MaxRequests<C> {
29 pub fn get_ref(&self) -> &C {
31 &self.inner
32 }
33}
34
35impl<C> MaxRequests<C>
36where
37 C: Channel,
38{
39 pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
42 MaxRequests {
43 max_in_flight_requests,
44 inner,
45 }
46 }
47}
48
49impl<C> Stream for MaxRequests<C>
50where
51 C: Channel,
52{
53 type Item = <C as Stream>::Item;
54
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
56 while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
57 {
58 ready!(self.as_mut().project().inner.poll_ready(cx)?);
59
60 match ready!(self.as_mut().project().inner.poll_next(cx)?) {
61 Some(r) => {
62 let _entered = r.span.enter();
63 tracing::info!(
64 in_flight_requests = self.as_mut().in_flight_requests(),
65 "ThrottleRequest",
66 );
67
68 self.as_mut().start_send(Response {
69 request_id: r.request.id,
70 message: Err(ServerError {
71 kind: io::ErrorKind::WouldBlock,
72 detail: "server throttled the request.".into(),
73 }),
74 })?;
75 }
76 None => return Poll::Ready(None),
77 }
78 }
79 self.project().inner.poll_next(cx)
80 }
81}
82
83impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
84where
85 C: Channel,
86{
87 type Error = C::Error;
88
89 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
90 self.project().inner.poll_ready(cx)
91 }
92
93 fn start_send(
94 self: Pin<&mut Self>,
95 item: Response<<C as Channel>::Resp>,
96 ) -> Result<(), Self::Error> {
97 self.project().inner.start_send(item)
98 }
99
100 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
101 self.project().inner.poll_flush(cx)
102 }
103
104 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
105 self.project().inner.poll_close(cx)
106 }
107}
108
109impl<C> AsRef<C> for MaxRequests<C> {
110 fn as_ref(&self) -> &C {
111 &self.inner
112 }
113}
114
115impl<C> Channel for MaxRequests<C>
116where
117 C: Channel,
118{
119 type Req = <C as Channel>::Req;
120 type Resp = <C as Channel>::Resp;
121 type Transport = <C as Channel>::Transport;
122
123 fn in_flight_requests(&self) -> usize {
124 self.inner.in_flight_requests()
125 }
126
127 fn config(&self) -> &Config {
128 self.inner.config()
129 }
130
131 fn transport(&self) -> &Self::Transport {
132 self.inner.transport()
133 }
134}
135
136#[pin_project]
139#[derive(Debug)]
140pub struct MaxRequestsPerChannel<S> {
141 #[pin]
142 inner: S,
143 max_in_flight_requests: usize,
144}
145
146impl<S> MaxRequestsPerChannel<S>
147where
148 S: Stream,
149 <S as Stream>::Item: Channel,
150{
151 pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
152 Self {
153 inner,
154 max_in_flight_requests,
155 }
156 }
157}
158
159impl<S> Stream for MaxRequestsPerChannel<S>
160where
161 S: Stream,
162 <S as Stream>::Item: Channel,
163{
164 type Item = MaxRequests<<S as Stream>::Item>;
165
166 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
167 match ready!(self.as_mut().project().inner.poll_next(cx)) {
168 Some(channel) => Poll::Ready(Some(MaxRequests::new(
169 channel,
170 *self.project().max_in_flight_requests,
171 ))),
172 None => Poll::Ready(None),
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 use crate::server::{
182 testing::{self, FakeChannel, PollExt},
183 TrackedRequest,
184 };
185 use pin_utils::pin_mut;
186 use std::{
187 marker::PhantomData,
188 time::{Duration, SystemTime},
189 };
190 use tracing::Span;
191
192 #[tokio::test]
193 async fn throttler_in_flight_requests() {
194 let throttler = MaxRequests {
195 max_in_flight_requests: 0,
196 inner: FakeChannel::default::<isize, isize>(),
197 };
198
199 pin_mut!(throttler);
200 for i in 0..5 {
201 throttler
202 .inner
203 .in_flight_requests
204 .start_request(
205 i,
206 SystemTime::now() + Duration::from_secs(1),
207 Span::current(),
208 )
209 .unwrap();
210 }
211 assert_eq!(throttler.as_mut().in_flight_requests(), 5);
212 }
213
214 #[test]
215 fn throttler_poll_next_done() {
216 let throttler = MaxRequests {
217 max_in_flight_requests: 0,
218 inner: FakeChannel::default::<isize, isize>(),
219 };
220
221 pin_mut!(throttler);
222 assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
223 }
224
225 #[test]
226 fn throttler_poll_next_some() -> io::Result<()> {
227 let throttler = MaxRequests {
228 max_in_flight_requests: 1,
229 inner: FakeChannel::default::<isize, isize>(),
230 };
231
232 pin_mut!(throttler);
233 throttler.inner.push_req(0, 1);
234 assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
235 assert_eq!(
236 throttler
237 .as_mut()
238 .poll_next(&mut testing::cx())?
239 .map(|r| r.map(|r| (r.request.id, r.request.message))),
240 Poll::Ready(Some((0, 1)))
241 );
242 Ok(())
243 }
244
245 #[test]
246 fn throttler_poll_next_throttled() {
247 let throttler = MaxRequests {
248 max_in_flight_requests: 0,
249 inner: FakeChannel::default::<isize, isize>(),
250 };
251
252 pin_mut!(throttler);
253 throttler.inner.push_req(1, 1);
254 assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
255 assert_eq!(throttler.inner.sink.len(), 1);
256 let resp = throttler.inner.sink.get(0).unwrap();
257 assert_eq!(resp.request_id, 1);
258 assert!(resp.message.is_err());
259 }
260
261 #[test]
262 fn throttler_poll_next_throttled_sink_not_ready() {
263 let throttler = MaxRequests {
264 max_in_flight_requests: 0,
265 inner: PendingSink::default::<isize, isize>(),
266 };
267 pin_mut!(throttler);
268 assert!(throttler.poll_next(&mut testing::cx()).is_pending());
269
270 struct PendingSink<In, Out> {
271 ghost: PhantomData<fn(Out) -> In>,
272 }
273 impl PendingSink<(), ()> {
274 pub fn default<Req, Resp>(
275 ) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
276 PendingSink { ghost: PhantomData }
277 }
278 }
279 impl<In, Out> Stream for PendingSink<In, Out> {
280 type Item = In;
281 fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
282 unimplemented!()
283 }
284 }
285 impl<In, Out> Sink<Out> for PendingSink<In, Out> {
286 type Error = io::Error;
287 fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
288 Poll::Pending
289 }
290 fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
291 Err(io::Error::from(io::ErrorKind::WouldBlock))
292 }
293 fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
294 Poll::Pending
295 }
296 fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
297 Poll::Pending
298 }
299 }
300 impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
301 type Req = Req;
302 type Resp = Resp;
303 type Transport = ();
304 fn config(&self) -> &Config {
305 unimplemented!()
306 }
307 fn in_flight_requests(&self) -> usize {
308 0
309 }
310 fn transport(&self) -> &() {
311 &()
312 }
313 }
314 }
315
316 #[tokio::test]
317 async fn throttler_start_send() {
318 let throttler = MaxRequests {
319 max_in_flight_requests: 0,
320 inner: FakeChannel::default::<isize, isize>(),
321 };
322
323 pin_mut!(throttler);
324 throttler
325 .inner
326 .in_flight_requests
327 .start_request(
328 0,
329 SystemTime::now() + Duration::from_secs(1),
330 Span::current(),
331 )
332 .unwrap();
333 throttler
334 .as_mut()
335 .start_send(Response {
336 request_id: 0,
337 message: Ok(1),
338 })
339 .unwrap();
340 assert_eq!(throttler.inner.in_flight_requests.len(), 0);
341 assert_eq!(
342 throttler.inner.sink.get(0),
343 Some(&Response {
344 request_id: 0,
345 message: Ok(1),
346 })
347 );
348 }
349}