1use crate::network::adapter::{
2 Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
3 ListeningInfo, PendingStatus,
4};
5use crate::network::{RemoteAddr, Readiness};
6use crate::util::thread::{OTHER_THREAD_ERR};
7use crate::network::{TransportConnect, TransportListen};
8
9use mio::event::{Source};
10use mio::net::{TcpStream, TcpListener};
11
12use tungstenite::protocol::{WebSocket, Message};
13use tungstenite::{accept as ws_accept};
14use tungstenite::client::{client as ws_connect};
15use tungstenite::handshake::{
16 HandshakeError, MidHandshake,
17 server::{ServerHandshake, NoCallback},
18 client::{ClientHandshake},
19};
20use tungstenite::error::{Error};
21
22use url::Url;
23
24use std::sync::{Mutex, Arc};
25use std::net::{SocketAddr};
26use std::io::{self, ErrorKind};
27use std::ops::{DerefMut};
28
29pub const MAX_PAYLOAD_LEN: usize = 32 << 20;
32
33pub(crate) struct WsAdapter;
34impl Adapter for WsAdapter {
35 type Remote = RemoteResource;
36 type Local = LocalResource;
37}
38
39enum PendingHandshake {
40 Connect(Url, ArcTcpStream),
41 Accept(ArcTcpStream),
42 Client(MidHandshake<ClientHandshake<ArcTcpStream>>),
43 Server(MidHandshake<ServerHandshake<ArcTcpStream, NoCallback>>),
44}
45
46#[allow(clippy::large_enum_variant)]
47enum RemoteState {
48 WebSocket(WebSocket<ArcTcpStream>),
49 Handshake(Option<PendingHandshake>),
50 Error(ArcTcpStream),
51}
52
53pub(crate) struct RemoteResource {
54 state: Mutex<RemoteState>,
55}
56
57impl Resource for RemoteResource {
58 fn source(&mut self) -> &mut dyn Source {
59 match self.state.get_mut().unwrap() {
60 RemoteState::WebSocket(web_socket) => {
61 Arc::get_mut(&mut web_socket.get_mut().0).unwrap()
62 }
63 RemoteState::Handshake(Some(handshake)) => match handshake {
64 PendingHandshake::Connect(_, stream) => Arc::get_mut(&mut stream.0).unwrap(),
65 PendingHandshake::Accept(stream) => Arc::get_mut(&mut stream.0).unwrap(),
66 PendingHandshake::Client(handshake) => {
67 Arc::get_mut(&mut handshake.get_mut().get_mut().0).unwrap()
68 }
69 PendingHandshake::Server(handshake) => {
70 Arc::get_mut(&mut handshake.get_mut().get_mut().0).unwrap()
71 }
72 },
73 RemoteState::Handshake(None) => unreachable!(),
74 RemoteState::Error(stream) => Arc::get_mut(&mut stream.0).unwrap(),
75 }
76 }
77}
78
79impl Remote for RemoteResource {
80 fn connect_with(
81 _: TransportConnect,
82 remote_addr: RemoteAddr,
83 ) -> io::Result<ConnectionInfo<Self>> {
84 let (peer_addr, url) = match remote_addr {
85 RemoteAddr::Socket(addr) => {
86 (addr, Url::parse(&format!("ws://{addr}/message-io-default")).unwrap())
87 }
88 RemoteAddr::Str(path) => {
89 let url = Url::parse(&path).expect("A valid URL");
90 let addr = url
91 .socket_addrs(|| match url.scheme() {
92 "ws" => Some(80), "wss" => Some(443), _ => None,
95 })
96 .unwrap()[0];
97 (addr, url)
98 }
99 };
100
101 let stream = TcpStream::connect(peer_addr)?;
102 let local_addr = stream.local_addr()?;
103
104 Ok(ConnectionInfo {
105 remote: RemoteResource {
106 state: Mutex::new(RemoteState::Handshake(Some(PendingHandshake::Connect(
107 url,
108 stream.into(),
109 )))),
110 },
111 local_addr,
112 peer_addr,
113 })
114 }
115
116 fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
117 loop {
118 let mut state = self.state.lock().expect(OTHER_THREAD_ERR);
120 let deref_state = state.deref_mut();
121
122 match deref_state {
123 RemoteState::WebSocket(web_socket) => match web_socket.read() {
124 Ok(message) => match message {
125 Message::Binary(data) => {
126 #[cfg(not(target_os = "windows"))]
133 let _peek_result = web_socket.get_ref().0.peek(&mut [0; 0]);
134
135 drop(state);
138 process_data(&data);
139
140 #[cfg(not(target_os = "windows"))]
141 if let Err(err) = _peek_result {
142 break Self::io_error_to_read_status(&err);
143 }
144 }
145 Message::Close(_) => break ReadStatus::Disconnected,
146 _ => continue,
147 },
148 Err(Error::Io(ref err)) => break Self::io_error_to_read_status(err),
149 Err(err) => {
150 log::error!("WS receive error: {}", err);
151 break ReadStatus::Disconnected; }
153 },
154 RemoteState::Handshake(_) => unreachable!(),
155 RemoteState::Error(_) => unreachable!(),
156 }
157 }
158 }
159
160 fn send(&self, data: &[u8]) -> SendStatus {
161 let mut state = self.state.lock().expect(OTHER_THREAD_ERR);
162 let deref_state = state.deref_mut();
163 match deref_state {
164 RemoteState::WebSocket(web_socket) => {
165 let message = Message::Binary(data.to_vec().into());
166
167 let mut result = web_socket.send(message);
168 loop {
169 match result {
170 Ok(_) => break SendStatus::Sent,
171 Err(Error::Io(ref err)) if err.kind() == ErrorKind::WouldBlock => {
172 result = web_socket.flush();
173 }
174 Err(Error::Capacity(_)) => break SendStatus::MaxPacketSizeExceeded,
175 Err(err) => {
176 log::error!("WS send error: {}", err);
177 break SendStatus::ResourceNotFound; }
179 }
180 }
181 }
182 RemoteState::Handshake(_) => unreachable!(),
183 RemoteState::Error(_) => unreachable!(),
184 }
185 }
186
187 fn pending(&self, _readiness: Readiness) -> PendingStatus {
188 let mut state = self.state.lock().expect(OTHER_THREAD_ERR);
189 let deref_state = state.deref_mut();
190 match deref_state {
191 RemoteState::WebSocket(_) => PendingStatus::Ready,
192 RemoteState::Handshake(pending) => match pending.take().unwrap() {
193 PendingHandshake::Connect(url, stream) => {
194 let tcp_status = super::tcp::check_stream_ready(&stream.0);
195 if tcp_status != PendingStatus::Ready {
196 *pending = Some(PendingHandshake::Connect(url, stream));
198 return tcp_status;
199 }
200 let stream_backup = stream.clone();
201 match ws_connect(url, stream) {
202 Ok((web_socket, _)) => {
203 *state = RemoteState::WebSocket(web_socket);
204 PendingStatus::Ready
205 }
206 Err(HandshakeError::Interrupted(mid_handshake)) => {
207 *pending = Some(PendingHandshake::Client(mid_handshake));
208 PendingStatus::Incomplete
209 }
210 Err(HandshakeError::Failure(Error::Io(_))) => {
211 *state = RemoteState::Error(stream_backup);
212 PendingStatus::Disconnected
213 }
214 Err(HandshakeError::Failure(err)) => {
215 *state = RemoteState::Error(stream_backup);
216 log::error!("WS connect handshake error: {}", err);
217 PendingStatus::Disconnected }
219 }
220 }
221 PendingHandshake::Accept(stream) => {
222 let stream_backup = stream.clone();
223 match ws_accept(stream) {
224 Ok(web_socket) => {
225 *state = RemoteState::WebSocket(web_socket);
226 PendingStatus::Ready
227 }
228 Err(HandshakeError::Interrupted(mid_handshake)) => {
229 *pending = Some(PendingHandshake::Server(mid_handshake));
230 PendingStatus::Incomplete
231 }
232 Err(HandshakeError::Failure(Error::Io(_))) => {
233 *state = RemoteState::Error(stream_backup);
234 PendingStatus::Disconnected
235 }
236 Err(HandshakeError::Failure(err)) => {
237 *state = RemoteState::Error(stream_backup);
238 log::error!("WS accept handshake error: {}", err);
239 PendingStatus::Disconnected
240 }
241 }
242 }
243 PendingHandshake::Client(mid_handshake) => {
244 let stream_backup = mid_handshake.get_ref().get_ref().clone();
245 match mid_handshake.handshake() {
246 Ok((web_socket, _)) => {
247 *state = RemoteState::WebSocket(web_socket);
248 PendingStatus::Ready
249 }
250 Err(HandshakeError::Interrupted(mid_handshake)) => {
251 *pending = Some(PendingHandshake::Client(mid_handshake));
252 PendingStatus::Incomplete
253 }
254 Err(HandshakeError::Failure(Error::Io(_))) => {
255 *state = RemoteState::Error(stream_backup);
256 PendingStatus::Disconnected
257 }
258 Err(HandshakeError::Failure(err)) => {
259 *state = RemoteState::Error(stream_backup);
260 log::error!("WS client handshake error: {}", err);
261 PendingStatus::Disconnected }
263 }
264 }
265 PendingHandshake::Server(mid_handshake) => {
266 let stream_backup = mid_handshake.get_ref().get_ref().clone();
267 match mid_handshake.handshake() {
268 Ok(web_socket) => {
269 *state = RemoteState::WebSocket(web_socket);
270 PendingStatus::Ready
271 }
272 Err(HandshakeError::Interrupted(mid_handshake)) => {
273 *pending = Some(PendingHandshake::Server(mid_handshake));
274 PendingStatus::Incomplete
275 }
276 Err(HandshakeError::Failure(Error::Io(_))) => {
277 *state = RemoteState::Error(stream_backup);
278 PendingStatus::Disconnected
279 }
280 Err(HandshakeError::Failure(err)) => {
281 *state = RemoteState::Error(stream_backup);
282 log::error!("WS server handshake error: {}", err);
283 PendingStatus::Disconnected }
285 }
286 }
287 },
288 RemoteState::Error(_) => unreachable!(),
289 }
290 }
291
292 fn ready_to_write(&self) -> bool {
293 true
294 }
307}
308
309impl RemoteResource {
310 fn io_error_to_read_status(err: &io::Error) -> ReadStatus {
311 if err.kind() == io::ErrorKind::WouldBlock {
312 ReadStatus::WaitNextEvent
313 }
314 else if err.kind() == io::ErrorKind::ConnectionReset {
315 ReadStatus::Disconnected
316 }
317 else {
318 log::error!("WS receive error: {}", err);
319 ReadStatus::Disconnected }
321 }
322}
323
324pub(crate) struct LocalResource {
325 listener: TcpListener,
326}
327
328impl Resource for LocalResource {
329 fn source(&mut self) -> &mut dyn Source {
330 &mut self.listener
331 }
332}
333
334impl Local for LocalResource {
335 type Remote = RemoteResource;
336
337 fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
338 let listener = TcpListener::bind(addr)?;
339 let local_addr = listener.local_addr().unwrap();
340 Ok(ListeningInfo { local: LocalResource { listener }, local_addr })
341 }
342
343 fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
344 loop {
345 match self.listener.accept() {
346 Ok((stream, addr)) => {
347 let remote = RemoteResource {
348 state: Mutex::new(RemoteState::Handshake(Some(PendingHandshake::Accept(
349 stream.into(),
350 )))),
351 };
352 accept_remote(AcceptedType::Remote(addr, remote));
353 }
354 Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
355 Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
356 Err(err) => break log::error!("WS accept error: {}", err), }
358 }
359 }
360}
361
362struct ArcTcpStream(Arc<TcpStream>);
367
368impl From<TcpStream> for ArcTcpStream {
369 fn from(stream: TcpStream) -> Self {
370 Self(Arc::new(stream))
371 }
372}
373
374impl io::Read for ArcTcpStream {
375 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
376 (&*self.0).read(buf)
377 }
378}
379
380impl io::Write for ArcTcpStream {
381 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
382 (&*self.0).write(buf)
383 }
384
385 fn flush(&mut self) -> io::Result<()> {
386 (&*self.0).flush()
387 }
388}
389
390impl Clone for ArcTcpStream {
391 fn clone(&self) -> Self {
392 Self(self.0.clone())
393 }
394}