use crate::types::{AuthContext, AuthProvider, AuthType};
use anyhow::{Result, anyhow};
use base64::Engine;
use chrono::{DateTime, TimeDelta, Utc};
use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
use moka::future::Cache;
use openidconnect::core::{CoreJsonWebKeySet, CoreProviderMetadata};
use openidconnect::{IssuerUrl, JsonWebKey};
use rsa::pkcs1::EncodeRsaPublicKey;
use rsa::{BigUint, RsaPublicKey};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
pub fn create_http_client() -> Result<reqwest::Client> {
reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| anyhow!("Failed to create HTTP client: {e:?}"))
}
async fn fetch_jwks(issuer_url: &IssuerUrl) -> Result<Arc<CoreJsonWebKeySet>> {
let http_client = create_http_client()?;
let metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
.await
.map_err(|e| {
anyhow!(
"Failed to discover OIDC metadata from {}: {e:?}",
issuer_url
)
})?;
let jwks_uri = metadata.jwks_uri();
let jwks: CoreJsonWebKeySet = http_client
.get(jwks_uri.url().as_str())
.send()
.await
.map_err(|e| anyhow!("Failed to fetch JWKS from {}: {e:?}", jwks_uri))?
.json()
.await
.map_err(|e| anyhow!("Failed to parse JWKS: {e:?}"))?;
Ok(Arc::new(jwks))
}
struct JwksCache {
issuer_url: IssuerUrl,
cache: Cache<String, Arc<CoreJsonWebKeySet>>,
}
impl JwksCache {
fn new(issuer_url: IssuerUrl, ttl: Duration) -> Self {
let cache = Cache::builder().time_to_live(ttl).build();
Self { issuer_url, cache }
}
async fn get(&self) -> Result<Arc<CoreJsonWebKeySet>> {
let issuer_url = self.issuer_url.clone();
self.cache
.try_get_with(
"jwks".to_string(),
async move { fetch_jwks(&issuer_url).await },
)
.await
.map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct OidcIssuer {
pub issuer: String,
pub audience: String,
}
const DEFAULT_JWKS_REFRESH_INTERVAL_SECS: u64 = 3600;
const DEFAULT_TOKEN_CACHE_SIZE: u64 = 1000;
const DEFAULT_TOKEN_CACHE_TTL_SECS: u64 = 300;
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct OidcConfig {
pub issuers: Vec<OidcIssuer>,
pub jwks_refresh_interval_secs: u64,
pub token_cache_size: u64,
pub token_cache_ttl_secs: u64,
}
impl Default for OidcConfig {
fn default() -> Self {
Self {
issuers: Vec::new(),
jwks_refresh_interval_secs: DEFAULT_JWKS_REFRESH_INTERVAL_SECS,
token_cache_size: DEFAULT_TOKEN_CACHE_SIZE,
token_cache_ttl_secs: DEFAULT_TOKEN_CACHE_TTL_SECS,
}
}
}
impl OidcConfig {
pub fn from_env() -> Result<Self> {
let json = std::env::var("MICROMEGAS_OIDC_CONFIG")
.map_err(|_| anyhow!("MICROMEGAS_OIDC_CONFIG environment variable not set"))?;
let config: OidcConfig = serde_json::from_str(&json)
.map_err(|e| anyhow!("Failed to parse MICROMEGAS_OIDC_CONFIG: {e:?}"))?;
Ok(config)
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
enum Audience {
Single(String),
Multiple(Vec<String>),
}
impl Audience {
fn contains(&self, aud: &str) -> bool {
match self {
Audience::Single(s) => s == aud,
Audience::Multiple(v) => v.iter().any(|a| a == aud),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
iss: String,
sub: String,
aud: Audience,
exp: i64,
#[serde(skip_serializing_if = "Option::is_none")]
email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
verified_primary_email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
preferred_username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
upn: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
unique_name: Option<String>,
#[serde(rename = "https://micromegas.io/email")]
#[serde(skip_serializing_if = "Option::is_none")]
namespaced_email: Option<String>,
#[serde(rename = "https://micromegas.io/name")]
#[serde(skip_serializing_if = "Option::is_none")]
namespaced_name: Option<String>,
}
impl Claims {
fn get_email(&self) -> Option<String> {
self.verified_primary_email
.clone()
.or_else(|| self.email.clone())
.or_else(|| self.namespaced_email.clone())
.or_else(|| self.preferred_username.clone())
.or_else(|| self.upn.clone())
.or_else(|| self.unique_name.clone())
}
}
struct OidcIssuerClient {
issuer: String,
audience: String,
jwks_cache: JwksCache,
}
impl OidcIssuerClient {
fn new(issuer: String, audience: String, jwks_ttl: Duration) -> Result<Self> {
let issuer_url = IssuerUrl::new(issuer.clone())
.map_err(|e| anyhow!("Invalid issuer URL '{}': {e:?}", issuer))?;
Ok(Self {
issuer,
audience,
jwks_cache: JwksCache::new(issuer_url, jwks_ttl),
})
}
}
fn load_admin_users() -> Vec<String> {
match std::env::var("MICROMEGAS_ADMINS") {
Ok(json) => serde_json::from_str::<Vec<String>>(&json).unwrap_or_default(),
Err(_) => vec![],
}
}
fn jwk_to_decoding_key(
jwk: &openidconnect::core::CoreJsonWebKey,
) -> Result<jsonwebtoken::DecodingKey> {
let jwk_json =
serde_json::to_value(jwk).map_err(|e| anyhow!("Failed to serialize JWK: {e:?}"))?;
let n = jwk_json
.get("n")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("JWK missing 'n' parameter"))?;
let e = jwk_json
.get("e")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("JWK missing 'e' parameter"))?;
let n_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(n.as_bytes())
.map_err(|e| anyhow!("Failed to decode 'n': {e:?}"))?;
let e_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(e.as_bytes())
.map_err(|e| anyhow!("Failed to decode 'e': {e:?}"))?;
let n_bigint = BigUint::from_bytes_be(&n_bytes);
let e_bigint = BigUint::from_bytes_be(&e_bytes);
let public_key = RsaPublicKey::new(n_bigint, e_bigint)
.map_err(|e| anyhow!("Failed to create RSA public key: {e:?}"))?;
let pem = public_key
.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
.map_err(|e| anyhow!("Failed to encode public key as PEM: {e:?}"))?;
jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
.map_err(|e| anyhow!("Failed to create decoding key: {e:?}"))
}
pub struct OidcAuthProvider {
clients: HashMap<String, Vec<Arc<OidcIssuerClient>>>,
token_cache: Cache<String, Arc<AuthContext>>,
admin_users: Vec<String>,
}
impl std::fmt::Debug for OidcAuthProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OidcAuthProvider")
.field("num_clients", &self.clients.len())
.field("admin_users", &"(not printed)")
.finish()
}
}
impl OidcAuthProvider {
pub async fn new(config: OidcConfig) -> Result<Self> {
if config.issuers.is_empty() {
return Err(anyhow!("At least one OIDC issuer must be configured"));
}
micromegas_tracing::info!("Configuring OIDC with {} issuer(s)", config.issuers.len());
for (idx, issuer_config) in config.issuers.iter().enumerate() {
micromegas_tracing::info!(
" Issuer {}: {} (audience: {})",
idx + 1,
issuer_config.issuer,
issuer_config.audience
);
}
let jwks_ttl = Duration::from_secs(config.jwks_refresh_interval_secs);
let mut clients: HashMap<String, Vec<Arc<OidcIssuerClient>>> = HashMap::new();
for issuer_config in config.issuers {
let client = OidcIssuerClient::new(
issuer_config.issuer.clone(),
issuer_config.audience,
jwks_ttl,
)?;
clients
.entry(issuer_config.issuer)
.or_default()
.push(Arc::new(client));
}
let token_cache = Cache::builder()
.max_capacity(config.token_cache_size)
.time_to_live(Duration::from_secs(config.token_cache_ttl_secs))
.build();
let admin_users = load_admin_users();
Ok(Self {
clients,
token_cache,
admin_users,
})
}
fn is_admin(&self, subject: &str, email: Option<&str>) -> bool {
self.admin_users
.iter()
.any(|admin| admin == subject || email.map(|e| admin == e).unwrap_or(false))
}
fn decode_payload_unsafe(&self, token: &str) -> Result<Claims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(anyhow!("Invalid JWT format"));
}
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1].as_bytes())
.map_err(|e| anyhow!("Failed to decode JWT payload: {e:?}"))?;
let claims: Claims = serde_json::from_slice(&payload_bytes)
.map_err(|e| anyhow!("Failed to parse JWT claims: {e:?}"))?;
Ok(claims)
}
async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext> {
let header = decode_header(token).map_err(|e| anyhow!("Invalid JWT header: {e:?}"))?;
let kid = header.kid;
let unverified_claims = self.decode_payload_unsafe(token)?;
let issuer_clients = self
.clients
.get(&unverified_claims.iss)
.ok_or_else(|| anyhow!("Unknown issuer: {}", unverified_claims.iss))?;
let first_client = issuer_clients
.first()
.ok_or_else(|| anyhow!("No clients configured for issuer"))?;
let jwks = first_client
.jwks_cache
.get()
.await
.map_err(|e| anyhow!("Failed to fetch JWKS: {e:?}"))?;
let keys_to_try: Vec<_> = if let Some(ref kid_value) = kid {
jwks.keys()
.iter()
.filter(|k| k.key_id().map(|id| id.as_str()) == Some(kid_value.as_str()))
.collect()
} else {
jwks.keys().iter().collect()
};
if keys_to_try.is_empty() {
return Err(if let Some(kid_value) = kid {
anyhow!("Key with kid '{}' not found in JWKS", kid_value)
} else {
anyhow!("No keys found in JWKS")
});
}
let mut key_error = anyhow!("No valid key found");
for key in keys_to_try {
match jwk_to_decoding_key(key) {
Ok(decoding_key) => {
match self
.try_validate_with_key(token, &decoding_key, issuer_clients)
.await
{
Ok(auth_ctx) => return Ok(auth_ctx),
Err(e) => key_error = e,
}
}
Err(e) => key_error = e,
}
}
Err(key_error)
}
async fn try_validate_with_key(
&self,
token: &str,
decoding_key: &jsonwebtoken::DecodingKey,
issuer_clients: &[Arc<OidcIssuerClient>],
) -> Result<AuthContext> {
let configured_audiences: Vec<String> =
issuer_clients.iter().map(|c| c.audience.clone()).collect();
let mut last_error = anyhow!("No matching audience found");
for client in issuer_clients {
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_aud = false;
validation.set_issuer(&[&client.issuer]);
let claims = match decode::<Claims>(token, decoding_key, &validation) {
Ok(token_data) => token_data.claims,
Err(e) => {
last_error = anyhow!("Token validation failed: {e:?}");
continue;
}
};
if claims.aud.contains(&client.audience) {
let expires_at = DateTime::from_timestamp(claims.exp, 0)
.ok_or_else(|| anyhow!("Invalid expiration timestamp"))?;
if expires_at < Utc::now() {
return Err(anyhow!("Token has expired"));
}
let email = claims.get_email();
let is_admin = self.is_admin(&claims.sub, email.as_deref());
return Ok(AuthContext {
subject: claims.sub,
email,
issuer: claims.iss,
audience: Some(client.audience.clone()),
expires_at: Some(expires_at),
auth_type: AuthType::Oidc,
is_admin,
allow_delegation: false,
});
} else {
let actual_audiences = match &claims.aud {
Audience::Single(s) => vec![s.clone()],
Audience::Multiple(v) => v.clone(),
};
last_error = anyhow!(
"Token audience mismatch - configured audiences: {:?}, token audiences: {:?}",
configured_audiences,
actual_audiences
);
}
}
Err(last_error)
}
}
#[async_trait::async_trait]
impl AuthProvider for OidcAuthProvider {
async fn validate_request(
&self,
parts: &dyn crate::types::RequestParts,
) -> Result<AuthContext> {
let token = parts
.bearer_token()
.ok_or_else(|| anyhow!("missing bearer token"))?;
if let Some(cached) = self.token_cache.get(token).await {
let is_expired = cached
.expires_at
.map(|exp| exp <= Utc::now() + TimeDelta::seconds(30))
.unwrap_or(false);
if is_expired {
self.token_cache.remove(token).await;
} else {
return Ok((*cached).clone());
}
}
let auth_ctx = self.validate_jwt_token(token).await?;
self.token_cache
.insert(token.to_string(), Arc::new(auth_ctx.clone()))
.await;
Ok(auth_ctx)
}
}