use std::collections::HashMap;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use jsonwebtoken::{
jwk::{self},
DecodingKey,
};
use serde::Deserialize;
use thiserror::Error;
#[derive(Clone)]
#[allow(dead_code)]
pub struct Jwks {
keys: HashMap<String, Jwk>,
}
#[derive(Deserialize)]
struct OIDCConfig {
jwks_uri: String,
}
impl Jwks {
pub async fn from_oidc_url(oidc_url: impl Into<String>) -> Result<Self, JwksError> {
Self::from_oidc_url_with_client(&reqwest::Client::default(), oidc_url).await
}
pub async fn from_oidc_url_with_client(
client: &reqwest::Client,
oidc_url: impl Into<String>,
) -> Result<Self, JwksError> {
let oidc_config = client
.get(oidc_url.into())
.send()
.await?
.json::<OIDCConfig>()
.await?;
let jwks_uri = oidc_config.jwks_uri;
Self::from_jwks_url_with_client(&reqwest::Client::default(), &jwks_uri).await
}
pub async fn from_jwks_url(jwks_url: impl Into<String>) -> Result<Self, JwksError> {
Self::from_jwks_url_with_client(&reqwest::Client::default(), jwks_url.into()).await
}
pub async fn from_jwks_url_with_client(
client: &reqwest::Client,
jwks_url: impl Into<String>,
) -> Result<Self, JwksError> {
let jwks: jwk::JwkSet = client.get(jwks_url.into()).send().await?.json().await?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
let kid = jwk.common.key_id.ok_or(JwkError::MissingKeyId)?;
match &jwk.algorithm {
jwk::AlgorithmParameters::RSA(params) => {
let decoding_key = DecodingKey::from_rsa_components(¶ms.n, ¶ms.e)
.map_err(|err| JwkError::DecodingError {
key_id: kid.clone(),
error: err,
})?;
keys.insert(
kid,
Jwk {
decoding_key: decoding_key,
},
);
}
jwk::AlgorithmParameters::EllipticCurve(params) => {
let decoding_key = DecodingKey::from_ec_components(¶ms.x, ¶ms.y)
.map_err(|err| JwkError::DecodingError {
key_id: kid.clone(),
error: err,
})?;
keys.insert(
kid,
Jwk {
decoding_key: decoding_key,
},
);
}
jwk::AlgorithmParameters::OctetKeyPair(params) => {
let decoding_key =
DecodingKey::from_ed_components(¶ms.x).map_err(|err| {
JwkError::DecodingError {
key_id: kid.clone(),
error: err,
}
})?;
keys.insert(
kid,
Jwk {
decoding_key: decoding_key,
},
);
}
jwk::AlgorithmParameters::OctetKey(params) => {
let base64_decoded = URL_SAFE_NO_PAD.decode(¶ms.value).map_err(|err| {
JwkError::DecodingError {
key_id: kid.clone(),
error: err.into(),
}
})?;
let decoding_key = DecodingKey::from_secret(&base64_decoded);
keys.insert(
kid,
Jwk {
decoding_key: decoding_key,
},
);
}
}
}
Ok(Self { keys })
}
}
#[derive(Clone)]
#[allow(dead_code)]
struct Jwk {
decoding_key: DecodingKey,
}
#[derive(Debug, Error)]
pub enum JwksError {
#[error("could not fetch config from authority: {0}")]
FetchError(#[from] reqwest::Error),
#[error("there was an error with an individual key: {0}")]
KeyError(#[from] JwkError),
}
#[derive(Debug, Error)]
pub enum JwkError {
#[error("could not construct a decoding key for {key_id:?}: {error:?}")]
DecodingError {
key_id: String,
error: jsonwebtoken::errors::Error,
},
#[error("the key {key_id:?} does not specify an algorithm")]
MissingAlgorithm { key_id: String },
#[error("the key is missing the `kid` attribute")]
MissingKeyId,
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use httpmock::prelude::*;
#[tokio::test]
async fn can_fetch_and_parse_jwks_from_jwks_url() {
let server = MockServer::start();
let jwks_path = "/oauth2/v3/certs";
let jwks = json!({
"keys": [
{
"use": "sig",
"n": "jb1Ps3fdt0oPYPbQlfZqKkCXrM1qJ5EkfBHSMrPXPzh9QLwa43WCLEdrTcf5vI8cNwbgSxDlCDS2BzHQC0hYPwFkJaD6y6NIIcwdSMcKlQPwk4-sqJbz55_gyUWjifcpXXKbXDdnd2QzSE2YipareOPJaBs3Ybuvf_EePnYoKEhXNeGm_T3546A56uOV2mNEe6e-RaIa76i8kcx_8JP3FjqxZSWRrmGYwZJhTGbeY5pfOS6v_EYpA4Up1kZANWReeC3mgh3O78f5nKEDxwPf99bIQ22fIC2779HbfzO-ybqR_EJ0zv8LlqfT7dMjZs25LH8Jw5wGWjP_9efP8emTOw",
"kty": "RSA",
"alg": "RS256",
"e": "AQAB",
"kid": "91413cf4fa0cb92a3c3f5a054509132c47660937"
},
{
"n": "tgkwz0K80MycaI2Dz_jHkErJ_IHUPTlx4LR_6wltAHQW_ZwhMzINNH8vbWo8P5F2YLDiIbuslF9y7Q3izsPX3XWQyt6LI8ZT4gmGXQBumYMKx2VtbmTYIysKY8AY7x5UCDO-oaAcBuKQvWc5E31kXm6d6vfaEZjrMc_KT3DsFdN0LcAkB-Q9oYcVl7YEgAN849ROKUs6onf7eukj1PHwDzIBgA9AExJaKen0wITvxQv3H_BRXB7m6hFkLbK5Jo18gl3UxJ7Em29peEwi8Psn7MuI7CwhFNchKhjZM9eaMX27tpDPqR15-I6CA5Zf94rabUGWYph5cFXKWPPr8dskQQ",
"alg": "RS256",
"use": "sig",
"kid": "1f40f0a8ef3d880978dc82f25c3ec317c6a5b781",
"e": "AQAB",
"kty": "RSA"
}
]
});
let _ = server.mock(|when, then| {
when.method(GET).path(jwks_path);
then.status(200)
.header("content-type", "application/json")
.body(jwks.to_string());
});
let jwks_url = server.url(jwks_path);
let jwks = Jwks::from_jwks_url(&jwks_url).await.unwrap();
assert_eq!(jwks.keys.len(), 2);
_ = &jwks
.keys
.get("91413cf4fa0cb92a3c3f5a054509132c47660937")
.expect("key one should be found");
_ = &jwks
.keys
.get("1f40f0a8ef3d880978dc82f25c3ec317c6a5b781")
.expect("key two should be found");
}
#[tokio::test]
async fn can_fetch_and_parse_jwks_from_oidc_config_url() {
let oidc_server = MockServer::start();
let oidc_config_path = "/.well-known/openid-configuration";
let jwks_path = "/oauth2/v3/certs";
let jwks_url = oidc_server.url(jwks_path);
let oidc_config = json!({
"issuer": "https://accounts.google.com",
"authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
"device_authorization_endpoint": "https://oauth2.googleapis.com/device/code",
"token_endpoint": "https://oauth2.googleapis.com/token",
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
"revocation_endpoint": "https://oauth2.googleapis.com/revoke",
"jwks_uri": jwks_url,
"response_types_supported": [
"code",
"token",
"id_token",
"code token",
"code id_token",
"token id_token",
"code token id_token",
"none"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"email",
"profile"
],
"token_endpoint_auth_methods_supported": [
"client_secret_post",
"client_secret_basic"
],
"claims_supported": [
"aud",
"email",
"email_verified",
"exp",
"family_name",
"given_name",
"iat",
"iss",
"locale",
"name",
"picture",
"sub"
],
"code_challenge_methods_supported": [
"plain",
"S256"
],
"grant_types_supported": [
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:jwt-bearer"
]
});
let _ = oidc_server.mock(|when, then| {
when.method(GET).path(oidc_config_path);
then.status(200)
.header("content-type", "application/json")
.body(oidc_config.to_string());
});
let jwks = json!({
"keys": [
{
"use": "sig",
"n": "jb1Ps3fdt0oPYPbQlfZqKkCXrM1qJ5EkfBHSMrPXPzh9QLwa43WCLEdrTcf5vI8cNwbgSxDlCDS2BzHQC0hYPwFkJaD6y6NIIcwdSMcKlQPwk4-sqJbz55_gyUWjifcpXXKbXDdnd2QzSE2YipareOPJaBs3Ybuvf_EePnYoKEhXNeGm_T3546A56uOV2mNEe6e-RaIa76i8kcx_8JP3FjqxZSWRrmGYwZJhTGbeY5pfOS6v_EYpA4Up1kZANWReeC3mgh3O78f5nKEDxwPf99bIQ22fIC2779HbfzO-ybqR_EJ0zv8LlqfT7dMjZs25LH8Jw5wGWjP_9efP8emTOw",
"kty": "RSA",
"alg": "RS256",
"e": "AQAB",
"kid": "91413cf4fa0cb92a3c3f5a054509132c47660937"
},
{
"n": "tgkwz0K80MycaI2Dz_jHkErJ_IHUPTlx4LR_6wltAHQW_ZwhMzINNH8vbWo8P5F2YLDiIbuslF9y7Q3izsPX3XWQyt6LI8ZT4gmGXQBumYMKx2VtbmTYIysKY8AY7x5UCDO-oaAcBuKQvWc5E31kXm6d6vfaEZjrMc_KT3DsFdN0LcAkB-Q9oYcVl7YEgAN849ROKUs6onf7eukj1PHwDzIBgA9AExJaKen0wITvxQv3H_BRXB7m6hFkLbK5Jo18gl3UxJ7Em29peEwi8Psn7MuI7CwhFNchKhjZM9eaMX27tpDPqR15-I6CA5Zf94rabUGWYph5cFXKWPPr8dskQQ",
"alg": "RS256",
"use": "sig",
"kid": "1f40f0a8ef3d880978dc82f25c3ec317c6a5b781",
"e": "AQAB",
"kty": "RSA"
}
]
});
let _ = oidc_server.mock(|when, then| {
when.method(GET).path(jwks_path);
then.status(200)
.header("content-type", "application/json")
.body(jwks.to_string());
});
let oidc_config_url = oidc_server.url(oidc_config_path);
let jwks = Jwks::from_oidc_url(&oidc_config_url).await.unwrap();
assert_eq!(jwks.keys.len(), 2);
_ = &jwks
.keys
.get("91413cf4fa0cb92a3c3f5a054509132c47660937")
.expect("key one should be found");
_ = &jwks
.keys
.get("1f40f0a8ef3d880978dc82f25c3ec317c6a5b781")
.expect("key two should be found");
}
}