Skip to main content

airio_core/transport/
upgrade.rs

1use std::{
2    error,
3    marker::PhantomData,
4    net::SocketAddr,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use airio_stream_select::Negotiated;
10use futures::{AsyncRead, AsyncWrite, Stream, TryFuture, future, ready};
11
12use crate::{
13    ConnectedPoint, Endpoint, ListenerEvent, PeerId, StreamMuxer, Transport, Upgrade,
14    muxing::StreamMuxerBox,
15    transport::{Boxed, and_then::AndThen, boxed::boxed},
16    upgrade::{UpgradeApply, UpgradeError},
17};
18
19#[derive(Clone)]
20pub struct Builder<T> {
21    inner: T,
22}
23
24impl<T> Builder<T>
25where
26    T: Transport,
27    T::Error: 'static,
28{
29    pub fn new(inner: T) -> Builder<T> {
30        Builder { inner }
31    }
32
33    /// 使用一个 [`Upgrade`] 对 [`Transport::Output`] 进行身份协商
34    /// * I/O upgrade: `C -> (PeerId, D)`.
35    /// * 转化 Transport output: `C -> (PeerId, D)`
36    pub fn authenticate<C, D, U, E>(
37        self,
38        upgrade: U,
39    ) -> AuthenticatedBuilder<
40        AndThen<T, impl FnOnce(C, ConnectedPoint) -> Authenticate<C, U> + Clone>,
41    >
42    where
43        T: Transport<Output = C>,
44        C: AsyncRead + AsyncWrite + Unpin,
45        D: AsyncRead + AsyncWrite + Unpin,
46        U: Upgrade<Negotiated<C>, Output = (PeerId, D), Error = E> + Clone,
47        E: error::Error + 'static,
48    {
49        AuthenticatedBuilder(Builder::new(self.inner.and_then(move |io, endpoint| {
50            let inner = if endpoint.is_dialer() {
51                UpgradeApply::new_outbound(io, upgrade)
52            } else {
53                UpgradeApply::new_inbound(io, upgrade)
54            };
55
56            Authenticate { inner }
57        })))
58    }
59}
60
61/// 身份认证 Builder
62#[derive(Clone)]
63pub struct AuthenticatedBuilder<T>(Builder<T>);
64
65impl<T> AuthenticatedBuilder<T>
66where
67    T: Transport,
68    T::Error: 'static,
69{
70    /// 在一个 [`Transport`] 身份协商后的流应用一个 [`Upgrade`]。
71    /// 这将返回一个新的 [`AuthenticatedBuilder`],其中包含了升级后的流
72    /// 使用 [`Upgrade`] 作用 -> `(PeerId, C) -> (PeerId, D)`.
73    pub fn apply<C, D, U, E>(self, upgrade: U) -> AuthenticatedBuilder<WithUpgrade<T, U>>
74    where
75        T: Transport<Output = (PeerId, C)>,
76        C: AsyncRead + AsyncWrite + Unpin,
77        D: AsyncRead + AsyncWrite + Unpin,
78        U: Upgrade<Negotiated<C>, Output = D, Error = E> + Clone,
79        E: error::Error + 'static,
80    {
81        AuthenticatedBuilder(Builder::new(WithUpgrade::new(self.0.inner, upgrade)))
82    }
83
84    /// 在一个 [`Transport`] 身份协商后的流应用一个多路复用 [`Upgrade`]。
85    /// 实现了一个多路复用的连接升级。
86    /// 使用 [`Upgrade`] 作用 -> `(PeerId, C) -> (PeerId, M)`.
87    /// M 必须实现了 [`StreamMuxer`]
88    pub fn multiplex<C, M, U, E>(
89        self,
90        upgrade: U,
91    ) -> Multiplexed<AndThen<T, impl FnOnce((PeerId, C), ConnectedPoint) -> Multiplex<C, U> + Clone>>
92    where
93        T: Transport<Output = (PeerId, C)>,
94        M: StreamMuxer,
95        C: AsyncRead + AsyncWrite + Unpin,
96        U: Upgrade<Negotiated<C>, Output = M, Error = E> + Clone,
97        E: error::Error + 'static,
98    {
99        Multiplexed(self.0.inner.and_then(move |(id, io), endpoint| {
100            let upgrade = if endpoint.is_dialer() {
101                UpgradeApply::new_outbound(io, upgrade)
102            } else {
103                UpgradeApply::new_inbound(io, upgrade)
104            };
105            Multiplex {
106                peer_id: Some(id),
107                upgrade,
108            }
109        }))
110    }
111}
112
113/// 身份认证 Future
114/// 内部为 [`UpgradeApply`] 结构体
115#[pin_project::pin_project]
116pub struct Authenticate<C, U>
117where
118    C: AsyncRead + AsyncWrite + Unpin,
119    U: Upgrade<Negotiated<C>>,
120{
121    #[pin]
122    inner: UpgradeApply<C, U>,
123}
124
125impl<C, U> Future for Authenticate<C, U>
126where
127    C: AsyncRead + AsyncWrite + Unpin,
128    U: Upgrade<Negotiated<C>>,
129{
130    type Output = <UpgradeApply<C, U> as Future>::Output;
131
132    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133        let this = self.project();
134        this.inner.poll(cx)
135    }
136}
137
138#[derive(Debug, Copy, Clone)]
139#[pin_project::pin_project]
140pub struct WithUpgrade<T, U> {
141    #[pin]
142    inner: T,
143    upgrade: U,
144}
145
146impl<T, U> WithUpgrade<T, U> {
147    pub fn new(inner: T, upgrade: U) -> Self {
148        WithUpgrade { inner, upgrade }
149    }
150}
151
152impl<T, C, D, U, E> Transport for WithUpgrade<T, U>
153where
154    T: Transport<Output = (PeerId, C)>,
155    T::Error: 'static,
156    C: AsyncRead + AsyncWrite + Unpin,
157    U: Upgrade<Negotiated<C>, Output = D, Error = E> + Clone,
158    E: error::Error + 'static,
159{
160    type Output = (PeerId, D);
161    type Error = TransportUpgradeError<T::Error, E>;
162    type Dialer = UpgradeFuture<T::Dialer, U, C>;
163    type ListenerUpgrade = UpgradeFuture<T::ListenerUpgrade, U, C>;
164    type Listener = MapListener<T, U>;
165
166    fn connect(&self, addr: SocketAddr) -> Result<Self::Dialer, Self::Error> {
167        let fut = self
168            .inner
169            .connect(addr)
170            .map_err(TransportUpgradeError::Transport)?;
171        Ok(UpgradeFuture {
172            inner_fut: Box::pin(fut),
173            role: Endpoint::Dialer,
174            upgrade: future::Either::Left(Some(self.upgrade.clone())),
175        })
176    }
177
178    fn listen(&self, addr: SocketAddr) -> Result<Self::Listener, Self::Error> {
179        let listener = self
180            .inner
181            .listen(addr)
182            .map_err(TransportUpgradeError::Transport)?;
183        Ok(MapListener {
184            inner: listener,
185            upgrade: self.upgrade.clone(),
186            _phantom: PhantomData,
187        })
188    }
189}
190
191#[pin_project::pin_project]
192#[derive(Clone, Debug)]
193pub struct MapListener<T, U>
194where
195    T: Transport,
196{
197    #[pin]
198    inner: T::Listener,
199    upgrade: U,
200    _phantom: PhantomData<T>,
201}
202
203impl<T, C, D, U, E> Stream for MapListener<T, U>
204where
205    T: Transport<Output = (PeerId, C)>,
206    T::Error: 'static,
207    C: AsyncRead + AsyncWrite + Unpin,
208    U: Upgrade<Negotiated<C>, Output = D, Error = E> + Clone,
209{
210    type Item =
211        ListenerEvent<UpgradeFuture<T::ListenerUpgrade, U, C>, TransportUpgradeError<T::Error, E>>;
212    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213        let mut this = self.as_mut().project();
214        match Pin::new(&mut this.inner).poll_next(cx) {
215            Poll::Ready(Some(event)) => {
216                let upgrade = self.upgrade.clone();
217                let event = event
218                    .map_upgrade(move |up| UpgradeFuture {
219                        inner_fut: Box::pin(up),
220                        role: Endpoint::Listener,
221                        upgrade: future::Either::Left(Some(upgrade)),
222                    })
223                    .map_err(TransportUpgradeError::Transport);
224
225                Poll::Ready(Some(event))
226            }
227            Poll::Ready(None) => Poll::Ready(None),
228            Poll::Pending => Poll::Pending,
229        }
230    }
231}
232
233pub struct UpgradeFuture<Fut, U, C>
234where
235    C: AsyncRead + AsyncWrite + Unpin,
236    U: Upgrade<Negotiated<C>>,
237{
238    inner_fut: Pin<Box<Fut>>,
239    role: Endpoint,
240    upgrade: future::Either<Option<U>, (PeerId, UpgradeApply<C, U>)>,
241}
242
243impl<Fut, U, C> Unpin for UpgradeFuture<Fut, U, C>
244where
245    C: AsyncRead + AsyncWrite + Unpin,
246    U: Upgrade<Negotiated<C>>,
247{
248}
249
250impl<Fut, U, C, D> Future for UpgradeFuture<Fut, U, C>
251where
252    Fut: TryFuture<Ok = (PeerId, C)>,
253    C: AsyncRead + AsyncWrite + Unpin,
254    U: Upgrade<Negotiated<C>, Output = D>,
255    U::Error: error::Error,
256{
257    type Output = Result<(PeerId, D), TransportUpgradeError<Fut::Error, U::Error>>;
258    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
259        let this = &mut *self;
260        loop {
261            this.upgrade = match this.upgrade {
262                future::Either::Left(ref mut up) => {
263                    // 底层的进站升级中Poll出 `(PeerId, C)`.
264                    let (peer_id, io) = ready!(this.inner_fut.as_mut().try_poll(cx))
265                        .map_err(TransportUpgradeError::Transport)?;
266                    let upgrade = up.take().expect("upgrade should be set");
267                    // 使用 `UpgradeApply` 来应用升级。
268                    let upgrade = match this.role {
269                        Endpoint::Dialer => UpgradeApply::new_outbound(io, upgrade),
270                        Endpoint::Listener => UpgradeApply::new_inbound(io, upgrade),
271                    };
272                    future::Either::Right((peer_id, upgrade))
273                }
274                future::Either::Right((i, ref mut up)) => {
275                    let output = ready!(Pin::new(up).try_poll(cx))?;
276                    return Poll::Ready(Ok((i, output)));
277                }
278            }
279        }
280    }
281}
282
283#[derive(Debug, thiserror::Error)]
284pub enum TransportUpgradeError<TE, UE> {
285    #[error("Transport error: {0}")]
286    Transport(TE),
287    #[error(transparent)]
288    Upgrade(#[from] UpgradeError<UE>),
289}
290
291#[derive(Clone)]
292#[pin_project::pin_project]
293pub struct Multiplexed<T>(#[pin] T);
294
295#[pin_project::pin_project]
296pub struct Multiplex<C, U>
297where
298    C: AsyncRead + AsyncWrite + Unpin,
299    U: Upgrade<Negotiated<C>>,
300{
301    peer_id: Option<PeerId>,
302    #[pin]
303    upgrade: UpgradeApply<C, U>,
304}
305
306impl<T> Multiplexed<T> {
307    pub fn boxed<M>(self) -> Boxed<(PeerId, StreamMuxerBox)>
308    where
309        T: Transport<Output = (PeerId, M)> + Sized + Send + Unpin + 'static,
310        T::Dialer: Send + 'static,
311        T::ListenerUpgrade: Send + 'static,
312        T::Listener: Send + 'static,
313        T::Error: Send + Sync,
314        M: StreamMuxer + Send + 'static,
315        M::Substream: Send + 'static,
316        M::Error: Send + Sync + 'static,
317    {
318        boxed(self.map(|(i, m), _| (i, StreamMuxerBox::new(m))))
319    }
320}
321
322impl<T> Transport for Multiplexed<T>
323where
324    T: Transport,
325{
326    type Output = T::Output;
327    type Error = T::Error;
328    type ListenerUpgrade = T::ListenerUpgrade;
329    type Dialer = T::Dialer;
330    type Listener = T::Listener;
331
332    fn connect(&self, addr: SocketAddr) -> Result<Self::Dialer, Self::Error> {
333        self.0.connect(addr)
334    }
335
336    fn listen(&self, addr: SocketAddr) -> Result<Self::Listener, Self::Error> {
337        self.0.listen(addr)
338    }
339}
340
341impl<C, U, M, E> Future for Multiplex<C, U>
342where
343    C: AsyncRead + AsyncWrite + Unpin,
344    U: Upgrade<Negotiated<C>, Output = M, Error = E>,
345{
346    type Output = Result<(PeerId, M), UpgradeError<E>>;
347
348    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
349        let mut this = self.project();
350        let m = match ready!(Pin::new(&mut this.upgrade).poll(cx)) {
351            Ok(m) => m,
352            Err(err) => return Poll::Ready(Err(err)),
353        };
354        let i = this
355            .peer_id
356            .take()
357            .expect("Multiplex future polled after completion.");
358        Poll::Ready(Ok((i, m)))
359    }
360}