use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use openid::biscuit::jwk::{AlgorithmParameters, KeyType, PublicKeyUse};
use serde::Deserialize;
use tokio::sync::{RwLock, RwLockReadGuard};
use tokio::time::{Duration, Instant};
use url::Url;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use super::Authenticator;
use crate::authz::Authorizable;
const ONE_HOUR: Duration = Duration::from_secs(3600);
const SUPPORTED_ALGORITHMS: &[Algorithm] = &[Algorithm::RS256, Algorithm::RS384, Algorithm::RS512];
#[derive(Clone)]
pub struct OidcAuthenticator {
authority_url: Url,
device_auth_url: String,
client_id: String,
token_url: String,
key_cache: Arc<RwLock<HashMap<String, DecodingKey>>>,
last_refresh: Arc<RwLock<Instant>>,
validator: Validation,
}
impl OidcAuthenticator {
pub async fn new(
oidc_authority_url: &str,
device_auth_url: &str,
client_id: &str,
) -> anyhow::Result<Self> {
let authority_url: Url = oidc_authority_url
.trim_end_matches('/')
.to_owned()
.parse()?;
let discovery = openid::DiscoveredClient::discover(
client_id.to_owned(),
String::new(),
None,
authority_url.clone(),
)
.await
.map_err(|e| anyhow::anyhow!("Unable to fetch discovery data: {}", e))?;
let mut issuers = HashSet::with_capacity(1);
issuers.insert(authority_url.to_string());
let mut validator = Validation::default();
validator.validate_nbf = true;
validator.iss = Some(issuers);
validator.algorithms = SUPPORTED_ALGORITHMS.to_owned();
let me = OidcAuthenticator {
authority_url: authority_url.clone(),
device_auth_url: device_auth_url.to_owned(),
client_id: client_id.to_owned(),
token_url: discovery.config().token_endpoint.to_string(),
key_cache: Arc::new(RwLock::new(HashMap::new())),
last_refresh: Arc::new(RwLock::new(Instant::now())),
validator,
};
me.update_keys(discovery).await?;
Ok(me)
}
async fn find_key(&self, kid: &str) -> anyhow::Result<RwLockReadGuard<'_, DecodingKey>> {
if let Ok(k) = self.lookup_key(kid).await {
return Ok(k);
}
if self.last_refresh.read().await.elapsed() >= ONE_HOUR {
let discovery = openid::DiscoveredClient::discover(
self.client_id.clone(),
String::new(),
None,
self.authority_url.clone(),
)
.await
.map_err(|e| anyhow::anyhow!("Unable to fetch key set from issuer on update: {}", e))?;
self.update_keys(discovery).await?;
} else {
anyhow::bail!("A key with a key_id of {} was not found", kid)
}
self.lookup_key(kid).await
}
async fn lookup_key(&self, kid: &str) -> anyhow::Result<RwLockReadGuard<'_, DecodingKey>> {
let keys = self.key_cache.read().await;
RwLockReadGuard::try_map(keys, |k| k.get(kid))
.map_err(|_| anyhow::anyhow!("A key with a key_id of {} was not found", kid))
}
async fn update_keys(&self, discovery: openid::DiscoveredClient) -> anyhow::Result<()> {
let mut key_cache = self.key_cache.write().await;
*key_cache = discovery
.jwks
.ok_or_else(|| anyhow::anyhow!("Issuer has no keys available"))?
.keys
.into_iter()
.filter_map(|k| {
if !k
.common
.public_key_use
.map_or(false, |u| matches!(u, PublicKeyUse::Signature))
{
return None;
}
let key = match k.algorithm.key_type() {
KeyType::EllipticCurve => {
let raw =
base64::decode(k.common.x509_chain.unwrap_or_default().pop()?).ok()?;
DecodingKey::from_ec_der(&raw)
}
KeyType::RSA => {
if let AlgorithmParameters::RSA(rsa) = k.algorithm {
DecodingKey::from_rsa_components(
&base64::encode_config(
rsa.n.to_bytes_be(),
base64::URL_SAFE_NO_PAD,
),
&base64::encode_config(
rsa.e.to_bytes_be(),
base64::URL_SAFE_NO_PAD,
),
).map_or_else(|e| {
tracing::error!(error = %e, "Unable to parse decoding key from discovery client, skipping");
None
}, Some)?
} else {
return None;
}
}
_ => return None,
};
Some((k.common.key_id.unwrap_or_default(), key))
})
.collect();
let mut last_refresh = self.last_refresh.write().await;
*last_refresh = Instant::now();
Ok(())
}
}
#[derive(Debug, Deserialize)]
struct Claims {
preferred_username: Option<String>,
email: Option<String>,
email_verified: Option<bool>,
sub: String,
iss: String,
groups: Option<Vec<String>>,
}
#[async_trait::async_trait]
impl Authenticator for OidcAuthenticator {
type Item = OidcUser;
async fn authenticate(&self, auth_data: &str) -> anyhow::Result<Self::Item> {
let raw_token = auth_data
.trim_start_matches("Bearer")
.trim_start_matches("bearer")
.trim();
let header = jsonwebtoken::decode_header(raw_token)?;
let key = self.find_key(&header.kid.unwrap_or_default()).await?;
let token = jsonwebtoken::decode::<Claims>(raw_token, &key, &self.validator)?;
Ok(token.claims.into())
}
fn client_id(&self) -> &str {
&self.client_id
}
fn auth_url(&self) -> &str {
&self.device_auth_url
}
fn token_url(&self) -> &str {
&self.token_url
}
}
pub struct OidcUser {
principal: String,
groups: Vec<String>,
}
impl Authorizable for OidcUser {
fn principal(&self) -> String {
self.principal.clone()
}
fn groups(&self) -> Vec<String> {
self.groups.clone()
}
}
impl From<Claims> for OidcUser {
fn from(c: Claims) -> Self {
let username = match (c.preferred_username, c.email) {
(Some(u), _) => u,
(None, Some(e)) if c.email_verified.unwrap_or_default() => e,
_ => c.sub,
};
OidcUser {
principal: format!("{}@{}", username, c.iss),
groups: c.groups.unwrap_or_default(),
}
}
}