1use std::collections::HashMap;
40use std::sync::{Arc, RwLock};
41use std::time::{Duration, SystemTime, UNIX_EPOCH};
42
43use crate::oauth::{OAuthError, OAuthServer, OAuthToken};
44
45#[derive(Debug, Clone)]
51pub struct OidcProviderConfig {
52 pub issuer: String,
54 pub id_token_lifetime: Duration,
56 pub signing_algorithm: SigningAlgorithm,
58 pub key_id: Option<String>,
60 pub supported_claims: Vec<String>,
62 pub supported_scopes: Vec<String>,
64}
65
66impl Default for OidcProviderConfig {
67 fn default() -> Self {
68 Self {
69 issuer: "fastmcp".to_string(),
70 id_token_lifetime: Duration::from_secs(3600), signing_algorithm: SigningAlgorithm::HS256,
72 key_id: None,
73 supported_claims: vec![
74 "sub".to_string(),
75 "name".to_string(),
76 "email".to_string(),
77 "email_verified".to_string(),
78 "preferred_username".to_string(),
79 "picture".to_string(),
80 "updated_at".to_string(),
81 ],
82 supported_scopes: vec![
83 "openid".to_string(),
84 "profile".to_string(),
85 "email".to_string(),
86 ],
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum SigningAlgorithm {
94 HS256,
96 RS256,
98}
99
100impl SigningAlgorithm {
101 #[must_use]
103 pub fn as_str(&self) -> &'static str {
104 match self {
105 Self::HS256 => "HS256",
106 Self::RS256 => "RS256",
107 }
108 }
109}
110
111#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
120pub struct UserClaims {
121 pub sub: String,
123
124 #[serde(skip_serializing_if = "Option::is_none")]
127 pub name: Option<String>,
128 #[serde(skip_serializing_if = "Option::is_none")]
130 pub given_name: Option<String>,
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub family_name: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub middle_name: Option<String>,
137 #[serde(skip_serializing_if = "Option::is_none")]
139 pub nickname: Option<String>,
140 #[serde(skip_serializing_if = "Option::is_none")]
142 pub preferred_username: Option<String>,
143 #[serde(skip_serializing_if = "Option::is_none")]
145 pub profile: Option<String>,
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub picture: Option<String>,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub website: Option<String>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub gender: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub birthdate: Option<String>,
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub zoneinfo: Option<String>,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub locale: Option<String>,
164 #[serde(skip_serializing_if = "Option::is_none")]
166 pub updated_at: Option<i64>,
167
168 #[serde(skip_serializing_if = "Option::is_none")]
171 pub email: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
174 pub email_verified: Option<bool>,
175
176 #[serde(skip_serializing_if = "Option::is_none")]
179 pub phone_number: Option<String>,
180 #[serde(skip_serializing_if = "Option::is_none")]
182 pub phone_number_verified: Option<bool>,
183
184 #[serde(skip_serializing_if = "Option::is_none")]
187 pub address: Option<AddressClaim>,
188
189 #[serde(flatten)]
191 pub custom: HashMap<String, serde_json::Value>,
192}
193
194impl UserClaims {
195 #[must_use]
197 pub fn new(sub: impl Into<String>) -> Self {
198 Self {
199 sub: sub.into(),
200 ..Default::default()
201 }
202 }
203
204 #[must_use]
206 pub fn with_name(mut self, name: impl Into<String>) -> Self {
207 self.name = Some(name.into());
208 self
209 }
210
211 #[must_use]
213 pub fn with_email(mut self, email: impl Into<String>) -> Self {
214 self.email = Some(email.into());
215 self
216 }
217
218 #[must_use]
220 pub fn with_email_verified(mut self, verified: bool) -> Self {
221 self.email_verified = Some(verified);
222 self
223 }
224
225 #[must_use]
227 pub fn with_preferred_username(mut self, username: impl Into<String>) -> Self {
228 self.preferred_username = Some(username.into());
229 self
230 }
231
232 #[must_use]
234 pub fn with_picture(mut self, url: impl Into<String>) -> Self {
235 self.picture = Some(url.into());
236 self
237 }
238
239 #[must_use]
241 pub fn with_given_name(mut self, name: impl Into<String>) -> Self {
242 self.given_name = Some(name.into());
243 self
244 }
245
246 #[must_use]
248 pub fn with_family_name(mut self, name: impl Into<String>) -> Self {
249 self.family_name = Some(name.into());
250 self
251 }
252
253 #[must_use]
255 pub fn with_phone_number(mut self, phone: impl Into<String>) -> Self {
256 self.phone_number = Some(phone.into());
257 self
258 }
259
260 #[must_use]
262 pub fn with_updated_at(mut self, timestamp: i64) -> Self {
263 self.updated_at = Some(timestamp);
264 self
265 }
266
267 #[must_use]
269 pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
270 self.custom.insert(key.into(), value);
271 self
272 }
273
274 #[must_use]
278 #[allow(clippy::assigning_clones)]
279 pub fn filter_by_scopes(&self, scopes: &[String]) -> UserClaims {
280 let mut filtered = UserClaims::new(&self.sub);
281
282 if scopes.iter().any(|s| s == "profile") {
284 filtered.name = self.name.clone();
285 filtered.given_name = self.given_name.clone();
286 filtered.family_name = self.family_name.clone();
287 filtered.middle_name = self.middle_name.clone();
288 filtered.nickname = self.nickname.clone();
289 filtered.preferred_username = self.preferred_username.clone();
290 filtered.profile = self.profile.clone();
291 filtered.picture = self.picture.clone();
292 filtered.website = self.website.clone();
293 filtered.gender = self.gender.clone();
294 filtered.birthdate = self.birthdate.clone();
295 filtered.zoneinfo = self.zoneinfo.clone();
296 filtered.locale = self.locale.clone();
297 filtered.updated_at = self.updated_at;
298 }
299
300 if scopes.iter().any(|s| s == "email") {
302 filtered.email = self.email.clone();
303 filtered.email_verified = self.email_verified;
304 }
305
306 if scopes.iter().any(|s| s == "phone") {
308 filtered.phone_number = self.phone_number.clone();
309 filtered.phone_number_verified = self.phone_number_verified;
310 }
311
312 if scopes.iter().any(|s| s == "address") {
314 filtered.address = self.address.clone();
315 }
316
317 filtered
318 }
319}
320
321#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
323pub struct AddressClaim {
324 #[serde(skip_serializing_if = "Option::is_none")]
326 pub formatted: Option<String>,
327 #[serde(skip_serializing_if = "Option::is_none")]
329 pub street_address: Option<String>,
330 #[serde(skip_serializing_if = "Option::is_none")]
332 pub locality: Option<String>,
333 #[serde(skip_serializing_if = "Option::is_none")]
335 pub region: Option<String>,
336 #[serde(skip_serializing_if = "Option::is_none")]
338 pub postal_code: Option<String>,
339 #[serde(skip_serializing_if = "Option::is_none")]
341 pub country: Option<String>,
342}
343
344#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
350pub struct IdTokenClaims {
351 pub iss: String,
353 pub sub: String,
355 pub aud: String,
357 pub exp: i64,
359 pub iat: i64,
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub auth_time: Option<i64>,
364 #[serde(skip_serializing_if = "Option::is_none")]
366 pub nonce: Option<String>,
367 #[serde(skip_serializing_if = "Option::is_none")]
369 pub acr: Option<String>,
370 #[serde(skip_serializing_if = "Option::is_none")]
372 pub amr: Option<Vec<String>>,
373 #[serde(skip_serializing_if = "Option::is_none")]
375 pub azp: Option<String>,
376 #[serde(skip_serializing_if = "Option::is_none")]
378 pub at_hash: Option<String>,
379 #[serde(skip_serializing_if = "Option::is_none")]
381 pub c_hash: Option<String>,
382 #[serde(flatten)]
384 pub user_claims: UserClaims,
385}
386
387#[derive(Debug, Clone)]
389pub struct IdToken {
390 pub raw: String,
392 pub claims: IdTokenClaims,
394}
395
396#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
404pub struct DiscoveryDocument {
405 pub issuer: String,
407 pub authorization_endpoint: String,
409 pub token_endpoint: String,
411 #[serde(skip_serializing_if = "Option::is_none")]
413 pub userinfo_endpoint: Option<String>,
414 #[serde(skip_serializing_if = "Option::is_none")]
416 pub jwks_uri: Option<String>,
417 #[serde(skip_serializing_if = "Option::is_none")]
419 pub registration_endpoint: Option<String>,
420 #[serde(skip_serializing_if = "Option::is_none")]
422 pub revocation_endpoint: Option<String>,
423 pub scopes_supported: Vec<String>,
425 pub response_types_supported: Vec<String>,
427 #[serde(skip_serializing_if = "Option::is_none")]
429 pub response_modes_supported: Option<Vec<String>>,
430 pub grant_types_supported: Vec<String>,
432 pub subject_types_supported: Vec<String>,
434 pub id_token_signing_alg_values_supported: Vec<String>,
436 pub token_endpoint_auth_methods_supported: Vec<String>,
438 #[serde(skip_serializing_if = "Option::is_none")]
440 pub claims_supported: Option<Vec<String>>,
441 #[serde(skip_serializing_if = "Option::is_none")]
443 pub code_challenge_methods_supported: Option<Vec<String>>,
444}
445
446impl DiscoveryDocument {
447 #[must_use]
449 pub fn new(issuer: impl Into<String>, base_url: impl Into<String>) -> Self {
450 let issuer = issuer.into();
451 let base = base_url.into();
452
453 Self {
454 issuer: issuer.clone(),
455 authorization_endpoint: format!("{}/authorize", base),
456 token_endpoint: format!("{}/token", base),
457 userinfo_endpoint: Some(format!("{}/userinfo", base)),
458 jwks_uri: Some(format!("{}/.well-known/jwks.json", base)),
459 registration_endpoint: None,
460 revocation_endpoint: Some(format!("{}/revoke", base)),
461 scopes_supported: vec![
462 "openid".to_string(),
463 "profile".to_string(),
464 "email".to_string(),
465 ],
466 response_types_supported: vec!["code".to_string()],
467 response_modes_supported: Some(vec!["query".to_string()]),
468 grant_types_supported: vec![
469 "authorization_code".to_string(),
470 "refresh_token".to_string(),
471 ],
472 subject_types_supported: vec!["public".to_string()],
473 id_token_signing_alg_values_supported: vec!["HS256".to_string()],
474 token_endpoint_auth_methods_supported: vec![
475 "client_secret_post".to_string(),
476 "client_secret_basic".to_string(),
477 ],
478 claims_supported: Some(vec![
479 "sub".to_string(),
480 "iss".to_string(),
481 "aud".to_string(),
482 "exp".to_string(),
483 "iat".to_string(),
484 "name".to_string(),
485 "email".to_string(),
486 "email_verified".to_string(),
487 "preferred_username".to_string(),
488 "picture".to_string(),
489 ]),
490 code_challenge_methods_supported: Some(vec!["plain".to_string(), "S256".to_string()]),
491 }
492 }
493}
494
495pub trait ClaimsProvider: Send + Sync {
501 fn get_claims(&self, subject: &str) -> Option<UserClaims>;
505}
506
507#[derive(Debug, Default)]
509pub struct InMemoryClaimsProvider {
510 claims: RwLock<HashMap<String, UserClaims>>,
511}
512
513impl InMemoryClaimsProvider {
514 #[must_use]
516 pub fn new() -> Self {
517 Self::default()
518 }
519
520 pub fn set_claims(&self, claims: UserClaims) {
522 if let Ok(mut guard) = self.claims.write() {
523 guard.insert(claims.sub.clone(), claims);
524 }
525 }
526
527 pub fn remove_claims(&self, subject: &str) {
529 if let Ok(mut guard) = self.claims.write() {
530 guard.remove(subject);
531 }
532 }
533}
534
535impl ClaimsProvider for InMemoryClaimsProvider {
536 fn get_claims(&self, subject: &str) -> Option<UserClaims> {
537 self.claims
538 .read()
539 .ok()
540 .and_then(|guard| guard.get(subject).cloned())
541 }
542}
543
544pub struct FnClaimsProvider<F>
546where
547 F: Fn(&str) -> Option<UserClaims> + Send + Sync,
548{
549 func: F,
550}
551
552impl<F> FnClaimsProvider<F>
553where
554 F: Fn(&str) -> Option<UserClaims> + Send + Sync,
555{
556 #[must_use]
558 pub fn new(func: F) -> Self {
559 Self { func }
560 }
561}
562
563impl<F> ClaimsProvider for FnClaimsProvider<F>
564where
565 F: Fn(&str) -> Option<UserClaims> + Send + Sync,
566{
567 fn get_claims(&self, subject: &str) -> Option<UserClaims> {
568 (self.func)(subject)
569 }
570}
571
572impl ClaimsProvider for Arc<dyn ClaimsProvider> {
573 fn get_claims(&self, subject: &str) -> Option<UserClaims> {
574 (**self).get_claims(subject)
575 }
576}
577
578#[derive(Debug, Clone)]
584pub enum OidcError {
585 OAuth(OAuthError),
587 MissingOpenIdScope,
589 ClaimsNotFound(String),
591 SigningError(String),
593 InvalidIdToken(String),
595}
596
597impl std::fmt::Display for OidcError {
598 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
599 match self {
600 Self::OAuth(e) => write!(f, "OAuth error: {}", e),
601 Self::MissingOpenIdScope => write!(f, "missing 'openid' scope"),
602 Self::ClaimsNotFound(s) => write!(f, "claims not found for subject: {}", s),
603 Self::SigningError(s) => write!(f, "signing error: {}", s),
604 Self::InvalidIdToken(s) => write!(f, "invalid ID token: {}", s),
605 }
606 }
607}
608
609impl std::error::Error for OidcError {}
610
611impl From<OAuthError> for OidcError {
612 fn from(err: OAuthError) -> Self {
613 Self::OAuth(err)
614 }
615}
616
617pub struct OidcProvider {
625 oauth: Arc<OAuthServer>,
627 config: OidcProviderConfig,
629 signing_key: RwLock<SigningKey>,
631 claims_provider: RwLock<Option<Arc<dyn ClaimsProvider>>>,
633 id_tokens: RwLock<HashMap<String, IdToken>>,
635}
636
637#[derive(Clone, Default)]
639enum SigningKey {
640 Hmac(Vec<u8>),
642 #[default]
644 None,
645}
646
647impl OidcProvider {
648 #[must_use]
650 pub fn new(oauth: Arc<OAuthServer>, config: OidcProviderConfig) -> Self {
651 Self {
652 oauth,
653 config,
654 signing_key: RwLock::new(SigningKey::None),
655 claims_provider: RwLock::new(None),
656 id_tokens: RwLock::new(HashMap::new()),
657 }
658 }
659
660 #[must_use]
662 pub fn with_defaults(oauth: Arc<OAuthServer>) -> Self {
663 Self::new(oauth, OidcProviderConfig::default())
664 }
665
666 #[must_use]
668 pub fn config(&self) -> &OidcProviderConfig {
669 &self.config
670 }
671
672 #[must_use]
674 pub fn oauth(&self) -> &Arc<OAuthServer> {
675 &self.oauth
676 }
677
678 pub fn set_hmac_key(&self, key: impl AsRef<[u8]>) {
680 if let Ok(mut guard) = self.signing_key.write() {
681 *guard = SigningKey::Hmac(key.as_ref().to_vec());
682 }
683 }
684
685 pub fn set_claims_provider<P: ClaimsProvider + 'static>(&self, provider: P) {
687 if let Ok(mut guard) = self.claims_provider.write() {
688 *guard = Some(Arc::new(provider));
689 }
690 }
691
692 pub fn set_claims_fn<F>(&self, func: F)
694 where
695 F: Fn(&str) -> Option<UserClaims> + Send + Sync + 'static,
696 {
697 self.set_claims_provider(FnClaimsProvider::new(func));
698 }
699
700 #[must_use]
702 pub fn discovery_document(&self, base_url: impl Into<String>) -> DiscoveryDocument {
703 let mut doc = DiscoveryDocument::new(&self.config.issuer, base_url);
704 doc.scopes_supported = self.config.supported_scopes.clone();
705 doc.claims_supported = Some(self.config.supported_claims.clone());
706 doc.id_token_signing_alg_values_supported =
707 vec![self.config.signing_algorithm.as_str().to_string()];
708 doc
709 }
710
711 pub fn issue_id_token(
720 &self,
721 access_token: &OAuthToken,
722 nonce: Option<&str>,
723 ) -> Result<IdToken, OidcError> {
724 if !access_token.scopes.iter().any(|s| s == "openid") {
726 return Err(OidcError::MissingOpenIdScope);
727 }
728
729 let subject = access_token
730 .subject
731 .as_ref()
732 .ok_or_else(|| OidcError::ClaimsNotFound("no subject in access token".to_string()))?;
733
734 let user_claims = self.get_user_claims(subject, &access_token.scopes)?;
736
737 let now = SystemTime::now()
739 .duration_since(UNIX_EPOCH)
740 .unwrap_or_default()
741 .as_secs() as i64;
742
743 let claims = IdTokenClaims {
744 iss: self.config.issuer.clone(),
745 sub: subject.clone(),
746 aud: access_token.client_id.clone(),
747 exp: now + self.config.id_token_lifetime.as_secs() as i64,
748 iat: now,
749 auth_time: Some(now),
750 nonce: nonce.map(String::from),
751 acr: None,
752 amr: None,
753 azp: Some(access_token.client_id.clone()),
754 at_hash: Some(self.compute_at_hash(&access_token.token)),
755 c_hash: None,
756 user_claims,
757 };
758
759 let raw = self.sign_id_token(&claims)?;
761
762 let id_token = IdToken { raw, claims };
763
764 if let Ok(mut guard) = self.id_tokens.write() {
766 guard.insert(access_token.token.clone(), id_token.clone());
767 }
768
769 Ok(id_token)
770 }
771
772 #[must_use]
774 pub fn get_id_token(&self, access_token: &str) -> Option<IdToken> {
775 self.id_tokens
776 .read()
777 .ok()
778 .and_then(|guard| guard.get(access_token).cloned())
779 }
780
781 pub fn userinfo(&self, access_token: &str) -> Result<UserClaims, OidcError> {
789 let token = self
791 .oauth
792 .validate_access_token(access_token)
793 .ok_or_else(|| {
794 OidcError::OAuth(OAuthError::InvalidGrant(
795 "invalid or expired access token".to_string(),
796 ))
797 })?;
798
799 if !token.scopes.iter().any(|s| s == "openid") {
801 return Err(OidcError::MissingOpenIdScope);
802 }
803
804 let subject = token
805 .subject
806 .as_ref()
807 .ok_or_else(|| OidcError::ClaimsNotFound("no subject in access token".to_string()))?;
808
809 self.get_user_claims(subject, &token.scopes)
810 }
811
812 fn get_user_claims(&self, subject: &str, scopes: &[String]) -> Result<UserClaims, OidcError> {
817 let provider = self
818 .claims_provider
819 .read()
820 .ok()
821 .and_then(|guard| guard.clone());
822
823 let claims = match provider {
824 Some(p) => p
825 .get_claims(subject)
826 .ok_or_else(|| OidcError::ClaimsNotFound(subject.to_string()))?,
827 None => {
828 UserClaims::new(subject)
830 }
831 };
832
833 Ok(claims.filter_by_scopes(scopes))
834 }
835
836 fn sign_id_token(&self, claims: &IdTokenClaims) -> Result<String, OidcError> {
837 let key = self.get_or_generate_signing_key()?;
838
839 let header = serde_json::json!({
841 "alg": self.config.signing_algorithm.as_str(),
842 "typ": "JWT",
843 "kid": self.config.key_id.as_deref().unwrap_or("default"),
844 });
845
846 let header_b64 =
847 base64url_encode(&serde_json::to_vec(&header).map_err(|e| {
848 OidcError::SigningError(format!("failed to serialize header: {}", e))
849 })?);
850
851 let claims_b64 =
852 base64url_encode(&serde_json::to_vec(claims).map_err(|e| {
853 OidcError::SigningError(format!("failed to serialize claims: {}", e))
854 })?);
855
856 let signing_input = format!("{}.{}", header_b64, claims_b64);
857
858 let signature = match &key {
859 SigningKey::Hmac(secret) => hmac_sha256(&signing_input, secret),
860 SigningKey::None => {
861 return Err(OidcError::SigningError(
862 "no signing key configured".to_string(),
863 ));
864 }
865 };
866
867 let signature_b64 = base64url_encode(&signature);
868
869 Ok(format!("{}.{}", signing_input, signature_b64))
870 }
871
872 fn get_or_generate_signing_key(&self) -> Result<SigningKey, OidcError> {
873 let guard = self
874 .signing_key
875 .read()
876 .map_err(|_| OidcError::SigningError("failed to acquire read lock".to_string()))?;
877
878 match &*guard {
879 SigningKey::None => {
880 drop(guard);
882 let mut write_guard = self.signing_key.write().map_err(|_| {
883 OidcError::SigningError("failed to acquire write lock".to_string())
884 })?;
885
886 if matches!(&*write_guard, SigningKey::None) {
888 let key = generate_random_bytes(32);
889 *write_guard = SigningKey::Hmac(key.clone());
890 Ok(SigningKey::Hmac(key))
891 } else {
892 Ok(write_guard.clone())
893 }
894 }
895 key => Ok(key.clone()),
896 }
897 }
898
899 fn compute_at_hash(&self, access_token: &str) -> String {
900 let hash = simple_sha256(access_token.as_bytes());
902 base64url_encode(&hash[..16])
903 }
904
905 pub fn cleanup_expired(&self) {
907 let now = SystemTime::now()
908 .duration_since(UNIX_EPOCH)
909 .unwrap_or_default()
910 .as_secs() as i64;
911
912 if let Ok(mut guard) = self.id_tokens.write() {
913 guard.retain(|_, token| token.claims.exp > now);
914 }
915 }
916}
917
918fn base64url_encode(data: &[u8]) -> String {
924 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
925
926 let mut result = String::with_capacity((data.len() * 4 + 2) / 3);
927 let mut i = 0;
928
929 while i + 2 < data.len() {
930 let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8) | u32::from(data[i + 2]);
931 result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
932 result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
933 result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
934 result.push(ALPHABET[n as usize & 0x3F] as char);
935 i += 3;
936 }
937
938 if i + 1 == data.len() {
939 let n = u32::from(data[i]) << 16;
940 result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
941 result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
942 } else if i + 2 == data.len() {
943 let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8);
944 result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
945 result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
946 result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
947 }
948
949 result
950}
951
952fn simple_sha256(data: &[u8]) -> [u8; 32] {
954 use std::collections::hash_map::RandomState;
955 use std::hash::{BuildHasher, Hasher};
956
957 let mut result = [0u8; 32];
958 let state = RandomState::new();
959
960 for (i, chunk) in result.chunks_mut(8).enumerate() {
961 let mut hasher = state.build_hasher();
962 hasher.write(data);
963 hasher.write_usize(i);
964 let hash = hasher.finish().to_le_bytes();
965 chunk.copy_from_slice(&hash[..chunk.len()]);
966 }
967
968 result
969}
970
971fn hmac_sha256(message: &str, key: &[u8]) -> [u8; 32] {
973 let mut combined = Vec::with_capacity(key.len() + message.len());
976 combined.extend_from_slice(key);
977 combined.extend_from_slice(message.as_bytes());
978 simple_sha256(&combined)
979}
980
981fn generate_random_bytes(len: usize) -> Vec<u8> {
983 use std::collections::hash_map::RandomState;
984 use std::hash::{BuildHasher, Hasher};
985
986 let mut result = Vec::with_capacity(len);
987 let state = RandomState::new();
988
989 for i in 0..len {
990 let mut hasher = state.build_hasher();
991 hasher.write_usize(i);
992 hasher.write_u128(
993 SystemTime::now()
994 .duration_since(UNIX_EPOCH)
995 .unwrap_or_default()
996 .as_nanos(),
997 );
998 result.push((hasher.finish() & 0xFF) as u8);
999 }
1000
1001 result
1002}
1003
1004#[cfg(test)]
1009mod tests {
1010 use super::*;
1011 use crate::oauth::{OAuthClient, OAuthServerConfig};
1012 use std::time::Instant;
1013
1014 fn create_test_provider() -> OidcProvider {
1015 let oauth = Arc::new(OAuthServer::new(OAuthServerConfig::default()));
1016 OidcProvider::with_defaults(oauth)
1017 }
1018
1019 #[test]
1020 fn test_user_claims_builder() {
1021 let claims = UserClaims::new("user123")
1022 .with_name("John Doe")
1023 .with_email("john@example.com")
1024 .with_email_verified(true)
1025 .with_preferred_username("johnd");
1026
1027 assert_eq!(claims.sub, "user123");
1028 assert_eq!(claims.name, Some("John Doe".to_string()));
1029 assert_eq!(claims.email, Some("john@example.com".to_string()));
1030 assert_eq!(claims.email_verified, Some(true));
1031 assert_eq!(claims.preferred_username, Some("johnd".to_string()));
1032 }
1033
1034 #[test]
1035 fn test_claims_filter_by_scopes() {
1036 let claims = UserClaims::new("user123")
1037 .with_name("John Doe")
1038 .with_email("john@example.com")
1039 .with_phone_number("+1234567890");
1040
1041 let filtered = claims.filter_by_scopes(&["openid".to_string()]);
1043 assert_eq!(filtered.sub, "user123");
1044 assert!(filtered.name.is_none());
1045 assert!(filtered.email.is_none());
1046
1047 let filtered = claims.filter_by_scopes(&["openid".to_string(), "profile".to_string()]);
1049 assert_eq!(filtered.name, Some("John Doe".to_string()));
1050 assert!(filtered.email.is_none());
1051
1052 let filtered = claims.filter_by_scopes(&["openid".to_string(), "email".to_string()]);
1054 assert!(filtered.name.is_none());
1055 assert_eq!(filtered.email, Some("john@example.com".to_string()));
1056
1057 let filtered = claims.filter_by_scopes(&[
1059 "openid".to_string(),
1060 "profile".to_string(),
1061 "email".to_string(),
1062 "phone".to_string(),
1063 ]);
1064 assert_eq!(filtered.name, Some("John Doe".to_string()));
1065 assert_eq!(filtered.email, Some("john@example.com".to_string()));
1066 assert_eq!(filtered.phone_number, Some("+1234567890".to_string()));
1067 }
1068
1069 #[test]
1070 fn test_discovery_document() {
1071 let provider = create_test_provider();
1072 let doc = provider.discovery_document("https://example.com");
1073
1074 assert_eq!(doc.issuer, "fastmcp");
1075 assert_eq!(doc.authorization_endpoint, "https://example.com/authorize");
1076 assert_eq!(doc.token_endpoint, "https://example.com/token");
1077 assert!(doc.scopes_supported.contains(&"openid".to_string()));
1078 assert!(doc.response_types_supported.contains(&"code".to_string()));
1079 }
1080
1081 #[test]
1082 fn test_in_memory_claims_provider() {
1083 let provider = InMemoryClaimsProvider::new();
1084
1085 let claims = UserClaims::new("user123")
1086 .with_name("John Doe")
1087 .with_email("john@example.com");
1088
1089 provider.set_claims(claims);
1090
1091 let retrieved = provider.get_claims("user123");
1092 assert!(retrieved.is_some());
1093 assert_eq!(retrieved.unwrap().name, Some("John Doe".to_string()));
1094
1095 assert!(provider.get_claims("nonexistent").is_none());
1096
1097 provider.remove_claims("user123");
1098 assert!(provider.get_claims("user123").is_none());
1099 }
1100
1101 #[test]
1102 fn test_fn_claims_provider() {
1103 let provider = FnClaimsProvider::new(|subject| {
1104 if subject == "user123" {
1105 Some(UserClaims::new(subject).with_name("John Doe"))
1106 } else {
1107 None
1108 }
1109 });
1110
1111 let claims = provider.get_claims("user123");
1112 assert!(claims.is_some());
1113 assert_eq!(claims.unwrap().name, Some("John Doe".to_string()));
1114
1115 assert!(provider.get_claims("other").is_none());
1116 }
1117
1118 #[test]
1119 fn test_signing_algorithm() {
1120 assert_eq!(SigningAlgorithm::HS256.as_str(), "HS256");
1121 assert_eq!(SigningAlgorithm::RS256.as_str(), "RS256");
1122 }
1123
1124 #[test]
1125 fn test_oidc_error_display() {
1126 let err = OidcError::MissingOpenIdScope;
1127 assert_eq!(err.to_string(), "missing 'openid' scope");
1128
1129 let err = OidcError::ClaimsNotFound("user123".to_string());
1130 assert!(err.to_string().contains("user123"));
1131 }
1132
1133 #[test]
1134 fn test_base64url_encode() {
1135 assert_eq!(base64url_encode(b""), "");
1136 assert_eq!(base64url_encode(b"f"), "Zg");
1137 assert_eq!(base64url_encode(b"fo"), "Zm8");
1138 assert_eq!(base64url_encode(b"foo"), "Zm9v");
1139 }
1140
1141 #[test]
1142 fn test_id_token_issuance() {
1143 let provider = create_test_provider();
1144
1145 let claims_provider = InMemoryClaimsProvider::new();
1147 claims_provider.set_claims(
1148 UserClaims::new("user123")
1149 .with_name("John Doe")
1150 .with_email("john@example.com"),
1151 );
1152 provider.set_claims_provider(claims_provider);
1153
1154 provider.set_hmac_key(b"test-secret-key");
1156
1157 let now = Instant::now();
1159 let access_token = crate::oauth::OAuthToken {
1160 token: "test-access-token".to_string(),
1161 token_type: crate::oauth::TokenType::Bearer,
1162 client_id: "test-client".to_string(),
1163 scopes: vec![
1164 "openid".to_string(),
1165 "profile".to_string(),
1166 "email".to_string(),
1167 ],
1168 issued_at: now,
1169 expires_at: now + Duration::from_secs(3600),
1170 subject: Some("user123".to_string()),
1171 is_refresh_token: false,
1172 };
1173
1174 let result = provider.issue_id_token(&access_token, Some("nonce123"));
1175 assert!(result.is_ok());
1176
1177 let id_token = result.unwrap();
1178 assert!(!id_token.raw.is_empty());
1179 assert!(id_token.raw.contains('.'));
1180 assert_eq!(id_token.claims.sub, "user123");
1181 assert_eq!(id_token.claims.aud, "test-client");
1182 assert_eq!(id_token.claims.nonce, Some("nonce123".to_string()));
1183 assert_eq!(
1184 id_token.claims.user_claims.name,
1185 Some("John Doe".to_string())
1186 );
1187 }
1188
1189 #[test]
1190 fn test_id_token_requires_openid_scope() {
1191 let provider = create_test_provider();
1192
1193 let now = Instant::now();
1194 let access_token = crate::oauth::OAuthToken {
1195 token: "test-access-token".to_string(),
1196 token_type: crate::oauth::TokenType::Bearer,
1197 client_id: "test-client".to_string(),
1198 scopes: vec!["profile".to_string()], issued_at: now,
1200 expires_at: now + Duration::from_secs(3600),
1201 subject: Some("user123".to_string()),
1202 is_refresh_token: false,
1203 };
1204
1205 let result = provider.issue_id_token(&access_token, None);
1206 assert!(matches!(result, Err(OidcError::MissingOpenIdScope)));
1207 }
1208
1209 #[test]
1210 fn test_userinfo() {
1211 let oauth = Arc::new(OAuthServer::new(OAuthServerConfig::default()));
1212
1213 let client = OAuthClient::builder("test-client")
1215 .redirect_uri("http://localhost:3000/callback")
1216 .scope("openid")
1217 .scope("profile")
1218 .build()
1219 .unwrap();
1220 oauth.register_client(client).unwrap();
1221
1222 {
1224 let mut state = oauth.state.write().unwrap();
1225 let now = Instant::now();
1226 let token = crate::oauth::OAuthToken {
1227 token: "test-token".to_string(),
1228 token_type: crate::oauth::TokenType::Bearer,
1229 client_id: "test-client".to_string(),
1230 scopes: vec!["openid".to_string(), "profile".to_string()],
1231 issued_at: now,
1232 expires_at: now + Duration::from_secs(3600),
1233 subject: Some("user123".to_string()),
1234 is_refresh_token: false,
1235 };
1236 state.access_tokens.insert("test-token".to_string(), token);
1237 }
1238
1239 let provider = OidcProvider::with_defaults(oauth);
1240
1241 let claims_store = InMemoryClaimsProvider::new();
1243 claims_store.set_claims(UserClaims::new("user123").with_name("John Doe"));
1244 provider.set_claims_provider(claims_store);
1245
1246 let result = provider.userinfo("test-token");
1247 assert!(result.is_ok());
1248
1249 let claims = result.unwrap();
1250 assert_eq!(claims.sub, "user123");
1251 assert_eq!(claims.name, Some("John Doe".to_string()));
1252 }
1253
1254 #[test]
1255 fn test_address_claim() {
1256 let address = AddressClaim {
1257 formatted: Some("123 Main St, City, ST 12345".to_string()),
1258 street_address: Some("123 Main St".to_string()),
1259 locality: Some("City".to_string()),
1260 region: Some("ST".to_string()),
1261 postal_code: Some("12345".to_string()),
1262 country: Some("US".to_string()),
1263 };
1264
1265 let json = serde_json::to_string(&address).unwrap();
1266 assert!(json.contains("formatted"));
1267 assert!(json.contains("street_address"));
1268 }
1269
1270 #[test]
1271 fn test_custom_claims() {
1272 let claims = UserClaims::new("user123")
1273 .with_custom("custom_field", serde_json::json!("custom_value"))
1274 .with_custom("roles", serde_json::json!(["admin", "user"]));
1275
1276 assert_eq!(
1277 claims.custom.get("custom_field"),
1278 Some(&serde_json::json!("custom_value"))
1279 );
1280 assert_eq!(
1281 claims.custom.get("roles"),
1282 Some(&serde_json::json!(["admin", "user"]))
1283 );
1284 }
1285}