1use pin_project::pin_project;
2use sea_streamer_types::{export::futures::Stream, Message, StreamKey};
3use std::{
4 collections::{BTreeMap, VecDeque},
5 pin::Pin,
6 task::Poll,
7};
8
9type Keys<M> = BTreeMap<StreamKey, VecDeque<M>>;
10
11#[pin_project]
34pub struct StreamJoin<S, M, E>
35where
36 S: Stream<Item = Result<M, E>>,
37 M: Message,
38 E: std::error::Error,
39{
40 #[pin]
41 muxed: S,
42 keys: Keys<M>,
43 key_keys: Vec<StreamKey>,
44 ended: bool,
45 err: Option<E>,
46}
47
48impl<S, M, E> StreamJoin<S, M, E>
49where
50 S: Stream<Item = Result<M, E>>,
51 M: Message,
52 E: std::error::Error,
53{
54 pub fn muxed(muxed: S) -> Self {
56 Self {
57 muxed,
58 keys: Default::default(),
59 key_keys: Default::default(),
60 ended: false,
61 err: None,
62 }
63 }
64
65 pub fn align(&mut self, stream_key: StreamKey) {
67 self.keys.insert(stream_key.clone(), Default::default());
68 self.key_keys.push(stream_key);
69 }
70
71 fn next(keys: &mut Keys<M>) -> Option<M> {
72 let mut min_key = None;
73 let mut min_ts = None;
74 for (k, ms) in keys.iter() {
75 if let Some(m) = ms.front() {
76 let m_ts = m.timestamp();
77 if min_ts.is_none() || m_ts < min_ts.unwrap() {
78 min_ts = Some(m_ts);
79 min_key = Some(k.clone());
80 }
81 }
82 }
83 if let Some(min_key) = min_key {
84 Some(
85 keys.get_mut(&min_key)
86 .unwrap()
87 .pop_front()
88 .expect("Checked above"),
89 )
90 } else {
91 None
93 }
94 }
95
96 fn check(keys: &Keys<M>, key_keys: &[StreamKey]) -> bool {
97 for kk in key_keys {
99 if keys.get(kk).expect("Already inserted").is_empty() {
100 return false;
101 }
102 }
103 keys.values().any(|ms| !ms.is_empty())
105 }
106}
107
108impl<S, M, E> Stream for StreamJoin<S, M, E>
109where
110 S: Stream<Item = Result<M, E>>,
111 M: Message,
112 E: std::error::Error,
113{
114 type Item = Result<M, E>;
115
116 fn poll_next(
117 self: Pin<&mut Self>,
118 cx: &mut std::task::Context<'_>,
119 ) -> Poll<Option<Self::Item>> {
120 let mut this = self.project();
121 while !*this.ended {
122 match this.muxed.as_mut().poll_next(cx) {
123 Poll::Ready(Some(Ok(mes))) => {
124 let key = mes.stream_key();
125 this.keys.entry(key).or_default().push_back(mes);
126 if Self::check(&this.keys, &this.key_keys) {
127 break;
129 }
130 }
132 Poll::Ready(Some(Err(err))) => {
133 *this.ended = true;
134 *this.err = Some(err);
135 break;
136 }
137 Poll::Ready(None) => {
138 *this.ended = true;
139 break;
140 }
141 Poll::Pending => {
142 break;
144 }
145 }
146 }
147 if *this.ended || Self::check(&this.keys, &this.key_keys) {
148 Poll::Ready(match Self::next(this.keys) {
149 Some(item) => Some(Ok(item)),
150 None => this.err.take().map(Err),
151 })
152 } else {
153 Poll::Pending
154 }
155 }
156}
157
158#[cfg(test)]
159mod test {
160 use super::*;
161 use sea_streamer_socket::{BackendErr, SeaMessage, SeaMessageStream};
162 use sea_streamer_types::{
163 export::futures::{self, TryStreamExt},
164 MessageHeader, OwnedMessage, StreamErr, Timestamp,
165 };
166
167 #[allow(dead_code)]
169 fn wrap<'a>(
170 s: SeaMessageStream<'a>,
171 ) -> StreamJoin<SeaMessageStream<'a>, SeaMessage<'a>, StreamErr<BackendErr>> {
172 StreamJoin::muxed(s)
173 }
174
175 fn make_seq(key: StreamKey, items: &[u64]) -> Vec<Result<OwnedMessage, BackendErr>> {
176 items
177 .iter()
178 .copied()
179 .map(|i| {
180 Ok(OwnedMessage::new(
181 MessageHeader::new(
182 key.clone(),
183 Default::default(),
184 i,
185 Timestamp::from_unix_timestamp(i as i64).unwrap(),
186 ),
187 Vec::new(),
188 ))
189 })
190 .collect()
191 }
192
193 fn compare(messages: Vec<OwnedMessage>, expected: &[(&str, u64)]) {
194 assert_eq!(messages.len(), expected.len());
195 for (i, m) in messages.iter().enumerate() {
196 assert_eq!(m.stream_key().name(), expected[i].0);
197 assert_eq!(m.sequence(), expected[i].1);
198 }
199 }
200
201 #[tokio::test]
202 async fn test_mux_streams_2() {
203 let a = StreamKey::new("a").unwrap();
204 let b = StreamKey::new("b").unwrap();
205 let stream = futures::stream::iter(
206 make_seq(a.clone(), &[1, 3, 5, 7, 9])
207 .into_iter()
208 .chain(make_seq(b.clone(), &[2, 4, 6, 8, 10]).into_iter()),
209 );
210 let mut join = StreamJoin::muxed(stream);
211 join.align(a);
212 join.align(b);
213 let messages: Vec<_> = join.try_collect().await.unwrap();
214 compare(
215 messages,
216 &[
217 ("a", 1),
218 ("b", 2),
219 ("a", 3),
220 ("b", 4),
221 ("a", 5),
222 ("b", 6),
223 ("a", 7),
224 ("b", 8),
225 ("a", 9),
226 ("b", 10),
227 ],
228 );
229 }
230
231 #[tokio::test]
232 async fn test_mux_streams_2_2() {
233 let a = StreamKey::new("a").unwrap();
234 let b = StreamKey::new("b").unwrap();
235 let stream = futures::stream::iter(
236 make_seq(a.clone(), &[1, 2, 5, 8, 9])
237 .into_iter()
238 .chain(make_seq(b.clone(), &[3, 4, 6, 7, 10]).into_iter()),
239 );
240 let mut join = StreamJoin::muxed(stream);
241 join.align(a);
242 join.align(b);
243 let messages: Vec<_> = join.try_collect().await.unwrap();
244 compare(
245 messages,
246 &[
247 ("a", 1),
248 ("a", 2),
249 ("b", 3),
250 ("b", 4),
251 ("a", 5),
252 ("b", 6),
253 ("b", 7),
254 ("a", 8),
255 ("a", 9),
256 ("b", 10),
257 ],
258 );
259 }
260
261 #[tokio::test]
262 async fn test_mux_streams_3() {
263 let a = StreamKey::new("a").unwrap();
264 let b = StreamKey::new("b").unwrap();
265 let c = StreamKey::new("c").unwrap();
266 let stream = futures::stream::iter(
267 make_seq(a.clone(), &[1, 3, 5, 7, 9])
268 .into_iter()
269 .chain(make_seq(c.clone(), &[5]).into_iter())
270 .chain(make_seq(b.clone(), &[2, 4, 6, 8, 10]).into_iter()),
271 );
272 let mut join = StreamJoin::muxed(stream);
273 join.align(a);
274 join.align(b);
275 join.align(c);
276 let messages: Vec<_> = join.try_collect().await.unwrap();
277 compare(
278 messages,
279 &[
280 ("a", 1),
281 ("b", 2),
282 ("a", 3),
283 ("b", 4),
284 ("a", 5),
285 ("c", 5),
286 ("b", 6),
287 ("a", 7),
288 ("b", 8),
289 ("a", 9),
290 ("b", 10),
291 ],
292 );
293 }
294
295 #[tokio::test]
296 async fn test_mux_streams_4() {
297 let a = StreamKey::new("a").unwrap();
298 let b = StreamKey::new("b").unwrap();
299 let c = StreamKey::new("c").unwrap();
300 let d = StreamKey::new("d").unwrap();
301 let stream = futures::stream::iter(
302 make_seq(a.clone(), &[1, 3])
303 .into_iter()
304 .chain(make_seq(d.clone(), &[5]).into_iter())
305 .chain(make_seq(b.clone(), &[2, 4]).into_iter())
306 .chain(make_seq(c.clone(), &[3]).into_iter()),
307 );
308 let mut join = StreamJoin::muxed(stream);
309 join.align(a);
310 join.align(b);
311 join.align(c);
312 join.align(d);
313 let messages: Vec<_> = join.try_collect().await.unwrap();
314 compare(
315 messages,
316 &[("a", 1), ("b", 2), ("a", 3), ("c", 3), ("b", 4), ("d", 5)],
317 );
318 }
319}