hydro_deploy_integration/
lib.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::net::SocketAddr;
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use async_recursion::async_recursion;
11use async_trait::async_trait;
12use bytes::{Bytes, BytesMut};
13use futures::sink::Buffer;
14use futures::stream::{FuturesUnordered, SplitSink, SplitStream};
15use futures::{Future, Sink, SinkExt, Stream, StreamExt, stream};
16use serde::{Deserialize, Serialize};
17use tempfile::TempDir;
18use tokio::io;
19use tokio::net::{TcpListener, TcpStream};
20#[cfg(unix)]
21use tokio::net::{UnixListener, UnixStream};
22use tokio_stream::wrappers::TcpListenerStream;
23use tokio_util::codec::{Framed, LengthDelimitedCodec};
24
25pub mod multi_connection;
26pub mod single_connection;
27
28pub type InitConfig = (HashMap<String, ServerBindConfig>, Option<String>);
29
30/// Contains runtime information passed by Hydro Deploy to a program,
31/// describing how to connect to other services and metadata about them.
32pub struct DeployPorts<T = Option<()>> {
33    pub ports: RefCell<HashMap<String, Connection>>,
34    pub meta: T,
35}
36
37impl<T> DeployPorts<T> {
38    pub fn port(&self, name: &str) -> Connection {
39        self.ports
40            .try_borrow_mut()
41            .unwrap()
42            .remove(name)
43            .unwrap_or_else(|| panic!("port {} not found", name))
44    }
45}
46
47#[cfg(not(unix))]
48type UnixStream = std::convert::Infallible;
49
50#[cfg(not(unix))]
51type UnixListener = std::convert::Infallible;
52
53/// Describes how to connect to a service which is listening on some port.
54#[derive(Serialize, Deserialize, Clone, Debug)]
55pub enum ServerPort {
56    UnixSocket(PathBuf),
57    TcpPort(SocketAddr),
58    Demux(HashMap<u32, ServerPort>),
59    Merge(Vec<ServerPort>),
60    Tagged(Box<ServerPort>, u32),
61    Null,
62}
63
64impl ServerPort {
65    #[async_recursion]
66    pub async fn connect(&self) -> ClientConnection {
67        match self {
68            ServerPort::UnixSocket(path) => {
69                #[cfg(unix)]
70                {
71                    let bound = UnixStream::connect(path.clone());
72                    ClientConnection::UnixSocket(bound.await.unwrap())
73                }
74
75                #[cfg(not(unix))]
76                {
77                    let _ = path;
78                    panic!("Unix sockets are not supported on this platform")
79                }
80            }
81            ServerPort::TcpPort(addr) => {
82                let addr_clone = *addr;
83                let stream = async_retry(
84                    move || TcpStream::connect(addr_clone),
85                    10,
86                    Duration::from_secs(1),
87                )
88                .await
89                .unwrap();
90                ClientConnection::TcpPort(stream)
91            }
92            ServerPort::Demux(bindings) => ClientConnection::Demux(
93                bindings
94                    .iter()
95                    .map(|(k, v)| async move { (*k, v.connect().await) })
96                    .collect::<FuturesUnordered<_>>()
97                    .collect::<Vec<_>>()
98                    .await
99                    .into_iter()
100                    .collect(),
101            ),
102            ServerPort::Merge(ports) => ClientConnection::Merge(
103                ports
104                    .iter()
105                    .map(|p| p.connect())
106                    .collect::<FuturesUnordered<_>>()
107                    .collect::<Vec<_>>()
108                    .await
109                    .into_iter()
110                    .collect(),
111            ),
112            ServerPort::Tagged(port, tag) => {
113                ClientConnection::Tagged(Box::new(port.as_ref().connect().await), *tag)
114            }
115            ServerPort::Null => ClientConnection::Null,
116        }
117    }
118
119    pub async fn instantiate(&self) -> Connection {
120        Connection::AsClient(self.connect().await)
121    }
122}
123
124#[derive(Debug)]
125pub enum ClientConnection {
126    UnixSocket(UnixStream),
127    TcpPort(TcpStream),
128    Demux(HashMap<u32, ClientConnection>),
129    Merge(Vec<ClientConnection>),
130    Tagged(Box<ClientConnection>, u32),
131    Null,
132}
133
134#[derive(Serialize, Deserialize, Clone, Debug)]
135pub enum ServerBindConfig {
136    UnixSocket,
137    TcpPort(
138        /// The host the port should be bound on.
139        String,
140        /// The port the service should listen on.
141        ///
142        /// If `None`, the port will be chosen automatically.
143        Option<u16>,
144    ),
145    Demux(HashMap<u32, ServerBindConfig>),
146    Merge(Vec<ServerBindConfig>),
147    Tagged(Box<ServerBindConfig>, u32),
148    MultiConnection(Box<ServerBindConfig>),
149    Null,
150}
151
152impl ServerBindConfig {
153    #[async_recursion]
154    pub async fn bind(self) -> BoundServer {
155        match self {
156            ServerBindConfig::UnixSocket => {
157                #[cfg(unix)]
158                {
159                    let dir = tempfile::tempdir().unwrap();
160                    let socket_path = dir.path().join("socket");
161                    let bound = UnixListener::bind(socket_path).unwrap();
162                    BoundServer::UnixSocket(bound, dir)
163                }
164
165                #[cfg(not(unix))]
166                {
167                    panic!("Unix sockets are not supported on this platform")
168                }
169            }
170            ServerBindConfig::TcpPort(host, port) => {
171                let listener = TcpListener::bind((host, port.unwrap_or(0))).await.unwrap();
172                let addr = listener.local_addr().unwrap();
173                BoundServer::TcpPort(TcpListenerStream::new(listener), addr)
174            }
175            ServerBindConfig::Demux(bindings) => {
176                let mut demux = HashMap::new();
177                for (key, bind) in bindings {
178                    demux.insert(key, bind.bind().await);
179                }
180                BoundServer::Demux(demux)
181            }
182            ServerBindConfig::Merge(bindings) => {
183                let mut merge = Vec::new();
184                for bind in bindings {
185                    merge.push(bind.bind().await);
186                }
187                BoundServer::Merge(merge)
188            }
189            ServerBindConfig::Tagged(underlying, id) => {
190                BoundServer::Tagged(Box::new(underlying.bind().await), id)
191            }
192            ServerBindConfig::MultiConnection(underlying) => {
193                BoundServer::MultiConnection(Box::new(underlying.bind().await))
194            }
195            ServerBindConfig::Null => BoundServer::Null,
196        }
197    }
198}
199
200#[derive(Debug)]
201pub enum Connection {
202    AsClient(ClientConnection),
203    AsServer(AcceptedServer),
204}
205
206impl Connection {
207    pub fn connect<T: Connected>(self) -> T {
208        T::from_defn(self)
209    }
210}
211
212pub type DynStream = Pin<Box<dyn Stream<Item = Result<BytesMut, io::Error>> + Send + Sync>>;
213
214pub type DynSink<Input> = Pin<Box<dyn Sink<Input, Error = io::Error> + Send + Sync>>;
215
216pub trait StreamSink:
217    Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>
218{
219}
220impl<T: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error>> StreamSink
221    for T
222{
223}
224
225pub type DynStreamSink = Pin<Box<dyn StreamSink + Send + Sync>>;
226
227pub trait Connected: Send {
228    fn from_defn(pipe: Connection) -> Self;
229}
230
231pub trait ConnectedSink {
232    type Input: Send;
233    type Sink: Sink<Self::Input, Error = io::Error> + Send + Sync;
234
235    fn into_sink(self) -> Self::Sink;
236}
237
238pub trait ConnectedSource {
239    type Output: Send;
240    type Stream: Stream<Item = Result<Self::Output, io::Error>> + Send + Sync;
241    fn into_source(self) -> Self::Stream;
242}
243
244#[derive(Debug)]
245pub enum BoundServer {
246    UnixSocket(UnixListener, TempDir),
247    TcpPort(TcpListenerStream, SocketAddr),
248    Demux(HashMap<u32, BoundServer>),
249    Merge(Vec<BoundServer>),
250    Tagged(Box<BoundServer>, u32),
251    MultiConnection(Box<BoundServer>),
252    Null,
253}
254
255#[derive(Debug)]
256pub enum AcceptedServer {
257    UnixSocket(UnixStream, TempDir),
258    TcpPort(TcpStream),
259    Demux(HashMap<u32, AcceptedServer>),
260    Merge(Vec<AcceptedServer>),
261    Tagged(Box<AcceptedServer>, u32),
262    MultiConnection(Box<BoundServer>),
263    Null,
264}
265
266#[async_recursion]
267pub async fn accept_bound(bound: BoundServer) -> AcceptedServer {
268    match bound {
269        BoundServer::UnixSocket(listener, dir) => {
270            #[cfg(unix)]
271            {
272                let stream = listener.accept().await.unwrap().0;
273                AcceptedServer::UnixSocket(stream, dir)
274            }
275
276            #[cfg(not(unix))]
277            {
278                let _ = listener;
279                let _ = dir;
280                panic!("Unix sockets are not supported on this platform")
281            }
282        }
283        BoundServer::TcpPort(mut listener, _) => {
284            let stream = listener.next().await.unwrap().unwrap();
285            AcceptedServer::TcpPort(stream)
286        }
287        BoundServer::Demux(bindings) => AcceptedServer::Demux(
288            bindings
289                .into_iter()
290                .map(|(k, b)| async move { (k, accept_bound(b).await) })
291                .collect::<FuturesUnordered<_>>()
292                .collect::<Vec<_>>()
293                .await
294                .into_iter()
295                .collect(),
296        ),
297        BoundServer::Merge(merge) => AcceptedServer::Merge(
298            merge
299                .into_iter()
300                .map(|b| async move { accept_bound(b).await })
301                .collect::<FuturesUnordered<_>>()
302                .collect::<Vec<_>>()
303                .await,
304        ),
305        BoundServer::Tagged(underlying, id) => {
306            AcceptedServer::Tagged(Box::new(accept_bound(*underlying).await), id)
307        }
308        BoundServer::MultiConnection(underlying) => AcceptedServer::MultiConnection(underlying),
309        BoundServer::Null => AcceptedServer::Null,
310    }
311}
312
313impl BoundServer {
314    pub fn server_port(&self) -> ServerPort {
315        match self {
316            BoundServer::UnixSocket(_, tempdir) => {
317                #[cfg(unix)]
318                {
319                    ServerPort::UnixSocket(tempdir.path().join("socket"))
320                }
321
322                #[cfg(not(unix))]
323                {
324                    let _ = tempdir;
325                    panic!("Unix sockets are not supported on this platform")
326                }
327            }
328            BoundServer::TcpPort(_, addr) => {
329                ServerPort::TcpPort(SocketAddr::new(addr.ip(), addr.port()))
330            }
331
332            BoundServer::Demux(bindings) => {
333                let mut demux = HashMap::new();
334                for (key, bind) in bindings {
335                    demux.insert(*key, bind.server_port());
336                }
337                ServerPort::Demux(demux)
338            }
339
340            BoundServer::Merge(bindings) => {
341                let mut merge = Vec::new();
342                for bind in bindings {
343                    merge.push(bind.server_port());
344                }
345                ServerPort::Merge(merge)
346            }
347
348            BoundServer::Tagged(underlying, id) => {
349                ServerPort::Tagged(Box::new(underlying.server_port()), *id)
350            }
351
352            BoundServer::MultiConnection(underlying) => underlying.server_port(),
353
354            BoundServer::Null => ServerPort::Null,
355        }
356    }
357}
358
359fn accept(bound: AcceptedServer) -> ConnectedDirect {
360    match bound {
361        AcceptedServer::UnixSocket(stream, _dir) => {
362            #[cfg(unix)]
363            {
364                ConnectedDirect {
365                    stream_sink: Some(Box::pin(unix_bytes(stream))),
366                    source_only: None,
367                    sink_only: None,
368                }
369            }
370
371            #[cfg(not(unix))]
372            {
373                let _ = stream;
374                panic!("Unix sockets are not supported on this platform")
375            }
376        }
377        AcceptedServer::TcpPort(stream) => ConnectedDirect {
378            stream_sink: Some(Box::pin(tcp_bytes(stream))),
379            source_only: None,
380            sink_only: None,
381        },
382        AcceptedServer::Merge(merge) => {
383            let mut sources = vec![];
384            for bound in merge {
385                sources.push(Some(Box::pin(accept(bound).into_source())));
386            }
387
388            let merge_source: DynStream = Box::pin(MergeSource {
389                marker: PhantomData,
390                sources,
391                poll_cursor: 0,
392            });
393
394            ConnectedDirect {
395                stream_sink: None,
396                source_only: Some(merge_source),
397                sink_only: None,
398            }
399        }
400        AcceptedServer::Demux(_) => panic!("Cannot connect to a demux pipe directly"),
401        AcceptedServer::Tagged(_, _) => panic!("Cannot connect to a tagged pipe directly"),
402        AcceptedServer::MultiConnection(_) => {
403            panic!("Cannot connect to a multi-connection pipe directly")
404        }
405        AcceptedServer::Null => {
406            ConnectedDirect::from_defn(Connection::AsClient(ClientConnection::Null))
407        }
408    }
409}
410
411fn tcp_bytes(stream: TcpStream) -> impl StreamSink {
412    Framed::new(stream, LengthDelimitedCodec::new())
413}
414
415#[cfg(unix)]
416fn unix_bytes(stream: UnixStream) -> impl StreamSink {
417    Framed::new(stream, LengthDelimitedCodec::new())
418}
419
420struct IoErrorDrain<T> {
421    marker: PhantomData<T>,
422}
423
424impl<T> Sink<T> for IoErrorDrain<T> {
425    type Error = io::Error;
426
427    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428        Poll::Ready(Ok(()))
429    }
430
431    fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
432        Ok(())
433    }
434
435    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436        Poll::Ready(Ok(()))
437    }
438
439    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
440        Poll::Ready(Ok(()))
441    }
442}
443
444async fn async_retry<T, E, F: Future<Output = Result<T, E>>>(
445    thunk: impl Fn() -> F,
446    count: usize,
447    delay: Duration,
448) -> Result<T, E> {
449    for _ in 1..count {
450        let result = thunk().await;
451        if result.is_ok() {
452            return result;
453        } else {
454            tokio::time::sleep(delay).await;
455        }
456    }
457
458    thunk().await
459}
460
461pub struct ConnectedDirect {
462    stream_sink: Option<DynStreamSink>,
463    source_only: Option<DynStream>,
464    sink_only: Option<DynSink<Bytes>>,
465}
466
467impl ConnectedDirect {
468    pub fn into_source_sink(self) -> (SplitStream<DynStreamSink>, SplitSink<DynStreamSink, Bytes>) {
469        let (sink, stream) = self.stream_sink.unwrap().split();
470        (stream, sink)
471    }
472}
473
474impl Connected for ConnectedDirect {
475    fn from_defn(pipe: Connection) -> Self {
476        match pipe {
477            Connection::AsClient(ClientConnection::UnixSocket(stream)) => {
478                #[cfg(unix)]
479                {
480                    ConnectedDirect {
481                        stream_sink: Some(Box::pin(unix_bytes(stream))),
482                        source_only: None,
483                        sink_only: None,
484                    }
485                }
486
487                #[cfg(not(unix))]
488                {
489                    let _ = stream;
490                    panic!("Unix sockets are not supported on this platform");
491                }
492            }
493            Connection::AsClient(ClientConnection::TcpPort(stream)) => {
494                stream.set_nodelay(true).unwrap();
495                ConnectedDirect {
496                    stream_sink: Some(Box::pin(tcp_bytes(stream))),
497                    source_only: None,
498                    sink_only: None,
499                }
500            }
501            Connection::AsClient(ClientConnection::Merge(merge)) => {
502                let sources = merge
503                    .into_iter()
504                    .map(|port| {
505                        Some(Box::pin(
506                            ConnectedDirect::from_defn(Connection::AsClient(port)).into_source(),
507                        ))
508                    })
509                    .collect::<Vec<_>>();
510
511                let merged = MergeSource {
512                    marker: PhantomData,
513                    sources,
514                    poll_cursor: 0,
515                };
516
517                ConnectedDirect {
518                    stream_sink: None,
519                    source_only: Some(Box::pin(merged)),
520                    sink_only: None,
521                }
522            }
523            Connection::AsClient(ClientConnection::Demux(_)) => {
524                panic!("Cannot connect to a demux pipe directly")
525            }
526
527            Connection::AsClient(ClientConnection::Tagged(_, _)) => {
528                panic!("Cannot connect to a tagged pipe directly")
529            }
530
531            Connection::AsClient(ClientConnection::Null) => ConnectedDirect {
532                stream_sink: None,
533                source_only: Some(Box::pin(stream::empty())),
534                sink_only: Some(Box::pin(IoErrorDrain {
535                    marker: PhantomData,
536                })),
537            },
538
539            Connection::AsServer(bound) => accept(bound),
540        }
541    }
542}
543
544impl ConnectedSource for ConnectedDirect {
545    type Output = BytesMut;
546    type Stream = DynStream;
547
548    fn into_source(mut self) -> DynStream {
549        if let Some(s) = self.stream_sink.take() {
550            Box::pin(s)
551        } else {
552            self.source_only.take().unwrap()
553        }
554    }
555}
556
557impl ConnectedSink for ConnectedDirect {
558    type Input = Bytes;
559    type Sink = DynSink<Bytes>;
560
561    fn into_sink(mut self) -> DynSink<Self::Input> {
562        if let Some(s) = self.stream_sink.take() {
563            Box::pin(s)
564        } else {
565            self.sink_only.take().unwrap()
566        }
567    }
568}
569
570pub type BufferedDrain<S, I> = sinktools::demux_map::DemuxMap<u32, Pin<Box<Buffer<S, I>>>>;
571
572pub struct ConnectedDemux<T: ConnectedSink>
573where
574    <T as ConnectedSink>::Input: Sync,
575{
576    pub keys: Vec<u32>,
577    sink: Option<BufferedDrain<T::Sink, T::Input>>,
578}
579
580#[async_trait]
581impl<T: Connected + ConnectedSink> Connected for ConnectedDemux<T>
582where
583    <T as ConnectedSink>::Input: 'static + Sync,
584{
585    fn from_defn(pipe: Connection) -> Self {
586        match pipe {
587            Connection::AsClient(ClientConnection::Demux(demux)) => {
588                let mut connected_demux = HashMap::new();
589                let keys = demux.keys().cloned().collect();
590                for (id, pipe) in demux {
591                    connected_demux.insert(
592                        id,
593                        Box::pin(
594                            T::from_defn(Connection::AsClient(pipe))
595                                .into_sink()
596                                .buffer(1024),
597                        ),
598                    );
599                }
600
601                let demuxer = sinktools::demux_map(connected_demux);
602
603                ConnectedDemux {
604                    keys,
605                    sink: Some(demuxer),
606                }
607            }
608
609            Connection::AsServer(AcceptedServer::Demux(demux)) => {
610                let mut connected_demux = HashMap::new();
611                let keys = demux.keys().cloned().collect();
612                for (id, bound) in demux {
613                    connected_demux.insert(
614                        id,
615                        Box::pin(
616                            T::from_defn(Connection::AsServer(bound))
617                                .into_sink()
618                                .buffer(1024),
619                        ),
620                    );
621                }
622
623                let demuxer = sinktools::demux_map(connected_demux);
624
625                ConnectedDemux {
626                    keys,
627                    sink: Some(demuxer),
628                }
629            }
630            _ => panic!("Cannot connect to a non-demux pipe as a demux"),
631        }
632    }
633}
634
635impl<T: ConnectedSink> ConnectedSink for ConnectedDemux<T>
636where
637    <T as ConnectedSink>::Input: 'static + Sync,
638{
639    type Input = (u32, T::Input);
640    type Sink = BufferedDrain<T::Sink, T::Input>;
641
642    fn into_sink(mut self) -> Self::Sink {
643        self.sink.take().unwrap()
644    }
645}
646
647pub struct MergeSource<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> {
648    marker: PhantomData<T>,
649    /// Ordered list for fair polling, will never be `None` at the beginning of a poll
650    sources: Vec<Option<Pin<Box<S>>>>,
651    /// Cursor for fair round-robin polling
652    poll_cursor: usize,
653}
654
655impl<T: Unpin, S: Stream<Item = T> + Send + Sync + ?Sized> Stream for MergeSource<T, S> {
656    type Item = T;
657
658    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
659        let me = self.get_mut();
660        let mut out = Poll::Pending;
661        let mut any_removed = false;
662
663        if !me.sources.is_empty() {
664            let start_cursor = me.poll_cursor;
665
666            loop {
667                let current_length = me.sources.len();
668                let source = &mut me.sources[me.poll_cursor];
669
670                // Move cursor to next source for next poll
671                me.poll_cursor = (me.poll_cursor + 1) % current_length;
672
673                match source.as_mut().unwrap().as_mut().poll_next(cx) {
674                    Poll::Ready(Some(data)) => {
675                        out = Poll::Ready(Some(data));
676                        break;
677                    }
678                    Poll::Ready(None) => {
679                        *source = None; // Mark source as removed
680                        any_removed = true;
681                    }
682                    Poll::Pending => {}
683                }
684
685                // Check if we've completed a full round
686                if me.poll_cursor == start_cursor {
687                    break;
688                }
689            }
690        }
691
692        // Clean up None entries and adjust cursor
693        let mut current_index = 0;
694        let original_cursor = me.poll_cursor;
695
696        if any_removed {
697            me.sources.retain(|source| {
698                if source.is_none() && current_index < original_cursor {
699                    me.poll_cursor -= 1;
700                }
701                current_index += 1;
702                source.is_some()
703            });
704        }
705
706        if me.poll_cursor == me.sources.len() {
707            me.poll_cursor = 0;
708        }
709
710        if me.sources.is_empty() {
711            Poll::Ready(None)
712        } else {
713            out
714        }
715    }
716}
717
718pub struct TaggedSource<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> {
719    marker: PhantomData<T>,
720    id: u32,
721    source: Pin<Box<S>>,
722}
723
724impl<T: Unpin, S: Stream<Item = Result<T, io::Error>> + Send + Sync + ?Sized> Stream
725    for TaggedSource<T, S>
726{
727    type Item = Result<(u32, T), io::Error>;
728
729    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
730        let id = self.as_ref().id;
731        let source = &mut self.get_mut().source;
732        match source.as_mut().poll_next(cx) {
733            Poll::Ready(Some(v)) => Poll::Ready(Some(v.map(|d| (id, d)))),
734            Poll::Ready(None) => Poll::Ready(None),
735            Poll::Pending => Poll::Pending,
736        }
737    }
738}
739
740type MergedMux<T> = MergeSource<
741    Result<(u32, <T as ConnectedSource>::Output), io::Error>,
742    TaggedSource<<T as ConnectedSource>::Output, <T as ConnectedSource>::Stream>,
743>;
744
745pub struct ConnectedTagged<T: ConnectedSource>
746where
747    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
748{
749    source: MergedMux<T>,
750}
751
752#[async_trait]
753impl<T: Connected + ConnectedSource> Connected for ConnectedTagged<T>
754where
755    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
756{
757    fn from_defn(pipe: Connection) -> Self {
758        let sources = match pipe {
759            Connection::AsClient(ClientConnection::Tagged(pipe, id)) => {
760                vec![(
761                    Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
762                    id,
763                )]
764            }
765
766            Connection::AsClient(ClientConnection::Merge(m)) => {
767                let mut sources = Vec::new();
768                for port in m {
769                    if let ClientConnection::Tagged(pipe, id) = port {
770                        sources.push((
771                            Box::pin(T::from_defn(Connection::AsClient(*pipe)).into_source()),
772                            id,
773                        ));
774                    } else {
775                        panic!("Merge port must be tagged");
776                    }
777                }
778
779                sources
780            }
781
782            Connection::AsServer(AcceptedServer::Tagged(pipe, id)) => {
783                vec![(
784                    Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
785                    id,
786                )]
787            }
788
789            Connection::AsServer(AcceptedServer::Merge(m)) => {
790                let mut sources = Vec::new();
791                for port in m {
792                    if let AcceptedServer::Tagged(pipe, id) = port {
793                        sources.push((
794                            Box::pin(T::from_defn(Connection::AsServer(*pipe)).into_source()),
795                            id,
796                        ));
797                    } else {
798                        panic!("Merge port must be tagged");
799                    }
800                }
801
802                sources
803            }
804
805            _ => panic!("Cannot connect to a non-tagged pipe as a tagged"),
806        };
807
808        let mut connected_mux = Vec::new();
809        for (pipe, id) in sources {
810            connected_mux.push(Some(Box::pin(TaggedSource {
811                marker: PhantomData,
812                id,
813                source: pipe,
814            })));
815        }
816
817        let muxer = MergeSource {
818            marker: PhantomData,
819            sources: connected_mux,
820            poll_cursor: 0,
821        };
822
823        ConnectedTagged { source: muxer }
824    }
825}
826
827impl<T: ConnectedSource> ConnectedSource for ConnectedTagged<T>
828where
829    <T as ConnectedSource>::Output: 'static + Sync + Unpin,
830{
831    type Output = (u32, T::Output);
832    type Stream = MergeSource<Result<Self::Output, io::Error>, TaggedSource<T::Output, T::Stream>>;
833
834    fn into_source(self) -> Self::Stream {
835        self.source
836    }
837}
838
839#[cfg(test)]
840mod tests {
841    use std::sync::Arc;
842    use std::task::{Context, Poll};
843
844    use futures::stream;
845
846    use super::*;
847
848    struct TestWaker;
849    impl std::task::Wake for TestWaker {
850        fn wake(self: Arc<Self>) {}
851    }
852
853    #[test]
854    fn test_merge_source_fair_polling() {
855        // Create test streams that yield values in a predictable pattern
856        let stream1 = Box::pin(stream::iter(vec![1, 4, 7]));
857        let stream2 = Box::pin(stream::iter(vec![2, 5, 8]));
858        let stream3 = Box::pin(stream::iter(vec![3, 6, 9]));
859
860        let mut merge_source = MergeSource {
861            marker: PhantomData,
862            sources: vec![Some(stream1), Some(stream2), Some(stream3)],
863            poll_cursor: 0,
864        };
865
866        let waker = Arc::new(TestWaker).into();
867        let mut cx = Context::from_waker(&waker);
868
869        let mut results = Vec::new();
870
871        // Poll until all streams are exhausted
872        loop {
873            match Pin::new(&mut merge_source).poll_next(&mut cx) {
874                Poll::Ready(Some(value)) => results.push(value),
875                Poll::Ready(None) => break,
876                Poll::Pending => break, // Shouldn't happen with our test streams
877            }
878        }
879
880        // With fair polling, we should get values in round-robin order: 1, 2, 3, 4, 5, 6, 7, 8, 9
881        assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
882    }
883}