1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use async_recursion::async_recursion;
11use async_trait::async_trait;
12use bytes::{Bytes, BytesMut};
13use futures::sink::Buffer;
14use futures::stream::{FuturesUnordered, SplitSink, SplitStream};
15use futures::{Future, Sink, SinkExt, Stream, StreamExt, stream};
16use serde::{Deserialize, Serialize};
17use tempfile::TempDir;
18use tokio::io;
19use tokio::net::{TcpListener, TcpStream};
20#[cfg(unix)]
21use tokio::net::{UnixListener, UnixStream};
22use tokio_stream::wrappers::TcpListenerStream;
23use tokio_util::codec::{Framed, LengthDelimitedCodec};
24
25pub mod multi_connection;
26pub mod single_connection;
27
28pub type InitConfig = (HashMap<String, ServerBindConfig>, Option<String>);
29
30pub struct DeployPorts<T = Option<()>> {
33 pub ports: RefCell<HashMap<String, Connection>>,
34 pub meta: T,
35}
36
37impl<T> DeployPorts<T> {
38 pub fn port(&self, name: &str) -> Connection {
39 self.ports
40 .try_borrow_mut()
41 .unwrap()
42 .remove(name)
43 .unwrap_or_else(|| panic!("port {} not found", name))
44 }
45}
46
47#[cfg(not(unix))]
48type UnixStream = std::convert::Infallible;
49
50#[cfg(not(unix))]
51type UnixListener = std::convert::Infallible;
52
53#[derive(Serialize, Deserialize, Clone, Debug)]
55pub enum ServerPort {
56 UnixSocket(PathBuf),
57 TcpPort(SocketAddr),
58 Demux(HashMap<u32, ServerPort>),
59 Merge(Vec<ServerPort>),
60 Tagged(Box<ServerPort>, u32),
61 Null,
62}
63
64impl ServerPort {
65 #[async_recursion]
66 pub async fn connect(&self) -> ClientConnection {
67 match self {
68 ServerPort::UnixSocket(path) => {
69 #[cfg(unix)]
70 {
71 let bound = UnixStream::connect(path.clone());
72 ClientConnection::UnixSocket(bound.await.unwrap())
73 }
74
75 #[cfg(not(unix))]
76 {
77 let _ = path;
78 panic!("Unix sockets are not supported on this platform")
79 }
80 }
81 ServerPort::TcpPort(addr) => {
82 let addr_clone = *addr;
83 let stream = async_retry(
84 move || TcpStream::connect(addr_clone),
85 10,
86 Duration::from_secs(1),
87 )
88 .await
89 .unwrap();
90 ClientConnection::TcpPort(stream)
91 }
92 ServerPort::Demux(bindings) => ClientConnection::Demux(
93 bindings
94 .iter()
95 .map(|(k, v)| async move { (*k, v.connect().await) })
96 .collect::<FuturesUnordered<_>>()
97 .collect::<Vec<_>>()
98 .await
99 .into_iter()
100 .collect(),
101 ),
102 ServerPort::Merge(ports) => ClientConnection::Merge(
103 ports
104 .iter()
105 .map(|p| p.connect())
106 .collect::<FuturesUnordered<_>>()
107 .collect::<Vec<_>>()
108 .await
109 .into_iter()
110 .collect(),
111 ),
112 ServerPort::Tagged(port, tag) => {
113 ClientConnection::Tagged(Box::new(port.as_ref().connect().await), *tag)
114 }
115 ServerPort::Null => ClientConnection::Null,
116 }
117 }
118
119 pub async fn instantiate(&self) -> Connection {
120 Connection::AsClient(self.connect().await)
121 }
122}
123
124#[derive(Debug)]
125pub enum ClientConnection {
126 UnixSocket(UnixStream),
127 TcpPort(TcpStream),
128 Demux(HashMap<u32, ClientConnection>),
129 Merge(Vec<ClientConnection>),
130 Tagged(Box<ClientConnection>, u32),
131 Null,
132}
133
134#[derive(Serialize, Deserialize, Clone, Debug)]
135pub enum ServerBindConfig {
136 UnixSocket,
137 TcpPort(
138 String,
140 Option<u16>,
144 ),
145 Demux(HashMap<u32, ServerBindConfig>),
146 Merge(Vec<ServerBindConfig>),
147 Tagged(Box<ServerBindConfig>, u32),
148 MultiConnection(Box<ServerBindConfig>),
149 Null,
150}
151
152impl ServerBindConfig {
153 #[async_recursion]
154 pub async fn bind(self) -> BoundServer {
155 match self {
156 ServerBindConfig::UnixSocket => {
157 #[cfg(unix)]
158 {
159 let dir = tempfile::tempdir().unwrap();
160 let socket_path = dir.path().join("socket");
161 let bound = UnixListener::bind(socket_path).unwrap();
162 BoundServer::UnixSocket(bound, dir)
163 }
164
165 #[cfg(not(unix))]
166 {
167 panic!("Unix sockets are not supported on this platform")
168 }
169 }
170 ServerBindConfig::TcpPort(host, port) => {
171 let listener = TcpListener::bind((host, port.unwrap_or(0))).await.unwrap();
172 let addr = listener.local_addr().unwrap();
173 BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
174 }
175 ServerBindConfig::Demux(bindings) => {
176 let mut demux = HashMap::new();
177 for (key, bind) in bindings {
178 demux.insert(key, bind.bind().await);
179 }
180 BoundServer::Demux(demux)
181 }
182 ServerBindConfig::Merge(bindings) => {
183 let mut merge = Vec::new();
184 for bind in bindings {
185 merge.push(bind.bind().await);
186 }
187 BoundServer::Merge(merge)
188 }
189 ServerBindConfig::Tagged(underlying, id) => {
190 BoundServer::Tagged(Box::new(underlying.bind().await), id)
191 }
192 ServerBindConfig::MultiConnection(underlying) => {
193 BoundServer::MultiConnection(Box::new(underlying.bind().await))
194 }
195 ServerBindConfig::Null => BoundServer::Null,
196 }
197 }
198}
199
200#[derive(Debug)]
201pub enum Connection {
202 AsClient(ClientConnection),
203 AsServer(AcceptedServer),
204}
205
206impl Connection {
207 pub fn connect<T: Connected>(self) -> T {
208 T::from_defn(self)
209 }
210}
211
212pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
213
214pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
215
216pub trait StreamSink:
217 Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
218{
219}
220impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
221 for T
222{
223}
224
225pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
226
227pub trait Connected: Send {
228 fn from_defn(pipe: Connection) -> Self;
229}
230
231pub trait ConnectedSink {
232 type Input: Send;
233 type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
234
235 fn into_sink(self) -> Self::Sink;
236}
237
238pub trait ConnectedSource {
239 type Output: Send;
240 type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
241 fn into_source(self) -> Self::Stream;
242}
243
244#[derive(Debug)]
245pub enum BoundServer {
246 UnixSocket(UnixListener, TempDir),
247 TcpPort(TcpListenerStream, SocketAddr),
248 Demux(HashMap<u32, BoundServer>),
249 Merge(Vec<BoundServer>),
250 Tagged(Box<BoundServer>, u32),
251 MultiConnection(Box<BoundServer>),
252 Null,
253}
254
255#[derive(Debug)]
256pub enum AcceptedServer {
257 UnixSocket(UnixStream, TempDir),
258 TcpPort(TcpStream),
259 Demux(HashMap<u32, AcceptedServer>),
260 Merge(Vec<AcceptedServer>),
261 Tagged(Box<AcceptedServer>, u32),
262 MultiConnection(Box<BoundServer>),
263 Null,
264}
265
266#[async_recursion]
267pub async fn accept_bound(bound: BoundServer) -> AcceptedServer {
268 match bound {
269 BoundServer::UnixSocket(listener, dir) => {
270 #[cfg(unix)]
271 {
272 let stream = listener.accept().await.unwrap().0;
273 AcceptedServer::UnixSocket(stream, dir)
274 }
275
276 #[cfg(not(unix))]
277 {
278 let _ = listener;
279 let _ = dir;
280 panic!("Unix sockets are not supported on this platform")
281 }
282 }
283 BoundServer::TcpPort(mut listener, _) => {
284 let stream = listener.next().await.unwrap().unwrap();
285 AcceptedServer::TcpPort(stream)
286 }
287 BoundServer::Demux(bindings) => AcceptedServer::Demux(
288 bindings
289 .into_iter()
290 .map(|(k, b)| async move { (k, accept_bound(b).await) })
291 .collect::<FuturesUnordered<_>>()
292 .collect::<Vec<_>>()
293 .await
294 .into_iter()
295 .collect(),
296 ),
297 BoundServer::Merge(merge) => AcceptedServer::Merge(
298 merge
299 .into_iter()
300 .map(|b| async move { accept_bound(b).await })
301 .collect::<FuturesUnordered<_>>()
302 .collect::<Vec<_>>()
303 .await,
304 ),
305 BoundServer::Tagged(underlying, id) => {
306 AcceptedServer::Tagged(Box::new(accept_bound(*underlying).await), id)
307 }
308 BoundServer::MultiConnection(underlying) => AcceptedServer::MultiConnection(underlying),
309 BoundServer::Null => AcceptedServer::Null,
310 }
311}
312
313impl BoundServer {
314 pub fn server_port(&self) -> ServerPort {
315 match self {
316 BoundServer::UnixSocket(_, tempdir) => {
317 #[cfg(unix)]
318 {
319 ServerPort::UnixSocket(tempdir.path().join("socket"))
320 }
321
322 #[cfg(not(unix))]
323 {
324 let _ = tempdir;
325 panic!("Unix sockets are not supported on this platform")
326 }
327 }
328 BoundServer::TcpPort(_, addr) => {
329 ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
330 }
331
332 BoundServer::Demux(bindings) => {
333 let mut demux = HashMap::new();
334 for (key, bind) in bindings {
335 demux.insert(*key, bind.server_port());
336 }
337 ServerPort::Demux(demux)
338 }
339
340 BoundServer::Merge(bindings) => {
341 let mut merge = Vec::new();
342 for bind in bindings {
343 merge.push(bind.server_port());
344 }
345 ServerPort::Merge(merge)
346 }
347
348 BoundServer::Tagged(underlying, id) => {
349 ServerPort::Tagged(Box::new(underlying.server_port()), *id)
350 }
351
352 BoundServer::MultiConnection(underlying) => underlying.server_port(),
353
354 BoundServer::Null => ServerPort::Null,
355 }
356 }
357}
358
359fn accept(bound: AcceptedServer) -> ConnectedDirect {
360 match bound {
361 AcceptedServer::UnixSocket(stream, _dir) => {
362 #[cfg(unix)]
363 {
364 ConnectedDirect {
365 stream_sink: Some(Box::pin(unix_bytes(stream))),
366 source_only: None,
367 sink_only: None,
368 }
369 }
370
371 #[cfg(not(unix))]
372 {
373 let _ = stream;
374 panic!("Unix sockets are not supported on this platform")
375 }
376 }
377 AcceptedServer::TcpPort(stream) => ConnectedDirect {
378 stream_sink: Some(Box::pin(tcp_bytes(stream))),
379 source_only: None,
380 sink_only: None,
381 },
382 AcceptedServer::Merge(merge) => {
383 let mut sources = vec![];
384 for bound in merge {
385 sources.push(Some(Box::pin(accept(bound).into_source())));
386 }
387
388 let merge_source: DynStream = Box::pin(MergeSource {
389 marker: PhantomData,
390 sources,
391 poll_cursor: 0,
392 });
393
394 ConnectedDirect {
395 stream_sink: None,
396 source_only: Some(merge_source),
397 sink_only: None,
398 }
399 }
400 AcceptedServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
401 AcceptedServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
402 AcceptedServer::MultiConnection(_) => {
403 panic!("Cannot connect to a multi-connection pipe directly")
404 }
405 AcceptedServer::Null => {
406 ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null))
407 }
408 }
409}
410
411fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
412 Framed::new(stream, LengthDelimitedCodec::new())
413}
414
415#[cfg(unix)]
416fn unix_bytes(stream: UnixStream) -> impl StreamSink {
417 Framed::new(stream, LengthDelimitedCodec::new())
418}
419
420struct IoErrorDrain<T> {
421 marker: PhantomData<T>,
422}
423
424impl<T> Sink<T> for IoErrorDrain<T> {
425 type Error = io::Error;
426
427 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428 Poll::Ready(Ok(()))
429 }
430
431 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
432 Ok(())
433 }
434
435 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436 Poll::Ready(Ok(()))
437 }
438
439 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
440 Poll::Ready(Ok(()))
441 }
442}
443
444async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
445 thunk: impl Fn() -> F,
446 count: usize,
447 delay: Duration,
448) -> Result<T, E> {
449 for _ in 1..count {
450 let result = thunk().await;
451 if result.is_ok() {
452 return result;
453 } else {
454 tokio::time::sleep(delay).await;
455 }
456 }
457
458 thunk().await
459}
460
461pub struct ConnectedDirect {
462 stream_sink: Option<DynStreamSink>,
463 source_only: Option<DynStream>,
464 sink_only: Option<DynSink<Bytes>>,
465}
466
467impl ConnectedDirect {
468 pub fn into_source_sink(self) -> (SplitStream<DynStreamSink>, SplitSink<DynStreamSink, Bytes>) {
469 let (sink, stream) = self.stream_sink.unwrap().split();
470 (stream, sink)
471 }
472}
473
474impl Connected for ConnectedDirect {
475 fn from_defn(pipe: Connection) -> Self {
476 match pipe {
477 Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
478 #[cfg(unix)]
479 {
480 ConnectedDirect {
481 stream_sink: Some(Box::pin(unix_bytes(stream))),
482 source_only: None,
483 sink_only: None,
484 }
485 }
486
487 #[cfg(not(unix))]
488 {
489 let _ = stream;
490 panic!("Unix sockets are not supported on this platform");
491 }
492 }
493 Connection::AsClient(ClientConnection::TcpPort(stream)) => {
494 stream.set_nodelay(true).unwrap();
495 ConnectedDirect {
496 stream_sink: Some(Box::pin(tcp_bytes(stream))),
497 source_only: None,
498 sink_only: None,
499 }
500 }
501 Connection::AsClient(ClientConnection::Merge(merge)) => {
502 let sources = merge
503 .into_iter()
504 .map(|port| {
505 Some(Box::pin(
506 ConnectedDirect::from_defn(Connection::AsClient(port)).into_source(),
507 ))
508 })
509 .collect::<Vec<_>>();
510
511 let merged = MergeSource {
512 marker: PhantomData,
513 sources,
514 poll_cursor: 0,
515 };
516
517 ConnectedDirect {
518 stream_sink: None,
519 source_only: Some(Box::pin(merged)),
520 sink_only: None,
521 }
522 }
523 Connection::AsClient(ClientConnection::Demux(_)) => {
524 panic!("Cannot connect to a demux pipe directly")
525 }
526
527 Connection::AsClient(ClientConnection::Tagged(_, _)) => {
528 panic!("Cannot connect to a tagged pipe directly")
529 }
530
531 Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
532 stream_sink: None,
533 source_only: Some(Box::pin(stream::empty())),
534 sink_only: Some(Box::pin(IoErrorDrain {
535 marker: PhantomData,
536 })),
537 },
538
539 Connection::AsServer(bound) => accept(bound),
540 }
541 }
542}
543
544impl ConnectedSource for ConnectedDirect {
545 type Output = BytesMut;
546 type Stream = DynStream;
547
548 fn into_source(mut self) -> DynStream {
549 if let Some(s) = self.stream_sink.take() {
550 Box::pin(s)
551 } else {
552 self.source_only.take().unwrap()
553 }
554 }
555}
556
557impl ConnectedSink for ConnectedDirect {
558 type Input = Bytes;
559 type Sink = DynSink<Bytes>;
560
561 fn into_sink(mut self) -> DynSink<Self::Input> {
562 if let Some(s) = self.stream_sink.take() {
563 Box::pin(s)
564 } else {
565 self.sink_only.take().unwrap()
566 }
567 }
568}
569
570pub type BufferedDrain<S, I> = sinktools::demux_map::DemuxMap<u32, Pin<Box<Buffer<S, I>>>>;
571
572pub struct ConnectedDemux<T: ConnectedSink>
573where
574 <T as ConnectedSink>::Input: Sync,
575{
576 pub keys: Vec<u32>,
577 sink: Option<BufferedDrain<T::Sink, T::Input>>,
578}
579
580#[async_trait]
581impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
582where
583 <T as ConnectedSink>::Input: 'static + Sync,
584{
585 fn from_defn(pipe: Connection) -> Self {
586 match pipe {
587 Connection::AsClient(ClientConnection::Demux(demux)) => {
588 let mut connected_demux = HashMap::new();
589 let keys = demux.keys().cloned().collect();
590 for (id, pipe) in demux {
591 connected_demux.insert(
592 id,
593 Box::pin(
594 T::from_defn(Connection::AsClient(pipe))
595 .into_sink()
596 .buffer(1024),
597 ),
598 );
599 }
600
601 let demuxer = sinktools::demux_map(connected_demux);
602
603 ConnectedDemux {
604 keys,
605 sink: Some(demuxer),
606 }
607 }
608
609 Connection::AsServer(AcceptedServer::Demux(demux)) => {
610 let mut connected_demux = HashMap::new();
611 let keys = demux.keys().cloned().collect();
612 for (id, bound) in demux {
613 connected_demux.insert(
614 id,
615 Box::pin(
616 T::from_defn(Connection::AsServer(bound))
617 .into_sink()
618 .buffer(1024),
619 ),
620 );
621 }
622
623 let demuxer = sinktools::demux_map(connected_demux);
624
625 ConnectedDemux {
626 keys,
627 sink: Some(demuxer),
628 }
629 }
630 _ => panic!("Cannot connect to a non-demux pipe as a demux"),
631 }
632 }
633}
634
635impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
636where
637 <T as ConnectedSink>::Input: 'static + Sync,
638{
639 type Input = (u32, T::Input);
640 type Sink = BufferedDrain<T::Sink, T::Input>;
641
642 fn into_sink(mut self) -> Self::Sink {
643 self.sink.take().unwrap()
644 }
645}
646
647pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
648 marker: PhantomData<T>,
649 sources: Vec<Option<Pin<Box<S>>>>,
651 poll_cursor: usize,
653}
654
655impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
656 type Item = T;
657
658 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
659 let me = self.get_mut();
660 let mut out = Poll::Pending;
661 let mut any_removed = false;
662
663 if !me.sources.is_empty() {
664 let start_cursor = me.poll_cursor;
665
666 loop {
667 let current_length = me.sources.len();
668 let source = &mut me.sources[me.poll_cursor];
669
670 me.poll_cursor = (me.poll_cursor + 1) % current_length;
672
673 match source.as_mut().unwrap().as_mut().poll_next(cx) {
674 Poll::Ready(Some(data)) => {
675 out = Poll::Ready(Some(data));
676 break;
677 }
678 Poll::Ready(None) => {
679 *source = None; any_removed = true;
681 }
682 Poll::Pending => {}
683 }
684
685 if me.poll_cursor == start_cursor {
687 break;
688 }
689 }
690 }
691
692 let mut current_index = 0;
694 let original_cursor = me.poll_cursor;
695
696 if any_removed {
697 me.sources.retain(|source| {
698 if source.is_none() && current_index < original_cursor {
699 me.poll_cursor -= 1;
700 }
701 current_index += 1;
702 source.is_some()
703 });
704 }
705
706 if me.poll_cursor == me.sources.len() {
707 me.poll_cursor = 0;
708 }
709
710 if me.sources.is_empty() {
711 Poll::Ready(None)
712 } else {
713 out
714 }
715 }
716}
717
718pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
719 marker: PhantomData<T>,
720 id: u32,
721 source: Pin<Box<S>>,
722}
723
724impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
725 for TaggedSource<T, S>
726{
727 type Item = Result<(u32, T), io::Error>;
728
729 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
730 let id = self.as_ref().id;
731 let source = &mut self.get_mut().source;
732 match source.as_mut().poll_next(cx) {
733 Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
734 Poll::Ready(None) => Poll::Ready(None),
735 Poll::Pending => Poll::Pending,
736 }
737 }
738}
739
740type MergedMux<T> = MergeSource<
741 Result<(u32, <T as ConnectedSource>::Output), io::Error>,
742 TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
743>;
744
745pub struct ConnectedTagged<T: ConnectedSource>
746where
747 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
748{
749 source: MergedMux<T>,
750}
751
752#[async_trait]
753impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
754where
755 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
756{
757 fn from_defn(pipe: Connection) -> Self {
758 let sources = match pipe {
759 Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
760 vec![(
761 Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
762 id,
763 )]
764 }
765
766 Connection::AsClient(ClientConnection::Merge(m)) => {
767 let mut sources = Vec::new();
768 for port in m {
769 if let ClientConnection::Tagged(pipe, id) = port {
770 sources.push((
771 Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
772 id,
773 ));
774 } else {
775 panic!("Merge port must be tagged");
776 }
777 }
778
779 sources
780 }
781
782 Connection::AsServer(AcceptedServer::Tagged(pipe, id)) => {
783 vec![(
784 Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
785 id,
786 )]
787 }
788
789 Connection::AsServer(AcceptedServer::Merge(m)) => {
790 let mut sources = Vec::new();
791 for port in m {
792 if let AcceptedServer::Tagged(pipe, id) = port {
793 sources.push((
794 Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
795 id,
796 ));
797 } else {
798 panic!("Merge port must be tagged");
799 }
800 }
801
802 sources
803 }
804
805 _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
806 };
807
808 let mut connected_mux = Vec::new();
809 for (pipe, id) in sources {
810 connected_mux.push(Some(Box::pin(TaggedSource {
811 marker: PhantomData,
812 id,
813 source: pipe,
814 })));
815 }
816
817 let muxer = MergeSource {
818 marker: PhantomData,
819 sources: connected_mux,
820 poll_cursor: 0,
821 };
822
823 ConnectedTagged { source: muxer }
824 }
825}
826
827impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
828where
829 <T as ConnectedSource>::Output: 'static + Sync + Unpin,
830{
831 type Output = (u32, T::Output);
832 type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
833
834 fn into_source(self) -> Self::Stream {
835 self.source
836 }
837}
838
839#[cfg(test)]
840mod tests {
841 use std::sync::Arc;
842 use std::task::{Context, Poll};
843
844 use futures::stream;
845
846 use super::*;
847
848 struct TestWaker;
849 impl std::task::Wake for TestWaker {
850 fn wake(self: Arc<Self>) {}
851 }
852
853 #[test]
854 fn test_merge_source_fair_polling() {
855 let stream1 = Box::pin(stream::iter(vec![1, 4, 7]));
857 let stream2 = Box::pin(stream::iter(vec![2, 5, 8]));
858 let stream3 = Box::pin(stream::iter(vec![3, 6, 9]));
859
860 let mut merge_source = MergeSource {
861 marker: PhantomData,
862 sources: vec![Some(stream1), Some(stream2), Some(stream3)],
863 poll_cursor: 0,
864 };
865
866 let waker = Arc::new(TestWaker).into();
867 let mut cx = Context::from_waker(&waker);
868
869 let mut results = Vec::new();
870
871 loop {
873 match Pin::new(&mut merge_source).poll_next(&mut cx) {
874 Poll::Ready(Some(value)) => results.push(value),
875 Poll::Ready(None) => break,
876 Poll::Pending => break, }
878 }
879
880 assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
882 }
883}