use futures::stream::{Fuse, FusedStream, Stream, StreamExt};
use futures::task::{Context, Poll};
use pin_project::pin_project;
use std::pin::Pin;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(super) enum Source {
Left,
Right,
}
#[pin_project]
pub(super) struct SelectBiased<S, T> {
#[pin]
left: Fuse<S>,
#[pin]
right: Fuse<T>,
}
pub(super) fn select_biased<S, T>(left: S, right: T) -> SelectBiased<S, T>
where
S: Stream,
T: Stream<Item = S::Item>,
{
SelectBiased {
left: left.fuse(),
right: right.fuse(),
}
}
impl<S, T> FusedStream for SelectBiased<S, T>
where
S: Stream,
T: Stream<Item = S::Item>,
{
fn is_terminated(&self) -> bool {
self.left.is_terminated()
}
}
impl<S, T> Stream for SelectBiased<S, T>
where
S: Stream,
T: Stream<Item = S::Item>,
{
type Item = (Source, S::Item);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.left.poll_next(cx) {
Poll::Ready(Some(val)) => {
return Poll::Ready(Some((Source::Left, val)));
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {}
}
match this.right.poll_next(cx) {
Poll::Ready(Some(val)) => {
Poll::Ready(Some((Source::Right, val)))
}
_ => {
Poll::Pending
}
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use futures_await_test::async_test;
#[async_test]
async fn left_only() {
use futures::stream::iter;
use Source::Left as L;
let left = vec![1_usize, 2, 3];
let right = vec![];
let s = select_biased(iter(left), iter(right));
let result: Vec<_> = s.collect().await;
assert_eq!(result, vec![(L, 1_usize), (L, 2), (L, 3)]);
let left = vec![1_usize, 2, 3];
let right = vec![4, 5, 6];
let s = select_biased(iter(left), iter(right));
let result: Vec<_> = s.collect().await;
assert_eq!(result, vec![(L, 1_usize), (L, 2), (L, 3)]);
let left = vec![];
let right = vec![4_usize, 5, 6];
let s = select_biased(iter(left), iter(right));
let result: Vec<_> = s.collect().await;
assert_eq!(result, vec![]);
}
#[async_test]
async fn right_only() {
use futures::stream::{iter, pending};
use Source::Right as R;
let left = pending();
let right = vec![4_usize, 5, 6];
let mut s = select_biased(left, iter(right));
assert_eq!(s.next().await, Some((R, 4)));
assert_eq!(s.next().await, Some((R, 5)));
assert_eq!(s.next().await, Some((R, 6)));
}
#[async_test]
async fn multiplex() {
use futures::SinkExt;
use Source::{Left as L, Right as R};
let (mut snd_l, rcv_l) = futures::channel::mpsc::channel(5);
let (mut snd_r, rcv_r) = futures::channel::mpsc::channel(5);
let mut s = select_biased(rcv_l, rcv_r);
snd_l.send(1_usize).await.unwrap();
snd_r.send(4_usize).await.unwrap();
snd_l.send(2_usize).await.unwrap();
assert_eq!(s.next().await, Some((L, 1)));
assert_eq!(s.next().await, Some((L, 2)));
assert_eq!(s.next().await, Some((R, 4)));
snd_r.send(5_usize).await.unwrap();
snd_l.send(3_usize).await.unwrap();
assert!(!s.is_terminated());
drop(snd_r);
assert_eq!(s.next().await, Some((L, 3)));
assert_eq!(s.next().await, Some((R, 5)));
drop(snd_l);
assert_eq!(s.next().await, None);
assert!(s.is_terminated());
}
}