#[macro_use]
extern crate serde_derive;
extern crate base64;
extern crate openssl;
extern crate serde;
extern crate serde_json;
extern crate time;
pub mod errors;
use errors::Error;
use openssl::bn::BigNum;
use openssl::hash::MessageDigest;
use openssl::pkey::{PKey, Public};
use openssl::rsa::Rsa;
use openssl::sign::Verifier;
use std::collections::HashMap;
use std::io::Read;
use time::{Duration, OffsetDateTime};
type JsonValue = serde_json::value::Value;
type JsonObject = serde_json::map::Map<String, JsonValue>;
#[derive(Deserialize, Debug)]
struct Header {
pub alg: String,
pub kid: String,
}
struct Payload {
pub sub: String,
pub iss: String,
pub aud: String,
pub exp: OffsetDateTime,
}
#[derive(Deserialize)]
struct JsonKey {
pub kty: String,
pub alg: String,
#[serde(rename = "use")]
pub use_: String,
pub kid: String,
pub n: String,
pub e: String,
}
#[derive(Deserialize)]
struct JsonKeys {
pub keys: Vec<JsonKey>,
}
struct Key {
alg: String,
pkey: PKey<Public>,
}
type KeysMap = HashMap<String, Key>;
pub struct Ctx {
client_id: String,
keys: KeysMap,
}
impl Ctx {
pub fn new(client_id: String) -> Ctx {
Ctx {
client_id: client_id,
keys: HashMap::new(),
}
}
fn set_keys_from_json_keys(&mut self, jsonkeys: JsonKeys) -> Result<(), Error> {
let mut map: KeysMap = HashMap::new();
for key in jsonkeys.keys {
if key.use_ != "sig" {
continue;
}
match key.alg.as_ref() {
"RS256" => {
let n_decoded = base64_decode_url(&key.n)?;
let n = BigNum::from_slice(&n_decoded)?;
let e_decoded = base64_decode_url(&key.e)?;
let e = BigNum::from_slice(&e_decoded)?;
let rsa = Rsa::from_public_components(n, e)?;
let pkey: PKey<Public> = PKey::from_rsa(rsa)?;
let k = Key {
alg: key.alg.clone(),
pkey: pkey,
};
map.insert(key.kid, k);
}
_ => return Err(Error::UnsupportedAlgorithm),
}
}
if map.len() == 0 {
return Err(Error::NoKeys);
}
self.keys = map;
Ok(())
}
pub fn set_keys_from_reader<R>(&mut self, reader: R) -> Result<(), Error>
where
R: Read,
{
let jsonkeys: JsonKeys = serde_json::from_reader(reader)?;
return self.set_keys_from_json_keys(jsonkeys);
}
pub fn set_keys_from_str<'a>(&mut self, s: &'a str) -> Result<(), Error> {
let jsonkeys: JsonKeys = serde_json::from_str(s)?;
return self.set_keys_from_json_keys(jsonkeys);
}
pub fn google_signin_from_str(&self, token: &str) -> Result<String, Error> {
let arr: Vec<&str> = token.split(".").collect();
if arr.len() != 3 {
return Err(Error::InvalidToken);
}
let hdr_base64 = arr[0];
let payload_base64 = arr[1];
let sig_base64 = arr[2];
let hdr = decode_header(hdr_base64)?;
let payload = decode_payload(payload_base64)?;
let sig = base64_decode_url(sig_base64)?;
let sig_slice: &[u8] = &sig;
verify_payload(self, &payload)?;
verify_signature(self, &hdr, &hdr_base64, &payload_base64, sig_slice)?;
Ok(payload.sub)
}
}
fn base64_decode_url(msg: &str) -> Result<Vec<u8>, base64::DecodeError> {
base64::decode_config(msg, base64::URL_SAFE)
}
fn decode_header(base64_hdr: &str) -> Result<Header, Error> {
let hdr = base64_decode_url(base64_hdr)?;
let hdr: Header = serde_json::from_slice(&hdr)?;
Ok(hdr)
}
fn json_get_str<'a>(obj: &'a JsonObject, name: &'static str) -> Result<&'a str, Error> {
let o: Option<&JsonValue> = obj.get(name);
if let Some(v) = o {
if !v.is_string() {
return Err(Error::InvalidTypeField(name));
}
return Ok(v.as_str().unwrap());
} else {
return Err(Error::MissingField(name));
}
}
fn json_get_numeric_date(obj: &JsonObject, name: &'static str) -> Result<OffsetDateTime, Error> {
let o: Option<&JsonValue> = obj.get(name);
if let Some(v) = o {
if !v.is_i64() {
return Err(Error::InvalidTypeField(name));
}
let sec = v.as_i64().unwrap();
return Ok(OffsetDateTime::from_unix_timestamp(sec));
} else {
return Err(Error::MissingField(name));
}
}
fn decode_payload(base64_payload: &str) -> Result<Payload, Error> {
let payload_json = base64_decode_url(base64_payload)?;
let obj: JsonValue = serde_json::from_slice(&payload_json)?;
if !obj.is_object() {
return Err(Error::InvalidTypeField(""));
}
let obj = obj.as_object().unwrap();
let sub = json_get_str(obj, "sub")?;
let iss = json_get_str(obj, "iss")?;
let aud = json_get_str(obj, "aud")?;
let exp = json_get_numeric_date(obj, "exp")?;
Ok(Payload {
sub: sub.to_string(),
iss: iss.to_string(),
aud: aud.to_string(),
exp: exp,
})
}
fn verify_payload(ctx: &Ctx, payload: &Payload) -> Result<(), Error> {
if payload.aud != ctx.client_id {
return Err(Error::InvalidAudience);
}
if payload.iss != "accounts.google.com" && payload.iss != "https://accounts.google.com" {
return Err(Error::InvalidIssuer);
}
let now = OffsetDateTime::now_utc();
if payload.exp + Duration::hours(1) < now {
return Err(Error::Expired);
}
Ok(())
}
fn verify_rs256(txt: &str, key: &Key, sig: &[u8]) -> Result<(), Error> {
let digest = MessageDigest::sha256();
let mut verifier = Verifier::new(digest, &key.pkey)?;
verifier.update(txt.as_bytes())?;
let res = verifier.verify(sig);
match res {
Ok(true) => Ok(()),
Ok(false) => Err(Error::InvalidSignature),
Err(_) => Err(Error::InvalidSignature),
}
}
fn verify_signature(
ctx: &Ctx,
hdr: &Header,
hdr_base64: &str,
payload_base64: &str,
sig: &[u8],
) -> Result<(), Error> {
let txt = format!("{}.{}", hdr_base64, payload_base64);
let key = ctx.keys.get(&hdr.kid);
if key.is_none() {
return Err(Error::NoMatchingSigningKey);
}
let key = key.unwrap();
if key.alg != hdr.alg {
return Err(Error::NoMatchingSigningKey);
}
match key.alg.as_ref() {
"RS256" => verify_rs256(&txt, key, sig),
_ => Err(Error::UnsupportedAlgorithm),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
fn content_from_file(filename: &str) -> String {
let mut file = File::open(filename).unwrap();
let mut buf = String::new();
assert!(file.read_to_string(&mut buf).is_ok());
buf.pop(); buf
}
#[test]
fn from_token_file() {
let token = content_from_file("token");
let client_id = content_from_file("client_id");
let mut ctx = Ctx::new(client_id);
let keys = File::open("google_keys.json").unwrap();
assert!(ctx.set_keys_from_reader(keys).is_ok());
let res = ctx.google_signin_from_str(&token);
assert!(res.is_ok());
}
}