1use crate::{
8 server::{self, Channel},
9 util::Compact,
10};
11use fnv::FnvHashMap;
12use futures::{prelude::*, ready, stream::Fuse, task::*};
13use pin_project::pin_project;
14use std::sync::{Arc, Weak};
15use std::{
16 collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
17};
18use tokio::sync::mpsc;
19use tracing::{debug, info, trace};
20
21#[pin_project]
27#[derive(Debug)]
28pub struct MaxChannelsPerKey<S, K, F>
29where
30 K: Eq + Hash,
31{
32 #[pin]
33 listener: Fuse<S>,
34 channels_per_key: u32,
35 dropped_keys: mpsc::UnboundedReceiver<K>,
36 dropped_keys_tx: mpsc::UnboundedSender<K>,
37 key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
38 keymaker: F,
39}
40
41#[pin_project]
43#[derive(Debug)]
44pub struct TrackedChannel<C, K> {
45 #[pin]
46 inner: C,
47 tracker: Arc<Tracker<K>>,
48}
49
50#[derive(Debug)]
51struct Tracker<K> {
52 key: Option<K>,
53 dropped_keys: mpsc::UnboundedSender<K>,
54}
55
56impl<K> Drop for Tracker<K> {
57 fn drop(&mut self) {
58 let _ = self.dropped_keys.send(self.key.take().unwrap());
60 }
61}
62
63impl<C, K> Stream for TrackedChannel<C, K>
64where
65 C: Stream,
66{
67 type Item = <C as Stream>::Item;
68
69 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
70 self.inner_pin_mut().poll_next(cx)
71 }
72}
73
74impl<C, I, K> Sink<I> for TrackedChannel<C, K>
75where
76 C: Sink<I>,
77{
78 type Error = C::Error;
79
80 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
81 self.inner_pin_mut().poll_ready(cx)
82 }
83
84 fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
85 self.inner_pin_mut().start_send(item)
86 }
87
88 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
89 self.inner_pin_mut().poll_flush(cx)
90 }
91
92 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
93 self.inner_pin_mut().poll_close(cx)
94 }
95}
96
97impl<C, K> AsRef<C> for TrackedChannel<C, K> {
98 fn as_ref(&self) -> &C {
99 &self.inner
100 }
101}
102
103impl<C, K> Channel for TrackedChannel<C, K>
104where
105 C: Channel,
106{
107 type Req = C::Req;
108 type Resp = C::Resp;
109 type Transport = C::Transport;
110
111 fn config(&self) -> &server::Config {
112 self.inner.config()
113 }
114
115 fn in_flight_requests(&self) -> usize {
116 self.inner.in_flight_requests()
117 }
118
119 fn transport(&self) -> &Self::Transport {
120 self.inner.transport()
121 }
122}
123
124impl<C, K> TrackedChannel<C, K> {
125 pub fn get_ref(&self) -> &C {
127 &self.inner
128 }
129
130 fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
132 self.as_mut().project().inner
133 }
134}
135
136impl<S, K, F> MaxChannelsPerKey<S, K, F>
137where
138 K: Eq + Hash,
139 S: Stream,
140 F: Fn(&S::Item) -> K,
141{
142 pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
144 let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
145 MaxChannelsPerKey {
146 listener: listener.fuse(),
147 channels_per_key,
148 dropped_keys,
149 dropped_keys_tx,
150 key_counts: FnvHashMap::default(),
151 keymaker,
152 }
153 }
154}
155
156impl<S, K, F> MaxChannelsPerKey<S, K, F>
157where
158 S: Stream,
159 K: fmt::Display + Eq + Hash + Clone + Unpin,
160 F: Fn(&S::Item) -> K,
161{
162 fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
163 self.as_mut().project().listener
164 }
165
166 fn handle_new_channel(
167 mut self: Pin<&mut Self>,
168 stream: S::Item,
169 ) -> Result<TrackedChannel<S::Item, K>, K> {
170 let key = (self.as_mut().keymaker)(&stream);
171 let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
172
173 trace!(
174 channel_filter_key = %key,
175 open_channels = Arc::strong_count(&tracker),
176 max_open_channels = self.channels_per_key,
177 "Opening channel");
178
179 Ok(TrackedChannel {
180 tracker,
181 inner: stream,
182 })
183 }
184
185 fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
186 let self_ = self.project();
187 let dropped_keys = self_.dropped_keys_tx;
188 match self_.key_counts.entry(key.clone()) {
189 Entry::Vacant(vacant) => {
190 let tracker = Arc::new(Tracker {
191 key: Some(key),
192 dropped_keys: dropped_keys.clone(),
193 });
194
195 vacant.insert(Arc::downgrade(&tracker));
196 Ok(tracker)
197 }
198 Entry::Occupied(mut o) => {
199 let count = o.get().strong_count();
200 if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
201 info!(
202 channel_filter_key = %key,
203 open_channels = count,
204 max_open_channels = *self_.channels_per_key,
205 "At open channel limit");
206 Err(key)
207 } else {
208 Ok(o.get().upgrade().unwrap_or_else(|| {
209 let tracker = Arc::new(Tracker {
210 key: Some(key),
211 dropped_keys: dropped_keys.clone(),
212 });
213
214 *o.get_mut() = Arc::downgrade(&tracker);
215 tracker
216 }))
217 }
218 }
219 }
220 }
221
222 fn poll_listener(
223 mut self: Pin<&mut Self>,
224 cx: &mut Context<'_>,
225 ) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
226 match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
227 Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
228 None => Poll::Ready(None),
229 }
230 }
231
232 fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
233 let self_ = self.project();
234 match ready!(self_.dropped_keys.poll_recv(cx)) {
235 Some(key) => {
236 debug!(
237 channel_filter_key = %key,
238 "All channels dropped");
239 self_.key_counts.remove(&key);
240 self_.key_counts.compact(0.1);
241 Poll::Ready(())
242 }
243 None => unreachable!("Holding a copy of closed_channels and didn't close it."),
244 }
245 }
246}
247
248impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
249where
250 S: Stream,
251 K: fmt::Display + Eq + Hash + Clone + Unpin,
252 F: Fn(&S::Item) -> K,
253{
254 type Item = TrackedChannel<S::Item, K>;
255
256 fn poll_next(
257 mut self: Pin<&mut Self>,
258 cx: &mut Context<'_>,
259 ) -> Poll<Option<TrackedChannel<S::Item, K>>> {
260 loop {
261 match (
262 self.as_mut().poll_listener(cx),
263 self.as_mut().poll_closed_channels(cx),
264 ) {
265 (Poll::Ready(Some(Ok(channel))), _) => {
266 return Poll::Ready(Some(channel));
267 }
268 (Poll::Ready(Some(Err(_))), _) => {
269 continue;
270 }
271 (_, Poll::Ready(())) => continue,
272 (Poll::Pending, Poll::Pending) => return Poll::Pending,
273 (Poll::Ready(None), Poll::Pending) => {
274 trace!("Shutting down listener.");
275 return Poll::Ready(None);
276 }
277 }
278 }
279 }
280}
281#[cfg(test)]
282fn ctx() -> Context<'static> {
283 use futures::task::*;
284
285 Context::from_waker(noop_waker_ref())
286}
287
288#[test]
289fn tracker_drop() {
290 use assert_matches::assert_matches;
291
292 let (tx, mut rx) = mpsc::unbounded_channel();
293 Tracker {
294 key: Some(1),
295 dropped_keys: tx,
296 };
297 assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1)));
298}
299
300#[test]
301fn tracked_channel_stream() {
302 use assert_matches::assert_matches;
303 use pin_utils::pin_mut;
304
305 let (chan_tx, chan) = futures::channel::mpsc::unbounded();
306 let (dropped_keys, _) = mpsc::unbounded_channel();
307 let channel = TrackedChannel {
308 inner: chan,
309 tracker: Arc::new(Tracker {
310 key: Some(1),
311 dropped_keys,
312 }),
313 };
314
315 chan_tx.unbounded_send("test").unwrap();
316 pin_mut!(channel);
317 assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
318}
319
320#[test]
321fn tracked_channel_sink() {
322 use assert_matches::assert_matches;
323 use pin_utils::pin_mut;
324
325 let (chan, mut chan_rx) = futures::channel::mpsc::unbounded();
326 let (dropped_keys, _) = mpsc::unbounded_channel();
327 let channel = TrackedChannel {
328 inner: chan,
329 tracker: Arc::new(Tracker {
330 key: Some(1),
331 dropped_keys,
332 }),
333 };
334
335 pin_mut!(channel);
336 assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
337 assert_matches!(channel.as_mut().start_send("test"), Ok(()));
338 assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
339 assert_matches!(chan_rx.try_next(), Ok(Some("test")));
340}
341
342#[test]
343fn channel_filter_increment_channels_for_key() {
344 use assert_matches::assert_matches;
345 use pin_utils::pin_mut;
346
347 struct TestChannel {
348 key: &'static str,
349 }
350 let (_, listener) = futures::channel::mpsc::unbounded();
351 let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
352 pin_mut!(filter);
353 let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
354 assert_eq!(Arc::strong_count(&tracker1), 1);
355 let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
356 assert_eq!(Arc::strong_count(&tracker1), 2);
357 assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
358 drop(tracker2);
359 assert_eq!(Arc::strong_count(&tracker1), 1);
360}
361
362#[test]
363fn channel_filter_handle_new_channel() {
364 use assert_matches::assert_matches;
365 use pin_utils::pin_mut;
366
367 #[derive(Debug)]
368 struct TestChannel {
369 key: &'static str,
370 }
371 let (_, listener) = futures::channel::mpsc::unbounded();
372 let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
373 pin_mut!(filter);
374 let channel1 = filter
375 .as_mut()
376 .handle_new_channel(TestChannel { key: "key" })
377 .unwrap();
378 assert_eq!(Arc::strong_count(&channel1.tracker), 1);
379
380 let channel2 = filter
381 .as_mut()
382 .handle_new_channel(TestChannel { key: "key" })
383 .unwrap();
384 assert_eq!(Arc::strong_count(&channel1.tracker), 2);
385
386 assert_matches!(
387 filter.handle_new_channel(TestChannel { key: "key" }),
388 Err("key")
389 );
390 drop(channel2);
391 assert_eq!(Arc::strong_count(&channel1.tracker), 1);
392}
393
394#[test]
395fn channel_filter_poll_listener() {
396 use assert_matches::assert_matches;
397 use pin_utils::pin_mut;
398
399 #[derive(Debug)]
400 struct TestChannel {
401 key: &'static str,
402 }
403 let (new_channels, listener) = futures::channel::mpsc::unbounded();
404 let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
405 pin_mut!(filter);
406
407 new_channels
408 .unbounded_send(TestChannel { key: "key" })
409 .unwrap();
410 let channel1 =
411 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
412 assert_eq!(Arc::strong_count(&channel1.tracker), 1);
413
414 new_channels
415 .unbounded_send(TestChannel { key: "key" })
416 .unwrap();
417 let _channel2 =
418 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
419 assert_eq!(Arc::strong_count(&channel1.tracker), 2);
420
421 new_channels
422 .unbounded_send(TestChannel { key: "key" })
423 .unwrap();
424 let key =
425 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
426 assert_eq!(key, "key");
427 assert_eq!(Arc::strong_count(&channel1.tracker), 2);
428}
429
430#[test]
431fn channel_filter_poll_closed_channels() {
432 use assert_matches::assert_matches;
433 use pin_utils::pin_mut;
434
435 #[derive(Debug)]
436 struct TestChannel {
437 key: &'static str,
438 }
439 let (new_channels, listener) = futures::channel::mpsc::unbounded();
440 let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
441 pin_mut!(filter);
442
443 new_channels
444 .unbounded_send(TestChannel { key: "key" })
445 .unwrap();
446 let channel =
447 assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
448 assert_eq!(filter.key_counts.len(), 1);
449
450 drop(channel);
451 assert_matches!(
452 filter.as_mut().poll_closed_channels(&mut ctx()),
453 Poll::Ready(())
454 );
455 assert!(filter.key_counts.is_empty());
456}
457
458#[test]
459fn channel_filter_stream() {
460 use assert_matches::assert_matches;
461 use pin_utils::pin_mut;
462
463 #[derive(Debug)]
464 struct TestChannel {
465 key: &'static str,
466 }
467 let (new_channels, listener) = futures::channel::mpsc::unbounded();
468 let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
469 pin_mut!(filter);
470
471 new_channels
472 .unbounded_send(TestChannel { key: "key" })
473 .unwrap();
474 let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
475 assert_eq!(filter.key_counts.len(), 1);
476
477 drop(channel);
478 assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
479 assert!(filter.key_counts.is_empty());
480}