Skip to main content

fastmcp_server/
oidc.rs

1//! OpenID Connect (OIDC) Provider for MCP.
2//!
3//! This module extends the OAuth 2.0/2.1 server with OpenID Connect identity
4//! layer features:
5//!
6//! - **ID Token Issuance**: JWT tokens containing user identity claims
7//! - **UserInfo Endpoint**: Standard endpoint for retrieving user claims
8//! - **Discovery Document**: `.well-known/openid-configuration` metadata
9//! - **Standard Claims**: OpenID Connect standard claim types
10//!
11//! # Architecture
12//!
13//! The OIDC provider builds on top of [`OAuthServer`] by:
14//!
15//! 1. Adding the `openid` scope to enable OIDC flows
16//! 2. Issuing ID tokens alongside access tokens
17//! 3. Providing standard endpoints for identity operations
18//!
19//! # Example
20//!
21//! ```ignore
22//! use fastmcp::oidc::{OidcProvider, OidcProviderConfig, UserClaims};
23//! use fastmcp::oauth::{OAuthServer, OAuthServerConfig};
24//!
25//! // Create OAuth server first
26//! let oauth = Arc::new(OAuthServer::new(OAuthServerConfig::default()));
27//!
28//! // Create OIDC provider on top
29//! let oidc = OidcProvider::new(oauth, OidcProviderConfig::default());
30//!
31//! // Set up user claims provider
32//! oidc.set_claims_provider(|subject| {
33//!     UserClaims::new(subject)
34//!         .with_name("John Doe")
35//!         .with_email("john@example.com")
36//! });
37//! ```
38
39use std::collections::HashMap;
40use std::sync::{Arc, RwLock};
41use std::time::{Duration, SystemTime, UNIX_EPOCH};
42
43use crate::oauth::{OAuthError, OAuthServer, OAuthToken};
44
45// =============================================================================
46// Configuration
47// =============================================================================
48
49/// Configuration for the OIDC provider.
50#[derive(Debug, Clone)]
51pub struct OidcProviderConfig {
52    /// Issuer identifier (URL) - must match OAuth server issuer.
53    pub issuer: String,
54    /// ID token lifetime.
55    pub id_token_lifetime: Duration,
56    /// Signing algorithm for ID tokens.
57    pub signing_algorithm: SigningAlgorithm,
58    /// Key ID for token signing.
59    pub key_id: Option<String>,
60    /// Supported claims.
61    pub supported_claims: Vec<String>,
62    /// Supported scopes beyond `openid`.
63    pub supported_scopes: Vec<String>,
64}
65
66impl Default for OidcProviderConfig {
67    fn default() -> Self {
68        Self {
69            issuer: "fastmcp".to_string(),
70            id_token_lifetime: Duration::from_secs(3600), // 1 hour
71            signing_algorithm: SigningAlgorithm::HS256,
72            key_id: None,
73            supported_claims: vec![
74                "sub".to_string(),
75                "name".to_string(),
76                "email".to_string(),
77                "email_verified".to_string(),
78                "preferred_username".to_string(),
79                "picture".to_string(),
80                "updated_at".to_string(),
81            ],
82            supported_scopes: vec![
83                "openid".to_string(),
84                "profile".to_string(),
85                "email".to_string(),
86            ],
87        }
88    }
89}
90
91/// Signing algorithm for ID tokens.
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum SigningAlgorithm {
94    /// HMAC-SHA256 (symmetric).
95    HS256,
96    /// RSA-SHA256 (asymmetric) - requires RSA key pair.
97    RS256,
98}
99
100impl SigningAlgorithm {
101    /// Returns the algorithm name as used in JWT headers.
102    #[must_use]
103    pub fn as_str(&self) -> &'static str {
104        match self {
105            Self::HS256 => "HS256",
106            Self::RS256 => "RS256",
107        }
108    }
109}
110
111// =============================================================================
112// User Claims
113// =============================================================================
114
115/// Standard OpenID Connect user claims.
116///
117/// These claims describe the authenticated user and are included in
118/// ID tokens and returned from the userinfo endpoint.
119#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
120pub struct UserClaims {
121    /// Subject identifier (required, unique user ID).
122    pub sub: String,
123
124    // Profile scope claims
125    /// User's full name.
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub name: Option<String>,
128    /// User's given/first name.
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub given_name: Option<String>,
131    /// User's family/last name.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub family_name: Option<String>,
134    /// User's middle name.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub middle_name: Option<String>,
137    /// User's nickname/username.
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub nickname: Option<String>,
140    /// User's preferred username.
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub preferred_username: Option<String>,
143    /// URL of user's profile page.
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub profile: Option<String>,
146    /// URL of user's profile picture.
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub picture: Option<String>,
149    /// URL of user's website.
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub website: Option<String>,
152    /// User's gender.
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub gender: Option<String>,
155    /// User's birthday (ISO 8601 date).
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub birthdate: Option<String>,
158    /// User's timezone (IANA timezone string).
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub zoneinfo: Option<String>,
161    /// User's locale (BCP47 language tag).
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub locale: Option<String>,
164    /// Time the user's info was last updated (Unix timestamp).
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub updated_at: Option<i64>,
167
168    // Email scope claims
169    /// User's email address.
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub email: Option<String>,
172    /// Whether the email has been verified.
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub email_verified: Option<bool>,
175
176    // Phone scope claims
177    /// User's phone number.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub phone_number: Option<String>,
180    /// Whether the phone number has been verified.
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub phone_number_verified: Option<bool>,
183
184    // Address scope claims
185    /// User's address (JSON object).
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub address: Option<AddressClaim>,
188
189    /// Additional custom claims.
190    #[serde(flatten)]
191    pub custom: HashMap<String, serde_json::Value>,
192}
193
194impl UserClaims {
195    /// Creates new user claims with the given subject.
196    #[must_use]
197    pub fn new(sub: impl Into<String>) -> Self {
198        Self {
199            sub: sub.into(),
200            ..Default::default()
201        }
202    }
203
204    /// Sets the user's full name.
205    #[must_use]
206    pub fn with_name(mut self, name: impl Into<String>) -> Self {
207        self.name = Some(name.into());
208        self
209    }
210
211    /// Sets the user's email.
212    #[must_use]
213    pub fn with_email(mut self, email: impl Into<String>) -> Self {
214        self.email = Some(email.into());
215        self
216    }
217
218    /// Sets whether the email is verified.
219    #[must_use]
220    pub fn with_email_verified(mut self, verified: bool) -> Self {
221        self.email_verified = Some(verified);
222        self
223    }
224
225    /// Sets the user's preferred username.
226    #[must_use]
227    pub fn with_preferred_username(mut self, username: impl Into<String>) -> Self {
228        self.preferred_username = Some(username.into());
229        self
230    }
231
232    /// Sets the user's profile picture URL.
233    #[must_use]
234    pub fn with_picture(mut self, url: impl Into<String>) -> Self {
235        self.picture = Some(url.into());
236        self
237    }
238
239    /// Sets the user's given name.
240    #[must_use]
241    pub fn with_given_name(mut self, name: impl Into<String>) -> Self {
242        self.given_name = Some(name.into());
243        self
244    }
245
246    /// Sets the user's family name.
247    #[must_use]
248    pub fn with_family_name(mut self, name: impl Into<String>) -> Self {
249        self.family_name = Some(name.into());
250        self
251    }
252
253    /// Sets the user's phone number.
254    #[must_use]
255    pub fn with_phone_number(mut self, phone: impl Into<String>) -> Self {
256        self.phone_number = Some(phone.into());
257        self
258    }
259
260    /// Sets the updated_at timestamp.
261    #[must_use]
262    pub fn with_updated_at(mut self, timestamp: i64) -> Self {
263        self.updated_at = Some(timestamp);
264        self
265    }
266
267    /// Adds a custom claim.
268    #[must_use]
269    pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
270        self.custom.insert(key.into(), value);
271        self
272    }
273
274    /// Filters claims based on requested scopes.
275    ///
276    /// Only returns claims that are allowed by the given scopes.
277    #[must_use]
278    #[allow(clippy::assigning_clones)]
279    pub fn filter_by_scopes(&self, scopes: &[String]) -> UserClaims {
280        let mut filtered = UserClaims::new(&self.sub);
281
282        // Profile scope claims
283        if scopes.iter().any(|s| s == "profile") {
284            filtered.name = self.name.clone();
285            filtered.given_name = self.given_name.clone();
286            filtered.family_name = self.family_name.clone();
287            filtered.middle_name = self.middle_name.clone();
288            filtered.nickname = self.nickname.clone();
289            filtered.preferred_username = self.preferred_username.clone();
290            filtered.profile = self.profile.clone();
291            filtered.picture = self.picture.clone();
292            filtered.website = self.website.clone();
293            filtered.gender = self.gender.clone();
294            filtered.birthdate = self.birthdate.clone();
295            filtered.zoneinfo = self.zoneinfo.clone();
296            filtered.locale = self.locale.clone();
297            filtered.updated_at = self.updated_at;
298        }
299
300        // Email scope claims
301        if scopes.iter().any(|s| s == "email") {
302            filtered.email = self.email.clone();
303            filtered.email_verified = self.email_verified;
304        }
305
306        // Phone scope claims
307        if scopes.iter().any(|s| s == "phone") {
308            filtered.phone_number = self.phone_number.clone();
309            filtered.phone_number_verified = self.phone_number_verified;
310        }
311
312        // Address scope claims
313        if scopes.iter().any(|s| s == "address") {
314            filtered.address = self.address.clone();
315        }
316
317        filtered
318    }
319}
320
321/// Address claim structure per OpenID Connect spec.
322#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
323pub struct AddressClaim {
324    /// Full formatted address.
325    #[serde(skip_serializing_if = "Option::is_none")]
326    pub formatted: Option<String>,
327    /// Street address.
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub street_address: Option<String>,
330    /// City/locality.
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub locality: Option<String>,
333    /// State/region.
334    #[serde(skip_serializing_if = "Option::is_none")]
335    pub region: Option<String>,
336    /// Postal/zip code.
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub postal_code: Option<String>,
339    /// Country.
340    #[serde(skip_serializing_if = "Option::is_none")]
341    pub country: Option<String>,
342}
343
344// =============================================================================
345// ID Token
346// =============================================================================
347
348/// ID Token claims (JWT payload).
349#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
350pub struct IdTokenClaims {
351    /// Issuer identifier.
352    pub iss: String,
353    /// Subject identifier.
354    pub sub: String,
355    /// Audience (client ID).
356    pub aud: String,
357    /// Expiration time (Unix timestamp).
358    pub exp: i64,
359    /// Issued at time (Unix timestamp).
360    pub iat: i64,
361    /// Authentication time (Unix timestamp).
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub auth_time: Option<i64>,
364    /// Nonce from authorization request.
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub nonce: Option<String>,
367    /// Authentication Context Class Reference.
368    #[serde(skip_serializing_if = "Option::is_none")]
369    pub acr: Option<String>,
370    /// Authentication Methods References.
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub amr: Option<Vec<String>>,
373    /// Authorized party (client ID that was issued the token).
374    #[serde(skip_serializing_if = "Option::is_none")]
375    pub azp: Option<String>,
376    /// Access token hash (for hybrid flows).
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub at_hash: Option<String>,
379    /// Code hash (for hybrid flows).
380    #[serde(skip_serializing_if = "Option::is_none")]
381    pub c_hash: Option<String>,
382    /// Additional user claims.
383    #[serde(flatten)]
384    pub user_claims: UserClaims,
385}
386
387/// A signed ID token.
388#[derive(Debug, Clone)]
389pub struct IdToken {
390    /// The raw JWT string.
391    pub raw: String,
392    /// The parsed claims.
393    pub claims: IdTokenClaims,
394}
395
396// =============================================================================
397// Discovery Document
398// =============================================================================
399
400/// OpenID Connect Discovery Document.
401///
402/// This is served at `/.well-known/openid-configuration`.
403#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
404pub struct DiscoveryDocument {
405    /// Issuer identifier URL.
406    pub issuer: String,
407    /// Authorization endpoint URL.
408    pub authorization_endpoint: String,
409    /// Token endpoint URL.
410    pub token_endpoint: String,
411    /// UserInfo endpoint URL.
412    #[serde(skip_serializing_if = "Option::is_none")]
413    pub userinfo_endpoint: Option<String>,
414    /// JWKs URI for public key retrieval.
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub jwks_uri: Option<String>,
417    /// Registration endpoint URL.
418    #[serde(skip_serializing_if = "Option::is_none")]
419    pub registration_endpoint: Option<String>,
420    /// Revocation endpoint URL.
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub revocation_endpoint: Option<String>,
423    /// Supported scopes.
424    pub scopes_supported: Vec<String>,
425    /// Supported response types.
426    pub response_types_supported: Vec<String>,
427    /// Supported response modes.
428    #[serde(skip_serializing_if = "Option::is_none")]
429    pub response_modes_supported: Option<Vec<String>>,
430    /// Supported grant types.
431    pub grant_types_supported: Vec<String>,
432    /// Supported subject types.
433    pub subject_types_supported: Vec<String>,
434    /// Supported ID token signing algorithms.
435    pub id_token_signing_alg_values_supported: Vec<String>,
436    /// Supported token endpoint auth methods.
437    pub token_endpoint_auth_methods_supported: Vec<String>,
438    /// Supported claims.
439    #[serde(skip_serializing_if = "Option::is_none")]
440    pub claims_supported: Option<Vec<String>>,
441    /// Supported code challenge methods.
442    #[serde(skip_serializing_if = "Option::is_none")]
443    pub code_challenge_methods_supported: Option<Vec<String>>,
444}
445
446impl DiscoveryDocument {
447    /// Creates a new discovery document with the given issuer and base URL.
448    #[must_use]
449    pub fn new(issuer: impl Into<String>, base_url: impl Into<String>) -> Self {
450        let issuer = issuer.into();
451        let base = base_url.into();
452
453        Self {
454            issuer: issuer.clone(),
455            authorization_endpoint: format!("{}/authorize", base),
456            token_endpoint: format!("{}/token", base),
457            userinfo_endpoint: Some(format!("{}/userinfo", base)),
458            jwks_uri: Some(format!("{}/.well-known/jwks.json", base)),
459            registration_endpoint: None,
460            revocation_endpoint: Some(format!("{}/revoke", base)),
461            scopes_supported: vec![
462                "openid".to_string(),
463                "profile".to_string(),
464                "email".to_string(),
465            ],
466            response_types_supported: vec!["code".to_string()],
467            response_modes_supported: Some(vec!["query".to_string()]),
468            grant_types_supported: vec![
469                "authorization_code".to_string(),
470                "refresh_token".to_string(),
471            ],
472            subject_types_supported: vec!["public".to_string()],
473            id_token_signing_alg_values_supported: vec!["HS256".to_string()],
474            token_endpoint_auth_methods_supported: vec![
475                "client_secret_post".to_string(),
476                "client_secret_basic".to_string(),
477            ],
478            claims_supported: Some(vec![
479                "sub".to_string(),
480                "iss".to_string(),
481                "aud".to_string(),
482                "exp".to_string(),
483                "iat".to_string(),
484                "name".to_string(),
485                "email".to_string(),
486                "email_verified".to_string(),
487                "preferred_username".to_string(),
488                "picture".to_string(),
489            ]),
490            code_challenge_methods_supported: Some(vec!["plain".to_string(), "S256".to_string()]),
491        }
492    }
493}
494
495// =============================================================================
496// Claims Provider
497// =============================================================================
498
499/// Trait for providing user claims.
500pub trait ClaimsProvider: Send + Sync {
501    /// Retrieves claims for a user by subject identifier.
502    ///
503    /// Returns `None` if the user is not found.
504    fn get_claims(&self, subject: &str) -> Option<UserClaims>;
505}
506
507/// Simple in-memory claims provider.
508#[derive(Debug, Default)]
509pub struct InMemoryClaimsProvider {
510    claims: RwLock<HashMap<String, UserClaims>>,
511}
512
513impl InMemoryClaimsProvider {
514    /// Creates a new empty claims provider.
515    #[must_use]
516    pub fn new() -> Self {
517        Self::default()
518    }
519
520    /// Adds or updates claims for a user.
521    pub fn set_claims(&self, claims: UserClaims) {
522        if let Ok(mut guard) = self.claims.write() {
523            guard.insert(claims.sub.clone(), claims);
524        }
525    }
526
527    /// Removes claims for a user.
528    pub fn remove_claims(&self, subject: &str) {
529        if let Ok(mut guard) = self.claims.write() {
530            guard.remove(subject);
531        }
532    }
533}
534
535impl ClaimsProvider for InMemoryClaimsProvider {
536    fn get_claims(&self, subject: &str) -> Option<UserClaims> {
537        self.claims
538            .read()
539            .ok()
540            .and_then(|guard| guard.get(subject).cloned())
541    }
542}
543
544/// Function-based claims provider.
545pub struct FnClaimsProvider<F>
546where
547    F: Fn(&str) -> Option<UserClaims> + Send + Sync,
548{
549    func: F,
550}
551
552impl<F> FnClaimsProvider<F>
553where
554    F: Fn(&str) -> Option<UserClaims> + Send + Sync,
555{
556    /// Creates a new function-based claims provider.
557    #[must_use]
558    pub fn new(func: F) -> Self {
559        Self { func }
560    }
561}
562
563impl<F> ClaimsProvider for FnClaimsProvider<F>
564where
565    F: Fn(&str) -> Option<UserClaims> + Send + Sync,
566{
567    fn get_claims(&self, subject: &str) -> Option<UserClaims> {
568        (self.func)(subject)
569    }
570}
571
572impl ClaimsProvider for Arc<dyn ClaimsProvider> {
573    fn get_claims(&self, subject: &str) -> Option<UserClaims> {
574        (**self).get_claims(subject)
575    }
576}
577
578// =============================================================================
579// OIDC Errors
580// =============================================================================
581
582/// OIDC-specific errors.
583#[derive(Debug, Clone)]
584pub enum OidcError {
585    /// Underlying OAuth error.
586    OAuth(OAuthError),
587    /// Missing openid scope.
588    MissingOpenIdScope,
589    /// User claims not found.
590    ClaimsNotFound(String),
591    /// Token signing failed.
592    SigningError(String),
593    /// Invalid ID token.
594    InvalidIdToken(String),
595}
596
597impl std::fmt::Display for OidcError {
598    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
599        match self {
600            Self::OAuth(e) => write!(f, "OAuth error: {}", e),
601            Self::MissingOpenIdScope => write!(f, "missing 'openid' scope"),
602            Self::ClaimsNotFound(s) => write!(f, "claims not found for subject: {}", s),
603            Self::SigningError(s) => write!(f, "signing error: {}", s),
604            Self::InvalidIdToken(s) => write!(f, "invalid ID token: {}", s),
605        }
606    }
607}
608
609impl std::error::Error for OidcError {}
610
611impl From<OAuthError> for OidcError {
612    fn from(err: OAuthError) -> Self {
613        Self::OAuth(err)
614    }
615}
616
617// =============================================================================
618// OIDC Provider
619// =============================================================================
620
621/// OpenID Connect Provider.
622///
623/// This extends the OAuth server with OIDC identity features.
624pub struct OidcProvider {
625    /// Underlying OAuth server.
626    oauth: Arc<OAuthServer>,
627    /// OIDC configuration.
628    config: OidcProviderConfig,
629    /// Signing key (HMAC secret or RSA private key).
630    signing_key: RwLock<SigningKey>,
631    /// Claims provider.
632    claims_provider: RwLock<Option<Arc<dyn ClaimsProvider>>>,
633    /// Cached ID tokens by access token.
634    id_tokens: RwLock<HashMap<String, IdToken>>,
635}
636
637/// Signing key for ID tokens.
638#[derive(Clone, Default)]
639enum SigningKey {
640    /// HMAC-SHA256 secret.
641    Hmac(Vec<u8>),
642    /// No key configured (will generate on first use).
643    #[default]
644    None,
645}
646
647impl OidcProvider {
648    /// Creates a new OIDC provider with the given OAuth server.
649    #[must_use]
650    pub fn new(oauth: Arc<OAuthServer>, config: OidcProviderConfig) -> Self {
651        Self {
652            oauth,
653            config,
654            signing_key: RwLock::new(SigningKey::None),
655            claims_provider: RwLock::new(None),
656            id_tokens: RwLock::new(HashMap::new()),
657        }
658    }
659
660    /// Creates a new OIDC provider with default configuration.
661    #[must_use]
662    pub fn with_defaults(oauth: Arc<OAuthServer>) -> Self {
663        Self::new(oauth, OidcProviderConfig::default())
664    }
665
666    /// Returns the OIDC configuration.
667    #[must_use]
668    pub fn config(&self) -> &OidcProviderConfig {
669        &self.config
670    }
671
672    /// Returns a reference to the underlying OAuth server.
673    #[must_use]
674    pub fn oauth(&self) -> &Arc<OAuthServer> {
675        &self.oauth
676    }
677
678    /// Sets the HMAC signing key.
679    pub fn set_hmac_key(&self, key: impl AsRef<[u8]>) {
680        if let Ok(mut guard) = self.signing_key.write() {
681            *guard = SigningKey::Hmac(key.as_ref().to_vec());
682        }
683    }
684
685    /// Sets the claims provider.
686    pub fn set_claims_provider<P: ClaimsProvider + 'static>(&self, provider: P) {
687        if let Ok(mut guard) = self.claims_provider.write() {
688            *guard = Some(Arc::new(provider));
689        }
690    }
691
692    /// Sets a function-based claims provider.
693    pub fn set_claims_fn<F>(&self, func: F)
694    where
695        F: Fn(&str) -> Option<UserClaims> + Send + Sync + 'static,
696    {
697        self.set_claims_provider(FnClaimsProvider::new(func));
698    }
699
700    /// Generates the discovery document.
701    #[must_use]
702    pub fn discovery_document(&self, base_url: impl Into<String>) -> DiscoveryDocument {
703        let mut doc = DiscoveryDocument::new(&self.config.issuer, base_url);
704        doc.scopes_supported = self.config.supported_scopes.clone();
705        doc.claims_supported = Some(self.config.supported_claims.clone());
706        doc.id_token_signing_alg_values_supported =
707            vec![self.config.signing_algorithm.as_str().to_string()];
708        doc
709    }
710
711    // -------------------------------------------------------------------------
712    // ID Token Issuance
713    // -------------------------------------------------------------------------
714
715    /// Issues an ID token for the given access token.
716    ///
717    /// This should be called after a successful token exchange when the
718    /// `openid` scope was requested.
719    pub fn issue_id_token(
720        &self,
721        access_token: &OAuthToken,
722        nonce: Option<&str>,
723    ) -> Result<IdToken, OidcError> {
724        // Verify openid scope
725        if !access_token.scopes.iter().any(|s| s == "openid") {
726            return Err(OidcError::MissingOpenIdScope);
727        }
728
729        let subject = access_token
730            .subject
731            .as_ref()
732            .ok_or_else(|| OidcError::ClaimsNotFound("no subject in access token".to_string()))?;
733
734        // Get user claims
735        let user_claims = self.get_user_claims(subject, &access_token.scopes)?;
736
737        // Build ID token claims
738        let now = SystemTime::now()
739            .duration_since(UNIX_EPOCH)
740            .unwrap_or_default()
741            .as_secs() as i64;
742
743        let claims = IdTokenClaims {
744            iss: self.config.issuer.clone(),
745            sub: subject.clone(),
746            aud: access_token.client_id.clone(),
747            exp: now + self.config.id_token_lifetime.as_secs() as i64,
748            iat: now,
749            auth_time: Some(now),
750            nonce: nonce.map(String::from),
751            acr: None,
752            amr: None,
753            azp: Some(access_token.client_id.clone()),
754            at_hash: Some(self.compute_at_hash(&access_token.token)),
755            c_hash: None,
756            user_claims,
757        };
758
759        // Sign the token
760        let raw = self.sign_id_token(&claims)?;
761
762        let id_token = IdToken { raw, claims };
763
764        // Cache the ID token
765        if let Ok(mut guard) = self.id_tokens.write() {
766            guard.insert(access_token.token.clone(), id_token.clone());
767        }
768
769        Ok(id_token)
770    }
771
772    /// Gets the ID token associated with an access token.
773    #[must_use]
774    pub fn get_id_token(&self, access_token: &str) -> Option<IdToken> {
775        self.id_tokens
776            .read()
777            .ok()
778            .and_then(|guard| guard.get(access_token).cloned())
779    }
780
781    // -------------------------------------------------------------------------
782    // UserInfo Endpoint
783    // -------------------------------------------------------------------------
784
785    /// Handles a userinfo request.
786    ///
787    /// Returns the user's claims filtered by the access token's scopes.
788    pub fn userinfo(&self, access_token: &str) -> Result<UserClaims, OidcError> {
789        // Validate access token
790        let token = self
791            .oauth
792            .validate_access_token(access_token)
793            .ok_or_else(|| {
794                OidcError::OAuth(OAuthError::InvalidGrant(
795                    "invalid or expired access token".to_string(),
796                ))
797            })?;
798
799        // Verify openid scope
800        if !token.scopes.iter().any(|s| s == "openid") {
801            return Err(OidcError::MissingOpenIdScope);
802        }
803
804        let subject = token
805            .subject
806            .as_ref()
807            .ok_or_else(|| OidcError::ClaimsNotFound("no subject in access token".to_string()))?;
808
809        self.get_user_claims(subject, &token.scopes)
810    }
811
812    // -------------------------------------------------------------------------
813    // Helper Methods
814    // -------------------------------------------------------------------------
815
816    fn get_user_claims(&self, subject: &str, scopes: &[String]) -> Result<UserClaims, OidcError> {
817        let provider = self
818            .claims_provider
819            .read()
820            .ok()
821            .and_then(|guard| guard.clone());
822
823        let claims = match provider {
824            Some(p) => p
825                .get_claims(subject)
826                .ok_or_else(|| OidcError::ClaimsNotFound(subject.to_string()))?,
827            None => {
828                // Default: just return subject
829                UserClaims::new(subject)
830            }
831        };
832
833        Ok(claims.filter_by_scopes(scopes))
834    }
835
836    fn sign_id_token(&self, claims: &IdTokenClaims) -> Result<String, OidcError> {
837        let key = self.get_or_generate_signing_key()?;
838
839        // Build JWT
840        let header = serde_json::json!({
841            "alg": self.config.signing_algorithm.as_str(),
842            "typ": "JWT",
843            "kid": self.config.key_id.as_deref().unwrap_or("default"),
844        });
845
846        let header_b64 =
847            base64url_encode(&serde_json::to_vec(&header).map_err(|e| {
848                OidcError::SigningError(format!("failed to serialize header: {}", e))
849            })?);
850
851        let claims_b64 =
852            base64url_encode(&serde_json::to_vec(claims).map_err(|e| {
853                OidcError::SigningError(format!("failed to serialize claims: {}", e))
854            })?);
855
856        let signing_input = format!("{}.{}", header_b64, claims_b64);
857
858        let signature = match &key {
859            SigningKey::Hmac(secret) => hmac_sha256(&signing_input, secret),
860            SigningKey::None => {
861                return Err(OidcError::SigningError(
862                    "no signing key configured".to_string(),
863                ));
864            }
865        };
866
867        let signature_b64 = base64url_encode(&signature);
868
869        Ok(format!("{}.{}", signing_input, signature_b64))
870    }
871
872    fn get_or_generate_signing_key(&self) -> Result<SigningKey, OidcError> {
873        let guard = self
874            .signing_key
875            .read()
876            .map_err(|_| OidcError::SigningError("failed to acquire read lock".to_string()))?;
877
878        match &*guard {
879            SigningKey::None => {
880                // Generate a random key
881                drop(guard);
882                let mut write_guard = self.signing_key.write().map_err(|_| {
883                    OidcError::SigningError("failed to acquire write lock".to_string())
884                })?;
885
886                // Double-check after acquiring write lock
887                if matches!(&*write_guard, SigningKey::None) {
888                    let key = generate_random_bytes(32);
889                    *write_guard = SigningKey::Hmac(key.clone());
890                    Ok(SigningKey::Hmac(key))
891                } else {
892                    Ok(write_guard.clone())
893                }
894            }
895            key => Ok(key.clone()),
896        }
897    }
898
899    fn compute_at_hash(&self, access_token: &str) -> String {
900        // at_hash is left half of hash of access token
901        let hash = simple_sha256(access_token.as_bytes());
902        base64url_encode(&hash[..16])
903    }
904
905    /// Removes expired ID tokens from cache.
906    pub fn cleanup_expired(&self) {
907        let now = SystemTime::now()
908            .duration_since(UNIX_EPOCH)
909            .unwrap_or_default()
910            .as_secs() as i64;
911
912        if let Ok(mut guard) = self.id_tokens.write() {
913            guard.retain(|_, token| token.claims.exp > now);
914        }
915    }
916}
917
918// =============================================================================
919// Helper Functions
920// =============================================================================
921
922/// Base64url encodes bytes (no padding).
923fn base64url_encode(data: &[u8]) -> String {
924    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
925
926    let mut result = String::with_capacity((data.len() * 4 + 2) / 3);
927    let mut i = 0;
928
929    while i + 2 < data.len() {
930        let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8) | u32::from(data[i + 2]);
931        result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
932        result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
933        result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
934        result.push(ALPHABET[n as usize & 0x3F] as char);
935        i += 3;
936    }
937
938    if i + 1 == data.len() {
939        let n = u32::from(data[i]) << 16;
940        result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
941        result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
942    } else if i + 2 == data.len() {
943        let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8);
944        result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
945        result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
946        result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
947    }
948
949    result
950}
951
952/// Simple SHA-256 (for demonstration - use a real crypto library in production).
953fn simple_sha256(data: &[u8]) -> [u8; 32] {
954    use std::collections::hash_map::RandomState;
955    use std::hash::{BuildHasher, Hasher};
956
957    let mut result = [0u8; 32];
958    let state = RandomState::new();
959
960    for (i, chunk) in result.chunks_mut(8).enumerate() {
961        let mut hasher = state.build_hasher();
962        hasher.write(data);
963        hasher.write_usize(i);
964        let hash = hasher.finish().to_le_bytes();
965        chunk.copy_from_slice(&hash[..chunk.len()]);
966    }
967
968    result
969}
970
971/// HMAC-SHA256 (simplified - use a real crypto library in production).
972fn hmac_sha256(message: &str, key: &[u8]) -> [u8; 32] {
973    // This is a simplified HMAC for demonstration.
974    // In production, use ring, hmac, or similar crates.
975    let mut combined = Vec::with_capacity(key.len() + message.len());
976    combined.extend_from_slice(key);
977    combined.extend_from_slice(message.as_bytes());
978    simple_sha256(&combined)
979}
980
981/// Generates random bytes.
982fn generate_random_bytes(len: usize) -> Vec<u8> {
983    use std::collections::hash_map::RandomState;
984    use std::hash::{BuildHasher, Hasher};
985
986    let mut result = Vec::with_capacity(len);
987    let state = RandomState::new();
988
989    for i in 0..len {
990        let mut hasher = state.build_hasher();
991        hasher.write_usize(i);
992        hasher.write_u128(
993            SystemTime::now()
994                .duration_since(UNIX_EPOCH)
995                .unwrap_or_default()
996                .as_nanos(),
997        );
998        result.push((hasher.finish() & 0xFF) as u8);
999    }
1000
1001    result
1002}
1003
1004// =============================================================================
1005// Tests
1006// =============================================================================
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use crate::oauth::{OAuthClient, OAuthServerConfig};
1012    use std::time::Instant;
1013
1014    fn create_test_provider() -> OidcProvider {
1015        let oauth = Arc::new(OAuthServer::new(OAuthServerConfig::default()));
1016        OidcProvider::with_defaults(oauth)
1017    }
1018
1019    #[test]
1020    fn test_user_claims_builder() {
1021        let claims = UserClaims::new("user123")
1022            .with_name("John Doe")
1023            .with_email("john@example.com")
1024            .with_email_verified(true)
1025            .with_preferred_username("johnd");
1026
1027        assert_eq!(claims.sub, "user123");
1028        assert_eq!(claims.name, Some("John Doe".to_string()));
1029        assert_eq!(claims.email, Some("john@example.com".to_string()));
1030        assert_eq!(claims.email_verified, Some(true));
1031        assert_eq!(claims.preferred_username, Some("johnd".to_string()));
1032    }
1033
1034    #[test]
1035    fn test_claims_filter_by_scopes() {
1036        let claims = UserClaims::new("user123")
1037            .with_name("John Doe")
1038            .with_email("john@example.com")
1039            .with_phone_number("+1234567890");
1040
1041        // Only openid scope - just sub
1042        let filtered = claims.filter_by_scopes(&["openid".to_string()]);
1043        assert_eq!(filtered.sub, "user123");
1044        assert!(filtered.name.is_none());
1045        assert!(filtered.email.is_none());
1046
1047        // Profile scope
1048        let filtered = claims.filter_by_scopes(&["openid".to_string(), "profile".to_string()]);
1049        assert_eq!(filtered.name, Some("John Doe".to_string()));
1050        assert!(filtered.email.is_none());
1051
1052        // Email scope
1053        let filtered = claims.filter_by_scopes(&["openid".to_string(), "email".to_string()]);
1054        assert!(filtered.name.is_none());
1055        assert_eq!(filtered.email, Some("john@example.com".to_string()));
1056
1057        // All scopes
1058        let filtered = claims.filter_by_scopes(&[
1059            "openid".to_string(),
1060            "profile".to_string(),
1061            "email".to_string(),
1062            "phone".to_string(),
1063        ]);
1064        assert_eq!(filtered.name, Some("John Doe".to_string()));
1065        assert_eq!(filtered.email, Some("john@example.com".to_string()));
1066        assert_eq!(filtered.phone_number, Some("+1234567890".to_string()));
1067    }
1068
1069    #[test]
1070    fn test_discovery_document() {
1071        let provider = create_test_provider();
1072        let doc = provider.discovery_document("https://example.com");
1073
1074        assert_eq!(doc.issuer, "fastmcp");
1075        assert_eq!(doc.authorization_endpoint, "https://example.com/authorize");
1076        assert_eq!(doc.token_endpoint, "https://example.com/token");
1077        assert!(doc.scopes_supported.contains(&"openid".to_string()));
1078        assert!(doc.response_types_supported.contains(&"code".to_string()));
1079    }
1080
1081    #[test]
1082    fn test_in_memory_claims_provider() {
1083        let provider = InMemoryClaimsProvider::new();
1084
1085        let claims = UserClaims::new("user123")
1086            .with_name("John Doe")
1087            .with_email("john@example.com");
1088
1089        provider.set_claims(claims);
1090
1091        let retrieved = provider.get_claims("user123");
1092        assert!(retrieved.is_some());
1093        assert_eq!(retrieved.unwrap().name, Some("John Doe".to_string()));
1094
1095        assert!(provider.get_claims("nonexistent").is_none());
1096
1097        provider.remove_claims("user123");
1098        assert!(provider.get_claims("user123").is_none());
1099    }
1100
1101    #[test]
1102    fn test_fn_claims_provider() {
1103        let provider = FnClaimsProvider::new(|subject| {
1104            if subject == "user123" {
1105                Some(UserClaims::new(subject).with_name("John Doe"))
1106            } else {
1107                None
1108            }
1109        });
1110
1111        let claims = provider.get_claims("user123");
1112        assert!(claims.is_some());
1113        assert_eq!(claims.unwrap().name, Some("John Doe".to_string()));
1114
1115        assert!(provider.get_claims("other").is_none());
1116    }
1117
1118    #[test]
1119    fn test_signing_algorithm() {
1120        assert_eq!(SigningAlgorithm::HS256.as_str(), "HS256");
1121        assert_eq!(SigningAlgorithm::RS256.as_str(), "RS256");
1122    }
1123
1124    #[test]
1125    fn test_oidc_error_display() {
1126        let err = OidcError::MissingOpenIdScope;
1127        assert_eq!(err.to_string(), "missing 'openid' scope");
1128
1129        let err = OidcError::ClaimsNotFound("user123".to_string());
1130        assert!(err.to_string().contains("user123"));
1131    }
1132
1133    #[test]
1134    fn test_base64url_encode() {
1135        assert_eq!(base64url_encode(b""), "");
1136        assert_eq!(base64url_encode(b"f"), "Zg");
1137        assert_eq!(base64url_encode(b"fo"), "Zm8");
1138        assert_eq!(base64url_encode(b"foo"), "Zm9v");
1139    }
1140
1141    #[test]
1142    fn test_id_token_issuance() {
1143        let provider = create_test_provider();
1144
1145        // Set up claims provider
1146        let claims_provider = InMemoryClaimsProvider::new();
1147        claims_provider.set_claims(
1148            UserClaims::new("user123")
1149                .with_name("John Doe")
1150                .with_email("john@example.com"),
1151        );
1152        provider.set_claims_provider(claims_provider);
1153
1154        // Set signing key
1155        provider.set_hmac_key(b"test-secret-key");
1156
1157        // Create a mock access token with openid scope
1158        let now = Instant::now();
1159        let access_token = crate::oauth::OAuthToken {
1160            token: "test-access-token".to_string(),
1161            token_type: crate::oauth::TokenType::Bearer,
1162            client_id: "test-client".to_string(),
1163            scopes: vec![
1164                "openid".to_string(),
1165                "profile".to_string(),
1166                "email".to_string(),
1167            ],
1168            issued_at: now,
1169            expires_at: now + Duration::from_secs(3600),
1170            subject: Some("user123".to_string()),
1171            is_refresh_token: false,
1172        };
1173
1174        let result = provider.issue_id_token(&access_token, Some("nonce123"));
1175        assert!(result.is_ok());
1176
1177        let id_token = result.unwrap();
1178        assert!(!id_token.raw.is_empty());
1179        assert!(id_token.raw.contains('.'));
1180        assert_eq!(id_token.claims.sub, "user123");
1181        assert_eq!(id_token.claims.aud, "test-client");
1182        assert_eq!(id_token.claims.nonce, Some("nonce123".to_string()));
1183        assert_eq!(
1184            id_token.claims.user_claims.name,
1185            Some("John Doe".to_string())
1186        );
1187    }
1188
1189    #[test]
1190    fn test_id_token_requires_openid_scope() {
1191        let provider = create_test_provider();
1192
1193        let now = Instant::now();
1194        let access_token = crate::oauth::OAuthToken {
1195            token: "test-access-token".to_string(),
1196            token_type: crate::oauth::TokenType::Bearer,
1197            client_id: "test-client".to_string(),
1198            scopes: vec!["profile".to_string()], // No openid scope
1199            issued_at: now,
1200            expires_at: now + Duration::from_secs(3600),
1201            subject: Some("user123".to_string()),
1202            is_refresh_token: false,
1203        };
1204
1205        let result = provider.issue_id_token(&access_token, None);
1206        assert!(matches!(result, Err(OidcError::MissingOpenIdScope)));
1207    }
1208
1209    #[test]
1210    fn test_userinfo() {
1211        let oauth = Arc::new(OAuthServer::new(OAuthServerConfig::default()));
1212
1213        // Register a client
1214        let client = OAuthClient::builder("test-client")
1215            .redirect_uri("http://localhost:3000/callback")
1216            .scope("openid")
1217            .scope("profile")
1218            .build()
1219            .unwrap();
1220        oauth.register_client(client).unwrap();
1221
1222        // Create an access token manually
1223        {
1224            let mut state = oauth.state.write().unwrap();
1225            let now = Instant::now();
1226            let token = crate::oauth::OAuthToken {
1227                token: "test-token".to_string(),
1228                token_type: crate::oauth::TokenType::Bearer,
1229                client_id: "test-client".to_string(),
1230                scopes: vec!["openid".to_string(), "profile".to_string()],
1231                issued_at: now,
1232                expires_at: now + Duration::from_secs(3600),
1233                subject: Some("user123".to_string()),
1234                is_refresh_token: false,
1235            };
1236            state.access_tokens.insert("test-token".to_string(), token);
1237        }
1238
1239        let provider = OidcProvider::with_defaults(oauth);
1240
1241        // Set up claims
1242        let claims_store = InMemoryClaimsProvider::new();
1243        claims_store.set_claims(UserClaims::new("user123").with_name("John Doe"));
1244        provider.set_claims_provider(claims_store);
1245
1246        let result = provider.userinfo("test-token");
1247        assert!(result.is_ok());
1248
1249        let claims = result.unwrap();
1250        assert_eq!(claims.sub, "user123");
1251        assert_eq!(claims.name, Some("John Doe".to_string()));
1252    }
1253
1254    #[test]
1255    fn test_address_claim() {
1256        let address = AddressClaim {
1257            formatted: Some("123 Main St, City, ST 12345".to_string()),
1258            street_address: Some("123 Main St".to_string()),
1259            locality: Some("City".to_string()),
1260            region: Some("ST".to_string()),
1261            postal_code: Some("12345".to_string()),
1262            country: Some("US".to_string()),
1263        };
1264
1265        let json = serde_json::to_string(&address).unwrap();
1266        assert!(json.contains("formatted"));
1267        assert!(json.contains("street_address"));
1268    }
1269
1270    #[test]
1271    fn test_custom_claims() {
1272        let claims = UserClaims::new("user123")
1273            .with_custom("custom_field", serde_json::json!("custom_value"))
1274            .with_custom("roles", serde_json::json!(["admin", "user"]));
1275
1276        assert_eq!(
1277            claims.custom.get("custom_field"),
1278            Some(&serde_json::json!("custom_value"))
1279        );
1280        assert_eq!(
1281            claims.custom.get("roles"),
1282            Some(&serde_json::json!(["admin", "user"]))
1283        );
1284    }
1285}