1use crate::error::{Error, ErrorCode, Result};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub enum GrantType {
15 #[serde(rename = "authorization_code")]
17 AuthorizationCode,
18 #[serde(rename = "refresh_token")]
20 RefreshToken,
21 #[serde(rename = "client_credentials")]
23 ClientCredentials,
24 #[serde(rename = "password")]
26 Password,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub enum ResponseType {
32 #[serde(rename = "code")]
34 Code,
35 #[serde(rename = "token")]
37 Token,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum TokenType {
44 Bearer,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct OAuthClient {
51 pub client_id: String,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub client_secret: Option<String>,
57
58 pub client_name: String,
60
61 pub redirect_uris: Vec<String>,
63
64 pub grant_types: Vec<GrantType>,
66
67 pub response_types: Vec<ResponseType>,
69
70 pub scopes: Vec<String>,
72
73 #[serde(flatten)]
75 pub metadata: HashMap<String, serde_json::Value>,
76}
77
78#[derive(Debug, Clone)]
80pub struct AuthorizationCode {
81 pub code: String,
83
84 pub client_id: String,
86
87 pub user_id: String,
89
90 pub redirect_uri: String,
92
93 pub scopes: Vec<String>,
95
96 pub code_challenge: Option<String>,
98
99 pub code_challenge_method: Option<String>,
101
102 pub expires_at: u64,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct AccessToken {
109 pub access_token: String,
111
112 pub token_type: TokenType,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub expires_in: Option<u64>,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub refresh_token: Option<String>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub scope: Option<String>,
126
127 #[serde(flatten)]
129 pub extra: HashMap<String, serde_json::Value>,
130}
131
132#[derive(Debug, Clone)]
134pub struct TokenInfo {
135 pub token: String,
137
138 pub client_id: String,
140
141 pub user_id: String,
143
144 pub scopes: Vec<String>,
146
147 pub expires_at: u64,
149
150 pub token_type: TokenType,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct OAuthError {
157 pub error: String,
159
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub error_description: Option<String>,
163
164 #[serde(skip_serializing_if = "Option::is_none")]
166 pub error_uri: Option<String>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct OidcDiscoveryMetadata {
173 pub issuer: String,
175
176 pub authorization_endpoint: String,
178
179 pub token_endpoint: String,
181
182 #[serde(skip_serializing_if = "Option::is_none")]
184 pub jwks_uri: Option<String>,
185
186 #[serde(skip_serializing_if = "Option::is_none")]
188 pub userinfo_endpoint: Option<String>,
189
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub registration_endpoint: Option<String>,
193
194 #[serde(skip_serializing_if = "Option::is_none")]
196 pub revocation_endpoint: Option<String>,
197
198 #[serde(skip_serializing_if = "Option::is_none")]
200 pub introspection_endpoint: Option<String>,
201
202 pub response_types_supported: Vec<ResponseType>,
204
205 pub grant_types_supported: Vec<GrantType>,
207
208 pub scopes_supported: Vec<String>,
210
211 pub token_endpoint_auth_methods_supported: Vec<String>,
213
214 pub code_challenge_methods_supported: Vec<String>,
216}
217
218pub type OAuthMetadata = OidcDiscoveryMetadata;
220
221#[derive(Debug, Clone, Deserialize)]
223pub struct AuthorizationRequest {
224 pub response_type: ResponseType,
226
227 pub client_id: String,
229
230 pub redirect_uri: String,
232
233 #[serde(default)]
235 pub scope: String,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
239 pub state: Option<String>,
240
241 #[serde(skip_serializing_if = "Option::is_none")]
243 pub code_challenge: Option<String>,
244
245 #[serde(skip_serializing_if = "Option::is_none")]
247 pub code_challenge_method: Option<String>,
248}
249
250#[derive(Debug, Clone, Deserialize)]
252pub struct TokenRequest {
253 pub grant_type: GrantType,
255
256 #[serde(skip_serializing_if = "Option::is_none")]
258 pub code: Option<String>,
259
260 #[serde(skip_serializing_if = "Option::is_none")]
262 pub redirect_uri: Option<String>,
263
264 #[serde(skip_serializing_if = "Option::is_none")]
266 pub client_id: Option<String>,
267
268 #[serde(skip_serializing_if = "Option::is_none")]
270 pub client_secret: Option<String>,
271
272 #[serde(skip_serializing_if = "Option::is_none")]
274 pub refresh_token: Option<String>,
275
276 #[serde(skip_serializing_if = "Option::is_none")]
278 pub username: Option<String>,
279
280 #[serde(skip_serializing_if = "Option::is_none")]
282 pub password: Option<String>,
283
284 #[serde(skip_serializing_if = "Option::is_none")]
286 pub scope: Option<String>,
287
288 #[serde(skip_serializing_if = "Option::is_none")]
290 pub code_verifier: Option<String>,
291}
292
293#[derive(Debug, Clone, Deserialize)]
295pub struct RevocationRequest {
296 pub token: String,
298
299 #[serde(skip_serializing_if = "Option::is_none")]
301 pub token_type_hint: Option<String>,
302
303 #[serde(skip_serializing_if = "Option::is_none")]
305 pub client_id: Option<String>,
306
307 #[serde(skip_serializing_if = "Option::is_none")]
309 pub client_secret: Option<String>,
310}
311
312#[async_trait]
314pub trait OAuthProvider: Send + Sync {
315 async fn register_client(&self, client: OAuthClient) -> Result<OAuthClient>;
317
318 async fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>;
320
321 async fn validate_authorization(&self, request: &AuthorizationRequest) -> Result<()>;
323
324 async fn create_authorization_code(
326 &self,
327 client_id: &str,
328 user_id: &str,
329 redirect_uri: &str,
330 scopes: Vec<String>,
331 code_challenge: Option<String>,
332 code_challenge_method: Option<String>,
333 ) -> Result<String>;
334
335 async fn exchange_code(&self, request: &TokenRequest) -> Result<AccessToken>;
337
338 async fn create_access_token(
340 &self,
341 client_id: &str,
342 user_id: &str,
343 scopes: Vec<String>,
344 ) -> Result<AccessToken>;
345
346 async fn refresh_token(&self, refresh_token: &str) -> Result<AccessToken>;
348
349 async fn revoke_token(&self, token: &str) -> Result<()>;
351
352 async fn validate_token(&self, token: &str) -> Result<TokenInfo>;
354
355 async fn metadata(&self) -> Result<OAuthMetadata>;
357
358 async fn discover(&self, _issuer_url: &str) -> Result<OidcDiscoveryMetadata> {
362 Err(Error::protocol(
365 ErrorCode::METHOD_NOT_FOUND,
366 "OIDC discovery not implemented for this provider",
367 ))
368 }
369}
370
371#[derive(Debug)]
373pub struct InMemoryOAuthProvider {
374 base_url: String,
376
377 clients: Arc<RwLock<HashMap<String, OAuthClient>>>,
379
380 codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
382
383 tokens: Arc<RwLock<HashMap<String, TokenInfo>>>,
385
386 refresh_tokens: Arc<RwLock<HashMap<String, String>>>,
388
389 token_expiration: u64,
391
392 code_expiration: u64,
394
395 supported_scopes: Vec<String>,
397}
398
399impl InMemoryOAuthProvider {
400 pub fn new(base_url: impl Into<String>) -> Self {
402 Self {
403 base_url: base_url.into(),
404 clients: Arc::new(RwLock::new(HashMap::new())),
405 codes: Arc::new(RwLock::new(HashMap::new())),
406 tokens: Arc::new(RwLock::new(HashMap::new())),
407 refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
408 token_expiration: 3600, code_expiration: 600, supported_scopes: vec!["read".to_string(), "write".to_string()],
411 }
412 }
413
414 fn generate_token() -> String {
416 Uuid::new_v4().to_string()
417 }
418
419 fn now() -> u64 {
421 std::time::SystemTime::now()
422 .duration_since(std::time::UNIX_EPOCH)
423 .unwrap()
424 .as_secs()
425 }
426
427 fn verify_pkce(verifier: &str, challenge: &str, method: &str) -> bool {
429 match method {
430 "plain" => verifier == challenge,
431 "S256" => {
432 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
433 use sha2::{Digest, Sha256};
434
435 let mut hasher = Sha256::new();
436 hasher.update(verifier.as_bytes());
437 let result = hasher.finalize();
438 let encoded = URL_SAFE_NO_PAD.encode(result);
439 encoded == challenge
440 },
441 _ => false,
442 }
443 }
444}
445
446#[async_trait]
447impl OAuthProvider for InMemoryOAuthProvider {
448 async fn register_client(&self, mut client: OAuthClient) -> Result<OAuthClient> {
449 if client.client_id.is_empty() {
451 client.client_id = Self::generate_token();
452 }
453 if client.client_secret.is_none() {
454 client.client_secret = Some(Self::generate_token());
455 }
456
457 let mut clients = self.clients.write().await;
459 clients.insert(client.client_id.clone(), client.clone());
460
461 Ok(client)
462 }
463
464 async fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>> {
465 let clients = self.clients.read().await;
466 Ok(clients.get(client_id).cloned())
467 }
468
469 async fn validate_authorization(&self, request: &AuthorizationRequest) -> Result<()> {
470 let client = self
472 .get_client(&request.client_id)
473 .await?
474 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid client_id"))?;
475
476 if !client.redirect_uris.contains(&request.redirect_uri) {
478 return Err(Error::protocol(
479 ErrorCode::INVALID_REQUEST,
480 "Invalid redirect_uri",
481 ));
482 }
483
484 if !client.response_types.contains(&request.response_type) {
486 return Err(Error::protocol(
487 ErrorCode::INVALID_REQUEST,
488 "Unsupported response_type",
489 ));
490 }
491
492 let requested_scopes: Vec<&str> = request.scope.split_whitespace().collect();
494 for scope in &requested_scopes {
495 if !self.supported_scopes.iter().any(|s| s == scope) {
496 return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid scope"));
497 }
498 }
499
500 Ok(())
501 }
502
503 async fn create_authorization_code(
504 &self,
505 client_id: &str,
506 user_id: &str,
507 redirect_uri: &str,
508 scopes: Vec<String>,
509 code_challenge: Option<String>,
510 code_challenge_method: Option<String>,
511 ) -> Result<String> {
512 let code = Self::generate_token();
513 let expires_at = Self::now() + self.code_expiration;
514
515 let auth_code = AuthorizationCode {
516 code: code.clone(),
517 client_id: client_id.to_string(),
518 user_id: user_id.to_string(),
519 redirect_uri: redirect_uri.to_string(),
520 scopes,
521 code_challenge,
522 code_challenge_method,
523 expires_at,
524 };
525
526 let mut codes = self.codes.write().await;
527 codes.insert(code.clone(), auth_code);
528
529 Ok(code)
530 }
531
532 async fn exchange_code(&self, request: &TokenRequest) -> Result<AccessToken> {
533 if request.grant_type != GrantType::AuthorizationCode {
535 return Err(Error::protocol(
536 ErrorCode::INVALID_REQUEST,
537 "Invalid grant_type",
538 ));
539 }
540
541 let code = request
543 .code
544 .as_ref()
545 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing code"))?;
546
547 let mut codes = self.codes.write().await;
548 let auth_code = codes
549 .remove(code)
550 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid code"))?;
551
552 if auth_code.expires_at < Self::now() {
554 return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Code expired"));
555 }
556
557 let client_id = request
559 .client_id
560 .as_ref()
561 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing client_id"))?;
562
563 if auth_code.client_id != *client_id {
564 return Err(Error::protocol(
565 ErrorCode::INVALID_REQUEST,
566 "Invalid client_id",
567 ));
568 }
569
570 let redirect_uri = request
572 .redirect_uri
573 .as_ref()
574 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing redirect_uri"))?;
575
576 if auth_code.redirect_uri != *redirect_uri {
577 return Err(Error::protocol(
578 ErrorCode::INVALID_REQUEST,
579 "Invalid redirect_uri",
580 ));
581 }
582
583 if let Some(challenge) = &auth_code.code_challenge {
585 let verifier = request.code_verifier.as_ref().ok_or_else(|| {
586 Error::protocol(ErrorCode::INVALID_REQUEST, "Missing code_verifier")
587 })?;
588
589 let method = auth_code
590 .code_challenge_method
591 .as_deref()
592 .unwrap_or("plain");
593 if !Self::verify_pkce(verifier, challenge, method) {
594 return Err(Error::protocol(
595 ErrorCode::INVALID_REQUEST,
596 "Invalid code_verifier",
597 ));
598 }
599 }
600
601 self.create_access_token(&auth_code.client_id, &auth_code.user_id, auth_code.scopes)
603 .await
604 }
605
606 async fn create_access_token(
607 &self,
608 client_id: &str,
609 user_id: &str,
610 scopes: Vec<String>,
611 ) -> Result<AccessToken> {
612 let access_token = Self::generate_token();
613 let refresh_token = Self::generate_token();
614 let expires_at = Self::now() + self.token_expiration;
615
616 let token_info = TokenInfo {
618 token: access_token.clone(),
619 client_id: client_id.to_string(),
620 user_id: user_id.to_string(),
621 scopes: scopes.clone(),
622 expires_at,
623 token_type: TokenType::Bearer,
624 };
625
626 let mut tokens = self.tokens.write().await;
627 tokens.insert(access_token.clone(), token_info);
628
629 let mut refresh_tokens = self.refresh_tokens.write().await;
631 refresh_tokens.insert(refresh_token.clone(), access_token.clone());
632
633 Ok(AccessToken {
634 access_token,
635 token_type: TokenType::Bearer,
636 expires_in: Some(self.token_expiration),
637 refresh_token: Some(refresh_token),
638 scope: Some(scopes.join(" ")),
639 extra: HashMap::new(),
640 })
641 }
642
643 async fn refresh_token(&self, refresh_token: &str) -> Result<AccessToken> {
644 let refresh_tokens = self.refresh_tokens.read().await;
646 let old_token = refresh_tokens
647 .get(refresh_token)
648 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid refresh_token"))?
649 .clone();
650
651 let tokens = self.tokens.read().await;
653 let token_info = tokens
654 .get(&old_token)
655 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid refresh_token"))?;
656
657 let client_id = token_info.client_id.clone();
658 let user_id = token_info.user_id.clone();
659 let scopes = token_info.scopes.clone();
660
661 drop(tokens);
662 drop(refresh_tokens);
663
664 let mut tokens = self.tokens.write().await;
666 tokens.remove(&old_token);
667 drop(tokens);
668
669 let mut refresh_tokens = self.refresh_tokens.write().await;
670 refresh_tokens.remove(refresh_token);
671 drop(refresh_tokens);
672
673 self.create_access_token(&client_id, &user_id, scopes).await
675 }
676
677 async fn revoke_token(&self, token: &str) -> Result<()> {
678 let mut tokens = self.tokens.write().await;
680 if tokens.remove(token).is_some() {
681 return Ok(());
682 }
683 drop(tokens);
684
685 let mut refresh_tokens = self.refresh_tokens.write().await;
687 if let Some(access_token) = refresh_tokens.remove(token) {
688 let mut tokens = self.tokens.write().await;
689 tokens.remove(&access_token);
690 }
691
692 Ok(())
693 }
694
695 async fn validate_token(&self, token: &str) -> Result<TokenInfo> {
696 let tokens = self.tokens.read().await;
697 let token_info = tokens
698 .get(token)
699 .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid token"))?;
700
701 if token_info.expires_at < Self::now() {
703 return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Token expired"));
704 }
705
706 Ok(token_info.clone())
707 }
708
709 async fn metadata(&self) -> Result<OAuthMetadata> {
710 Ok(OAuthMetadata {
711 issuer: self.base_url.clone(),
712 authorization_endpoint: format!("{}/oauth2/authorize", self.base_url),
713 token_endpoint: format!("{}/oauth2/token", self.base_url),
714 jwks_uri: Some(format!("{}/oauth2/jwks", self.base_url)),
715 userinfo_endpoint: Some(format!("{}/oauth2/userinfo", self.base_url)),
716 registration_endpoint: Some(format!("{}/oauth2/register", self.base_url)),
717 revocation_endpoint: Some(format!("{}/oauth2/revoke", self.base_url)),
718 introspection_endpoint: Some(format!("{}/oauth2/introspect", self.base_url)),
719 response_types_supported: vec![ResponseType::Code],
720 grant_types_supported: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
721 scopes_supported: self.supported_scopes.clone(),
722 token_endpoint_auth_methods_supported: vec![
723 "client_secret_basic".to_string(),
724 "client_secret_post".to_string(),
725 ],
726 code_challenge_methods_supported: vec!["plain".to_string(), "S256".to_string()],
727 })
728 }
729}
730
731#[derive(Debug)]
733pub struct ProxyOAuthProvider {
734 _upstream_url: String,
736
737 _token_cache: Arc<RwLock<HashMap<String, TokenInfo>>>,
739}
740
741impl ProxyOAuthProvider {
742 pub fn new(upstream_url: impl Into<String>) -> Self {
744 Self {
745 _upstream_url: upstream_url.into(),
746 _token_cache: Arc::new(RwLock::new(HashMap::new())),
747 }
748 }
749}
750
751#[cfg(test)]
755mod tests {
756 use super::*;
757
758 #[tokio::test]
759 async fn test_oauth_flow() {
760 let provider = InMemoryOAuthProvider::new("http://localhost:8080");
761
762 let client = OAuthClient {
764 client_id: String::new(),
765 client_secret: None,
766 client_name: "Test Client".to_string(),
767 redirect_uris: vec!["http://localhost:3000/callback".to_string()],
768 grant_types: vec![GrantType::AuthorizationCode],
769 response_types: vec![ResponseType::Code],
770 scopes: vec!["read".to_string(), "write".to_string()],
771 metadata: HashMap::new(),
772 };
773
774 let registered = provider.register_client(client).await.unwrap();
775 assert!(!registered.client_id.is_empty());
776 assert!(registered.client_secret.is_some());
777
778 let auth_req = AuthorizationRequest {
780 response_type: ResponseType::Code,
781 client_id: registered.client_id.clone(),
782 redirect_uri: "http://localhost:3000/callback".to_string(),
783 scope: "read write".to_string(),
784 state: Some("test-state".to_string()),
785 code_challenge: None,
786 code_challenge_method: None,
787 };
788
789 provider.validate_authorization(&auth_req).await.unwrap();
790
791 let code = provider
793 .create_authorization_code(
794 ®istered.client_id,
795 "user-123",
796 &auth_req.redirect_uri,
797 vec!["read".to_string(), "write".to_string()],
798 None,
799 None,
800 )
801 .await
802 .unwrap();
803
804 let token_req = TokenRequest {
806 grant_type: GrantType::AuthorizationCode,
807 code: Some(code),
808 redirect_uri: Some(auth_req.redirect_uri),
809 client_id: Some(registered.client_id.clone()),
810 client_secret: registered.client_secret.clone(),
811 refresh_token: None,
812 username: None,
813 password: None,
814 scope: None,
815 code_verifier: None,
816 };
817
818 let token = provider.exchange_code(&token_req).await.unwrap();
819 assert_eq!(token.token_type, TokenType::Bearer);
820 assert!(token.refresh_token.is_some());
821
822 let token_info = provider.validate_token(&token.access_token).await.unwrap();
824 assert_eq!(token_info.client_id, registered.client_id);
825 assert_eq!(token_info.user_id, "user-123");
826 }
827}