use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use rsa::{
BigUint, RsaPublicKey,
pkcs8::{EncodePublicKey, LineEnding},
};
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::{
api::auth::sasl::oauth::{OauthValidator, ValidatorModuleResult},
error::{PgWireError, PgWireResult},
};
#[derive(Debug, Deserialize)]
struct SimpleOidcDiscovery {
jwks_uri: String,
}
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize)]
struct Jwk {
kid: String,
kty: String,
n: Option<String>,
e: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Claims {
sub: String,
#[serde(default)]
scope: Option<String>,
#[serde(default)]
preferred_username: Option<String>,
#[serde(default)]
email: Option<String>,
}
#[derive(Debug)]
pub struct SimpleOidcValidator {
issuer: String,
client: reqwest::Client,
key_cache: Arc<RwLock<HashMap<String, String>>>,
jwks_uri: Arc<RwLock<Option<String>>>,
}
impl SimpleOidcValidator {
pub async fn new(issuer: impl Into<String>) -> Result<Self, PgWireError> {
let issuer = issuer.into();
let client = reqwest::Client::new();
let uri = Self::fetch_jwks_uri(&client, &issuer).await?;
Ok(Self {
issuer,
client,
key_cache: Arc::new(RwLock::new(HashMap::new())),
jwks_uri: Arc::new(RwLock::new(Some(uri))),
})
}
async fn fetch_jwks_uri(client: &reqwest::Client, issuer: &str) -> Result<String, PgWireError> {
let url = format!(
"{}/.well-known/openid-configuration",
issuer.trim_end_matches('/')
);
let discovery: SimpleOidcDiscovery = client
.get(&url)
.send()
.await
.map_err(|e| PgWireError::OAuthValidationError(format!("Discovery failed: {}", e)))?
.json()
.await
.map_err(|e| PgWireError::OAuthValidationError(format!("Invalid discovery: {}", e)))?;
Ok(discovery.jwks_uri)
}
async fn get_uri(&self) -> Result<String, PgWireError> {
let cache = self.jwks_uri.read().await;
if let Some(uri) = cache.as_ref() {
return Ok(uri.clone());
}
let uri = Self::fetch_jwks_uri(&self.client, &self.issuer).await?;
let mut cache = self.jwks_uri.write().await;
*cache = Some(uri.clone());
Ok(uri)
}
fn jwk_to_pem(&self, jwk: &Jwk) -> Result<String, PgWireError> {
if jwk.kty != "RSA" {
return Err(PgWireError::OAuthValidationError(format!(
"only RSA key type is supported. Got: {}",
jwk.kty
)));
}
let n = jwk
.n
.as_ref()
.ok_or_else(|| PgWireError::OAuthValidationError("modulus is missing".to_string()))?;
let e = jwk
.e
.as_ref()
.ok_or_else(|| PgWireError::OAuthValidationError("exponent is missing".to_string()))?;
let n_bytes = BASE64_URL_SAFE_NO_PAD
.decode(n)
.map_err(|e| PgWireError::OAuthValidationError(e.to_string()))?;
let e_bytes = BASE64_URL_SAFE_NO_PAD
.decode(e)
.map_err(|e| PgWireError::OAuthValidationError(e.to_string()))?;
let public_key = RsaPublicKey::new(
BigUint::from_bytes_be(&n_bytes),
BigUint::from_bytes_be(&e_bytes),
)
.map_err(|e| PgWireError::OAuthValidationError(e.to_string()))?;
public_key
.to_public_key_pem(LineEnding::LF)
.map_err(|e| PgWireError::OAuthValidationError(e.to_string()))
}
fn check_scopes(granted: Option<&str>, required: &str) -> bool {
if required.is_empty() {
return true;
}
let granted: Vec<&str> = granted.unwrap_or("").split_whitespace().collect();
required.split_whitespace().all(|r| granted.contains(&r))
}
async fn get_pk(&self, kid: &str) -> Result<DecodingKey, PgWireError> {
{
let cache = self.key_cache.read().await;
if let Some(pem) = cache.get(kid) {
return DecodingKey::from_rsa_pem(pem.as_bytes())
.map_err(|err| PgWireError::OAuthValidationError(err.to_string()));
}
}
let uri = self.get_uri().await?;
let jwks: Jwks = self
.client
.get(&uri)
.send()
.await
.map_err(|err| {
PgWireError::OAuthValidationError(format!(
"failed to fetch jwks from uri: {uri}. Err: {}",
err
))
})?
.json()
.await
.map_err(|err| {
PgWireError::OAuthValidationError(format!("invalid jwks format. Err {}", err))
})?;
let jwk: Jwk = jwks
.keys
.into_iter()
.find(|k| k.kid == kid)
.ok_or_else(|| PgWireError::OAuthValidationError(format!("key not found: {}", kid)))?;
let pem = self.jwk_to_pem(&jwk)?;
{
let mut cache = self.key_cache.write().await;
cache.insert(kid.to_string(), pem.clone());
}
DecodingKey::from_rsa_pem(pem.as_bytes())
.map_err(|err| PgWireError::OAuthValidationError(err.to_string()))
}
}
#[async_trait]
impl OauthValidator for SimpleOidcValidator {
async fn validate(
&self,
token: &str,
username: &str,
_issuer: &str,
required_scopes: &str,
) -> PgWireResult<ValidatorModuleResult> {
let header = decode_header(token).map_err(|e| {
PgWireError::OAuthValidationError(format!("Invalid token header: {}", e))
})?;
let kid = header.kid.ok_or_else(|| {
PgWireError::OAuthValidationError("Missing 'kid' in token".to_string())
})?;
let key = self.get_pk(&kid).await?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[&self.issuer]);
validation.validate_aud = false;
let token_data = decode::<Claims>(token, &key, &validation).map_err(|e| {
PgWireError::OAuthValidationError(format!("Token validation failed: {}", e))
})?;
let claims = token_data.claims;
let authn_id = claims
.preferred_username
.or(claims.email)
.unwrap_or(claims.sub.clone());
if username != authn_id && username != claims.sub {
return Ok(ValidatorModuleResult {
authorized: false,
authn_id: Some(authn_id),
metadata: None,
});
}
if !Self::check_scopes(claims.scope.as_deref(), required_scopes) {
return Ok(ValidatorModuleResult {
authorized: false,
authn_id: Some(authn_id),
metadata: None,
});
}
Ok(ValidatorModuleResult {
authorized: true,
authn_id: Some(authn_id),
metadata: None,
})
}
}