1use std::collections::{HashMap, HashSet};
54use std::sync::{Arc, RwLock};
55use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
56
57use fastmcp_core::{AccessToken, AuthContext, McpContext, McpError, McpErrorCode, McpResult};
58
59use crate::auth::{AuthRequest, TokenVerifier};
60
61#[derive(Debug, Clone)]
67pub struct OAuthServerConfig {
68 pub issuer: String,
70 pub access_token_lifetime: Duration,
72 pub refresh_token_lifetime: Duration,
74 pub authorization_code_lifetime: Duration,
76 pub allow_public_clients: bool,
78 pub min_code_verifier_length: usize,
80 pub max_code_verifier_length: usize,
82 pub token_entropy_bytes: usize,
84}
85
86impl Default for OAuthServerConfig {
87 fn default() -> Self {
88 Self {
89 issuer: "fastmcp".to_string(),
90 access_token_lifetime: Duration::from_secs(3600), refresh_token_lifetime: Duration::from_secs(86400 * 30), authorization_code_lifetime: Duration::from_secs(600), allow_public_clients: true,
94 min_code_verifier_length: 43,
95 max_code_verifier_length: 128,
96 token_entropy_bytes: 32,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum ClientType {
108 Confidential,
110 Public,
112}
113
114#[derive(Debug, Clone)]
116pub struct OAuthClient {
117 pub client_id: String,
119 pub client_secret: Option<String>,
121 pub client_type: ClientType,
123 pub redirect_uris: Vec<String>,
125 pub allowed_scopes: HashSet<String>,
127 pub name: Option<String>,
129 pub description: Option<String>,
131 pub registered_at: SystemTime,
133}
134
135impl OAuthClient {
136 #[must_use]
138 pub fn builder(client_id: impl Into<String>) -> OAuthClientBuilder {
139 OAuthClientBuilder::new(client_id)
140 }
141
142 #[must_use]
144 pub fn validate_redirect_uri(&self, uri: &str) -> bool {
145 if self.redirect_uris.contains(&uri.to_string()) {
147 return true;
148 }
149
150 for allowed in &self.redirect_uris {
152 if is_localhost_redirect(allowed) && is_localhost_redirect(uri) {
153 if localhost_match(allowed, uri) {
155 return true;
156 }
157 }
158 }
159
160 false
161 }
162
163 #[must_use]
165 pub fn validate_scopes(&self, scopes: &[String]) -> bool {
166 scopes.iter().all(|s| self.allowed_scopes.contains(s))
167 }
168
169 #[must_use]
171 pub fn authenticate(&self, secret: Option<&str>) -> bool {
172 match (&self.client_secret, secret) {
173 (Some(expected), Some(provided)) => constant_time_eq(expected, provided),
174 (None, None) => self.client_type == ClientType::Public,
175 _ => false,
176 }
177 }
178}
179
180#[derive(Debug)]
182pub struct OAuthClientBuilder {
183 client_id: String,
184 client_secret: Option<String>,
185 redirect_uris: Vec<String>,
186 allowed_scopes: HashSet<String>,
187 name: Option<String>,
188 description: Option<String>,
189}
190
191impl OAuthClientBuilder {
192 fn new(client_id: impl Into<String>) -> Self {
194 Self {
195 client_id: client_id.into(),
196 client_secret: None,
197 redirect_uris: Vec::new(),
198 allowed_scopes: HashSet::new(),
199 name: None,
200 description: None,
201 }
202 }
203
204 #[must_use]
206 pub fn secret(mut self, secret: impl Into<String>) -> Self {
207 self.client_secret = Some(secret.into());
208 self
209 }
210
211 #[must_use]
213 pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
214 self.redirect_uris.push(uri.into());
215 self
216 }
217
218 #[must_use]
220 pub fn redirect_uris<I, S>(mut self, uris: I) -> Self
221 where
222 I: IntoIterator<Item = S>,
223 S: Into<String>,
224 {
225 self.redirect_uris.extend(uris.into_iter().map(Into::into));
226 self
227 }
228
229 #[must_use]
231 pub fn scope(mut self, scope: impl Into<String>) -> Self {
232 self.allowed_scopes.insert(scope.into());
233 self
234 }
235
236 #[must_use]
238 pub fn scopes<I, S>(mut self, scopes: I) -> Self
239 where
240 I: IntoIterator<Item = S>,
241 S: Into<String>,
242 {
243 self.allowed_scopes
244 .extend(scopes.into_iter().map(Into::into));
245 self
246 }
247
248 #[must_use]
250 pub fn name(mut self, name: impl Into<String>) -> Self {
251 self.name = Some(name.into());
252 self
253 }
254
255 #[must_use]
257 pub fn description(mut self, description: impl Into<String>) -> Self {
258 self.description = Some(description.into());
259 self
260 }
261
262 pub fn build(self) -> Result<OAuthClient, OAuthError> {
270 if self.client_id.is_empty() {
271 return Err(OAuthError::InvalidRequest(
272 "client_id cannot be empty".to_string(),
273 ));
274 }
275
276 if self.redirect_uris.is_empty() {
277 return Err(OAuthError::InvalidRequest(
278 "at least one redirect_uri is required".to_string(),
279 ));
280 }
281
282 let client_type = if self.client_secret.is_some() {
283 ClientType::Confidential
284 } else {
285 ClientType::Public
286 };
287
288 Ok(OAuthClient {
289 client_id: self.client_id,
290 client_secret: self.client_secret,
291 client_type,
292 redirect_uris: self.redirect_uris,
293 allowed_scopes: self.allowed_scopes,
294 name: self.name,
295 description: self.description,
296 registered_at: SystemTime::now(),
297 })
298 }
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub enum CodeChallengeMethod {
308 Plain,
310 S256,
312}
313
314impl CodeChallengeMethod {
315 #[must_use]
317 pub fn parse(s: &str) -> Option<Self> {
318 match s {
319 "plain" => Some(Self::Plain),
320 "S256" => Some(Self::S256),
321 _ => None,
322 }
323 }
324
325 #[must_use]
327 pub fn as_str(&self) -> &'static str {
328 match self {
329 Self::Plain => "plain",
330 Self::S256 => "S256",
331 }
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct AuthorizationCode {
338 pub code: String,
340 pub client_id: String,
342 pub redirect_uri: String,
344 pub scopes: Vec<String>,
346 pub code_challenge: String,
348 pub code_challenge_method: CodeChallengeMethod,
350 pub issued_at: Instant,
352 pub expires_at: Instant,
354 pub subject: Option<String>,
356 pub state: Option<String>,
358}
359
360impl AuthorizationCode {
361 #[must_use]
363 pub fn is_expired(&self) -> bool {
364 Instant::now() >= self.expires_at
365 }
366
367 #[must_use]
369 pub fn validate_code_verifier(&self, verifier: &str) -> bool {
370 match self.code_challenge_method {
371 CodeChallengeMethod::Plain => constant_time_eq(&self.code_challenge, verifier),
372 CodeChallengeMethod::S256 => {
373 let computed = compute_s256_challenge(verifier);
374 constant_time_eq(&self.code_challenge, &computed)
375 }
376 }
377 }
378}
379
380#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum TokenType {
387 Bearer,
389}
390
391impl TokenType {
392 #[must_use]
394 pub fn as_str(&self) -> &'static str {
395 match self {
396 Self::Bearer => "bearer",
397 }
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct OAuthToken {
404 pub token: String,
406 pub token_type: TokenType,
408 pub client_id: String,
410 pub scopes: Vec<String>,
412 pub issued_at: Instant,
414 pub expires_at: Instant,
416 pub subject: Option<String>,
418 pub is_refresh_token: bool,
420}
421
422impl OAuthToken {
423 #[must_use]
425 pub fn is_expired(&self) -> bool {
426 Instant::now() >= self.expires_at
427 }
428
429 #[must_use]
431 pub fn expires_in_secs(&self) -> u64 {
432 self.expires_at
433 .saturating_duration_since(Instant::now())
434 .as_secs()
435 }
436}
437
438#[derive(Debug, Clone, serde::Serialize)]
440pub struct TokenResponse {
441 pub access_token: String,
443 pub token_type: String,
445 pub expires_in: u64,
447 #[serde(skip_serializing_if = "Option::is_none")]
449 pub refresh_token: Option<String>,
450 #[serde(skip_serializing_if = "Option::is_none")]
452 pub scope: Option<String>,
453}
454
455#[derive(Debug, Clone)]
461pub struct AuthorizationRequest {
462 pub response_type: String,
464 pub client_id: String,
466 pub redirect_uri: String,
468 pub scopes: Vec<String>,
470 pub state: Option<String>,
472 pub code_challenge: String,
474 pub code_challenge_method: CodeChallengeMethod,
476}
477
478#[derive(Debug, Clone)]
480pub struct TokenRequest {
481 pub grant_type: String,
483 pub code: Option<String>,
485 pub redirect_uri: Option<String>,
487 pub client_id: String,
489 pub client_secret: Option<String>,
491 pub code_verifier: Option<String>,
493 pub refresh_token: Option<String>,
495 pub scopes: Option<Vec<String>>,
497}
498
499#[derive(Debug, Clone)]
505pub enum OAuthError {
506 InvalidRequest(String),
508 InvalidClient(String),
510 InvalidGrant(String),
512 UnauthorizedClient(String),
514 UnsupportedGrantType(String),
516 InvalidScope(String),
518 ServerError(String),
520 TemporarilyUnavailable(String),
522 AccessDenied(String),
524 UnsupportedResponseType(String),
526}
527
528impl OAuthError {
529 #[must_use]
531 pub fn error_code(&self) -> &'static str {
532 match self {
533 Self::InvalidRequest(_) => "invalid_request",
534 Self::InvalidClient(_) => "invalid_client",
535 Self::InvalidGrant(_) => "invalid_grant",
536 Self::UnauthorizedClient(_) => "unauthorized_client",
537 Self::UnsupportedGrantType(_) => "unsupported_grant_type",
538 Self::InvalidScope(_) => "invalid_scope",
539 Self::ServerError(_) => "server_error",
540 Self::TemporarilyUnavailable(_) => "temporarily_unavailable",
541 Self::AccessDenied(_) => "access_denied",
542 Self::UnsupportedResponseType(_) => "unsupported_response_type",
543 }
544 }
545
546 #[must_use]
548 pub fn description(&self) -> &str {
549 match self {
550 Self::InvalidRequest(s)
551 | Self::InvalidClient(s)
552 | Self::InvalidGrant(s)
553 | Self::UnauthorizedClient(s)
554 | Self::UnsupportedGrantType(s)
555 | Self::InvalidScope(s)
556 | Self::ServerError(s)
557 | Self::TemporarilyUnavailable(s)
558 | Self::AccessDenied(s)
559 | Self::UnsupportedResponseType(s) => s,
560 }
561 }
562}
563
564impl std::fmt::Display for OAuthError {
565 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 write!(f, "{}: {}", self.error_code(), self.description())
567 }
568}
569
570impl std::error::Error for OAuthError {}
571
572impl From<OAuthError> for McpError {
573 fn from(err: OAuthError) -> Self {
574 match &err {
575 OAuthError::InvalidClient(_) | OAuthError::UnauthorizedClient(_) => {
576 McpError::new(McpErrorCode::ResourceForbidden, err.to_string())
577 }
578 OAuthError::AccessDenied(_) => {
579 McpError::new(McpErrorCode::ResourceForbidden, err.to_string())
580 }
581 _ => McpError::new(McpErrorCode::InvalidRequest, err.to_string()),
582 }
583 }
584}
585
586pub(crate) struct OAuthServerState {
592 pub(crate) clients: HashMap<String, OAuthClient>,
594 pub(crate) authorization_codes: HashMap<String, AuthorizationCode>,
596 pub(crate) access_tokens: HashMap<String, OAuthToken>,
598 pub(crate) refresh_tokens: HashMap<String, OAuthToken>,
600 pub(crate) revoked_tokens: HashSet<String>,
602}
603
604impl OAuthServerState {
605 fn new() -> Self {
606 Self {
607 clients: HashMap::new(),
608 authorization_codes: HashMap::new(),
609 access_tokens: HashMap::new(),
610 refresh_tokens: HashMap::new(),
611 revoked_tokens: HashSet::new(),
612 }
613 }
614}
615
616pub struct OAuthServer {
621 config: OAuthServerConfig,
622 pub(crate) state: RwLock<OAuthServerState>,
623}
624
625impl OAuthServer {
626 #[must_use]
628 pub fn new(config: OAuthServerConfig) -> Self {
629 Self {
630 config,
631 state: RwLock::new(OAuthServerState::new()),
632 }
633 }
634
635 #[must_use]
637 pub fn with_defaults() -> Self {
638 Self::new(OAuthServerConfig::default())
639 }
640
641 #[must_use]
643 pub fn config(&self) -> &OAuthServerConfig {
644 &self.config
645 }
646
647 pub fn register_client(&self, client: OAuthClient) -> Result<(), OAuthError> {
659 if client.client_type == ClientType::Public && !self.config.allow_public_clients {
660 return Err(OAuthError::InvalidClient(
661 "public clients are not allowed".to_string(),
662 ));
663 }
664
665 let mut state = self
666 .state
667 .write()
668 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
669
670 if state.clients.contains_key(&client.client_id) {
671 return Err(OAuthError::InvalidClient(format!(
672 "client '{}' already exists",
673 client.client_id
674 )));
675 }
676
677 state.clients.insert(client.client_id.clone(), client);
678 Ok(())
679 }
680
681 pub fn unregister_client(&self, client_id: &str) -> Result<(), OAuthError> {
685 let mut state = self
686 .state
687 .write()
688 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
689
690 if state.clients.remove(client_id).is_none() {
691 return Err(OAuthError::InvalidClient(format!(
692 "client '{}' not found",
693 client_id
694 )));
695 }
696
697 let access_tokens: Vec<_> = state
699 .access_tokens
700 .iter()
701 .filter(|(_, t)| t.client_id == client_id)
702 .map(|(k, _)| k.clone())
703 .collect();
704 for token in access_tokens {
705 state.access_tokens.remove(&token);
706 state.revoked_tokens.insert(token);
707 }
708
709 let refresh_tokens: Vec<_> = state
710 .refresh_tokens
711 .iter()
712 .filter(|(_, t)| t.client_id == client_id)
713 .map(|(k, _)| k.clone())
714 .collect();
715 for token in refresh_tokens {
716 state.refresh_tokens.remove(&token);
717 state.revoked_tokens.insert(token);
718 }
719
720 let codes: Vec<_> = state
722 .authorization_codes
723 .iter()
724 .filter(|(_, c)| c.client_id == client_id)
725 .map(|(k, _)| k.clone())
726 .collect();
727 for code in codes {
728 state.authorization_codes.remove(&code);
729 }
730
731 Ok(())
732 }
733
734 #[must_use]
736 pub fn get_client(&self, client_id: &str) -> Option<OAuthClient> {
737 self.state
738 .read()
739 .ok()
740 .and_then(|s| s.clients.get(client_id).cloned())
741 }
742
743 #[must_use]
745 pub fn list_clients(&self) -> Vec<OAuthClient> {
746 self.state
747 .read()
748 .map(|s| s.clients.values().cloned().collect())
749 .unwrap_or_default()
750 }
751
752 pub fn authorize(
770 &self,
771 request: &AuthorizationRequest,
772 subject: Option<String>,
773 ) -> Result<(String, String), OAuthError> {
774 if request.response_type != "code" {
776 return Err(OAuthError::UnsupportedResponseType(
777 "only 'code' response_type is supported".to_string(),
778 ));
779 }
780
781 let client = self.get_client(&request.client_id).ok_or_else(|| {
783 OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
784 })?;
785
786 if !client.validate_redirect_uri(&request.redirect_uri) {
788 return Err(OAuthError::InvalidRequest(
789 "invalid redirect_uri".to_string(),
790 ));
791 }
792
793 if !client.validate_scopes(&request.scopes) {
795 return Err(OAuthError::InvalidScope(
796 "requested scope not allowed".to_string(),
797 ));
798 }
799
800 if request.code_challenge.is_empty() {
802 return Err(OAuthError::InvalidRequest(
803 "code_challenge is required (PKCE)".to_string(),
804 ));
805 }
806
807 let code_value = generate_token(self.config.token_entropy_bytes);
809 let now = Instant::now();
810 let code = AuthorizationCode {
811 code: code_value.clone(),
812 client_id: request.client_id.clone(),
813 redirect_uri: request.redirect_uri.clone(),
814 scopes: request.scopes.clone(),
815 code_challenge: request.code_challenge.clone(),
816 code_challenge_method: request.code_challenge_method,
817 issued_at: now,
818 expires_at: now + self.config.authorization_code_lifetime,
819 subject,
820 state: request.state.clone(),
821 };
822
823 {
825 let mut state = self
826 .state
827 .write()
828 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
829 state.authorization_codes.insert(code_value.clone(), code);
830 }
831
832 let mut redirect = request.redirect_uri.clone();
834 let separator = if redirect.contains('?') { '&' } else { '?' };
835 redirect.push(separator);
836 redirect.push_str("code=");
837 redirect.push_str(&url_encode(&code_value));
838 if let Some(state) = &request.state {
839 redirect.push_str("&state=");
840 redirect.push_str(&url_encode(state));
841 }
842
843 Ok((code_value, redirect))
844 }
845
846 pub fn token(&self, request: &TokenRequest) -> Result<TokenResponse, OAuthError> {
852 match request.grant_type.as_str() {
853 "authorization_code" => self.token_authorization_code(request),
854 "refresh_token" => self.token_refresh_token(request),
855 other => Err(OAuthError::UnsupportedGrantType(format!(
856 "grant_type '{}' is not supported",
857 other
858 ))),
859 }
860 }
861
862 fn token_authorization_code(
863 &self,
864 request: &TokenRequest,
865 ) -> Result<TokenResponse, OAuthError> {
866 let code_value = request
868 .code
869 .as_ref()
870 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
871 let redirect_uri = request
872 .redirect_uri
873 .as_ref()
874 .ok_or_else(|| OAuthError::InvalidRequest("redirect_uri is required".to_string()))?;
875 let code_verifier = request.code_verifier.as_ref().ok_or_else(|| {
876 OAuthError::InvalidRequest("code_verifier is required (PKCE)".to_string())
877 })?;
878
879 if code_verifier.len() < self.config.min_code_verifier_length
881 || code_verifier.len() > self.config.max_code_verifier_length
882 {
883 return Err(OAuthError::InvalidRequest(format!(
884 "code_verifier must be between {} and {} characters",
885 self.config.min_code_verifier_length, self.config.max_code_verifier_length
886 )));
887 }
888
889 let auth_code = {
891 let mut state = self
892 .state
893 .write()
894 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
895
896 state
898 .authorization_codes
899 .remove(code_value)
900 .ok_or_else(|| {
901 OAuthError::InvalidGrant(
902 "authorization code not found or already used".to_string(),
903 )
904 })?
905 };
906
907 if auth_code.is_expired() {
909 return Err(OAuthError::InvalidGrant(
910 "authorization code has expired".to_string(),
911 ));
912 }
913 if auth_code.client_id != request.client_id {
914 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
915 }
916 if auth_code.redirect_uri != *redirect_uri {
917 return Err(OAuthError::InvalidGrant(
918 "redirect_uri mismatch".to_string(),
919 ));
920 }
921
922 if !auth_code.validate_code_verifier(code_verifier) {
924 return Err(OAuthError::InvalidGrant(
925 "code_verifier validation failed".to_string(),
926 ));
927 }
928
929 let client = self.get_client(&request.client_id).ok_or_else(|| {
931 OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
932 })?;
933
934 if client.client_type == ClientType::Confidential {
935 if !client.authenticate(request.client_secret.as_deref()) {
936 return Err(OAuthError::InvalidClient(
937 "client authentication failed".to_string(),
938 ));
939 }
940 }
941
942 self.issue_tokens(
944 &auth_code.client_id,
945 &auth_code.scopes,
946 auth_code.subject.as_deref(),
947 )
948 }
949
950 fn token_refresh_token(&self, request: &TokenRequest) -> Result<TokenResponse, OAuthError> {
951 let refresh_token_value = request
952 .refresh_token
953 .as_ref()
954 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
955
956 let refresh_token = {
958 let state = self
959 .state
960 .read()
961 .map_err(|_| OAuthError::ServerError("failed to acquire read lock".to_string()))?;
962
963 if state.revoked_tokens.contains(refresh_token_value) {
965 return Err(OAuthError::InvalidGrant(
966 "refresh token has been revoked".to_string(),
967 ));
968 }
969
970 state
971 .refresh_tokens
972 .get(refresh_token_value)
973 .cloned()
974 .ok_or_else(|| OAuthError::InvalidGrant("refresh token not found".to_string()))?
975 };
976
977 if refresh_token.is_expired() {
978 return Err(OAuthError::InvalidGrant(
979 "refresh token has expired".to_string(),
980 ));
981 }
982 if refresh_token.client_id != request.client_id {
983 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
984 }
985
986 let client = self.get_client(&request.client_id).ok_or_else(|| {
988 OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
989 })?;
990
991 if client.client_type == ClientType::Confidential {
992 if !client.authenticate(request.client_secret.as_deref()) {
993 return Err(OAuthError::InvalidClient(
994 "client authentication failed".to_string(),
995 ));
996 }
997 }
998
999 let scopes = if let Some(requested) = &request.scopes {
1001 for scope in requested {
1003 if !refresh_token.scopes.contains(scope) {
1004 return Err(OAuthError::InvalidScope(format!(
1005 "scope '{}' was not in original grant",
1006 scope
1007 )));
1008 }
1009 }
1010 requested.clone()
1011 } else {
1012 refresh_token.scopes.clone()
1013 };
1014
1015 let now = Instant::now();
1017 let access_token_value = generate_token(self.config.token_entropy_bytes);
1018 let access_token = OAuthToken {
1019 token: access_token_value.clone(),
1020 token_type: TokenType::Bearer,
1021 client_id: request.client_id.clone(),
1022 scopes: scopes.clone(),
1023 issued_at: now,
1024 expires_at: now + self.config.access_token_lifetime,
1025 subject: refresh_token.subject.clone(),
1026 is_refresh_token: false,
1027 };
1028
1029 {
1031 let mut state = self
1032 .state
1033 .write()
1034 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1035 state
1036 .access_tokens
1037 .insert(access_token_value.clone(), access_token.clone());
1038 }
1039
1040 Ok(TokenResponse {
1041 access_token: access_token_value,
1042 token_type: access_token.token_type.as_str().to_string(),
1043 expires_in: access_token.expires_in_secs(),
1044 refresh_token: None, scope: if scopes.is_empty() {
1046 None
1047 } else {
1048 Some(scopes.join(" "))
1049 },
1050 })
1051 }
1052
1053 fn issue_tokens(
1054 &self,
1055 client_id: &str,
1056 scopes: &[String],
1057 subject: Option<&str>,
1058 ) -> Result<TokenResponse, OAuthError> {
1059 let now = Instant::now();
1060
1061 let access_token_value = generate_token(self.config.token_entropy_bytes);
1063 let access_token = OAuthToken {
1064 token: access_token_value.clone(),
1065 token_type: TokenType::Bearer,
1066 client_id: client_id.to_string(),
1067 scopes: scopes.to_vec(),
1068 issued_at: now,
1069 expires_at: now + self.config.access_token_lifetime,
1070 subject: subject.map(String::from),
1071 is_refresh_token: false,
1072 };
1073
1074 let refresh_token_value = generate_token(self.config.token_entropy_bytes);
1076 let refresh_token = OAuthToken {
1077 token: refresh_token_value.clone(),
1078 token_type: TokenType::Bearer,
1079 client_id: client_id.to_string(),
1080 scopes: scopes.to_vec(),
1081 issued_at: now,
1082 expires_at: now + self.config.refresh_token_lifetime,
1083 subject: subject.map(String::from),
1084 is_refresh_token: true,
1085 };
1086
1087 {
1089 let mut state = self
1090 .state
1091 .write()
1092 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1093 state
1094 .access_tokens
1095 .insert(access_token_value.clone(), access_token.clone());
1096 state
1097 .refresh_tokens
1098 .insert(refresh_token_value.clone(), refresh_token);
1099 }
1100
1101 Ok(TokenResponse {
1102 access_token: access_token_value,
1103 token_type: access_token.token_type.as_str().to_string(),
1104 expires_in: access_token.expires_in_secs(),
1105 refresh_token: Some(refresh_token_value),
1106 scope: if scopes.is_empty() {
1107 None
1108 } else {
1109 Some(scopes.join(" "))
1110 },
1111 })
1112 }
1113
1114 pub fn revoke(
1122 &self,
1123 token: &str,
1124 client_id: &str,
1125 client_secret: Option<&str>,
1126 ) -> Result<(), OAuthError> {
1127 let client = self.get_client(client_id).ok_or_else(|| {
1129 OAuthError::InvalidClient(format!("client '{}' not found", client_id))
1130 })?;
1131
1132 if client.client_type == ClientType::Confidential {
1133 if !client.authenticate(client_secret) {
1134 return Err(OAuthError::InvalidClient(
1135 "client authentication failed".to_string(),
1136 ));
1137 }
1138 }
1139
1140 let mut state = self
1141 .state
1142 .write()
1143 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1144
1145 let found_access = state.access_tokens.remove(token);
1147 let found_refresh = state.refresh_tokens.remove(token);
1148
1149 if let Some(ref t) = found_access {
1151 if t.client_id != client_id {
1152 return Ok(());
1154 }
1155 }
1156 if let Some(ref t) = found_refresh {
1157 if t.client_id != client_id {
1158 return Ok(());
1159 }
1160 }
1161
1162 if found_access.is_some() || found_refresh.is_some() {
1164 state.revoked_tokens.insert(token.to_string());
1165 }
1166
1167 Ok(())
1168 }
1169
1170 pub fn validate_access_token(&self, token: &str) -> Option<OAuthToken> {
1178 let state = self.state.read().ok()?;
1179
1180 if state.revoked_tokens.contains(token) {
1182 return None;
1183 }
1184
1185 let token_info = state.access_tokens.get(token)?;
1186
1187 if token_info.is_expired() {
1188 return None;
1189 }
1190
1191 Some(token_info.clone())
1192 }
1193
1194 #[must_use]
1200 pub fn token_verifier(self: &Arc<Self>) -> OAuthTokenVerifier {
1201 OAuthTokenVerifier {
1202 server: Arc::clone(self),
1203 }
1204 }
1205
1206 pub fn cleanup_expired(&self) {
1214 let Ok(mut state) = self.state.write() else {
1215 return;
1216 };
1217
1218 state.authorization_codes.retain(|_, c| !c.is_expired());
1220
1221 state.access_tokens.retain(|_, t| !t.is_expired());
1223
1224 state.refresh_tokens.retain(|_, t| !t.is_expired());
1226 }
1227
1228 #[must_use]
1230 pub fn stats(&self) -> OAuthServerStats {
1231 let state = self.state.read().unwrap();
1232 OAuthServerStats {
1233 clients: state.clients.len(),
1234 authorization_codes: state.authorization_codes.len(),
1235 access_tokens: state.access_tokens.len(),
1236 refresh_tokens: state.refresh_tokens.len(),
1237 revoked_tokens: state.revoked_tokens.len(),
1238 }
1239 }
1240}
1241
1242#[derive(Debug, Clone, Default)]
1244pub struct OAuthServerStats {
1245 pub clients: usize,
1247 pub authorization_codes: usize,
1249 pub access_tokens: usize,
1251 pub refresh_tokens: usize,
1253 pub revoked_tokens: usize,
1255}
1256
1257pub struct OAuthTokenVerifier {
1266 server: Arc<OAuthServer>,
1267}
1268
1269impl TokenVerifier for OAuthTokenVerifier {
1270 fn verify(
1271 &self,
1272 _ctx: &McpContext,
1273 _request: AuthRequest<'_>,
1274 token: &AccessToken,
1275 ) -> McpResult<AuthContext> {
1276 if !token.scheme.eq_ignore_ascii_case("Bearer") {
1278 return Err(McpError::new(
1279 McpErrorCode::ResourceForbidden,
1280 "unsupported auth scheme",
1281 ));
1282 }
1283
1284 let token_info = self
1286 .server
1287 .validate_access_token(&token.token)
1288 .ok_or_else(|| {
1289 McpError::new(McpErrorCode::ResourceForbidden, "invalid or expired token")
1290 })?;
1291
1292 Ok(AuthContext {
1293 subject: token_info.subject,
1294 scopes: token_info.scopes,
1295 token: Some(token.clone()),
1296 claims: Some(serde_json::json!({
1297 "client_id": token_info.client_id,
1298 "iss": self.server.config.issuer,
1299 "iat": token_info.issued_at.elapsed().as_secs(),
1300 })),
1301 })
1302 }
1303}
1304
1305fn generate_token(bytes: usize) -> String {
1311 use std::collections::hash_map::RandomState;
1312 use std::hash::{BuildHasher, Hasher};
1313
1314 let mut result = Vec::with_capacity(bytes * 2);
1316 let state = RandomState::new();
1317
1318 for i in 0..bytes {
1319 let mut hasher = state.build_hasher();
1320 hasher.write_usize(i);
1321 hasher.write_u128(
1322 SystemTime::now()
1323 .duration_since(UNIX_EPOCH)
1324 .unwrap_or_default()
1325 .as_nanos(),
1326 );
1327 let hash = hasher.finish();
1328 result.extend_from_slice(&hash.to_le_bytes()[..2]);
1329 }
1330
1331 base64url_encode(&result[..bytes])
1333}
1334
1335fn base64url_encode(data: &[u8]) -> String {
1337 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
1338
1339 let mut result = String::with_capacity((data.len() * 4).div_ceil(3));
1340 let mut i = 0;
1341
1342 while i + 2 < data.len() {
1343 let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8) | u32::from(data[i + 2]);
1344 result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
1345 result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
1346 result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
1347 result.push(ALPHABET[n as usize & 0x3F] as char);
1348 i += 3;
1349 }
1350
1351 if i + 1 == data.len() {
1352 let n = u32::from(data[i]) << 16;
1353 result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
1354 result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
1355 } else if i + 2 == data.len() {
1356 let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8);
1357 result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
1358 result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
1359 result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
1360 }
1361
1362 result
1363}
1364
1365fn compute_s256_challenge(verifier: &str) -> String {
1367 let hash = simple_sha256(verifier.as_bytes());
1370 base64url_encode(&hash)
1371}
1372
1373fn simple_sha256(data: &[u8]) -> [u8; 32] {
1376 use std::collections::hash_map::RandomState;
1379 use std::hash::{BuildHasher, Hasher};
1380
1381 let mut result = [0u8; 32];
1382 let state = RandomState::new();
1383
1384 for (i, chunk) in result.chunks_mut(8).enumerate() {
1385 let mut hasher = state.build_hasher();
1386 hasher.write(data);
1387 hasher.write_usize(i);
1388 let hash = hasher.finish().to_le_bytes();
1389 chunk.copy_from_slice(&hash[..chunk.len()]);
1390 }
1391
1392 result
1393}
1394
1395fn url_encode(s: &str) -> String {
1397 let mut result = String::with_capacity(s.len() * 3);
1398 for byte in s.bytes() {
1399 match byte {
1400 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1401 result.push(byte as char);
1402 }
1403 _ => {
1404 result.push('%');
1405 result.push_str(&format!("{:02X}", byte));
1406 }
1407 }
1408 }
1409 result
1410}
1411
1412fn constant_time_eq(a: &str, b: &str) -> bool {
1414 if a.len() != b.len() {
1415 return false;
1416 }
1417
1418 let mut result = 0u8;
1419 for (x, y) in a.bytes().zip(b.bytes()) {
1420 result |= x ^ y;
1421 }
1422 result == 0
1423}
1424
1425fn is_localhost_redirect(uri: &str) -> bool {
1427 uri.starts_with("http://localhost")
1428 || uri.starts_with("http://127.0.0.1")
1429 || uri.starts_with("http://[::1]")
1430}
1431
1432fn localhost_match(a: &str, b: &str) -> bool {
1434 fn extract_parts(uri: &str) -> Option<(String, String)> {
1436 let after_scheme = uri.strip_prefix("http://")?;
1437 let path_start = after_scheme.find('/').unwrap_or(after_scheme.len());
1439 let host_port = &after_scheme[..path_start];
1440 let path = &after_scheme[path_start..];
1441
1442 let host = host_port.rsplit_once(':').map_or(host_port, |(h, _)| h);
1444 Some((host.to_string(), path.to_string()))
1445 }
1446
1447 match (extract_parts(a), extract_parts(b)) {
1448 (Some((host_a, path_a)), Some((host_b, path_b))) => {
1449 normalize_localhost(&host_a) == normalize_localhost(&host_b) && path_a == path_b
1450 }
1451 _ => false,
1452 }
1453}
1454
1455fn normalize_localhost(host: &str) -> &'static str {
1457 match host {
1458 "localhost" | "127.0.0.1" | "[::1]" => "localhost",
1459 _ => "other",
1460 }
1461}
1462
1463#[cfg(test)]
1468mod tests {
1469 use super::*;
1470
1471 #[test]
1472 fn test_client_builder() {
1473 let client = OAuthClient::builder("test-client")
1474 .redirect_uri("http://localhost:3000/callback")
1475 .scope("read")
1476 .scope("write")
1477 .name("Test Client")
1478 .build()
1479 .unwrap();
1480
1481 assert_eq!(client.client_id, "test-client");
1482 assert_eq!(client.client_type, ClientType::Public);
1483 assert_eq!(client.redirect_uris.len(), 1);
1484 assert!(client.allowed_scopes.contains("read"));
1485 assert!(client.allowed_scopes.contains("write"));
1486 }
1487
1488 #[test]
1489 fn test_confidential_client() {
1490 let client = OAuthClient::builder("test-client")
1491 .secret("super-secret")
1492 .redirect_uri("http://localhost:3000/callback")
1493 .build()
1494 .unwrap();
1495
1496 assert_eq!(client.client_type, ClientType::Confidential);
1497 assert!(client.authenticate(Some("super-secret")));
1498 assert!(!client.authenticate(Some("wrong-secret")));
1499 assert!(!client.authenticate(None));
1500 }
1501
1502 #[test]
1503 fn test_redirect_uri_validation() {
1504 let client = OAuthClient::builder("test-client")
1505 .redirect_uri("http://localhost:3000/callback")
1506 .redirect_uri("https://example.com/oauth/callback")
1507 .build()
1508 .unwrap();
1509
1510 assert!(client.validate_redirect_uri("http://localhost:3000/callback"));
1512 assert!(client.validate_redirect_uri("https://example.com/oauth/callback"));
1513
1514 assert!(client.validate_redirect_uri("http://localhost:8080/callback"));
1516 assert!(client.validate_redirect_uri("http://127.0.0.1:9000/callback"));
1517
1518 assert!(!client.validate_redirect_uri("http://localhost:3000/other"));
1520 assert!(!client.validate_redirect_uri("https://evil.com/callback"));
1521 }
1522
1523 #[test]
1524 fn test_scope_validation() {
1525 let client = OAuthClient::builder("test-client")
1526 .redirect_uri("http://localhost:3000/callback")
1527 .scope("read")
1528 .scope("write")
1529 .build()
1530 .unwrap();
1531
1532 assert!(client.validate_scopes(&["read".to_string()]));
1533 assert!(client.validate_scopes(&["read".to_string(), "write".to_string()]));
1534 assert!(!client.validate_scopes(&["admin".to_string()]));
1535 }
1536
1537 #[test]
1538 fn test_oauth_server_client_registration() {
1539 let server = OAuthServer::with_defaults();
1540
1541 let client = OAuthClient::builder("test-client")
1542 .redirect_uri("http://localhost:3000/callback")
1543 .build()
1544 .unwrap();
1545
1546 server.register_client(client).unwrap();
1547
1548 let client2 = OAuthClient::builder("test-client")
1550 .redirect_uri("http://localhost:3000/callback")
1551 .build()
1552 .unwrap();
1553 assert!(server.register_client(client2).is_err());
1554
1555 assert!(server.get_client("test-client").is_some());
1557 assert!(server.get_client("nonexistent").is_none());
1558 }
1559
1560 #[test]
1561 fn test_authorization_flow() {
1562 let server = OAuthServer::with_defaults();
1563
1564 let client = OAuthClient::builder("test-client")
1565 .redirect_uri("http://localhost:3000/callback")
1566 .scope("read")
1567 .build()
1568 .unwrap();
1569 server.register_client(client).unwrap();
1570
1571 let request = AuthorizationRequest {
1573 response_type: "code".to_string(),
1574 client_id: "test-client".to_string(),
1575 redirect_uri: "http://localhost:3000/callback".to_string(),
1576 scopes: vec!["read".to_string()],
1577 state: Some("xyz".to_string()),
1578 code_challenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM".to_string(),
1579 code_challenge_method: CodeChallengeMethod::S256,
1580 };
1581
1582 let (code, redirect) = server
1583 .authorize(&request, Some("user123".to_string()))
1584 .unwrap();
1585
1586 assert!(!code.is_empty());
1587 assert!(redirect.contains("code="));
1588 assert!(redirect.contains("state=xyz"));
1589 }
1590
1591 #[test]
1592 fn test_pkce_required() {
1593 let server = OAuthServer::with_defaults();
1594
1595 let client = OAuthClient::builder("test-client")
1596 .redirect_uri("http://localhost:3000/callback")
1597 .build()
1598 .unwrap();
1599 server.register_client(client).unwrap();
1600
1601 let request = AuthorizationRequest {
1603 response_type: "code".to_string(),
1604 client_id: "test-client".to_string(),
1605 redirect_uri: "http://localhost:3000/callback".to_string(),
1606 scopes: vec![],
1607 state: None,
1608 code_challenge: String::new(), code_challenge_method: CodeChallengeMethod::S256,
1610 };
1611
1612 let result = server.authorize(&request, None);
1613 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
1614 }
1615
1616 #[test]
1617 fn test_token_generation() {
1618 let token1 = generate_token(32);
1619 let token2 = generate_token(32);
1620
1621 assert_ne!(token1, token2);
1623 assert!(
1625 token1
1626 .chars()
1627 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
1628 );
1629 }
1630
1631 #[test]
1632 fn test_base64url_encode() {
1633 assert_eq!(base64url_encode(b""), "");
1635 assert_eq!(base64url_encode(b"f"), "Zg");
1636 assert_eq!(base64url_encode(b"fo"), "Zm8");
1637 assert_eq!(base64url_encode(b"foo"), "Zm9v");
1638 assert_eq!(base64url_encode(b"foob"), "Zm9vYg");
1639 assert_eq!(base64url_encode(b"fooba"), "Zm9vYmE");
1640 assert_eq!(base64url_encode(b"foobar"), "Zm9vYmFy");
1641 }
1642
1643 #[test]
1644 fn test_url_encode() {
1645 assert_eq!(url_encode("hello"), "hello");
1646 assert_eq!(url_encode("hello world"), "hello%20world");
1647 assert_eq!(url_encode("a=b&c=d"), "a%3Db%26c%3Dd");
1648 }
1649
1650 #[test]
1651 fn test_constant_time_eq() {
1652 assert!(constant_time_eq("hello", "hello"));
1653 assert!(!constant_time_eq("hello", "world"));
1654 assert!(!constant_time_eq("hello", "hell"));
1655 }
1656
1657 #[test]
1658 fn test_localhost_match() {
1659 assert!(localhost_match(
1660 "http://localhost:3000/callback",
1661 "http://localhost:8080/callback"
1662 ));
1663 assert!(localhost_match(
1664 "http://127.0.0.1:3000/callback",
1665 "http://localhost:8080/callback"
1666 ));
1667 assert!(!localhost_match(
1668 "http://localhost:3000/callback",
1669 "http://localhost:3000/other"
1670 ));
1671 }
1672
1673 #[test]
1674 fn test_oauth_server_stats() {
1675 let server = OAuthServer::with_defaults();
1676
1677 let stats = server.stats();
1678 assert_eq!(stats.clients, 0);
1679 assert_eq!(stats.access_tokens, 0);
1680
1681 let client = OAuthClient::builder("test-client")
1682 .redirect_uri("http://localhost:3000/callback")
1683 .build()
1684 .unwrap();
1685 server.register_client(client).unwrap();
1686
1687 let stats = server.stats();
1688 assert_eq!(stats.clients, 1);
1689 }
1690
1691 #[test]
1692 fn test_code_challenge_method_parse() {
1693 assert_eq!(
1694 CodeChallengeMethod::parse("plain"),
1695 Some(CodeChallengeMethod::Plain)
1696 );
1697 assert_eq!(
1698 CodeChallengeMethod::parse("S256"),
1699 Some(CodeChallengeMethod::S256)
1700 );
1701 assert_eq!(CodeChallengeMethod::parse("unknown"), None);
1702 }
1703
1704 #[test]
1705 fn test_oauth_error_display() {
1706 let err = OAuthError::InvalidRequest("missing parameter".to_string());
1707 assert_eq!(err.error_code(), "invalid_request");
1708 assert_eq!(err.description(), "missing parameter");
1709 assert_eq!(err.to_string(), "invalid_request: missing parameter");
1710 }
1711
1712 #[test]
1713 fn test_token_revocation() {
1714 let server = Arc::new(OAuthServer::with_defaults());
1715
1716 let client = OAuthClient::builder("test-client")
1718 .redirect_uri("http://localhost:3000/callback")
1719 .scope("read")
1720 .build()
1721 .unwrap();
1722 server.register_client(client).unwrap();
1723
1724 let token_response = {
1726 let mut state = server.state.write().unwrap();
1727 let now = Instant::now();
1728 let token = OAuthToken {
1729 token: "test-access-token".to_string(),
1730 token_type: TokenType::Bearer,
1731 client_id: "test-client".to_string(),
1732 scopes: vec!["read".to_string()],
1733 issued_at: now,
1734 expires_at: now + Duration::from_secs(3600),
1735 subject: Some("user123".to_string()),
1736 is_refresh_token: false,
1737 };
1738 state
1739 .access_tokens
1740 .insert("test-access-token".to_string(), token);
1741 TokenResponse {
1742 access_token: "test-access-token".to_string(),
1743 token_type: "bearer".to_string(),
1744 expires_in: 3600,
1745 refresh_token: None,
1746 scope: Some("read".to_string()),
1747 }
1748 };
1749
1750 assert!(
1752 server
1753 .validate_access_token(&token_response.access_token)
1754 .is_some()
1755 );
1756
1757 server
1759 .revoke(&token_response.access_token, "test-client", None)
1760 .unwrap();
1761
1762 assert!(
1764 server
1765 .validate_access_token(&token_response.access_token)
1766 .is_none()
1767 );
1768 }
1769
1770 #[test]
1771 fn test_client_unregistration() {
1772 let server = OAuthServer::with_defaults();
1773
1774 let client = OAuthClient::builder("test-client")
1775 .redirect_uri("http://localhost:3000/callback")
1776 .build()
1777 .unwrap();
1778 server.register_client(client).unwrap();
1779
1780 assert!(server.get_client("test-client").is_some());
1781
1782 server.unregister_client("test-client").unwrap();
1783
1784 assert!(server.get_client("test-client").is_none());
1785
1786 assert!(server.unregister_client("test-client").is_err());
1788 }
1789
1790 #[test]
1791 fn test_token_verifier() {
1792 let server = Arc::new(OAuthServer::with_defaults());
1793
1794 let client = OAuthClient::builder("test-client")
1796 .redirect_uri("http://localhost:3000/callback")
1797 .scope("read")
1798 .build()
1799 .unwrap();
1800 server.register_client(client).unwrap();
1801
1802 {
1804 let mut state = server.state.write().unwrap();
1805 let now = Instant::now();
1806 let token = OAuthToken {
1807 token: "valid-token".to_string(),
1808 token_type: TokenType::Bearer,
1809 client_id: "test-client".to_string(),
1810 scopes: vec!["read".to_string()],
1811 issued_at: now,
1812 expires_at: now + Duration::from_secs(3600),
1813 subject: Some("user123".to_string()),
1814 is_refresh_token: false,
1815 };
1816 state.access_tokens.insert("valid-token".to_string(), token);
1817 }
1818
1819 let verifier = server.token_verifier();
1821 let cx = asupersync::Cx::for_testing();
1822 let mcp_ctx = McpContext::new(cx, 1);
1823 let auth_request = AuthRequest {
1824 method: "test",
1825 params: None,
1826 request_id: 1,
1827 };
1828
1829 let access = AccessToken {
1831 scheme: "Bearer".to_string(),
1832 token: "valid-token".to_string(),
1833 };
1834 let result = verifier.verify(&mcp_ctx, auth_request, &access);
1835 assert!(result.is_ok());
1836 let auth = result.unwrap();
1837 assert_eq!(auth.subject, Some("user123".to_string()));
1838 assert_eq!(auth.scopes, vec!["read".to_string()]);
1839
1840 let invalid = AccessToken {
1842 scheme: "Bearer".to_string(),
1843 token: "invalid-token".to_string(),
1844 };
1845 let result = verifier.verify(&mcp_ctx, auth_request, &invalid);
1846 assert!(result.is_err());
1847
1848 let wrong_scheme = AccessToken {
1850 scheme: "Basic".to_string(),
1851 token: "valid-token".to_string(),
1852 };
1853 let result = verifier.verify(&mcp_ctx, auth_request, &wrong_scheme);
1854 assert!(result.is_err());
1855 }
1856}