twist-jwt 0.3.1

An implementation of RFC7519 JSON Web Token (JWT)
//! twist client side extension.
use {PMJWT, RSV3, SWE};
use byteorder::{BigEndian, WriteBytesExt};
use slog::Logger;
use std::io;
use twist::client::BaseFrame;
use twist::extension::{Header, PerMessage};
use util;

/// The jwt client-side extension configuration.
#[derive(Default)]
pub struct Jwt {
    /// Is this extension enabled?
    enabled: bool,
    /// slog stdout `Logger`
    stdout: Option<Logger>,
    /// slog stderr `Logger`
    stderr: Option<Logger>,
    /// JWT token.
    token: Vec<u8>,
}

impl Jwt {
    /// Set the `enabled` flag.
    pub fn set_enabled(&mut self, enabled: bool) -> &mut Jwt {
        self.enabled = enabled;
        self
    }

    /// Set the `token` value.
    // TODO: Remove this when decode is implemented.
    pub fn set_token(&mut self, token: &[u8]) -> &mut Jwt {
        self.token = token.into();
        self
    }

    /// Add a stdout slog `Logger` to this protocol.
    pub fn stdout(&mut self, logger: Logger) -> &mut Jwt {
        let stdout = logger.new(o!("extension" => "jwt", "module" => "client"));
        self.stdout = Some(stdout);
        self
    }

    /// Add a stderr slog `Logger` to this protocol.
    pub fn stderr(&mut self, logger: Logger) -> &mut Jwt {
        let stderr = logger.new(o!("extension" => "jwt", "module" => "client"));
        self.stderr = Some(stderr);
        self
    }
}

impl Header for Jwt {
    fn from_header(&mut self, header: &str) -> Result<(), io::Error> {
        try_trace!(self.stdout, "from_header");
        if header.contains(PMJWT) {
            try_trace!(self.stdout, "permessage-lz4 is enabled");
            self.enabled = true;
        } else {
            try_trace!(self.stdout, "permessage lz4 is disabled");
            self.enabled = false;
        }
        Ok(())
    }

    fn into_header(&mut self) -> io::Result<Option<String>> {
        try_trace!(self.stdout, "into_header");
        if self.enabled {
            let mut resp = String::new();
            resp.push_str(SWE);
            resp.push_str(PMJWT);
            Ok(Some(resp))
        } else {
            Ok(None)
        }
    }
}

impl PerMessage for Jwt {
    fn enabled(&self) -> bool {
        try_trace!(self.stdout, "enabled");
        self.enabled
    }

    fn reserve_rsv(&self, reserved: u8) -> Result<u8, io::Error> {
        try_trace!(self.stdout, "reserve_rsv");
        if self.enabled {
            if reserved & RSV3 == 0 {
                Ok(reserved | RSV3)
            } else {
                try_error!(self.stderr, "rsv3 bit is already reserved");
                Err(util::other("rsv3 bit is already reserved"))
            }
        } else {
            Ok(reserved)
        }
    }

    fn decode(&self, frame: &mut BaseFrame) -> Result<(), io::Error> {
        try_trace!(self.stdout, "decode");
        // If I get an rsv3 frame from the server, this is the token I should use.
        if frame.rsv3() {
            // let (valid, len, app_data) = if let Some(app_data) = frame.application_data() {
            //     // Read the first bytes to get the length of the token.
            //     // If value < 253, that is the length.
            //     // If value == 253, read the next 2 bytes to get the u16 length.
            //     // If value == 254, read the next 4 bytes to get the u32 length.
            //     // If value == 255, read the next 8 bytes to get the u64 length.
            //     (true, app_data.len() as u64, Some(frame.appli)
            // } else {
            //     (false, 0, None)
            // };

            // if valid {
            // let clone = frame.application_data().clone();
            // frame.set_payload_length(clone.len() as u64);
            // frame.set_application_data(clone);
            // }
        }
        Ok(())
    }

    fn encode(&self, frame: &mut BaseFrame) -> Result<(), io::Error> {
        try_trace!(self.stdout, "encode");
        // TODO: Only encode TEXT and BINARY frames.
        if frame.rsv3() {
            let mut ext_app_data = Vec::new();
            let tok_len = self.token.len();
            let mut first_byte = 0u8;
            if tok_len < 253 {
                #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
                let len = tok_len as u8;
                first_byte |= len;
                ext_app_data.push(first_byte);
            } else if tok_len < ::std::u16::MAX as usize {
                first_byte |= 253;
                ext_app_data.push(first_byte);
                let mut actual_len = Vec::with_capacity(2);
                #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
                let cast_len = tok_len as u16;
                actual_len.write_u16::<BigEndian>(cast_len)?;
                ext_app_data.extend(actual_len);
            } else if tok_len < ::std::u32::MAX as usize {
                first_byte |= 254;
                ext_app_data.push(first_byte);
                let mut actual_len = Vec::with_capacity(4);
                #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
                let cast_len = tok_len as u32;
                actual_len.write_u32::<BigEndian>(cast_len)?;
                ext_app_data.extend(actual_len);
            } else {
                first_byte |= 255;
                ext_app_data.push(first_byte);
                let mut actual_len = Vec::with_capacity(8);
                #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
                let cast_len = tok_len as u64;
                actual_len.write_u64::<BigEndian>(cast_len)?;
                ext_app_data.extend(actual_len);
            }
            ext_app_data.extend(self.token.clone());
            try_trace!(self.stdout,
                       "encoded token: {}",
                       util::as_hex(&ext_app_data));
            ext_app_data.extend(frame.application_data());
            frame.set_payload_length(ext_app_data.len() as u64);
            frame.set_application_data(ext_app_data);
        }
        Ok(())
    }
}