use crate::Result;
use crate::errors::CredentialsError;
use jsonwebtoken::{Algorithm, DecodingKey, jwk::JwkSet};
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
const IAP_JWK_URL: &str = "https://www.gstatic.com/iap/verify/public_key-jwk";
const OAUTH2_JWK_URL: &str = "https://www.googleapis.com/oauth2/v3/certs";
const CACHE_TTL: Duration = Duration::from_secs(3600);
#[derive(Clone, Debug)]
struct CacheEntry {
key: DecodingKey,
expires_at: Instant,
}
#[derive(Clone, Debug)]
pub struct JwkClient {
cache: Arc<RwLock<HashMap<String, CacheEntry>>>, ttl: Duration,
}
impl JwkClient {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
ttl: CACHE_TTL,
}
}
pub async fn get_or_load_cert(
&self,
key_id: String,
alg: Algorithm,
jwks_url: Option<String>,
) -> Result<DecodingKey> {
let key_id_str = key_id.as_str();
let mut cache = self.cache.try_write().map_err(|_e| {
CredentialsError::from_msg(false, "failed to obtain lock to read certificate cache")
})?;
if let Some(entry) = cache.get(key_id_str) {
if entry.expires_at > Instant::now() {
return Ok(entry.key.clone());
}
}
let jwks_url = self.resolve_jwks_url(alg, jwks_url)?;
let jwk_set: JwkSet = self.fetch_certs(jwks_url).await?;
let jwk = jwk_set.find(key_id_str).ok_or_else(|| {
CredentialsError::from_msg(false, "JWKS did not contain a matching `kid`")
})?;
let key = DecodingKey::from_jwk(jwk)
.map_err(|e| CredentialsError::new(false, "failed to parse JWK", e))?;
let entry = CacheEntry {
key: key.clone(),
expires_at: Instant::now() + self.ttl,
};
cache.insert(key_id_str.to_string(), entry);
Ok(key)
}
fn resolve_jwks_url(&self, alg: Algorithm, jwks_url: Option<String>) -> Result<String> {
if let Some(jwks_url) = jwks_url {
return Ok(jwks_url);
}
match alg {
Algorithm::RS256 => Ok(OAUTH2_JWK_URL.to_string()),
Algorithm::ES256 => Ok(IAP_JWK_URL.to_string()),
_ => Err(CredentialsError::from_msg(
false,
format!(
"unexpected signing algorithm: expected either RS256 or ES256: found {alg:?}"
),
)),
}
}
async fn fetch_certs(&self, jwks_url: String) -> Result<JwkSet> {
let client = reqwest::Client::new();
let response = client
.get(jwks_url)
.send()
.await
.map_err(|e| crate::errors::from_http_error(e, "failed to fetch JWK set"))?;
if !response.status().is_success() {
let err = crate::errors::from_http_response(response, "failed to fetch JWK set").await;
return Err(err);
}
let jwk_set: JwkSet = response
.json()
.await
.map_err(|e| CredentialsError::new(!e.is_decode(), "failed to parse JWK set", e))?;
Ok(jwk_set)
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use base64::Engine;
use httptest::matchers::{all_of, request};
use httptest::responders::json_encoded;
use httptest::{Expectation, Server};
use jsonwebtoken::Algorithm;
use p256::elliptic_curve::sec1::ToEncodedPoint;
use rsa::traits::PublicKeyParts;
use serial_test::parallel;
type TestResult = anyhow::Result<()>;
impl JwkClient {
fn with_ttl(ttl: Duration) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
ttl,
}
}
}
const TEST_KEY_ID: &str = "test-key-id";
pub(crate) fn create_rsa256_jwk_set_response() -> serde_json::Value {
let pub_cert = crate::credentials::tests::RSA_PRIVATE_KEY.to_public_key();
serde_json::json!({
"keys": [
{
"e": base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(pub_cert.e().to_bytes_be()),
"kid": TEST_KEY_ID,
"use": "sig",
"kty": "RSA",
"n": base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(pub_cert.n().to_bytes_be()),
"alg": "RS256"
}
]
})
}
pub(crate) fn create_es256_jwk_set_response() -> serde_json::Value {
let pk = crate::credentials::tests::ES256_PRIVATE_KEY.public_key();
let encoded_point = pk.to_encoded_point(false);
serde_json::json!({
"keys": [
{
"kid": TEST_KEY_ID,
"use": "sig",
"kty": "EC",
"crv": "P-256",
"x": base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded_point.x().unwrap()),
"y": base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded_point.y().unwrap()),
"alg": "ES256"
}
]
})
}
#[tokio::test]
#[parallel]
async fn test_get_or_load_cert_success() -> TestResult {
let server = Server::run();
let response = create_rsa256_jwk_set_response();
server.expect(
Expectation::matching(all_of![request::path("/certs"),])
.times(1)
.respond_with(json_encoded(response.clone())),
);
let client = JwkClient::new();
let jwks_url = format!("http://{}/certs", server.addr());
let _key = client
.get_or_load_cert(
TEST_KEY_ID.to_string(),
Algorithm::RS256,
Some(jwks_url.clone()),
)
.await?;
let _key = client
.get_or_load_cert(TEST_KEY_ID.to_string(), Algorithm::RS256, Some(jwks_url))
.await?;
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_get_or_load_cert_kid_not_found() -> TestResult {
let server = Server::run();
let response = create_rsa256_jwk_set_response();
server.expect(
Expectation::matching(all_of![request::path("/certs"),])
.times(1)
.respond_with(json_encoded(response.clone())),
);
let client = JwkClient::new();
let jwks_url = format!("http://{}/certs", server.addr());
let result = client
.get_or_load_cert("unknown-kid".to_string(), Algorithm::RS256, Some(jwks_url))
.await;
assert!(result.is_err(), "{result:?}");
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("JWKS did not contain a matching `kid`")
);
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_get_or_load_cert_fetch_error() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path("/certs"),])
.times(1)
.respond_with(httptest::responders::status_code(500)),
);
let client = JwkClient::new();
let jwks_url = format!("http://{}/certs", server.addr());
let result = client
.get_or_load_cert(TEST_KEY_ID.to_string(), Algorithm::RS256, Some(jwks_url))
.await;
assert!(result.is_err(), "{result:?}");
let err = result.unwrap_err();
assert!(err.to_string().contains("failed to fetch JWK set"));
Ok(())
}
#[test]
#[parallel]
fn test_resolve_jwks_url() -> TestResult {
let client = JwkClient::new();
let url = "https://example.com/jwks".to_string();
assert_eq!(
client
.resolve_jwks_url(Algorithm::RS256, Some(url.clone()))
.unwrap(),
url
);
assert_eq!(
client.resolve_jwks_url(Algorithm::RS256, None).unwrap(),
OAUTH2_JWK_URL
);
assert_eq!(
client.resolve_jwks_url(Algorithm::ES256, None).unwrap(),
IAP_JWK_URL
);
let result = client.resolve_jwks_url(Algorithm::HS256, None);
assert!(result.is_err(), "{result:?}");
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_get_or_load_cert_cache_expiration() -> TestResult {
let server = Server::run();
let response = create_rsa256_jwk_set_response();
server.expect(
Expectation::matching(all_of![request::path("/certs"),])
.times(2)
.respond_with(json_encoded(response.clone())),
);
let client = JwkClient::with_ttl(Duration::from_secs(1));
let jwks_url = format!("http://{}/certs", server.addr());
let _key = client
.get_or_load_cert(
TEST_KEY_ID.to_string(),
Algorithm::RS256,
Some(jwks_url.clone()),
)
.await?;
let _key = client
.get_or_load_cert(
TEST_KEY_ID.to_string(),
Algorithm::RS256,
Some(jwks_url.clone()),
)
.await?;
tokio::time::sleep(Duration::from_secs(2)).await;
let _key = client
.get_or_load_cert(TEST_KEY_ID.to_string(), Algorithm::RS256, Some(jwks_url))
.await?;
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_get_or_load_cert_es256_success() -> TestResult {
let server = Server::run();
let response = create_es256_jwk_set_response();
server.expect(
Expectation::matching(all_of![request::path("/certs"),])
.times(1)
.respond_with(json_encoded(response.clone())),
);
let client = JwkClient::new();
let jwks_url = format!("http://{}/certs", server.addr());
let _key = client
.get_or_load_cert(
TEST_KEY_ID.to_string(),
Algorithm::ES256,
Some(jwks_url.clone()),
)
.await?;
let _key = client
.get_or_load_cert(TEST_KEY_ID.to_string(), Algorithm::ES256, Some(jwks_url))
.await?;
Ok(())
}
}