1use crate::auth::{AuthError, AuthResult, JwtManager};
12use dashmap::DashMap;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16use uuid::Uuid;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum GrantType {
22 AuthorizationCode,
24 ClientCredentials,
26 RefreshToken,
28 Implicit,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum TokenType {
36 Bearer,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum ResponseType {
43 Code,
44 Token,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct Scope(String);
50
51impl Scope {
52 pub fn new(scope: impl Into<String>) -> Self {
53 Self(scope.into())
54 }
55
56 pub fn as_str(&self) -> &str {
57 &self.0
58 }
59
60 pub fn parse_scopes(scopes: &str) -> Vec<Scope> {
62 scopes.split_whitespace().map(Scope::new).collect()
63 }
64
65 pub fn join_scopes(scopes: &[Scope]) -> String {
67 scopes
68 .iter()
69 .map(|s| s.as_str())
70 .collect::<Vec<_>>()
71 .join(" ")
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OAuth2Client {
78 pub client_id: String,
79 pub client_secret: String,
80 pub redirect_uris: Vec<String>,
81 pub grant_types: Vec<GrantType>,
82 pub scopes: Vec<Scope>,
83 pub name: String,
84 pub created_at: u64,
85}
86
87impl OAuth2Client {
88 pub fn new(
89 name: String,
90 redirect_uris: Vec<String>,
91 grant_types: Vec<GrantType>,
92 scopes: Vec<Scope>,
93 ) -> Self {
94 Self {
95 client_id: Uuid::new_v4().to_string(),
96 client_secret: Uuid::new_v4().to_string(),
97 redirect_uris,
98 grant_types,
99 scopes,
100 name,
101 created_at: SystemTime::now()
102 .duration_since(UNIX_EPOCH)
103 .unwrap()
104 .as_secs(),
105 }
106 }
107
108 pub fn verify_secret(&self, secret: &str) -> bool {
110 self.client_secret == secret
111 }
112
113 pub fn is_redirect_uri_allowed(&self, uri: &str) -> bool {
115 self.redirect_uris.iter().any(|u| u == uri)
116 }
117
118 pub fn is_grant_type_allowed(&self, grant_type: GrantType) -> bool {
120 self.grant_types.contains(&grant_type)
121 }
122
123 pub fn is_scope_allowed(&self, scope: &Scope) -> bool {
125 self.scopes.contains(scope)
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct AuthorizationCode {
132 pub code: String,
133 pub client_id: String,
134 pub redirect_uri: String,
135 pub scopes: Vec<Scope>,
136 pub user_id: String,
137 pub expires_at: u64,
138 pub code_challenge: Option<String>,
140 pub code_challenge_method: Option<CodeChallengeMethod>,
142}
143
144impl AuthorizationCode {
145 pub fn new(
146 client_id: String,
147 redirect_uri: String,
148 scopes: Vec<Scope>,
149 user_id: String,
150 ttl: Duration,
151 code_challenge: Option<String>,
152 code_challenge_method: Option<CodeChallengeMethod>,
153 ) -> Self {
154 let expires_at = SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .unwrap()
157 .as_secs()
158 + ttl.as_secs();
159
160 Self {
161 code: Uuid::new_v4().to_string(),
162 client_id,
163 redirect_uri,
164 scopes,
165 user_id,
166 expires_at,
167 code_challenge,
168 code_challenge_method,
169 }
170 }
171
172 pub fn is_expired(&self) -> bool {
173 let now = SystemTime::now()
174 .duration_since(UNIX_EPOCH)
175 .unwrap()
176 .as_secs();
177 now > self.expires_at
178 }
179
180 pub fn verify_code_verifier(&self, verifier: &str) -> bool {
182 match (&self.code_challenge, &self.code_challenge_method) {
183 (Some(challenge), Some(method)) => {
184 let computed_challenge = method.compute_challenge(verifier);
185 &computed_challenge == challenge
186 }
187 (None, None) => true, _ => false, }
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
195pub enum CodeChallengeMethod {
196 #[serde(rename = "plain")]
197 Plain,
198 #[serde(rename = "S256")]
199 S256,
200}
201
202impl CodeChallengeMethod {
203 pub fn compute_challenge(&self, verifier: &str) -> String {
205 match self {
206 Self::Plain => verifier.to_string(),
207 Self::S256 => {
208 use sha2::{Digest, Sha256};
209 let hash = Sha256::digest(verifier.as_bytes());
210 base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, hash)
211 }
212 }
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct AccessToken {
219 pub token: String,
220 pub token_type: TokenType,
221 pub expires_in: u64,
222 pub scopes: Vec<Scope>,
223 pub user_id: String,
224 pub created_at: u64,
225}
226
227impl AccessToken {
228 pub fn new(token: String, scopes: Vec<Scope>, user_id: String, ttl: Duration) -> Self {
229 Self {
230 token,
231 token_type: TokenType::Bearer,
232 expires_in: ttl.as_secs(),
233 scopes,
234 user_id,
235 created_at: SystemTime::now()
236 .duration_since(UNIX_EPOCH)
237 .unwrap()
238 .as_secs(),
239 }
240 }
241
242 pub fn is_expired(&self) -> bool {
243 let now = SystemTime::now()
244 .duration_since(UNIX_EPOCH)
245 .unwrap()
246 .as_secs();
247 now > self.created_at + self.expires_in
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct RefreshToken {
254 pub token: String,
255 pub client_id: String,
256 pub user_id: String,
257 pub scopes: Vec<Scope>,
258 pub created_at: u64,
259}
260
261impl RefreshToken {
262 pub fn new(client_id: String, user_id: String, scopes: Vec<Scope>) -> Self {
263 Self {
264 token: Uuid::new_v4().to_string(),
265 client_id,
266 user_id,
267 scopes,
268 created_at: SystemTime::now()
269 .duration_since(UNIX_EPOCH)
270 .unwrap()
271 .as_secs(),
272 }
273 }
274}
275
276#[derive(Debug, Serialize, Deserialize)]
278pub struct TokenResponse {
279 pub access_token: String,
280 pub token_type: String,
281 pub expires_in: u64,
282 #[serde(skip_serializing_if = "Option::is_none")]
283 pub refresh_token: Option<String>,
284 #[serde(skip_serializing_if = "Option::is_none")]
285 pub scope: Option<String>,
286}
287
288#[derive(Debug, Serialize, Deserialize)]
290pub struct ErrorResponse {
291 pub error: String,
292 #[serde(skip_serializing_if = "Option::is_none")]
293 pub error_description: Option<String>,
294}
295
296pub struct OAuth2Server {
298 clients: Arc<DashMap<String, OAuth2Client>>,
299 authorization_codes: Arc<DashMap<String, AuthorizationCode>>,
300 access_tokens: Arc<DashMap<String, AccessToken>>,
301 refresh_tokens: Arc<DashMap<String, RefreshToken>>,
302 jwt_manager: Arc<JwtManager>,
303 access_token_ttl: Duration,
305 #[allow(dead_code)]
307 refresh_token_ttl: Duration,
308 code_ttl: Duration,
310}
311
312impl OAuth2Server {
313 pub fn new(jwt_secret: &[u8]) -> Self {
314 Self {
315 clients: Arc::new(DashMap::new()),
316 authorization_codes: Arc::new(DashMap::new()),
317 access_tokens: Arc::new(DashMap::new()),
318 refresh_tokens: Arc::new(DashMap::new()),
319 jwt_manager: Arc::new(JwtManager::new(jwt_secret)),
320 access_token_ttl: Duration::from_secs(3600), refresh_token_ttl: Duration::from_secs(86400 * 30), code_ttl: Duration::from_secs(600), }
324 }
325
326 pub fn register_client(
328 &self,
329 name: String,
330 redirect_uris: Vec<String>,
331 grant_types: Vec<GrantType>,
332 scopes: Vec<Scope>,
333 ) -> OAuth2Client {
334 let client = OAuth2Client::new(name, redirect_uris, grant_types, scopes);
335 self.clients
336 .insert(client.client_id.clone(), client.clone());
337 client
338 }
339
340 pub fn get_client(&self, client_id: &str) -> Option<OAuth2Client> {
342 self.clients.get(client_id).map(|c| c.clone())
343 }
344
345 #[allow(clippy::too_many_arguments)]
347 pub fn authorize(
348 &self,
349 client_id: &str,
350 redirect_uri: &str,
351 response_type: ResponseType,
352 scopes: Vec<Scope>,
353 user_id: String,
354 code_challenge: Option<String>,
355 code_challenge_method: Option<CodeChallengeMethod>,
356 ) -> AuthResult<AuthorizationCode> {
357 let client = self
359 .get_client(client_id)
360 .ok_or(AuthError::InvalidCredentials)?;
361
362 if !client.is_redirect_uri_allowed(redirect_uri) {
364 return Err(AuthError::InvalidCredentials);
365 }
366
367 if !client.is_grant_type_allowed(GrantType::AuthorizationCode) {
369 return Err(AuthError::InvalidCredentials);
370 }
371
372 for scope in &scopes {
374 if !client.is_scope_allowed(scope) {
375 return Err(AuthError::InsufficientPermissions);
376 }
377 }
378
379 if response_type != ResponseType::Code {
381 return Err(AuthError::InvalidCredentials);
382 }
383
384 let auth_code = AuthorizationCode::new(
386 client_id.to_string(),
387 redirect_uri.to_string(),
388 scopes,
389 user_id,
390 self.code_ttl,
391 code_challenge,
392 code_challenge_method,
393 );
394
395 self.authorization_codes
396 .insert(auth_code.code.clone(), auth_code.clone());
397
398 Ok(auth_code)
399 }
400
401 pub fn exchange_code(
403 &self,
404 client_id: &str,
405 client_secret: &str,
406 code: &str,
407 redirect_uri: &str,
408 code_verifier: Option<&str>,
409 ) -> AuthResult<(AccessToken, RefreshToken)> {
410 let client = self
412 .get_client(client_id)
413 .ok_or(AuthError::InvalidCredentials)?;
414
415 if !client.verify_secret(client_secret) {
416 return Err(AuthError::InvalidCredentials);
417 }
418
419 let auth_code = self
421 .authorization_codes
422 .remove(code)
423 .ok_or(AuthError::InvalidToken("Invalid code".to_string()))?
424 .1;
425
426 if auth_code.is_expired() {
428 return Err(AuthError::TokenExpired);
429 }
430
431 if auth_code.client_id != client_id {
432 return Err(AuthError::InvalidCredentials);
433 }
434
435 if auth_code.redirect_uri != redirect_uri {
436 return Err(AuthError::InvalidCredentials);
437 }
438
439 if let Some(verifier) = code_verifier {
441 if !auth_code.verify_code_verifier(verifier) {
442 return Err(AuthError::InvalidCredentials);
443 }
444 } else if auth_code.code_challenge.is_some() {
445 return Err(AuthError::InvalidCredentials);
447 }
448
449 let access_token_jwt = self
451 .jwt_manager
452 .generate_token_with_scopes(
453 &auth_code.user_id,
454 &Scope::join_scopes(&auth_code.scopes),
455 (self.access_token_ttl.as_secs() / 3600) as usize,
456 )
457 .map_err(|_| AuthError::InvalidToken("Failed to generate token".to_string()))?;
458
459 let access_token = AccessToken::new(
460 access_token_jwt,
461 auth_code.scopes.clone(),
462 auth_code.user_id.clone(),
463 self.access_token_ttl,
464 );
465
466 let refresh_token = RefreshToken::new(
468 client_id.to_string(),
469 auth_code.user_id.clone(),
470 auth_code.scopes.clone(),
471 );
472
473 self.access_tokens
475 .insert(access_token.token.clone(), access_token.clone());
476 self.refresh_tokens
477 .insert(refresh_token.token.clone(), refresh_token.clone());
478
479 Ok((access_token, refresh_token))
480 }
481
482 pub fn client_credentials(
484 &self,
485 client_id: &str,
486 client_secret: &str,
487 scopes: Vec<Scope>,
488 ) -> AuthResult<AccessToken> {
489 let client = self
491 .get_client(client_id)
492 .ok_or(AuthError::InvalidCredentials)?;
493
494 if !client.verify_secret(client_secret) {
495 return Err(AuthError::InvalidCredentials);
496 }
497
498 if !client.is_grant_type_allowed(GrantType::ClientCredentials) {
500 return Err(AuthError::InvalidCredentials);
501 }
502
503 for scope in &scopes {
505 if !client.is_scope_allowed(scope) {
506 return Err(AuthError::InsufficientPermissions);
507 }
508 }
509
510 let access_token_jwt = self
512 .jwt_manager
513 .generate_token_with_scopes(
514 client_id,
515 &Scope::join_scopes(&scopes),
516 (self.access_token_ttl.as_secs() / 3600) as usize,
517 )
518 .map_err(|_| AuthError::InvalidToken("Failed to generate token".to_string()))?;
519
520 let access_token = AccessToken::new(
521 access_token_jwt,
522 scopes,
523 client_id.to_string(),
524 self.access_token_ttl,
525 );
526
527 self.access_tokens
528 .insert(access_token.token.clone(), access_token.clone());
529
530 Ok(access_token)
531 }
532
533 pub fn refresh_token(
535 &self,
536 client_id: &str,
537 client_secret: &str,
538 refresh_token: &str,
539 ) -> AuthResult<AccessToken> {
540 let client = self
542 .get_client(client_id)
543 .ok_or(AuthError::InvalidCredentials)?;
544
545 if !client.verify_secret(client_secret) {
546 return Err(AuthError::InvalidCredentials);
547 }
548
549 let rt = self
551 .refresh_tokens
552 .get(refresh_token)
553 .ok_or(AuthError::InvalidToken("Invalid refresh token".to_string()))?;
554
555 if rt.client_id != client_id {
556 return Err(AuthError::InvalidCredentials);
557 }
558
559 let access_token_jwt = self
561 .jwt_manager
562 .generate_token_with_scopes(
563 &rt.user_id,
564 &Scope::join_scopes(&rt.scopes),
565 (self.access_token_ttl.as_secs() / 3600) as usize,
566 )
567 .map_err(|_| AuthError::InvalidToken("Failed to generate token".to_string()))?;
568
569 let access_token = AccessToken::new(
570 access_token_jwt,
571 rt.scopes.clone(),
572 rt.user_id.clone(),
573 self.access_token_ttl,
574 );
575
576 self.access_tokens
577 .insert(access_token.token.clone(), access_token.clone());
578
579 Ok(access_token)
580 }
581
582 pub fn validate_token(&self, token: &str) -> AuthResult<AccessToken> {
584 let access_token = self
585 .access_tokens
586 .get(token)
587 .ok_or(AuthError::InvalidToken("Token not found".to_string()))?;
588
589 if access_token.is_expired() {
590 drop(access_token);
592 self.access_tokens.remove(token);
593 return Err(AuthError::TokenExpired);
594 }
595
596 Ok(access_token.clone())
597 }
598
599 pub fn revoke_access_token(&self, token: &str) -> bool {
601 self.access_tokens.remove(token).is_some()
602 }
603
604 pub fn revoke_refresh_token(&self, token: &str) -> bool {
606 self.refresh_tokens.remove(token).is_some()
607 }
608
609 pub fn cleanup_expired(&self) {
611 self.authorization_codes
613 .retain(|_, code| !code.is_expired());
614
615 self.access_tokens.retain(|_, token| !token.is_expired());
617 }
618}
619
620impl Default for OAuth2Server {
621 fn default() -> Self {
622 Self::new(b"default-secret-change-in-production")
623 }
624}
625
626#[derive(Debug, Clone, Serialize, Deserialize)]
628pub struct OAuth2ProviderConfig {
629 pub name: String,
630 pub client_id: String,
631 pub client_secret: String,
632 pub authorization_endpoint: String,
633 pub token_endpoint: String,
634 pub redirect_uri: String,
635 pub scopes: Vec<Scope>,
636}
637
638impl OAuth2ProviderConfig {
639 pub fn google(client_id: String, client_secret: String, redirect_uri: String) -> Self {
641 Self {
642 name: "google".to_string(),
643 client_id,
644 client_secret,
645 authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
646 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
647 redirect_uri,
648 scopes: vec![
649 Scope::new("openid"),
650 Scope::new("email"),
651 Scope::new("profile"),
652 ],
653 }
654 }
655
656 pub fn github(client_id: String, client_secret: String, redirect_uri: String) -> Self {
658 Self {
659 name: "github".to_string(),
660 client_id,
661 client_secret,
662 authorization_endpoint: "https://github.com/login/oauth/authorize".to_string(),
663 token_endpoint: "https://github.com/login/oauth/access_token".to_string(),
664 redirect_uri,
665 scopes: vec![Scope::new("user:email"), Scope::new("read:user")],
666 }
667 }
668
669 pub fn build_auth_url(&self, state: &str) -> String {
671 let scope = Scope::join_scopes(&self.scopes);
672 format!(
673 "{}?client_id={}&redirect_uri={}&scope={}&response_type=code&state={}",
674 self.authorization_endpoint,
675 urlencoding::encode(&self.client_id),
676 urlencoding::encode(&self.redirect_uri),
677 urlencoding::encode(&scope),
678 state
679 )
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686 use std::time::{SystemTime, UNIX_EPOCH};
687
688 #[test]
689 fn test_scope_parsing() {
690 let scopes = Scope::parse_scopes("read write admin");
691 assert_eq!(scopes.len(), 3);
692 assert_eq!(scopes[0].as_str(), "read");
693 assert_eq!(scopes[1].as_str(), "write");
694 assert_eq!(scopes[2].as_str(), "admin");
695 }
696
697 #[test]
698 fn test_scope_joining() {
699 let scopes = vec![Scope::new("read"), Scope::new("write")];
700 let joined = Scope::join_scopes(&scopes);
701 assert_eq!(joined, "read write");
702 }
703
704 #[test]
705 fn test_client_creation() {
706 let client = OAuth2Client::new(
707 "test-client".to_string(),
708 vec!["http://localhost:3000/callback".to_string()],
709 vec![GrantType::AuthorizationCode],
710 vec![Scope::new("read")],
711 );
712
713 assert!(!client.client_id.is_empty());
714 assert!(!client.client_secret.is_empty());
715 assert_eq!(client.name, "test-client");
716 }
717
718 #[test]
719 fn test_client_verification() {
720 let client = OAuth2Client::new(
721 "test".to_string(),
722 vec!["http://localhost/callback".to_string()],
723 vec![GrantType::AuthorizationCode],
724 vec![Scope::new("read")],
725 );
726
727 assert!(client.verify_secret(&client.client_secret));
728 assert!(!client.verify_secret("wrong-secret"));
729 assert!(client.is_redirect_uri_allowed("http://localhost/callback"));
730 assert!(!client.is_redirect_uri_allowed("http://evil.com/callback"));
731 }
732
733 #[test]
734 fn test_pkce_plain() {
735 let method = CodeChallengeMethod::Plain;
736 let verifier = "test-verifier";
737 let challenge = method.compute_challenge(verifier);
738 assert_eq!(challenge, verifier);
739 }
740
741 #[test]
742 fn test_pkce_s256() {
743 let method = CodeChallengeMethod::S256;
744 let verifier = "test-verifier-with-sufficient-entropy";
745 let challenge = method.compute_challenge(verifier);
746 assert_ne!(challenge, verifier);
747 assert!(!challenge.is_empty());
748
749 let challenge2 = method.compute_challenge(verifier);
751 assert_eq!(challenge, challenge2);
752 }
753
754 #[test]
755 fn test_authorization_code_expiry() {
756 let code = AuthorizationCode {
758 code: "test-code".to_string(),
759 client_id: "client-id".to_string(),
760 redirect_uri: "http://localhost/callback".to_string(),
761 scopes: vec![Scope::new("read")],
762 user_id: "user-id".to_string(),
763 expires_at: SystemTime::now()
764 .duration_since(UNIX_EPOCH)
765 .unwrap()
766 .as_secs()
767 - 1, code_challenge: None,
769 code_challenge_method: None,
770 };
771
772 assert!(code.is_expired());
773 }
774
775 #[test]
776 fn test_oauth2_server_client_registration() {
777 let server = OAuth2Server::default();
778 let client = server.register_client(
779 "test-client".to_string(),
780 vec!["http://localhost/callback".to_string()],
781 vec![GrantType::AuthorizationCode],
782 vec![Scope::new("read")],
783 );
784
785 let retrieved = server.get_client(&client.client_id);
786 assert!(retrieved.is_some());
787 assert_eq!(retrieved.unwrap().name, "test-client");
788 }
789
790 #[test]
791 fn test_oauth2_server_authorization() {
792 let server = OAuth2Server::default();
793 let client = server.register_client(
794 "test".to_string(),
795 vec!["http://localhost/callback".to_string()],
796 vec![GrantType::AuthorizationCode],
797 vec![Scope::new("read")],
798 );
799
800 let auth_code = server
801 .authorize(
802 &client.client_id,
803 "http://localhost/callback",
804 ResponseType::Code,
805 vec![Scope::new("read")],
806 "user-123".to_string(),
807 None,
808 None,
809 )
810 .unwrap();
811
812 assert!(!auth_code.code.is_empty());
813 assert_eq!(auth_code.user_id, "user-123");
814 }
815
816 #[test]
817 fn test_oauth2_server_client_credentials() {
818 let server = OAuth2Server::default();
819 let client = server.register_client(
820 "test".to_string(),
821 vec![],
822 vec![GrantType::ClientCredentials],
823 vec![Scope::new("read")],
824 );
825
826 let token = server
827 .client_credentials(
828 &client.client_id,
829 &client.client_secret,
830 vec![Scope::new("read")],
831 )
832 .unwrap();
833
834 assert!(!token.token.is_empty());
835 assert_eq!(token.token_type, TokenType::Bearer);
836 }
837
838 #[test]
839 fn test_provider_config_google() {
840 let config = OAuth2ProviderConfig::google(
841 "client-id".to_string(),
842 "client-secret".to_string(),
843 "http://localhost/callback".to_string(),
844 );
845
846 assert_eq!(config.name, "google");
847 assert!(config.authorization_endpoint.contains("google"));
848
849 let url = config.build_auth_url("random-state");
850 assert!(url.contains("client_id=client-id"));
851 assert!(url.contains("state=random-state"));
852 }
853
854 #[test]
855 fn test_provider_config_github() {
856 let config = OAuth2ProviderConfig::github(
857 "client-id".to_string(),
858 "client-secret".to_string(),
859 "http://localhost/callback".to_string(),
860 );
861
862 assert_eq!(config.name, "github");
863 assert!(config.authorization_endpoint.contains("github"));
864 }
865
866 #[test]
867 fn test_token_validation() {
868 let server = OAuth2Server::default();
869 let client = server.register_client(
870 "test".to_string(),
871 vec![],
872 vec![GrantType::ClientCredentials],
873 vec![Scope::new("read")],
874 );
875
876 let token = server
877 .client_credentials(
878 &client.client_id,
879 &client.client_secret,
880 vec![Scope::new("read")],
881 )
882 .unwrap();
883
884 let validated = server.validate_token(&token.token);
886 assert!(validated.is_ok());
887
888 server.revoke_access_token(&token.token);
890
891 let validated = server.validate_token(&token.token);
893 assert!(validated.is_err());
894 }
895}