1use bytes::{BufMut, BytesMut};
2use futures_util::{Future, FutureExt, SinkExt, StreamExt};
3use hyper::{client::conn, http::HeaderValue, Body, HeaderMap, Request, StatusCode};
4use narrowlink_types::ServiceType;
5use std::{
6 collections::HashMap,
7 io::{self, Error, ErrorKind},
8 net::{SocketAddr, SocketAddrV4},
9 pin::Pin,
10 task::{Context, Poll},
11};
12use tokio::{
13 io::{AsyncRead, AsyncWrite, ReadBuf},
14 task::JoinHandle,
15};
16use tokio_tungstenite::WebSocketStream;
17use tracing::{debug, trace, warn};
18use tungstenite::Message;
19
20use crate::{
21 error::NetworkError,
22 transport::{StreamType, TlsConfiguration, UnifiedSocket},
23 AsyncSocket,
24};
25
26const KEEP_ALIVE_TIME: u64 = 20;
27
28pub enum WsMode {
29 Server(tokio::time::Interval),
30 Client(HeaderMap, JoinHandle<()>),
31}
32
33pub struct WsConnection {
34 ws_stream: WebSocketStream<Box<dyn AsyncSocket>>,
35 remaining_bytes: Option<BytesMut>,
36 mode: WsMode,
37 local_addr: SocketAddr,
38 peer_addr: SocketAddr,
39}
40
41impl WsConnection {
42 pub async fn from(server_stream: impl AsyncSocket) -> Self {
43 let ws_stream = WebSocketStream::from_raw_socket(
45 Box::new(server_stream) as Box<dyn AsyncSocket>,
46 tungstenite::protocol::Role::Server,
47 None,
48 )
49 .await;
50
51 Self {
52 ws_stream,
53 remaining_bytes: None,
54 mode: WsMode::Server(tokio::time::interval(core::time::Duration::from_secs(
55 KEEP_ALIVE_TIME,
56 ))),
57 local_addr: SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::UNSPECIFIED, 0)),
58 peer_addr: SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::UNSPECIFIED, 0)),
59 }
60 }
61 pub async fn new(
62 host: &str,
63 headers: &HashMap<&'static str, String>,
64 service_type: &ServiceType,
65 ) -> Result<Self, NetworkError> {
66 let sni = if let Some(sni) = host.split(':').next() {
67 sni
68 } else {
69 host
70 };
71 let transport_type = if let ServiceType::Wss = service_type {
72 StreamType::Tls(TlsConfiguration {
73 sni: sni.to_owned(),
74 })
75 } else {
76 StreamType::Tcp
77 };
78 let stream = UnifiedSocket::new(host, transport_type).await?;
79 let local_addr = stream.local_addr();
80 let peer_addr = stream.peer_addr();
81 let (mut request_sender, connection) = conn::handshake(stream).await?;
82 let conn_handler = tokio::spawn(async move {
83 if let Err(e) = connection.await {
84 eprintln!("Error in connection: {}", e);
85 }
86 });
87 let mut request = Request::builder()
88 .header(
90 "Host",
91 host.strip_suffix(":443")
92 .or(host.strip_suffix(":80"))
93 .unwrap_or(host),
94 )
95 .header("Connection", "Upgrade")
96 .header("Upgrade", "websocket")
97 .header("Sec-WebSocket-Version", "13")
98 .header(
99 "Sec-WebSocket-Key",
100 tungstenite::handshake::client::generate_key(),
101 )
102 .header("NL-VERSION", env!("CARGO_PKG_VERSION"));
103 for (key, value) in headers.iter() {
104 if let Ok(header_value) = HeaderValue::from_str(value) {
105 request
106 .headers_mut()
107 .and_then(|headers| headers.insert(*key, header_value));
108 }
109 }
110 let request = request.method("GET").body(Body::from(""))?;
111 let response = request_sender.send_request(request).await?;
112 let response_headers = response.headers().clone();
113 trace!("response status: {}", response.status().to_string());
114 if response.status() != StatusCode::SWITCHING_PROTOCOLS {
115 let status_code = response.status().as_u16();
116 trace!(
117 "response body: {}",
118 String::from_utf8_lossy(
119 hyper::body::to_bytes(response.into_body()).await?.as_ref()
120 )
121 );
122
123 return Err(NetworkError::UnableToUpgrade(status_code));
124 }
125
126 let upgraded = hyper::upgrade::on(response).await?;
127 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
128 Box::new(upgraded) as Box<dyn AsyncSocket>,
129 tungstenite::protocol::Role::Client,
130 None,
131 )
132 .await;
133 Ok(Self {
134 ws_stream,
135 remaining_bytes: None,
136 mode: WsMode::Client(response_headers, conn_handler),
137 local_addr,
138 peer_addr,
139 })
140 }
141 pub fn get_header(&self, key: &str) -> Option<&str> {
142 if let WsMode::Client(response_headers, _) = &self.mode {
143 response_headers.get(key).and_then(|v| v.to_str().ok())
144 } else {
145 None
146 }
147 }
148 pub fn drive_key(key: &[u8]) -> String {
149 tungstenite::handshake::derive_accept_key(key)
150 }
151 pub fn local_addr(&self) -> SocketAddr {
152 self.local_addr
153 }
154 pub fn peer_addr(&self) -> SocketAddr {
155 self.peer_addr
156 }
157}
158
159impl AsyncRead for WsConnection {
160 fn poll_read(
161 mut self: std::pin::Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &mut ReadBuf<'_>,
164 ) -> Poll<io::Result<()>> {
165 loop {
166 if let Some(remaining_buf) = self.remaining_bytes.as_mut() {
167 if buf.remaining() < remaining_buf.len() {
168 let buffer = remaining_buf.split_to(buf.remaining());
169 buf.put_slice(&buffer);
170 } else {
171 buf.put_slice(remaining_buf);
172 self.remaining_bytes = None::<BytesMut>;
173 }
174 return Poll::Ready(Ok(()));
175 }
176
177 match self.ws_stream.poll_next_unpin(cx) {
178 Poll::Ready(d) => match d {
179 Some(Ok(data)) => {
180 if let Message::Binary(bin) = data {
181 if buf.remaining() < bin.len() {
182 let mut bytes =
184 BytesMut::with_capacity(bin.len() - buf.remaining());
185 bytes.put(&bin[buf.remaining()..]);
186 self.remaining_bytes = Some(bytes);
187 buf.put_slice(&bin[..buf.remaining()]);
188 } else {
189 buf.put_slice(&bin);
190 }
191
192 return Poll::Ready(Ok(()));
193 } else {
194 continue;
195 }
196 }
197 Some(Err(_e)) => io::Error::from(io::ErrorKind::UnexpectedEof),
198 None => return Poll::Ready(Ok(())),
199 },
200 Poll::Pending => {
201 if let WsMode::Server(interval) = &mut self.mode {
202 match interval.poll_tick(cx) {
203 Poll::Ready(_) => {
204 match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
205 Poll::Ready(Ok(_)) => continue,
206 Poll::Ready(Err(_e)) => {
207 return Poll::Ready(Err(Error::new(
208 ErrorKind::Other,
209 "Ping Error!",
210 )))
211 }
212 Poll::Pending => return Poll::Pending,
213 }
214 }
215 Poll::Pending => return Poll::Pending,
216 }
217 } else {
218 return Poll::Pending;
219 }
220 }
221 };
222 }
223 }
224}
225
226impl AsyncWrite for WsConnection {
227 fn poll_write(
228 mut self: std::pin::Pin<&mut Self>,
229 cx: &mut Context<'_>,
230 buf: &[u8],
231 ) -> Poll<Result<usize, io::Error>> {
232 match Pin::new(&mut self.ws_stream.send(Message::binary(buf)))
233 .poll(cx)
234 .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?
235 {
236 Poll::Ready(_) => Poll::Ready(Ok(buf.len())),
237 Poll::Pending => Poll::Pending,
238 }
239 }
240
241 fn poll_flush(
242 mut self: std::pin::Pin<&mut Self>,
243 cx: &mut Context<'_>,
244 ) -> Poll<Result<(), io::Error>> {
245 self.ws_stream
246 .poll_flush_unpin(cx)
247 .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
248 }
249
250 fn poll_shutdown(
251 mut self: std::pin::Pin<&mut Self>,
252 cx: &mut Context<'_>,
253 ) -> Poll<Result<(), io::Error>> {
254 self.ws_stream
255 .poll_close_unpin(cx)
256 .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
257 }
258}
259
260impl futures_util::Stream for WsConnection {
261 type Item = Result<String, NetworkError>;
262
263 fn poll_next(
264 mut self: std::pin::Pin<&mut Self>,
265 cx: &mut Context<'_>,
266 ) -> Poll<Option<Self::Item>> {
267 loop {
268 match self.ws_stream.poll_next_unpin(cx) {
269 Poll::Ready(Some(Ok(msg))) => {
270 if let Message::Text(msg) = msg {
271 return Poll::Ready(Some(Ok(msg)));
272 } else {
273 continue;
274 }
275 }
276 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
277 Poll::Ready(None) => return Poll::Ready(None),
278 Poll::Pending => {
279 if let WsMode::Server(interval) = &mut self.mode {
280 match interval.poll_tick(cx) {
281 Poll::Ready(_) => {
282 match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
283 Poll::Ready(Ok(_)) => continue,
284 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
285 Poll::Pending => return Poll::Pending,
286 }
287 }
288 Poll::Pending => return Poll::Pending,
289 }
290 } else {
291 return Poll::Pending;
292 }
293 }
294 }
295 }
296 }
297}
298impl futures_util::Sink<String> for WsConnection {
299 type Error = NetworkError;
300
301 fn poll_ready(
302 mut self: std::pin::Pin<&mut Self>,
303 cx: &mut Context<'_>,
304 ) -> Poll<Result<(), Self::Error>> {
305 self.ws_stream.poll_ready_unpin(cx).map_err(|e| e.into())
306 }
307
308 fn start_send(mut self: std::pin::Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
309 self.ws_stream
310 .start_send_unpin(Message::Text(item))
311 .map_err(|e| e.into())
312 }
313
314 fn poll_flush(
315 mut self: std::pin::Pin<&mut Self>,
316 cx: &mut Context<'_>,
317 ) -> Poll<Result<(), Self::Error>> {
318 self.ws_stream.poll_flush_unpin(cx).map_err(|e| e.into())
319 }
320
321 fn poll_close(
322 mut self: std::pin::Pin<&mut Self>,
323 cx: &mut Context<'_>,
324 ) -> Poll<Result<(), Self::Error>> {
325 self.ws_stream.poll_close_unpin(cx).map_err(|e| e.into())
326 }
327}
328
329pub struct WsConnectionBinary {
330 ws_stream: WebSocketStream<Box<dyn AsyncSocket>>,
331 remaining_bytes: Option<BytesMut>,
332 mode: WsMode,
333}
334
335impl WsConnectionBinary {
336 pub async fn from(server_stream: impl AsyncSocket) -> Self {
337 let ws_stream = WebSocketStream::from_raw_socket(
339 Box::new(server_stream) as Box<dyn AsyncSocket>,
340 tungstenite::protocol::Role::Server,
341 None,
342 )
343 .await;
344
345 Self {
346 ws_stream,
347 remaining_bytes: None,
348 mode: WsMode::Server(tokio::time::interval(core::time::Duration::from_secs(
349 KEEP_ALIVE_TIME,
350 ))),
351 }
352 }
353 pub async fn new(
354 host: &str,
355 headers: HashMap<&'static str, String>,
357 service_type: &ServiceType,
358 ) -> Result<Self, NetworkError> {
359 let sni = if let Some(sni) = host.split(':').next() {
360 sni
361 } else {
362 host
363 };
364 let transport_type = if let ServiceType::Wss = service_type {
365 StreamType::Tls(TlsConfiguration {
366 sni: sni.to_owned(),
367 })
368 } else {
369 StreamType::Tcp
370 };
371 let stream = UnifiedSocket::new(host, transport_type).await?;
372
373 let (mut request_sender, connection) = conn::handshake(stream).await?;
374 let conn_handler = tokio::spawn(async move {
375 if let Err(e) = connection.await {
376 warn!("Error in connection: {}", e);
377 }
378 });
379
380 let mut request = Request::builder()
381 .header(
383 "Host",
384 host.strip_suffix(":443")
385 .or(host.strip_suffix(":80"))
386 .unwrap_or(host),
387 )
388 .header("Connection", "Upgrade")
389 .header("Upgrade", "websocket")
390 .header("Sec-WebSocket-Version", "13")
391 .header(
392 "Sec-WebSocket-Key",
393 tungstenite::handshake::client::generate_key(),
394 );
395 for (key, value) in headers.iter() {
396 if let Ok(header_value) = HeaderValue::from_str(value) {
397 request
398 .headers_mut()
399 .and_then(|headers| headers.insert(*key, header_value));
400 }
401 }
402 let request = request.method("GET").body(Body::from(""))?;
403 let response = request_sender.send_request(request).await?;
404 let response_headers = response.headers().clone();
405 debug!("ws connection status: {}", response.status());
406 if response.status() != StatusCode::SWITCHING_PROTOCOLS {
407 let status_code = response.status().as_u16();
408 trace!(
409 "response body: {}",
410 String::from_utf8_lossy(
411 hyper::body::to_bytes(response.into_body()).await?.as_ref()
412 )
413 );
414 return Err(NetworkError::UnableToUpgrade(status_code));
415 }
416 let upgraded = hyper::upgrade::on(response).await?;
417 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
418 Box::new(upgraded) as Box<dyn AsyncSocket>,
419 tungstenite::protocol::Role::Client,
420 None,
421 )
422 .await;
423 Ok(Self {
424 ws_stream,
425 remaining_bytes: None,
426 mode: WsMode::Client(response_headers, conn_handler),
427 })
428 }
429 pub fn get_header(&self, key: &str) -> Option<&str> {
430 if let WsMode::Client(response_headers, _) = &self.mode {
431 response_headers.get(key).and_then(|v| v.to_str().ok())
432 } else {
433 None
434 }
435 }
436 pub fn drive_key(key: &[u8]) -> String {
437 tungstenite::handshake::derive_accept_key(key)
438 }
439}
440
441impl futures_util::Stream for WsConnectionBinary {
442 type Item = Result<Vec<u8>, NetworkError>;
443
444 fn poll_next(
445 mut self: std::pin::Pin<&mut Self>,
446 cx: &mut Context<'_>,
447 ) -> Poll<Option<Self::Item>> {
448 loop {
449 match self.ws_stream.poll_next_unpin(cx) {
450 Poll::Ready(Some(Ok(msg))) => {
451 if let Message::Binary(msg) = msg {
452 return Poll::Ready(Some(Ok(msg)));
453 } else {
454 continue;
455 }
456 }
457 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
458 Poll::Ready(None) => return Poll::Ready(None),
459 Poll::Pending => {
460 if let WsMode::Server(interval) = &mut self.mode {
461 match interval.poll_tick(cx) {
462 Poll::Ready(_) => {
463 match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
464 Poll::Ready(Ok(_)) => continue,
465 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
466 Poll::Pending => return Poll::Pending,
467 }
468 }
469 Poll::Pending => return Poll::Pending,
470 }
471 } else {
472 return Poll::Pending;
473 }
474 }
475 }
476 }
477 }
478}
479impl futures_util::Sink<Vec<u8>> for WsConnectionBinary {
480 type Error = NetworkError;
481
482 fn poll_ready(
483 mut self: std::pin::Pin<&mut Self>,
484 cx: &mut Context<'_>,
485 ) -> Poll<Result<(), Self::Error>> {
486 self.ws_stream.poll_ready_unpin(cx).map_err(|e| e.into())
487 }
488
489 fn start_send(mut self: std::pin::Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
490 self.ws_stream
491 .start_send_unpin(Message::Binary(item))
492 .map_err(|e| e.into())
493 }
494
495 fn poll_flush(
496 mut self: std::pin::Pin<&mut Self>,
497 cx: &mut Context<'_>,
498 ) -> Poll<Result<(), Self::Error>> {
499 self.ws_stream.poll_flush_unpin(cx).map_err(|e| e.into())
500 }
501
502 fn poll_close(
503 mut self: std::pin::Pin<&mut Self>,
504 cx: &mut Context<'_>,
505 ) -> Poll<Result<(), Self::Error>> {
506 self.ws_stream.poll_close_unpin(cx).map_err(|e| e.into())
507 }
508}
509
510impl AsyncRead for WsConnectionBinary {
511 fn poll_read(
512 mut self: Pin<&mut Self>,
513 cx: &mut Context<'_>,
514 buf: &mut ReadBuf<'_>,
515 ) -> Poll<io::Result<()>> {
516 loop {
517 if let Some(remaining_buf) = self.remaining_bytes.as_mut() {
518 if buf.remaining() < remaining_buf.len() {
519 let buffer = remaining_buf.split_to(buf.remaining());
520 buf.put_slice(&buffer);
521 } else {
522 buf.put_slice(remaining_buf);
523 self.remaining_bytes = None::<BytesMut>;
524 }
525 return Poll::Ready(Ok(()));
526 }
527 match self.ws_stream.poll_next_unpin(cx) {
528 Poll::Ready(Some(Ok(msg))) => {
529 if let Message::Binary(msg) = msg {
530 if buf.remaining() < msg.len() {
531 let mut bytes = BytesMut::with_capacity(msg.len() - buf.remaining());
532 bytes.put(&msg[buf.remaining()..]);
533 self.remaining_bytes = Some(bytes);
534 buf.put_slice(&msg[..buf.remaining()]);
535 } else {
536 buf.put_slice(&msg);
537 }
538 return Poll::Ready(Ok(()));
539 } else {
540 continue;
541 }
542 }
543 Poll::Ready(Some(Err(e))) => {
544 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
545 }
546 Poll::Ready(None) => return Poll::Ready(Ok(())),
547 Poll::Pending => {
548 if let WsMode::Server(interval) = &mut self.mode {
549 match interval.poll_tick(cx) {
550 Poll::Ready(_) => {
551 match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
552 Poll::Ready(Ok(_)) => continue,
553 Poll::Ready(Err(e)) => {
554 return Poll::Ready(Err(io::Error::new(
555 io::ErrorKind::Other,
556 e.to_string(),
557 )))
558 }
559 Poll::Pending => return Poll::Pending,
560 }
561 }
562 Poll::Pending => return Poll::Pending,
563 }
564 } else {
565 return Poll::Pending;
566 }
567 }
568 }
569 }
570 }
571}
572
573impl AsyncWrite for WsConnectionBinary {
574 fn poll_write(
575 mut self: Pin<&mut Self>,
576 cx: &mut Context<'_>,
577 buf: &[u8],
578 ) -> Poll<Result<usize, io::Error>> {
579 match self
580 .ws_stream
581 .poll_ready_unpin(cx)
582 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
583 {
584 Poll::Ready(()) => {
585 self.ws_stream
586 .start_send_unpin(Message::binary(buf))
587 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
588 Poll::Ready(Ok(buf.len()))
589 }
590 Poll::Pending => Poll::Pending,
591 }
592 }
593
594 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
595 self.ws_stream
596 .poll_flush_unpin(cx)
597 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
598 }
599
600 fn poll_shutdown(
601 mut self: Pin<&mut Self>,
602 cx: &mut Context<'_>,
603 ) -> Poll<Result<(), io::Error>> {
604 self.ws_stream
605 .poll_close_unpin(cx)
606 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
607 }
608}