Skip to main content

tfserver/structures/
transport.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use futures_util::{StreamExt};
5use pin_project::pin_project;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7
8#[cfg(not(target_arch = "wasm32"))]
9use async_tungstenite::{ByteReader, ByteWriter};
10#[cfg(not(target_arch = "wasm32"))]
11use tokio::net::TcpStream;
12#[cfg(not(target_arch = "wasm32"))]
13use tokio_rustls::{client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream};
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_tungstenite::{accept_async, connect_async};
16
17
18pub struct Transport {
19    inner: Box<dyn AsyncReadWrite>,
20}
21
22pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + 'static {
23    #[cfg(not(target_arch = "wasm32"))]
24    fn is_send_sync(&self) where Self: Send + Sync {}
25}
26
27impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static> AsyncReadWrite for T {}
28
29
30
31#[pin_project]
32pub struct WsStreamCompat<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> {
33    #[pin]
34    reader: R,
35    #[pin]
36    writer: W,
37}
38
39impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncRead
40for WsStreamCompat<R, W>
41{
42    fn poll_read(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut ReadBuf<'_>,
46    ) -> Poll<io::Result<()>> {
47        let unfilled = buf.initialize_unfilled();
48        match self.project().reader.poll_read(cx, unfilled) {
49            Poll::Ready(Ok(n)) => {
50                buf.advance(n);
51                Poll::Ready(Ok(()))
52            }
53            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
54            Poll::Pending => Poll::Pending,
55        }
56    }
57}
58
59impl<R: futures_io::AsyncRead + Unpin, W: futures_io::AsyncWrite + Unpin> AsyncWrite
60for WsStreamCompat<R, W>
61{
62    fn poll_write(
63        self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65        buf: &[u8],
66    ) -> Poll<io::Result<usize>> {
67        self.project().writer.poll_write(cx, buf)
68    }
69
70    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71        self.project().writer.poll_flush(cx)
72    }
73
74    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75        self.project().writer.poll_close(cx)
76    }
77}
78
79impl Transport {
80    #[cfg(not(target_arch = "wasm32"))]
81    pub fn plain(stream: TcpStream) -> Self {
82        Self { inner: Box::new(stream) }
83    }
84    
85
86    #[cfg(not(target_arch = "wasm32"))]
87    pub fn tls_server(stream: ServerTlsStream<TcpStream>) -> Self {
88        Self { inner: Box::new(stream) }
89    }
90
91    #[cfg(not(target_arch = "wasm32"))]
92    pub fn tls_client(stream: ClientTlsStream<TcpStream>) -> Self {
93        Self { inner: Box::new(stream) }
94    }
95
96    /// On WASM: connect via WebSocket, returns a Transport backed by ws_stream_wasm.
97    /// On native: not available — use plain/tls_client/tls_server + a WS proxy if needed.
98    #[cfg(target_arch = "wasm32")]
99    pub async fn connect(url: &str) -> io::Result<Self> {
100        use ws_stream_wasm::WsMeta;
101        use futures_util::AsyncReadExt;
102        let (_meta, ws_stream) = WsMeta::connect(url, None)
103            .await
104            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
105
106        let (reader, writer) = ws_stream.into_io().split();
107        Ok(Self {
108            inner: Box::new(WsStreamCompat { reader, writer }),
109        })
110    }
111
112    #[cfg(not(target_arch = "wasm32"))]
113    pub async fn connect(url: &str) -> io::Result<Self> {
114        let (ws_stream, _response) = connect_async(url)
115            .await
116            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()))?;
117
118        let (write, read) = ws_stream.split();
119        let reader = ByteReader::new(read);
120        let writer = ByteWriter::new(write);
121
122        Ok(Self {
123            inner: Box::new(WsStreamCompat { reader, writer }),
124        })
125    }
126
127    #[cfg(not(target_arch = "wasm32"))]
128    pub async fn accept_websocket(stream: Transport) -> io::Result<Self> {
129        let ws_stream = accept_async(stream)
130            .await
131            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
132
133        let (write, read) = ws_stream.split();
134        let reader = ByteReader::new(read);
135        let writer = ByteWriter::new(write);
136
137        Ok(Self {
138            inner: Box::new(WsStreamCompat { reader, writer }),
139        })
140    }
141
142    pub fn inner(&mut self) -> &mut dyn AsyncReadWrite {
143        &mut *self.inner
144    }
145}
146
147impl AsyncRead for Transport {
148    fn poll_read(
149        mut self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        buf: &mut ReadBuf<'_>,
152    ) -> Poll<io::Result<()>> {
153        Pin::new(&mut *self.inner).poll_read(cx, buf)
154    }
155}
156
157impl AsyncWrite for Transport {
158    fn poll_write(
159        mut self: Pin<&mut Self>,
160        cx: &mut Context<'_>,
161        buf: &[u8],
162    ) -> Poll<io::Result<usize>> {
163        Pin::new(&mut *self.inner).poll_write(cx, buf)
164    }
165
166    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
167        Pin::new(&mut *self.inner).poll_flush(cx)
168    }
169
170    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
171        Pin::new(&mut *self.inner).poll_shutdown(cx)
172    }
173}
174unsafe impl Send for Transport {
175
176}
177unsafe impl Sync for Transport {}