1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
//! Implementation of websocket transport

use super::{Credentials, Request, Subscribe, TransportError};
use http::{Request as HttpRequest, Uri};
use log::{debug, error, trace};
use std::borrow::Cow;
use std::str::FromStr;
use thiserror::Error;
use tungstenite::client::AutoStream;
use tungstenite::protocol::frame::coding::CloseCode;
use tungstenite::protocol::CloseFrame;
use tungstenite::{connect as ws_connect, Message, WebSocket as WebSocketTungstenite};

/// Wraps a websocket connection
pub struct WebSocket {
    /// The endpoint of the websocket connection
    pub address: String,
    pub(crate) credentials: Option<Credentials>,
    ws: WebSocketTungstenite<AutoStream>,
}

impl WebSocket {
    pub(crate) fn new(
        address: String,
        credentials: Option<Credentials>,
    ) -> Result<WebSocket, WebSocketError> {
        debug!("Initiating websocket connection to {}", address);
        let uri = Uri::from_str(&address)?;
        let handshake_request = create_handshake_request(&uri, credentials.clone())?;
        let ws = ws_connect(handshake_request)?;
        trace!("Handshake Response: {:?}", ws.1);
        Ok(WebSocket {
            address,
            credentials,
            ws: ws.0,
        })
    }

    fn read_message(&mut self) -> Result<String, WebSocketError> {
        match self.read() {
            Ok(Message::Text(response)) => Ok(response),
            Ok(_) => self.read_message(),
            Err(err) => Err(err),
        }
    }

    fn read(&mut self) -> Result<Message, WebSocketError> {
        let message = self.ws.read_message()?;
        trace!("Reading from websocket: {}", &message);
        Ok(message)
    }

    fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
        trace!("Writing to websocket: {}", &message);
        self.ws.write_message(message)?;
        Ok(())
    }

    fn close(&mut self) -> Result<(), WebSocketError> {
        debug!("Closing websocket connection");
        let close_frame = CloseFrame {
            code: CloseCode::Normal,
            reason: Cow::from("Finished"),
        };
        self.ws.close(Some(close_frame))?;
        self.ws.write_pending().map_err(WebSocketError::from)
    }
}

impl Request for WebSocket {
    fn request(&mut self, cmd: String) -> Result<String, TransportError> {
        let _write = self.write(Message::Text(cmd))?;
        self.read_message().map_err(TransportError::from)
    }
}

impl Subscribe for WebSocket {
    fn read_next(&mut self) -> Result<String, TransportError> {
        self.read_message().map_err(TransportError::from)
    }

    fn fork(&self) -> Result<Self, TransportError> {
        Self::new(self.address.clone(), self.credentials.clone()).map_err(TransportError::from)
    }
}

impl Drop for WebSocket {
    fn drop(&mut self) {
        let close = self.close();
        if let Err(err) = close {
            error!("{}", err);
        }
    }
}

fn create_handshake_request(
    uri: &Uri,
    credentials: Option<Credentials>,
) -> Result<HttpRequest<()>, WebSocketError> {
    let mut req_builder = HttpRequest::get(uri);
    if let Some(ref credentials) = credentials {
        let headers = req_builder.headers_mut().ok_or(WebSocketError::Handshake)?;
        headers.insert("Authorization", credentials.to_auth_string().parse()?);
    }

    let request = req_builder.body(())?;
    trace!("Built websocket handshake request: {:?}", &request);
    Ok(request)
}

/// An error type collecting what can go wrong with a websocket
#[derive(Debug, Error)]
pub enum WebSocketError {
    #[error("WebSocket Error: {0}")]
    Tungstenite(#[from] tungstenite::Error),
    #[error("WebSocket Invalid Handshake Request Error: {0}")]
    Http(#[from] http::Error),
    #[error("WebSocket Invalid Address Error: {0}")]
    Url(#[from] http::uri::InvalidUri),
    #[error("WebSocket Handshake Header Error")]
    Handshake,
    #[error("WebSocket Error. Unable to parse credentials {0}")]
    InvalidHeader(#[from] http::header::InvalidHeaderValue),
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::{SocketAddr, TcpStream};
    use tungstenite::{accept, Message};

    fn spawn_websocket_server<F>(mut handle_ws_stream: F, port: u16)
    where
        F: FnMut(&mut WebSocketTungstenite<TcpStream>) + Send + 'static,
    {
        let tcp_listener =
            std::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], port))).unwrap();

        let _thread = std::thread::Builder::new()
            .name("Websocket Server".to_string())
            .spawn(move || loop {
                match tcp_listener.accept() {
                    Ok((tcp_stream, _address)) => match accept(tcp_stream) {
                        Ok(mut websocket) => handle_ws_stream(&mut websocket),
                        Err(err) => panic!("{}", err),
                    },
                    Err(err) => panic!("{}", err),
                }
            })
            .is_ok();
    }

    fn ping_pong(ws_stream: &mut WebSocketTungstenite<TcpStream>) {
        match ws_stream.read_message() {
            Ok(message) => match message {
                Message::Text(echo) => ws_stream
                    .write_message(Message::Text(echo + " Pong"))
                    .unwrap(),
                _ => panic!("Received other message type."),
            },
            Err(err) => panic!(err),
        }
    }

    #[test]
    fn test_websocket_create_handshake_request_with_credentials() {
        let uri = Uri::from_static("localhost");
        let credentials = Credentials::Basic(String::from("YWJjOjEyMw=="));
        let request = create_handshake_request(&uri, Some(credentials)).unwrap();
        assert_eq!(
            request.headers().get("Authorization").unwrap(),
            "Basic YWJjOjEyMw=="
        );
    }

    #[test]
    fn test_websocket_create_handshake_request_without_credentials() {
        let uri = Uri::from_static("localhost");
        let request = create_handshake_request(&uri, None).unwrap();
        assert_eq!(request.method(), http::method::Method::GET);
        assert_eq!(request.uri(), &uri);
    }

    #[test]
    fn test_websocket_request() {
        spawn_websocket_server(ping_pong, 3001);
        let mut ws_client = WebSocket::new(String::from("ws://localhost:3001"), None).unwrap();
        let response = ws_client.request(String::from("Ping")).unwrap();
        assert_eq!(response, "Ping Pong");
    }
}