weighted_select/
lib.rs

1//! An adapter for merging the output of several streams with priority.
2//!
3//! The merged stream produces items from either of the underlying streams as they become available,
4//! and the streams are polled according to priority.
5//!
6//! Example:
7//! ```
8//! # use futures::{prelude::*, stream::iter_ok};
9//! use weighted_select::{self, IncompleteSelect};
10//!
11//! let select = weighted_select::new()
12//!     .append(iter_ok::<_, ()>(vec![1u32, 1]), 1)
13//!     .append(iter_ok(vec![2, 2, 2, 2, 2]), 3)
14//!     .append(iter_ok(vec![3, 3, 3, 3]), 2)
15//!     .build();
16//!
17//! let actual = select.wait().collect::<Result<Vec<_>, _>>().unwrap();
18//!
19//! assert_eq!(actual, vec![1, 2, 2, 2, 3, 3, 1, 2, 2, 3, 3]);
20//! ```
21
22use std::marker::PhantomData;
23
24use futures::{prelude::*, stream::Fuse};
25
26#[cfg(test)]
27mod tests;
28
29/// An adapter for merging the output of several streams with priority.
30///
31/// The merged stream produces items from either of the underlying streams as they become available,
32/// and the streams are polled according to priority.
33#[must_use = "streams do nothing unless polled"]
34#[derive(Debug)]
35pub struct Select<N> {
36    head: N,
37    cursor: u64,
38    limit: u64,
39}
40
41impl<N> Stream for Select<N>
42where
43    N: IncompleteSelect,
44{
45    type Item = N::Item;
46    type Error = N::Error;
47
48    fn poll(&mut self) -> Poll<Option<N::Item>, N::Error> {
49        let (cnt, res) = match self.head.poll_chain(self.cursor) {
50            (_, Ok(Async::NotReady)) | (_, Ok(Async::Ready(None))) if self.cursor > 0 => {
51                self.head.poll_chain(0)
52            }
53            res => res,
54        };
55
56        self.cursor = cnt % self.limit;
57        res
58    }
59}
60
61#[derive(Debug)]
62pub struct SelectPart<S, N> {
63    stream: Fuse<S>,
64    weight: u32,
65    start_at: u64,
66    prev_start_at: u64,
67    next: N,
68}
69
70pub trait IncompleteSelect: Sized {
71    type Item;
72    type Error;
73
74    fn append<NS>(self, stream: NS, weight: u32) -> SelectPart<NS, Self>
75    where
76        NS: Stream<Item = Self::Item, Error = Self::Error>;
77
78    fn build(self) -> Select<Self>;
79
80    fn poll_chain(&mut self, cursor: u64) -> (u64, Poll<Option<Self::Item>, Self::Error>);
81}
82
83impl<S, N> IncompleteSelect for SelectPart<S, N>
84where
85    S: Stream,
86    N: IncompleteSelect<Item = S::Item, Error = S::Error>,
87{
88    type Item = S::Item;
89    type Error = S::Error;
90
91    fn append<NS>(self, stream: NS, weight: u32) -> SelectPart<NS, Self>
92    where
93        NS: Stream<Item = S::Item, Error = S::Error>,
94    {
95        assert!(weight > 0);
96
97        let start_at = self.start_at + u64::from(self.weight);
98
99        SelectPart {
100            stream: stream.fuse(),
101            weight,
102            start_at,
103            prev_start_at: start_at + u64::from(weight),
104            next: self,
105        }
106    }
107
108    fn build(self) -> Select<Self> {
109        Select {
110            limit: self.prev_start_at,
111            head: self,
112            cursor: 0,
113        }
114    }
115
116    fn poll_chain(&mut self, cursor: u64) -> (u64, Poll<Option<Self::Item>, Self::Error>) {
117        let (cursor, next_done) = if cursor < self.start_at {
118            match self.next.poll_chain(cursor) {
119                (cursor, Ok(Async::Ready(None))) => (cursor, true),
120                (cursor, Ok(Async::NotReady)) => (cursor, false),
121                result => return result,
122            }
123        } else {
124            (cursor, cursor == 0)
125        };
126
127        debug_assert!(cursor >= self.start_at);
128
129        match self.stream.poll() {
130            Ok(Async::Ready(None)) if next_done => (self.prev_start_at, Ok(Async::Ready(None))),
131            Ok(Async::NotReady) | Ok(Async::Ready(None)) => {
132                (self.prev_start_at, Ok(Async::NotReady))
133            }
134            Err(err) => (self.prev_start_at, Err(err)),
135            x => (cursor + 1, x),
136        }
137    }
138}
139
140#[derive(Debug)]
141struct Terminal<I, E>(PhantomData<(I, E)>);
142
143impl<I, E> IncompleteSelect for Terminal<I, E> {
144    type Item = I;
145    type Error = E;
146
147    fn append<NS>(self, stream: NS, weight: u32) -> SelectPart<NS, Self>
148    where
149        NS: Stream<Item = I, Error = E>,
150    {
151        assert!(weight > 0);
152
153        SelectPart {
154            stream: stream.fuse(),
155            weight,
156            start_at: 0,
157            prev_start_at: u64::from(weight),
158            next: self,
159        }
160    }
161
162    fn build(self) -> Select<Self> {
163        Select {
164            limit: 1, // Avoid calculating the remainder with a divisor of zero.
165            head: self,
166            cursor: 0,
167        }
168    }
169
170    #[inline]
171    fn poll_chain(&mut self, cursor: u64) -> (u64, Poll<Option<Self::Item>, Self::Error>) {
172        debug_assert_eq!(cursor, 0);
173        (0, Ok(Async::Ready(None)))
174    }
175}
176
177pub fn new<I, E>() -> impl IncompleteSelect<Item = I, Error = E> {
178    Terminal(PhantomData)
179}