twists 0.2.1

Example WebSocket Echo Server implemented with twist
// Copyright (c) 2016 twist developers
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.
//! websocket echo server using `twist`
#![deny(missing_docs)]
#![cfg_attr(feature = "cargo-clippy", allow(unseparated_literal_suffix))]
extern crate base64;
extern crate chrono;
extern crate clap;
extern crate env_logger;
extern crate futures;
extern crate native_tls;
extern crate serde;
extern crate slog_atomic;
extern crate slog_term;
extern crate tokio_proto;
extern crate tokio_service;
extern crate tokio_tls;
extern crate twist_jwt;
extern crate twist_lz4;

#[macro_use]
extern crate serde_derive;
#[macro_use]
extern crate slog;
#[macro_use]
extern crate twist;

mod claims;
mod service;

use clap::{App, Arg};
use native_tls::{Pkcs12, TlsAcceptor};
use service::Echo;
use slog::{DrainExt, Level, LevelFilter, Logger};
use slog_atomic::{AtomicSwitch, AtomicSwitchCtrl};
use std::env;
use std::fmt;
use std::fs::File;
use std::io::{self, Read, Write};
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
use std::thread::{self, JoinHandle};
use tokio_proto::TcpServer;
use tokio_tls::proto;
use twist::server::WebSocketProtocol;
use twist_lz4::ServerLz4;

/// The set of errors that can be generated by `twists`
enum TwistError {
    /// Thown if the given address cannot be parsed.
    AddrParse(std::net::AddrParseError),
    /// Thown if any IO error occurs.
    Io(std::io::Error),
    /// Thown on any errors setting up TLS.
    Tls(native_tls::Error),
    /// Thrown if PFX_PWD is not set propertly.
    Var(std::env::VarError),
}

impl fmt::Display for TwistError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match *self {
            TwistError::AddrParse(ref e) => write!(f, "{}", e),
            TwistError::Io(ref e) => write!(f, "{}", e),
            TwistError::Tls(ref e) => write!(f, "{}", e),
            TwistError::Var(ref e) => write!(f, "{}", e),
        }
    }
}

impl From<native_tls::Error> for TwistError {
    fn from(err: native_tls::Error) -> TwistError {
        TwistError::Tls(err)
    }
}

impl From<std::env::VarError> for TwistError {
    fn from(err: std::env::VarError) -> TwistError {
        TwistError::Var(err)
    }
}

impl From<std::io::Error> for TwistError {
    fn from(err: std::io::Error) -> TwistError {
        TwistError::Io(err)
    }
}

impl From<std::net::AddrParseError> for TwistError {
    fn from(err: std::net::AddrParseError) -> TwistError {
        TwistError::AddrParse(err)
    }
}

/// `twist` Configuration
struct Config {
    /// The unsecure ip address to listen on.
    address: String,
    /// The unsecure port to listen on.
    port: u16,
    /// Is TLS enabled?
    tls_enabled: bool,
    /// The secure ip address to listen on.
    tls_address: String,
    /// The secure port to listen on.
    tls_port: u16,
    /// The path to the PFX file for TLS.
    pfx_file_path: String,
    /// The slog switch ctrl for stdout.
    stdout_ctrl: AtomicSwitchCtrl<io::Error>,
    /// The slog stdout `Logger`.
    stdout: Logger,
    /// The slog stderr `Logger`.
    stderr: Logger,
}

impl Config {
    /// Set the `address` value.
    pub fn set_address(&mut self, address: String) -> &mut Config {
        self.address = address;
        self
    }

    /// Set the `port` value.
    pub fn set_port(&mut self, port: u16) -> &mut Config {
        self.port = port;
        self
    }

    /// Set the `tls_address` value.
    pub fn set_tls_address(&mut self, tls_address: String) -> &mut Config {
        self.tls_address = tls_address;
        self
    }

    /// Set the `tls_enabled` flag.
    pub fn set_tls_enabled(&mut self, tls_enabled: bool) -> &mut Config {
        self.tls_enabled = tls_enabled;
        self
    }

    /// Set the `tls_port` value.
    pub fn set_tls_port(&mut self, tls_port: u16) -> &mut Config {
        self.tls_port = tls_port;
        self
    }

