1use std::mem;
2use std::pin::Pin;
3use std::task::{Context, Poll, Waker};
4use std::time::Duration;
5
6use futures_util::stream::{BoxStream, SelectAll};
7use futures_util::{stream, FutureExt, Stream, StreamExt};
8
9use crate::{Delay, PushError, Timeout};
10
11pub struct StreamMap<ID, O> {
15 make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
16 capacity: usize,
17 inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
18 empty_waker: Option<Waker>,
19 full_waker: Option<Waker>,
20}
21
22impl<ID, O> StreamMap<ID, O>
23where
24 ID: Clone + Unpin,
25{
26 pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self {
27 Self {
28 make_delay: Box::new(make_delay),
29 capacity,
30 inner: Default::default(),
31 empty_waker: None,
32 full_waker: None,
33 }
34 }
35}
36
37impl<ID, O> StreamMap<ID, O>
38where
39 ID: Clone + PartialEq + Send + Unpin + 'static,
40 O: Send + 'static,
41{
42 pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
44 where
45 F: Stream<Item = O> + Send + 'static,
46 {
47 if self.inner.len() >= self.capacity {
48 return Err(PushError::BeyondCapacity(stream.boxed()));
49 }
50
51 if let Some(waker) = self.empty_waker.take() {
52 waker.wake();
53 }
54
55 let old = self.remove(id.clone());
56 self.inner.push(TaggedStream::new(
57 id,
58 TimeoutStream {
59 inner: stream.boxed(),
60 timeout: (self.make_delay)(),
61 },
62 ));
63
64 match old {
65 None => Ok(()),
66 Some(old) => Err(PushError::Replaced(old)),
67 }
68 }
69
70 pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
71 let tagged = self.inner.iter_mut().find(|s| s.key == id)?;
72
73 let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
74 tagged.exhausted = true; Some(inner)
77 }
78
79 pub fn len(&self) -> usize {
80 self.inner.len()
81 }
82
83 pub fn is_empty(&self) -> bool {
84 self.inner.is_empty()
85 }
86
87 #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
89 if self.inner.len() < self.capacity {
90 return Poll::Ready(());
91 }
92
93 self.full_waker = Some(cx.waker().clone());
94
95 Poll::Pending
96 }
97
98 pub fn poll_next_unpin(
99 &mut self,
100 cx: &mut Context<'_>,
101 ) -> Poll<(ID, Option<Result<O, Timeout>>)> {
102 match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
103 None => {
104 self.empty_waker = Some(cx.waker().clone());
105 Poll::Pending
106 }
107 Some((id, Some(Ok(output)))) => Poll::Ready((id, Some(Ok(output)))),
108 Some((id, Some(Err(dur)))) => {
109 self.remove(id.clone()); Poll::Ready((id, Some(Err(Timeout::new(dur)))))
112 }
113 Some((id, None)) => Poll::Ready((id, None)),
114 }
115 }
116}
117
118struct TimeoutStream<S> {
119 inner: S,
120 timeout: Delay,
121}
122
123impl<F> Stream for TimeoutStream<F>
124where
125 F: Stream + Unpin,
126{
127 type Item = Result<F::Item, Duration>;
128
129 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 if let Poll::Ready(dur) = self.timeout.poll_unpin(cx) {
131 return Poll::Ready(Some(Err(dur)));
132 }
133
134 self.inner.poll_next_unpin(cx).map(|a| a.map(Ok))
135 }
136}
137
138struct TaggedStream<K, S> {
139 key: K,
140 inner: S,
141
142 exhausted: bool,
143}
144
145impl<K, S> TaggedStream<K, S> {
146 fn new(key: K, inner: S) -> Self {
147 Self {
148 key,
149 inner,
150 exhausted: false,
151 }
152 }
153}
154
155impl<K, S> Stream for TaggedStream<K, S>
156where
157 K: Clone + Unpin,
158 S: Stream + Unpin,
159{
160 type Item = (K, Option<S::Item>);
161
162 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
163 if self.exhausted {
164 return Poll::Ready(None);
165 }
166
167 match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
168 Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))),
169 None => {
170 self.exhausted = true;
171
172 Poll::Ready(Some((self.key.clone(), None)))
173 }
174 }
175 }
176}
177
178#[cfg(all(test, feature = "futures-timer"))]
179mod tests {
180 use futures::channel::mpsc;
181 use futures_util::stream::{once, pending};
182 use futures_util::SinkExt;
183 use std::future::{poll_fn, ready, Future};
184 use std::pin::Pin;
185 use std::time::Instant;
186
187 use super::*;
188
189 #[test]
190 fn cannot_push_more_than_capacity_tasks() {
191 let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 1);
192
193 assert!(streams.try_push("ID_1", once(ready(()))).is_ok());
194 matches!(
195 streams.try_push("ID_2", once(ready(()))),
196 Err(PushError::BeyondCapacity(_))
197 );
198 }
199
200 #[test]
201 fn cannot_push_the_same_id_few_times() {
202 let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 5);
203
204 assert!(streams.try_push("ID", once(ready(()))).is_ok());
205 matches!(
206 streams.try_push("ID", once(ready(()))),
207 Err(PushError::Replaced(_))
208 );
209 }
210
211 #[tokio::test]
212 async fn streams_timeout() {
213 let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
214
215 let _ = streams.try_push("ID", pending::<()>());
216 futures_timer::Delay::new(Duration::from_millis(150)).await;
217 let (_, result) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
218
219 assert!(result.unwrap().is_err())
220 }
221
222 #[tokio::test]
223 async fn timed_out_stream_gets_removed() {
224 let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
225
226 let _ = streams.try_push("ID", pending::<()>());
227 futures_timer::Delay::new(Duration::from_millis(150)).await;
228 poll_fn(|cx| streams.poll_next_unpin(cx)).await;
229
230 let poll = streams.poll_next_unpin(&mut Context::from_waker(
231 futures_util::task::noop_waker_ref(),
232 ));
233 assert!(poll.is_pending())
234 }
235
236 #[test]
237 fn removing_stream() {
238 let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
239
240 let _ = streams.try_push("ID", stream::once(ready(())));
241
242 {
243 let cancelled_stream = streams.remove("ID");
244 assert!(cancelled_stream.is_some());
245 }
246
247 let poll = streams.poll_next_unpin(&mut Context::from_waker(
248 futures_util::task::noop_waker_ref(),
249 ));
250
251 assert!(poll.is_pending());
252 assert_eq!(
253 streams.len(),
254 0,
255 "resources of cancelled streams are cleaned up properly"
256 );
257 }
258
259 #[tokio::test]
260 async fn replaced_stream_is_still_registered() {
261 let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 3);
262
263 let (mut tx1, rx1) = mpsc::channel(5);
264 let (mut tx2, rx2) = mpsc::channel(5);
265
266 let _ = streams.try_push("ID1", rx1);
267 let _ = streams.try_push("ID2", rx2);
268
269 let _ = tx2.send(2).await;
270 let _ = tx1.send(1).await;
271 let _ = tx2.send(3).await;
272 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
273 assert_eq!(id, "ID1");
274 assert_eq!(res.unwrap().unwrap(), 1);
275 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
276 assert_eq!(id, "ID2");
277 assert_eq!(res.unwrap().unwrap(), 2);
278 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
279 assert_eq!(id, "ID2");
280 assert_eq!(res.unwrap().unwrap(), 3);
281
282 let (mut new_tx1, new_rx1) = mpsc::channel(5);
283 let replaced = streams.try_push("ID1", new_rx1);
284 assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
285
286 let _ = new_tx1.send(4).await;
287 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
288
289 assert_eq!(id, "ID1");
290 assert_eq!(res.unwrap().unwrap(), 4);
291 }
292
293 #[tokio::test]
296 async fn backpressure() {
297 const DELAY: Duration = Duration::from_millis(100);
298 const NUM_STREAMS: u32 = 10;
299
300 let start = Instant::now();
301 Task::new(DELAY, NUM_STREAMS, 1).await;
302 let duration = start.elapsed();
303
304 assert!(duration >= DELAY * NUM_STREAMS);
305 }
306
307 struct Task {
308 item_delay: Duration,
309 num_streams: usize,
310 num_processed: usize,
311 inner: StreamMap<u8, ()>,
312 }
313
314 impl Task {
315 fn new(item_delay: Duration, num_streams: u32, capacity: usize) -> Self {
316 Self {
317 item_delay,
318 num_streams: num_streams as usize,
319 num_processed: 0,
320 inner: StreamMap::new(|| Delay::futures_timer(Duration::from_secs(60)), capacity),
321 }
322 }
323 }
324
325 impl Future for Task {
326 type Output = ();
327
328 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
329 let this = self.get_mut();
330
331 while this.num_processed < this.num_streams {
332 match this.inner.poll_next_unpin(cx) {
333 Poll::Ready((_, Some(result))) => {
334 if result.is_err() {
335 panic!("Timeout is great than item delay")
336 }
337
338 this.num_processed += 1;
339 continue;
340 }
341 Poll::Ready((_, None)) => {
342 continue;
343 }
344 _ => {}
345 }
346
347 if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
348 let maybe_future = this
350 .inner
351 .try_push(1u8, once(futures_timer::Delay::new(this.item_delay)));
352 assert!(maybe_future.is_ok(), "we polled for readiness");
353
354 continue;
355 }
356
357 return Poll::Pending;
358 }
359
360 Poll::Ready(())
361 }
362 }
363}