narrowlink_network/
ws.rs

1use bytes::{BufMut, BytesMut};
2use futures_util::{Future, FutureExt, SinkExt, StreamExt};
3use hyper::{client::conn, http::HeaderValue, Body, HeaderMap, Request, StatusCode};
4use narrowlink_types::ServiceType;
5use std::{
6    collections::HashMap,
7    io::{self, Error, ErrorKind},
8    net::{SocketAddr, SocketAddrV4},
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tokio::{
13    io::{AsyncRead, AsyncWrite, ReadBuf},
14    task::JoinHandle,
15};
16use tokio_tungstenite::WebSocketStream;
17use tracing::{debug, trace, warn};
18use tungstenite::Message;
19
20use crate::{
21    error::NetworkError,
22    transport::{StreamType, TlsConfiguration, UnifiedSocket},
23    AsyncSocket,
24};
25
26const KEEP_ALIVE_TIME: u64 = 20;
27
28pub enum WsMode {
29    Server(tokio::time::Interval),
30    Client(HeaderMap, JoinHandle<()>),
31}
32
33pub struct WsConnection {
34    ws_stream: WebSocketStream<Box<dyn AsyncSocket>>,
35    remaining_bytes: Option<BytesMut>,
36    mode: WsMode,
37    local_addr: SocketAddr,
38    peer_addr: SocketAddr,
39}
40
41impl WsConnection {
42    pub async fn from(server_stream: impl AsyncSocket) -> Self {
43        // let x: Box<dyn AsyncSocket> = Box::new(server_stream);
44        let ws_stream = WebSocketStream::from_raw_socket(
45            Box::new(server_stream) as Box<dyn AsyncSocket>,
46            tungstenite::protocol::Role::Server,
47            None,
48        )
49        .await;
50
51        Self {
52            ws_stream,
53            remaining_bytes: None,
54            mode: WsMode::Server(tokio::time::interval(core::time::Duration::from_secs(
55                KEEP_ALIVE_TIME,
56            ))),
57            local_addr: SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::UNSPECIFIED, 0)),
58            peer_addr: SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::UNSPECIFIED, 0)),
59        }
60    }
61    pub async fn new(
62        host: &str,
63        headers: &HashMap<&'static str, String>,
64        service_type: &ServiceType,
65    ) -> Result<Self, NetworkError> {
66        let sni = if let Some(sni) = host.split(':').next() {
67            sni
68        } else {
69            host
70        };
71        let transport_type = if let ServiceType::Wss = service_type {
72            StreamType::Tls(TlsConfiguration {
73                sni: sni.to_owned(),
74            })
75        } else {
76            StreamType::Tcp
77        };
78        let stream = UnifiedSocket::new(host, transport_type).await?;
79        let local_addr = stream.local_addr();
80        let peer_addr = stream.peer_addr();
81        let (mut request_sender, connection) = conn::handshake(stream).await?;
82        let conn_handler = tokio::spawn(async move {
83            if let Err(e) = connection.await {
84                eprintln!("Error in connection: {}", e);
85            }
86        });
87        let mut request = Request::builder()
88            // .uri(uri)
89            .header(
90                "Host",
91                host.strip_suffix(":443")
92                    .or(host.strip_suffix(":80"))
93                    .unwrap_or(host),
94            )
95            .header("Connection", "Upgrade")
96            .header("Upgrade", "websocket")
97            .header("Sec-WebSocket-Version", "13")
98            .header(
99                "Sec-WebSocket-Key",
100                tungstenite::handshake::client::generate_key(),
101            )
102            .header("NL-VERSION", env!("CARGO_PKG_VERSION"));
103        for (key, value) in headers.iter() {
104            if let Ok(header_value) = HeaderValue::from_str(value) {
105                request
106                    .headers_mut()
107                    .and_then(|headers| headers.insert(*key, header_value));
108            }
109        }
110        let request = request.method("GET").body(Body::from(""))?;
111        let response = request_sender.send_request(request).await?;
112        let response_headers = response.headers().clone();
113        trace!("response status: {}", response.status().to_string());
114        if response.status() != StatusCode::SWITCHING_PROTOCOLS {
115            let status_code = response.status().as_u16();
116            trace!(
117                "response body: {}",
118                String::from_utf8_lossy(
119                    hyper::body::to_bytes(response.into_body()).await?.as_ref()
120                )
121            );
122
123            return Err(NetworkError::UnableToUpgrade(status_code));
124        }
125
126        let upgraded = hyper::upgrade::on(response).await?;
127        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
128            Box::new(upgraded) as Box<dyn AsyncSocket>,
129            tungstenite::protocol::Role::Client,
130            None,
131        )
132        .await;
133        Ok(Self {
134            ws_stream,
135            remaining_bytes: None,
136            mode: WsMode::Client(response_headers, conn_handler),
137            local_addr,
138            peer_addr,
139        })
140    }
141    pub fn get_header(&self, key: &str) -> Option<&str> {
142        if let WsMode::Client(response_headers, _) = &self.mode {
143            response_headers.get(key).and_then(|v| v.to_str().ok())
144        } else {
145            None
146        }
147    }
148    pub fn drive_key(key: &[u8]) -> String {
149        tungstenite::handshake::derive_accept_key(key)
150    }
151    pub fn local_addr(&self) -> SocketAddr {
152        self.local_addr
153    }
154    pub fn peer_addr(&self) -> SocketAddr {
155        self.peer_addr
156    }
157}
158
159impl AsyncRead for WsConnection {
160    fn poll_read(
161        mut self: std::pin::Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &mut ReadBuf<'_>,
164    ) -> Poll<io::Result<()>> {
165        loop {
166            if let Some(remaining_buf) = self.remaining_bytes.as_mut() {
167                if buf.remaining() < remaining_buf.len() {
168                    let buffer = remaining_buf.split_to(buf.remaining());
169                    buf.put_slice(&buffer);
170                } else {
171                    buf.put_slice(remaining_buf);
172                    self.remaining_bytes = None::<BytesMut>;
173                }
174                return Poll::Ready(Ok(()));
175            }
176
177            match self.ws_stream.poll_next_unpin(cx) {
178                Poll::Ready(d) => match d {
179                    Some(Ok(data)) => {
180                        if let Message::Binary(bin) = data {
181                            if buf.remaining() < bin.len() {
182                                // todo max size 64 << 20
183                                let mut bytes =
184                                    BytesMut::with_capacity(bin.len() - buf.remaining());
185                                bytes.put(&bin[buf.remaining()..]);
186                                self.remaining_bytes = Some(bytes);
187                                buf.put_slice(&bin[..buf.remaining()]);
188                            } else {
189                                buf.put_slice(&bin);
190                            }
191
192                            return Poll::Ready(Ok(()));
193                        } else {
194                            continue;
195                        }
196                    }
197                    Some(Err(_e)) => io::Error::from(io::ErrorKind::UnexpectedEof),
198                    None => return Poll::Ready(Ok(())),
199                },
200                Poll::Pending => {
201                    if let WsMode::Server(interval) = &mut self.mode {
202                        match interval.poll_tick(cx) {
203                            Poll::Ready(_) => {
204                                match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
205                                    Poll::Ready(Ok(_)) => continue,
206                                    Poll::Ready(Err(_e)) => {
207                                        return Poll::Ready(Err(Error::new(
208                                            ErrorKind::Other,
209                                            "Ping Error!",
210                                        )))
211                                    }
212                                    Poll::Pending => return Poll::Pending,
213                                }
214                            }
215                            Poll::Pending => return Poll::Pending,
216                        }
217                    } else {
218                        return Poll::Pending;
219                    }
220                }
221            };
222        }
223    }
224}
225
226impl AsyncWrite for WsConnection {
227    fn poll_write(
228        mut self: std::pin::Pin<&mut Self>,
229        cx: &mut Context<'_>,
230        buf: &[u8],
231    ) -> Poll<Result<usize, io::Error>> {
232        match Pin::new(&mut self.ws_stream.send(Message::binary(buf)))
233            .poll(cx)
234            .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?
235        {
236            Poll::Ready(_) => Poll::Ready(Ok(buf.len())),
237            Poll::Pending => Poll::Pending,
238        }
239    }
240
241    fn poll_flush(
242        mut self: std::pin::Pin<&mut Self>,
243        cx: &mut Context<'_>,
244    ) -> Poll<Result<(), io::Error>> {
245        self.ws_stream
246            .poll_flush_unpin(cx)
247            .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
248    }
249
250    fn poll_shutdown(
251        mut self: std::pin::Pin<&mut Self>,
252        cx: &mut Context<'_>,
253    ) -> Poll<Result<(), io::Error>> {
254        self.ws_stream
255            .poll_close_unpin(cx)
256            .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
257    }
258}
259
260impl futures_util::Stream for WsConnection {
261    type Item = Result<String, NetworkError>;
262
263    fn poll_next(
264        mut self: std::pin::Pin<&mut Self>,
265        cx: &mut Context<'_>,
266    ) -> Poll<Option<Self::Item>> {
267        loop {
268            match self.ws_stream.poll_next_unpin(cx) {
269                Poll::Ready(Some(Ok(msg))) => {
270                    if let Message::Text(msg) = msg {
271                        return Poll::Ready(Some(Ok(msg)));
272                    } else {
273                        continue;
274                    }
275                }
276                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
277                Poll::Ready(None) => return Poll::Ready(None),
278                Poll::Pending => {
279                    if let WsMode::Server(interval) = &mut self.mode {
280                        match interval.poll_tick(cx) {
281                            Poll::Ready(_) => {
282                                match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
283                                    Poll::Ready(Ok(_)) => continue,
284                                    Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
285                                    Poll::Pending => return Poll::Pending,
286                                }
287                            }
288                            Poll::Pending => return Poll::Pending,
289                        }
290                    } else {
291                        return Poll::Pending;
292                    }
293                }
294            }
295        }
296    }
297}
298impl futures_util::Sink<String> for WsConnection {
299    type Error = NetworkError;
300
301    fn poll_ready(
302        mut self: std::pin::Pin<&mut Self>,
303        cx: &mut Context<'_>,
304    ) -> Poll<Result<(), Self::Error>> {
305        self.ws_stream.poll_ready_unpin(cx).map_err(|e| e.into())
306    }
307
308    fn start_send(mut self: std::pin::Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
309        self.ws_stream
310            .start_send_unpin(Message::Text(item))
311            .map_err(|e| e.into())
312    }
313
314    fn poll_flush(
315        mut self: std::pin::Pin<&mut Self>,
316        cx: &mut Context<'_>,
317    ) -> Poll<Result<(), Self::Error>> {
318        self.ws_stream.poll_flush_unpin(cx).map_err(|e| e.into())
319    }
320
321    fn poll_close(
322        mut self: std::pin::Pin<&mut Self>,
323        cx: &mut Context<'_>,
324    ) -> Poll<Result<(), Self::Error>> {
325        self.ws_stream.poll_close_unpin(cx).map_err(|e| e.into())
326    }
327}
328
329pub struct WsConnectionBinary {
330    ws_stream: WebSocketStream<Box<dyn AsyncSocket>>,
331    remaining_bytes: Option<BytesMut>,
332    mode: WsMode,
333}
334
335impl WsConnectionBinary {
336    pub async fn from(server_stream: impl AsyncSocket) -> Self {
337        // let x: Box<dyn AsyncSocket> = Box::new(server_stream);
338        let ws_stream = WebSocketStream::from_raw_socket(
339            Box::new(server_stream) as Box<dyn AsyncSocket>,
340            tungstenite::protocol::Role::Server,
341            None,
342        )
343        .await;
344
345        Self {
346            ws_stream,
347            remaining_bytes: None,
348            mode: WsMode::Server(tokio::time::interval(core::time::Duration::from_secs(
349                KEEP_ALIVE_TIME,
350            ))),
351        }
352    }
353    pub async fn new(
354        host: &str,
355        // uri: &str,
356        headers: HashMap<&'static str, String>,
357        service_type: &ServiceType,
358    ) -> Result<Self, NetworkError> {
359        let sni = if let Some(sni) = host.split(':').next() {
360            sni
361        } else {
362            host
363        };
364        let transport_type = if let ServiceType::Wss = service_type {
365            StreamType::Tls(TlsConfiguration {
366                sni: sni.to_owned(),
367            })
368        } else {
369            StreamType::Tcp
370        };
371        let stream = UnifiedSocket::new(host, transport_type).await?;
372
373        let (mut request_sender, connection) = conn::handshake(stream).await?;
374        let conn_handler = tokio::spawn(async move {
375            if let Err(e) = connection.await {
376                warn!("Error in connection: {}", e);
377            }
378        });
379
380        let mut request = Request::builder()
381            // .uri(uri)
382            .header(
383                "Host",
384                host.strip_suffix(":443")
385                    .or(host.strip_suffix(":80"))
386                    .unwrap_or(host),
387            )
388            .header("Connection", "Upgrade")
389            .header("Upgrade", "websocket")
390            .header("Sec-WebSocket-Version", "13")
391            .header(
392                "Sec-WebSocket-Key",
393                tungstenite::handshake::client::generate_key(),
394            );
395        for (key, value) in headers.iter() {
396            if let Ok(header_value) = HeaderValue::from_str(value) {
397                request
398                    .headers_mut()
399                    .and_then(|headers| headers.insert(*key, header_value));
400            }
401        }
402        let request = request.method("GET").body(Body::from(""))?;
403        let response = request_sender.send_request(request).await?;
404        let response_headers = response.headers().clone();
405        debug!("ws connection status: {}", response.status());
406        if response.status() != StatusCode::SWITCHING_PROTOCOLS {
407            let status_code = response.status().as_u16();
408            trace!(
409                "response body: {}",
410                String::from_utf8_lossy(
411                    hyper::body::to_bytes(response.into_body()).await?.as_ref()
412                )
413            );
414            return Err(NetworkError::UnableToUpgrade(status_code));
415        }
416        let upgraded = hyper::upgrade::on(response).await?;
417        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
418            Box::new(upgraded) as Box<dyn AsyncSocket>,
419            tungstenite::protocol::Role::Client,
420            None,
421        )
422        .await;
423        Ok(Self {
424            ws_stream,
425            remaining_bytes: None,
426            mode: WsMode::Client(response_headers, conn_handler),
427        })
428    }
429    pub fn get_header(&self, key: &str) -> Option<&str> {
430        if let WsMode::Client(response_headers, _) = &self.mode {
431            response_headers.get(key).and_then(|v| v.to_str().ok())
432        } else {
433            None
434        }
435    }
436    pub fn drive_key(key: &[u8]) -> String {
437        tungstenite::handshake::derive_accept_key(key)
438    }
439}
440
441impl futures_util::Stream for WsConnectionBinary {
442    type Item = Result<Vec<u8>, NetworkError>;
443
444    fn poll_next(
445        mut self: std::pin::Pin<&mut Self>,
446        cx: &mut Context<'_>,
447    ) -> Poll<Option<Self::Item>> {
448        loop {
449            match self.ws_stream.poll_next_unpin(cx) {
450                Poll::Ready(Some(Ok(msg))) => {
451                    if let Message::Binary(msg) = msg {
452                        return Poll::Ready(Some(Ok(msg)));
453                    } else {
454                        continue;
455                    }
456                }
457                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
458                Poll::Ready(None) => return Poll::Ready(None),
459                Poll::Pending => {
460                    if let WsMode::Server(interval) = &mut self.mode {
461                        match interval.poll_tick(cx) {
462                            Poll::Ready(_) => {
463                                match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
464                                    Poll::Ready(Ok(_)) => continue,
465                                    Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
466                                    Poll::Pending => return Poll::Pending,
467                                }
468                            }
469                            Poll::Pending => return Poll::Pending,
470                        }
471                    } else {
472                        return Poll::Pending;
473                    }
474                }
475            }
476        }
477    }
478}
479impl futures_util::Sink<Vec<u8>> for WsConnectionBinary {
480    type Error = NetworkError;
481
482    fn poll_ready(
483        mut self: std::pin::Pin<&mut Self>,
484        cx: &mut Context<'_>,
485    ) -> Poll<Result<(), Self::Error>> {
486        self.ws_stream.poll_ready_unpin(cx).map_err(|e| e.into())
487    }
488
489    fn start_send(mut self: std::pin::Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
490        self.ws_stream
491            .start_send_unpin(Message::Binary(item))
492            .map_err(|e| e.into())
493    }
494
495    fn poll_flush(
496        mut self: std::pin::Pin<&mut Self>,
497        cx: &mut Context<'_>,
498    ) -> Poll<Result<(), Self::Error>> {
499        self.ws_stream.poll_flush_unpin(cx).map_err(|e| e.into())
500    }
501
502    fn poll_close(
503        mut self: std::pin::Pin<&mut Self>,
504        cx: &mut Context<'_>,
505    ) -> Poll<Result<(), Self::Error>> {
506        self.ws_stream.poll_close_unpin(cx).map_err(|e| e.into())
507    }
508}
509
510impl AsyncRead for WsConnectionBinary {
511    fn poll_read(
512        mut self: Pin<&mut Self>,
513        cx: &mut Context<'_>,
514        buf: &mut ReadBuf<'_>,
515    ) -> Poll<io::Result<()>> {
516        loop {
517            if let Some(remaining_buf) = self.remaining_bytes.as_mut() {
518                if buf.remaining() < remaining_buf.len() {
519                    let buffer = remaining_buf.split_to(buf.remaining());
520                    buf.put_slice(&buffer);
521                } else {
522                    buf.put_slice(remaining_buf);
523                    self.remaining_bytes = None::<BytesMut>;
524                }
525                return Poll::Ready(Ok(()));
526            }
527            match self.ws_stream.poll_next_unpin(cx) {
528                Poll::Ready(Some(Ok(msg))) => {
529                    if let Message::Binary(msg) = msg {
530                        if buf.remaining() < msg.len() {
531                            let mut bytes = BytesMut::with_capacity(msg.len() - buf.remaining());
532                            bytes.put(&msg[buf.remaining()..]);
533                            self.remaining_bytes = Some(bytes);
534                            buf.put_slice(&msg[..buf.remaining()]);
535                        } else {
536                            buf.put_slice(&msg);
537                        }
538                        return Poll::Ready(Ok(()));
539                    } else {
540                        continue;
541                    }
542                }
543                Poll::Ready(Some(Err(e))) => {
544                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.to_string())))
545                }
546                Poll::Ready(None) => return Poll::Ready(Ok(())),
547                Poll::Pending => {
548                    if let WsMode::Server(interval) = &mut self.mode {
549                        match interval.poll_tick(cx) {
550                            Poll::Ready(_) => {
551                                match self.ws_stream.send(Message::Ping(vec![0])).poll_unpin(cx) {
552                                    Poll::Ready(Ok(_)) => continue,
553                                    Poll::Ready(Err(e)) => {
554                                        return Poll::Ready(Err(io::Error::new(
555                                            io::ErrorKind::Other,
556                                            e.to_string(),
557                                        )))
558                                    }
559                                    Poll::Pending => return Poll::Pending,
560                                }
561                            }
562                            Poll::Pending => return Poll::Pending,
563                        }
564                    } else {
565                        return Poll::Pending;
566                    }
567                }
568            }
569        }
570    }
571}
572
573impl AsyncWrite for WsConnectionBinary {
574    fn poll_write(
575        mut self: Pin<&mut Self>,
576        cx: &mut Context<'_>,
577        buf: &[u8],
578    ) -> Poll<Result<usize, io::Error>> {
579        match self
580            .ws_stream
581            .poll_ready_unpin(cx)
582            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
583        {
584            Poll::Ready(()) => {
585                self.ws_stream
586                    .start_send_unpin(Message::binary(buf))
587                    .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
588                Poll::Ready(Ok(buf.len()))
589            }
590            Poll::Pending => Poll::Pending,
591        }
592    }
593
594    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
595        self.ws_stream
596            .poll_flush_unpin(cx)
597            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
598    }
599
600    fn poll_shutdown(
601        mut self: Pin<&mut Self>,
602        cx: &mut Context<'_>,
603    ) -> Poll<Result<(), io::Error>> {
604        self.ws_stream
605            .poll_close_unpin(cx)
606            .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
607    }
608}