forked_tarpc/server/limits/
channels_per_key.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7use crate::{
8    server::{self, Channel},
9    util::Compact,
10};
11use fnv::FnvHashMap;
12use futures::{prelude::*, ready, stream::Fuse, task::*};
13use pin_project::pin_project;
14use std::sync::{Arc, Weak};
15use std::{
16    collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
17};
18use tokio::sync::mpsc;
19use tracing::{debug, info, trace};
20
21/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
22/// per-key limits.
23///
24/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
25/// Channel is yielded, it will not be prematurely dropped.
26#[pin_project]
27#[derive(Debug)]
28pub struct MaxChannelsPerKey<S, K, F>
29where
30    K: Eq + Hash,
31{
32    #[pin]
33    listener: Fuse<S>,
34    channels_per_key: u32,
35    dropped_keys: mpsc::UnboundedReceiver<K>,
36    dropped_keys_tx: mpsc::UnboundedSender<K>,
37    key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
38    keymaker: F,
39}
40
41/// A channel that is tracked by [`MaxChannelsPerKey`].
42#[pin_project]
43#[derive(Debug)]
44pub struct TrackedChannel<C, K> {
45    #[pin]
46    inner: C,
47    tracker: Arc<Tracker<K>>,
48}
49
50#[derive(Debug)]
51struct Tracker<K> {
52    key: Option<K>,
53    dropped_keys: mpsc::UnboundedSender<K>,
54}
55
56impl<K> Drop for Tracker<K> {
57    fn drop(&mut self) {
58        // Don't care if the listener is dropped.
59        let _ = self.dropped_keys.send(self.key.take().unwrap());
60    }
61}
62
63impl<C, K> Stream for TrackedChannel<C, K>
64where
65    C: Stream,
66{
67    type Item = <C as Stream>::Item;
68
69    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
70        self.inner_pin_mut().poll_next(cx)
71    }
72}
73
74impl<C, I, K> Sink<I> for TrackedChannel<C, K>
75where
76    C: Sink<I>,
77{
78    type Error = C::Error;
79
80    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
81        self.inner_pin_mut().poll_ready(cx)
82    }
83
84    fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
85        self.inner_pin_mut().start_send(item)
86    }
87
88    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
89        self.inner_pin_mut().poll_flush(cx)
90    }
91
92    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
93        self.inner_pin_mut().poll_close(cx)
94    }
95}
96
97impl<C, K> AsRef<C> for TrackedChannel<C, K> {
98    fn as_ref(&self) -> &C {
99        &self.inner
100    }
101}
102
103impl<C, K> Channel for TrackedChannel<C, K>
104where
105    C: Channel,
106{
107    type Req = C::Req;
108    type Resp = C::Resp;
109    type Transport = C::Transport;
110
111    fn config(&self) -> &server::Config {
112        self.inner.config()
113    }
114
115    fn in_flight_requests(&self) -> usize {
116        self.inner.in_flight_requests()
117    }
118
119    fn transport(&self) -> &Self::Transport {
120        self.inner.transport()
121    }
122}
123
124impl<C, K> TrackedChannel<C, K> {
125    /// Returns the inner channel.
126    pub fn get_ref(&self) -> &C {
127        &self.inner
128    }
129
130    /// Returns the pinned inner channel.
131    fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
132        self.as_mut().project().inner
133    }
134}
135
136impl<S, K, F> MaxChannelsPerKey<S, K, F>
137where
138    K: Eq + Hash,
139    S: Stream,
140    F: Fn(&S::Item) -> K,
141{
142    /// Sheds new channels to stay under configured limits.
143    pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
144        let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
145        MaxChannelsPerKey {
146            listener: listener.fuse(),
147            channels_per_key,
148            dropped_keys,
149            dropped_keys_tx,
150            key_counts: FnvHashMap::default(),
151            keymaker,
152        }
153    }
154}
155
156impl<S, K, F> MaxChannelsPerKey<S, K, F>
157where
158    S: Stream,
159    K: fmt::Display + Eq + Hash + Clone + Unpin,
160    F: Fn(&S::Item) -> K,
161{
162    fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
163        self.as_mut().project().listener
164    }
165
166    fn handle_new_channel(
167        mut self: Pin<&mut Self>,
168        stream: S::Item,
169    ) -> Result<TrackedChannel<S::Item, K>, K> {
170        let key = (self.as_mut().keymaker)(&stream);
171        let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
172
173        trace!(
174            channel_filter_key = %key,
175            open_channels = Arc::strong_count(&tracker),
176            max_open_channels = self.channels_per_key,
177            "Opening channel");
178
179        Ok(TrackedChannel {
180            tracker,
181            inner: stream,
182        })
183    }
184
185    fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
186        let self_ = self.project();
187        let dropped_keys = self_.dropped_keys_tx;
188        match self_.key_counts.entry(key.clone()) {
189            Entry::Vacant(vacant) => {
190                let tracker = Arc::new(Tracker {
191                    key: Some(key),
192                    dropped_keys: dropped_keys.clone(),
193                });
194
195                vacant.insert(Arc::downgrade(&tracker));
196                Ok(tracker)
197            }
198            Entry::Occupied(mut o) => {
199                let count = o.get().strong_count();
200                if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
201                    info!(
202                        channel_filter_key = %key,
203                        open_channels = count,
204                        max_open_channels = *self_.channels_per_key,
205                        "At open channel limit");
206                    Err(key)
207                } else {
208                    Ok(o.get().upgrade().unwrap_or_else(|| {
209                        let tracker = Arc::new(Tracker {
210                            key: Some(key),
211                            dropped_keys: dropped_keys.clone(),
212                        });
213
214                        *o.get_mut() = Arc::downgrade(&tracker);
215                        tracker
216                    }))
217                }
218            }
219        }
220    }
221
222    fn poll_listener(
223        mut self: Pin<&mut Self>,
224        cx: &mut Context<'_>,
225    ) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
226        match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
227            Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
228            None => Poll::Ready(None),
229        }
230    }
231
232    fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
233        let self_ = self.project();
234        match ready!(self_.dropped_keys.poll_recv(cx)) {
235            Some(key) => {
236                debug!(
237                    channel_filter_key = %key,
238                    "All channels dropped");
239                self_.key_counts.remove(&key);
240                self_.key_counts.compact(0.1);
241                Poll::Ready(())
242            }
243            None => unreachable!("Holding a copy of closed_channels and didn't close it."),
244        }
245    }
246}
247
248impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
249where
250    S: Stream,
251    K: fmt::Display + Eq + Hash + Clone + Unpin,
252    F: Fn(&S::Item) -> K,
253{
254    type Item = TrackedChannel<S::Item, K>;
255
256    fn poll_next(
257        mut self: Pin<&mut Self>,
258        cx: &mut Context<'_>,
259    ) -> Poll<Option<TrackedChannel<S::Item, K>>> {
260        loop {
261            match (
262                self.as_mut().poll_listener(cx),
263                self.as_mut().poll_closed_channels(cx),
264            ) {
265                (Poll::Ready(Some(Ok(channel))), _) => {
266                    return Poll::Ready(Some(channel));
267                }
268                (Poll::Ready(Some(Err(_))), _) => {
269                    continue;
270                }
271                (_, Poll::Ready(())) => continue,
272                (Poll::Pending, Poll::Pending) => return Poll::Pending,
273                (Poll::Ready(None), Poll::Pending) => {
274                    trace!("Shutting down listener.");
275                    return Poll::Ready(None);
276                }
277            }
278        }
279    }
280}
281#[cfg(test)]
282fn ctx() -> Context<'static> {
283    use futures::task::*;
284
285    Context::from_waker(noop_waker_ref())
286}
287
288#[test]
289fn tracker_drop() {
290    use assert_matches::assert_matches;
291
292    let (tx, mut rx) = mpsc::unbounded_channel();
293    Tracker {
294        key: Some(1),
295        dropped_keys: tx,
296    };
297    assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1)));
298}
299
300#[test]
301fn tracked_channel_stream() {
302    use assert_matches::assert_matches;
303    use pin_utils::pin_mut;
304
305    let (chan_tx, chan) = futures::channel::mpsc::unbounded();
306    let (dropped_keys, _) = mpsc::unbounded_channel();
307    let channel = TrackedChannel {
308        inner: chan,
309        tracker: Arc::new(Tracker {
310            key: Some(1),
311            dropped_keys,
312        }),
313    };
314
315    chan_tx.unbounded_send("test").unwrap();
316    pin_mut!(channel);
317    assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
318}
319
320#[test]
321fn tracked_channel_sink() {
322    use assert_matches::assert_matches;
323    use pin_utils::pin_mut;
324
325    let (chan, mut chan_rx) = futures::channel::mpsc::unbounded();
326    let (dropped_keys, _) = mpsc::unbounded_channel();
327    let channel = TrackedChannel {
328        inner: chan,
329        tracker: Arc::new(Tracker {
330            key: Some(1),
331            dropped_keys,
332        }),
333    };
334
335    pin_mut!(channel);
336    assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
337    assert_matches!(channel.as_mut().start_send("test"), Ok(()));
338    assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
339    assert_matches!(chan_rx.try_next(), Ok(Some("test")));
340}
341
342#[test]
343fn channel_filter_increment_channels_for_key() {
344    use assert_matches::assert_matches;
345    use pin_utils::pin_mut;
346
347    struct TestChannel {
348        key: &'static str,
349    }
350    let (_, listener) = futures::channel::mpsc::unbounded();
351    let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
352    pin_mut!(filter);
353    let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
354    assert_eq!(Arc::strong_count(&tracker1), 1);
355    let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
356    assert_eq!(Arc::strong_count(&tracker1), 2);
357    assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
358    drop(tracker2);
359    assert_eq!(Arc::strong_count(&tracker1), 1);
360}
361
362#[test]
363fn channel_filter_handle_new_channel() {
364    use assert_matches::assert_matches;
365    use pin_utils::pin_mut;
366
367    #[derive(Debug)]
368    struct TestChannel {
369        key: &'static str,
370    }
371    let (_, listener) = futures::channel::mpsc::unbounded();
372    let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
373    pin_mut!(filter);
374    let channel1 = filter
375        .as_mut()
376        .handle_new_channel(TestChannel { key: "key" })
377        .unwrap();
378    assert_eq!(Arc::strong_count(&channel1.tracker), 1);
379
380    let channel2 = filter
381        .as_mut()
382        .handle_new_channel(TestChannel { key: "key" })
383        .unwrap();
384    assert_eq!(Arc::strong_count(&channel1.tracker), 2);
385
386    assert_matches!(
387        filter.handle_new_channel(TestChannel { key: "key" }),
388        Err("key")
389    );
390    drop(channel2);
391    assert_eq!(Arc::strong_count(&channel1.tracker), 1);
392}
393
394#[test]
395fn channel_filter_poll_listener() {
396    use assert_matches::assert_matches;
397    use pin_utils::pin_mut;
398
399    #[derive(Debug)]
400    struct TestChannel {
401        key: &'static str,
402    }
403    let (new_channels, listener) = futures::channel::mpsc::unbounded();
404    let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
405    pin_mut!(filter);
406
407    new_channels
408        .unbounded_send(TestChannel { key: "key" })
409        .unwrap();
410    let channel1 =
411        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
412    assert_eq!(Arc::strong_count(&channel1.tracker), 1);
413
414    new_channels
415        .unbounded_send(TestChannel { key: "key" })
416        .unwrap();
417    let _channel2 =
418        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
419    assert_eq!(Arc::strong_count(&channel1.tracker), 2);
420
421    new_channels
422        .unbounded_send(TestChannel { key: "key" })
423        .unwrap();
424    let key =
425        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
426    assert_eq!(key, "key");
427    assert_eq!(Arc::strong_count(&channel1.tracker), 2);
428}
429
430#[test]
431fn channel_filter_poll_closed_channels() {
432    use assert_matches::assert_matches;
433    use pin_utils::pin_mut;
434
435    #[derive(Debug)]
436    struct TestChannel {
437        key: &'static str,
438    }
439    let (new_channels, listener) = futures::channel::mpsc::unbounded();
440    let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
441    pin_mut!(filter);
442
443    new_channels
444        .unbounded_send(TestChannel { key: "key" })
445        .unwrap();
446    let channel =
447        assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
448    assert_eq!(filter.key_counts.len(), 1);
449
450    drop(channel);
451    assert_matches!(
452        filter.as_mut().poll_closed_channels(&mut ctx()),
453        Poll::Ready(())
454    );
455    assert!(filter.key_counts.is_empty());
456}
457
458#[test]
459fn channel_filter_stream() {
460    use assert_matches::assert_matches;
461    use pin_utils::pin_mut;
462
463    #[derive(Debug)]
464    struct TestChannel {
465        key: &'static str,
466    }
467    let (new_channels, listener) = futures::channel::mpsc::unbounded();
468    let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
469    pin_mut!(filter);
470
471    new_channels
472        .unbounded_send(TestChannel { key: "key" })
473        .unwrap();
474    let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
475    assert_eq!(filter.key_counts.len(), 1);
476
477    drop(channel);
478    assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
479    assert!(filter.key_counts.is_empty());
480}