Skip to main content

kellnr_auth/
oauth2.rs

1//! OAuth2/OpenID Connect authentication handler
2//!
3//! This module provides OAuth2/OIDC authentication support using the authorization
4//! code flow with PKCE.
5
6use std::future::Future;
7use std::sync::Arc;
8
9use kellnr_settings::OAuth2 as OAuth2Settings;
10use openidconnect::core::{
11    CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
12    CoreGenderClaim, CoreIdTokenClaims, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
13    CoreProviderMetadata, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
14    CoreTokenResponse,
15};
16use openidconnect::{
17    AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet,
18    EndpointNotSet, EndpointSet, IssuerUrl, Nonce, PkceCodeChallenge, PkceCodeVerifier,
19    RedirectUrl, Scope, TokenResponse, reqwest,
20};
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23use tracing::{info, warn};
24use url::Url;
25
26/// Type alias for the configured OIDC client after provider discovery
27type ConfiguredCoreClient = openidconnect::Client<
28    EmptyAdditionalClaims,
29    CoreAuthDisplay,
30    CoreGenderClaim,
31    CoreJweContentEncryptionAlgorithm,
32    CoreJsonWebKey,
33    CoreAuthPrompt,
34    openidconnect::StandardErrorResponse<CoreErrorResponseType>,
35    CoreTokenResponse,
36    CoreTokenIntrospectionResponse,
37    openidconnect::core::CoreRevocableToken,
38    CoreRevocationErrorResponse,
39    EndpointSet,      // HasAuthUrl - set by from_provider_metadata
40    EndpointNotSet,   // HasDeviceAuthUrl
41    EndpointNotSet,   // HasIntrospectionUrl
42    EndpointNotSet,   // HasRevocationUrl
43    EndpointMaybeSet, // HasTokenUrl - maybe set by provider
44    EndpointMaybeSet, // HasUserInfoUrl - maybe set by provider
45>;
46
47/// Errors that can occur during OAuth2/OIDC operations
48#[derive(Debug, Error)]
49pub enum OAuth2Error {
50    #[error("OAuth2 is not enabled")]
51    NotEnabled,
52
53    #[error("OAuth2 configuration is invalid: {0}")]
54    ConfigurationError(String),
55
56    #[error("Failed to discover OIDC provider: {0}")]
57    DiscoveryError(String),
58
59    #[error("Failed to parse URL: {0}")]
60    UrlParseError(#[from] url::ParseError),
61
62    #[error("Failed to exchange authorization code: {0}")]
63    TokenExchangeError(String),
64
65    #[error("Failed to verify ID token: {0}")]
66    TokenVerificationError(String),
67
68    #[error("Missing required claim: {0}")]
69    MissingClaim(String),
70
71    #[error("HTTP request failed: {0}")]
72    HttpError(String),
73}
74
75/// Information extracted from the OIDC ID token
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct UserInfo {
78    /// The subject claim (unique user identifier from the provider)
79    pub subject: String,
80    /// Email address (if available)
81    pub email: Option<String>,
82    /// Preferred username (if available)
83    pub preferred_username: Option<String>,
84    /// Groups the user belongs to (from the configured claim)
85    pub groups: Vec<String>,
86    /// Whether the user should be an admin (derived from group claims)
87    pub is_admin: bool,
88    /// Whether the user should be read-only (derived from group claims)
89    pub is_read_only: bool,
90}
91
92/// Request data for initiating `OAuth2` authentication
93#[derive(Debug)]
94pub struct AuthRequest {
95    /// The URL to redirect the user to for authentication
96    pub auth_url: Url,
97    /// CSRF protection state (to be stored in database)
98    pub state: String,
99    /// PKCE verifier (to be stored in database for code exchange)
100    pub pkce_verifier: String,
101    /// Nonce for ID token verification (to be stored in database)
102    pub nonce: String,
103}
104
105/// Token response from the `OAuth2` provider
106#[derive(Debug)]
107pub struct TokenResult {
108    /// The standard ID token claims
109    pub claims: CoreIdTokenClaims,
110    /// Raw JWT payload for extracting additional claims
111    pub raw_payload: serde_json::Value,
112}
113
114/// OAuth2/OIDC authentication handler
115pub struct OAuth2Handler {
116    client: ConfiguredCoreClient,
117    settings: Arc<OAuth2Settings>,
118    issuer_url: IssuerUrl,
119    http_client: reqwest::Client,
120}
121
122impl OAuth2Handler {
123    /// Create a new `OAuth2Handler` using OIDC discovery
124    ///
125    /// This performs automatic discovery of the OIDC provider's configuration
126    /// using the well-known endpoint.
127    pub async fn from_discovery(
128        settings: &OAuth2Settings,
129        redirect_url: &str,
130    ) -> Result<Self, OAuth2Error> {
131        if !settings.enabled {
132            return Err(OAuth2Error::NotEnabled);
133        }
134
135        // Validate configuration
136        settings
137            .validate()
138            .map_err(OAuth2Error::ConfigurationError)?;
139
140        let issuer_url_str = settings
141            .issuer_url
142            .as_ref()
143            .ok_or_else(|| OAuth2Error::ConfigurationError("Missing issuer_url".to_string()))?;
144
145        let issuer_url = IssuerUrl::new(issuer_url_str.clone())
146            .map_err(|e| OAuth2Error::ConfigurationError(format!("Invalid issuer URL: {e}")))?;
147
148        info!("Discovering OIDC provider at: {}", issuer_url_str);
149
150        // Create HTTP client
151        let http_client = reqwest::ClientBuilder::new()
152            .redirect(reqwest::redirect::Policy::none())
153            .build()
154            .map_err(|e| OAuth2Error::HttpError(e.to_string()))?;
155
156        // Perform OIDC discovery
157        let provider_metadata =
158            CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
159                .await
160                .map_err(|e| OAuth2Error::DiscoveryError(e.to_string()))?;
161
162        let client_id = ClientId::new(
163            settings
164                .client_id
165                .clone()
166                .ok_or_else(|| OAuth2Error::ConfigurationError("Missing client_id".to_string()))?,
167        );
168
169        let client_secret = settings.client_secret.clone().map(ClientSecret::new);
170
171        let redirect_url = RedirectUrl::new(redirect_url.to_string())?;
172
173        // Build the client from provider metadata
174        let client =
175            CoreClient::from_provider_metadata(provider_metadata, client_id, client_secret)
176                .set_redirect_uri(redirect_url);
177
178        Ok(Self {
179            client,
180            settings: Arc::new(settings.clone()),
181            issuer_url,
182            http_client,
183        })
184    }
185
186    /// Generate an authorization URL for the `OAuth2` flow
187    ///
188    /// Returns an `AuthRequest` containing the URL to redirect the user to,
189    /// along with the state, PKCE verifier, and nonce that must be stored
190    /// for later verification.
191    pub fn generate_auth_url(&self) -> AuthRequest {
192        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
193
194        // Build the authorization request with configured scopes
195        let mut auth_request = self
196            .client
197            .authorize_url(
198                CoreAuthenticationFlow::AuthorizationCode,
199                CsrfToken::new_random,
200                Nonce::new_random,
201            )
202            .set_pkce_challenge(pkce_challenge);
203
204        // Add configured scopes
205        for scope in &self.settings.scopes {
206            auth_request = auth_request.add_scope(Scope::new(scope.clone()));
207        }
208
209        let (auth_url, csrf_state, nonce) = auth_request.url();
210
211        AuthRequest {
212            auth_url,
213            state: csrf_state.secret().clone(),
214            pkce_verifier: pkce_verifier.secret().clone(),
215            nonce: nonce.secret().clone(),
216        }
217    }
218
219    /// Exchange an authorization code for tokens and validate the ID token
220    ///
221    /// # Arguments
222    /// * `code` - The authorization code received from the provider
223    /// * `pkce_verifier` - The PKCE verifier stored during `generate_auth_url`
224    /// * `nonce` - The nonce stored during `generate_auth_url`
225    ///
226    /// # Returns
227    /// The validated ID token claims
228    pub async fn exchange_and_validate(
229        &self,
230        code: &str,
231        pkce_verifier: &str,
232        nonce: &str,
233    ) -> Result<TokenResult, OAuth2Error> {
234        let code = AuthorizationCode::new(code.to_string());
235        let verifier = PkceCodeVerifier::new(pkce_verifier.to_string());
236
237        let token_request = self
238            .client
239            .exchange_code(code)
240            .map_err(|e| OAuth2Error::TokenExchangeError(e.to_string()))?;
241
242        let token_response: CoreTokenResponse = token_request
243            .set_pkce_verifier(verifier)
244            .request_async(&self.http_client)
245            .await
246            .map_err(|e| OAuth2Error::TokenExchangeError(e.to_string()))?;
247
248        // Get and validate the ID token
249        let id_token = token_response
250            .id_token()
251            .ok_or_else(|| OAuth2Error::MissingClaim("id_token".to_string()))?;
252
253        let nonce = Nonce::new(nonce.to_string());
254        let verifier = self.client.id_token_verifier();
255
256        let claims = id_token
257            .claims(&verifier, &nonce)
258            .map_err(|e| OAuth2Error::TokenVerificationError(e.to_string()))?
259            .clone();
260
261        // Also extract raw payload for additional claims
262        let raw_payload = extract_jwt_payload(id_token.to_string().as_str())
263            .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
264
265        Ok(TokenResult {
266            claims,
267            raw_payload,
268        })
269    }
270
271    /// Extract user information from ID token claims
272    ///
273    /// This extracts the subject, email, `preferred_username`, and group membership
274    /// based on the `OAuth2` settings configuration.
275    pub fn extract_user_info(&self, result: &TokenResult) -> UserInfo {
276        let subject = result.claims.subject().as_str().to_string();
277
278        let email = result.claims.email().map(|e| e.as_str().to_string());
279
280        let preferred_username = result
281            .claims
282            .preferred_username()
283            .map(|u| u.as_str().to_string());
284
285        // Extract groups from raw JWT payload
286        let groups = self.extract_groups(&result.raw_payload);
287
288        // Determine admin status from group claims
289        let is_admin = self.check_group_membership(
290            &groups,
291            &result.raw_payload,
292            self.settings.admin_group_claim.as_deref(),
293            self.settings.admin_group_value.as_deref(),
294        );
295
296        // Determine read-only status from group claims
297        let is_read_only = self.check_group_membership(
298            &groups,
299            &result.raw_payload,
300            self.settings.read_only_group_claim.as_deref(),
301            self.settings.read_only_group_value.as_deref(),
302        );
303
304        UserInfo {
305            subject,
306            email,
307            preferred_username,
308            groups,
309            is_admin,
310            is_read_only,
311        }
312    }
313
314    /// Extract groups from the raw JWT payload
315    fn extract_groups(&self, payload: &serde_json::Value) -> Vec<String> {
316        // First try the configured admin group claim if it exists
317        if let Some(claim_name) = &self.settings.admin_group_claim
318            && let Some(groups) = get_string_array_from_json(payload, claim_name)
319        {
320            return groups;
321        }
322
323        // Then try the configured read-only group claim if different
324        if let Some(claim_name) = &self.settings.read_only_group_claim
325            && self.settings.admin_group_claim.as_ref() != Some(claim_name)
326            && let Some(groups) = get_string_array_from_json(payload, claim_name)
327        {
328            return groups;
329        }
330
331        // Try common group claim names
332        for claim_name in &["groups", "roles", "group"] {
333            if let Some(groups) = get_string_array_from_json(payload, claim_name) {
334                return groups;
335            }
336        }
337
338        Vec::new()
339    }
340
341    /// Check if the user belongs to a specific group based on claims
342    #[allow(clippy::unused_self)]
343    fn check_group_membership(
344        &self,
345        groups: &[String],
346        payload: &serde_json::Value,
347        claim_name: Option<&str>,
348        claim_value: Option<&str>,
349    ) -> bool {
350        let (Some(claim_name), Some(claim_value)) = (claim_name, claim_value) else {
351            return false;
352        };
353
354        // First check in the extracted groups
355        if groups.iter().any(|g| g == claim_value) {
356            return true;
357        }
358
359        // Also check directly in the claim (in case it's a different claim than groups)
360        if let Some(values) = get_string_array_from_json(payload, claim_name) {
361            return values.iter().any(|v| v == claim_value);
362        }
363
364        // Check if it's a boolean claim
365        if let Some(value) = payload.get(claim_name)
366            && let Some(b) = value.as_bool()
367        {
368            // If the claim is a boolean and we're checking for "true"
369            return b && claim_value.eq_ignore_ascii_case("true");
370        }
371
372        false
373    }
374
375    /// Generate a unique username for auto-provisioning
376    ///
377    /// Priority:
378    /// 1. `preferred_username` claim
379    /// 2. Local part of email (before @)
380    /// 3. Subject claim
381    pub fn generate_username(user_info: &UserInfo) -> String {
382        if let Some(username) = &user_info.preferred_username
383            && !username.is_empty()
384        {
385            return sanitize_username_with_dots(username);
386        }
387
388        if let Some(email) = &user_info.email
389            && let Some(local_part) = email.split('@').next()
390            && !local_part.is_empty()
391        {
392            return sanitize_username_with_dots(local_part);
393        }
394
395        sanitize_username(&user_info.subject)
396    }
397
398    /// Get the issuer URL string
399    pub fn issuer_url(&self) -> &str {
400        self.issuer_url.as_str()
401    }
402
403    /// Get a reference to the settings
404    pub fn settings(&self) -> &OAuth2Settings {
405        &self.settings
406    }
407}
408
409/// Extract and decode the payload from a JWT
410fn extract_jwt_payload(jwt: &str) -> Result<serde_json::Value, OAuth2Error> {
411    use base64::Engine;
412
413    let parts: Vec<&str> = jwt.split('.').collect();
414    if parts.len() != 3 {
415        return Err(OAuth2Error::TokenVerificationError(
416            "Invalid JWT format".to_string(),
417        ));
418    }
419
420    let payload_b64 = parts[1];
421    let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
422        .decode(payload_b64)
423        .map_err(|e| OAuth2Error::TokenVerificationError(format!("Base64 decode error: {e}")))?;
424
425    serde_json::from_slice(&payload_bytes)
426        .map_err(|e| OAuth2Error::TokenVerificationError(format!("JSON parse error: {e}")))
427}
428
429/// Get a string array from a JSON value
430fn get_string_array_from_json(payload: &serde_json::Value, name: &str) -> Option<Vec<String>> {
431    let value = payload.get(name)?;
432
433    // Try to parse as array of strings
434    if let Some(arr) = value.as_array() {
435        let strings: Vec<String> = arr
436            .iter()
437            .filter_map(serde_json::Value::as_str)
438            .map(String::from)
439            .collect();
440        if !strings.is_empty() {
441            return Some(strings);
442        }
443    }
444
445    // Try to parse as single string (some providers return single group as string)
446    if let Some(s) = value.as_str() {
447        return Some(vec![s.to_string()]);
448    }
449
450    None
451}
452
453/// Sanitize a username to be valid for Kellnr
454///
455/// - Converts to lowercase
456/// - Replaces invalid characters with underscores
457/// - Ensures it starts with a letter
458fn sanitize_username(input: &str) -> String {
459    sanitize_username_impl(input, false)
460}
461
462/// Same as `sanitize_username`, but preserves dots.
463fn sanitize_username_with_dots(input: &str) -> String {
464    sanitize_username_impl(input, true)
465}
466
467fn sanitize_username_impl(input: &str, allow_dot: bool) -> String {
468    let mut result: String = input
469        .chars()
470        .map(|c| {
471            if c.is_ascii_alphanumeric() || c == '_' || c == '-' || (allow_dot && c == '.') {
472                c.to_ascii_lowercase()
473            } else {
474                '_'
475            }
476        })
477        .collect();
478
479    // Ensure it starts with a letter
480    if result
481        .chars()
482        .next()
483        .is_none_or(|c| !c.is_ascii_alphabetic())
484    {
485        result = format!("u_{result}");
486    }
487
488    // Truncate if too long (max 64 chars is reasonable)
489    if result.len() > 64 {
490        result.truncate(64);
491    }
492
493    result
494}
495
496/// Generate a unique username with collision handling
497///
498/// If the base username is taken, appends _2, _3, etc.
499pub async fn generate_unique_username<F, Fut>(user_info: &UserInfo, is_available: F) -> String
500where
501    F: Fn(String) -> Fut,
502    Fut: Future<Output = bool>,
503{
504    let base = OAuth2Handler::generate_username(user_info);
505
506    // Try the base username first
507    if is_available(base.clone()).await {
508        return base;
509    }
510
511    // Try with numeric suffixes
512    for i in 2..=100 {
513        let candidate = format!("{base}_{i}");
514        if is_available(candidate.clone()).await {
515            return candidate;
516        }
517    }
518
519    // Fallback: use subject with timestamp (very unlikely to reach here)
520    warn!("Could not find unique username after 100 attempts, using fallback");
521    format!(
522        "{}_{:x}",
523        sanitize_username(&user_info.subject),
524        std::time::SystemTime::now()
525            .duration_since(std::time::UNIX_EPOCH)
526            .map(|d| d.as_secs())
527            .unwrap_or(0)
528    )
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_sanitize_username() {
537        assert_eq!(sanitize_username("JohnDoe"), "johndoe");
538        assert_eq!(sanitize_username("john.doe"), "john_doe");
539        assert_eq!(sanitize_username("john@example.com"), "john_example_com");
540        assert_eq!(sanitize_username("123user"), "u_123user");
541        assert_eq!(sanitize_username("_user"), "u__user");
542        assert_eq!(sanitize_username("user-name"), "user-name");
543    }
544
545    #[test]
546    fn test_sanitize_username_with_dots() {
547        assert_eq!(sanitize_username_with_dots("john.doe"), "john.doe");
548        assert_eq!(
549            sanitize_username_with_dots("john@example.com"),
550            "john_example.com"
551        );
552    }
553
554    #[test]
555    fn test_generate_username_preferred() {
556        let user_info = UserInfo {
557            subject: "sub123".to_string(),
558            email: Some("john@example.com".to_string()),
559            preferred_username: Some("johndoe".to_string()),
560            groups: vec![],
561            is_admin: false,
562            is_read_only: false,
563        };
564        assert_eq!(OAuth2Handler::generate_username(&user_info), "johndoe");
565    }
566
567    #[test]
568    fn test_generate_username_preferred_preserves_dot() {
569        let user_info = UserInfo {
570            subject: "sub123".to_string(),
571            email: Some("john@example.com".to_string()),
572            preferred_username: Some("john.doe".to_string()),
573            groups: vec![],
574            is_admin: false,
575            is_read_only: false,
576        };
577        assert_eq!(OAuth2Handler::generate_username(&user_info), "john.doe");
578    }
579
580    #[test]
581    fn test_generate_username_email() {
582        let user_info = UserInfo {
583            subject: "sub123".to_string(),
584            email: Some("john@example.com".to_string()),
585            preferred_username: None,
586            groups: vec![],
587            is_admin: false,
588            is_read_only: false,
589        };
590        assert_eq!(OAuth2Handler::generate_username(&user_info), "john");
591    }
592
593    #[test]
594    fn test_generate_username_email_preserves_dot_in_local_part() {
595        let user_info = UserInfo {
596            subject: "sub123".to_string(),
597            email: Some("john.doe@example.com".to_string()),
598            preferred_username: None,
599            groups: vec![],
600            is_admin: false,
601            is_read_only: false,
602        };
603        assert_eq!(OAuth2Handler::generate_username(&user_info), "john.doe");
604    }
605
606    #[test]
607    fn test_generate_username_subject() {
608        let user_info = UserInfo {
609            subject: "sub123".to_string(),
610            email: None,
611            preferred_username: None,
612            groups: vec![],
613            is_admin: false,
614            is_read_only: false,
615        };
616        // "sub123" starts with 's' which is a letter, so no prefix needed
617        assert_eq!(OAuth2Handler::generate_username(&user_info), "sub123");
618    }
619
620    #[tokio::test]
621    async fn test_generate_unique_username() {
622        let user_info = UserInfo {
623            subject: "sub123".to_string(),
624            email: Some("john@example.com".to_string()),
625            preferred_username: Some("johndoe".to_string()),
626            groups: vec![],
627            is_admin: false,
628            is_read_only: false,
629        };
630
631        // First username is available
632        let username = generate_unique_username(&user_info, |_| async { true }).await;
633        assert_eq!(username, "johndoe");
634
635        // First username is taken, second is available
636        let username =
637            generate_unique_username(&user_info, |name| async move { name != "johndoe" }).await;
638        assert_eq!(username, "johndoe_2");
639
640        // First two are taken
641        let username = generate_unique_username(&user_info, |name| async move {
642            name != "johndoe" && name != "johndoe_2"
643        })
644        .await;
645        assert_eq!(username, "johndoe_3");
646    }
647
648    #[test]
649    fn test_extract_jwt_payload() {
650        // A test JWT with payload: {"sub":"1234567890","name":"John Doe","iat":1516239022}
651        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
652        let payload = extract_jwt_payload(jwt).unwrap();
653        assert_eq!(
654            payload.get("sub").and_then(|v| v.as_str()),
655            Some("1234567890")
656        );
657        assert_eq!(
658            payload.get("name").and_then(|v| v.as_str()),
659            Some("John Doe")
660        );
661    }
662
663    #[test]
664    fn test_get_string_array_from_json() {
665        let payload = serde_json::json!({
666            "groups": ["admin", "users"],
667            "single_group": "single",
668            "number": 42
669        });
670
671        assert_eq!(
672            get_string_array_from_json(&payload, "groups"),
673            Some(vec!["admin".to_string(), "users".to_string()])
674        );
675        assert_eq!(
676            get_string_array_from_json(&payload, "single_group"),
677            Some(vec!["single".to_string()])
678        );
679        assert_eq!(get_string_array_from_json(&payload, "number"), None);
680        assert_eq!(get_string_array_from_json(&payload, "missing"), None);
681    }
682}