1use async_trait::async_trait;
23use futures::future::BoxFuture;
24use futures::{
25 io::{IoSlice, IoSliceMut},
26 prelude::*,
27};
28use pin_project::pin_project;
29use std::{io::Error, pin::Pin, task::Context, task::Poll};
30
31use crate::identity::Keypair;
32use crate::muxing::{IReadWrite, IStreamMuxer, StreamInfo, StreamMuxer, StreamMuxerEx};
33use crate::secure_io::SecureInfo;
34use crate::transport::{ConnectionInfo, TransportError};
35use crate::upgrade::ProtocolName;
36use crate::{Multiaddr, PeerId, PublicKey};
37
38#[pin_project(project = EitherOutputProj)]
39#[derive(Debug, Copy, Clone)]
40pub enum EitherOutput<A, B> {
41 A(#[pin] A),
42 B(#[pin] B),
43}
44
45impl<A, B> AsyncRead for EitherOutput<A, B>
46where
47 A: AsyncRead,
48 B: AsyncRead,
49{
50 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize, Error>> {
51 match self.project() {
52 EitherOutputProj::A(a) => AsyncRead::poll_read(a, cx, buf),
53 EitherOutputProj::B(b) => AsyncRead::poll_read(b, cx, buf),
54 }
55 }
56
57 fn poll_read_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) -> Poll<Result<usize, Error>> {
58 match self.project() {
59 EitherOutputProj::A(a) => AsyncRead::poll_read_vectored(a, cx, bufs),
60 EitherOutputProj::B(b) => AsyncRead::poll_read_vectored(b, cx, bufs),
61 }
62 }
63}
64
65impl<A, B> AsyncWrite for EitherOutput<A, B>
66where
67 A: AsyncWrite,
68 B: AsyncWrite,
69{
70 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
71 match self.project() {
72 EitherOutputProj::A(a) => AsyncWrite::poll_write(a, cx, buf),
73 EitherOutputProj::B(b) => AsyncWrite::poll_write(b, cx, buf),
74 }
75 }
76
77 fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, Error>> {
78 match self.project() {
79 EitherOutputProj::A(a) => AsyncWrite::poll_write_vectored(a, cx, bufs),
80 EitherOutputProj::B(b) => AsyncWrite::poll_write_vectored(b, cx, bufs),
81 }
82 }
83
84 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
85 match self.project() {
86 EitherOutputProj::A(a) => AsyncWrite::poll_flush(a, cx),
87 EitherOutputProj::B(b) => AsyncWrite::poll_flush(b, cx),
88 }
89 }
90
91 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
92 match self.project() {
93 EitherOutputProj::A(a) => AsyncWrite::poll_close(a, cx),
94 EitherOutputProj::B(b) => AsyncWrite::poll_close(b, cx),
95 }
96 }
97}
98
99impl<A, B> SecureInfo for EitherOutput<A, B>
100where
101 A: SecureInfo,
102 B: SecureInfo,
103{
104 fn local_peer(&self) -> PeerId {
105 match self {
106 EitherOutput::A(a) => a.local_peer(),
107 EitherOutput::B(b) => b.local_peer(),
108 }
109 }
110
111 fn remote_peer(&self) -> PeerId {
112 match self {
113 EitherOutput::A(a) => a.remote_peer(),
114 EitherOutput::B(b) => b.remote_peer(),
115 }
116 }
117
118 fn local_priv_key(&self) -> Keypair {
119 match self {
120 EitherOutput::A(a) => a.local_priv_key(),
121 EitherOutput::B(b) => b.local_priv_key(),
122 }
123 }
124
125 fn remote_pub_key(&self) -> PublicKey {
126 match self {
127 EitherOutput::A(a) => a.remote_pub_key(),
128 EitherOutput::B(b) => b.remote_pub_key(),
129 }
130 }
131}
132
133impl<A, B> StreamInfo for EitherOutput<A, B>
134where
135 A: StreamInfo,
136 B: StreamInfo,
137{
138 fn id(&self) -> usize {
139 match self {
140 EitherOutput::A(a) => a.id(),
141 EitherOutput::B(b) => b.id(),
142 }
143 }
144}
145
146#[async_trait]
147impl<A, B> StreamMuxer for EitherOutput<A, B>
148where
149 A: StreamMuxer + Send,
150 B: StreamMuxer + Send,
151{
152 async fn open_stream(&mut self) -> Result<IReadWrite, TransportError> {
153 match self {
154 EitherOutput::A(a) => Ok(a.open_stream().await?),
155 EitherOutput::B(b) => Ok(b.open_stream().await?),
156 }
157 }
158
159 async fn accept_stream(&mut self) -> Result<IReadWrite, TransportError> {
160 match self {
161 EitherOutput::A(a) => Ok(a.accept_stream().await?),
162 EitherOutput::B(b) => Ok(b.accept_stream().await?),
163 }
164 }
165
166 async fn close(&mut self) -> Result<(), TransportError> {
167 match self {
168 EitherOutput::A(a) => a.close().await,
169 EitherOutput::B(b) => b.close().await,
170 }
171 }
172
173 fn task(&mut self) -> Option<BoxFuture<'static, ()>> {
174 match self {
175 EitherOutput::A(a) => a.task(),
176 EitherOutput::B(b) => b.task(),
177 }
178 }
179
180 fn box_clone(&self) -> IStreamMuxer {
181 match self {
182 EitherOutput::A(a) => a.box_clone(),
183 EitherOutput::B(b) => b.box_clone(),
184 }
185 }
186}
187
188impl<A, B> ConnectionInfo for EitherOutput<A, B>
189where
190 A: ConnectionInfo,
191 B: ConnectionInfo,
192{
193 fn local_multiaddr(&self) -> Multiaddr {
194 match self {
195 EitherOutput::A(a) => a.local_multiaddr(),
196 EitherOutput::B(b) => b.local_multiaddr(),
197 }
198 }
199
200 fn remote_multiaddr(&self) -> Multiaddr {
201 match self {
202 EitherOutput::A(a) => a.remote_multiaddr(),
203 EitherOutput::B(b) => b.remote_multiaddr(),
204 }
205 }
206}
207
208impl<A, B> StreamMuxerEx for EitherOutput<A, B>
209where
210 A: StreamMuxer + ConnectionInfo + SecureInfo + std::fmt::Debug,
211 B: StreamMuxer + ConnectionInfo + SecureInfo + std::fmt::Debug,
212{
213}
214
215#[derive(Debug, Clone)]
216pub enum EitherName<A, B> {
217 A(A),
218 B(B),
219}
220
221impl<A: ProtocolName, B: ProtocolName> ProtocolName for EitherName<A, B> {
222 fn protocol_name(&self) -> &[u8] {
223 match self {
224 EitherName::A(a) => a.protocol_name(),
225 EitherName::B(b) => b.protocol_name(),
226 }
227 }
228}
229
230