1use std::{
2 error,
3 marker::PhantomData,
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use airio_stream_select::Negotiated;
10use futures::{AsyncRead, AsyncWrite, Stream, TryFuture, future, ready};
11
12use crate::{
13 ConnectedPoint, Endpoint, ListenerEvent, PeerId, StreamMuxer, Transport, Upgrade,
14 muxing::StreamMuxerBox,
15 transport::{Boxed, and_then::AndThen, boxed::boxed},
16 upgrade::{UpgradeApply, UpgradeError},
17};
18
19#[derive(Clone)]
20pub struct Builder<T> {
21 inner: T,
22}
23
24impl<T> Builder<T>
25where
26 T: Transport,
27 T::Error: 'static,
28{
29 pub fn new(inner: T) -> Builder<T> {
30 Builder { inner }
31 }
32
33 pub fn authenticate<C, D, U, E>(
37 self,
38 upgrade: U,
39 ) -> AuthenticatedBuilder<
40 AndThen<T, impl FnOnce(C, ConnectedPoint) -> Authenticate<C, U> + Clone>,
41 >
42 where
43 T: Transport<Output = C>,
44 C: AsyncRead + AsyncWrite + Unpin,
45 D: AsyncRead + AsyncWrite + Unpin,
46 U: Upgrade<Negotiated<C>, Output = (PeerId, D), Error = E> + Clone,
47 E: error::Error + 'static,
48 {
49 AuthenticatedBuilder(Builder::new(self.inner.and_then(move |io, endpoint| {
50 let inner = if endpoint.is_dialer() {
51 UpgradeApply::new_outbound(io, upgrade)
52 } else {
53 UpgradeApply::new_inbound(io, upgrade)
54 };
55
56 Authenticate { inner }
57 })))
58 }
59}
60
61#[derive(Clone)]
63pub struct AuthenticatedBuilder<T>(Builder<T>);
64
65impl<T> AuthenticatedBuilder<T>
66where
67 T: Transport,
68 T::Error: 'static,
69{
70 pub fn apply<C, D, U, E>(self, upgrade: U) -> AuthenticatedBuilder<WithUpgrade<T, U>>
74 where
75 T: Transport<Output = (PeerId, C)>,
76 C: AsyncRead + AsyncWrite + Unpin,
77 D: AsyncRead + AsyncWrite + Unpin,
78 U: Upgrade<Negotiated<C>, Output = D, Error = E> + Clone,
79 E: error::Error + 'static,
80 {
81 AuthenticatedBuilder(Builder::new(WithUpgrade::new(self.0.inner, upgrade)))
82 }
83
84 pub fn multiplex<C, M, U, E>(
89 self,
90 upgrade: U,
91 ) -> Multiplexed<AndThen<T, impl FnOnce((PeerId, C), ConnectedPoint) -> Multiplex<C, U> + Clone>>
92 where
93 T: Transport<Output = (PeerId, C)>,
94 M: StreamMuxer,
95 C: AsyncRead + AsyncWrite + Unpin,
96 U: Upgrade<Negotiated<C>, Output = M, Error = E> + Clone,
97 E: error::Error + 'static,
98 {
99 Multiplexed(self.0.inner.and_then(move |(id, io), endpoint| {
100 let upgrade = if endpoint.is_dialer() {
101 UpgradeApply::new_outbound(io, upgrade)
102 } else {
103 UpgradeApply::new_inbound(io, upgrade)
104 };
105 Multiplex {
106 peer_id: Some(id),
107 upgrade,
108 }
109 }))
110 }
111}
112
113#[pin_project::pin_project]
116pub struct Authenticate<C, U>
117where
118 C: AsyncRead + AsyncWrite + Unpin,
119 U: Upgrade<Negotiated<C>>,
120{
121 #[pin]
122 inner: UpgradeApply<C, U>,
123}
124
125impl<C, U> Future for Authenticate<C, U>
126where
127 C: AsyncRead + AsyncWrite + Unpin,
128 U: Upgrade<Negotiated<C>>,
129{
130 type Output = <UpgradeApply<C, U> as Future>::Output;
131
132 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133 let this = self.project();
134 this.inner.poll(cx)
135 }
136}
137
138#[derive(Debug, Copy, Clone)]
139#[pin_project::pin_project]
140pub struct WithUpgrade<T, U> {
141 #[pin]
142 inner: T,
143 upgrade: U,
144}
145
146impl<T, U> WithUpgrade<T, U> {
147 pub fn new(inner: T, upgrade: U) -> Self {
148 WithUpgrade { inner, upgrade }
149 }
150}
151
152impl<T, C, D, U, E> Transport for WithUpgrade<T, U>
153where
154 T: Transport<Output = (PeerId, C)>,
155 T::Error: 'static,
156 C: AsyncRead + AsyncWrite + Unpin,
157 U: Upgrade<Negotiated<C>, Output = D, Error = E> + Clone,
158 E: error::Error + 'static,
159{
160 type Output = (PeerId, D);
161 type Error = TransportUpgradeError<T::Error, E>;
162 type Dialer = UpgradeFuture<T::Dialer, U, C>;
163 type ListenerUpgrade = UpgradeFuture<T::ListenerUpgrade, U, C>;
164 type Listener = MapListener<T, U>;
165
166 fn connect(&self, addr: SocketAddr) -> Result<Self::Dialer, Self::Error> {
167 let fut = self
168 .inner
169 .connect(addr)
170 .map_err(TransportUpgradeError::Transport)?;
171 Ok(UpgradeFuture {
172 inner_fut: Box::pin(fut),
173 role: Endpoint::Dialer,
174 upgrade: future::Either::Left(Some(self.upgrade.clone())),
175 })
176 }
177
178 fn listen(&self, addr: SocketAddr) -> Result<Self::Listener, Self::Error> {
179 let listener = self
180 .inner
181 .listen(addr)
182 .map_err(TransportUpgradeError::Transport)?;
183 Ok(MapListener {
184 inner: listener,
185 upgrade: self.upgrade.clone(),
186 _phantom: PhantomData,
187 })
188 }
189}
190
191#[pin_project::pin_project]
192#[derive(Clone, Debug)]
193pub struct MapListener<T, U>
194where
195 T: Transport,
196{
197 #[pin]
198 inner: T::Listener,
199 upgrade: U,
200 _phantom: PhantomData<T>,
201}
202
203impl<T, C, D, U, E> Stream for MapListener<T, U>
204where
205 T: Transport<Output = (PeerId, C)>,
206 T::Error: 'static,
207 C: AsyncRead + AsyncWrite + Unpin,
208 U: Upgrade<Negotiated<C>, Output = D, Error = E> + Clone,
209{
210 type Item =
211 ListenerEvent<UpgradeFuture<T::ListenerUpgrade, U, C>, TransportUpgradeError<T::Error, E>>;
212 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213 let mut this = self.as_mut().project();
214 match Pin::new(&mut this.inner).poll_next(cx) {
215 Poll::Ready(Some(event)) => {
216 let upgrade = self.upgrade.clone();
217 let event = event
218 .map_upgrade(move |up| UpgradeFuture {
219 inner_fut: Box::pin(up),
220 role: Endpoint::Listener,
221 upgrade: future::Either::Left(Some(upgrade)),
222 })
223 .map_err(TransportUpgradeError::Transport);
224
225 Poll::Ready(Some(event))
226 }
227 Poll::Ready(None) => Poll::Ready(None),
228 Poll::Pending => Poll::Pending,
229 }
230 }
231}
232
233pub struct UpgradeFuture<Fut, U, C>
234where
235 C: AsyncRead + AsyncWrite + Unpin,
236 U: Upgrade<Negotiated<C>>,
237{
238 inner_fut: Pin<Box<Fut>>,
239 role: Endpoint,
240 upgrade: future::Either<Option<U>, (PeerId, UpgradeApply<C, U>)>,
241}
242
243impl<Fut, U, C> Unpin for UpgradeFuture<Fut, U, C>
244where
245 C: AsyncRead + AsyncWrite + Unpin,
246 U: Upgrade<Negotiated<C>>,
247{
248}
249
250impl<Fut, U, C, D> Future for UpgradeFuture<Fut, U, C>
251where
252 Fut: TryFuture<Ok = (PeerId, C)>,
253 C: AsyncRead + AsyncWrite + Unpin,
254 U: Upgrade<Negotiated<C>, Output = D>,
255 U::Error: error::Error,
256{
257 type Output = Result<(PeerId, D), TransportUpgradeError<Fut::Error, U::Error>>;
258 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
259 let this = &mut *self;
260 loop {
261 this.upgrade = match this.upgrade {
262 future::Either::Left(ref mut up) => {
263 let (peer_id, io) = ready!(this.inner_fut.as_mut().try_poll(cx))
265 .map_err(TransportUpgradeError::Transport)?;
266 let upgrade = up.take().expect("upgrade should be set");
267 let upgrade = match this.role {
269 Endpoint::Dialer => UpgradeApply::new_outbound(io, upgrade),
270 Endpoint::Listener => UpgradeApply::new_inbound(io, upgrade),
271 };
272 future::Either::Right((peer_id, upgrade))
273 }
274 future::Either::Right((i, ref mut up)) => {
275 let output = ready!(Pin::new(up).try_poll(cx))?;
276 return Poll::Ready(Ok((i, output)));
277 }
278 }
279 }
280 }
281}
282
283#[derive(Debug, thiserror::Error)]
284pub enum TransportUpgradeError<TE, UE> {
285 #[error("Transport error: {0}")]
286 Transport(TE),
287 #[error(transparent)]
288 Upgrade(#[from] UpgradeError<UE>),
289}
290
291#[derive(Clone)]
292#[pin_project::pin_project]
293pub struct Multiplexed<T>(#[pin] T);
294
295#[pin_project::pin_project]
296pub struct Multiplex<C, U>
297where
298 C: AsyncRead + AsyncWrite + Unpin,
299 U: Upgrade<Negotiated<C>>,
300{
301 peer_id: Option<PeerId>,
302 #[pin]
303 upgrade: UpgradeApply<C, U>,
304}
305
306impl<T> Multiplexed<T> {
307 pub fn boxed<M>(self) -> Boxed<(PeerId, StreamMuxerBox)>
308 where
309 T: Transport<Output = (PeerId, M)> + Sized + Send + Unpin + 'static,
310 T::Dialer: Send + 'static,
311 T::ListenerUpgrade: Send + 'static,
312 T::Listener: Send + 'static,
313 T::Error: Send + Sync,
314 M: StreamMuxer + Send + 'static,
315 M::Substream: Send + 'static,
316 M::Error: Send + Sync + 'static,
317 {
318 boxed(self.map(|(i, m), _| (i, StreamMuxerBox::new(m))))
319 }
320}
321
322impl<T> Transport for Multiplexed<T>
323where
324 T: Transport,
325{
326 type Output = T::Output;
327 type Error = T::Error;
328 type ListenerUpgrade = T::ListenerUpgrade;
329 type Dialer = T::Dialer;
330 type Listener = T::Listener;
331
332 fn connect(&self, addr: SocketAddr) -> Result<Self::Dialer, Self::Error> {
333 self.0.connect(addr)
334 }
335
336 fn listen(&self, addr: SocketAddr) -> Result<Self::Listener, Self::Error> {
337 self.0.listen(addr)
338 }
339}
340
341impl<C, U, M, E> Future for Multiplex<C, U>
342where
343 C: AsyncRead + AsyncWrite + Unpin,
344 U: Upgrade<Negotiated<C>, Output = M, Error = E>,
345{
346 type Output = Result<(PeerId, M), UpgradeError<E>>;
347
348 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
349 let mut this = self.project();
350 let m = match ready!(Pin::new(&mut this.upgrade).poll(cx)) {
351 Ok(m) => m,
352 Err(err) => return Poll::Ready(Err(err)),
353 };
354 let i = this
355 .peer_id
356 .take()
357 .expect("Multiplex future polled after completion.");
358 Poll::Ready(Ok((i, m)))
359 }
360}