use futures::future::Future;
use futures::sink::{Sink, SinkExt};
use futures::stream::Stream;
use std::collections::HashMap;
use std::hash::Hash;
use std::pin::Pin;
use std::task::{Context, Poll};
mod rand;
mod tagger;
enum StreamState {
StreamActive,
TaggerActive,
SinkPending,
SinkActive,
SinkFlushing,
}
struct StreamManager<F, A, T> {
tagger: tagger::StreamTagger<F, A>,
state: StreamState,
pending_sink_tag: Option<T>,
pending_item: Option<A>,
stream: Box<dyn Stream<Item = A> + Unpin>,
}
impl<F, A, T> StreamManager<F, A, T> {
fn new(
tagger: tagger::StreamTagger<F, A>,
stream: Box<dyn Stream<Item = A> + Unpin>,
) -> StreamManager<F, A, T> {
StreamManager {
tagger,
state: StreamState::StreamActive,
pending_sink_tag: None,
pending_item: None,
stream,
}
}
}
pub struct StreamRouter<F, T, A>
where
T: Hash + Eq,
{
streams: Vec<StreamManager<F, A, T>>,
sinks: HashMap<T, (usize, Box<dyn Sink<A, Error = ()> + Unpin>)>,
}
impl<F, T, A> StreamRouter<F, T, A>
where
T: Hash + Eq,
{
pub fn new() -> StreamRouter<F, T, A> {
StreamRouter {
streams: vec![],
sinks: HashMap::new(),
}
}
pub fn add_source<S, M>(&mut self, stream: S, transform: M)
where
S: Stream<Item = A> + Unpin + 'static,
M: Fn(A) -> F + 'static,
F: Future<Output = (A, T)>,
{
let tagger = tagger::StreamTagger::new(Box::new(transform));
self.streams
.push(StreamManager::new(tagger, Box::new(stream)));
}
pub fn add_sink<S>(&mut self, sink: S, tag: T)
where
S: Sink<A> + Unpin + Sized + 'static,
{
self.sinks
.insert(tag, (0, Box::new(sink.sink_map_err(|_| ()))));
}
}
impl<F, T, A> StreamRouter<F, T, A>
where
F: Future<Output = (A, T)> + Unpin,
T: Hash + Eq + Unpin,
A: Unpin,
{
fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<A>> {
use Poll::*;
let start = rand::thread_rng_n(self.streams.len() as u32) as usize;
let mut idx = start;
'outterLoop: for _ in 0..self.streams.len() {
'innerLoop: loop {
match self.streams[idx].state {
StreamState::StreamActive => {
match Pin::new(&mut self.streams[idx].stream).poll_next(cx) {
Ready(Some(val)) => {
self.streams[idx].state = StreamState::TaggerActive;
self.streams[idx].tagger.start_map(val);
continue 'innerLoop;
}
Ready(None) => {
self.streams.swap_remove(idx);
continue 'outterLoop;
}
Pending => {
break 'innerLoop;
}
}
}
StreamState::TaggerActive => {
match Pin::new(&mut self.streams[idx].tagger).poll(cx) {
Ready((val, tag)) => {
if let Some((ref_count, _sink)) = self.sinks.get_mut(&tag) {
self.streams[idx].pending_sink_tag = Some(tag);
self.streams[idx].pending_item = Some(val);
if *ref_count == 0 {
self.streams[idx].state = StreamState::SinkPending;
continue 'innerLoop;
} else {
self.streams[idx].state = StreamState::SinkActive;
*ref_count += 1;
continue 'innerLoop;
}
} else {
self.streams[idx].state = StreamState::StreamActive;
return Ready(Some(val));
}
}
Pending => {
break 'innerLoop;
}
}
}
StreamState::SinkPending => {
let tag = self.streams[idx].pending_sink_tag.take().unwrap();
if let Some((ref_count, sink)) = self.sinks.get_mut(&tag) {
if *ref_count != 0 {
self.streams[idx].pending_sink_tag = Some(tag);
self.streams[idx].state = StreamState::SinkActive;
*ref_count += 1;
continue 'innerLoop;
}
match Pin::new(sink).poll_ready(cx) {
Ready(Ok(())) => {
self.streams[idx].pending_sink_tag = Some(tag);
self.streams[idx].state = StreamState::SinkActive;
*ref_count += 1;
continue 'innerLoop;
}
Ready(Err(_)) => {
self.streams[idx].pending_item = None;
self.streams[idx].state = StreamState::StreamActive;
break 'innerLoop;
}
Pending => {
self.streams[idx].pending_sink_tag = Some(tag);
break 'innerLoop;
}
}
} else {
self.streams[idx].state = StreamState::StreamActive;
break 'innerLoop;
}
}
StreamState::SinkActive => {
let tag = self.streams[idx].pending_sink_tag.take().unwrap();
if let Some((ref_count, sink)) = self.sinks.get_mut(&tag) {
if Pin::new(sink)
.start_send(self.streams[idx].pending_item.take().unwrap())
.is_ok()
{
self.streams[idx].pending_sink_tag = Some(tag);
self.streams[idx].state = StreamState::SinkFlushing;
continue 'innerLoop;
} else {
self.streams[idx].state = StreamState::StreamActive;
*ref_count -= 1;
break 'innerLoop;
}
}
}
StreamState::SinkFlushing => {
let tag = self.streams[idx].pending_sink_tag.take().unwrap();
if let Some((ref_count, sink)) = self.sinks.get_mut(&tag) {
if *ref_count > 1 {
*ref_count -= 1;
self.streams[idx].state = StreamState::StreamActive;
continue 'innerLoop;
} else {
match Pin::new(sink).poll_flush(cx) {
Ready(Ok(())) => {
self.streams[idx].state = StreamState::StreamActive;
*ref_count -= 1;
continue 'innerLoop;
}
Ready(Err(_)) => {
self.streams[idx].state = StreamState::StreamActive;
*ref_count -= 1;
continue 'innerLoop;
}
Pending => {
self.streams[idx].pending_sink_tag = Some(tag);
break 'innerLoop;
}
}
}
}
}
}
}
idx = idx.wrapping_add(1) % self.streams.len();
}
if self.streams.is_empty() {
Ready(None)
} else {
Pending
}
}
}
#[must_use = "streams do nothing unless you `.await` or poll them"]
impl<F, T, A> Stream for StreamRouter<F, T, A>
where
F: Future<Output = (A, T)> + Unpin,
T: Hash + Eq + Unpin,
A: Unpin,
{
type Item = A;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.poll_next_entry(cx) {
Poll::Ready(Some(val)) => Poll::Ready(Some(val)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let mut ret = (0, Some(0));
for stream_manager in &self.streams {
let hint = stream_manager.stream.size_hint();
ret.0 += hint.0;
match (ret.1, hint.1) {
(Some(a), Some(b)) => ret.1 = Some(a + b),
(Some(_), None) => ret.1 = None,
_ => {}
}
}
ret
}
}