use crate::jwt::{self, TokenData};
use bytes::Bytes;
use chrono::Utc;
use reqwest;
use serde_json::{from_str, to_string, Map, Value};
use std::cmp::max;
use std::collections::{HashMap, HashSet};
use std::env;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
use crate::claims::Claims;
use crate::errors::{Result, ResultExt, ValidatorError};
use crate::util::{extract_aud_from_claims, extract_claim};
type PublicKeyCache = HashMap<String, (SystemTime, Bytes)>;
pub const DEFAULT_CACHE_DURATION: Duration = Duration::from_secs(600);
pub const DEFAULT_MAX_LIFESPAN: i64 = 60 * 60;
pub struct ValidatorBuilder {
resource_server_audience: Option<String>,
keyserver_urls: Option<Vec<String>>,
leeway: Option<i64>,
max_lifespan: Option<i64>,
validate_jti: bool,
validate_kid: bool,
cache_duration: Option<Duration>,
}
impl ValidatorBuilder {
pub fn new(keyserver_url: String, resource_server_audience: String) -> ValidatorBuilder {
ValidatorBuilder {
resource_server_audience: Some(resource_server_audience),
keyserver_urls: Some(vec![keyserver_url]),
leeway: None,
max_lifespan: None,
validate_kid: true,
validate_jti: false,
cache_duration: None,
}
}
pub fn leeway(&mut self, leeway: i64) -> &mut ValidatorBuilder {
self.leeway = Some(leeway);
self
}
pub fn max_lifespan(&mut self, max_lifespan: i64) -> &mut ValidatorBuilder {
self.max_lifespan = Some(max_lifespan);
self
}
pub fn cache_duration(&mut self, cache_duration: Duration) -> &mut ValidatorBuilder {
self.cache_duration = Some(cache_duration);
self
}
pub fn validate_kid(&mut self, validate_kid: bool) -> &mut ValidatorBuilder {
self.validate_kid = validate_kid;
self
}
pub fn validate_jti(&mut self, validate_jti: bool) -> &mut ValidatorBuilder {
self.validate_jti = validate_jti;
self
}
pub fn fallback_keyserver(&mut self, keyserver: String) -> &mut ValidatorBuilder {
self.keyserver_urls.as_mut().unwrap().push(keyserver);
self
}
pub fn build(&mut self) -> Validator {
let jwt_validator = jwt::Validation {
leeway: 0,
validate_exp: false,
validate_nbf: false,
iss: None,
sub: None,
aud: None,
algorithms: vec![jwt::Algorithm::RS256],
};
Validator {
leeway: self.leeway.unwrap_or(0),
max_lifespan: max(0, self.max_lifespan.unwrap_or(DEFAULT_MAX_LIFESPAN)),
jwt_validator,
keyserver_urls: self.keyserver_urls.take().unwrap(),
resource_server_audience: self.resource_server_audience.take().unwrap(),
validate_kid: self.validate_kid,
validate_jti: self.validate_jti,
jti_seen: Arc::new(RwLock::new(HashSet::new())),
key_cache: Arc::new(RwLock::new(HashMap::new())),
key_cache_duration: self.cache_duration.unwrap_or(DEFAULT_CACHE_DURATION),
}
}
}
pub struct Validator {
pub validate_jti: bool,
pub validate_kid: bool,
jwt_validator: jwt::Validation,
resource_server_audience: String,
leeway: i64,
max_lifespan: i64,
keyserver_urls: Vec<String>,
jti_seen: Arc<RwLock<HashSet<String>>>,
key_cache: Arc<RwLock<PublicKeyCache>>,
key_cache_duration: Duration,
}
impl Validator {
pub fn builder(keyserver: String, resource_server_audience: String) -> ValidatorBuilder {
ValidatorBuilder::new(keyserver, resource_server_audience)
}
pub fn from_env() -> Result<ValidatorBuilder> {
let get_env_var = |x| {
env::var(x).map_err(|_| format_err!("Could not find '{:?}' environment variable", x))
};
let keyserver_url = get_env_var("ASAP_KEYSERVER_URL")?;
let fallback_keyserver_url = get_env_var("ASAP_FALLBACK_KEYSERVER_URL")?;
let resource_server_audience = get_env_var("ASAP_SERVER_AUDIENCE")?;
let mut vb = ValidatorBuilder::new(keyserver_url, resource_server_audience);
{
vb.fallback_keyserver(fallback_keyserver_url);
}
Ok(vb)
}
async fn get_key_from_server(&self, server_url: &str, kid: &str) -> Result<Bytes> {
let response = reqwest::get(&format!("{}{}", server_url, kid)).await?;
if response.status().is_success() {
Ok(response.bytes().await?)
} else {
Err(ValidatorError::KeyserverError(response.status().to_string()).into())
}
}
async fn get_public_key(&self, kid: &str) -> Result<Bytes> {
let mut key_cache = self
.key_cache
.write()
.expect("failed to acquire lock on public key cache");
if key_cache.contains_key(kid) {
{
let cached_key = get_key_from_cache(&mut key_cache, self.key_cache_duration, kid);
if cached_key.is_ok() {
return cached_key;
}
eprintln!(
"Error fetching from cache, reason: {}. \
Trying keyserver...",
cached_key.err().unwrap()
);
}
key_cache.remove(kid);
}
for url in &self.keyserver_urls {
if let Ok(key) = self.get_key_from_server(url, kid).await {
return Ok(key);
}
}
Err(
ValidatorError::KeyserverError("Failed to fetch a key from any keyserver".to_string())
.into(),
)
}
pub async fn decode(
&self,
token: &str,
whitelisted_issuers: &[&str],
) -> Result<TokenData<Claims>> {
let header = jwt::decode_header(token).sync()?;
let kid = header
.kid
.clone()
.ok_or(ValidatorError::NoKIDFound(header))?;
let public_key = self.get_public_key(&kid).await?;
let data = jwt::decode::<Claims>(
token,
&jwt::DecodingKey::from_rsa_der(&public_key),
&self.jwt_validator,
)
.sync()?;
self.validate(
&kid,
&from_str(&to_string(&data.claims)?)?,
whitelisted_issuers,
)?;
self.key_cache
.write()
.expect("failed to acquire lock on public key cache")
.entry(kid)
.or_insert((SystemTime::now(), public_key));
Ok(data)
}
pub fn dangerous_unsafe_decode(&self, token: &str) -> Result<TokenData<Claims>> {
Ok(jwt::dangerous_insecure_decode::<Claims>(token).sync()?)
}
fn validate(
&self,
kid: &str,
claims: &Map<String, Value>,
whitelisted_issuers: &[&str],
) -> Result<()> {
let now = Utc::now().timestamp();
let iss = extract_claim::<String>(claims, "iss")?;
let exp = extract_claim::<i64>(claims, "exp")?;
let iat = extract_claim::<i64>(claims, "iat")?;
let jti = extract_claim::<String>(claims, "jti")?;
let aud = extract_aud_from_claims(claims)?;
if self.validate_jti {
if self
.jti_seen
.read()
.expect("failed to acquire lock on jti set")
.contains(&jti)
{
return Err(ValidatorError::DuplicateJTI(jti).into());
} else {
self.jti_seen
.write()
.expect("failed to acquire lock on jti set")
.insert(jti);
}
}
if self.validate_kid && !kid.starts_with(&format!("{}/", &iss)) {
return Err(ValidatorError::InvalidKID(kid.to_string(), iss).into());
}
let nbf = extract_claim::<i64>(claims, "nbf").unwrap_or(iat);
if nbf > now + self.leeway {
return Err(ValidatorError::ImmatureSignature(nbf, exp).into());
} else if exp < now - self.leeway {
return Err(ValidatorError::ExpiredSignature(nbf, exp).into());
}
if exp - iat > self.max_lifespan {
return Err(ValidatorError::InvalidLifespan.into());
}
if !aud.contains(&self.resource_server_audience) {
return Err(ValidatorError::UnrecognisedAudience(aud).into());
}
let sub = extract_claim::<String>(claims, "sub").unwrap_or(iss);
if !whitelisted_issuers.contains(&sub.as_ref()) {
let subjects = whitelisted_issuers.iter().map(|&x| x.to_owned()).collect();
return Err(ValidatorError::UnauthorizedSubject(sub, subjects).into());
}
Ok(())
}
}
fn get_key_from_cache(
key_cache: &mut PublicKeyCache,
key_cache_duration: Duration,
kid: &str,
) -> Result<Bytes> {
if let Some((when, public_key)) = key_cache.get(kid) {
let time_since = when.elapsed()?;
if time_since <= key_cache_duration {
Ok(public_key.clone())
} else {
Err(ValidatorError::ExpiredCache(String::from(kid)).into())
}
} else {
Err(ValidatorError::CacheError.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn check_sync<T: Sync>() {}
#[test]
fn is_sync() {
check_sync::<Validator>();
}
}