libp2prs_websocket/
connection.rs

1use futures::{
2    prelude::*,
3    stream::{BoxStream, IntoAsyncRead, TryStreamExt},
4};
5use libp2prs_core::{either::EitherOutput, multiaddr::Multiaddr, transport::ConnectionInfo};
6use quicksink::Action;
7use soketto::connection;
8use std::{
9    io,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14pub type TlsOrPlain<T> = EitherOutput<EitherOutput<TlsClientStream<T>, TlsServerStream<T>>, T>;
15
16#[pin_project::pin_project]
17pub struct Connection<T> {
18    #[pin]
19    reader: IntoAsyncRead<BoxStream<'static, io::Result<Vec<u8>>>>,
20    #[pin]
21    writer: Pin<Box<dyn Sink<Vec<u8>, Error = io::Error> + Send>>,
22
23    local_addr: Multiaddr,
24    remote_addr: Multiaddr,
25
26    _mark: std::marker::PhantomData<T>,
27}
28
29impl<T> Connection<T>
30where
31    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
32{
33    #[allow(clippy::needless_return)]
34    pub fn new(builder: connection::Builder<T>, local_addr: Multiaddr, remote_addr: Multiaddr) -> Self {
35        let (tx, rx) = builder.finish();
36
37        let stream = futures::stream::unfold(rx, move |mut rx| async move {
38            let mut buf = Vec::with_capacity(1024);
39            log::debug!("receiving data");
40            match rx.receive_data(&mut buf).await {
41                Ok(data) => match data {
42                    soketto::Data::Binary(n) | soketto::Data::Text(n) => {
43                        buf.truncate(n);
44                        log::debug!("receive data ok: {:?}", buf);
45                        return Some((Ok(buf), rx));
46                    }
47                },
48                Err(e) => {
49                    log::debug!("receive data err: {:?}", e);
50                    match e {
51                        connection::Error::Io(ioe) => return Some((Err(ioe), rx)),
52                        connection::Error::Closed => return None,
53                        _ => return Some((Err(io::Error::new(io::ErrorKind::Other, e)), rx)),
54                    }
55                }
56            }
57        });
58        let stream: BoxStream<'static, io::Result<Vec<u8>>> = stream.boxed();
59        let reader = stream.into_async_read();
60
61        let sink = quicksink::make_sink(tx, move |mut tx, action: Action<Vec<u8>>| async move {
62            match action {
63                Action::Send(data) => {
64                    log::debug!("send data: {:?}", data);
65                    tx.send_binary(data).await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
66                }
67                Action::Flush => {
68                    tx.flush().await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
69                }
70                Action::Close => {
71                    tx.close().await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
72                }
73            }
74            Ok(tx)
75        });
76
77        Connection {
78            reader,
79            writer: Box::pin(sink),
80            local_addr,
81            remote_addr,
82            _mark: std::marker::PhantomData,
83        }
84    }
85}
86
87impl<T> AsyncRead for Connection<T>
88where
89    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
90{
91    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
92        self.project().reader.poll_read(cx, buf)
93    }
94}
95
96impl<T> AsyncWrite for Connection<T>
97where
98    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
99{
100    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
101        let mut this = self.project();
102        futures::ready!(this.writer.as_mut().poll_ready(cx))?;
103        let n = buf.len();
104        if let Err(e) = this.writer.as_mut().start_send(buf.to_vec()) {
105            return Poll::Ready(Err(e));
106        }
107        Poll::Ready(Ok(n))
108    }
109
110    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
111        self.project().writer.poll_flush(cx)
112    }
113
114    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
115        self.project().writer.poll_close(cx)
116    }
117}
118
119impl<T> ConnectionInfo for Connection<T>
120where
121    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
122{
123    fn local_multiaddr(&self) -> Multiaddr {
124        self.local_addr.clone()
125    }
126
127    fn remote_multiaddr(&self) -> Multiaddr {
128        self.remote_addr.clone()
129    }
130}
131
132pub struct TlsClientStream<T>(pub(crate) async_tls::client::TlsStream<T>);
133
134impl<T> AsyncRead for TlsClientStream<T>
135where
136    T: AsyncRead + AsyncWrite + Unpin,
137{
138    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
139        AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
140    }
141}
142
143impl<T> AsyncWrite for TlsClientStream<T>
144where
145    T: AsyncRead + AsyncWrite + Unpin,
146{
147    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
148        Pin::new(&mut self.0).poll_write(cx, buf)
149    }
150
151    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
152        Pin::new(&mut self.0).poll_flush(cx)
153    }
154
155    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
156        Pin::new(&mut self.0).poll_close(cx)
157    }
158}
159
160pub struct TlsServerStream<T>(pub(crate) async_tls::server::TlsStream<T>);
161
162impl<T> AsyncRead for TlsServerStream<T>
163where
164    T: AsyncRead + AsyncWrite + Unpin,
165{
166    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
167        AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
168    }
169}
170
171impl<T> AsyncWrite for TlsServerStream<T>
172where
173    T: AsyncRead + AsyncWrite + Unpin,
174{
175    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
176        Pin::new(&mut self.0).poll_write(cx, buf)
177    }
178
179    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180        Pin::new(&mut self.0).poll_flush(cx)
181    }
182
183    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184        Pin::new(&mut self.0).poll_close(cx)
185    }
186}