jets 0.4.3

A Rust rule-based tunnel
Documentation
use crate::app::config::WsSettings;
use crate::common::{invalid_data_error, invalid_input_error, Address};
use crate::proxy::{LocalAddr, ProxyStream};
use bytes::{Bytes, BytesMut};
use futures::sink::Sink;
use futures::stream::Stream;
use futures::{
    ready,
    task::{Context, Poll},
};
use hyper::http::{HeaderName, HeaderValue};
use std::cmp::min;
use std::io::{Error, ErrorKind, Result};
use std::net::SocketAddr;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_tungstenite::{client_async_with_config, WebSocketStream};
use tungstenite::client::IntoClientRequest;
use tungstenite::handshake::client::Request;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Message;

#[derive(Clone, Debug)]
pub struct Ws {
    request: Request,
    ws_config: WebSocketConfig,
}

impl Ws {
    pub fn new(ws_settings: WsSettings, addr: &Address, secure: bool) -> Result<Self> {
        let host = if !ws_settings.host.is_empty() {
            &ws_settings.host
        } else if let Some(val) = ws_settings.headers.get("HOST") {
            val
        } else {
            &addr.host()
        };
        let protocol = if secure { "wss" } else { "ws" };
        let url = format!("{}://{}{}", protocol, host, ws_settings.path);
        let mut request = url
            .as_str()
            .into_client_request()
            .map_err(|_| invalid_input_error(format!("invalid ws url {}", url)))?;
        for (k, v) in ws_settings.headers.iter() {
            if k.to_uppercase() != "HOST" {
                request.headers_mut().insert(
                    HeaderName::try_from(k).map_err(|_| {
                        invalid_input_error(format!("invalid ws header name: {}", k))
                    })?,
                    HeaderValue::from_str(v).map_err(|_| {
                        invalid_input_error(format!("invalid ws header value: {}", v))
                    })?,
                );
            }
        }
        request.headers_mut().insert(
            "Host",
            HeaderValue::from_str(host)
                .map_err(|_| invalid_input_error(format!("invalid ws host: {}", host)))?,
        );
        let ws_config = WebSocketConfig::default().write_buffer_size(0);
        Ok(Self { request, ws_config })
    }

    pub async fn connect<S>(&self, stream: S) -> Result<WsStream<S>>
    where
        S: ProxyStream,
    {
        log::debug!("Sending ws request to {}", self.request.uri());
        let (s, _) = client_async_with_config(self.request.clone(), stream, Some(self.ws_config))
            .await
            .map_err(|e| {
                Error::other(format!("connect ws {} failed: {}", self.request.uri(), e))
            })?;

        Ok(WsStream::new(s))
    }
}

pub struct WsStream<S> {
    buf: BytesMut,
    inner: WebSocketStream<S>,
}

impl<S> WsStream<S> {
    pub fn new(stream: WebSocketStream<S>) -> Self {
        Self {
            buf: BytesMut::new(),
            inner: stream,
        }
    }
}

fn broken_pipe() -> Error {
    Error::new(ErrorKind::Interrupted, "broken pipe")
}

fn invalid_frame() -> Error {
    Error::new(ErrorKind::Interrupted, "invalid frame")
}

impl<S: ProxyStream> AsyncRead for WsStream<S> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut ReadBuf,
    ) -> Poll<Result<()>> {
        if !self.buf.is_empty() {
            let to_read = min(buf.remaining(), self.buf.len());
            let for_read = self.buf.split_to(to_read);
            buf.put_slice(&for_read[..to_read]);
            return Poll::Ready(Ok(()));
        }
        Poll::Ready(ready!(Pin::new(&mut self.inner).poll_next(cx)).map_or(
            Err(broken_pipe()),
            |item| {
                item.map_or(Err(broken_pipe()), |msg| match msg {
                    Message::Binary(data) => {
                        let to_read = min(buf.remaining(), data.len());
                        buf.put_slice(&data[..to_read]);
                        if data.len() > to_read {
                            self.buf.extend_from_slice(&data[to_read..]);
                        }
                        log::trace!("poll_read {} bytes", buf.filled().len());
                        Ok(())
                    }
                    Message::Close(_) => Ok(()),
                    _ => Err(invalid_frame()),
                })
            },
        ))
    }
}

impl<S: ProxyStream> AsyncWrite for WsStream<S> {
    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize>> {
        log::trace!("poll_write {} bytes", buf.len());
        ready!(Pin::new(&mut self.inner)
            .poll_ready(cx)
            .map_err(|_| broken_pipe()))?;

        let msg = Message::Binary(Bytes::copy_from_slice(buf));
        Pin::new(&mut self.inner)
            .start_send(msg)
            .map_err(|_| broken_pipe())?;

        let _ = Pin::new(&mut self.inner).poll_flush(cx);

        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
        Pin::new(&mut self.inner)
            .poll_flush(cx)
            .map_err(|_| broken_pipe())
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
        Pin::new(&mut self.inner)
            .poll_close(cx)
            .map_err(invalid_data_error)
    }
}

impl<S: ProxyStream> LocalAddr for WsStream<S> {
    fn local_addr(&self) -> Result<SocketAddr> {
        self.inner.get_ref().local_addr()
    }
}