    /// Set the `pfx_file_path` value.
    pub fn set_pfx_file_path(&mut self, pfx_file_path: String) -> &mut Config {
        self.pfx_file_path = pfx_file_path;
        self
    }

    /// Set the stdout slog 'Logger' level.
    pub fn set_level(&mut self, level: Level) -> &mut Config {
        let stdout_term = slog_term::streamer().async().compact().build();
        let stdout_drain = LevelFilter::new(stdout_term, level);
        self.stdout_ctrl.set(stdout_drain);
        self
    }
}

impl Default for Config {
    fn default() -> Config {
        let stderr_term = slog_term::streamer().async().stderr().compact().build();
        let stderr_drain = LevelFilter::new(stderr_term, Level::Error).fuse();
        let stderr = Logger::root(stderr_drain,
                                  o!(
            "executable" => env!("CARGO_PKG_NAME"),
            "version" => env!("CARGO_PKG_VERSION")
        ));

        let stdout_term = slog_term::streamer().async().compact().build();
        let stdout_drain = LevelFilter::new(stdout_term, Level::Info);
        let stdout_ctrl = AtomicSwitch::new(stdout_drain).ctrl();
        let stdout = Logger::root(stdout_ctrl.drain().fuse(),
                                  o!(
            "executable" => "twist",
            "version" => env!("CARGO_PKG_VERSION")
        ));

        Config {
            address: String::from("127.0.0.1"),
            port: 11579,
            tls_enabled: false,
            tls_address: String::from("127.0.0.1"),
            tls_port: 32276,
            pfx_file_path: String::from(".env/jasonozias.com.pfx"),
            stdout_ctrl: stdout_ctrl,
            stdout: stdout,
            stderr: stderr,
        }
    }
}

/// Run the unsecure `Echo` service.
fn run_unsecure(config: &Config) -> Result<JoinHandle<()>, TwistError> {
    let addr = IpAddr::from_str(&config.address)?;
    let unenc_socket_addr = SocketAddr::new(addr, config.port);
    let unenc_stdout = config.stdout.clone();
    let mut lz4: ServerLz4 = Default::default();
    lz4.stdout(config.stdout.clone()).stderr(config.stderr.clone());
    let mut ws_proto: WebSocketProtocol = Default::default();
    ws_proto.stdout(config.stdout.clone());
    ws_proto.stderr(config.stderr.clone());
    ws_proto.per_message(lz4);
    let mut server = TcpServer::new(ws_proto, unenc_socket_addr);
    server.threads(4);
    let mut service: Echo = Default::default();
    service.add_stdout(config.stdout.clone()).add_stderr(config.stderr.clone());
    let unenc = thread::spawn(move || {
                                  info!(unenc_stdout,
                                        "Listening for websocket connections on {}",
                                        unenc_socket_addr);
                                  server.serve(move || Ok(service.clone()));
                              });
    Ok(unenc)
}

/// Run the secure `Echo` service.
fn run_secure(config: &Config) -> Result<JoinHandle<()>, TwistError> {
    // Setup the socket address.
    let addr = IpAddr::from_str(&config.tls_address)?;
    let socket_addr = SocketAddr::new(addr, config.tls_port);

    // Read the PFX file.
    let path = PathBuf::from(&config.pfx_file_path);
    let mut file = File::open(path)?;
    let mut pkcs12 = vec![];
    file.read_to_end(&mut pkcs12)?;

    // Get the PFX file password.
    let pfx_pwd = env::var("PFX_PWD")?;

    // Setup the TLS acceptor.
    let pkcs12 = Pkcs12::from_der(&pkcs12, &pfx_pwd)?;
    let builder = TlsAcceptor::builder(pkcs12)?;
    let acceptor = builder.build()?;

    /// Setup the lz4 compression extension.
    let mut lz4: ServerLz4 = Default::default();
    lz4.stdout(config.stdout.clone()).stderr(config.stderr.clone());

    /// Setup the tokio-proto protocol.
    let mut ws_proto: WebSocketProtocol = Default::default();
    ws_proto.stdout(config.stdout.clone());
    ws_proto.stderr(config.stderr.clone());
    ws_proto.per_message(lz4);

    /// Setup the tokio-proto TlsServer
    let tls_proto = proto::Server::new(ws_proto, acceptor);
    let mut server = TcpServer::new(tls_proto, socket_addr);
    server.threads(4);

    // Setup the Echo tokio-service
    let mut service: Echo = Default::default();
    service.add_stdout(config.stdout.clone()).add_stderr(config.stderr.clone());

    // Clone stdout for the thread.
    let stdout = config.stdout.clone();
    let enc = thread::spawn(move || {
                                info!(stdout,
                                      "Listening for secure websocket connections on {}",
                                      socket_addr);
                                server.serve(move || Ok(service.clone()));
                            });

    Ok(enc)
}

