1use crate::errors::{AuthError, Result};
7use crate::oauth2_enhanced_storage::{
8 EnhancedAuthorizationCode, EnhancedClientCredentials, EnhancedTokenStorage, RefreshToken,
9};
10use crate::security::secure_utils::constant_time_compare;
11use crate::tokens::{AuthToken, TokenManager};
12use crate::user_context::{SessionStore, UserContext};
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::RwLock;
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub enum GrantType {
22 AuthorizationCode,
23 RefreshToken,
24 ClientCredentials,
25 DeviceCode,
26 TokenExchange,
27}
28
29impl std::fmt::Display for GrantType {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 GrantType::AuthorizationCode => write!(f, "authorization_code"),
33 GrantType::RefreshToken => write!(f, "refresh_token"),
34 GrantType::ClientCredentials => write!(f, "client_credentials"),
35 GrantType::DeviceCode => write!(f, "urn:ietf:params:oauth:grant-type:device_code"),
36 GrantType::TokenExchange => {
37 write!(f, "urn:ietf:params:oauth:grant-type:token-exchange")
38 }
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub enum ResponseType {
46 Code,
47 Token,
48 IdToken,
49}
50
51#[derive(Debug, Clone)]
53pub struct OAuth2Config {
54 pub issuer: String,
56 pub authorization_code_lifetime: Duration,
58 pub access_token_lifetime: Duration,
60 pub refresh_token_lifetime: Duration,
62 pub device_code_lifetime: Duration,
64 pub default_scope: Option<String>,
66 pub max_scope_lifetime: Duration,
68 pub require_pkce: bool,
70 pub enable_introspection: bool,
72 pub enable_revocation: bool,
74}
75
76impl Default for OAuth2Config {
77 fn default() -> Self {
78 Self {
79 issuer: "https://auth.example.com".to_string(),
80 authorization_code_lifetime: Duration::from_secs(600), access_token_lifetime: Duration::from_secs(3600), refresh_token_lifetime: Duration::from_secs(86400 * 7), device_code_lifetime: Duration::from_secs(600), default_scope: Some("read".to_string()),
85 max_scope_lifetime: Duration::from_secs(86400 * 30), require_pkce: true,
87 enable_introspection: true,
88 enable_revocation: true,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Default)]
95pub struct TokenRequest {
96 pub grant_type: String,
97 pub client_id: String,
98 pub client_secret: Option<String>,
99 pub code: Option<String>,
100 pub redirect_uri: Option<String>,
101 pub refresh_token: Option<String>,
102 pub scope: Option<String>,
103 pub code_verifier: Option<String>,
104 pub username: Option<String>,
105 pub password: Option<String>,
106 pub device_code: Option<String>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct TokenResponse {
112 pub access_token: String,
113 pub token_type: String,
114 pub expires_in: u64,
115 pub refresh_token: Option<String>,
116 pub scope: Option<String>,
117 pub id_token: Option<String>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct AuthorizationRequest {
123 pub client_id: String,
124 pub response_type: String,
125 pub redirect_uri: String,
126 pub scope: Option<String>,
127 pub state: Option<String>,
128 pub code_challenge: Option<String>,
129 pub code_challenge_method: Option<String>,
130 pub nonce: Option<String>,
131}
132
133pub struct OAuth2Server {
135 config: OAuth2Config,
136 token_storage: Arc<RwLock<EnhancedTokenStorage>>,
137 session_store: Arc<RwLock<SessionStore>>,
138 token_manager: Arc<TokenManager>,
139}
140
141impl OAuth2Server {
142 pub async fn new(config: OAuth2Config, token_manager: Arc<TokenManager>) -> Result<Self> {
143 Ok(Self {
144 config,
145 token_storage: Arc::new(RwLock::new(EnhancedTokenStorage::new())),
146 session_store: Arc::new(RwLock::new(SessionStore::new())),
147 token_manager,
148 })
149 }
150
151 pub async fn register_confidential_client(
153 &self,
154 client_id: String,
155 client_secret: &str,
156 redirect_uris: Vec<String>,
157 allowed_scopes: Vec<String>,
158 grant_types: Vec<String>,
159 ) -> Result<()> {
160 if client_secret.len() < 32 {
162 return Err(AuthError::auth_method(
163 "oauth2",
164 "Client secret must be at least 32 characters",
165 ));
166 }
167
168 let credentials = EnhancedClientCredentials::new_confidential(
169 client_id,
170 client_secret,
171 redirect_uris,
172 allowed_scopes,
173 grant_types,
174 )?;
175
176 let mut storage = self.token_storage.write().await;
177 storage.store_client_credentials(credentials).await?;
178
179 Ok(())
180 }
181
182 pub async fn register_public_client(
184 &self,
185 client_id: String,
186 redirect_uris: Vec<String>,
187 allowed_scopes: Vec<String>,
188 grant_types: Vec<String>,
189 ) -> Result<()> {
190 let credentials = EnhancedClientCredentials::new_public(
191 client_id,
192 redirect_uris,
193 allowed_scopes,
194 grant_types,
195 );
196
197 let mut storage = self.token_storage.write().await;
198 storage.store_client_credentials(credentials).await?;
199
200 Ok(())
201 }
202
203 pub async fn create_authorization_code(
205 &self,
206 request: AuthorizationRequest,
207 user_context: UserContext,
208 ) -> Result<EnhancedAuthorizationCode> {
209 let storage = self.token_storage.read().await;
211 let client = storage
212 .get_client_credentials(&request.client_id)
213 .await?
214 .ok_or_else(|| AuthError::auth_method("oauth2", "Invalid client_id"))?;
215
216 if !client.supports_grant_type("authorization_code") {
217 return Err(AuthError::auth_method(
218 "oauth2",
219 "Client does not support authorization code grant",
220 ));
221 }
222
223 if !client.redirect_uris.contains(&request.redirect_uri) {
224 return Err(AuthError::auth_method("oauth2", "Invalid redirect_uri"));
225 }
226
227 let requested_scopes = self.parse_scopes(request.scope.as_deref())?;
229 let authorized_scopes = self.authorize_scopes(&client, &user_context, &requested_scopes)?;
230
231 let auth_code = EnhancedAuthorizationCode::new(
233 client.client_id.clone(),
234 user_context.user_id.clone(), request.redirect_uri,
236 authorized_scopes,
237 request.code_challenge,
238 request.code_challenge_method,
239 self.config.authorization_code_lifetime,
240 );
241
242 drop(storage);
244 let mut storage = self.token_storage.write().await;
245 storage.store_authorization_code(auth_code.clone()).await?;
246
247 Ok(auth_code)
248 }
249
250 pub async fn token_exchange(&self, request: TokenRequest) -> Result<TokenResponse> {
252 match request.grant_type.as_str() {
253 "authorization_code" => self.handle_authorization_code_grant(request).await,
254 "refresh_token" => self.handle_refresh_token_grant(request).await,
255 "client_credentials" => self.handle_client_credentials_grant(request).await,
256 _ => Err(AuthError::auth_method("oauth2", "Unsupported grant type")),
257 }
258 }
259
260 async fn handle_authorization_code_grant(
262 &self,
263 request: TokenRequest,
264 ) -> Result<TokenResponse> {
265 let storage = self.token_storage.read().await;
267 let _client = storage
268 .get_client_credentials(&request.client_id)
269 .await?
270 .ok_or_else(|| AuthError::auth_method("oauth2", "Invalid client_id"))?;
271
272 if !storage
274 .validate_client_credentials(&request.client_id, request.client_secret.as_deref())
275 .await?
276 {
277 return Err(AuthError::auth_method(
278 "oauth2",
279 "Invalid client credentials",
280 ));
281 }
282
283 let code = request
285 .code
286 .ok_or_else(|| AuthError::auth_method("oauth2", "Missing authorization code"))?;
287
288 drop(storage);
289 let mut storage = self.token_storage.write().await;
290 let auth_code = storage
291 .consume_authorization_code(&code)
292 .await?
293 .ok_or_else(|| {
294 AuthError::auth_method("oauth2", "Invalid or expired authorization code")
295 })?;
296
297 if auth_code.client_id != request.client_id {
299 return Err(AuthError::auth_method(
300 "oauth2",
301 "Authorization code does not belong to client",
302 ));
303 }
304
305 if let Some(challenge) = &auth_code.code_challenge {
307 let verifier = request
308 .code_verifier
309 .ok_or_else(|| AuthError::auth_method("oauth2", "PKCE code verifier required"))?;
310
311 if !self.validate_pkce_challenge(
312 challenge,
313 &verifier,
314 &auth_code.code_challenge_method,
315 )? {
316 return Err(AuthError::auth_method(
317 "oauth2",
318 "Invalid PKCE code verifier",
319 ));
320 }
321 }
322
323 let access_token = self
325 .generate_access_token(
326 &auth_code.client_id,
327 Some(&auth_code.user_id), &auth_code.scopes,
329 )
330 .await?;
331
332 let refresh_token = RefreshToken::new(
334 auth_code.client_id.clone(),
335 auth_code.user_id.clone(), auth_code.scopes.clone(), self.config.refresh_token_lifetime,
338 );
339
340 let refresh_token_id = storage.store_refresh_token(refresh_token).await?;
341
342 Ok(TokenResponse {
343 access_token: access_token.access_token,
344 token_type: "Bearer".to_string(),
345 expires_in: self.config.access_token_lifetime.as_secs(),
346 refresh_token: Some(refresh_token_id),
347 scope: Some(auth_code.scopes.join(" ")),
348 id_token: None,
349 })
350 }
351
352 async fn handle_refresh_token_grant(&self, request: TokenRequest) -> Result<TokenResponse> {
354 let storage = self.token_storage.read().await;
356 if !storage
357 .validate_client_credentials(&request.client_id, request.client_secret.as_deref())
358 .await?
359 {
360 return Err(AuthError::auth_method(
361 "oauth2",
362 "Invalid client credentials",
363 ));
364 }
365
366 let refresh_token_id = request
368 .refresh_token
369 .ok_or_else(|| AuthError::auth_method("oauth2", "Missing refresh token"))?;
370
371 let stored_token = storage
373 .get_refresh_token(&refresh_token_id)
374 .await?
375 .ok_or_else(|| AuthError::auth_method("oauth2", "Invalid refresh token"))?;
376
377 if !stored_token.is_valid() {
378 return Err(AuthError::auth_method(
379 "oauth2",
380 "Refresh token is expired or revoked",
381 ));
382 }
383
384 if stored_token.client_id != request.client_id {
386 return Err(AuthError::auth_method(
387 "oauth2",
388 "Refresh token does not belong to client",
389 ));
390 }
391
392 let requested_scopes = self.parse_scopes(request.scope.as_deref())?;
394 let authorized_scopes = if requested_scopes.is_empty() {
395 stored_token.scopes.clone() } else {
397 self.validate_scope_subset(&stored_token.scopes, &requested_scopes)?
398 };
399
400 drop(storage);
401
402 let access_token = self
404 .generate_access_token(
405 &stored_token.client_id,
406 Some(&stored_token.user_id), &authorized_scopes,
408 )
409 .await?;
410
411 let mut storage = self.token_storage.write().await;
413 storage.revoke_refresh_token(&refresh_token_id).await?; let new_refresh_token = RefreshToken::new(
416 stored_token.client_id.clone(),
417 stored_token.user_id.clone(),
418 authorized_scopes.clone(),
419 self.config.refresh_token_lifetime,
420 );
421
422 let new_refresh_token_id = storage.store_refresh_token(new_refresh_token).await?;
423
424 Ok(TokenResponse {
425 access_token: access_token.access_token,
426 token_type: "Bearer".to_string(),
427 expires_in: self.config.access_token_lifetime.as_secs(),
428 refresh_token: Some(new_refresh_token_id),
429 scope: Some(authorized_scopes.join(" ")),
430 id_token: None,
431 })
432 }
433
434 async fn handle_client_credentials_grant(
436 &self,
437 request: TokenRequest,
438 ) -> Result<TokenResponse> {
439 let storage = self.token_storage.read().await;
441 let client = storage
442 .get_client_credentials(&request.client_id)
443 .await?
444 .ok_or_else(|| AuthError::auth_method("oauth2", "Invalid client_id"))?;
445
446 if !storage
448 .validate_client_credentials(&request.client_id, request.client_secret.as_deref())
449 .await?
450 {
451 return Err(AuthError::auth_method(
452 "oauth2",
453 "Invalid client credentials",
454 ));
455 }
456
457 if !client.supports_grant_type("client_credentials") {
458 return Err(AuthError::auth_method(
459 "oauth2",
460 "Client does not support client credentials grant",
461 ));
462 }
463
464 let requested_scopes = self.parse_scopes(request.scope.as_deref())?;
466 let authorized_scopes = requested_scopes
467 .iter()
468 .filter(|scope| client.has_scope(scope))
469 .cloned()
470 .collect::<Vec<_>>();
471
472 if authorized_scopes.is_empty() && !requested_scopes.is_empty() {
473 return Err(AuthError::auth_method("oauth2", "No authorized scopes"));
474 }
475
476 drop(storage);
477
478 let access_token = self
480 .generate_access_token(&request.client_id, None, &authorized_scopes)
481 .await?;
482
483 Ok(TokenResponse {
484 access_token: access_token.access_token,
485 token_type: "Bearer".to_string(),
486 expires_in: self.config.access_token_lifetime.as_secs(),
487 refresh_token: None, scope: Some(authorized_scopes.join(" ")),
489 id_token: None,
490 })
491 }
492
493 async fn generate_access_token(
495 &self,
496 client_id: &str,
497 user_id: Option<&str>,
498 scopes: &[String],
499 ) -> Result<AuthToken> {
500 let subject = user_id.unwrap_or(client_id);
501 let mut token = self.token_manager.create_auth_token(
502 subject,
503 scopes.iter().map(|s| s.to_string()).collect(),
504 "oauth2",
505 Some(self.config.access_token_lifetime),
506 )?;
507
508 token.add_custom_claim(
510 "client_id".to_string(),
511 serde_json::Value::String(client_id.to_string()),
512 );
513
514 if let Some(uid) = user_id {
516 token.add_custom_claim(
517 "user_id".to_string(),
518 serde_json::Value::String(uid.to_string()),
519 );
520 }
521
522 Ok(token)
523 }
524
525 fn parse_scopes(&self, scope_str: Option<&str>) -> Result<Vec<String>> {
527 match scope_str {
528 Some(scopes) => Ok(scopes.split_whitespace().map(|s| s.to_string()).collect()),
529 None => match &self.config.default_scope {
530 Some(default) => Ok(vec![default.clone()]),
531 None => Ok(vec![]),
532 },
533 }
534 }
535
536 fn authorize_scopes(
538 &self,
539 client: &EnhancedClientCredentials,
540 user_context: &UserContext,
541 requested_scopes: &[String],
542 ) -> Result<Vec<String>> {
543 let mut authorized = Vec::new();
544
545 for scope in requested_scopes {
546 if client.has_scope(scope) {
548 if user_context.has_scope(scope) {
550 authorized.push(scope.clone());
551 }
552 }
553 }
554
555 if authorized.is_empty() && !requested_scopes.is_empty() {
556 return Err(AuthError::auth_method("oauth2", "No authorized scopes"));
557 }
558
559 Ok(authorized)
560 }
561
562 fn validate_scope_subset(
564 &self,
565 original_scopes: &[String],
566 requested_scopes: &[String],
567 ) -> Result<Vec<String>> {
568 let mut validated = Vec::new();
569
570 for scope in requested_scopes {
571 if original_scopes.contains(scope) {
572 validated.push(scope.clone());
573 } else {
574 return Err(AuthError::auth_method(
575 "oauth2",
576 format!("Requested scope '{}' not in original grant", scope),
577 ));
578 }
579 }
580
581 Ok(validated)
582 }
583
584 fn validate_pkce_challenge(
586 &self,
587 challenge: &str,
588 verifier: &str,
589 method: &Option<String>,
590 ) -> Result<bool> {
591 let method = method.as_deref().unwrap_or("plain");
592
593 match method {
594 "plain" => Ok(constant_time_compare(
595 challenge.as_bytes(),
596 verifier.as_bytes(),
597 )),
598 "S256" => {
599 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
600 use sha2::{Digest, Sha256};
601
602 let hash = Sha256::digest(verifier.as_bytes());
603 let encoded = URL_SAFE_NO_PAD.encode(hash);
604 Ok(constant_time_compare(
605 challenge.as_bytes(),
606 encoded.as_bytes(),
607 ))
608 }
609 _ => Err(AuthError::auth_method("oauth2", "Unsupported PKCE method")),
610 }
611 }
612
613 pub async fn revoke_token(&self, token: &str, client_id: &str) -> Result<bool> {
615 let mut storage = self.token_storage.write().await;
616
617 if client_id.is_empty() {
619 return Err(AuthError::auth_method(
620 "oauth2",
621 "Client ID is required for token revocation",
622 ));
623 }
624
625 if storage.get_client_credentials(client_id).await.is_err() {
627 return Err(AuthError::auth_method("oauth2", "Invalid client"));
628 }
629
630 if storage.validate_refresh_token(token).await? {
632 return storage.revoke_refresh_token(token).await;
633 }
634
635 Ok(false)
638 }
639
640 pub async fn cleanup_expired_tokens(&self) -> Result<usize> {
642 let mut storage = self.token_storage.write().await;
643 storage.cleanup_expired_tokens().await
644 }
645
646 pub async fn authenticate_user(
648 &self,
649 username: &str,
650 password: &str,
651 scopes: Vec<String>,
652 ) -> Result<UserContext> {
653 let storage = self.token_storage.read().await;
655
656 if !self
658 .validate_user_credentials_against_storage(&storage, username, password)
659 .await?
660 {
661 return Err(AuthError::auth_method(
662 "oauth2",
663 "Invalid username or password",
664 ));
665 }
666
667 let authorized_scopes = self
669 .validate_user_scopes_against_storage(&storage, username, &scopes)
670 .await?;
671
672 drop(storage);
673
674 let user_context = UserContext::new(
676 self.generate_user_id(username).await?,
677 username.to_string(),
678 self.get_user_email(username).await?,
679 )
680 .with_scopes(authorized_scopes);
681
682 let mut session_store = self.session_store.write().await;
683 session_store.create_session(user_context.clone());
684
685 Ok(user_context)
686 }
687
688 async fn validate_user_credentials_against_storage(
690 &self,
691 storage: &EnhancedTokenStorage,
692 username: &str,
693 password: &str,
694 ) -> Result<bool> {
695 let is_empty = username.is_empty() || password.is_empty();
697 let is_too_short = password.len() < 8;
698
699 match storage.get_user_credentials(username).await {
701 Ok(Some(stored_credentials)) => {
702 use bcrypt::verify;
704 match verify(password, &stored_credentials.password_hash) {
705 Ok(is_valid) => {
706 Ok(is_valid && !is_empty && !is_too_short)
708 }
709 Err(_) => {
710 Ok(false)
712 }
713 }
714 }
715 Ok(None) => {
716 use bcrypt::verify;
718 let _dummy_result = verify(
719 password,
720 "$2b$12$K2CtDP7zMH7VgxScmHTa/.EUm5nd9.xnZM8Cl/p9RMb5QZaJUHgBm",
721 );
722 Ok(false)
723 }
724 Err(_) => {
725 use bcrypt::verify;
727 let _dummy_result = verify(
728 password,
729 "$2b$12$K2CtDP7zMH7VgxScmHTa/.EUm5nd9.xnZM8Cl/p9RMb5QZaJUHgBm",
730 );
731 Ok(false)
732 }
733 }
734 }
735
736 async fn validate_user_scopes_against_storage(
738 &self,
739 storage: &EnhancedTokenStorage,
740 username: &str,
741 requested_scopes: &[String],
742 ) -> Result<Vec<String>> {
743 let user_permissions = match storage.get_user_permissions(username).await {
745 Ok(Some(permissions)) => permissions.scopes,
746 Ok(None) => {
747 return Err(AuthError::auth_method(
748 "oauth2",
749 "User not found in permission store",
750 ));
751 }
752 Err(_) => {
753 return Err(AuthError::auth_method(
754 "oauth2",
755 "Failed to retrieve user permissions",
756 ));
757 }
758 };
759
760 let mut authorized = Vec::new();
761 for scope in requested_scopes {
762 if user_permissions.contains(scope) {
763 authorized.push(scope.clone());
764 }
765 }
766
767 if authorized.is_empty() && !requested_scopes.is_empty() {
769 return Err(AuthError::auth_method(
770 "oauth2",
771 "User not authorized for requested scopes",
772 ));
773 }
774
775 if authorized.is_empty() {
776 if user_permissions.contains(&"read".to_string()) {
778 authorized.push("read".to_string());
779 } else {
780 return Err(AuthError::auth_method(
781 "oauth2",
782 "User has no authorized scopes",
783 ));
784 }
785 }
786
787 Ok(authorized)
788 }
789
790 async fn generate_user_id(&self, username: &str) -> Result<String> {
792 let hash = Sha256::digest(format!("user_id_{}", username).as_bytes());
795 let hash_str = format!("{:x}", hash);
796 Ok(format!("user_{}", &hash_str[0..16]))
797 }
798
799 async fn get_user_email(&self, username: &str) -> Result<Option<String>> {
801 Ok(Some(format!("{}@example.com", username)))
803 }
804
805 pub async fn get_user_context(&self, session_id: &str) -> Result<Option<UserContext>> {
807 let session_store = self.session_store.read().await;
808 Ok(session_store.get_session(session_id).cloned())
809 }
810
811 pub async fn invalidate_session(&self, session_id: &str) -> Result<bool> {
813 let mut session_store = self.session_store.write().await;
814 Ok(session_store.invalidate_session(session_id))
815 }
816}
817
818