use std::future::Future;
use std::sync::Arc;
use kellnr_settings::OAuth2 as OAuth2Settings;
use openidconnect::core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
CoreGenderClaim, CoreIdTokenClaims, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
CoreProviderMetadata, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenResponse,
};
use openidconnect::{
AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet,
EndpointNotSet, EndpointSet, IssuerUrl, Nonce, PkceCodeChallenge, PkceCodeVerifier,
RedirectUrl, Scope, TokenResponse, reqwest,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::{info, warn};
use url::Url;
type ConfiguredCoreClient = openidconnect::Client<
EmptyAdditionalClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
openidconnect::StandardErrorResponse<CoreErrorResponseType>,
CoreTokenResponse,
CoreTokenIntrospectionResponse,
openidconnect::core::CoreRevocableToken,
CoreRevocationErrorResponse,
EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointMaybeSet, EndpointMaybeSet, >;
#[derive(Debug, Error)]
pub enum OAuth2Error {
#[error("OAuth2 is not enabled")]
NotEnabled,
#[error("OAuth2 configuration is invalid: {0}")]
ConfigurationError(String),
#[error("Failed to discover OIDC provider: {0}")]
DiscoveryError(String),
#[error("Failed to parse URL: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("Failed to exchange authorization code: {0}")]
TokenExchangeError(String),
#[error("Failed to verify ID token: {0}")]
TokenVerificationError(String),
#[error("Missing required claim: {0}")]
MissingClaim(String),
#[error("HTTP request failed: {0}")]
HttpError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub subject: String,
pub email: Option<String>,
pub preferred_username: Option<String>,
pub groups: Vec<String>,
pub is_admin: bool,
pub is_read_only: bool,
}
#[derive(Debug)]
pub struct AuthRequest {
pub auth_url: Url,
pub state: String,
pub pkce_verifier: String,
pub nonce: String,
}
#[derive(Debug)]
pub struct TokenResult {
pub claims: CoreIdTokenClaims,
pub raw_payload: serde_json::Value,
}
pub struct OAuth2Handler {
client: ConfiguredCoreClient,
settings: Arc<OAuth2Settings>,
issuer_url: IssuerUrl,
http_client: reqwest::Client,
}
impl OAuth2Handler {
pub async fn from_discovery(
settings: &OAuth2Settings,
redirect_url: &str,
) -> Result<Self, OAuth2Error> {
if !settings.enabled {
return Err(OAuth2Error::NotEnabled);
}
settings
.validate()
.map_err(OAuth2Error::ConfigurationError)?;
let issuer_url_str = settings
.issuer_url
.as_ref()
.ok_or_else(|| OAuth2Error::ConfigurationError("Missing issuer_url".to_string()))?;
let issuer_url = IssuerUrl::new(issuer_url_str.clone())
.map_err(|e| OAuth2Error::ConfigurationError(format!("Invalid issuer URL: {e}")))?;
info!("Discovering OIDC provider at: {}", issuer_url_str);
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| OAuth2Error::HttpError(e.to_string()))?;
let provider_metadata =
CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
.await
.map_err(|e| OAuth2Error::DiscoveryError(e.to_string()))?;
let client_id = ClientId::new(
settings
.client_id
.clone()
.ok_or_else(|| OAuth2Error::ConfigurationError("Missing client_id".to_string()))?,
);
let client_secret = settings.client_secret.clone().map(ClientSecret::new);
let redirect_url = RedirectUrl::new(redirect_url.to_string())?;
let client =
CoreClient::from_provider_metadata(provider_metadata, client_id, client_secret)
.set_redirect_uri(redirect_url);
Ok(Self {
client,
settings: Arc::new(settings.clone()),
issuer_url,
http_client,
})
}
pub fn generate_auth_url(&self) -> AuthRequest {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_request = self
.client
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.set_pkce_challenge(pkce_challenge);
for scope in &self.settings.scopes {
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
}
let (auth_url, csrf_state, nonce) = auth_request.url();
AuthRequest {
auth_url,
state: csrf_state.secret().clone(),
pkce_verifier: pkce_verifier.secret().clone(),
nonce: nonce.secret().clone(),
}
}
pub async fn exchange_and_validate(
&self,
code: &str,
pkce_verifier: &str,
nonce: &str,
) -> Result<TokenResult, OAuth2Error> {
let code = AuthorizationCode::new(code.to_string());
let verifier = PkceCodeVerifier::new(pkce_verifier.to_string());
let token_request = self
.client
.exchange_code(code)
.map_err(|e| OAuth2Error::TokenExchangeError(e.to_string()))?;
let token_response: CoreTokenResponse = token_request
.set_pkce_verifier(verifier)
.request_async(&self.http_client)
.await
.map_err(|e| OAuth2Error::TokenExchangeError(e.to_string()))?;
let id_token = token_response
.id_token()
.ok_or_else(|| OAuth2Error::MissingClaim("id_token".to_string()))?;
let nonce = Nonce::new(nonce.to_string());
let verifier = self.client.id_token_verifier();
let claims = id_token
.claims(&verifier, &nonce)
.map_err(|e| OAuth2Error::TokenVerificationError(e.to_string()))?
.clone();
let raw_payload = extract_jwt_payload(id_token.to_string().as_str())
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
Ok(TokenResult {
claims,
raw_payload,
})
}
pub fn extract_user_info(&self, result: &TokenResult) -> UserInfo {
let subject = result.claims.subject().as_str().to_string();
let email = result.claims.email().map(|e| e.as_str().to_string());
let preferred_username = result
.claims
.preferred_username()
.map(|u| u.as_str().to_string());
let groups = self.extract_groups(&result.raw_payload);
let is_admin = self.check_group_membership(
&groups,
&result.raw_payload,
self.settings.admin_group_claim.as_deref(),
self.settings.admin_group_value.as_deref(),
);
let is_read_only = self.check_group_membership(
&groups,
&result.raw_payload,
self.settings.read_only_group_claim.as_deref(),
self.settings.read_only_group_value.as_deref(),
);
UserInfo {
subject,
email,
preferred_username,
groups,
is_admin,
is_read_only,
}
}
fn extract_groups(&self, payload: &serde_json::Value) -> Vec<String> {
if let Some(claim_name) = &self.settings.admin_group_claim
&& let Some(groups) = get_string_array_from_json(payload, claim_name)
{
return groups;
}
if let Some(claim_name) = &self.settings.read_only_group_claim
&& self.settings.admin_group_claim.as_ref() != Some(claim_name)
&& let Some(groups) = get_string_array_from_json(payload, claim_name)
{
return groups;
}
for claim_name in &["groups", "roles", "group"] {
if let Some(groups) = get_string_array_from_json(payload, claim_name) {
return groups;
}
}
Vec::new()
}
#[allow(clippy::unused_self)]
fn check_group_membership(
&self,
groups: &[String],
payload: &serde_json::Value,
claim_name: Option<&str>,
claim_value: Option<&str>,
) -> bool {
let (Some(claim_name), Some(claim_value)) = (claim_name, claim_value) else {
return false;
};
if groups.iter().any(|g| g == claim_value) {
return true;
}
if let Some(values) = get_string_array_from_json(payload, claim_name) {
return values.iter().any(|v| v == claim_value);
}
if let Some(value) = payload.get(claim_name)
&& let Some(b) = value.as_bool()
{
return b && claim_value.eq_ignore_ascii_case("true");
}
false
}
pub fn generate_username(user_info: &UserInfo) -> String {
if let Some(username) = &user_info.preferred_username
&& !username.is_empty()
{
return sanitize_username_with_dots(username);
}
if let Some(email) = &user_info.email
&& let Some(local_part) = email.split('@').next()
&& !local_part.is_empty()
{
return sanitize_username_with_dots(local_part);
}
sanitize_username(&user_info.subject)
}
pub fn issuer_url(&self) -> &str {
self.issuer_url.as_str()
}
pub fn settings(&self) -> &OAuth2Settings {
&self.settings
}
}
fn extract_jwt_payload(jwt: &str) -> Result<serde_json::Value, OAuth2Error> {
use base64::Engine;
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() != 3 {
return Err(OAuth2Error::TokenVerificationError(
"Invalid JWT format".to_string(),
));
}
let payload_b64 = parts[1];
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|e| OAuth2Error::TokenVerificationError(format!("Base64 decode error: {e}")))?;
serde_json::from_slice(&payload_bytes)
.map_err(|e| OAuth2Error::TokenVerificationError(format!("JSON parse error: {e}")))
}
fn get_string_array_from_json(payload: &serde_json::Value, name: &str) -> Option<Vec<String>> {
let value = payload.get(name)?;
if let Some(arr) = value.as_array() {
let strings: Vec<String> = arr
.iter()
.filter_map(serde_json::Value::as_str)
.map(String::from)
.collect();
if !strings.is_empty() {
return Some(strings);
}
}
if let Some(s) = value.as_str() {
return Some(vec![s.to_string()]);
}
None
}
fn sanitize_username(input: &str) -> String {
sanitize_username_impl(input, false)
}
fn sanitize_username_with_dots(input: &str) -> String {
sanitize_username_impl(input, true)
}
fn sanitize_username_impl(input: &str, allow_dot: bool) -> String {
let mut result: String = input
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '_' || c == '-' || (allow_dot && c == '.') {
c.to_ascii_lowercase()
} else {
'_'
}
})
.collect();
if result
.chars()
.next()
.is_none_or(|c| !c.is_ascii_alphabetic())
{
result = format!("u_{result}");
}
if result.len() > 64 {
result.truncate(64);
}
result
}
pub async fn generate_unique_username<F, Fut>(user_info: &UserInfo, is_available: F) -> String
where
F: Fn(String) -> Fut,
Fut: Future<Output = bool>,
{
let base = OAuth2Handler::generate_username(user_info);
if is_available(base.clone()).await {
return base;
}
for i in 2..=100 {
let candidate = format!("{base}_{i}");
if is_available(candidate.clone()).await {
return candidate;
}
}
warn!("Could not find unique username after 100 attempts, using fallback");
format!(
"{}_{:x}",
sanitize_username(&user_info.subject),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_username() {
assert_eq!(sanitize_username("JohnDoe"), "johndoe");
assert_eq!(sanitize_username("john.doe"), "john_doe");
assert_eq!(sanitize_username("john@example.com"), "john_example_com");
assert_eq!(sanitize_username("123user"), "u_123user");
assert_eq!(sanitize_username("_user"), "u__user");
assert_eq!(sanitize_username("user-name"), "user-name");
}
#[test]
fn test_sanitize_username_with_dots() {
assert_eq!(sanitize_username_with_dots("john.doe"), "john.doe");
assert_eq!(
sanitize_username_with_dots("john@example.com"),
"john_example.com"
);
}
#[test]
fn test_generate_username_preferred() {
let user_info = UserInfo {
subject: "sub123".to_string(),
email: Some("john@example.com".to_string()),
preferred_username: Some("johndoe".to_string()),
groups: vec![],
is_admin: false,
is_read_only: false,
};
assert_eq!(OAuth2Handler::generate_username(&user_info), "johndoe");
}
#[test]
fn test_generate_username_preferred_preserves_dot() {
let user_info = UserInfo {
subject: "sub123".to_string(),
email: Some("john@example.com".to_string()),
preferred_username: Some("john.doe".to_string()),
groups: vec![],
is_admin: false,
is_read_only: false,
};
assert_eq!(OAuth2Handler::generate_username(&user_info), "john.doe");
}
#[test]
fn test_generate_username_email() {
let user_info = UserInfo {
subject: "sub123".to_string(),
email: Some("john@example.com".to_string()),
preferred_username: None,
groups: vec![],
is_admin: false,
is_read_only: false,
};
assert_eq!(OAuth2Handler::generate_username(&user_info), "john");
}
#[test]
fn test_generate_username_email_preserves_dot_in_local_part() {
let user_info = UserInfo {
subject: "sub123".to_string(),
email: Some("john.doe@example.com".to_string()),
preferred_username: None,
groups: vec![],
is_admin: false,
is_read_only: false,
};
assert_eq!(OAuth2Handler::generate_username(&user_info), "john.doe");
}
#[test]
fn test_generate_username_subject() {
let user_info = UserInfo {
subject: "sub123".to_string(),
email: None,
preferred_username: None,
groups: vec![],
is_admin: false,
is_read_only: false,
};
assert_eq!(OAuth2Handler::generate_username(&user_info), "sub123");
}
#[tokio::test]
async fn test_generate_unique_username() {
let user_info = UserInfo {
subject: "sub123".to_string(),
email: Some("john@example.com".to_string()),
preferred_username: Some("johndoe".to_string()),
groups: vec![],
is_admin: false,
is_read_only: false,
};
let username = generate_unique_username(&user_info, |_| async { true }).await;
assert_eq!(username, "johndoe");
let username =
generate_unique_username(&user_info, |name| async move { name != "johndoe" }).await;
assert_eq!(username, "johndoe_2");
let username = generate_unique_username(&user_info, |name| async move {
name != "johndoe" && name != "johndoe_2"
})
.await;
assert_eq!(username, "johndoe_3");
}
#[test]
fn test_extract_jwt_payload() {
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
let payload = extract_jwt_payload(jwt).unwrap();
assert_eq!(
payload.get("sub").and_then(|v| v.as_str()),
Some("1234567890")
);
assert_eq!(
payload.get("name").and_then(|v| v.as_str()),
Some("John Doe")
);
}
#[test]
fn test_get_string_array_from_json() {
let payload = serde_json::json!({
"groups": ["admin", "users"],
"single_group": "single",
"number": 42
});
assert_eq!(
get_string_array_from_json(&payload, "groups"),
Some(vec!["admin".to_string(), "users".to_string()])
);
assert_eq!(
get_string_array_from_json(&payload, "single_group"),
Some(vec!["single".to_string()])
);
assert_eq!(get_string_array_from_json(&payload, "number"), None);
assert_eq!(get_string_array_from_json(&payload, "missing"), None);
}
}