libp2prs_websocket/
connection.rs1use futures::{
2 prelude::*,
3 stream::{BoxStream, IntoAsyncRead, TryStreamExt},
4};
5use libp2prs_core::{either::EitherOutput, multiaddr::Multiaddr, transport::ConnectionInfo};
6use quicksink::Action;
7use soketto::connection;
8use std::{
9 io,
10 pin::Pin,
11 task::{Context, Poll},
12};
13
14pub type TlsOrPlain<T> = EitherOutput<EitherOutput<TlsClientStream<T>, TlsServerStream<T>>, T>;
15
16#[pin_project::pin_project]
17pub struct Connection<T> {
18 #[pin]
19 reader: IntoAsyncRead<BoxStream<'static, io::Result<Vec<u8>>>>,
20 #[pin]
21 writer: Pin<Box<dyn Sink<Vec<u8>, Error = io::Error> + Send>>,
22
23 local_addr: Multiaddr,
24 remote_addr: Multiaddr,
25
26 _mark: std::marker::PhantomData<T>,
27}
28
29impl<T> Connection<T>
30where
31 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
32{
33 #[allow(clippy::needless_return)]
34 pub fn new(builder: connection::Builder<T>, local_addr: Multiaddr, remote_addr: Multiaddr) -> Self {
35 let (tx, rx) = builder.finish();
36
37 let stream = futures::stream::unfold(rx, move |mut rx| async move {
38 let mut buf = Vec::with_capacity(1024);
39 log::debug!("receiving data");
40 match rx.receive_data(&mut buf).await {
41 Ok(data) => match data {
42 soketto::Data::Binary(n) | soketto::Data::Text(n) => {
43 buf.truncate(n);
44 log::debug!("receive data ok: {:?}", buf);
45 return Some((Ok(buf), rx));
46 }
47 },
48 Err(e) => {
49 log::debug!("receive data err: {:?}", e);
50 match e {
51 connection::Error::Io(ioe) => return Some((Err(ioe), rx)),
52 connection::Error::Closed => return None,
53 _ => return Some((Err(io::Error::new(io::ErrorKind::Other, e)), rx)),
54 }
55 }
56 }
57 });
58 let stream: BoxStream<'static, io::Result<Vec<u8>>> = stream.boxed();
59 let reader = stream.into_async_read();
60
61 let sink = quicksink::make_sink(tx, move |mut tx, action: Action<Vec<u8>>| async move {
62 match action {
63 Action::Send(data) => {
64 log::debug!("send data: {:?}", data);
65 tx.send_binary(data).await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
66 }
67 Action::Flush => {
68 tx.flush().await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
69 }
70 Action::Close => {
71 tx.close().await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
72 }
73 }
74 Ok(tx)
75 });
76
77 Connection {
78 reader,
79 writer: Box::pin(sink),
80 local_addr,
81 remote_addr,
82 _mark: std::marker::PhantomData,
83 }
84 }
85}
86
87impl<T> AsyncRead for Connection<T>
88where
89 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
90{
91 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
92 self.project().reader.poll_read(cx, buf)
93 }
94}
95
96impl<T> AsyncWrite for Connection<T>
97where
98 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
99{
100 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
101 let mut this = self.project();
102 futures::ready!(this.writer.as_mut().poll_ready(cx))?;
103 let n = buf.len();
104 if let Err(e) = this.writer.as_mut().start_send(buf.to_vec()) {
105 return Poll::Ready(Err(e));
106 }
107 Poll::Ready(Ok(n))
108 }
109
110 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
111 self.project().writer.poll_flush(cx)
112 }
113
114 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
115 self.project().writer.poll_close(cx)
116 }
117}
118
119impl<T> ConnectionInfo for Connection<T>
120where
121 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
122{
123 fn local_multiaddr(&self) -> Multiaddr {
124 self.local_addr.clone()
125 }
126
127 fn remote_multiaddr(&self) -> Multiaddr {
128 self.remote_addr.clone()
129 }
130}
131
132pub struct TlsClientStream<T>(pub(crate) async_tls::client::TlsStream<T>);
133
134impl<T> AsyncRead for TlsClientStream<T>
135where
136 T: AsyncRead + AsyncWrite + Unpin,
137{
138 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
139 AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
140 }
141}
142
143impl<T> AsyncWrite for TlsClientStream<T>
144where
145 T: AsyncRead + AsyncWrite + Unpin,
146{
147 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
148 Pin::new(&mut self.0).poll_write(cx, buf)
149 }
150
151 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
152 Pin::new(&mut self.0).poll_flush(cx)
153 }
154
155 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
156 Pin::new(&mut self.0).poll_close(cx)
157 }
158}
159
160pub struct TlsServerStream<T>(pub(crate) async_tls::server::TlsStream<T>);
161
162impl<T> AsyncRead for TlsServerStream<T>
163where
164 T: AsyncRead + AsyncWrite + Unpin,
165{
166 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
167 AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
168 }
169}
170
171impl<T> AsyncWrite for TlsServerStream<T>
172where
173 T: AsyncRead + AsyncWrite + Unpin,
174{
175 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
176 Pin::new(&mut self.0).poll_write(cx, buf)
177 }
178
179 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180 Pin::new(&mut self.0).poll_flush(cx)
181 }
182
183 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 Pin::new(&mut self.0).poll_close(cx)
185 }
186}