macro_rules! err(
    ($($args:tt)+) => {
        writeln!(io::stderr(), $($args)+).expect("Unable to write to stderr");
    }
);

macro_rules! try_join(
    ($h:expr) => {
        $h.join().expect("Failed to join child thread");
    }
);

/// Run the twists websocket server
fn main() {
    if let Err(e) = env_logger::init() {
        let stdout = io::stdout();
        let mut handle = stdout.lock();
        writeln!(handle, "unable to initialize env_logger! {}", e)
            .expect("Unable to write to stdout");
    }

    let mut config: Config = Default::default();

    let matches = App::new("twist")
        .version(env!("CARGO_PKG_VERSION"))
        .author("Jason Ozias <jason.g.ozias@gmail.com>")
        .about("RUSTFul Server for ellmak")
        .arg(Arg::with_name("address")
                 .short("a")
                 .long("address")
                 .help("Set the unsecure address to listen on")
                 .takes_value(true))
        .arg(Arg::with_name("port")
                 .short("p")
                 .long("port")
                 .help("Set the unsecure port to listen on")
                 .takes_value(true))
        .arg(Arg::with_name("tls")
                 .short("s")
                 .long("with-tls")
                 .help("Enable tls listener"))
        .arg(Arg::with_name("tls_address")
                 .long("tlsaddr")
                 .help("Set the secure address to listen on")
                 .takes_value(true))
        .arg(Arg::with_name("tls_port")
                 .long("tlsport")
                 .help("Set the secure port to listen on")
                 .takes_value(true))
        .arg(Arg::with_name("verbose")
                 .short("v")
                 .multiple(true)
                 .help("Sets the output verbosity"))
        .arg(Arg::with_name("pfx_file_path")
                 .short("f")
                 .long("pfxpath")
                 .help("Set the path to the pfx file")
                 .takes_value(true))
        .get_matches();

    if let Some(addr_string) = matches.value_of("address") {
        config.set_address(addr_string.into());
    }

    if let Some(port_string) = matches.value_of("port") {
        if let Ok(port_val) = port_string.parse::<u16>() {
            config.set_port(port_val);
        }
    }

    if let Some(addr_string) = matches.value_of("tls_address") {
        config.set_tls_address(addr_string.into());
    }

    if let Some(port_string) = matches.value_of("tls_port") {
        if let Ok(port_val) = port_string.parse::<u16>() {
            config.set_tls_port(port_val);
        }
    }

    if let Some(pfx_file_path_string) = matches.value_of("pfx_file_path") {
        config.set_pfx_file_path(pfx_file_path_string.into());
    }

    config.set_level(match matches.occurrences_of("verbose") {
                         0 => Level::Warning,
                         1 => Level::Info,
                         2 => Level::Debug,
                         3 | _ => Level::Trace,
                     });

    config.set_tls_enabled(matches.is_present("tls"));

    let unenc = match run_unsecure(&config) {
        Ok(unenc) => unenc,
        Err(e) => {
            err!("{}", e);
            std::process::exit(1);
        }
    };

    if config.tls_enabled {
        let enc = match run_secure(&config) {
            Ok(enc) => enc,
            Err(e) => {
                match e {
                    TwistError::Var(_) => err!("PFX_PWD not set"),
                    _ => err!("{}", e),
                }
                std::process::exit(1);
            }
        };
        try_join!(enc);
    }
    try_join!(unenc);
}