use crate::provider::HTTP_CLIENT;
use crate::rauthy_error::RauthyError;
use crate::{base64_url_no_pad_decode, base64_url_no_pad_decode_buf};
use cached::Cached;
use serde::Deserialize;
use std::borrow::Cow;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, warn};
static JWKS_TX: OnceLock<mpsc::UnboundedSender<JwksMsg>> = OnceLock::new();
#[derive(Debug)]
pub(crate) enum JwksMsg {
Get((String, oneshot::Sender<Result<JwkPublicKey, RauthyError>>)),
Update,
NewJwksUri(String),
}
impl JwksMsg {
pub(crate) fn send(self) -> Result<(), RauthyError> {
JWKS_TX
.get()
.ok_or(RauthyError::Init("JWKS_TX has not been initialized"))?
.send(self)
.map_err(|err| RauthyError::Internal(Cow::from(err.to_string())))?;
Ok(())
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
pub enum JwkKeyPairAlg {
RS256,
RS384,
RS512,
#[default]
EdDSA,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(clippy::upper_case_acronyms)] pub(crate) enum JwkKeyPairType {
RSA,
OKP,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)] pub(crate) struct JwkPublicKey {
pub kty: JwkKeyPairType,
pub alg: JwkKeyPairAlg,
pub kid: String,
pub crv: Option<String>, pub n: Option<String>, pub n_bytes: Option<Vec<u8>>, pub e: Option<String>, pub e_bytes: Option<Vec<u8>>, pub x: Option<String>, pub x_bytes: Option<Vec<u8>>, }
impl JwkPublicKey {
#[inline]
pub(crate) async fn get_for_token(token: &str) -> Result<Self, RauthyError> {
let Some((metadata, _)) = token.split_once(".") else {
return Err(RauthyError::InvalidJwt(
"JWT token does not contain any metadata",
));
};
let json = base64_url_no_pad_decode(metadata)?;
let serde_json::Value::Object(meta) = serde_json::from_slice::<serde_json::Value>(&json)?
else {
return Err(RauthyError::InvalidClaims(
"JWT token metadata is no JSON object",
));
};
let Some(kid) = meta.get("kid") else {
return Err(RauthyError::InvalidClaims("No 'kid' in JWT token header"));
};
Self::get_for_kid(kid.as_str().unwrap_or_default()).await
}
#[inline]
pub(crate) async fn get_for_kid(kid: &str) -> Result<Self, RauthyError> {
let (tx, rx) = oneshot::channel();
JwksMsg::Get((kid.to_string(), tx)).send()?;
rx.await
.map_err(|err| RauthyError::Internal(Cow::from(err.to_string())))?
}
#[cfg(feature = "rsa")]
#[inline(always)]
fn e(&self) -> Result<rsa::BigUint, RauthyError> {
match &self.e_bytes {
None => Err(RauthyError::JWK("Missing 'e' in JWK".into())),
Some(bytes) => Ok(rsa::BigUint::from_bytes_be(bytes)),
}
}
#[cfg(feature = "rsa")]
#[inline(always)]
fn n(&self) -> Result<rsa::BigUint, RauthyError> {
match &self.n_bytes {
None => Err(RauthyError::JWK("Missing 'n' in JWK".into())),
Some(bytes) => Ok(rsa::BigUint::from_bytes_be(bytes)),
}
}
#[inline(always)]
fn x(&self) -> Result<&[u8], RauthyError> {
match &self.x_bytes {
None => Err(RauthyError::JWK("Missing 'x' in JWK".into())),
Some(bytes) => Ok(bytes),
}
}
#[inline(always)]
pub fn validate_token_signature(
&self,
token: &str,
buf: &mut Vec<u8>,
) -> Result<(), RauthyError> {
let (message, sig) = token
.rsplit_once('.')
.ok_or(RauthyError::MalformedJwt("Malformed token"))?;
buf.clear();
base64_url_no_pad_decode_buf(sig, buf)?;
match self.alg {
JwkKeyPairAlg::RS256 => {
#[cfg(feature = "rsa")]
{
let hash = hmac_sha256::Hash::hash(message.as_bytes());
let rsa_pk = rsa::RsaPublicKey::new(self.n()?, self.e()?)?;
if rsa_pk
.verify(
rsa::Pkcs1v15Sign::new::<sha2::Sha256>(),
hash.as_slice(),
buf,
)
.is_ok()
{
return Ok(());
}
}
#[cfg(not(feature = "rsa"))]
error!("Cannot validate RSA tokens without the `rsa` feature");
}
JwkKeyPairAlg::RS384 => {
#[cfg(feature = "rsa")]
{
let hash = hmac_sha512::sha384::Hash::hash(message.as_bytes());
let rsa_pk = rsa::RsaPublicKey::new(self.n()?, self.e()?)?;
if rsa_pk
.verify(
rsa::Pkcs1v15Sign::new::<sha2::Sha384>(),
hash.as_slice(),
buf,
)
.is_ok()
{
return Ok(());
}
}
#[cfg(not(feature = "rsa"))]
error!("Cannot validate RSA tokens without the `rsa` feature");
}
JwkKeyPairAlg::RS512 => {
#[cfg(feature = "rsa")]
{
let hash = hmac_sha512::Hash::hash(message.as_bytes());
let rsa_pk = rsa::RsaPublicKey::new(self.n()?, self.e()?)?;
if rsa_pk
.verify(
rsa::Pkcs1v15Sign::new::<sha2::Sha512>(),
hash.as_slice(),
buf,
)
.is_ok()
{
return Ok(());
}
}
#[cfg(not(feature = "rsa"))]
error!("Cannot validate RSA tokens without the `rsa` feature");
}
JwkKeyPairAlg::EdDSA => {
let pubkey = ed25519_compact::PublicKey::from_slice(self.x()?)?;
let signature = ed25519_compact::Signature::from_slice(buf)?;
if pubkey.verify(message, &signature).is_ok() {
return Ok(());
}
}
};
warn!("JWT Token validation error");
Err(RauthyError::InvalidJwt("Invalid JWT Token signature"))
}
}
#[derive(Debug, Default, Deserialize)]
pub(crate) struct JwksCerts {
pub keys: Vec<JwkPublicKey>,
}
pub(crate) async fn jwks_handler() {
let (tx, mut rx) = mpsc::unbounded_channel();
if JWKS_TX.set(tx).is_err() {
error!("Error initializing JWKS_TX");
}
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(3600));
loop {
if let Err(err) = JwksMsg::Update.send() {
error!("Error Updating JWKS - this should never happen: {:?}", err);
}
interval.tick().await;
}
});
tokio::spawn(async move {
let mut jwks_uri = None;
let mut recently_looked_up = cached::TimedCache::with_lifespan(Duration::from_secs(300));
let mut jwks: Vec<JwkPublicKey> = Vec::with_capacity(4);
let update = |jwks_uri: Option<String>, curr_jwks: Vec<JwkPublicKey>| async {
let uri = if let Some(uri) = jwks_uri {
uri
} else {
debug!("Cannot update JWKS with no configured OIDC provider");
return curr_jwks;
};
info!("Updating JWKS from Rauthy");
let client = HTTP_CLIENT.get().expect("HTTP_CLIENT to be initialized");
match client.get(&uri).send().await {
Ok(res) => {
if !res.status().is_success() {
error!("Error fetching JWKS from {}", uri);
return curr_jwks;
}
let certs = match res.json::<JwksCerts>().await {
Ok(jwks) => jwks,
Err(err) => {
error!("Error deserializing JWKS from {}: {:?}", uri, err);
return curr_jwks;
}
};
certs
.keys
.into_iter()
.filter_map(|mut key| {
if key.alg == JwkKeyPairAlg::EdDSA {
if let Some(x) = &key.x {
match base64_url_no_pad_decode(x) {
Ok(bytes) => {
key.x_bytes = Some(bytes)
}
Err(err) => {
error!("Error pre-decoding given EdDSA 'x' pub key bytes: {}", err);
return None;
}
}
}
} else {
if let Some(e) = &key.e {
match base64_url_no_pad_decode(e) {
Ok(bytes) => {
key.e_bytes = Some(bytes)
}
Err(err) => {
error!("Error pre-decoding given RSA 'e' pub key bytes: {}", err);
return None;
}
}
}
if let Some(n) = &key.n {
match base64_url_no_pad_decode(n) {
Ok(bytes) => {
key.n_bytes = Some(bytes)
}
Err(err) => {
error!("Error pre-decoding given RSA 'e' pub key bytes: {}", err);
return None;
}
}
}
}
Some(key)
})
.collect()
}
Err(err) => {
error!("Error fetching JWKS from Rauthy {}: {:?}", uri, err);
curr_jwks
}
}
};
'main: while let Some(msg) = rx.recv().await {
match msg {
JwksMsg::Get((kid, tx_ack)) => {
for jwk in &jwks {
if jwk.kid == kid {
tx_ack.send(Ok(jwk.clone())).unwrap();
continue 'main;
}
}
if recently_looked_up.cache_get(&kid).is_some() {
tx_ack
.send(Err(RauthyError::InvalidClaims(
"'kid' not found and it has been recently looked up",
)))
.unwrap();
continue;
}
jwks = update(jwks_uri.clone(), jwks).await;
recently_looked_up.cache_set(kid.clone(), ());
for jwk in &jwks {
if jwk.kid == kid {
tx_ack.send(Ok(jwk.clone())).unwrap();
continue 'main;
}
}
tx_ack
.send(Err(RauthyError::InvalidClaims(
"'kid' not found after updating JWKs",
)))
.unwrap();
}
JwksMsg::Update => {
jwks = update(jwks_uri.clone(), jwks).await;
}
JwksMsg::NewJwksUri(uri) => {
info!("Received a new JWKS URI: {}", uri);
jwks_uri = Some(uri);
JwksMsg::Update.send().unwrap();
}
}
}
});
}