extern crate rustc_serialize;
extern crate time;
extern crate openssl;
pub mod error;
use rustc_serialize::base64::{self, ToBase64, FromBase64};
use rustc_serialize::json::{self, ToJson, Json};
use std::collections::BTreeMap;
use std::fs::File;
use std::io::Read;
use std::str;
use openssl::hash::MessageDigest;
use openssl::pkey::PKey;
use openssl::rsa::Rsa;
use openssl::sign::{Signer, Verifier};
use openssl::ec::EcKey;
use error::Error;
pub type Payload = BTreeMap<String, String>;
const STANDARD_HEADER_TYPE: &str = "JWT";
pub struct Header {
algorithm: Algorithm,
ttype: String
}
impl Header {
pub fn new(alg: Algorithm) -> Header {
Header { algorithm: alg, ttype: STANDARD_HEADER_TYPE.to_string() }
}
}
#[derive(Clone, Copy)]
pub enum Algorithm {
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
ES512
}
impl ToString for Algorithm {
fn to_string(&self) -> String {
match *self {
Algorithm::HS256 => "HS256",
Algorithm::HS384 => "HS384",
Algorithm::HS512 => "HS512",
Algorithm::RS256 => "RS256",
Algorithm::RS384 => "RS384",
Algorithm::RS512 => "RS512",
Algorithm::ES256 => "ES256",
Algorithm::ES384 => "ES384",
Algorithm::ES512 => "ES512"
}.to_string()
}
}
impl ToJson for Header {
fn to_json(&self) -> json::Json {
let mut map = BTreeMap::new();
map.insert("typ".to_string(), self.ttype.to_json());
map.insert("alg".to_string(), self.algorithm.to_string().to_json());
Json::Object(map)
}
}
pub fn encode(header: Header, key: String, payload: Payload) -> String {
let signing_input = get_signing_input(payload, &header.algorithm);
let signature = match header.algorithm {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => sign_hmac(&signing_input, key, header.algorithm),
Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => sign_rsa(&signing_input, key, header.algorithm),
Algorithm::ES256 | Algorithm::ES384 | Algorithm::ES512 => sign_es(&signing_input, key, header.algorithm),
};
format!("{}.{}", signing_input, signature)
}
pub fn decode(encoded_token: String, key: String, algorithm: Algorithm) -> Result<(Header, Payload), Error> {
match decode_segments(encoded_token) {
Some((header, payload, signature, signing_input)) => {
if !verify_signature(algorithm, signing_input, &signature, key.to_string()) {
return Err(Error::SignatureInvalid)
}
Ok((header, payload))
},
None => Err(Error::JWTInvalid)
}
}
fn segments_count() -> usize {
3
}
fn get_signing_input(payload: Payload, algorithm: &Algorithm) -> String {
let header = Header::new(*algorithm);
let header_json_str = header.to_json();
let encoded_header = base64_url_encode(header_json_str.to_string().as_bytes()).to_string();
let p = payload.into_iter().map(|(k, v)| (k, v.to_json())).collect();
let payload_json = Json::Object(p);
let encoded_payload = base64_url_encode(payload_json.to_string().as_bytes()).to_string();
format!("{}.{}", encoded_header, encoded_payload)
}
fn sign_hmac(data: &str, key: String, algorithm: Algorithm) -> String {
let stp = match algorithm {
Algorithm::HS256 => MessageDigest::sha256(),
Algorithm::HS384 => MessageDigest::sha384(),
Algorithm::HS512 => MessageDigest::sha512(),
_ => panic!("Invalid hmac algorithm")
};
let key = PKey::hmac(key.as_bytes()).unwrap();
let mut signer = Signer::new(stp, &key).unwrap();
signer.update(data.as_bytes()).unwrap();
let hmac = signer.finish().unwrap();
base64_url_encode(&hmac)
}
fn sign_rsa(data: &str, private_key_path: String, algorithm: Algorithm) -> String {
let stp = match algorithm {
Algorithm::RS256 => MessageDigest::sha256(),
Algorithm::RS384 => MessageDigest::sha384(),
Algorithm::RS512 => MessageDigest::sha512(),
_ => panic!("Invalid hmac algorithm")
};
let buffer = read_pem(&private_key_path[..]);
let rsa = Rsa::private_key_from_pem(&buffer).unwrap();
let key = PKey::from_rsa(rsa).unwrap();
sign(data, key, stp)
}
fn sign_es(data: &str, private_key_path: String, algorithm: Algorithm) -> String {
let raw_key = read_pem(&private_key_path[..]);
let ec_key = EcKey::private_key_from_pem(&raw_key).expect("could not convert to EC private key");
let key = PKey::from_ec_key(ec_key).expect("could not convert EC private key");
let stp = match algorithm {
Algorithm::ES256 => MessageDigest::sha256(),
Algorithm::ES384 => MessageDigest::sha384(),
Algorithm::ES512 => MessageDigest::sha512(),
_ => panic!("Invalid hmac algorithm")
};
sign(data, key, stp)
}
fn sign(data: &str, private_key:PKey,digest: MessageDigest) -> String {
let mut signer = Signer::new(digest, &private_key).unwrap();
signer.update(data.as_bytes()).unwrap();
let signature = signer.finish().unwrap();
base64_url_encode(&signature)
}
fn read_pem(private_key_path: &str) -> Vec<u8>{
let mut file = File::open(private_key_path).unwrap();
let mut buffer:Vec<u8> = Vec::new();
file.read_to_end(&mut buffer).unwrap();
buffer
}
fn decode_segments(encoded_token: String) -> Option<(Header, Payload, Vec<u8>, String)> {
let raw_segments: Vec<&str> = encoded_token.split(".").collect();
if raw_segments.len() != segments_count() {
return None
}
let header_segment = raw_segments[0];
let payload_segment = raw_segments[1];
let crypto_segment = raw_segments[2];
let (header, payload) = decode_header_and_payload(header_segment, payload_segment);
let signature = &crypto_segment.as_bytes().from_base64().unwrap();
let signing_input = format!("{}.{}", header_segment, payload_segment);
Some((header, payload, signature.clone(), signing_input))
}
fn decode_header_and_payload<'a>(header_segment: &str, payload_segment: &str) -> (Header, Payload) {
fn base64_to_json(input: &str) -> Json {
let bytes = input.as_bytes().from_base64().unwrap();
let s = str::from_utf8(&bytes).unwrap();
Json::from_str(s).unwrap()
};
let header_json = base64_to_json(header_segment);
let header_tree = json_to_tree(header_json);
let alg = header_tree.get("alg").unwrap();
let header = Header::new(parse_algorithm(alg));
let payload_json = base64_to_json(payload_segment);
let payload = json_to_tree(payload_json);
(header, payload)
}
fn parse_algorithm(alg: &str) -> Algorithm {
match alg {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
"HS512" => Algorithm::HS512,
"RS256" => Algorithm::RS256,
"ES512" => Algorithm::ES512,
"ES384" => Algorithm::ES384,
"ES256" => Algorithm::ES256,
_ => panic!("Unknown algorithm")
}
}
fn sign_hmac2(data: &str, key: String, algorithm: Algorithm) -> Vec<u8> {
let stp = match algorithm {
Algorithm::HS256 => MessageDigest::sha256(),
Algorithm::HS384 => MessageDigest::sha384(),
Algorithm::HS512 => MessageDigest::sha512(),
_ => panic!("Invalid HMAC algorithm")
};
let pkey = PKey::hmac(key.as_bytes()).unwrap();
let mut signer = Signer::new(stp, &pkey).unwrap();
signer.update(data.as_bytes()).unwrap();
signer.finish().unwrap()
}
fn verify_signature(algorithm: Algorithm, signing_input: String, signature: &[u8], public_key: String) -> bool {
match algorithm {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
let signature2 = sign_hmac2(&signing_input, public_key, algorithm);
secure_compare(signature, &signature2)
},
Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
let mut file = File::open(public_key).unwrap();
let mut buffer:Vec<u8> = Vec::new();
file.read_to_end(&mut buffer).unwrap();
let rsa = Rsa::public_key_from_pem(&buffer).unwrap();
let key = PKey::from_rsa(rsa).unwrap();
let digest = get_sha_algorithm(algorithm);
let mut verifier = Verifier::new(digest, &key).unwrap();
verifier.update(signing_input.as_bytes()).unwrap();
verifier.finish(&signature).unwrap()
},
Algorithm::ES256 | Algorithm::ES384 | Algorithm::ES512 => {
let raw_pem = read_pem(&public_key[..]);
let key = PKey::public_key_from_pem(&raw_pem).expect("could not convert ec key to pkey");
let digest = get_sha_algorithm(algorithm);
let mut verifier = Verifier::new(digest, &key).unwrap();
verifier.update(signing_input.as_bytes()).unwrap();
verifier.finish(&signature).unwrap()
},
}
}
fn get_sha_algorithm(alg: Algorithm) -> MessageDigest {
match alg {
Algorithm::RS256 | Algorithm::ES256 => MessageDigest::sha256(),
Algorithm::RS384 | Algorithm::ES384 => MessageDigest::sha384(),
Algorithm::RS512 | Algorithm::ES512 => MessageDigest::sha512(),
_ => panic!("Invalid RSA algorithm")
}
}
fn secure_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false
}
let mut res = 0_u8;
for (&x, &y) in a.iter().zip(b.iter()) {
res |= x ^ y;
}
res == 0
}
fn base64_url_encode(bytes: &[u8]) -> String {
bytes.to_base64(base64::URL_SAFE)
}
fn json_to_tree(input: Json) -> BTreeMap<String, String> {
match input {
Json::Object(json_tree) => json_tree.into_iter().map(|(k, v)| (k, match v {
Json::String(s) => s,
_ => unreachable!()
})).collect(),
_ => unreachable!()
}
}
#[cfg(test)]
mod tests {
extern crate time;
use super::{Header, Payload, Algorithm};
use super::encode;
use super::decode;
use super::secure_compare;
use std::env;
#[test]
fn test_encode_and_decode_jwt_hs256() {
let mut p1 = Payload::new();
p1.insert("key1".to_string(), "val1".to_string());
p1.insert("key2".to_string(), "val2".to_string());
p1.insert("key3".to_string(), "val3".to_string());
let secret = "secret123";
let header = Header::new(Algorithm::HS256);
let jwt1 = encode(header, secret.to_string(), p1.clone());
let maybe_res = decode(jwt1, secret.to_string(), Algorithm::HS256);
assert!(maybe_res.is_ok());
}
#[test]
fn test_decode_valid_jwt_hs256() {
let mut p1 = Payload::new();
p1.insert("key11".to_string(), "val1".to_string());
p1.insert("key22".to_string(), "val2".to_string());
let secret = "secret123";
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkxMSI6InZhbDEiLCJrZXkyMiI6InZhbDIifQ.jrcoVcRsmQqDEzSW9qOhG1HIrzV_n3nMhykNPnGvp9c";
let maybe_res = decode(jwt.to_string(), secret.to_string(), Algorithm::HS256);
assert!(maybe_res.is_ok());
}
#[test]
fn test_secure_compare_same_strings() {
let str1 = "same same".as_bytes();
let str2 = "same same".as_bytes();
let res = secure_compare(str1, str2);
assert!(res);
}
#[test]
fn test_fails_when_secure_compare_different_strings() {
let str1 = "same same".as_bytes();
let str2 = "same same but different".as_bytes();
let res = secure_compare(str1, str2);
assert!(!res);
}
#[test]
fn test_encode_and_decode_jwt_hs384() {
let mut p1 = Payload::new();
p1.insert("key1".to_string(), "val1".to_string());
p1.insert("key2".to_string(), "val2".to_string());
p1.insert("key3".to_string(), "val3".to_string());
let secret = "secret123";
let header = Header::new(Algorithm::HS384);
let jwt1 = encode(header, secret.to_string(), p1.clone());
let maybe_res = decode(jwt1, secret.to_string(), Algorithm::HS384);
assert!(maybe_res.is_ok());
}
#[test]
fn test_encode_and_decode_jwt_hs512() {
let mut p1 = Payload::new();
p1.insert("key12".to_string(), "val1".to_string());
p1.insert("key22".to_string(), "val2".to_string());
p1.insert("key33".to_string(), "val3".to_string());
let secret = "secret123456";
let header = Header::new(Algorithm::HS512);
let jwt1 = encode(header, secret.to_string(), p1.clone());
let maybe_res = decode(jwt1, secret.to_string(), Algorithm::HS512);
assert!(maybe_res.is_ok());
}
#[test]
fn test_encode_and_decode_jwt_rs256() {
let mut p1 = Payload::new();
p1.insert("key12".to_string(), "val1".to_string());
p1.insert("key22".to_string(), "val2".to_string());
p1.insert("key33".to_string(), "val3".to_string());
let header = Header::new(Algorithm::RS256);
let mut path = env::current_dir().unwrap();
path.push("test");
path.push("my_rsa_2048_key.pem");
path.to_str().unwrap().to_string();
let jwt1 = encode(header, get_rsa_256_private_key_full_path(), p1.clone());
let maybe_res = decode(jwt1, get_rsa_256_public_key_full_path(), Algorithm::RS256);
assert!(maybe_res.is_ok());
}
#[test]
fn test_decode_valid_jwt_rs256() {
let mut p1 = Payload::new();
p1.insert("key1".to_string(), "val1".to_string());
p1.insert("key2".to_string(), "val2".to_string());
let header = Header::new(Algorithm::RS256);
let jwt1 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkxIjoidmFsMSIsImtleTIiOiJ2YWwyIn0.DFusERCFWCL3CkKBaoVKsi1Z3QO2NTTRDTGHPqm7ctzypKHxLslJXfS1p_8_aRX30V2osMAEfGzXO9U0S9J1Z7looIFNf5rWSEcqA3ah7b7YQ2iTn9LOiDWwzVG8rm_HQXkWq-TXqayA-IXeiX9pVPB9bnguKXy3YrLWhP9pxnhl2WmaE9ryn8WTleMiElwDq4xw5JDeopA-qFS-AyEwlc-CE7S_afBd5OQBRbvgtfv1a9soNW3KP_mBg0ucz5eUYg_ON17BG6bwpAwyFuPdDAXphG4hCsa7GlXea0f7DnYD5e5-CA6O7BPW_EvjaGhL_D9LNWHJuDiSDBwZ4-IEIg";
let jwt2 = encode(header, get_rsa_256_private_key_full_path(), p1.clone());
assert_eq!(jwt1, jwt2);
}
#[test]
fn test_decode_valid_jwt_rs256_and_check_deeply() {
let mut p1 = Payload::new();
p1.insert("key1".to_string(), "val1".to_string());
p1.insert("key2".to_string(), "val2".to_string());
let h1 = Header::new(Algorithm::RS256);
let jwt1 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkxIjoidmFsMSIsImtleTIiOiJ2YWwyIn0.DFusERCFWCL3CkKBaoVKsi1Z3QO2NTTRDTGHPqm7ctzypKHxLslJXfS1p_8_aRX30V2osMAEfGzXO9U0S9J1Z7looIFNf5rWSEcqA3ah7b7YQ2iTn9LOiDWwzVG8rm_HQXkWq-TXqayA-IXeiX9pVPB9bnguKXy3YrLWhP9pxnhl2WmaE9ryn8WTleMiElwDq4xw5JDeopA-qFS-AyEwlc-CE7S_afBd5OQBRbvgtfv1a9soNW3KP_mBg0ucz5eUYg_ON17BG6bwpAwyFuPdDAXphG4hCsa7GlXea0f7DnYD5e5-CA6O7BPW_EvjaGhL_D9LNWHJuDiSDBwZ4-IEIg";
let res = decode(jwt1.to_string(), get_rsa_256_public_key_full_path(), Algorithm::RS256);
match res {
Ok((h2, p2)) => {
assert_eq!(h1.ttype, h2.ttype);
assert_eq!(h1.algorithm.to_string(), h2.algorithm.to_string()); for (k, v) in &p1 {
assert_eq!(true, p2.contains_key(k));
assert_eq!(v, p2.get(k).unwrap());
}
},
Err(e) => panic!(e)
}
}
#[test]
fn test_encode_and_decode_jwt_ec() {
let mut p1 = Payload::new();
p1.insert("key12".to_string(), "val1".to_string());
p1.insert("key22".to_string(), "val2".to_string());
p1.insert("key33".to_string(), "val3".to_string());
let header = Header::new(Algorithm::ES512);
let jwt1 = encode(header, get_ec_private_key_path(), p1.clone());
let maybe_res = decode(jwt1, get_ec_public_key_path(), Algorithm::ES512);
assert!(maybe_res.is_ok());
}
fn get_ec_private_key_path() -> String {
let mut path = env::current_dir().unwrap();
path.push("test");
path.push("ec_x9_62_prime256v1.private.key.pem");
path.to_str().unwrap().to_string()
}
fn get_ec_public_key_path() -> String {
let mut path = env::current_dir().unwrap();
path.push("test");
path.push("ec_x9_62_prime256v1.public.key.pem");
path.to_str().unwrap().to_string()
}
fn get_rsa_256_private_key_full_path() -> String {
let mut path = env::current_dir().unwrap();
path.push("test");
path.push("my_rsa_2048_key.pem");
path.to_str().unwrap().to_string()
}
fn get_rsa_256_public_key_full_path() -> String {
let mut path = env::current_dir().unwrap();
path.push("test");
path.push("my_rsa_public_2048_key.pem");
path.to_str().unwrap().to_string()
}
}