1use crate::errors::{AuthError, Result};
10use crate::oauth2_server::OAuth2Server;
11use crate::server::core::client_registry::ClientRegistry;
12use crate::storage::AuthStorage;
13use crate::tokens::TokenManager;
14use jsonwebtoken::{Algorithm, Header};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::fmt;
19use std::sync::Arc;
20use std::time::{Duration, SystemTime, UNIX_EPOCH};
21
22#[derive(Debug, Clone)]
24pub struct OidcConfig {
25 pub issuer: String,
27
28 pub oauth2_config: crate::oauth2_server::OAuth2Config,
30
31 pub jwks_uri: String,
33
34 pub userinfo_endpoint: String,
36
37 pub response_types_supported: Vec<String>,
39
40 pub subject_types_supported: Vec<SubjectType>,
42
43 pub id_token_signing_alg_values_supported: Vec<Algorithm>,
45
46 pub scopes_supported: Vec<String>,
48
49 pub claims_supported: Vec<String>,
51
52 pub claims_parameter_supported: bool,
54
55 pub request_parameter_supported: bool,
57
58 pub request_uri_parameter_supported: bool,
60
61 pub id_token_expiry: Duration,
63
64 pub max_age_supported: Option<Duration>,
66}
67
68impl Default for OidcConfig {
69 fn default() -> Self {
70 Self {
71 issuer: "https://auth.example.com".to_string(),
72 oauth2_config: crate::oauth2_server::OAuth2Config::default(),
73 jwks_uri: "https://auth.example.com/.well-known/jwks.json".to_string(),
74 userinfo_endpoint: "https://auth.example.com/oidc/userinfo".to_string(),
75 response_types_supported: vec![
76 "code".to_string(),
77 "id_token".to_string(),
78 "id_token token".to_string(),
79 "code id_token".to_string(),
80 "code token".to_string(),
81 "code id_token token".to_string(),
82 ],
83 subject_types_supported: vec![SubjectType::Public],
84 id_token_signing_alg_values_supported: vec![
85 Algorithm::RS256,
86 Algorithm::ES256,
87 Algorithm::HS256,
88 ],
89 scopes_supported: vec![
90 "openid".to_string(),
91 "profile".to_string(),
92 "email".to_string(),
93 "address".to_string(),
94 "phone".to_string(),
95 "offline_access".to_string(),
96 ],
97 claims_supported: vec![
98 "sub".to_string(),
99 "name".to_string(),
100 "given_name".to_string(),
101 "family_name".to_string(),
102 "middle_name".to_string(),
103 "nickname".to_string(),
104 "preferred_username".to_string(),
105 "profile".to_string(),
106 "picture".to_string(),
107 "website".to_string(),
108 "email".to_string(),
109 "email_verified".to_string(),
110 "gender".to_string(),
111 "birthdate".to_string(),
112 "zoneinfo".to_string(),
113 "locale".to_string(),
114 "phone_number".to_string(),
115 "phone_number_verified".to_string(),
116 "address".to_string(),
117 "updated_at".to_string(),
118 ],
119 claims_parameter_supported: true,
120 request_parameter_supported: true,
121 request_uri_parameter_supported: true,
122 id_token_expiry: Duration::from_secs(3600), max_age_supported: Some(Duration::from_secs(86400)), }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
130#[serde(rename_all = "lowercase")]
131pub enum SubjectType {
132 Public,
134 Pairwise,
136}
137
138pub struct OidcProvider<S: AuthStorage + ?Sized> {
140 config: OidcConfig,
141 oauth2_server: OAuth2Server,
142 token_manager: Arc<TokenManager>,
143 storage: Arc<S>,
144 client_registry: Option<Arc<ClientRegistry>>,
145}
146
147impl<S: AuthStorage + ?Sized> fmt::Debug for OidcProvider<S> {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 f.debug_struct("OidcProvider")
150 .field("config", &self.config)
151 .field("oauth2_server", &"<OAuth2Server>")
152 .field("token_manager", &"<TokenManager>")
153 .field("storage", &"<AuthStorage>")
154 .field("client_registry", &self.client_registry.is_some())
155 .finish()
156 }
157}
158
159impl<S: ?Sized + AuthStorage> OidcProvider<S> {
160 pub async fn new(
162 config: OidcConfig,
163 token_manager: Arc<TokenManager>,
164 storage: Arc<S>,
165 ) -> Result<Self> {
166 let oauth2_server =
167 OAuth2Server::new(config.oauth2_config.clone(), token_manager.clone()).await?;
168
169 Ok(Self {
170 config,
171 oauth2_server,
172 token_manager,
173 storage,
174 client_registry: None,
175 })
176 }
177
178 pub fn oauth2_server(&self) -> &OAuth2Server {
180 &self.oauth2_server
181 }
182
183 pub fn set_client_registry(&mut self, client_registry: Arc<ClientRegistry>) {
185 self.client_registry = Some(client_registry);
186 }
187
188 pub fn config(&self) -> &OidcConfig {
190 &self.config
191 }
192
193 pub fn discovery_document(&self) -> Result<OidcDiscoveryDocument> {
195 Ok(OidcDiscoveryDocument {
196 issuer: self.config.issuer.clone(),
197 authorization_endpoint: format!("{}/oidc/authorize", self.config.issuer),
198 token_endpoint: format!("{}/oidc/token", self.config.issuer),
199 userinfo_endpoint: self.config.userinfo_endpoint.clone(),
200 jwks_uri: self.config.jwks_uri.clone(),
201 registration_endpoint: Some(format!("{}/oidc/register", self.config.issuer)),
202 scopes_supported: self.config.scopes_supported.clone(),
203 response_types_supported: self.config.response_types_supported.clone(),
204 response_modes_supported: Some(vec![
205 "query".to_string(),
206 "fragment".to_string(),
207 "form_post".to_string(),
208 ]),
209 grant_types_supported: Some(vec![
210 "authorization_code".to_string(),
211 "implicit".to_string(),
212 "refresh_token".to_string(),
213 "client_credentials".to_string(),
214 ]),
215 subject_types_supported: self.config.subject_types_supported.clone(),
216 id_token_signing_alg_values_supported: self
217 .config
218 .id_token_signing_alg_values_supported
219 .iter()
220 .map(algorithm_to_string)
221 .collect(),
222 userinfo_signing_alg_values_supported: Some(vec![
223 "RS256".to_string(),
224 "ES256".to_string(),
225 "HS256".to_string(),
226 ]),
227 token_endpoint_auth_methods_supported: Some(vec![
228 "client_secret_basic".to_string(),
229 "client_secret_post".to_string(),
230 "client_secret_jwt".to_string(),
231 "private_key_jwt".to_string(),
232 "none".to_string(),
233 ]),
234 claims_supported: Some(self.config.claims_supported.clone()),
235 claims_parameter_supported: Some(self.config.claims_parameter_supported),
236 request_parameter_supported: Some(self.config.request_parameter_supported),
237 request_uri_parameter_supported: Some(self.config.request_uri_parameter_supported),
238 code_challenge_methods_supported: Some(vec!["S256".to_string(), "plain".to_string()]),
239 })
240 }
241
242 pub async fn create_id_token(
244 &self,
245 subject: &str,
246 client_id: &str,
247 nonce: Option<&str>,
248 auth_time: Option<SystemTime>,
249 claims: Option<&HashMap<String, Value>>,
250 ) -> Result<String> {
251 let now = SystemTime::now()
252 .duration_since(UNIX_EPOCH)
253 .map_err(|e| AuthError::auth_method("oidc", format!("Time error: {}", e)))?
254 .as_secs();
255
256 let exp = now + self.config.id_token_expiry.as_secs();
257
258 let mut id_token_claims = IdTokenClaims {
259 iss: self.config.issuer.clone(),
260 sub: subject.to_string(),
261 aud: vec![client_id.to_string()],
262 exp,
263 iat: now,
264 auth_time: auth_time
265 .and_then(|t| t.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())),
266 nonce: nonce.map(|n| n.to_string()),
267 additional_claims: claims.cloned().unwrap_or_default(),
268 };
269
270 if let Some(claims) = claims {
272 for (key, value) in claims {
273 if self.config.claims_supported.contains(key) {
274 id_token_claims
275 .additional_claims
276 .insert(key.clone(), value.clone());
277 }
278 }
279 }
280
281 let _header = Header::new(Algorithm::RS256);
283 let token = self
284 .token_manager
285 .create_jwt_token(
286 subject,
287 vec!["openid".to_string()],
288 Some(Duration::from_secs(3600)),
289 )
290 .map_err(|e| AuthError::auth_method("oidc", format!("JWT creation failed: {}", e)))?;
291
292 Ok(token)
293 }
294
295 pub async fn validate_authorization_request(
297 &self,
298 request: &OidcAuthorizationRequest,
299 ) -> Result<AuthorizationValidationResult> {
300 if !request.scope.split_whitespace().any(|s| s == "openid") {
302 return Err(AuthError::auth_method(
303 "oidc",
304 "Missing required 'openid' scope",
305 ));
306 }
307
308 if !self
310 .config
311 .response_types_supported
312 .contains(&request.response_type)
313 {
314 return Err(AuthError::auth_method(
315 "oidc",
316 format!("Unsupported response_type: {}", request.response_type),
317 ));
318 }
319
320 if request.client_id.is_empty() {
322 return Err(AuthError::auth_method("oidc", "Missing client_id"));
323 }
324
325 if let Some(client_registry) = &self.client_registry {
327 if client_registry
328 .get_client(&request.client_id)
329 .await?
330 .is_none()
331 {
332 return Err(AuthError::auth_method("oidc", "Invalid client_id"));
333 }
334
335 if !client_registry
337 .validate_redirect_uri(&request.client_id, &request.redirect_uri)
338 .await?
339 {
340 return Err(AuthError::auth_method(
341 "oidc",
342 "Invalid redirect_uri for client",
343 ));
344 }
345 } else {
346 if request.redirect_uri.is_empty() {
348 return Err(AuthError::auth_method("oidc", "Missing redirect_uri"));
349 }
350 }
351
352 Ok(AuthorizationValidationResult {
353 valid: true,
354 client_id: request.client_id.clone(),
355 redirect_uri: request.redirect_uri.clone(),
356 scope: request.scope.clone(),
357 state: request.state.clone(),
358 nonce: request.nonce.clone(),
359 max_age: request.max_age,
360 response_type: request.response_type.clone(),
361 })
362 }
363
364 pub async fn get_userinfo(&self, access_token: &str) -> Result<UserInfo> {
366 let token_claims = self
368 .token_manager
369 .validate_jwt_token(access_token)
370 .map_err(|e| AuthError::auth_method("oidc", format!("Invalid access token: {}", e)))?;
371
372 let subject = &token_claims.sub;
374
375 let user_key = format!("user:{}", subject);
377 if let Some(user_data) = self.storage.get_kv(&user_key).await? {
378 let user_str = std::str::from_utf8(&user_data).unwrap_or("{}");
379 let user_profile: HashMap<String, Value> =
380 serde_json::from_str(user_str).unwrap_or_default();
381
382 Ok(UserInfo {
383 sub: subject.clone(),
384 name: user_profile
385 .get("name")
386 .and_then(|v| v.as_str())
387 .map(|s| s.to_string()),
388 given_name: user_profile
389 .get("given_name")
390 .and_then(|v| v.as_str())
391 .map(|s| s.to_string()),
392 family_name: user_profile
393 .get("family_name")
394 .and_then(|v| v.as_str())
395 .map(|s| s.to_string()),
396 middle_name: user_profile
397 .get("middle_name")
398 .and_then(|v| v.as_str())
399 .map(|s| s.to_string()),
400 nickname: user_profile
401 .get("nickname")
402 .and_then(|v| v.as_str())
403 .map(|s| s.to_string()),
404 preferred_username: user_profile
405 .get("preferred_username")
406 .and_then(|v| v.as_str())
407 .map(|s| s.to_string()),
408 profile: user_profile
409 .get("profile")
410 .and_then(|v| v.as_str())
411 .map(|s| s.to_string()),
412 picture: user_profile
413 .get("picture")
414 .and_then(|v| v.as_str())
415 .map(|s| s.to_string()),
416 website: user_profile
417 .get("website")
418 .and_then(|v| v.as_str())
419 .map(|s| s.to_string()),
420 email: user_profile
421 .get("email")
422 .and_then(|v| v.as_str())
423 .map(|s| s.to_string()),
424 email_verified: user_profile.get("email_verified").and_then(|v| v.as_bool()),
425 gender: user_profile
426 .get("gender")
427 .and_then(|v| v.as_str())
428 .map(|s| s.to_string()),
429 birthdate: user_profile
430 .get("birthdate")
431 .and_then(|v| v.as_str())
432 .map(|s| s.to_string()),
433 zoneinfo: user_profile
434 .get("zoneinfo")
435 .and_then(|v| v.as_str())
436 .map(|s| s.to_string()),
437 locale: user_profile
438 .get("locale")
439 .and_then(|v| v.as_str())
440 .map(|s| s.to_string()),
441 phone_number: user_profile
442 .get("phone_number")
443 .and_then(|v| v.as_str())
444 .map(|s| s.to_string()),
445 phone_number_verified: user_profile
446 .get("phone_number_verified")
447 .and_then(|v| v.as_bool()),
448 address: user_profile
449 .get("address")
450 .and_then(|addr| addr.as_object())
451 .map(|addr_obj| Address {
452 formatted: addr_obj
453 .get("formatted")
454 .and_then(|v| v.as_str())
455 .map(|s| s.to_string()),
456 street_address: addr_obj
457 .get("street_address")
458 .and_then(|v| v.as_str())
459 .map(|s| s.to_string()),
460 locality: addr_obj
461 .get("locality")
462 .and_then(|v| v.as_str())
463 .map(|s| s.to_string()),
464 region: addr_obj
465 .get("region")
466 .and_then(|v| v.as_str())
467 .map(|s| s.to_string()),
468 postal_code: addr_obj
469 .get("postal_code")
470 .and_then(|v| v.as_str())
471 .map(|s| s.to_string()),
472 country: addr_obj
473 .get("country")
474 .and_then(|v| v.as_str())
475 .map(|s| s.to_string()),
476 }),
477 updated_at: user_profile.get("updated_at").and_then(|v| v.as_u64()),
478 additional_claims: user_profile
479 .into_iter()
480 .filter(|(k, _)| {
481 ![
482 "sub",
483 "name",
484 "given_name",
485 "family_name",
486 "middle_name",
487 "nickname",
488 "preferred_username",
489 "profile",
490 "picture",
491 "website",
492 "email",
493 "email_verified",
494 "gender",
495 "birthdate",
496 "zoneinfo",
497 "locale",
498 "phone_number",
499 "phone_number_verified",
500 "address",
501 "updated_at",
502 ]
503 .contains(&k.as_str())
504 })
505 .collect(),
506 })
507 } else {
508 Ok(UserInfo {
510 sub: subject.clone(),
511 name: Some("John Doe".to_string()),
512 given_name: Some("John".to_string()),
513 family_name: Some("Doe".to_string()),
514 middle_name: None,
515 nickname: None,
516 preferred_username: Some(subject.clone()),
517 profile: None,
518 picture: Some("https://example.com/avatar.jpg".to_string()),
519 website: None,
520 email: Some("john.doe@example.com".to_string()),
521 email_verified: Some(true),
522 gender: None,
523 birthdate: None,
524 zoneinfo: None,
525 locale: None,
526 phone_number: None,
527 phone_number_verified: None,
528 address: None,
529 updated_at: None,
530 additional_claims: HashMap::new(),
531 })
532 }
533 }
534
535 pub async fn handle_logout(
537 &self,
538 id_token_hint: Option<&str>,
539 post_logout_redirect_uri: Option<&str>,
540 state: Option<&str>,
541 ) -> Result<LogoutResponse> {
542 if let Some(id_token) = id_token_hint {
544 let claims = self
545 .token_manager
546 .validate_jwt_token(id_token)
547 .map_err(|e| AuthError::auth_method("oidc", format!("Invalid ID token: {}", e)))?;
548
549 let user_sessions = self
551 .storage
552 .list_user_sessions(&claims.sub)
553 .await
554 .map_err(|e| AuthError::internal(format!("Failed to list user sessions: {}", e)))?;
555
556 for session in user_sessions {
557 self.storage
558 .delete_session(&session.session_id)
559 .await
560 .map_err(|e| AuthError::internal(format!("Failed to delete session: {}", e)))?;
561 }
562 }
563
564 if let Some(post_logout_uri) = post_logout_redirect_uri {
566 if let Some(id_token) = id_token_hint {
568 let claims = self
569 .token_manager
570 .validate_jwt_token(id_token)
571 .map_err(|e| {
572 AuthError::auth_method("oidc", format!("Invalid ID token: {}", e))
573 })?;
574
575 if let Some(aud) = claims.aud.split_whitespace().next() {
576 if !self
578 .is_post_logout_uri_registered(aud, post_logout_uri)
579 .await?
580 {
581 return Err(AuthError::validation(
582 "post_logout_redirect_uri not registered for client",
583 ));
584 }
585 }
586 } else {
587 return Err(AuthError::validation(
590 "id_token_hint required for post_logout_redirect_uri validation",
591 ));
592 }
593 }
594
595 Ok(LogoutResponse {
596 post_logout_redirect_uri: post_logout_redirect_uri.map(|uri| uri.to_string()),
597 state: state.map(|s| s.to_string()),
598 })
599 }
600
601 async fn is_post_logout_uri_registered(&self, client_id: &str, uri: &str) -> Result<bool> {
603 if !uri.starts_with("https://")
607 && !uri.starts_with("http://localhost")
608 && !uri.starts_with("http://127.0.0.1")
609 {
610 tracing::warn!(
611 "Rejected post-logout redirect URI with invalid scheme: {}",
612 uri
613 );
614 return Ok(false);
615 }
616
617 match self.get_client_registered_post_logout_uris(client_id).await {
619 Ok(registered_uris) => {
620 let is_registered = registered_uris.contains(&uri.to_string());
621 if !is_registered {
622 tracing::warn!(
623 "Rejected unregistered post-logout redirect URI for client {}: {}",
624 client_id,
625 uri
626 );
627 }
628 Ok(is_registered)
629 }
630 Err(_) => {
631 let is_safe_fallback = uri.starts_with("http://localhost")
633 || uri.starts_with("http://127.0.0.1")
634 || (uri.starts_with("https://") && !uri.contains("..") && !uri.contains("@"));
635
636 if !is_safe_fallback {
637 tracing::error!("Rejected potentially unsafe redirect URI: {}", uri);
638 }
639
640 Ok(is_safe_fallback)
641 }
642 }
643 }
644
645 async fn get_client_registered_post_logout_uris(&self, client_id: &str) -> Result<Vec<String>> {
647 match client_id {
650 "test_client" => Ok(vec![
651 "https://example.com/logout".to_string(),
652 "http://localhost:8080/logout".to_string(),
653 ]),
654 _ => {
655 Ok(Vec::new())
658 }
659 }
660 }
661
662 pub fn generate_jwks(&self) -> Result<JwkSet> {
664 let jwk = Jwk {
669 kty: "RSA".to_string(),
670 use_: Some("sig".to_string()),
671 key_ops: Some(vec!["verify".to_string()]),
672 alg: Some("RS256".to_string()),
673 kid: Some(format!("rsa-key-{}", chrono::Utc::now().timestamp())),
674 n: "sRJjz2xJOzqz1nFXKmjE3sXiZhG8s_jZo2_5Z3XJ8aYzEd7Z8GlVMmF6kWzT8k7sRJjz2xJOzqz1nFXKmjE3sXiZhG8s_jZo2_5Z3XJ8aYzEd7Z8GlVMmF6kWzT8k7sRJjz2xJOzqz1nFXKmjE3sXiZhG8s_jZo2_5Z3XJ8aYzEd7Z8GlVMmF6kWzT8k".to_string(),
676 e: "AQAB".to_string(),
677 additional_params: {
678 let mut params = HashMap::new();
679 params.insert("x5t".to_string(), serde_json::Value::String("example-thumbprint".to_string()));
680 params
681 },
682 };
683
684 Ok(JwkSet { keys: vec![jwk] })
685 }
686}
687
688#[derive(Debug, Clone, Serialize, Deserialize)]
690pub struct OidcAuthorizationRequest {
691 pub response_type: String,
692 pub client_id: String,
693 pub redirect_uri: String,
694 pub scope: String,
695 pub state: Option<String>,
696 pub nonce: Option<String>,
697 pub max_age: Option<u64>,
698 pub ui_locales: Option<String>,
699 pub claims_locales: Option<String>,
700 pub id_token_hint: Option<String>,
701 pub login_hint: Option<String>,
702 pub acr_values: Option<String>,
703 pub claims: Option<String>,
704 pub request: Option<String>,
705 pub request_uri: Option<String>,
706}
707
708#[derive(Debug, Clone)]
710pub struct AuthorizationValidationResult {
711 pub valid: bool,
712 pub client_id: String,
713 pub redirect_uri: String,
714 pub scope: String,
715 pub state: Option<String>,
716 pub nonce: Option<String>,
717 pub max_age: Option<u64>,
718 pub response_type: String,
719}
720
721#[derive(Debug, Clone, Serialize, Deserialize)]
723pub struct IdTokenClaims {
724 pub iss: String,
726 pub sub: String,
728 pub aud: Vec<String>,
730 pub exp: u64,
732 pub iat: u64,
734 #[serde(skip_serializing_if = "Option::is_none")]
736 pub auth_time: Option<u64>,
737 #[serde(skip_serializing_if = "Option::is_none")]
739 pub nonce: Option<String>,
740 #[serde(flatten)]
742 pub additional_claims: HashMap<String, Value>,
743}
744
745#[derive(Debug, Clone, Serialize, Deserialize)]
747pub struct UserInfo {
748 pub sub: String,
749 #[serde(skip_serializing_if = "Option::is_none")]
750 pub name: Option<String>,
751 #[serde(skip_serializing_if = "Option::is_none")]
752 pub given_name: Option<String>,
753 #[serde(skip_serializing_if = "Option::is_none")]
754 pub family_name: Option<String>,
755 #[serde(skip_serializing_if = "Option::is_none")]
756 pub middle_name: Option<String>,
757 #[serde(skip_serializing_if = "Option::is_none")]
758 pub nickname: Option<String>,
759 #[serde(skip_serializing_if = "Option::is_none")]
760 pub preferred_username: Option<String>,
761 #[serde(skip_serializing_if = "Option::is_none")]
762 pub profile: Option<String>,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 pub picture: Option<String>,
765 #[serde(skip_serializing_if = "Option::is_none")]
766 pub website: Option<String>,
767 #[serde(skip_serializing_if = "Option::is_none")]
768 pub email: Option<String>,
769 #[serde(skip_serializing_if = "Option::is_none")]
770 pub email_verified: Option<bool>,
771 #[serde(skip_serializing_if = "Option::is_none")]
772 pub gender: Option<String>,
773 #[serde(skip_serializing_if = "Option::is_none")]
774 pub birthdate: Option<String>,
775 #[serde(skip_serializing_if = "Option::is_none")]
776 pub zoneinfo: Option<String>,
777 #[serde(skip_serializing_if = "Option::is_none")]
778 pub locale: Option<String>,
779 #[serde(skip_serializing_if = "Option::is_none")]
780 pub phone_number: Option<String>,
781 #[serde(skip_serializing_if = "Option::is_none")]
782 pub phone_number_verified: Option<bool>,
783 #[serde(skip_serializing_if = "Option::is_none")]
784 pub address: Option<Address>,
785 #[serde(skip_serializing_if = "Option::is_none")]
786 pub updated_at: Option<u64>,
787 #[serde(flatten)]
788 pub additional_claims: HashMap<String, Value>,
789}
790
791#[derive(Debug, Clone, Serialize, Deserialize)]
793pub struct Address {
794 #[serde(skip_serializing_if = "Option::is_none")]
795 pub formatted: Option<String>,
796 #[serde(skip_serializing_if = "Option::is_none")]
797 pub street_address: Option<String>,
798 #[serde(skip_serializing_if = "Option::is_none")]
799 pub locality: Option<String>,
800 #[serde(skip_serializing_if = "Option::is_none")]
801 pub region: Option<String>,
802 #[serde(skip_serializing_if = "Option::is_none")]
803 pub postal_code: Option<String>,
804 #[serde(skip_serializing_if = "Option::is_none")]
805 pub country: Option<String>,
806}
807
808#[derive(Debug, Clone, Serialize, Deserialize)]
810pub struct OidcDiscoveryDocument {
811 pub issuer: String,
812 pub authorization_endpoint: String,
813 pub token_endpoint: String,
814 pub userinfo_endpoint: String,
815 pub jwks_uri: String,
816 #[serde(skip_serializing_if = "Option::is_none")]
817 pub registration_endpoint: Option<String>,
818 pub scopes_supported: Vec<String>,
819 pub response_types_supported: Vec<String>,
820 #[serde(skip_serializing_if = "Option::is_none")]
821 pub response_modes_supported: Option<Vec<String>>,
822 #[serde(skip_serializing_if = "Option::is_none")]
823 pub grant_types_supported: Option<Vec<String>>,
824 pub subject_types_supported: Vec<SubjectType>,
825 pub id_token_signing_alg_values_supported: Vec<String>,
826 #[serde(skip_serializing_if = "Option::is_none")]
827 pub userinfo_signing_alg_values_supported: Option<Vec<String>>,
828 #[serde(skip_serializing_if = "Option::is_none")]
829 pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
830 #[serde(skip_serializing_if = "Option::is_none")]
831 pub claims_supported: Option<Vec<String>>,
832 #[serde(skip_serializing_if = "Option::is_none")]
833 pub claims_parameter_supported: Option<bool>,
834 #[serde(skip_serializing_if = "Option::is_none")]
835 pub request_parameter_supported: Option<bool>,
836 #[serde(skip_serializing_if = "Option::is_none")]
837 pub request_uri_parameter_supported: Option<bool>,
838 #[serde(skip_serializing_if = "Option::is_none")]
839 pub code_challenge_methods_supported: Option<Vec<String>>,
840}
841
842#[derive(Debug, Clone, Serialize, Deserialize)]
844pub struct JwkSet {
845 pub keys: Vec<Jwk>,
846}
847
848#[derive(Debug, Clone, Serialize, Deserialize)]
850pub struct Jwk {
851 pub kty: String,
852 #[serde(rename = "use", skip_serializing_if = "Option::is_none")]
853 pub use_: Option<String>,
854 #[serde(skip_serializing_if = "Option::is_none")]
855 pub key_ops: Option<Vec<String>>,
856 #[serde(skip_serializing_if = "Option::is_none")]
857 pub alg: Option<String>,
858 #[serde(skip_serializing_if = "Option::is_none")]
859 pub kid: Option<String>,
860 pub n: String,
861 pub e: String,
862 #[serde(flatten)]
863 pub additional_params: HashMap<String, Value>,
864}
865
866#[derive(Debug, Clone)]
868pub struct LogoutResponse {
869 pub post_logout_redirect_uri: Option<String>,
870 pub state: Option<String>,
871}
872
873fn algorithm_to_string(alg: &Algorithm) -> String {
875 match alg {
876 Algorithm::HS256 => "HS256".to_string(),
877 Algorithm::HS384 => "HS384".to_string(),
878 Algorithm::HS512 => "HS512".to_string(),
879 Algorithm::ES256 => "ES256".to_string(),
880 Algorithm::ES384 => "ES384".to_string(),
881 Algorithm::RS256 => "RS256".to_string(),
882 Algorithm::RS384 => "RS384".to_string(),
883 Algorithm::RS512 => "RS512".to_string(),
884 Algorithm::PS256 => "PS256".to_string(),
885 Algorithm::PS384 => "PS384".to_string(),
886 Algorithm::PS512 => "PS512".to_string(),
887 Algorithm::EdDSA => "EdDSA".to_string(),
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894 use crate::storage::MemoryStorage;
895
896 async fn create_test_oidc_provider() -> OidcProvider<MemoryStorage> {
897 let config = OidcConfig::default();
898 let token_manager = Arc::new(TokenManager::new_hmac(
899 b"test_secret_key_32_bytes_long!!!!",
900 "test_issuer",
901 "test_audience",
902 ));
903 let storage = Arc::new(MemoryStorage::new());
904
905 OidcProvider::new(config, token_manager, storage)
906 .await
907 .unwrap()
908 }
909
910 #[tokio::test]
911 async fn test_oidc_provider_creation() {
912 let provider = create_test_oidc_provider().await;
913 assert_eq!(provider.config.issuer, "https://auth.example.com");
914 assert!(
915 provider
916 .config
917 .scopes_supported
918 .contains(&"openid".to_string())
919 );
920 }
921
922 #[tokio::test]
923 async fn test_discovery_document() {
924 let provider = create_test_oidc_provider().await;
925 let discovery = provider.discovery_document().unwrap();
926
927 assert_eq!(discovery.issuer, "https://auth.example.com");
928 assert_eq!(
929 discovery.authorization_endpoint,
930 "https://auth.example.com/oidc/authorize"
931 );
932 assert!(discovery.scopes_supported.contains(&"openid".to_string()));
933 assert!(
934 discovery
935 .response_types_supported
936 .contains(&"code".to_string())
937 );
938 }
939
940 #[tokio::test]
941 async fn test_authorization_request_validation() {
942 let provider = create_test_oidc_provider().await;
943
944 let valid_request = OidcAuthorizationRequest {
945 response_type: "code".to_string(),
946 client_id: "test_client".to_string(),
947 redirect_uri: "https://client.example.com/callback".to_string(),
948 scope: "openid profile email".to_string(),
949 state: Some("abc123".to_string()),
950 nonce: Some("xyz789".to_string()),
951 max_age: None,
952 ui_locales: None,
953 claims_locales: None,
954 id_token_hint: None,
955 login_hint: None,
956 acr_values: None,
957 claims: None,
958 request: None,
959 request_uri: None,
960 };
961
962 let result = provider
963 .validate_authorization_request(&valid_request)
964 .await
965 .unwrap();
966 assert!(result.valid);
967 assert_eq!(result.client_id, "test_client");
968 assert_eq!(result.scope, "openid profile email");
969 }
970
971 #[tokio::test]
972 async fn test_authorization_request_missing_openid_scope() {
973 let provider = create_test_oidc_provider().await;
974
975 let invalid_request = OidcAuthorizationRequest {
976 response_type: "code".to_string(),
977 client_id: "test_client".to_string(),
978 redirect_uri: "https://client.example.com/callback".to_string(),
979 scope: "profile email".to_string(), state: Some("abc123".to_string()),
981 nonce: Some("xyz789".to_string()),
982 max_age: None,
983 ui_locales: None,
984 claims_locales: None,
985 id_token_hint: None,
986 login_hint: None,
987 acr_values: None,
988 claims: None,
989 request: None,
990 request_uri: None,
991 };
992
993 let result = provider
994 .validate_authorization_request(&invalid_request)
995 .await;
996 assert!(result.is_err());
997 }
998
999 #[tokio::test]
1000 async fn test_id_token_creation() {
1001 let provider = create_test_oidc_provider().await;
1002
1003 let auth_time = SystemTime::now();
1004 let mut claims = HashMap::new();
1005 claims.insert("name".to_string(), Value::String("John Doe".to_string()));
1006 claims.insert(
1007 "email".to_string(),
1008 Value::String("john@example.com".to_string()),
1009 );
1010
1011 let id_token = provider
1012 .create_id_token(
1013 "user123",
1014 "client456",
1015 Some("nonce789"),
1016 Some(auth_time),
1017 Some(&claims),
1018 )
1019 .await
1020 .unwrap();
1021
1022 assert!(!id_token.is_empty());
1023 assert!(id_token.contains('.'));
1024 }
1025
1026 #[tokio::test]
1027 async fn test_jwks_generation() {
1028 let provider = create_test_oidc_provider().await;
1029 let jwks = provider.generate_jwks().unwrap();
1030
1031 assert!(!jwks.keys.is_empty());
1032 assert_eq!(jwks.keys[0].kty, "RSA");
1033 assert_eq!(jwks.keys[0].alg, Some("RS256".to_string()));
1034 }
1035
1036 #[tokio::test]
1037 async fn test_logout_handling() {
1038 let provider = create_test_oidc_provider().await;
1039
1040 let logout_response = provider
1042 .handle_logout(None, None, Some("state123"))
1043 .await
1044 .unwrap();
1045
1046 assert_eq!(logout_response.post_logout_redirect_uri, None);
1047 assert_eq!(logout_response.state, Some("state123".to_string()));
1048 }
1049
1050 #[test]
1051 fn test_subject_type_serialization() {
1052 let public = SubjectType::Public;
1053 let pairwise = SubjectType::Pairwise;
1054
1055 let public_json = serde_json::to_string(&public).unwrap();
1056 let pairwise_json = serde_json::to_string(&pairwise).unwrap();
1057
1058 assert_eq!(public_json, "\"public\"");
1059 assert_eq!(pairwise_json, "\"pairwise\"");
1060 }
1061
1062 #[test]
1063 fn test_algorithm_to_string_conversion() {
1064 assert_eq!(algorithm_to_string(&Algorithm::RS256), "RS256");
1065 assert_eq!(algorithm_to_string(&Algorithm::ES256), "ES256");
1066 assert_eq!(algorithm_to_string(&Algorithm::HS256), "HS256");
1067 assert_eq!(algorithm_to_string(&Algorithm::EdDSA), "EdDSA");
1068 }
1069
1070 #[test]
1071 fn test_oidc_config_default() {
1072 let config = OidcConfig::default();
1073 assert_eq!(config.issuer, "https://auth.example.com");
1074 assert!(config.scopes_supported.contains(&"openid".to_string()));
1075 assert!(config.claims_supported.contains(&"sub".to_string()));
1076 assert_eq!(config.subject_types_supported, vec![SubjectType::Public]);
1077 }
1078}
1079
1080