use crate::claims::{Aud, Claims, ClaimsBuilder, ExtraClaims};
use jwt;
use lru_time_cache::LruCache;
use std::env;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use crate::errors::{Result, ResultExt};
use crate::util::convert_pem_to_der;
pub struct Generator {
header: jwt::Header,
private_key: Vec<u8>,
claims_builder: Arc<RwLock<ClaimsBuilder>>,
cache: Arc<RwLock<Option<LruCache<String, String>>>>,
}
impl Generator {
pub fn new(iss: String, kid: String, private_key: Vec<u8>) -> Generator {
let mut header = jwt::Header::new(jwt::Algorithm::RS256);
header.kid = Some(kid);
let claims_builder = ClaimsBuilder::new(iss);
Generator {
header,
private_key,
claims_builder: Arc::new(RwLock::new(claims_builder)),
cache: Arc::new(RwLock::new(None)),
}
}
pub fn enable_token_caching(&self, max_count: usize, ttl: Duration) {
let new_cache =
LruCache::<String, String>::with_expiry_duration_and_capacity(ttl, max_count);
let mut cur_cache = self.cache.write().expect("failed to acquire lock on cache");
*cur_cache = Some(new_cache);
}
pub fn disable_token_caching(&self) -> Option<LruCache<String, String>> {
let mut cur_cache = self.cache.write().expect("failed to acquire lock on cache");
let old_cache = cur_cache.clone();
*cur_cache = None;
old_cache
}
pub fn set_max_lifespan(&self, lifespan: i64) {
self.claims_builder
.write()
.expect("failed to acquire lock on claims builder")
.lifespan(lifespan);
}
pub fn from_env() -> Result<Generator> {
let get_env_var = |x| {
env::var(x).map_err(|_| format_err!("Could not find '{:?}' environment variable", x))
};
let pem_key = get_env_var("ASAP_PRIVATE_KEY")?;
let der_key = convert_pem_to_der(pem_key.as_bytes())?;
let iss = get_env_var("ASAP_ISSUER")?;
let kid = get_env_var("ASAP_KEY_ID")?;
Ok(Generator::new(iss, kid, der_key))
}
pub fn token(&self, aud: Aud, extra_claims: Option<ExtraClaims>) -> Result<String> {
let claims = self
.claims_builder
.write()
.expect("failed to acquire lock on claims builder")
.build(aud, extra_claims);
let cache_enabled = self
.cache
.read()
.expect("failed to acquire lock on cache")
.is_some();
if cache_enabled {
{
let mut cache_opt = self.cache.write().expect("failed to acquire lock on cache");
let cache = cache_opt.as_mut().unwrap();
let cache_key = claims.cache_key();
if let Some(cached_token) = cache.get(&cache_key) {
return Ok(cached_token.to_string());
}
}
let token = Generator::generate_token(&self.header, &claims, &self.private_key)?;
{
let mut cache_opt = self.cache.write().expect("failed to acquire lock on cache");
let cache = cache_opt.as_mut().unwrap();
cache.insert(claims.cache_key(), token.clone());
}
return Ok(token);
}
Generator::generate_token(&self.header, &claims, &self.private_key)
}
fn generate_token(header: &jwt::Header, claims: &Claims, private_key: &[u8]) -> Result<String> {
let token = jwt::encode(
header,
&claims,
&jwt::EncodingKey::from_rsa_der(private_key),
)
.sync()?;
Ok(token)
}
pub fn auth_header(&self, aud: Aud, extra_claims: Option<ExtraClaims>) -> Result<String> {
Ok(format!("Bearer {}", self.token(aud, extra_claims)?))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::claims::Aud;
fn check_sync<T: Sync>() {}
#[test]
fn is_sync() {
check_sync::<Generator>();
}
#[test]
fn it_does_not_cache_more_tokens_than_max_count() {
let iss = "service01";
let kid = "service01/1530402390-public.der";
let private_key = include_bytes!("../support/keys/service01/1530402390-private.der");
let generator = Generator::new(iss.to_string(), kid.to_string(), private_key.to_vec());
generator.enable_token_caching(10, ::std::time::Duration::from_millis(1000));
let _token_1 = generator.token(Aud::One(iss.to_string()), None).unwrap();
let _token_2 = generator.token(Aud::One("foo".to_string()), None).unwrap();
let _token_3 = generator
.token(Aud::Many(vec![iss.to_string(), "foo".to_string()]), None)
.unwrap();
assert_eq!(generator.cache.write().unwrap().as_ref().unwrap().len(), 3);
generator.enable_token_caching(2, ::std::time::Duration::from_millis(1000));
let _token_1 = generator.token(Aud::One(iss.to_string()), None).unwrap();
let _token_2 = generator.token(Aud::One("foo".to_string()), None).unwrap();
let _token_3 = generator
.token(Aud::Many(vec![iss.to_string(), "foo".to_string()]), None)
.unwrap();
assert_eq!(generator.cache.write().unwrap().as_ref().unwrap().len(), 2);
}
}