forked_tarpc/server/limits/
requests_per_channel.rs

1// Copyright 2020 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7use crate::{
8    server::{Channel, Config},
9    Response, ServerError,
10};
11use futures::{prelude::*, ready, task::*};
12use pin_project::pin_project;
13use std::{io, pin::Pin};
14
15/// A [`Channel`] that limits the number of concurrent requests by throttling.
16///
17/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
18/// for the resources available to the server. For production use cases, a more advanced throttler
19/// is likely needed.
20#[pin_project]
21#[derive(Debug)]
22pub struct MaxRequests<C> {
23    max_in_flight_requests: usize,
24    #[pin]
25    inner: C,
26}
27
28impl<C> MaxRequests<C> {
29    /// Returns the inner channel.
30    pub fn get_ref(&self) -> &C {
31        &self.inner
32    }
33}
34
35impl<C> MaxRequests<C>
36where
37    C: Channel,
38{
39    /// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
40    /// `max_in_flight_requests`.
41    pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
42        MaxRequests {
43            max_in_flight_requests,
44            inner,
45        }
46    }
47}
48
49impl<C> Stream for MaxRequests<C>
50where
51    C: Channel,
52{
53    type Item = <C as Stream>::Item;
54
55    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
56        while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
57        {
58            ready!(self.as_mut().project().inner.poll_ready(cx)?);
59
60            match ready!(self.as_mut().project().inner.poll_next(cx)?) {
61                Some(r) => {
62                    let _entered = r.span.enter();
63                    tracing::info!(
64                        in_flight_requests = self.as_mut().in_flight_requests(),
65                        "ThrottleRequest",
66                    );
67
68                    self.as_mut().start_send(Response {
69                        request_id: r.request.id,
70                        message: Err(ServerError {
71                            kind: io::ErrorKind::WouldBlock,
72                            detail: "server throttled the request.".into(),
73                        }),
74                    })?;
75                }
76                None => return Poll::Ready(None),
77            }
78        }
79        self.project().inner.poll_next(cx)
80    }
81}
82
83impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
84where
85    C: Channel,
86{
87    type Error = C::Error;
88
89    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
90        self.project().inner.poll_ready(cx)
91    }
92
93    fn start_send(
94        self: Pin<&mut Self>,
95        item: Response<<C as Channel>::Resp>,
96    ) -> Result<(), Self::Error> {
97        self.project().inner.start_send(item)
98    }
99
100    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
101        self.project().inner.poll_flush(cx)
102    }
103
104    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
105        self.project().inner.poll_close(cx)
106    }
107}
108
109impl<C> AsRef<C> for MaxRequests<C> {
110    fn as_ref(&self) -> &C {
111        &self.inner
112    }
113}
114
115impl<C> Channel for MaxRequests<C>
116where
117    C: Channel,
118{
119    type Req = <C as Channel>::Req;
120    type Resp = <C as Channel>::Resp;
121    type Transport = <C as Channel>::Transport;
122
123    fn in_flight_requests(&self) -> usize {
124        self.inner.in_flight_requests()
125    }
126
127    fn config(&self) -> &Config {
128        self.inner.config()
129    }
130
131    fn transport(&self) -> &Self::Transport {
132        self.inner.transport()
133    }
134}
135
136/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
137/// the number of in-flight requests.
138#[pin_project]
139#[derive(Debug)]
140pub struct MaxRequestsPerChannel<S> {
141    #[pin]
142    inner: S,
143    max_in_flight_requests: usize,
144}
145
146impl<S> MaxRequestsPerChannel<S>
147where
148    S: Stream,
149    <S as Stream>::Item: Channel,
150{
151    pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
152        Self {
153            inner,
154            max_in_flight_requests,
155        }
156    }
157}
158
159impl<S> Stream for MaxRequestsPerChannel<S>
160where
161    S: Stream,
162    <S as Stream>::Item: Channel,
163{
164    type Item = MaxRequests<<S as Stream>::Item>;
165
166    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
167        match ready!(self.as_mut().project().inner.poll_next(cx)) {
168            Some(channel) => Poll::Ready(Some(MaxRequests::new(
169                channel,
170                *self.project().max_in_flight_requests,
171            ))),
172            None => Poll::Ready(None),
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    use crate::server::{
182        testing::{self, FakeChannel, PollExt},
183        TrackedRequest,
184    };
185    use pin_utils::pin_mut;
186    use std::{
187        marker::PhantomData,
188        time::{Duration, SystemTime},
189    };
190    use tracing::Span;
191
192    #[tokio::test]
193    async fn throttler_in_flight_requests() {
194        let throttler = MaxRequests {
195            max_in_flight_requests: 0,
196            inner: FakeChannel::default::<isize, isize>(),
197        };
198
199        pin_mut!(throttler);
200        for i in 0..5 {
201            throttler
202                .inner
203                .in_flight_requests
204                .start_request(
205                    i,
206                    SystemTime::now() + Duration::from_secs(1),
207                    Span::current(),
208                )
209                .unwrap();
210        }
211        assert_eq!(throttler.as_mut().in_flight_requests(), 5);
212    }
213
214    #[test]
215    fn throttler_poll_next_done() {
216        let throttler = MaxRequests {
217            max_in_flight_requests: 0,
218            inner: FakeChannel::default::<isize, isize>(),
219        };
220
221        pin_mut!(throttler);
222        assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
223    }
224
225    #[test]
226    fn throttler_poll_next_some() -> io::Result<()> {
227        let throttler = MaxRequests {
228            max_in_flight_requests: 1,
229            inner: FakeChannel::default::<isize, isize>(),
230        };
231
232        pin_mut!(throttler);
233        throttler.inner.push_req(0, 1);
234        assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
235        assert_eq!(
236            throttler
237                .as_mut()
238                .poll_next(&mut testing::cx())?
239                .map(|r| r.map(|r| (r.request.id, r.request.message))),
240            Poll::Ready(Some((0, 1)))
241        );
242        Ok(())
243    }
244
245    #[test]
246    fn throttler_poll_next_throttled() {
247        let throttler = MaxRequests {
248            max_in_flight_requests: 0,
249            inner: FakeChannel::default::<isize, isize>(),
250        };
251
252        pin_mut!(throttler);
253        throttler.inner.push_req(1, 1);
254        assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
255        assert_eq!(throttler.inner.sink.len(), 1);
256        let resp = throttler.inner.sink.get(0).unwrap();
257        assert_eq!(resp.request_id, 1);
258        assert!(resp.message.is_err());
259    }
260
261    #[test]
262    fn throttler_poll_next_throttled_sink_not_ready() {
263        let throttler = MaxRequests {
264            max_in_flight_requests: 0,
265            inner: PendingSink::default::<isize, isize>(),
266        };
267        pin_mut!(throttler);
268        assert!(throttler.poll_next(&mut testing::cx()).is_pending());
269
270        struct PendingSink<In, Out> {
271            ghost: PhantomData<fn(Out) -> In>,
272        }
273        impl PendingSink<(), ()> {
274            pub fn default<Req, Resp>(
275            ) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
276                PendingSink { ghost: PhantomData }
277            }
278        }
279        impl<In, Out> Stream for PendingSink<In, Out> {
280            type Item = In;
281            fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
282                unimplemented!()
283            }
284        }
285        impl<In, Out> Sink<Out> for PendingSink<In, Out> {
286            type Error = io::Error;
287            fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
288                Poll::Pending
289            }
290            fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
291                Err(io::Error::from(io::ErrorKind::WouldBlock))
292            }
293            fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
294                Poll::Pending
295            }
296            fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
297                Poll::Pending
298            }
299        }
300        impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
301            type Req = Req;
302            type Resp = Resp;
303            type Transport = ();
304            fn config(&self) -> &Config {
305                unimplemented!()
306            }
307            fn in_flight_requests(&self) -> usize {
308                0
309            }
310            fn transport(&self) -> &() {
311                &()
312            }
313        }
314    }
315
316    #[tokio::test]
317    async fn throttler_start_send() {
318        let throttler = MaxRequests {
319            max_in_flight_requests: 0,
320            inner: FakeChannel::default::<isize, isize>(),
321        };
322
323        pin_mut!(throttler);
324        throttler
325            .inner
326            .in_flight_requests
327            .start_request(
328                0,
329                SystemTime::now() + Duration::from_secs(1),
330                Span::current(),
331            )
332            .unwrap();
333        throttler
334            .as_mut()
335            .start_send(Response {
336                request_id: 0,
337                message: Ok(1),
338            })
339            .unwrap();
340        assert_eq!(throttler.inner.in_flight_requests.len(), 0);
341        assert_eq!(
342            throttler.inner.sink.get(0),
343            Some(&Response {
344                request_id: 0,
345                message: Ok(1),
346            })
347        );
348    }
349}