message_io/adapters/
ws.rs

1use crate::network::adapter::{
2    Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
3    ListeningInfo, PendingStatus,
4};
5use crate::network::{RemoteAddr, Readiness};
6use crate::util::thread::{OTHER_THREAD_ERR};
7use crate::network::{TransportConnect, TransportListen};
8
9use mio::event::{Source};
10use mio::net::{TcpStream, TcpListener};
11
12use tungstenite::protocol::{WebSocket, Message};
13use tungstenite::{accept as ws_accept};
14use tungstenite::client::{client as ws_connect};
15use tungstenite::handshake::{
16    HandshakeError, MidHandshake,
17    server::{ServerHandshake, NoCallback},
18    client::{ClientHandshake},
19};
20use tungstenite::error::{Error};
21
22use url::Url;
23
24use std::sync::{Mutex, Arc};
25use std::net::{SocketAddr};
26use std::io::{self, ErrorKind};
27use std::ops::{DerefMut};
28
29/// Max message size for default config
30// From https://docs.rs/tungstenite/0.13.0/src/tungstenite/protocol/mod.rs.html#65
31pub const MAX_PAYLOAD_LEN: usize = 32 << 20;
32
33pub(crate) struct WsAdapter;
34impl Adapter for WsAdapter {
35    type Remote = RemoteResource;
36    type Local = LocalResource;
37}
38
39enum PendingHandshake {
40    Connect(Url, ArcTcpStream),
41    Accept(ArcTcpStream),
42    Client(MidHandshake<ClientHandshake<ArcTcpStream>>),
43    Server(MidHandshake<ServerHandshake<ArcTcpStream, NoCallback>>),
44}
45
46#[allow(clippy::large_enum_variant)]
47enum RemoteState {
48    WebSocket(WebSocket<ArcTcpStream>),
49    Handshake(Option<PendingHandshake>),
50    Error(ArcTcpStream),
51}
52
53pub(crate) struct RemoteResource {
54    state: Mutex<RemoteState>,
55}
56
57impl Resource for RemoteResource {
58    fn source(&mut self) -> &mut dyn Source {
59        match self.state.get_mut().unwrap() {
60            RemoteState::WebSocket(web_socket) => {
61                Arc::get_mut(&mut web_socket.get_mut().0).unwrap()
62            }
63            RemoteState::Handshake(Some(handshake)) => match handshake {
64                PendingHandshake::Connect(_, stream) => Arc::get_mut(&mut stream.0).unwrap(),
65                PendingHandshake::Accept(stream) => Arc::get_mut(&mut stream.0).unwrap(),
66                PendingHandshake::Client(handshake) => {
67                    Arc::get_mut(&mut handshake.get_mut().get_mut().0).unwrap()
68                }
69                PendingHandshake::Server(handshake) => {
70                    Arc::get_mut(&mut handshake.get_mut().get_mut().0).unwrap()
71                }
72            },
73            RemoteState::Handshake(None) => unreachable!(),
74            RemoteState::Error(stream) => Arc::get_mut(&mut stream.0).unwrap(),
75        }
76    }
77}
78
79impl Remote for RemoteResource {
80    fn connect_with(
81        _: TransportConnect,
82        remote_addr: RemoteAddr,
83    ) -> io::Result<ConnectionInfo<Self>> {
84        let (peer_addr, url) = match remote_addr {
85            RemoteAddr::Socket(addr) => {
86                (addr, Url::parse(&format!("ws://{addr}/message-io-default")).unwrap())
87            }
88            RemoteAddr::Str(path) => {
89                let url = Url::parse(&path).expect("A valid URL");
90                let addr = url
91                    .socket_addrs(|| match url.scheme() {
92                        "ws" => Some(80),   // Plain
93                        "wss" => Some(443), //Tls
94                        _ => None,
95                    })
96                    .unwrap()[0];
97                (addr, url)
98            }
99        };
100
101        let stream = TcpStream::connect(peer_addr)?;
102        let local_addr = stream.local_addr()?;
103
104        Ok(ConnectionInfo {
105            remote: RemoteResource {
106                state: Mutex::new(RemoteState::Handshake(Some(PendingHandshake::Connect(
107                    url,
108                    stream.into(),
109                )))),
110            },
111            local_addr,
112            peer_addr,
113        })
114    }
115
116    fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
117        loop {
118            // "emulates" full duplex for the websocket case locking here and not outside the loop.
119            let mut state = self.state.lock().expect(OTHER_THREAD_ERR);
120            let deref_state = state.deref_mut();
121
122            match deref_state {
123                RemoteState::WebSocket(web_socket) => match web_socket.read() {
124                    Ok(message) => match message {
125                        Message::Binary(data) => {
126                            // As an optimization.
127                            // Fast check to know if there is more data to avoid call
128                            // WebSocket::read_message() again.
129                            // TODO: investigate why this code doesn't work in windows.
130                            // Seems like windows consume the `WouldBlock` notification
131                            // at peek() when it happens, and the poll never wakes it again.
132                            #[cfg(not(target_os = "windows"))]
133                            let _peek_result = web_socket.get_ref().0.peek(&mut [0; 0]);
134
135                            // We can not call process_data while the socket is blocked.
136                            // The user could lock it again if sends from the callback.
137                            drop(state);
138                            process_data(&data);
139
140                            #[cfg(not(target_os = "windows"))]
141                            if let Err(err) = _peek_result {
142                                break Self::io_error_to_read_status(&err);
143                            }
144                        }
145                        Message::Close(_) => break ReadStatus::Disconnected,
146                        _ => continue,
147                    },
148                    Err(Error::Io(ref err)) => break Self::io_error_to_read_status(err),
149                    Err(err) => {
150                        log::error!("WS receive error: {}", err);
151                        break ReadStatus::Disconnected; // should not happen
152                    }
153                },
154                RemoteState::Handshake(_) => unreachable!(),
155                RemoteState::Error(_) => unreachable!(),
156            }
157        }
158    }
159
160    fn send(&self, data: &[u8]) -> SendStatus {
161        let mut state = self.state.lock().expect(OTHER_THREAD_ERR);
162        let deref_state = state.deref_mut();
163        match deref_state {
164            RemoteState::WebSocket(web_socket) => {
165                let message = Message::Binary(data.to_vec().into());
166
167                let mut result = web_socket.send(message);
168                loop {
169                    match result {
170                        Ok(_) => break SendStatus::Sent,
171                        Err(Error::Io(ref err)) if err.kind() == ErrorKind::WouldBlock => {
172                            result = web_socket.flush();
173                        }
174                        Err(Error::Capacity(_)) => break SendStatus::MaxPacketSizeExceeded,
175                        Err(err) => {
176                            log::error!("WS send error: {}", err);
177                            break SendStatus::ResourceNotFound; // should not happen
178                        }
179                    }
180                }
181            }
182            RemoteState::Handshake(_) => unreachable!(),
183            RemoteState::Error(_) => unreachable!(),
184        }
185    }
186
187    fn pending(&self, _readiness: Readiness) -> PendingStatus {
188        let mut state = self.state.lock().expect(OTHER_THREAD_ERR);
189        let deref_state = state.deref_mut();
190        match deref_state {
191            RemoteState::WebSocket(_) => PendingStatus::Ready,
192            RemoteState::Handshake(pending) => match pending.take().unwrap() {
193                PendingHandshake::Connect(url, stream) => {
194                    let tcp_status = super::tcp::check_stream_ready(&stream.0);
195                    if tcp_status != PendingStatus::Ready {
196                        // TCP handshake not ready yet.
197                        *pending = Some(PendingHandshake::Connect(url, stream));
198                        return tcp_status;
199                    }
200                    let stream_backup = stream.clone();
201                    match ws_connect(url, stream) {
202                        Ok((web_socket, _)) => {
203                            *state = RemoteState::WebSocket(web_socket);
204                            PendingStatus::Ready
205                        }
206                        Err(HandshakeError::Interrupted(mid_handshake)) => {
207                            *pending = Some(PendingHandshake::Client(mid_handshake));
208                            PendingStatus::Incomplete
209                        }
210                        Err(HandshakeError::Failure(Error::Io(_))) => {
211                            *state = RemoteState::Error(stream_backup);
212                            PendingStatus::Disconnected
213                        }
214                        Err(HandshakeError::Failure(err)) => {
215                            *state = RemoteState::Error(stream_backup);
216                            log::error!("WS connect handshake error: {}", err);
217                            PendingStatus::Disconnected // should not happen
218                        }
219                    }
220                }
221                PendingHandshake::Accept(stream) => {
222                    let stream_backup = stream.clone();
223                    match ws_accept(stream) {
224                        Ok(web_socket) => {
225                            *state = RemoteState::WebSocket(web_socket);
226                            PendingStatus::Ready
227                        }
228                        Err(HandshakeError::Interrupted(mid_handshake)) => {
229                            *pending = Some(PendingHandshake::Server(mid_handshake));
230                            PendingStatus::Incomplete
231                        }
232                        Err(HandshakeError::Failure(Error::Io(_))) => {
233                            *state = RemoteState::Error(stream_backup);
234                            PendingStatus::Disconnected
235                        }
236                        Err(HandshakeError::Failure(err)) => {
237                            *state = RemoteState::Error(stream_backup);
238                            log::error!("WS accept handshake error: {}", err);
239                            PendingStatus::Disconnected
240                        }
241                    }
242                }
243                PendingHandshake::Client(mid_handshake) => {
244                    let stream_backup = mid_handshake.get_ref().get_ref().clone();
245                    match mid_handshake.handshake() {
246                        Ok((web_socket, _)) => {
247                            *state = RemoteState::WebSocket(web_socket);
248                            PendingStatus::Ready
249                        }
250                        Err(HandshakeError::Interrupted(mid_handshake)) => {
251                            *pending = Some(PendingHandshake::Client(mid_handshake));
252                            PendingStatus::Incomplete
253                        }
254                        Err(HandshakeError::Failure(Error::Io(_))) => {
255                            *state = RemoteState::Error(stream_backup);
256                            PendingStatus::Disconnected
257                        }
258                        Err(HandshakeError::Failure(err)) => {
259                            *state = RemoteState::Error(stream_backup);
260                            log::error!("WS client handshake error: {}", err);
261                            PendingStatus::Disconnected // should not happen
262                        }
263                    }
264                }
265                PendingHandshake::Server(mid_handshake) => {
266                    let stream_backup = mid_handshake.get_ref().get_ref().clone();
267                    match mid_handshake.handshake() {
268                        Ok(web_socket) => {
269                            *state = RemoteState::WebSocket(web_socket);
270                            PendingStatus::Ready
271                        }
272                        Err(HandshakeError::Interrupted(mid_handshake)) => {
273                            *pending = Some(PendingHandshake::Server(mid_handshake));
274                            PendingStatus::Incomplete
275                        }
276                        Err(HandshakeError::Failure(Error::Io(_))) => {
277                            *state = RemoteState::Error(stream_backup);
278                            PendingStatus::Disconnected
279                        }
280                        Err(HandshakeError::Failure(err)) => {
281                            *state = RemoteState::Error(stream_backup);
282                            log::error!("WS server handshake error: {}", err);
283                            PendingStatus::Disconnected // should not happen
284                        }
285                    }
286                }
287            },
288            RemoteState::Error(_) => unreachable!(),
289        }
290    }
291
292    fn ready_to_write(&self) -> bool {
293        true
294        /* Is this needed?
295        match self.state.lock().expect(OTHER_THREAD_ERR).deref_mut() {
296            RemoteState::WebSocket(web_socket) => match web_socket.write_pending() {
297                Ok(_) => true,
298                Err(Error::Io(ref err)) if err.kind() == ErrorKind::WouldBlock => true,
299                Err(_) => false, // Will be disconnected,
300            },
301            // This function is only call on ready resources.
302            RemoteState::Handshake(_) => unreachable!(),
303            RemoteState::Error(_) => unreachable!(),
304        }
305        */
306    }
307}
308
309impl RemoteResource {
310    fn io_error_to_read_status(err: &io::Error) -> ReadStatus {
311        if err.kind() == io::ErrorKind::WouldBlock {
312            ReadStatus::WaitNextEvent
313        }
314        else if err.kind() == io::ErrorKind::ConnectionReset {
315            ReadStatus::Disconnected
316        }
317        else {
318            log::error!("WS receive error: {}", err);
319            ReadStatus::Disconnected // should not happen
320        }
321    }
322}
323
324pub(crate) struct LocalResource {
325    listener: TcpListener,
326}
327
328impl Resource for LocalResource {
329    fn source(&mut self) -> &mut dyn Source {
330        &mut self.listener
331    }
332}
333
334impl Local for LocalResource {
335    type Remote = RemoteResource;
336
337    fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
338        let listener = TcpListener::bind(addr)?;
339        let local_addr = listener.local_addr().unwrap();
340        Ok(ListeningInfo { local: LocalResource { listener }, local_addr })
341    }
342
343    fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
344        loop {
345            match self.listener.accept() {
346                Ok((stream, addr)) => {
347                    let remote = RemoteResource {
348                        state: Mutex::new(RemoteState::Handshake(Some(PendingHandshake::Accept(
349                            stream.into(),
350                        )))),
351                    };
352                    accept_remote(AcceptedType::Remote(addr, remote));
353                }
354                Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
355                Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
356                Err(err) => break log::error!("WS accept error: {}", err), // Should not happen
357            }
358        }
359    }
360}
361
362/// This struct is used to avoid the tungstenite handshake to take the ownership of the stream
363/// an drop it without allow to the driver to deregister from the poll.
364/// It can be removed when this issue is resolved:
365/// https://github.com/snapview/tungstenite-rs/issues/51
366struct ArcTcpStream(Arc<TcpStream>);
367
368impl From<TcpStream> for ArcTcpStream {
369    fn from(stream: TcpStream) -> Self {
370        Self(Arc::new(stream))
371    }
372}
373
374impl io::Read for ArcTcpStream {
375    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
376        (&*self.0).read(buf)
377    }
378}
379
380impl io::Write for ArcTcpStream {
381    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
382        (&*self.0).write(buf)
383    }
384
385    fn flush(&mut self) -> io::Result<()> {
386        (&*self.0).flush()
387    }
388}
389
390impl Clone for ArcTcpStream {
391    fn clone(&self) -> Self {
392        Self(self.0.clone())
393    }
394}