use async_trait::async_trait;
use base64::DecodeError;
use headers::Header;
use jwt_simple::prelude::*;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use reqwest::header::{HeaderMap, CACHE_CONTROL};
use serde::Deserialize;
use std::collections::HashMap;
use std::time::SystemTime;
use thiserror::Error;
#[async_trait(?Send)]
pub trait KeyProvider: Send {
async fn get_key(
&mut self,
kid: &str,
now: SystemTime,
) -> Result<RS256PublicKey, ProviderError>;
}
#[derive(Error, Debug)]
pub enum ProviderError {
#[error("Key not found.")]
KeyNotFound,
#[error("Fetch error - {0}.")]
FetchError(String),
#[error("Parse error - {0}.")]
ParseError(String),
#[error("Unknown error.")]
UnknownError,
#[error("Decode error - {0}.")]
DecodingError(DecodeError),
#[error("Key creation error - {0}.")]
CreateKeyError(jwt_simple::Error),
}
#[derive(Deserialize, Clone)]
pub struct GoogleKeys {
keys: Vec<GoogleKey>,
}
#[derive(Deserialize, Clone, Debug)]
pub struct GoogleKey {
kid: String,
n: String,
e: String,
}
impl GoogleKey {
fn n(&self) -> Result<Vec<u8>, ProviderError> {
let n: Result<Vec<u8>, DecodeError> = URL_SAFE_NO_PAD.decode(self.n.clone());
n.map_err(ProviderError::DecodingError)
}
fn e(&self) -> Result<Vec<u8>, ProviderError> {
let e = URL_SAFE_NO_PAD.decode(self.e.clone());
e.map_err(ProviderError::DecodingError)
}
}
pub struct GooglePublicKeyProvider {
url: String,
keys: HashMap<String, GoogleKey>,
expiration_time: Option<SystemTime>,
}
impl GooglePublicKeyProvider {
pub fn new(public_key_url: &str) -> Self {
Self {
url: public_key_url.to_owned(),
keys: Default::default(),
expiration_time: None,
}
}
pub async fn reload(&mut self, now: SystemTime) -> Result<(), ProviderError> {
let r = reqwest::get(&self.url).await;
let r = r.map_err(|e| ProviderError::FetchError(format!("{:?}", e)))?;
let expiration_time = GooglePublicKeyProvider::parse_expiration_time(r.headers(), now);
let google_keys = r.json::<GoogleKeys>().await;
let google_keys = google_keys.map_err(|e| ProviderError::ParseError(format!("{:?}", e)))?;
self.keys.clear();
for key in google_keys.keys.into_iter() {
self.keys.insert(key.kid.clone(), key);
}
self.expiration_time = expiration_time;
Result::Ok(())
}
fn parse_expiration_time(header_map: &HeaderMap, now: SystemTime) -> Option<SystemTime> {
headers::CacheControl::decode(&mut header_map.get_all(CACHE_CONTROL).iter())
.ok()
.and_then(|h| h.max_age())
.map(|d| now + d)
}
pub fn is_expire(&self, now: SystemTime) -> bool {
self.expiration_time.map_or_else(|| true, |t| now > t)
}
}
#[async_trait(?Send)]
impl KeyProvider for GooglePublicKeyProvider {
async fn get_key(
&mut self,
kid: &str,
now: SystemTime,
) -> Result<RS256PublicKey, ProviderError> {
if self.is_expire(now) {
self.reload(now).await?;
}
let key_data = self.keys.get(&kid.to_owned());
let key_data = key_data.ok_or(ProviderError::KeyNotFound)?;
let key = RS256PublicKey::from_components(&key_data.n()?, &key_data.e()?);
key.map_err(ProviderError::CreateKeyError)
}
}
#[cfg(test)]
mod tests {
use crate::keys::{GooglePublicKeyProvider, KeyProvider, ProviderError};
use httpmock::MockServer;
use std::time::{Duration, SystemTime};
#[tokio::test]
async fn should_parse_keys() {
let n = "3g46w4uRYBx8CXFauWh6c5yO4ax_VDu5y8ml_Jd4Gx711155PTdtLeRuwZOhJ6nRy8YvLFPXc_aXtHifnQsi9YuI_vo7LGG2v3CCxh6ndZBjIeFkxErMDg4ELt2DQ0PgJUQUAKCkl2_gkVV9vh3oxahv_BpIgv1kuYlyQQi5JWeF7zAIm0FaZ-LJT27NbsCugcZIDQg9sztTN18L3-P_kYwvAkKY2bGYNU19qLFM1gZkzccFEDZv3LzAz7qbdWkwCoK00TUUH8TNjqmK67bytYzgEgkfF9q9szEQ5TrRL0uFg9LxT3kSTLYqYOVaUIX3uaChwaa-bQvHuNmryu7i9w";
let e = "AQAB";
let kid = "some-kid";
let resp = format!("{{\"keys\": [{{\"kty\": \"RSA\",\"use\": \"sig\",\"e\": \"{}\",\"n\": \"{}\",\"alg\": \"RS256\",\"kid\": \"{}\"}}]}}", e, n, kid);
let server = MockServer::start();
let _server_mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/");
then.status(200)
.header(
"cache-control",
"public, max-age=24920, must-revalidate, no-transform",
)
.header("Content-Type", "application/json; charset=UTF-8")
.body(resp);
});
let mut provider = GooglePublicKeyProvider::new(server.url("/").as_str());
assert!(matches!(
provider.get_key(kid, SystemTime::now()).await,
Result::Ok(_)
));
assert!(matches!(
provider.get_key("missing-key", SystemTime::now()).await,
Result::Err(_)
));
}
#[tokio::test]
async fn should_expire_and_reload() {
let server = MockServer::start();
let n = "3g46w4uRYBx8CXFauWh6c5yO4ax_VDu5y8ml_Jd4Gx711155PTdtLeRuwZOhJ6nRy8YvLFPXc_aXtHifnQsi9YuI_vo7LGG2v3CCxh6ndZBjIeFkxErMDg4ELt2DQ0PgJUQUAKCkl2_gkVV9vh3oxahv_BpIgv1kuYlyQQi5JWeF7zAIm0FaZ-LJT27NbsCugcZIDQg9sztTN18L3-P_kYwvAkKY2bGYNU19qLFM1gZkzccFEDZv3LzAz7qbdWkwCoK00TUUH8TNjqmK67bytYzgEgkfF9q9szEQ5TrRL0uFg9LxT3kSTLYqYOVaUIX3uaChwaa-bQvHuNmryu7i9w";
let e = "AQAB";
let kid = "some-kid";
let resp = format!("{{\"keys\": [{{\"kty\": \"RSA\",\"use\": \"sig\",\"e\": \"{}\",\"n\": \"{}\",\"alg\": \"RS256\",\"kid\": \"{}\"}}]}}", e, n, kid);
let mut server_mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/");
then.status(200)
.header(
"cache-control",
"public, max-age=3, must-revalidate, no-transform",
)
.header("Content-Type", "application/json; charset=UTF-8")
.body("{\"keys\":[]}");
});
let mut provider = GooglePublicKeyProvider::new(server.url("/").as_str());
let key_result = provider.get_key(kid, SystemTime::now()).await;
assert!(matches!(
key_result,
Result::Err(ProviderError::KeyNotFound)
));
server_mock.delete();
let _server_mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/");
then.status(200)
.header(
"cache-control",
"public, max-age=3, must-revalidate, no-transform",
)
.header("Content-Type", "application/json; charset=UTF-8")
.body(resp);
});
std::thread::sleep(Duration::from_secs(4));
let key_result = provider.get_key(kid, SystemTime::now()).await;
assert!(matches!(key_result, Result::Ok(_)));
}
}