twisty 0.3.1

Example WebSocket Echo client implemented with twist
//! twisty run
use base64::encode;
use blake2::{Blake2b, Digest};
use byteorder::{BigEndian, WriteBytesExt};
use clap::{App, Arg};
use client::{other, Client, Tls};
use config::Config;
use env_logger;
use futures::Future;
use futures::future::{self, Loop};
use rand::{self, Rng};
use sha1::Sha1;
use std::collections::HashMap;
use std::error::Error;
use std::io::{self, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use term;
use tokio_core::reactor::Core;
use tokio_service::Service;
use twist::client::{HandshakeRequestFrame, HandshakeResponseFrame, WebSocketFrame};
use url::Url;

/// Defined in RFC6455 and used to generate the `Sec-WebSocket-Accept` header in the server
/// handshake response.
static KEY: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

/// Generate the twisty header.
fn header(socket_addr: &SocketAddr) -> io::Result<()> {
    let mut t = term::stdout().ok_or_else(|| other("invalid term"))?;
    t.fg(term::color::BRIGHT_GREEN)?;
    writeln!(t,
             "{} {}",
             env!("CARGO_PKG_NAME"),
             env!("CARGO_PKG_VERSION"))?;
    writeln!(t, "Enter 'exit' to stop {}", env!("CARGO_PKG_NAME"))?;
    writeln!(t)?;
    writeln!(t, "Connecting to {}", socket_addr)?;
    writeln!(t)?;
    t.reset()?;
    t.flush()?;
    Ok(())
}

/// Generate the twisty response.
fn response(resp: &str) -> io::Result<()> {
    let mut t = term::stdout().ok_or_else(|| other("invalid term"))?;
    t.fg(term::color::BLUE)?;
    write!(t, "twisty < {}", resp)?;
    t.reset()?;
    t.flush()?;
    Ok(())
}

/// Generate a nonce for a client handshake request.
fn gen_nonce() -> String {
    let mut rng = rand::thread_rng();
    let mut nonce_vec = Vec::with_capacity(2);
    let nonce = rng.gen::<u16>();
    let mut hasher = Blake2b::default();

    if nonce_vec.write_u16::<BigEndian>(nonce).is_ok() {
        hasher.input(&nonce_vec);
        encode(&hasher.result())
    } else {
        nonce_vec.clear();
        nonce_vec.push(rng.gen::<u8>());
        nonce_vec.push(rng.gen::<u8>());
        hasher.input(&nonce_vec);
        encode(&hasher.result())
    }
}

/// Validate the response accept key
fn validate(nonce: &str, resp: &HandshakeResponseFrame, config: &Config) -> bool {
    trace!(config.dual(), "server response\n{}", resp);
    let accept_key = resp.ws_accept();
    let mut base = String::from(nonce);
    base.push_str(KEY);

    let mut m = Sha1::new();
    m.reset();
    m.update(base.as_bytes());

    let encoded = encode(&m.digest().bytes());
    encoded == accept_key
}

/// Run an unsecure client.
fn run_unsecure(config: &Config, socket_addr: SocketAddr) -> i32 {
    match Core::new() {
        Ok(mut core) => {
            let handle = core.handle();
            let cfg = config.clone();
            let conn_future = Client::connect(config, socket_addr, &handle);
            let client_future = conn_future.and_then(|client| {
                let nonce = gen_nonce();
                let cl = nonce.clone();
                let mut handshake_frame: WebSocketFrame = Default::default();

                let auth = encode(b"twisty:twisty");
                let val = format!("Twist {}", auth);
                let mut other_headers = HashMap::new();
                other_headers.insert("Authorization".to_string(), val);

                let mut client_handshake: HandshakeRequestFrame = Default::default();
                client_handshake.set_user_agent("twisty 0.1.0".to_string());
                client_handshake.set_path(config.path().to_string());
                client_handshake.set_query(config.query().to_string());
                client_handshake.set_origin(config.origin_header().to_string());
                client_handshake.set_host(config.host_header().to_string());
                client_handshake.set_sec_websocket_key(nonce);
                client_handshake.set_others(other_headers);
                handshake_frame.set_clientside_handshake_request(client_handshake);
                client.call(handshake_frame).and_then(move |resp| if let Some(resp) =
                    resp.clientside_handshake_response() {
                                                          if validate(&cl, resp, &cfg) {
                                                              Ok(client)
                                                          } else {
                                                              Err(other("invalid response"))
                                                          }
                                                      } else {
                                                          Err(other("die"))
                                                      })
            });

            let chain = client_future.and_then(|client| {
                future::loop_fn((), move |_| {
                    client.read_line()
                        .and_then(|resp| -> Result<Loop<(), ()>, io::Error> {
                            let resp_str = if let Some(base) = resp.base() {
                                String::from_utf8_lossy(base.application_data()).into_owned()
                            } else {
                                String::new()
                            };
                            let _ = response(&resp_str);
                            Ok(Loop::Continue(()))
                        })
                        .or_else(|e| -> Result<Loop<(), ()>, io::Error> { Err(e) })
                })
            });
            if let Err(e) = core.run(chain) {
                if e.description() == "exit" {
                    info!(config.dual(), "shutting down twisty");
                } else {
                    error!(config.dual(), "error running core: {}", e);
                }
                1
            } else {
                0
            }
        }
        Err(e) => {
            error!(config.dual(), "{}", e);
            1
        }
    }
}

/// Run a secure client over TLS.
fn run_secure(config: &Config, socket_addr: SocketAddr) -> i32 {
    match Core::new() {
        Ok(mut core) => {
            let handle = core.handle();
            let cfg = config.clone();
            let conn_future = Tls::connect(config, socket_addr, &handle);
            let client_future = conn_future.and_then(|client| {
                let nonce = gen_nonce();
                let cl = nonce.clone();
                let mut handshake_frame: WebSocketFrame = Default::default();

                let auth = encode(b"twisty:twisty");
                let val = format!("Twist {}", auth);
                let mut other_headers = HashMap::new();
                other_headers.insert("Authorization".to_string(), val);

                let mut client_handshake: HandshakeRequestFrame = Default::default();
                client_handshake.set_user_agent("twisty 0.1.0".to_string());
                client_handshake.set_path(config.path().to_string());
                client_handshake.set_query(config.query().to_string());
                client_handshake.set_origin(config.origin_header().to_string());
                client_handshake.set_host(config.host_header().to_string());
                client_handshake.set_sec_websocket_key(nonce);
                client_handshake.set_others(other_headers);
                handshake_frame.set_clientside_handshake_request(client_handshake);
                client.call(handshake_frame).and_then(move |resp| if let Some(resp) =
                    resp.clientside_handshake_response() {
                                                          if validate(&cl, resp, &cfg) {
                                                              Ok(client)
                                                          } else {
                                                              Err(other("invalid response"))
                                                          }
                                                      } else {
                                                          Err(other("die"))
                                                      })
            });

            let chain = client_future.and_then(|client| {
                future::loop_fn((), move |_| {
                    client.read_line()
                        .and_then(|resp| -> Result<Loop<(), ()>, io::Error> {
                            let resp_str = if let Some(base) = resp.base() {
                                String::from_utf8_lossy(base.application_data()).into_owned()
                            } else {
                                String::new()
                            };
                            let _ = response(&resp_str);
                            Ok(Loop::Continue(()))
                        })
                        .or_else(|e| -> Result<Loop<(), ()>, io::Error> { Err(e) })
                })
            });
            if let Err(e) = core.run(chain) {
                if e.description() == "exit" {
                    info!(config.dual(), "shutting down twisty");
                } else {
                    error!(config.dual(), "error running core: {}", e);
                }
                1
            } else {
                0
            }
        }
        Err(e) => {
            error!(config.dual(), "{}", e);
            1
        }
    }
}

/// Configure and run twisty.
pub fn run(opt_args: Option<Vec<&str>>) -> i32 {
    if let Err(e) = env_logger::init() {
        err!("unable to initialize env_logger: {}", e);
    }

    let app = App::new("twisty")
        .version(env!("CARGO_PKG_VERSION"))
        .author("Jason Ozias <jason.g.ozias@gmail.com>")
        .about("tokio-proto twist client")
        .arg(Arg::with_name("URL")
                 .index(1)
                 .required(true)
                 .help("Set the url to connect to (ws:// or wss://)"))
        .arg(Arg::with_name("host")
                 .short("h")
                 .long("host")
                 .help("Set the value to use for the hanshake host header")
                 .takes_value(true))
        .arg(Arg::with_name("origin")
                 .short("o")
                 .long("origin")
                 .help("Set the value to use for the hanshake origin header")
                 .takes_value(true))
        .arg(Arg::with_name("pfx_file_path")
                 .short("f")
                 .long("pfxpath")
                 .help("Set the path to the pfx file")
                 .takes_value(true))
        .arg(Arg::with_name("verbose")
                 .short("v")
                 .multiple(true)
                 .help("Sets the output verbosity"));

    let matches = if let Some(args) = opt_args {
        app.get_matches_from(args)
    } else {
        app.get_matches()
    };

    let mut config = if let Some(url) = matches.value_of("URL") {
        if let Ok(parsed_url) = Url::parse(url) {
            if let Ok(config) = Config::new(parsed_url) {
                config
            } else {
                err!("invalid configuration");
                return 1;
            }
        } else {
            err!("invalid websocket url: {}", url);
            return 1;
        }
    } else {
        err!("no websocket url specified");
        return 1;
    };

    if let Some(host_string) = matches.value_of("host") {
        config.set_host_header(host_string.into());
    }

    if let Some(origin_string) = matches.value_of("origin") {
        config.set_origin_header(origin_string.into());
    }

    if let Some(pfx_file_path_string) = matches.value_of("pfx_file_path") {
        config.set_pfx_file_path(pfx_file_path_string.into());
    }

    let mut host_port = String::from(config.host());
    host_port.push(':');
    host_port.push_str(&config.port().to_string());

    let socket_addr = if let Ok(mut addr_iter) = host_port.to_socket_addrs() {
        if let Some(socket_addr) = addr_iter.next() {
            socket_addr
        } else {
            err!("no valid address found");
            return 1;
        }
    } else {
        err!("unable to convert {} to valid socket addresses", host_port);
        return 1;
    };

    if let Err(e) = header(&socket_addr) {
        err!("unable to show header: {}", e);
    }

    if config.tls_enabled() {
        run_secure(&config, socket_addr)
    } else {
        run_unsecure(&config, socket_addr)
    }
}