1use std::future::Future;
7use std::sync::Arc;
8
9use kellnr_settings::OAuth2 as OAuth2Settings;
10use openidconnect::core::{
11 CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
12 CoreGenderClaim, CoreIdTokenClaims, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
13 CoreProviderMetadata, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
14 CoreTokenResponse,
15};
16use openidconnect::{
17 AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet,
18 EndpointNotSet, EndpointSet, IssuerUrl, Nonce, PkceCodeChallenge, PkceCodeVerifier,
19 RedirectUrl, Scope, TokenResponse, reqwest,
20};
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23use tracing::{info, warn};
24use url::Url;
25
26type ConfiguredCoreClient = openidconnect::Client<
28 EmptyAdditionalClaims,
29 CoreAuthDisplay,
30 CoreGenderClaim,
31 CoreJweContentEncryptionAlgorithm,
32 CoreJsonWebKey,
33 CoreAuthPrompt,
34 openidconnect::StandardErrorResponse<CoreErrorResponseType>,
35 CoreTokenResponse,
36 CoreTokenIntrospectionResponse,
37 openidconnect::core::CoreRevocableToken,
38 CoreRevocationErrorResponse,
39 EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointMaybeSet, EndpointMaybeSet, >;
46
47#[derive(Debug, Error)]
49pub enum OAuth2Error {
50 #[error("OAuth2 is not enabled")]
51 NotEnabled,
52
53 #[error("OAuth2 configuration is invalid: {0}")]
54 ConfigurationError(String),
55
56 #[error("Failed to discover OIDC provider: {0}")]
57 DiscoveryError(String),
58
59 #[error("Failed to parse URL: {0}")]
60 UrlParseError(#[from] url::ParseError),
61
62 #[error("Failed to exchange authorization code: {0}")]
63 TokenExchangeError(String),
64
65 #[error("Failed to verify ID token: {0}")]
66 TokenVerificationError(String),
67
68 #[error("Missing required claim: {0}")]
69 MissingClaim(String),
70
71 #[error("HTTP request failed: {0}")]
72 HttpError(String),
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct UserInfo {
78 pub subject: String,
80 pub email: Option<String>,
82 pub preferred_username: Option<String>,
84 pub groups: Vec<String>,
86 pub is_admin: bool,
88 pub is_read_only: bool,
90}
91
92#[derive(Debug)]
94pub struct AuthRequest {
95 pub auth_url: Url,
97 pub state: String,
99 pub pkce_verifier: String,
101 pub nonce: String,
103}
104
105#[derive(Debug)]
107pub struct TokenResult {
108 pub claims: CoreIdTokenClaims,
110 pub raw_payload: serde_json::Value,
112}
113
114pub struct OAuth2Handler {
116 client: ConfiguredCoreClient,
117 settings: Arc<OAuth2Settings>,
118 issuer_url: IssuerUrl,
119 http_client: reqwest::Client,
120}
121
122impl OAuth2Handler {
123 pub async fn from_discovery(
128 settings: &OAuth2Settings,
129 redirect_url: &str,
130 ) -> Result<Self, OAuth2Error> {
131 if !settings.enabled {
132 return Err(OAuth2Error::NotEnabled);
133 }
134
135 settings
137 .validate()
138 .map_err(OAuth2Error::ConfigurationError)?;
139
140 let issuer_url_str = settings
141 .issuer_url
142 .as_ref()
143 .ok_or_else(|| OAuth2Error::ConfigurationError("Missing issuer_url".to_string()))?;
144
145 let issuer_url = IssuerUrl::new(issuer_url_str.clone())
146 .map_err(|e| OAuth2Error::ConfigurationError(format!("Invalid issuer URL: {e}")))?;
147
148 info!("Discovering OIDC provider at: {}", issuer_url_str);
149
150 let http_client = reqwest::ClientBuilder::new()
152 .redirect(reqwest::redirect::Policy::none())
153 .build()
154 .map_err(|e| OAuth2Error::HttpError(e.to_string()))?;
155
156 let provider_metadata =
158 CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
159 .await
160 .map_err(|e| OAuth2Error::DiscoveryError(e.to_string()))?;
161
162 let client_id = ClientId::new(
163 settings
164 .client_id
165 .clone()
166 .ok_or_else(|| OAuth2Error::ConfigurationError("Missing client_id".to_string()))?,
167 );
168
169 let client_secret = settings.client_secret.clone().map(ClientSecret::new);
170
171 let redirect_url = RedirectUrl::new(redirect_url.to_string())?;
172
173 let client =
175 CoreClient::from_provider_metadata(provider_metadata, client_id, client_secret)
176 .set_redirect_uri(redirect_url);
177
178 Ok(Self {
179 client,
180 settings: Arc::new(settings.clone()),
181 issuer_url,
182 http_client,
183 })
184 }
185
186 pub fn generate_auth_url(&self) -> AuthRequest {
192 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
193
194 let mut auth_request = self
196 .client
197 .authorize_url(
198 CoreAuthenticationFlow::AuthorizationCode,
199 CsrfToken::new_random,
200 Nonce::new_random,
201 )
202 .set_pkce_challenge(pkce_challenge);
203
204 for scope in &self.settings.scopes {
206 auth_request = auth_request.add_scope(Scope::new(scope.clone()));
207 }
208
209 let (auth_url, csrf_state, nonce) = auth_request.url();
210
211 AuthRequest {
212 auth_url,
213 state: csrf_state.secret().clone(),
214 pkce_verifier: pkce_verifier.secret().clone(),
215 nonce: nonce.secret().clone(),
216 }
217 }
218
219 pub async fn exchange_and_validate(
229 &self,
230 code: &str,
231 pkce_verifier: &str,
232 nonce: &str,
233 ) -> Result<TokenResult, OAuth2Error> {
234 let code = AuthorizationCode::new(code.to_string());
235 let verifier = PkceCodeVerifier::new(pkce_verifier.to_string());
236
237 let token_request = self
238 .client
239 .exchange_code(code)
240 .map_err(|e| OAuth2Error::TokenExchangeError(e.to_string()))?;
241
242 let token_response: CoreTokenResponse = token_request
243 .set_pkce_verifier(verifier)
244 .request_async(&self.http_client)
245 .await
246 .map_err(|e| OAuth2Error::TokenExchangeError(e.to_string()))?;
247
248 let id_token = token_response
250 .id_token()
251 .ok_or_else(|| OAuth2Error::MissingClaim("id_token".to_string()))?;
252
253 let nonce = Nonce::new(nonce.to_string());
254 let verifier = self.client.id_token_verifier();
255
256 let claims = id_token
257 .claims(&verifier, &nonce)
258 .map_err(|e| OAuth2Error::TokenVerificationError(e.to_string()))?
259 .clone();
260
261 let raw_payload = extract_jwt_payload(id_token.to_string().as_str())
263 .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
264
265 Ok(TokenResult {
266 claims,
267 raw_payload,
268 })
269 }
270
271 pub fn extract_user_info(&self, result: &TokenResult) -> UserInfo {
276 let subject = result.claims.subject().as_str().to_string();
277
278 let email = result.claims.email().map(|e| e.as_str().to_string());
279
280 let preferred_username = result
281 .claims
282 .preferred_username()
283 .map(|u| u.as_str().to_string());
284
285 let groups = self.extract_groups(&result.raw_payload);
287
288 let is_admin = self.check_group_membership(
290 &groups,
291 &result.raw_payload,
292 self.settings.admin_group_claim.as_deref(),
293 self.settings.admin_group_value.as_deref(),
294 );
295
296 let is_read_only = self.check_group_membership(
298 &groups,
299 &result.raw_payload,
300 self.settings.read_only_group_claim.as_deref(),
301 self.settings.read_only_group_value.as_deref(),
302 );
303
304 UserInfo {
305 subject,
306 email,
307 preferred_username,
308 groups,
309 is_admin,
310 is_read_only,
311 }
312 }
313
314 fn extract_groups(&self, payload: &serde_json::Value) -> Vec<String> {
316 if let Some(claim_name) = &self.settings.admin_group_claim
318 && let Some(groups) = get_string_array_from_json(payload, claim_name)
319 {
320 return groups;
321 }
322
323 if let Some(claim_name) = &self.settings.read_only_group_claim
325 && self.settings.admin_group_claim.as_ref() != Some(claim_name)
326 && let Some(groups) = get_string_array_from_json(payload, claim_name)
327 {
328 return groups;
329 }
330
331 for claim_name in &["groups", "roles", "group"] {
333 if let Some(groups) = get_string_array_from_json(payload, claim_name) {
334 return groups;
335 }
336 }
337
338 Vec::new()
339 }
340
341 #[allow(clippy::unused_self)]
343 fn check_group_membership(
344 &self,
345 groups: &[String],
346 payload: &serde_json::Value,
347 claim_name: Option<&str>,
348 claim_value: Option<&str>,
349 ) -> bool {
350 let (Some(claim_name), Some(claim_value)) = (claim_name, claim_value) else {
351 return false;
352 };
353
354 if groups.iter().any(|g| g == claim_value) {
356 return true;
357 }
358
359 if let Some(values) = get_string_array_from_json(payload, claim_name) {
361 return values.iter().any(|v| v == claim_value);
362 }
363
364 if let Some(value) = payload.get(claim_name)
366 && let Some(b) = value.as_bool()
367 {
368 return b && claim_value.eq_ignore_ascii_case("true");
370 }
371
372 false
373 }
374
375 pub fn generate_username(user_info: &UserInfo) -> String {
382 if let Some(username) = &user_info.preferred_username
383 && !username.is_empty()
384 {
385 return sanitize_username_with_dots(username);
386 }
387
388 if let Some(email) = &user_info.email
389 && let Some(local_part) = email.split('@').next()
390 && !local_part.is_empty()
391 {
392 return sanitize_username_with_dots(local_part);
393 }
394
395 sanitize_username(&user_info.subject)
396 }
397
398 pub fn issuer_url(&self) -> &str {
400 self.issuer_url.as_str()
401 }
402
403 pub fn settings(&self) -> &OAuth2Settings {
405 &self.settings
406 }
407}
408
409fn extract_jwt_payload(jwt: &str) -> Result<serde_json::Value, OAuth2Error> {
411 use base64::Engine;
412
413 let parts: Vec<&str> = jwt.split('.').collect();
414 if parts.len() != 3 {
415 return Err(OAuth2Error::TokenVerificationError(
416 "Invalid JWT format".to_string(),
417 ));
418 }
419
420 let payload_b64 = parts[1];
421 let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
422 .decode(payload_b64)
423 .map_err(|e| OAuth2Error::TokenVerificationError(format!("Base64 decode error: {e}")))?;
424
425 serde_json::from_slice(&payload_bytes)
426 .map_err(|e| OAuth2Error::TokenVerificationError(format!("JSON parse error: {e}")))
427}
428
429fn get_string_array_from_json(payload: &serde_json::Value, name: &str) -> Option<Vec<String>> {
431 let value = payload.get(name)?;
432
433 if let Some(arr) = value.as_array() {
435 let strings: Vec<String> = arr
436 .iter()
437 .filter_map(serde_json::Value::as_str)
438 .map(String::from)
439 .collect();
440 if !strings.is_empty() {
441 return Some(strings);
442 }
443 }
444
445 if let Some(s) = value.as_str() {
447 return Some(vec![s.to_string()]);
448 }
449
450 None
451}
452
453fn sanitize_username(input: &str) -> String {
459 sanitize_username_impl(input, false)
460}
461
462fn sanitize_username_with_dots(input: &str) -> String {
464 sanitize_username_impl(input, true)
465}
466
467fn sanitize_username_impl(input: &str, allow_dot: bool) -> String {
468 let mut result: String = input
469 .chars()
470 .map(|c| {
471 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || (allow_dot && c == '.') {
472 c.to_ascii_lowercase()
473 } else {
474 '_'
475 }
476 })
477 .collect();
478
479 if result
481 .chars()
482 .next()
483 .is_none_or(|c| !c.is_ascii_alphabetic())
484 {
485 result = format!("u_{result}");
486 }
487
488 if result.len() > 64 {
490 result.truncate(64);
491 }
492
493 result
494}
495
496pub async fn generate_unique_username<F, Fut>(user_info: &UserInfo, is_available: F) -> String
500where
501 F: Fn(String) -> Fut,
502 Fut: Future<Output = bool>,
503{
504 let base = OAuth2Handler::generate_username(user_info);
505
506 if is_available(base.clone()).await {
508 return base;
509 }
510
511 for i in 2..=100 {
513 let candidate = format!("{base}_{i}");
514 if is_available(candidate.clone()).await {
515 return candidate;
516 }
517 }
518
519 warn!("Could not find unique username after 100 attempts, using fallback");
521 format!(
522 "{}_{:x}",
523 sanitize_username(&user_info.subject),
524 std::time::SystemTime::now()
525 .duration_since(std::time::UNIX_EPOCH)
526 .map(|d| d.as_secs())
527 .unwrap_or(0)
528 )
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_sanitize_username() {
537 assert_eq!(sanitize_username("JohnDoe"), "johndoe");
538 assert_eq!(sanitize_username("john.doe"), "john_doe");
539 assert_eq!(sanitize_username("john@example.com"), "john_example_com");
540 assert_eq!(sanitize_username("123user"), "u_123user");
541 assert_eq!(sanitize_username("_user"), "u__user");
542 assert_eq!(sanitize_username("user-name"), "user-name");
543 }
544
545 #[test]
546 fn test_sanitize_username_with_dots() {
547 assert_eq!(sanitize_username_with_dots("john.doe"), "john.doe");
548 assert_eq!(
549 sanitize_username_with_dots("john@example.com"),
550 "john_example.com"
551 );
552 }
553
554 #[test]
555 fn test_generate_username_preferred() {
556 let user_info = UserInfo {
557 subject: "sub123".to_string(),
558 email: Some("john@example.com".to_string()),
559 preferred_username: Some("johndoe".to_string()),
560 groups: vec![],
561 is_admin: false,
562 is_read_only: false,
563 };
564 assert_eq!(OAuth2Handler::generate_username(&user_info), "johndoe");
565 }
566
567 #[test]
568 fn test_generate_username_preferred_preserves_dot() {
569 let user_info = UserInfo {
570 subject: "sub123".to_string(),
571 email: Some("john@example.com".to_string()),
572 preferred_username: Some("john.doe".to_string()),
573 groups: vec![],
574 is_admin: false,
575 is_read_only: false,
576 };
577 assert_eq!(OAuth2Handler::generate_username(&user_info), "john.doe");
578 }
579
580 #[test]
581 fn test_generate_username_email() {
582 let user_info = UserInfo {
583 subject: "sub123".to_string(),
584 email: Some("john@example.com".to_string()),
585 preferred_username: None,
586 groups: vec![],
587 is_admin: false,
588 is_read_only: false,
589 };
590 assert_eq!(OAuth2Handler::generate_username(&user_info), "john");
591 }
592
593 #[test]
594 fn test_generate_username_email_preserves_dot_in_local_part() {
595 let user_info = UserInfo {
596 subject: "sub123".to_string(),
597 email: Some("john.doe@example.com".to_string()),
598 preferred_username: None,
599 groups: vec![],
600 is_admin: false,
601 is_read_only: false,
602 };
603 assert_eq!(OAuth2Handler::generate_username(&user_info), "john.doe");
604 }
605
606 #[test]
607 fn test_generate_username_subject() {
608 let user_info = UserInfo {
609 subject: "sub123".to_string(),
610 email: None,
611 preferred_username: None,
612 groups: vec![],
613 is_admin: false,
614 is_read_only: false,
615 };
616 assert_eq!(OAuth2Handler::generate_username(&user_info), "sub123");
618 }
619
620 #[tokio::test]
621 async fn test_generate_unique_username() {
622 let user_info = UserInfo {
623 subject: "sub123".to_string(),
624 email: Some("john@example.com".to_string()),
625 preferred_username: Some("johndoe".to_string()),
626 groups: vec![],
627 is_admin: false,
628 is_read_only: false,
629 };
630
631 let username = generate_unique_username(&user_info, |_| async { true }).await;
633 assert_eq!(username, "johndoe");
634
635 let username =
637 generate_unique_username(&user_info, |name| async move { name != "johndoe" }).await;
638 assert_eq!(username, "johndoe_2");
639
640 let username = generate_unique_username(&user_info, |name| async move {
642 name != "johndoe" && name != "johndoe_2"
643 })
644 .await;
645 assert_eq!(username, "johndoe_3");
646 }
647
648 #[test]
649 fn test_extract_jwt_payload() {
650 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
652 let payload = extract_jwt_payload(jwt).unwrap();
653 assert_eq!(
654 payload.get("sub").and_then(|v| v.as_str()),
655 Some("1234567890")
656 );
657 assert_eq!(
658 payload.get("name").and_then(|v| v.as_str()),
659 Some("John Doe")
660 );
661 }
662
663 #[test]
664 fn test_get_string_array_from_json() {
665 let payload = serde_json::json!({
666 "groups": ["admin", "users"],
667 "single_group": "single",
668 "number": 42
669 });
670
671 assert_eq!(
672 get_string_array_from_json(&payload, "groups"),
673 Some(vec!["admin".to_string(), "users".to_string()])
674 );
675 assert_eq!(
676 get_string_array_from_json(&payload, "single_group"),
677 Some(vec!["single".to_string()])
678 );
679 assert_eq!(get_string_array_from_json(&payload, "number"), None);
680 assert_eq!(get_string_array_from_json(&payload, "missing"), None);
681 }
682}