1use std::collections::{HashMap, HashSet};
54use std::sync::{Arc, RwLock};
55use std::time::{Duration, Instant, SystemTime};
56
57use fastmcp_core::{AccessToken, AuthContext, McpContext, McpError, McpErrorCode, McpResult};
58
59use crate::auth::{AuthRequest, TokenVerifier};
60
61#[derive(Debug, Clone)]
67pub struct OAuthServerConfig {
68 pub issuer: String,
70 pub access_token_lifetime: Duration,
72 pub refresh_token_lifetime: Duration,
74 pub authorization_code_lifetime: Duration,
76 pub allow_public_clients: bool,
78 pub min_code_verifier_length: usize,
80 pub max_code_verifier_length: usize,
82 pub token_entropy_bytes: usize,
84}
85
86impl Default for OAuthServerConfig {
87 fn default() -> Self {
88 Self {
89 issuer: "fastmcp".to_string(),
90 access_token_lifetime: Duration::from_secs(3600), refresh_token_lifetime: Duration::from_secs(86400 * 30), authorization_code_lifetime: Duration::from_secs(600), allow_public_clients: true,
94 min_code_verifier_length: 43,
95 max_code_verifier_length: 128,
96 token_entropy_bytes: 32,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum ClientType {
108 Confidential,
110 Public,
112}
113
114#[derive(Debug, Clone)]
116pub struct OAuthClient {
117 pub client_id: String,
119 pub client_secret: Option<String>,
121 pub client_type: ClientType,
123 pub redirect_uris: Vec<String>,
125 pub allowed_scopes: HashSet<String>,
127 pub name: Option<String>,
129 pub description: Option<String>,
131 pub registered_at: SystemTime,
133}
134
135impl OAuthClient {
136 #[must_use]
138 pub fn builder(client_id: impl Into<String>) -> OAuthClientBuilder {
139 OAuthClientBuilder::new(client_id)
140 }
141
142 #[must_use]
144 pub fn validate_redirect_uri(&self, uri: &str) -> bool {
145 if self.redirect_uris.contains(&uri.to_string()) {
147 return true;
148 }
149
150 for allowed in &self.redirect_uris {
152 if is_localhost_redirect(allowed) && is_localhost_redirect(uri) {
153 if localhost_match(allowed, uri) {
155 return true;
156 }
157 }
158 }
159
160 false
161 }
162
163 #[must_use]
165 pub fn validate_scopes(&self, scopes: &[String]) -> bool {
166 scopes.iter().all(|s| self.allowed_scopes.contains(s))
167 }
168
169 #[must_use]
171 pub fn authenticate(&self, secret: Option<&str>) -> bool {
172 match (&self.client_secret, secret) {
173 (Some(expected), Some(provided)) => constant_time_eq(expected, provided),
174 (None, None) => self.client_type == ClientType::Public,
175 _ => false,
176 }
177 }
178}
179
180#[derive(Debug)]
182pub struct OAuthClientBuilder {
183 client_id: String,
184 client_credential: Option<String>,
185 redirect_uris: Vec<String>,
186 allowed_scopes: HashSet<String>,
187 name: Option<String>,
188 description: Option<String>,
189}
190
191impl OAuthClientBuilder {
192 fn new(client_id: impl Into<String>) -> Self {
194 Self {
195 client_id: client_id.into(),
196 client_credential: None,
197 redirect_uris: Vec::new(),
198 allowed_scopes: HashSet::new(),
199 name: None,
200 description: None,
201 }
202 }
203
204 #[must_use]
206 pub fn secret(mut self, credential: impl Into<String>) -> Self {
207 self.client_credential = Some(credential.into());
208 self
209 }
210
211 #[must_use]
213 pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
214 self.redirect_uris.push(uri.into());
215 self
216 }
217
218 #[must_use]
220 pub fn redirect_uris<I, S>(mut self, uris: I) -> Self
221 where
222 I: IntoIterator<Item = S>,
223 S: Into<String>,
224 {
225 self.redirect_uris.extend(uris.into_iter().map(Into::into));
226 self
227 }
228
229 #[must_use]
231 pub fn scope(mut self, scope: impl Into<String>) -> Self {
232 self.allowed_scopes.insert(scope.into());
233 self
234 }
235
236 #[must_use]
238 pub fn scopes<I, S>(mut self, scopes: I) -> Self
239 where
240 I: IntoIterator<Item = S>,
241 S: Into<String>,
242 {
243 self.allowed_scopes
244 .extend(scopes.into_iter().map(Into::into));
245 self
246 }
247
248 #[must_use]
250 pub fn name(mut self, name: impl Into<String>) -> Self {
251 self.name = Some(name.into());
252 self
253 }
254
255 #[must_use]
257 pub fn description(mut self, description: impl Into<String>) -> Self {
258 self.description = Some(description.into());
259 self
260 }
261
262 pub fn build(self) -> Result<OAuthClient, OAuthError> {
270 if self.client_id.is_empty() {
271 return Err(OAuthError::InvalidRequest(
272 "client_id cannot be empty".to_string(),
273 ));
274 }
275
276 if self.redirect_uris.is_empty() {
277 return Err(OAuthError::InvalidRequest(
278 "at least one redirect_uri is required".to_string(),
279 ));
280 }
281
282 let client_type = if self.client_credential.is_some() {
283 ClientType::Confidential
284 } else {
285 ClientType::Public
286 };
287
288 Ok(OAuthClient {
289 client_id: self.client_id,
290 client_secret: self.client_credential,
291 client_type,
292 redirect_uris: self.redirect_uris,
293 allowed_scopes: self.allowed_scopes,
294 name: self.name,
295 description: self.description,
296 registered_at: SystemTime::now(),
297 })
298 }
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub enum CodeChallengeMethod {
308 Plain,
310 S256,
312}
313
314impl CodeChallengeMethod {
315 #[must_use]
317 pub fn parse(s: &str) -> Option<Self> {
318 match s {
319 "plain" => Some(Self::Plain),
320 "S256" => Some(Self::S256),
321 _ => None,
322 }
323 }
324
325 #[must_use]
327 pub fn as_str(&self) -> &'static str {
328 match self {
329 Self::Plain => "plain",
330 Self::S256 => "S256",
331 }
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct AuthorizationCode {
338 pub code: String,
340 pub client_id: String,
342 pub redirect_uri: String,
344 pub scopes: Vec<String>,
346 pub code_challenge: String,
348 pub code_challenge_method: CodeChallengeMethod,
350 pub issued_at: Instant,
352 pub expires_at: Instant,
354 pub subject: Option<String>,
356 pub state: Option<String>,
358}
359
360impl AuthorizationCode {
361 #[must_use]
363 pub fn is_expired(&self) -> bool {
364 Instant::now() >= self.expires_at
365 }
366
367 #[must_use]
369 pub fn validate_code_verifier(&self, verifier: &str) -> bool {
370 match self.code_challenge_method {
371 CodeChallengeMethod::Plain => constant_time_eq(&self.code_challenge, verifier),
372 CodeChallengeMethod::S256 => {
373 let computed = compute_s256_challenge(verifier);
374 constant_time_eq(&self.code_challenge, &computed)
375 }
376 }
377 }
378}
379
380#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum TokenType {
387 Bearer,
389}
390
391impl TokenType {
392 #[must_use]
394 pub fn as_str(&self) -> &'static str {
395 match self {
396 Self::Bearer => "bearer",
397 }
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct OAuthToken {
404 pub token: String,
406 pub token_type: TokenType,
408 pub client_id: String,
410 pub scopes: Vec<String>,
412 pub issued_at: Instant,
414 pub expires_at: Instant,
416 pub subject: Option<String>,
418 pub is_refresh_token: bool,
420}
421
422impl OAuthToken {
423 #[must_use]
425 pub fn is_expired(&self) -> bool {
426 Instant::now() >= self.expires_at
427 }
428
429 #[must_use]
431 pub fn expires_in_secs(&self) -> u64 {
432 self.expires_at
433 .saturating_duration_since(Instant::now())
434 .as_secs()
435 }
436}
437
438#[derive(Debug, Clone, serde::Serialize)]
440pub struct TokenResponse {
441 pub access_token: String,
443 pub token_type: String,
445 pub expires_in: u64,
447 #[serde(skip_serializing_if = "Option::is_none")]
449 pub refresh_token: Option<String>,
450 #[serde(skip_serializing_if = "Option::is_none")]
452 pub scope: Option<String>,
453}
454
455#[derive(Debug, Clone)]
461pub struct AuthorizationRequest {
462 pub response_type: String,
464 pub client_id: String,
466 pub redirect_uri: String,
468 pub scopes: Vec<String>,
470 pub state: Option<String>,
472 pub code_challenge: String,
474 pub code_challenge_method: CodeChallengeMethod,
476}
477
478#[derive(Debug, Clone)]
480pub struct TokenRequest {
481 pub grant_type: String,
483 pub code: Option<String>,
485 pub redirect_uri: Option<String>,
487 pub client_id: String,
489 pub client_secret: Option<String>,
491 pub code_verifier: Option<String>,
493 pub refresh_token: Option<String>,
495 pub scopes: Option<Vec<String>>,
497}
498
499#[derive(Debug, Clone)]
505pub enum OAuthError {
506 InvalidRequest(String),
508 InvalidClient(String),
510 InvalidGrant(String),
512 UnauthorizedClient(String),
514 UnsupportedGrantType(String),
516 InvalidScope(String),
518 ServerError(String),
520 TemporarilyUnavailable(String),
522 AccessDenied(String),
524 UnsupportedResponseType(String),
526}
527
528impl OAuthError {
529 #[must_use]
531 pub fn error_code(&self) -> &'static str {
532 match self {
533 Self::InvalidRequest(_) => "invalid_request",
534 Self::InvalidClient(_) => "invalid_client",
535 Self::InvalidGrant(_) => "invalid_grant",
536 Self::UnauthorizedClient(_) => "unauthorized_client",
537 Self::UnsupportedGrantType(_) => "unsupported_grant_type",
538 Self::InvalidScope(_) => "invalid_scope",
539 Self::ServerError(_) => "server_error",
540 Self::TemporarilyUnavailable(_) => "temporarily_unavailable",
541 Self::AccessDenied(_) => "access_denied",
542 Self::UnsupportedResponseType(_) => "unsupported_response_type",
543 }
544 }
545
546 #[must_use]
548 pub fn description(&self) -> &str {
549 match self {
550 Self::InvalidRequest(s)
551 | Self::InvalidClient(s)
552 | Self::InvalidGrant(s)
553 | Self::UnauthorizedClient(s)
554 | Self::UnsupportedGrantType(s)
555 | Self::InvalidScope(s)
556 | Self::ServerError(s)
557 | Self::TemporarilyUnavailable(s)
558 | Self::AccessDenied(s)
559 | Self::UnsupportedResponseType(s) => s,
560 }
561 }
562}
563
564impl std::fmt::Display for OAuthError {
565 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 write!(f, "{}: {}", self.error_code(), self.description())
567 }
568}
569
570impl std::error::Error for OAuthError {}
571
572impl From<OAuthError> for McpError {
573 fn from(err: OAuthError) -> Self {
574 match &err {
575 OAuthError::InvalidClient(_) | OAuthError::UnauthorizedClient(_) => {
576 McpError::new(McpErrorCode::ResourceForbidden, err.to_string())
577 }
578 OAuthError::AccessDenied(_) => {
579 McpError::new(McpErrorCode::ResourceForbidden, err.to_string())
580 }
581 _ => McpError::new(McpErrorCode::InvalidRequest, err.to_string()),
582 }
583 }
584}
585
586pub(crate) struct OAuthServerState {
592 pub(crate) clients: HashMap<String, OAuthClient>,
594 pub(crate) authorization_codes: HashMap<String, AuthorizationCode>,
596 pub(crate) access_tokens: HashMap<String, OAuthToken>,
598 pub(crate) refresh_tokens: HashMap<String, OAuthToken>,
600 pub(crate) revoked_tokens: HashSet<String>,
602}
603
604impl OAuthServerState {
605 fn new() -> Self {
606 Self {
607 clients: HashMap::new(),
608 authorization_codes: HashMap::new(),
609 access_tokens: HashMap::new(),
610 refresh_tokens: HashMap::new(),
611 revoked_tokens: HashSet::new(),
612 }
613 }
614}
615
616pub struct OAuthServer {
621 config: OAuthServerConfig,
622 pub(crate) state: RwLock<OAuthServerState>,
623}
624
625impl OAuthServer {
626 #[must_use]
628 pub fn new(config: OAuthServerConfig) -> Self {
629 Self {
630 config,
631 state: RwLock::new(OAuthServerState::new()),
632 }
633 }
634
635 #[must_use]
637 pub fn with_defaults() -> Self {
638 Self::new(OAuthServerConfig::default())
639 }
640
641 #[must_use]
643 pub fn config(&self) -> &OAuthServerConfig {
644 &self.config
645 }
646
647 pub fn register_client(&self, client: OAuthClient) -> Result<(), OAuthError> {
659 if client.client_type == ClientType::Public && !self.config.allow_public_clients {
660 return Err(OAuthError::InvalidClient(
661 "public clients are not allowed".to_string(),
662 ));
663 }
664
665 let mut state = self
666 .state
667 .write()
668 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
669
670 if state.clients.contains_key(&client.client_id) {
671 return Err(OAuthError::InvalidClient(format!(
672 "client '{}' already exists",
673 client.client_id
674 )));
675 }
676
677 state.clients.insert(client.client_id.clone(), client);
678 Ok(())
679 }
680
681 pub fn unregister_client(&self, client_id: &str) -> Result<(), OAuthError> {
685 let mut state = self
686 .state
687 .write()
688 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
689
690 if state.clients.remove(client_id).is_none() {
691 return Err(OAuthError::InvalidClient(format!(
692 "client '{}' not found",
693 client_id
694 )));
695 }
696
697 let access_tokens: Vec<_> = state
699 .access_tokens
700 .iter()
701 .filter(|(_, t)| t.client_id == client_id)
702 .map(|(k, _)| k.clone())
703 .collect();
704 for token in access_tokens {
705 state.access_tokens.remove(&token);
706 state.revoked_tokens.insert(token);
707 }
708
709 let refresh_tokens: Vec<_> = state
710 .refresh_tokens
711 .iter()
712 .filter(|(_, t)| t.client_id == client_id)
713 .map(|(k, _)| k.clone())
714 .collect();
715 for token in refresh_tokens {
716 state.refresh_tokens.remove(&token);
717 state.revoked_tokens.insert(token);
718 }
719
720 let codes: Vec<_> = state
722 .authorization_codes
723 .iter()
724 .filter(|(_, c)| c.client_id == client_id)
725 .map(|(k, _)| k.clone())
726 .collect();
727 for code in codes {
728 state.authorization_codes.remove(&code);
729 }
730
731 Ok(())
732 }
733
734 #[must_use]
736 pub fn get_client(&self, client_id: &str) -> Option<OAuthClient> {
737 self.state
738 .read()
739 .ok()
740 .and_then(|s| s.clients.get(client_id).cloned())
741 }
742
743 #[must_use]
745 pub fn list_clients(&self) -> Vec<OAuthClient> {
746 self.state
747 .read()
748 .map(|s| s.clients.values().cloned().collect())
749 .unwrap_or_default()
750 }
751
752 pub fn authorize(
770 &self,
771 request: &AuthorizationRequest,
772 subject: Option<String>,
773 ) -> Result<(String, String), OAuthError> {
774 if request.response_type != "code" {
776 return Err(OAuthError::UnsupportedResponseType(
777 "only 'code' response_type is supported".to_string(),
778 ));
779 }
780
781 let client = self.get_client(&request.client_id).ok_or_else(|| {
783 OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
784 })?;
785
786 if !client.validate_redirect_uri(&request.redirect_uri) {
788 return Err(OAuthError::InvalidRequest(
789 "invalid redirect_uri".to_string(),
790 ));
791 }
792
793 if !client.validate_scopes(&request.scopes) {
795 return Err(OAuthError::InvalidScope(
796 "requested scope not allowed".to_string(),
797 ));
798 }
799
800 if request.code_challenge.is_empty() {
802 return Err(OAuthError::InvalidRequest(
803 "code_challenge is required (PKCE)".to_string(),
804 ));
805 }
806
807 let code_value = generate_token(self.config.token_entropy_bytes)?;
809 let now = Instant::now();
810 let code = AuthorizationCode {
811 code: code_value.clone(),
812 client_id: request.client_id.clone(),
813 redirect_uri: request.redirect_uri.clone(),
814 scopes: request.scopes.clone(),
815 code_challenge: request.code_challenge.clone(),
816 code_challenge_method: request.code_challenge_method,
817 issued_at: now,
818 expires_at: now + self.config.authorization_code_lifetime,
819 subject,
820 state: request.state.clone(),
821 };
822
823 {
825 let mut state = self
826 .state
827 .write()
828 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
829 state.authorization_codes.insert(code_value.clone(), code);
830 }
831
832 let mut redirect = request.redirect_uri.clone();
834 let separator = if redirect.contains('?') { '&' } else { '?' };
835 redirect.push(separator);
836 redirect.push_str("code=");
837 redirect.push_str(&url_encode(&code_value));
838 if let Some(state) = &request.state {
839 redirect.push_str("&state=");
840 redirect.push_str(&url_encode(state));
841 }
842
843 Ok((code_value, redirect))
844 }
845
846 pub fn token(&self, request: &TokenRequest) -> Result<TokenResponse, OAuthError> {
852 match request.grant_type.as_str() {
853 "authorization_code" => self.token_authorization_code(request),
854 "refresh_token" => self.token_refresh_token(request),
855 other => Err(OAuthError::UnsupportedGrantType(format!(
856 "grant_type '{}' is not supported",
857 other
858 ))),
859 }
860 }
861
862 fn token_authorization_code(
863 &self,
864 request: &TokenRequest,
865 ) -> Result<TokenResponse, OAuthError> {
866 let code_value = request
868 .code
869 .as_ref()
870 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
871 let redirect_uri = request
872 .redirect_uri
873 .as_ref()
874 .ok_or_else(|| OAuthError::InvalidRequest("redirect_uri is required".to_string()))?;
875 let code_verifier = request.code_verifier.as_ref().ok_or_else(|| {
876 OAuthError::InvalidRequest("code_verifier is required (PKCE)".to_string())
877 })?;
878
879 if code_verifier.len() < self.config.min_code_verifier_length
881 || code_verifier.len() > self.config.max_code_verifier_length
882 {
883 return Err(OAuthError::InvalidRequest(format!(
884 "code_verifier must be between {} and {} characters",
885 self.config.min_code_verifier_length, self.config.max_code_verifier_length
886 )));
887 }
888
889 let auth_code = {
891 let mut state = self
892 .state
893 .write()
894 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
895
896 state
898 .authorization_codes
899 .remove(code_value)
900 .ok_or_else(|| {
901 OAuthError::InvalidGrant(
902 "authorization code not found or already used".to_string(),
903 )
904 })?
905 };
906
907 if auth_code.is_expired() {
909 return Err(OAuthError::InvalidGrant(
910 "authorization code has expired".to_string(),
911 ));
912 }
913 if auth_code.client_id != request.client_id {
914 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
915 }
916 if auth_code.redirect_uri != *redirect_uri {
917 return Err(OAuthError::InvalidGrant(
918 "redirect_uri mismatch".to_string(),
919 ));
920 }
921
922 if !auth_code.validate_code_verifier(code_verifier) {
924 return Err(OAuthError::InvalidGrant(
925 "code_verifier validation failed".to_string(),
926 ));
927 }
928
929 let client = self.get_client(&request.client_id).ok_or_else(|| {
931 OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
932 })?;
933
934 if client.client_type == ClientType::Confidential {
935 if !client.authenticate(request.client_secret.as_deref()) {
936 return Err(OAuthError::InvalidClient(
937 "client authentication failed".to_string(),
938 ));
939 }
940 }
941
942 self.issue_tokens(
944 &auth_code.client_id,
945 &auth_code.scopes,
946 auth_code.subject.as_deref(),
947 )
948 }
949
950 fn token_refresh_token(&self, request: &TokenRequest) -> Result<TokenResponse, OAuthError> {
951 let refresh_value = request
952 .refresh_token
953 .as_ref()
954 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
955
956 let stored_refresh = {
958 let state = self
959 .state
960 .read()
961 .map_err(|_| OAuthError::ServerError("failed to acquire read lock".to_string()))?;
962
963 if state.revoked_tokens.contains(refresh_value) {
965 return Err(OAuthError::InvalidGrant(
966 "refresh token has been revoked".to_string(),
967 ));
968 }
969
970 state
971 .refresh_tokens
972 .get(refresh_value)
973 .cloned()
974 .ok_or_else(|| OAuthError::InvalidGrant("refresh token not found".to_string()))?
975 };
976
977 if stored_refresh.is_expired() {
978 return Err(OAuthError::InvalidGrant(
979 "refresh token has expired".to_string(),
980 ));
981 }
982 if stored_refresh.client_id != request.client_id {
983 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
984 }
985
986 let client = self.get_client(&request.client_id).ok_or_else(|| {
988 OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
989 })?;
990
991 if client.client_type == ClientType::Confidential {
992 if !client.authenticate(request.client_secret.as_deref()) {
993 return Err(OAuthError::InvalidClient(
994 "client authentication failed".to_string(),
995 ));
996 }
997 }
998
999 let scopes = if let Some(requested) = &request.scopes {
1001 for scope in requested {
1003 if !stored_refresh.scopes.contains(scope) {
1004 return Err(OAuthError::InvalidScope(format!(
1005 "scope '{}' was not in original grant",
1006 scope
1007 )));
1008 }
1009 }
1010 requested.clone()
1011 } else {
1012 stored_refresh.scopes.clone()
1013 };
1014
1015 let now = Instant::now();
1017 let access_value = generate_token(self.config.token_entropy_bytes)?;
1018 let issued_access = OAuthToken {
1019 token: access_value.clone(),
1020 token_type: TokenType::Bearer,
1021 client_id: request.client_id.clone(),
1022 scopes: scopes.clone(),
1023 issued_at: now,
1024 expires_at: now + self.config.access_token_lifetime,
1025 subject: stored_refresh.subject.clone(),
1026 is_refresh_token: false,
1027 };
1028
1029 {
1031 let mut state = self
1032 .state
1033 .write()
1034 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1035 state
1036 .access_tokens
1037 .insert(access_value.clone(), issued_access.clone());
1038 }
1039
1040 Ok(TokenResponse {
1041 access_token: access_value,
1042 token_type: issued_access.token_type.as_str().to_string(),
1043 expires_in: issued_access.expires_in_secs(),
1044 refresh_token: None, scope: if scopes.is_empty() {
1046 None
1047 } else {
1048 Some(scopes.join(" "))
1049 },
1050 })
1051 }
1052
1053 fn issue_tokens(
1054 &self,
1055 client_id: &str,
1056 scopes: &[String],
1057 subject: Option<&str>,
1058 ) -> Result<TokenResponse, OAuthError> {
1059 let now = Instant::now();
1060
1061 let access_value = generate_token(self.config.token_entropy_bytes)?;
1063 let access_cred = OAuthToken {
1064 token: access_value.clone(),
1065 token_type: TokenType::Bearer,
1066 client_id: client_id.to_string(),
1067 scopes: scopes.to_vec(),
1068 issued_at: now,
1069 expires_at: now + self.config.access_token_lifetime,
1070 subject: subject.map(String::from),
1071 is_refresh_token: false,
1072 };
1073
1074 let refresh_value = generate_token(self.config.token_entropy_bytes)?;
1076 let refresh_cred = OAuthToken {
1077 token: refresh_value.clone(),
1078 token_type: TokenType::Bearer,
1079 client_id: client_id.to_string(),
1080 scopes: scopes.to_vec(),
1081 issued_at: now,
1082 expires_at: now + self.config.refresh_token_lifetime,
1083 subject: subject.map(String::from),
1084 is_refresh_token: true,
1085 };
1086
1087 {
1089 let mut state = self
1090 .state
1091 .write()
1092 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1093 state
1094 .access_tokens
1095 .insert(access_value.clone(), access_cred.clone());
1096 state
1097 .refresh_tokens
1098 .insert(refresh_value.clone(), refresh_cred);
1099 }
1100
1101 Ok(TokenResponse {
1102 access_token: access_value,
1103 token_type: access_cred.token_type.as_str().to_string(),
1104 expires_in: access_cred.expires_in_secs(),
1105 refresh_token: Some(refresh_value),
1106 scope: if scopes.is_empty() {
1107 None
1108 } else {
1109 Some(scopes.join(" "))
1110 },
1111 })
1112 }
1113
1114 pub fn revoke(
1122 &self,
1123 token: &str,
1124 client_id: &str,
1125 client_secret: Option<&str>,
1126 ) -> Result<(), OAuthError> {
1127 let client = self.get_client(client_id).ok_or_else(|| {
1129 OAuthError::InvalidClient(format!("client '{}' not found", client_id))
1130 })?;
1131
1132 if client.client_type == ClientType::Confidential {
1133 if !client.authenticate(client_secret) {
1134 return Err(OAuthError::InvalidClient(
1135 "client authentication failed".to_string(),
1136 ));
1137 }
1138 }
1139
1140 let mut state = self
1141 .state
1142 .write()
1143 .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1144
1145 let found_access = state.access_tokens.remove(token);
1147 let found_refresh = state.refresh_tokens.remove(token);
1148
1149 if let Some(ref t) = found_access {
1151 if t.client_id != client_id {
1152 return Ok(());
1154 }
1155 }
1156 if let Some(ref t) = found_refresh {
1157 if t.client_id != client_id {
1158 return Ok(());
1159 }
1160 }
1161
1162 if found_access.is_some() || found_refresh.is_some() {
1164 state.revoked_tokens.insert(token.to_string());
1165 }
1166
1167 Ok(())
1168 }
1169
1170 pub fn validate_access_token(&self, token: &str) -> Option<OAuthToken> {
1178 let state = self.state.read().ok()?;
1179
1180 if state.revoked_tokens.contains(token) {
1182 return None;
1183 }
1184
1185 let token_info = state.access_tokens.get(token)?;
1186
1187 if token_info.is_expired() {
1188 return None;
1189 }
1190
1191 Some(token_info.clone())
1192 }
1193
1194 #[must_use]
1200 pub fn token_verifier(self: &Arc<Self>) -> OAuthTokenVerifier {
1201 OAuthTokenVerifier {
1202 server: Arc::clone(self),
1203 }
1204 }
1205
1206 pub fn cleanup_expired(&self) {
1214 let Ok(mut state) = self.state.write() else {
1215 return;
1216 };
1217
1218 state.authorization_codes.retain(|_, c| !c.is_expired());
1220
1221 state.access_tokens.retain(|_, t| !t.is_expired());
1223
1224 state.refresh_tokens.retain(|_, t| !t.is_expired());
1226 }
1227
1228 #[must_use]
1230 pub fn stats(&self) -> OAuthServerStats {
1231 let state = match self.state.read() {
1232 Ok(guard) => guard,
1233 Err(poisoned) => poisoned.into_inner(),
1235 };
1236 OAuthServerStats {
1237 clients: state.clients.len(),
1238 authorization_codes: state.authorization_codes.len(),
1239 access_tokens: state.access_tokens.len(),
1240 refresh_tokens: state.refresh_tokens.len(),
1241 revoked_tokens: state.revoked_tokens.len(),
1242 }
1243 }
1244}
1245
1246#[derive(Debug, Clone, Default)]
1248pub struct OAuthServerStats {
1249 pub clients: usize,
1251 pub authorization_codes: usize,
1253 pub access_tokens: usize,
1255 pub refresh_tokens: usize,
1257 pub revoked_tokens: usize,
1259}
1260
1261pub struct OAuthTokenVerifier {
1270 server: Arc<OAuthServer>,
1271}
1272
1273impl TokenVerifier for OAuthTokenVerifier {
1274 fn verify(
1275 &self,
1276 _ctx: &McpContext,
1277 _request: AuthRequest<'_>,
1278 token: &AccessToken,
1279 ) -> McpResult<AuthContext> {
1280 if !token.scheme.eq_ignore_ascii_case("Bearer") {
1282 return Err(McpError::new(
1283 McpErrorCode::ResourceForbidden,
1284 "unsupported auth scheme",
1285 ));
1286 }
1287
1288 let token_info = self
1290 .server
1291 .validate_access_token(&token.token)
1292 .ok_or_else(|| {
1293 McpError::new(McpErrorCode::ResourceForbidden, "invalid or expired token")
1294 })?;
1295
1296 Ok(AuthContext {
1297 subject: token_info.subject,
1298 scopes: token_info.scopes,
1299 token: Some(token.clone()),
1300 claims: Some(serde_json::json!({
1301 "client_id": token_info.client_id,
1302 "iss": self.server.config.issuer,
1303 "iat": token_info.issued_at.elapsed().as_secs(),
1304 })),
1305 })
1306 }
1307}
1308
1309fn generate_token(bytes: usize) -> Result<String, OAuthError> {
1315 let mut buf = vec![0u8; bytes];
1316 getrandom::fill(&mut buf)
1317 .map_err(|e| OAuthError::ServerError(format!("secure random generation failed: {e}")))?;
1318
1319 Ok(base64url_encode(&buf))
1321}
1322
1323fn base64url_encode(data: &[u8]) -> String {
1325 use base64::Engine;
1326 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
1327 URL_SAFE_NO_PAD.encode(data)
1328}
1329
1330fn compute_s256_challenge(verifier: &str) -> String {
1332 use sha2::Digest;
1333 let hash = sha2::Sha256::digest(verifier.as_bytes());
1334 base64url_encode(&hash)
1335}
1336
1337fn url_encode(s: &str) -> String {
1339 let mut result = String::with_capacity(s.len() * 3);
1340 for byte in s.bytes() {
1341 match byte {
1342 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1343 result.push(byte as char);
1344 }
1345 _ => {
1346 result.push('%');
1347 result.push_str(&format!("{:02X}", byte));
1348 }
1349 }
1350 }
1351 result
1352}
1353
1354fn constant_time_eq(a: &str, b: &str) -> bool {
1356 if a.len() != b.len() {
1357 return false;
1358 }
1359
1360 let mut result = 0u8;
1361 for (x, y) in a.bytes().zip(b.bytes()) {
1362 result |= x ^ y;
1363 }
1364 result == 0
1365}
1366
1367fn is_localhost_redirect(uri: &str) -> bool {
1369 uri.starts_with("http://localhost")
1370 || uri.starts_with("http://127.0.0.1")
1371 || uri.starts_with("http://[::1]")
1372}
1373
1374fn localhost_match(a: &str, b: &str) -> bool {
1376 fn extract_parts(uri: &str) -> Option<(String, String)> {
1378 let after_scheme = uri.strip_prefix("http://")?;
1379 let path_start = after_scheme.find('/').unwrap_or(after_scheme.len());
1381 let host_port = &after_scheme[..path_start];
1382 let path = &after_scheme[path_start..];
1383
1384 let host = host_port.rsplit_once(':').map_or(host_port, |(h, _)| h);
1386 Some((host.to_string(), path.to_string()))
1387 }
1388
1389 match (extract_parts(a), extract_parts(b)) {
1390 (Some((host_a, path_a)), Some((host_b, path_b))) => {
1391 normalize_localhost(&host_a) == normalize_localhost(&host_b) && path_a == path_b
1392 }
1393 _ => false,
1394 }
1395}
1396
1397fn normalize_localhost(host: &str) -> &'static str {
1399 match host {
1400 "localhost" | "127.0.0.1" | "[::1]" => "localhost",
1401 _ => "other",
1402 }
1403}
1404
1405#[cfg(test)]
1410mod tests {
1411 use super::*;
1412
1413 fn issue_access_token_via_auth_code(
1414 server: &OAuthServer,
1415 client_id: &str,
1416 redirect_uri: &str,
1417 scopes: &[&str],
1418 subject: &str,
1419 ) -> TokenResponse {
1420 let code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string();
1421 let auth_request = AuthorizationRequest {
1422 response_type: "code".to_string(),
1423 client_id: client_id.to_string(),
1424 redirect_uri: redirect_uri.to_string(),
1425 scopes: scopes.iter().map(|scope| (*scope).to_string()).collect(),
1426 state: Some("oauth-test-state".to_string()),
1427 code_challenge: code_verifier.clone(),
1428 code_challenge_method: CodeChallengeMethod::Plain,
1429 };
1430
1431 let (code, _redirect) = server
1432 .authorize(&auth_request, Some(subject.to_string()))
1433 .expect("authorize");
1434 server
1435 .token(&TokenRequest {
1436 grant_type: "authorization_code".to_string(),
1437 code: Some(code),
1438 redirect_uri: Some(redirect_uri.to_string()),
1439 client_id: client_id.to_string(),
1440 client_secret: None,
1441 code_verifier: Some(code_verifier),
1442 refresh_token: None,
1443 scopes: None,
1444 })
1445 .expect("token exchange")
1446 }
1447
1448 #[test]
1449 fn test_client_builder() {
1450 let client = OAuthClient::builder("test-client")
1451 .redirect_uri("http://localhost:3000/callback")
1452 .scope("read")
1453 .scope("write")
1454 .name("Test Client")
1455 .build()
1456 .unwrap();
1457
1458 assert_eq!(client.client_id, "test-client");
1459 assert_eq!(client.client_type, ClientType::Public);
1460 assert_eq!(client.redirect_uris.len(), 1);
1461 assert!(client.allowed_scopes.contains("read"));
1462 assert!(client.allowed_scopes.contains("write"));
1463 }
1464
1465 #[test]
1466 fn test_confidential_client() {
1467 let client = OAuthClient::builder("test-client")
1468 .secret("super-secret")
1469 .redirect_uri("http://localhost:3000/callback")
1470 .build()
1471 .unwrap();
1472
1473 assert_eq!(client.client_type, ClientType::Confidential);
1474 assert!(client.authenticate(Some("super-secret")));
1475 assert!(!client.authenticate(Some("wrong-secret")));
1476 assert!(!client.authenticate(None));
1477 }
1478
1479 #[test]
1480 fn test_redirect_uri_validation() {
1481 let client = OAuthClient::builder("test-client")
1482 .redirect_uri("http://localhost:3000/callback")
1483 .redirect_uri("https://example.com/oauth/callback")
1484 .build()
1485 .unwrap();
1486
1487 assert!(client.validate_redirect_uri("http://localhost:3000/callback"));
1489 assert!(client.validate_redirect_uri("https://example.com/oauth/callback"));
1490
1491 assert!(client.validate_redirect_uri("http://localhost:8080/callback"));
1493 assert!(client.validate_redirect_uri("http://127.0.0.1:9000/callback"));
1494
1495 assert!(!client.validate_redirect_uri("http://localhost:3000/other"));
1497 assert!(!client.validate_redirect_uri("https://evil.com/callback"));
1498 }
1499
1500 #[test]
1501 fn test_scope_validation() {
1502 let client = OAuthClient::builder("test-client")
1503 .redirect_uri("http://localhost:3000/callback")
1504 .scope("read")
1505 .scope("write")
1506 .build()
1507 .unwrap();
1508
1509 assert!(client.validate_scopes(&["read".to_string()]));
1510 assert!(client.validate_scopes(&["read".to_string(), "write".to_string()]));
1511 assert!(!client.validate_scopes(&["admin".to_string()]));
1512 }
1513
1514 #[test]
1515 fn test_oauth_server_client_registration() {
1516 let server = OAuthServer::with_defaults();
1517
1518 let client = OAuthClient::builder("test-client")
1519 .redirect_uri("http://localhost:3000/callback")
1520 .build()
1521 .unwrap();
1522
1523 server.register_client(client).unwrap();
1524
1525 let client2 = OAuthClient::builder("test-client")
1527 .redirect_uri("http://localhost:3000/callback")
1528 .build()
1529 .unwrap();
1530 assert!(server.register_client(client2).is_err());
1531
1532 assert!(server.get_client("test-client").is_some());
1534 assert!(server.get_client("nonexistent").is_none());
1535 }
1536
1537 #[test]
1538 fn test_authorization_flow() {
1539 let server = OAuthServer::with_defaults();
1540
1541 let client = OAuthClient::builder("test-client")
1542 .redirect_uri("http://localhost:3000/callback")
1543 .scope("read")
1544 .build()
1545 .unwrap();
1546 server.register_client(client).unwrap();
1547
1548 let request = AuthorizationRequest {
1550 response_type: "code".to_string(),
1551 client_id: "test-client".to_string(),
1552 redirect_uri: "http://localhost:3000/callback".to_string(),
1553 scopes: vec!["read".to_string()],
1554 state: Some("xyz".to_string()),
1555 code_challenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM".to_string(),
1556 code_challenge_method: CodeChallengeMethod::S256,
1557 };
1558
1559 let (code, redirect) = server
1560 .authorize(&request, Some("user123".to_string()))
1561 .unwrap();
1562
1563 assert!(!code.is_empty());
1564 assert!(redirect.contains("code="));
1565 assert!(redirect.contains("state=xyz"));
1566 }
1567
1568 #[test]
1569 fn test_pkce_required() {
1570 let server = OAuthServer::with_defaults();
1571
1572 let client = OAuthClient::builder("test-client")
1573 .redirect_uri("http://localhost:3000/callback")
1574 .build()
1575 .unwrap();
1576 server.register_client(client).unwrap();
1577
1578 let request = AuthorizationRequest {
1580 response_type: "code".to_string(),
1581 client_id: "test-client".to_string(),
1582 redirect_uri: "http://localhost:3000/callback".to_string(),
1583 scopes: vec![],
1584 state: None,
1585 code_challenge: String::new(), code_challenge_method: CodeChallengeMethod::S256,
1587 };
1588
1589 let result = server.authorize(&request, None);
1590 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
1591 }
1592
1593 #[test]
1594 fn test_token_generation() {
1595 let value1 = generate_token(32).unwrap();
1596 let value2 = generate_token(32).unwrap();
1597
1598 assert_ne!(value1, value2);
1600 assert!(
1602 value1
1603 .chars()
1604 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
1605 );
1606 }
1607
1608 #[test]
1609 fn test_base64url_encode() {
1610 assert_eq!(base64url_encode(b""), "");
1612 assert_eq!(base64url_encode(b"f"), "Zg");
1613 assert_eq!(base64url_encode(b"fo"), "Zm8");
1614 assert_eq!(base64url_encode(b"foo"), "Zm9v");
1615 assert_eq!(base64url_encode(b"foob"), "Zm9vYg");
1616 assert_eq!(base64url_encode(b"fooba"), "Zm9vYmE");
1617 assert_eq!(base64url_encode(b"foobar"), "Zm9vYmFy");
1618 }
1619
1620 #[test]
1621 fn test_url_encode() {
1622 assert_eq!(url_encode("hello"), "hello");
1623 assert_eq!(url_encode("hello world"), "hello%20world");
1624 assert_eq!(url_encode("a=b&c=d"), "a%3Db%26c%3Dd");
1625 }
1626
1627 #[test]
1628 fn test_constant_time_eq() {
1629 assert!(constant_time_eq("hello", "hello"));
1630 assert!(!constant_time_eq("hello", "world"));
1631 assert!(!constant_time_eq("hello", "hell"));
1632 }
1633
1634 #[test]
1635 fn test_localhost_match() {
1636 assert!(localhost_match(
1637 "http://localhost:3000/callback",
1638 "http://localhost:8080/callback"
1639 ));
1640 assert!(localhost_match(
1641 "http://127.0.0.1:3000/callback",
1642 "http://localhost:8080/callback"
1643 ));
1644 assert!(!localhost_match(
1645 "http://localhost:3000/callback",
1646 "http://localhost:3000/other"
1647 ));
1648 }
1649
1650 #[test]
1651 fn test_oauth_server_stats() {
1652 let server = OAuthServer::with_defaults();
1653
1654 let stats = server.stats();
1655 assert_eq!(stats.clients, 0);
1656 assert_eq!(stats.access_tokens, 0);
1657
1658 let client = OAuthClient::builder("test-client")
1659 .redirect_uri("http://localhost:3000/callback")
1660 .build()
1661 .unwrap();
1662 server.register_client(client).unwrap();
1663
1664 let stats = server.stats();
1665 assert_eq!(stats.clients, 1);
1666 }
1667
1668 #[test]
1669 fn test_code_challenge_method_parse() {
1670 assert_eq!(
1671 CodeChallengeMethod::parse("plain"),
1672 Some(CodeChallengeMethod::Plain)
1673 );
1674 assert_eq!(
1675 CodeChallengeMethod::parse("S256"),
1676 Some(CodeChallengeMethod::S256)
1677 );
1678 assert_eq!(CodeChallengeMethod::parse("unknown"), None);
1679 }
1680
1681 #[test]
1682 fn test_oauth_error_display() {
1683 let err = OAuthError::InvalidRequest("missing parameter".to_string());
1684 assert_eq!(err.error_code(), "invalid_request");
1685 assert_eq!(err.description(), "missing parameter");
1686 assert_eq!(err.to_string(), "invalid_request: missing parameter");
1687 }
1688
1689 #[test]
1690 fn test_token_revocation() {
1691 let server = Arc::new(OAuthServer::with_defaults());
1692
1693 let client = OAuthClient::builder("test-client")
1695 .redirect_uri("http://localhost:3000/callback")
1696 .scope("read")
1697 .build()
1698 .unwrap();
1699 server.register_client(client).unwrap();
1700
1701 let token_response = issue_access_token_via_auth_code(
1702 server.as_ref(),
1703 "test-client",
1704 "http://localhost:3000/callback",
1705 &["read"],
1706 "user123",
1707 );
1708
1709 assert!(
1711 server
1712 .validate_access_token(&token_response.access_token)
1713 .is_some()
1714 );
1715
1716 server
1718 .revoke(&token_response.access_token, "test-client", None)
1719 .unwrap();
1720
1721 assert!(
1723 server
1724 .validate_access_token(&token_response.access_token)
1725 .is_none()
1726 );
1727 }
1728
1729 #[test]
1730 fn test_client_unregistration() {
1731 let server = OAuthServer::with_defaults();
1732
1733 let client = OAuthClient::builder("test-client")
1734 .redirect_uri("http://localhost:3000/callback")
1735 .build()
1736 .unwrap();
1737 server.register_client(client).unwrap();
1738
1739 assert!(server.get_client("test-client").is_some());
1740
1741 server.unregister_client("test-client").unwrap();
1742
1743 assert!(server.get_client("test-client").is_none());
1744
1745 assert!(server.unregister_client("test-client").is_err());
1747 }
1748
1749 #[test]
1750 fn test_token_verifier() {
1751 let server = Arc::new(OAuthServer::with_defaults());
1752
1753 let client = OAuthClient::builder("test-client")
1755 .redirect_uri("http://localhost:3000/callback")
1756 .scope("read")
1757 .build()
1758 .unwrap();
1759 server.register_client(client).unwrap();
1760
1761 let token_response = issue_access_token_via_auth_code(
1762 server.as_ref(),
1763 "test-client",
1764 "http://localhost:3000/callback",
1765 &["read"],
1766 "user123",
1767 );
1768
1769 let verifier = server.token_verifier();
1771 let cx = asupersync::Cx::for_testing();
1772 let mcp_ctx = McpContext::new(cx, 1);
1773 let auth_request = AuthRequest {
1774 method: "test",
1775 params: None,
1776 request_id: 1,
1777 };
1778
1779 let access = AccessToken {
1781 scheme: "Bearer".to_string(),
1782 token: token_response.access_token.clone(),
1783 };
1784 let result = verifier.verify(&mcp_ctx, auth_request, &access);
1785 assert!(result.is_ok());
1786 let auth = result.unwrap();
1787 assert_eq!(auth.subject, Some("user123".to_string()));
1788 assert_eq!(auth.scopes, vec!["read".to_string()]);
1789
1790 let invalid = AccessToken {
1792 scheme: "Bearer".to_string(),
1793 token: "invalid-value".to_string(),
1794 };
1795 let result = verifier.verify(&mcp_ctx, auth_request, &invalid);
1796 assert!(result.is_err());
1797
1798 let wrong_scheme = AccessToken {
1800 scheme: "Basic".to_string(),
1801 token: token_response.access_token,
1802 };
1803 let result = verifier.verify(&mcp_ctx, auth_request, &wrong_scheme);
1804 assert!(result.is_err());
1805 }
1806
1807 #[test]
1812 fn config_default_values() {
1813 let c = OAuthServerConfig::default();
1814 assert_eq!(c.issuer, "fastmcp");
1815 assert_eq!(c.access_token_lifetime, Duration::from_secs(3600));
1816 assert_eq!(c.refresh_token_lifetime, Duration::from_secs(86400 * 30));
1817 assert_eq!(c.authorization_code_lifetime, Duration::from_secs(600));
1818 assert!(c.allow_public_clients);
1819 assert_eq!(c.min_code_verifier_length, 43);
1820 assert_eq!(c.max_code_verifier_length, 128);
1821 assert_eq!(c.token_entropy_bytes, 32);
1822 }
1823
1824 #[test]
1825 fn config_debug_and_clone() {
1826 let c = OAuthServerConfig::default();
1827 let debug = format!("{:?}", c);
1828 assert!(debug.contains("OAuthServerConfig"));
1829 assert!(debug.contains("fastmcp"));
1830
1831 let cloned = c.clone();
1832 assert_eq!(cloned.issuer, "fastmcp");
1833 }
1834
1835 #[test]
1840 fn client_type_debug_and_eq() {
1841 assert_eq!(ClientType::Public, ClientType::Public);
1842 assert_ne!(ClientType::Public, ClientType::Confidential);
1843 let debug = format!("{:?}", ClientType::Confidential);
1844 assert!(debug.contains("Confidential"));
1845 }
1846
1847 #[test]
1848 fn client_type_copy() {
1849 let t = ClientType::Public;
1850 let t2 = t; assert_eq!(t, t2);
1852 }
1853
1854 #[test]
1859 fn client_debug_and_clone() {
1860 let client = OAuthClient::builder("dbg")
1861 .redirect_uri("http://localhost/cb")
1862 .build()
1863 .unwrap();
1864 let debug = format!("{:?}", client);
1865 assert!(debug.contains("OAuthClient"));
1866 assert!(debug.contains("dbg"));
1867
1868 let cloned = client.clone();
1869 assert_eq!(cloned.client_id, "dbg");
1870 }
1871
1872 #[test]
1873 fn client_authenticate_public_no_secret() {
1874 let client = OAuthClient::builder("pub")
1875 .redirect_uri("http://localhost/cb")
1876 .build()
1877 .unwrap();
1878 assert!(client.authenticate(None));
1880 assert!(!client.authenticate(Some("any")));
1882 }
1883
1884 #[test]
1885 fn client_validate_redirect_uri_non_localhost() {
1886 let client = OAuthClient::builder("c")
1887 .redirect_uri("https://example.com/cb")
1888 .build()
1889 .unwrap();
1890 assert!(client.validate_redirect_uri("https://example.com/cb"));
1892 assert!(!client.validate_redirect_uri("https://example.com/cb2"));
1893 assert!(!client.validate_redirect_uri("https://other.com/cb"));
1894 }
1895
1896 #[test]
1897 fn client_validate_redirect_uri_localhost_ipv6() {
1898 let client = OAuthClient::builder("c")
1899 .redirect_uri("http://[::1]:3000/callback")
1900 .build()
1901 .unwrap();
1902 assert!(client.validate_redirect_uri("http://[::1]:8080/callback"));
1904 assert!(client.validate_redirect_uri("http://localhost:9000/callback"));
1906 }
1907
1908 #[test]
1909 fn client_validate_scopes_empty() {
1910 let client = OAuthClient::builder("c")
1911 .redirect_uri("http://localhost/cb")
1912 .scope("read")
1913 .build()
1914 .unwrap();
1915 assert!(client.validate_scopes(&[]));
1917 }
1918
1919 #[test]
1924 fn client_builder_debug() {
1925 let builder = OAuthClient::builder("test-id");
1926 let debug = format!("{:?}", builder);
1927 assert!(debug.contains("OAuthClientBuilder"));
1928 assert!(debug.contains("test-id"));
1929 }
1930
1931 #[test]
1932 fn client_builder_empty_id_fails() {
1933 let result = OAuthClient::builder("")
1934 .redirect_uri("http://localhost/cb")
1935 .build();
1936 assert!(result.is_err());
1937 }
1938
1939 #[test]
1940 fn client_builder_no_redirect_uris_fails() {
1941 let result = OAuthClient::builder("c").build();
1942 assert!(result.is_err());
1943 }
1944
1945 #[test]
1946 fn client_builder_redirect_uris_multiple() {
1947 let client = OAuthClient::builder("c")
1948 .redirect_uris(vec!["http://localhost/a", "http://localhost/b"])
1949 .build()
1950 .unwrap();
1951 assert_eq!(client.redirect_uris.len(), 2);
1952 }
1953
1954 #[test]
1955 fn client_builder_scopes_multiple() {
1956 let client = OAuthClient::builder("c")
1957 .redirect_uri("http://localhost/cb")
1958 .scopes(vec!["r", "w", "admin"])
1959 .build()
1960 .unwrap();
1961 assert_eq!(client.allowed_scopes.len(), 3);
1962 }
1963
1964 #[test]
1965 fn client_builder_description() {
1966 let client = OAuthClient::builder("c")
1967 .redirect_uri("http://localhost/cb")
1968 .description("A test app")
1969 .build()
1970 .unwrap();
1971 assert_eq!(client.description, Some("A test app".to_string()));
1972 }
1973
1974 #[test]
1979 fn code_challenge_method_as_str() {
1980 assert_eq!(CodeChallengeMethod::Plain.as_str(), "plain");
1981 assert_eq!(CodeChallengeMethod::S256.as_str(), "S256");
1982 }
1983
1984 #[test]
1985 fn code_challenge_method_clone_copy_eq() {
1986 let m = CodeChallengeMethod::S256;
1987 let m2 = m; assert_eq!(m, m2);
1989 let m3 = m.clone();
1990 assert_eq!(m, m3);
1991 }
1992
1993 #[test]
1998 fn authorization_code_not_expired_initially() {
1999 let code = AuthorizationCode {
2000 code: "test-code".to_string(),
2001 client_id: "c".to_string(),
2002 redirect_uri: "http://localhost/cb".to_string(),
2003 scopes: vec![],
2004 code_challenge: "challenge".to_string(),
2005 code_challenge_method: CodeChallengeMethod::Plain,
2006 issued_at: Instant::now(),
2007 expires_at: Instant::now() + Duration::from_secs(600),
2008 subject: None,
2009 state: None,
2010 };
2011 assert!(!code.is_expired());
2012 }
2013
2014 #[test]
2015 fn authorization_code_expired() {
2016 let code = AuthorizationCode {
2017 code: "test-code".to_string(),
2018 client_id: "c".to_string(),
2019 redirect_uri: "http://localhost/cb".to_string(),
2020 scopes: vec![],
2021 code_challenge: "challenge".to_string(),
2022 code_challenge_method: CodeChallengeMethod::Plain,
2023 issued_at: Instant::now() - Duration::from_secs(100),
2024 expires_at: Instant::now() - Duration::from_secs(1),
2025 subject: None,
2026 state: None,
2027 };
2028 assert!(code.is_expired());
2029 }
2030
2031 #[test]
2032 fn authorization_code_validate_plain() {
2033 let code = AuthorizationCode {
2034 code: "test".to_string(),
2035 client_id: "c".to_string(),
2036 redirect_uri: "http://localhost/cb".to_string(),
2037 scopes: vec![],
2038 code_challenge: "my-verifier".to_string(),
2039 code_challenge_method: CodeChallengeMethod::Plain,
2040 issued_at: Instant::now(),
2041 expires_at: Instant::now() + Duration::from_secs(600),
2042 subject: None,
2043 state: None,
2044 };
2045 assert!(code.validate_code_verifier("my-verifier"));
2046 assert!(!code.validate_code_verifier("wrong"));
2047 }
2048
2049 #[test]
2050 fn authorization_code_validate_s256() {
2051 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
2052 let challenge = compute_s256_challenge(verifier);
2053 let code = AuthorizationCode {
2054 code: "test".to_string(),
2055 client_id: "c".to_string(),
2056 redirect_uri: "http://localhost/cb".to_string(),
2057 scopes: vec![],
2058 code_challenge: challenge,
2059 code_challenge_method: CodeChallengeMethod::S256,
2060 issued_at: Instant::now(),
2061 expires_at: Instant::now() + Duration::from_secs(600),
2062 subject: None,
2063 state: None,
2064 };
2065 assert!(code.validate_code_verifier(verifier));
2066 assert!(!code.validate_code_verifier("wrong-verifier"));
2067 }
2068
2069 #[test]
2070 fn authorization_code_debug_and_clone() {
2071 let code = AuthorizationCode {
2072 code: "c".to_string(),
2073 client_id: "cid".to_string(),
2074 redirect_uri: "http://localhost/cb".to_string(),
2075 scopes: vec!["read".to_string()],
2076 code_challenge: "ch".to_string(),
2077 code_challenge_method: CodeChallengeMethod::Plain,
2078 issued_at: Instant::now(),
2079 expires_at: Instant::now() + Duration::from_secs(60),
2080 subject: Some("user".to_string()),
2081 state: Some("state".to_string()),
2082 };
2083 let debug = format!("{:?}", code);
2084 assert!(debug.contains("AuthorizationCode"));
2085 let cloned = code.clone();
2086 assert_eq!(cloned.client_id, "cid");
2087 }
2088
2089 #[test]
2094 fn token_type_as_str() {
2095 assert_eq!(TokenType::Bearer.as_str(), "bearer");
2096 }
2097
2098 #[test]
2099 fn token_type_debug_clone_copy_eq() {
2100 let t = TokenType::Bearer;
2101 let t2 = t; assert_eq!(t, t2);
2103 let t3 = t.clone();
2104 assert_eq!(t, t3);
2105 let debug = format!("{:?}", t);
2106 assert!(debug.contains("Bearer"));
2107 }
2108
2109 #[test]
2114 fn oauth_token_not_expired() {
2115 let token = OAuthToken {
2116 token: "t".to_string(),
2117 token_type: TokenType::Bearer,
2118 client_id: "c".to_string(),
2119 scopes: vec![],
2120 issued_at: Instant::now(),
2121 expires_at: Instant::now() + Duration::from_secs(3600),
2122 subject: None,
2123 is_refresh_token: false,
2124 };
2125 assert!(!token.is_expired());
2126 assert!(token.expires_in_secs() > 0);
2127 }
2128
2129 #[test]
2130 fn oauth_token_expired() {
2131 let token = OAuthToken {
2132 token: "t".to_string(),
2133 token_type: TokenType::Bearer,
2134 client_id: "c".to_string(),
2135 scopes: vec![],
2136 issued_at: Instant::now() - Duration::from_secs(100),
2137 expires_at: Instant::now() - Duration::from_secs(1),
2138 subject: None,
2139 is_refresh_token: false,
2140 };
2141 assert!(token.is_expired());
2142 assert_eq!(token.expires_in_secs(), 0);
2143 }
2144
2145 #[test]
2146 fn oauth_token_debug_and_clone() {
2147 let token = OAuthToken {
2148 token: "tok".to_string(),
2149 token_type: TokenType::Bearer,
2150 client_id: "c".to_string(),
2151 scopes: vec!["read".to_string()],
2152 issued_at: Instant::now(),
2153 expires_at: Instant::now() + Duration::from_secs(60),
2154 subject: Some("user".to_string()),
2155 is_refresh_token: true,
2156 };
2157 let debug = format!("{:?}", token);
2158 assert!(debug.contains("OAuthToken"));
2159 let cloned = token.clone();
2160 assert_eq!(cloned.token, "tok");
2161 assert!(cloned.is_refresh_token);
2162 }
2163
2164 #[test]
2169 fn token_response_serialize_without_optional_fields() {
2170 let resp = TokenResponse {
2171 access_token: "at".to_string(),
2172 token_type: "bearer".to_string(),
2173 expires_in: 3600,
2174 refresh_token: None,
2175 scope: None,
2176 };
2177 let json = serde_json::to_string(&resp).unwrap();
2178 assert!(!json.contains("refresh_token"));
2179 assert!(!json.contains("scope"));
2180 }
2181
2182 #[test]
2183 fn token_response_serialize_with_optional_fields() {
2184 let resp = TokenResponse {
2185 access_token: "at".to_string(),
2186 token_type: "bearer".to_string(),
2187 expires_in: 3600,
2188 refresh_token: Some("rt".to_string()),
2189 scope: Some("read write".to_string()),
2190 };
2191 let json = serde_json::to_string(&resp).unwrap();
2192 assert!(json.contains("refresh_token"));
2193 assert!(json.contains("scope"));
2194 }
2195
2196 #[test]
2201 fn authorization_request_debug_and_clone() {
2202 let req = AuthorizationRequest {
2203 response_type: "code".to_string(),
2204 client_id: "c".to_string(),
2205 redirect_uri: "http://localhost/cb".to_string(),
2206 scopes: vec!["read".to_string()],
2207 state: Some("s".to_string()),
2208 code_challenge: "ch".to_string(),
2209 code_challenge_method: CodeChallengeMethod::S256,
2210 };
2211 let debug = format!("{:?}", req);
2212 assert!(debug.contains("AuthorizationRequest"));
2213 let cloned = req.clone();
2214 assert_eq!(cloned.client_id, "c");
2215 }
2216
2217 #[test]
2218 fn token_request_debug_and_clone() {
2219 let req = TokenRequest {
2220 grant_type: "authorization_code".to_string(),
2221 code: Some("code".to_string()),
2222 redirect_uri: Some("http://localhost/cb".to_string()),
2223 client_id: "c".to_string(),
2224 client_secret: None,
2225 code_verifier: Some("verifier".to_string()),
2226 refresh_token: None,
2227 scopes: None,
2228 };
2229 let debug = format!("{:?}", req);
2230 assert!(debug.contains("TokenRequest"));
2231 let cloned = req.clone();
2232 assert_eq!(cloned.grant_type, "authorization_code");
2233 }
2234
2235 #[test]
2240 fn oauth_error_all_codes() {
2241 let cases: Vec<(OAuthError, &str)> = vec![
2242 (OAuthError::InvalidRequest("x".into()), "invalid_request"),
2243 (OAuthError::InvalidClient("x".into()), "invalid_client"),
2244 (OAuthError::InvalidGrant("x".into()), "invalid_grant"),
2245 (
2246 OAuthError::UnauthorizedClient("x".into()),
2247 "unauthorized_client",
2248 ),
2249 (
2250 OAuthError::UnsupportedGrantType("x".into()),
2251 "unsupported_grant_type",
2252 ),
2253 (OAuthError::InvalidScope("x".into()), "invalid_scope"),
2254 (OAuthError::ServerError("x".into()), "server_error"),
2255 (
2256 OAuthError::TemporarilyUnavailable("x".into()),
2257 "temporarily_unavailable",
2258 ),
2259 (OAuthError::AccessDenied("x".into()), "access_denied"),
2260 (
2261 OAuthError::UnsupportedResponseType("x".into()),
2262 "unsupported_response_type",
2263 ),
2264 ];
2265 for (err, expected_code) in cases {
2266 assert_eq!(err.error_code(), expected_code);
2267 assert_eq!(err.description(), "x");
2268 }
2269 }
2270
2271 #[test]
2272 fn oauth_error_debug_and_clone() {
2273 let err = OAuthError::ServerError("test".into());
2274 let debug = format!("{:?}", err);
2275 assert!(debug.contains("ServerError"));
2276 let cloned = err.clone();
2277 assert_eq!(cloned.description(), "test");
2278 }
2279
2280 #[test]
2281 fn oauth_error_is_std_error() {
2282 let err = OAuthError::InvalidGrant("x".into());
2283 let _: &dyn std::error::Error = &err;
2284 }
2285
2286 #[test]
2287 fn oauth_error_into_mcp_error_forbidden() {
2288 let err: McpError = OAuthError::InvalidClient("c".into()).into();
2290 assert!(err.message.contains("invalid_client"));
2291 let err: McpError = OAuthError::UnauthorizedClient("c".into()).into();
2292 assert!(err.message.contains("unauthorized_client"));
2293 let err: McpError = OAuthError::AccessDenied("d".into()).into();
2294 assert!(err.message.contains("access_denied"));
2295 }
2296
2297 #[test]
2298 fn oauth_error_into_mcp_error_invalid_request() {
2299 let err: McpError = OAuthError::InvalidScope("s".into()).into();
2301 assert!(err.message.contains("invalid_scope"));
2302 let err: McpError = OAuthError::UnsupportedGrantType("g".into()).into();
2303 assert!(err.message.contains("unsupported_grant_type"));
2304 }
2305
2306 #[test]
2311 fn server_config_accessor() {
2312 let config = OAuthServerConfig {
2313 issuer: "custom-issuer".to_string(),
2314 ..OAuthServerConfig::default()
2315 };
2316 let server = OAuthServer::new(config);
2317 assert_eq!(server.config().issuer, "custom-issuer");
2318 }
2319
2320 #[test]
2321 fn server_register_public_not_allowed() {
2322 let config = OAuthServerConfig {
2323 allow_public_clients: false,
2324 ..OAuthServerConfig::default()
2325 };
2326 let server = OAuthServer::new(config);
2327
2328 let client = OAuthClient::builder("c")
2329 .redirect_uri("http://localhost/cb")
2330 .build()
2331 .unwrap();
2332 let result = server.register_client(client);
2333 assert!(matches!(result, Err(OAuthError::InvalidClient(_))));
2334 }
2335
2336 #[test]
2337 fn server_list_clients() {
2338 let server = OAuthServer::with_defaults();
2339 assert!(server.list_clients().is_empty());
2340
2341 let client = OAuthClient::builder("a")
2342 .redirect_uri("http://localhost/cb")
2343 .build()
2344 .unwrap();
2345 server.register_client(client).unwrap();
2346 assert_eq!(server.list_clients().len(), 1);
2347 }
2348
2349 #[test]
2350 fn server_authorize_unsupported_response_type() {
2351 let server = OAuthServer::with_defaults();
2352 let client = OAuthClient::builder("c")
2353 .redirect_uri("http://localhost/cb")
2354 .build()
2355 .unwrap();
2356 server.register_client(client).unwrap();
2357
2358 let req = AuthorizationRequest {
2359 response_type: "token".to_string(), client_id: "c".to_string(),
2361 redirect_uri: "http://localhost/cb".to_string(),
2362 scopes: vec![],
2363 state: None,
2364 code_challenge: "ch".to_string(),
2365 code_challenge_method: CodeChallengeMethod::S256,
2366 };
2367 let result = server.authorize(&req, None);
2368 assert!(matches!(
2369 result,
2370 Err(OAuthError::UnsupportedResponseType(_))
2371 ));
2372 }
2373
2374 #[test]
2375 fn server_authorize_invalid_redirect() {
2376 let server = OAuthServer::with_defaults();
2377 let client = OAuthClient::builder("c")
2378 .redirect_uri("http://localhost/cb")
2379 .build()
2380 .unwrap();
2381 server.register_client(client).unwrap();
2382
2383 let req = AuthorizationRequest {
2384 response_type: "code".to_string(),
2385 client_id: "c".to_string(),
2386 redirect_uri: "https://evil.com/cb".to_string(),
2387 scopes: vec![],
2388 state: None,
2389 code_challenge: "ch".to_string(),
2390 code_challenge_method: CodeChallengeMethod::S256,
2391 };
2392 let result = server.authorize(&req, None);
2393 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
2394 }
2395
2396 #[test]
2397 fn server_authorize_invalid_scope() {
2398 let server = OAuthServer::with_defaults();
2399 let client = OAuthClient::builder("c")
2400 .redirect_uri("http://localhost/cb")
2401 .scope("read")
2402 .build()
2403 .unwrap();
2404 server.register_client(client).unwrap();
2405
2406 let req = AuthorizationRequest {
2407 response_type: "code".to_string(),
2408 client_id: "c".to_string(),
2409 redirect_uri: "http://localhost/cb".to_string(),
2410 scopes: vec!["admin".to_string()],
2411 state: None,
2412 code_challenge: "ch".to_string(),
2413 code_challenge_method: CodeChallengeMethod::S256,
2414 };
2415 let result = server.authorize(&req, None);
2416 assert!(matches!(result, Err(OAuthError::InvalidScope(_))));
2417 }
2418
2419 #[test]
2420 fn server_authorize_unknown_client() {
2421 let server = OAuthServer::with_defaults();
2422 let req = AuthorizationRequest {
2423 response_type: "code".to_string(),
2424 client_id: "nonexistent".to_string(),
2425 redirect_uri: "http://localhost/cb".to_string(),
2426 scopes: vec![],
2427 state: None,
2428 code_challenge: "ch".to_string(),
2429 code_challenge_method: CodeChallengeMethod::S256,
2430 };
2431 let result = server.authorize(&req, None);
2432 assert!(matches!(result, Err(OAuthError::InvalidClient(_))));
2433 }
2434
2435 #[test]
2436 fn server_token_unsupported_grant_type() {
2437 let server = OAuthServer::with_defaults();
2438 let req = TokenRequest {
2439 grant_type: "client_credentials".to_string(),
2440 code: None,
2441 redirect_uri: None,
2442 client_id: "c".to_string(),
2443 client_secret: None,
2444 code_verifier: None,
2445 refresh_token: None,
2446 scopes: None,
2447 };
2448 let result = server.token(&req);
2449 assert!(matches!(result, Err(OAuthError::UnsupportedGrantType(_))));
2450 }
2451
2452 #[test]
2453 fn server_token_auth_code_missing_code() {
2454 let server = OAuthServer::with_defaults();
2455 let req = TokenRequest {
2456 grant_type: "authorization_code".to_string(),
2457 code: None, redirect_uri: Some("http://localhost/cb".to_string()),
2459 client_id: "c".to_string(),
2460 client_secret: None,
2461 code_verifier: Some("v".repeat(43)),
2462 refresh_token: None,
2463 scopes: None,
2464 };
2465 let result = server.token(&req);
2466 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
2467 }
2468
2469 #[test]
2470 fn server_token_auth_code_missing_redirect() {
2471 let server = OAuthServer::with_defaults();
2472 let req = TokenRequest {
2473 grant_type: "authorization_code".to_string(),
2474 code: Some("code".to_string()),
2475 redirect_uri: None, client_id: "c".to_string(),
2477 client_secret: None,
2478 code_verifier: Some("v".repeat(43)),
2479 refresh_token: None,
2480 scopes: None,
2481 };
2482 let result = server.token(&req);
2483 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
2484 }
2485
2486 #[test]
2487 fn server_token_auth_code_missing_verifier() {
2488 let server = OAuthServer::with_defaults();
2489 let req = TokenRequest {
2490 grant_type: "authorization_code".to_string(),
2491 code: Some("code".to_string()),
2492 redirect_uri: Some("http://localhost/cb".to_string()),
2493 client_id: "c".to_string(),
2494 client_secret: None,
2495 code_verifier: None, refresh_token: None,
2497 scopes: None,
2498 };
2499 let result = server.token(&req);
2500 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
2501 }
2502
2503 #[test]
2504 fn server_token_auth_code_verifier_too_short() {
2505 let server = OAuthServer::with_defaults();
2506 let client = OAuthClient::builder("c")
2507 .redirect_uri("http://localhost/cb")
2508 .build()
2509 .unwrap();
2510 server.register_client(client).unwrap();
2511
2512 let verifier = "short"; let req = AuthorizationRequest {
2515 response_type: "code".to_string(),
2516 client_id: "c".to_string(),
2517 redirect_uri: "http://localhost/cb".to_string(),
2518 scopes: vec![],
2519 state: None,
2520 code_challenge: verifier.to_string(),
2521 code_challenge_method: CodeChallengeMethod::Plain,
2522 };
2523 let (code, _) = server.authorize(&req, None).unwrap();
2524
2525 let token_req = TokenRequest {
2526 grant_type: "authorization_code".to_string(),
2527 code: Some(code),
2528 redirect_uri: Some("http://localhost/cb".to_string()),
2529 client_id: "c".to_string(),
2530 client_secret: None,
2531 code_verifier: Some(verifier.to_string()),
2532 refresh_token: None,
2533 scopes: None,
2534 };
2535 let result = server.token(&token_req);
2536 assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
2537 }
2538
2539 #[test]
2540 fn server_full_auth_code_flow_with_s256() {
2541 let server = OAuthServer::with_defaults();
2542 let client = OAuthClient::builder("c")
2543 .redirect_uri("http://localhost/cb")
2544 .scope("read")
2545 .build()
2546 .unwrap();
2547 server.register_client(client).unwrap();
2548
2549 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
2550 let challenge = compute_s256_challenge(verifier);
2551
2552 let auth_req = AuthorizationRequest {
2553 response_type: "code".to_string(),
2554 client_id: "c".to_string(),
2555 redirect_uri: "http://localhost/cb".to_string(),
2556 scopes: vec!["read".to_string()],
2557 state: None,
2558 code_challenge: challenge,
2559 code_challenge_method: CodeChallengeMethod::S256,
2560 };
2561 let (code, _) = server
2562 .authorize(&auth_req, Some("user1".to_string()))
2563 .unwrap();
2564
2565 let token_req = TokenRequest {
2566 grant_type: "authorization_code".to_string(),
2567 code: Some(code),
2568 redirect_uri: Some("http://localhost/cb".to_string()),
2569 client_id: "c".to_string(),
2570 client_secret: None,
2571 code_verifier: Some(verifier.to_string()),
2572 refresh_token: None,
2573 scopes: None,
2574 };
2575 let resp = server.token(&token_req).unwrap();
2576 assert!(!resp.access_token.is_empty());
2577 assert!(resp.refresh_token.is_some());
2578 assert_eq!(resp.token_type, "bearer");
2579 assert_eq!(resp.scope, Some("read".to_string()));
2580 }
2581
2582 #[test]
2583 fn server_token_code_already_used() {
2584 let server = OAuthServer::with_defaults();
2585 let client = OAuthClient::builder("c")
2586 .redirect_uri("http://localhost/cb")
2587 .build()
2588 .unwrap();
2589 server.register_client(client).unwrap();
2590
2591 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
2592 let auth_req = AuthorizationRequest {
2593 response_type: "code".to_string(),
2594 client_id: "c".to_string(),
2595 redirect_uri: "http://localhost/cb".to_string(),
2596 scopes: vec![],
2597 state: None,
2598 code_challenge: verifier.to_string(),
2599 code_challenge_method: CodeChallengeMethod::Plain,
2600 };
2601 let (code, _) = server.authorize(&auth_req, None).unwrap();
2602
2603 let token_req = TokenRequest {
2604 grant_type: "authorization_code".to_string(),
2605 code: Some(code.clone()),
2606 redirect_uri: Some("http://localhost/cb".to_string()),
2607 client_id: "c".to_string(),
2608 client_secret: None,
2609 code_verifier: Some(verifier.to_string()),
2610 refresh_token: None,
2611 scopes: None,
2612 };
2613 server.token(&token_req).unwrap();
2615 let result = server.token(&token_req);
2617 assert!(matches!(result, Err(OAuthError::InvalidGrant(_))));
2618 }
2619
2620 #[test]
2621 fn server_validate_access_token_nonexistent() {
2622 let server = OAuthServer::with_defaults();
2623 assert!(server.validate_access_token("nonexistent").is_none());
2624 }
2625
2626 #[test]
2627 fn server_unregister_client_revokes_tokens() {
2628 let server = OAuthServer::with_defaults();
2629 let client = OAuthClient::builder("c")
2630 .redirect_uri("http://localhost/cb")
2631 .scope("read")
2632 .build()
2633 .unwrap();
2634 server.register_client(client).unwrap();
2635
2636 let resp = issue_access_token_via_auth_code(
2637 &server,
2638 "c",
2639 "http://localhost/cb",
2640 &["read"],
2641 "user",
2642 );
2643 assert!(server.validate_access_token(&resp.access_token).is_some());
2644
2645 server.unregister_client("c").unwrap();
2646 assert!(server.validate_access_token(&resp.access_token).is_none());
2647 }
2648
2649 #[test]
2650 fn server_cleanup_expired_removes_old_tokens() {
2651 let config = OAuthServerConfig {
2652 access_token_lifetime: Duration::from_millis(1),
2653 refresh_token_lifetime: Duration::from_millis(1),
2654 authorization_code_lifetime: Duration::from_millis(1),
2655 ..OAuthServerConfig::default()
2656 };
2657 let server = OAuthServer::new(config);
2658 let client = OAuthClient::builder("c")
2659 .redirect_uri("http://localhost/cb")
2660 .build()
2661 .unwrap();
2662 server.register_client(client).unwrap();
2663
2664 let _resp =
2665 issue_access_token_via_auth_code(&server, "c", "http://localhost/cb", &[], "user");
2666
2667 std::thread::sleep(Duration::from_millis(5));
2669
2670 let stats_before = server.stats();
2671 server.cleanup_expired();
2672 let stats_after = server.stats();
2673
2674 assert!(stats_after.access_tokens <= stats_before.access_tokens);
2675 }
2676
2677 #[test]
2682 fn server_stats_default() {
2683 let stats = OAuthServerStats::default();
2684 assert_eq!(stats.clients, 0);
2685 assert_eq!(stats.authorization_codes, 0);
2686 assert_eq!(stats.access_tokens, 0);
2687 assert_eq!(stats.refresh_tokens, 0);
2688 assert_eq!(stats.revoked_tokens, 0);
2689 }
2690
2691 #[test]
2692 fn server_stats_debug_and_clone() {
2693 let stats = OAuthServerStats {
2694 clients: 1,
2695 access_tokens: 5,
2696 ..OAuthServerStats::default()
2697 };
2698 let debug = format!("{:?}", stats);
2699 assert!(debug.contains("OAuthServerStats"));
2700 let cloned = stats.clone();
2701 assert_eq!(cloned.clients, 1);
2702 }
2703
2704 #[test]
2709 fn is_localhost_redirect_tests() {
2710 assert!(is_localhost_redirect("http://localhost:3000/cb"));
2711 assert!(is_localhost_redirect("http://127.0.0.1:8080/cb"));
2712 assert!(is_localhost_redirect("http://[::1]:9000/cb"));
2713 assert!(!is_localhost_redirect("https://example.com/cb"));
2714 assert!(!is_localhost_redirect("http://evil.com/cb"));
2715 }
2716
2717 #[test]
2718 fn normalize_localhost_variants() {
2719 assert_eq!(normalize_localhost("localhost"), "localhost");
2720 assert_eq!(normalize_localhost("127.0.0.1"), "localhost");
2721 assert_eq!(normalize_localhost("[::1]"), "localhost");
2722 assert_eq!(normalize_localhost("example.com"), "other");
2723 }
2724
2725 #[test]
2726 fn compute_s256_challenge_deterministic() {
2727 let v = "test-verifier";
2728 let c1 = compute_s256_challenge(v);
2729 let c2 = compute_s256_challenge(v);
2730 assert_eq!(c1, c2);
2731 assert!(!c1.is_empty());
2732 }
2733
2734 #[test]
2735 fn url_encode_special_chars() {
2736 assert_eq!(url_encode("a b"), "a%20b");
2737 assert_eq!(url_encode("a+b"), "a%2Bb");
2738 assert_eq!(url_encode("a/b"), "a%2Fb");
2739 assert_eq!(url_encode("safe-_~."), "safe-_~.");
2740 }
2741
2742 #[test]
2743 fn constant_time_eq_same_length_different() {
2744 assert!(!constant_time_eq("abc", "abd"));
2745 }
2746
2747 #[test]
2748 fn localhost_match_different_paths_fail() {
2749 assert!(!localhost_match(
2750 "http://localhost:3000/a",
2751 "http://localhost:3000/b"
2752 ));
2753 }
2754
2755 #[test]
2756 fn localhost_match_non_http_fails() {
2757 assert!(!localhost_match("ftp://localhost/a", "ftp://localhost/a"));
2758 }
2759
2760 #[test]
2765 fn server_refresh_token_flow() {
2766 let server = OAuthServer::with_defaults();
2767 let client = OAuthClient::builder("c1")
2768 .redirect_uri("http://localhost/cb")
2769 .scope("read")
2770 .scope("write")
2771 .build()
2772 .unwrap();
2773 server.register_client(client).unwrap();
2774
2775 let token_resp = issue_access_token_via_auth_code(
2776 &server,
2777 "c1",
2778 "http://localhost/cb",
2779 &["read", "write"],
2780 "user1",
2781 );
2782 let refresh = token_resp.refresh_token.unwrap();
2783
2784 let new_resp = server
2786 .token(&TokenRequest {
2787 grant_type: "refresh_token".to_string(),
2788 code: None,
2789 redirect_uri: None,
2790 client_id: "c1".to_string(),
2791 client_secret: None,
2792 code_verifier: None,
2793 refresh_token: Some(refresh),
2794 scopes: None,
2795 })
2796 .unwrap();
2797
2798 assert_ne!(new_resp.access_token, token_resp.access_token);
2800 assert_eq!(new_resp.token_type, "bearer");
2801 assert!(new_resp.refresh_token.is_none());
2803 assert!(new_resp.scope.is_some());
2805 }
2806
2807 #[test]
2808 fn server_refresh_token_scope_narrowing() {
2809 let server = OAuthServer::with_defaults();
2810 let client = OAuthClient::builder("c1")
2811 .redirect_uri("http://localhost/cb")
2812 .scope("read")
2813 .scope("write")
2814 .build()
2815 .unwrap();
2816 server.register_client(client).unwrap();
2817
2818 let token_resp = issue_access_token_via_auth_code(
2819 &server,
2820 "c1",
2821 "http://localhost/cb",
2822 &["read", "write"],
2823 "user1",
2824 );
2825 let refresh = token_resp.refresh_token.unwrap();
2826
2827 let new_resp = server
2829 .token(&TokenRequest {
2830 grant_type: "refresh_token".to_string(),
2831 code: None,
2832 redirect_uri: None,
2833 client_id: "c1".to_string(),
2834 client_secret: None,
2835 code_verifier: None,
2836 refresh_token: Some(refresh),
2837 scopes: Some(vec!["read".to_string()]),
2838 })
2839 .unwrap();
2840
2841 assert_eq!(new_resp.scope, Some("read".to_string()));
2842 }
2843
2844 #[test]
2845 fn server_refresh_token_invalid_scope() {
2846 let server = OAuthServer::with_defaults();
2847 let client = OAuthClient::builder("c1")
2848 .redirect_uri("http://localhost/cb")
2849 .scope("read")
2850 .build()
2851 .unwrap();
2852 server.register_client(client).unwrap();
2853
2854 let token_resp = issue_access_token_via_auth_code(
2855 &server,
2856 "c1",
2857 "http://localhost/cb",
2858 &["read"],
2859 "user1",
2860 );
2861 let refresh = token_resp.refresh_token.unwrap();
2862
2863 let err = server
2865 .token(&TokenRequest {
2866 grant_type: "refresh_token".to_string(),
2867 code: None,
2868 redirect_uri: None,
2869 client_id: "c1".to_string(),
2870 client_secret: None,
2871 code_verifier: None,
2872 refresh_token: Some(refresh),
2873 scopes: Some(vec!["admin".to_string()]),
2874 })
2875 .unwrap_err();
2876
2877 assert_eq!(err.error_code(), "invalid_scope");
2878 }
2879
2880 #[test]
2881 fn server_refresh_token_revoked() {
2882 let server = OAuthServer::with_defaults();
2883 let client = OAuthClient::builder("c1")
2884 .redirect_uri("http://localhost/cb")
2885 .scope("read")
2886 .build()
2887 .unwrap();
2888 server.register_client(client).unwrap();
2889
2890 let token_resp = issue_access_token_via_auth_code(
2891 &server,
2892 "c1",
2893 "http://localhost/cb",
2894 &["read"],
2895 "user1",
2896 );
2897 let refresh = token_resp.refresh_token.unwrap();
2898
2899 server.revoke(&refresh, "c1", None).unwrap();
2901
2902 let err = server
2904 .token(&TokenRequest {
2905 grant_type: "refresh_token".to_string(),
2906 code: None,
2907 redirect_uri: None,
2908 client_id: "c1".to_string(),
2909 client_secret: None,
2910 code_verifier: None,
2911 refresh_token: Some(refresh),
2912 scopes: None,
2913 })
2914 .unwrap_err();
2915
2916 assert_eq!(err.error_code(), "invalid_grant");
2917 assert!(err.description().contains("revoked"));
2918 }
2919
2920 #[test]
2921 fn server_refresh_token_client_id_mismatch() {
2922 let server = OAuthServer::with_defaults();
2923 let client1 = OAuthClient::builder("c1")
2924 .redirect_uri("http://localhost/cb")
2925 .scope("read")
2926 .build()
2927 .unwrap();
2928 let client2 = OAuthClient::builder("c2")
2929 .redirect_uri("http://localhost/cb")
2930 .scope("read")
2931 .build()
2932 .unwrap();
2933 server.register_client(client1).unwrap();
2934 server.register_client(client2).unwrap();
2935
2936 let token_resp = issue_access_token_via_auth_code(
2937 &server,
2938 "c1",
2939 "http://localhost/cb",
2940 &["read"],
2941 "user1",
2942 );
2943 let refresh = token_resp.refresh_token.unwrap();
2944
2945 let err = server
2947 .token(&TokenRequest {
2948 grant_type: "refresh_token".to_string(),
2949 code: None,
2950 redirect_uri: None,
2951 client_id: "c2".to_string(),
2952 client_secret: None,
2953 code_verifier: None,
2954 refresh_token: Some(refresh),
2955 scopes: None,
2956 })
2957 .unwrap_err();
2958
2959 assert_eq!(err.error_code(), "invalid_grant");
2960 assert!(err.description().contains("client_id"));
2961 }
2962
2963 #[test]
2964 fn server_refresh_token_missing_param() {
2965 let server = OAuthServer::with_defaults();
2966 let client = OAuthClient::builder("c1")
2967 .redirect_uri("http://localhost/cb")
2968 .build()
2969 .unwrap();
2970 server.register_client(client).unwrap();
2971
2972 let err = server
2973 .token(&TokenRequest {
2974 grant_type: "refresh_token".to_string(),
2975 code: None,
2976 redirect_uri: None,
2977 client_id: "c1".to_string(),
2978 client_secret: None,
2979 code_verifier: None,
2980 refresh_token: None,
2981 scopes: None,
2982 })
2983 .unwrap_err();
2984
2985 assert_eq!(err.error_code(), "invalid_request");
2986 assert!(err.description().contains("refresh_token"));
2987 }
2988
2989 #[test]
2990 fn server_refresh_token_not_found() {
2991 let server = OAuthServer::with_defaults();
2992 let client = OAuthClient::builder("c1")
2993 .redirect_uri("http://localhost/cb")
2994 .build()
2995 .unwrap();
2996 server.register_client(client).unwrap();
2997
2998 let err = server
2999 .token(&TokenRequest {
3000 grant_type: "refresh_token".to_string(),
3001 code: None,
3002 redirect_uri: None,
3003 client_id: "c1".to_string(),
3004 client_secret: None,
3005 code_verifier: None,
3006 refresh_token: Some("nonexistent".to_string()),
3007 scopes: None,
3008 })
3009 .unwrap_err();
3010
3011 assert_eq!(err.error_code(), "invalid_grant");
3012 }
3013
3014 #[test]
3019 fn server_token_auth_code_redirect_uri_mismatch() {
3020 let server = OAuthServer::with_defaults();
3021 let client = OAuthClient::builder("c1")
3022 .redirect_uri("http://localhost/cb")
3023 .redirect_uri("http://localhost/cb2")
3024 .scope("read")
3025 .build()
3026 .unwrap();
3027 server.register_client(client).unwrap();
3028
3029 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
3030 let (code, _) = server
3031 .authorize(
3032 &AuthorizationRequest {
3033 response_type: "code".to_string(),
3034 client_id: "c1".to_string(),
3035 redirect_uri: "http://localhost/cb".to_string(),
3036 scopes: vec!["read".to_string()],
3037 state: None,
3038 code_challenge: verifier.to_string(),
3039 code_challenge_method: CodeChallengeMethod::Plain,
3040 },
3041 None,
3042 )
3043 .unwrap();
3044
3045 let err = server
3047 .token(&TokenRequest {
3048 grant_type: "authorization_code".to_string(),
3049 code: Some(code),
3050 redirect_uri: Some("http://localhost/cb2".to_string()),
3051 client_id: "c1".to_string(),
3052 client_secret: None,
3053 code_verifier: Some(verifier.to_string()),
3054 refresh_token: None,
3055 scopes: None,
3056 })
3057 .unwrap_err();
3058
3059 assert_eq!(err.error_code(), "invalid_grant");
3060 assert!(err.description().contains("redirect_uri"));
3061 }
3062
3063 #[test]
3064 fn server_token_auth_code_client_id_mismatch() {
3065 let server = OAuthServer::with_defaults();
3066 let client1 = OAuthClient::builder("c1")
3067 .redirect_uri("http://localhost/cb")
3068 .scope("read")
3069 .build()
3070 .unwrap();
3071 let client2 = OAuthClient::builder("c2")
3072 .redirect_uri("http://localhost/cb")
3073 .scope("read")
3074 .build()
3075 .unwrap();
3076 server.register_client(client1).unwrap();
3077 server.register_client(client2).unwrap();
3078
3079 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
3080 let (code, _) = server
3081 .authorize(
3082 &AuthorizationRequest {
3083 response_type: "code".to_string(),
3084 client_id: "c1".to_string(),
3085 redirect_uri: "http://localhost/cb".to_string(),
3086 scopes: vec!["read".to_string()],
3087 state: None,
3088 code_challenge: verifier.to_string(),
3089 code_challenge_method: CodeChallengeMethod::Plain,
3090 },
3091 None,
3092 )
3093 .unwrap();
3094
3095 let err = server
3097 .token(&TokenRequest {
3098 grant_type: "authorization_code".to_string(),
3099 code: Some(code),
3100 redirect_uri: Some("http://localhost/cb".to_string()),
3101 client_id: "c2".to_string(),
3102 client_secret: None,
3103 code_verifier: Some(verifier.to_string()),
3104 refresh_token: None,
3105 scopes: None,
3106 })
3107 .unwrap_err();
3108
3109 assert_eq!(err.error_code(), "invalid_grant");
3110 assert!(err.description().contains("client_id"));
3111 }
3112
3113 #[test]
3114 fn server_token_auth_code_confidential_client_auth_fails() {
3115 let server = OAuthServer::with_defaults();
3116 let client = OAuthClient::builder("c1")
3117 .secret("correct-secret")
3118 .redirect_uri("http://localhost/cb")
3119 .scope("read")
3120 .build()
3121 .unwrap();
3122 server.register_client(client).unwrap();
3123
3124 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
3125 let (code, _) = server
3126 .authorize(
3127 &AuthorizationRequest {
3128 response_type: "code".to_string(),
3129 client_id: "c1".to_string(),
3130 redirect_uri: "http://localhost/cb".to_string(),
3131 scopes: vec!["read".to_string()],
3132 state: None,
3133 code_challenge: verifier.to_string(),
3134 code_challenge_method: CodeChallengeMethod::Plain,
3135 },
3136 None,
3137 )
3138 .unwrap();
3139
3140 let err = server
3142 .token(&TokenRequest {
3143 grant_type: "authorization_code".to_string(),
3144 code: Some(code),
3145 redirect_uri: Some("http://localhost/cb".to_string()),
3146 client_id: "c1".to_string(),
3147 client_secret: Some("wrong-secret".to_string()),
3148 code_verifier: Some(verifier.to_string()),
3149 refresh_token: None,
3150 scopes: None,
3151 })
3152 .unwrap_err();
3153
3154 assert_eq!(err.error_code(), "invalid_client");
3155 }
3156
3157 #[test]
3158 fn server_token_auth_code_verifier_too_long() {
3159 let server = OAuthServer::with_defaults();
3160 let client = OAuthClient::builder("c1")
3161 .redirect_uri("http://localhost/cb")
3162 .scope("read")
3163 .build()
3164 .unwrap();
3165 server.register_client(client).unwrap();
3166
3167 let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
3168 let (code, _) = server
3169 .authorize(
3170 &AuthorizationRequest {
3171 response_type: "code".to_string(),
3172 client_id: "c1".to_string(),
3173 redirect_uri: "http://localhost/cb".to_string(),
3174 scopes: vec!["read".to_string()],
3175 state: None,
3176 code_challenge: challenge.to_string(),
3177 code_challenge_method: CodeChallengeMethod::Plain,
3178 },
3179 None,
3180 )
3181 .unwrap();
3182
3183 let long_verifier = "a".repeat(129);
3185 let err = server
3186 .token(&TokenRequest {
3187 grant_type: "authorization_code".to_string(),
3188 code: Some(code),
3189 redirect_uri: Some("http://localhost/cb".to_string()),
3190 client_id: "c1".to_string(),
3191 client_secret: None,
3192 code_verifier: Some(long_verifier),
3193 refresh_token: None,
3194 scopes: None,
3195 })
3196 .unwrap_err();
3197
3198 assert_eq!(err.error_code(), "invalid_request");
3199 assert!(err.description().contains("code_verifier"));
3200 }
3201
3202 #[test]
3207 fn server_authorize_empty_code_challenge() {
3208 let server = OAuthServer::with_defaults();
3209 let client = OAuthClient::builder("c1")
3210 .redirect_uri("http://localhost/cb")
3211 .scope("read")
3212 .build()
3213 .unwrap();
3214 server.register_client(client).unwrap();
3215
3216 let err = server
3217 .authorize(
3218 &AuthorizationRequest {
3219 response_type: "code".to_string(),
3220 client_id: "c1".to_string(),
3221 redirect_uri: "http://localhost/cb".to_string(),
3222 scopes: vec!["read".to_string()],
3223 state: None,
3224 code_challenge: String::new(),
3225 code_challenge_method: CodeChallengeMethod::S256,
3226 },
3227 None,
3228 )
3229 .unwrap_err();
3230
3231 assert_eq!(err.error_code(), "invalid_request");
3232 assert!(err.description().contains("code_challenge"));
3233 }
3234
3235 #[test]
3236 fn server_authorize_with_state_in_redirect() {
3237 let server = OAuthServer::with_defaults();
3238 let client = OAuthClient::builder("c1")
3239 .redirect_uri("http://localhost/cb")
3240 .scope("read")
3241 .build()
3242 .unwrap();
3243 server.register_client(client).unwrap();
3244
3245 let (code, redirect) = server
3246 .authorize(
3247 &AuthorizationRequest {
3248 response_type: "code".to_string(),
3249 client_id: "c1".to_string(),
3250 redirect_uri: "http://localhost/cb".to_string(),
3251 scopes: vec!["read".to_string()],
3252 state: Some("my-csrf-state".to_string()),
3253 code_challenge: "challenge-value".to_string(),
3254 code_challenge_method: CodeChallengeMethod::S256,
3255 },
3256 Some("user1".to_string()),
3257 )
3258 .unwrap();
3259
3260 assert!(redirect.contains("code="));
3262 assert!(redirect.contains(&url_encode(&code)));
3263 assert!(redirect.contains("state=my-csrf-state"));
3264 }
3265
3266 #[test]
3267 fn server_authorize_redirect_with_existing_query() {
3268 let server = OAuthServer::with_defaults();
3269 let client = OAuthClient::builder("c1")
3270 .redirect_uri("http://localhost/cb?foo=bar")
3271 .scope("read")
3272 .build()
3273 .unwrap();
3274 server.register_client(client).unwrap();
3275
3276 let (_code, redirect) = server
3277 .authorize(
3278 &AuthorizationRequest {
3279 response_type: "code".to_string(),
3280 client_id: "c1".to_string(),
3281 redirect_uri: "http://localhost/cb?foo=bar".to_string(),
3282 scopes: vec!["read".to_string()],
3283 state: None,
3284 code_challenge: "chal".to_string(),
3285 code_challenge_method: CodeChallengeMethod::Plain,
3286 },
3287 None,
3288 )
3289 .unwrap();
3290
3291 assert!(redirect.starts_with("http://localhost/cb?foo=bar&code="));
3293 }
3294
3295 #[test]
3300 fn oauth_error_access_denied_into_mcp_error() {
3301 let err = OAuthError::AccessDenied("denied".to_string());
3302 let mcp: McpError = err.into();
3303 assert_eq!(mcp.code, McpErrorCode::ResourceForbidden);
3304 }
3305
3306 #[test]
3307 fn oauth_error_description_all_variants() {
3308 let cases: Vec<(OAuthError, &str)> = vec![
3309 (OAuthError::ServerError("srv".into()), "srv"),
3310 (OAuthError::TemporarilyUnavailable("tmp".into()), "tmp"),
3311 (OAuthError::UnsupportedResponseType("rt".into()), "rt"),
3312 ];
3313 for (err, expected) in cases {
3314 assert_eq!(err.description(), expected);
3315 }
3316 }
3317
3318 #[test]
3319 fn oauth_error_display_all_remaining_variants() {
3320 let err = OAuthError::TemporarilyUnavailable("try later".into());
3321 assert_eq!(format!("{err}"), "temporarily_unavailable: try later");
3322
3323 let err = OAuthError::UnsupportedResponseType("bad".into());
3324 assert_eq!(format!("{err}"), "unsupported_response_type: bad");
3325
3326 let err = OAuthError::AccessDenied("nope".into());
3327 assert_eq!(format!("{err}"), "access_denied: nope");
3328 }
3329
3330 #[test]
3335 fn server_revoke_unknown_token_succeeds() {
3336 let server = OAuthServer::with_defaults();
3337 let client = OAuthClient::builder("c1")
3338 .redirect_uri("http://localhost/cb")
3339 .build()
3340 .unwrap();
3341 server.register_client(client).unwrap();
3342
3343 server.revoke("no-such-token", "c1", None).unwrap();
3345 }
3346
3347 #[test]
3348 fn server_revoke_token_owned_by_other_client() {
3349 let server = OAuthServer::with_defaults();
3350 let client1 = OAuthClient::builder("c1")
3351 .redirect_uri("http://localhost/cb")
3352 .scope("read")
3353 .build()
3354 .unwrap();
3355 let client2 = OAuthClient::builder("c2")
3356 .redirect_uri("http://localhost/cb")
3357 .scope("read")
3358 .build()
3359 .unwrap();
3360 server.register_client(client1).unwrap();
3361 server.register_client(client2).unwrap();
3362
3363 let token_resp = issue_access_token_via_auth_code(
3364 &server,
3365 "c1",
3366 "http://localhost/cb",
3367 &["read"],
3368 "user1",
3369 );
3370
3371 server.revoke(&token_resp.access_token, "c2", None).unwrap();
3374
3375 }
3379
3380 #[test]
3381 fn server_revoke_unknown_client_fails() {
3382 let server = OAuthServer::with_defaults();
3383 let err = server.revoke("some-token", "unknown", None).unwrap_err();
3384 assert_eq!(err.error_code(), "invalid_client");
3385 }
3386
3387 #[test]
3392 fn server_unregister_client_removes_auth_codes() {
3393 let server = OAuthServer::with_defaults();
3394 let client = OAuthClient::builder("c1")
3395 .redirect_uri("http://localhost/cb")
3396 .scope("read")
3397 .build()
3398 .unwrap();
3399 server.register_client(client).unwrap();
3400
3401 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
3403 let (code, _) = server
3404 .authorize(
3405 &AuthorizationRequest {
3406 response_type: "code".to_string(),
3407 client_id: "c1".to_string(),
3408 redirect_uri: "http://localhost/cb".to_string(),
3409 scopes: vec!["read".to_string()],
3410 state: None,
3411 code_challenge: verifier.to_string(),
3412 code_challenge_method: CodeChallengeMethod::Plain,
3413 },
3414 None,
3415 )
3416 .unwrap();
3417
3418 {
3420 let state = server.state.read().unwrap();
3421 assert!(state.authorization_codes.contains_key(&code));
3422 }
3423
3424 server.unregister_client("c1").unwrap();
3426
3427 {
3429 let state = server.state.read().unwrap();
3430 assert!(!state.authorization_codes.contains_key(&code));
3431 }
3432 }
3433
3434 #[test]
3439 fn server_with_defaults_is_valid() {
3440 let server = OAuthServer::with_defaults();
3441 assert_eq!(server.config().issuer, "fastmcp");
3442 assert!(server.config().allow_public_clients);
3443 }
3444
3445 #[test]
3446 fn server_get_client_none_for_unknown() {
3447 let server = OAuthServer::with_defaults();
3448 assert!(server.get_client("nonexistent").is_none());
3449 }
3450
3451 #[test]
3452 fn server_validate_access_token_after_revoke() {
3453 let server = OAuthServer::with_defaults();
3454 let client = OAuthClient::builder("c1")
3455 .redirect_uri("http://localhost/cb")
3456 .scope("read")
3457 .build()
3458 .unwrap();
3459 server.register_client(client).unwrap();
3460
3461 let resp = issue_access_token_via_auth_code(
3462 &server,
3463 "c1",
3464 "http://localhost/cb",
3465 &["read"],
3466 "user1",
3467 );
3468
3469 assert!(server.validate_access_token(&resp.access_token).is_some());
3470 server.revoke(&resp.access_token, "c1", None).unwrap();
3471 assert!(server.validate_access_token(&resp.access_token).is_none());
3472 }
3473
3474 #[test]
3475 fn token_verifier_claims_contain_client_id_and_issuer() {
3476 let server = Arc::new(OAuthServer::with_defaults());
3477 let client = OAuthClient::builder("my-app")
3478 .redirect_uri("http://localhost/cb")
3479 .scope("read")
3480 .build()
3481 .unwrap();
3482 server.register_client(client).unwrap();
3483
3484 let token_resp = issue_access_token_via_auth_code(
3485 server.as_ref(),
3486 "my-app",
3487 "http://localhost/cb",
3488 &["read"],
3489 "user42",
3490 );
3491
3492 let verifier = server.token_verifier();
3493 let cx = asupersync::Cx::for_testing();
3494 let mcp_ctx = McpContext::new(cx, 1);
3495 let auth_request = AuthRequest {
3496 method: "test",
3497 params: None,
3498 request_id: 1,
3499 };
3500 let access = AccessToken {
3501 scheme: "Bearer".to_string(),
3502 token: token_resp.access_token,
3503 };
3504 let auth = verifier.verify(&mcp_ctx, auth_request, &access).unwrap();
3505
3506 let claims = auth.claims.unwrap();
3508 assert_eq!(claims["client_id"], "my-app");
3509 assert_eq!(claims["iss"], "fastmcp");
3510 }
3511
3512 #[test]
3513 fn oauth_token_expires_in_secs_positive() {
3514 let token = OAuthToken {
3515 token: "t".to_string(),
3516 token_type: TokenType::Bearer,
3517 client_id: "c".to_string(),
3518 scopes: vec![],
3519 issued_at: Instant::now(),
3520 expires_at: Instant::now() + Duration::from_secs(3600),
3521 subject: None,
3522 is_refresh_token: false,
3523 };
3524 assert!(token.expires_in_secs() > 0);
3526 }
3527
3528 #[test]
3529 fn server_refresh_token_confidential_client_auth_fails() {
3530 let server = OAuthServer::with_defaults();
3531 let client = OAuthClient::builder("c1")
3532 .secret("correct-secret")
3533 .redirect_uri("http://localhost/cb")
3534 .scope("read")
3535 .build()
3536 .unwrap();
3537 server.register_client(client).unwrap();
3538
3539 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
3541 let (code, _) = server
3542 .authorize(
3543 &AuthorizationRequest {
3544 response_type: "code".to_string(),
3545 client_id: "c1".to_string(),
3546 redirect_uri: "http://localhost/cb".to_string(),
3547 scopes: vec!["read".to_string()],
3548 state: None,
3549 code_challenge: verifier.to_string(),
3550 code_challenge_method: CodeChallengeMethod::Plain,
3551 },
3552 None,
3553 )
3554 .unwrap();
3555
3556 let token_resp = server
3558 .token(&TokenRequest {
3559 grant_type: "authorization_code".to_string(),
3560 code: Some(code),
3561 redirect_uri: Some("http://localhost/cb".to_string()),
3562 client_id: "c1".to_string(),
3563 client_secret: Some("correct-secret".to_string()),
3564 code_verifier: Some(verifier.to_string()),
3565 refresh_token: None,
3566 scopes: None,
3567 })
3568 .unwrap();
3569
3570 let refresh = token_resp.refresh_token.unwrap();
3571
3572 let err = server
3574 .token(&TokenRequest {
3575 grant_type: "refresh_token".to_string(),
3576 code: None,
3577 redirect_uri: None,
3578 client_id: "c1".to_string(),
3579 client_secret: Some("wrong-secret".to_string()),
3580 code_verifier: None,
3581 refresh_token: Some(refresh),
3582 scopes: None,
3583 })
3584 .unwrap_err();
3585
3586 assert_eq!(err.error_code(), "invalid_client");
3587 }
3588
3589 #[test]
3590 fn code_challenge_method_parse_unknown() {
3591 assert!(CodeChallengeMethod::parse("sha512").is_none());
3592 assert!(CodeChallengeMethod::parse("").is_none());
3593 }
3594
3595 #[test]
3596 fn constant_time_eq_different_lengths() {
3597 assert!(!constant_time_eq("short", "longer_string"));
3598 assert!(!constant_time_eq("", "a"));
3599 }
3600
3601 #[test]
3602 fn constant_time_eq_empty_strings() {
3603 assert!(constant_time_eq("", ""));
3604 }
3605
3606 #[test]
3607 fn localhost_match_different_localhost_variants() {
3608 assert!(localhost_match(
3610 "http://localhost:3000/cb",
3611 "http://127.0.0.1:8080/cb"
3612 ));
3613 assert!(localhost_match(
3615 "http://localhost:3000/cb",
3616 "http://[::1]:9000/cb"
3617 ));
3618 }
3619
3620 #[test]
3621 fn url_encode_empty_and_unicode() {
3622 assert_eq!(url_encode(""), "");
3623 let encoded = url_encode("ü");
3625 assert!(encoded.contains('%'));
3626 }
3627
3628 #[test]
3633 fn server_revoke_confidential_client_wrong_secret() {
3634 let server = OAuthServer::with_defaults();
3635 let client = OAuthClient::builder("c1")
3636 .secret("correct")
3637 .redirect_uri("http://localhost/cb")
3638 .scope("read")
3639 .build()
3640 .unwrap();
3641 server.register_client(client).unwrap();
3642
3643 let err = server.revoke("any-token", "c1", Some("wrong")).unwrap_err();
3644 assert_eq!(err.error_code(), "invalid_client");
3645 }
3646
3647 #[test]
3648 fn server_validate_access_token_expired_returns_none() {
3649 let config = OAuthServerConfig {
3650 access_token_lifetime: Duration::from_millis(1),
3651 ..OAuthServerConfig::default()
3652 };
3653 let server = OAuthServer::new(config);
3654 let client = OAuthClient::builder("c1")
3655 .redirect_uri("http://localhost/cb")
3656 .scope("read")
3657 .build()
3658 .unwrap();
3659 server.register_client(client).unwrap();
3660
3661 let resp = issue_access_token_via_auth_code(
3662 &server,
3663 "c1",
3664 "http://localhost/cb",
3665 &["read"],
3666 "user1",
3667 );
3668
3669 std::thread::sleep(Duration::from_millis(5));
3670 assert!(server.validate_access_token(&resp.access_token).is_none());
3671 }
3672
3673 #[test]
3674 fn server_authorize_without_state_omits_state_from_redirect() {
3675 let server = OAuthServer::with_defaults();
3676 let client = OAuthClient::builder("c1")
3677 .redirect_uri("http://localhost/cb")
3678 .build()
3679 .unwrap();
3680 server.register_client(client).unwrap();
3681
3682 let (_code, redirect) = server
3683 .authorize(
3684 &AuthorizationRequest {
3685 response_type: "code".to_string(),
3686 client_id: "c1".to_string(),
3687 redirect_uri: "http://localhost/cb".to_string(),
3688 scopes: vec![],
3689 state: None,
3690 code_challenge: "chal".to_string(),
3691 code_challenge_method: CodeChallengeMethod::Plain,
3692 },
3693 None,
3694 )
3695 .unwrap();
3696
3697 assert!(redirect.contains("code="));
3698 assert!(!redirect.contains("state="));
3699 }
3700
3701 #[test]
3702 fn server_refresh_token_client_deleted_after_issue() {
3703 let server = OAuthServer::with_defaults();
3704 let client = OAuthClient::builder("c1")
3705 .redirect_uri("http://localhost/cb")
3706 .scope("read")
3707 .build()
3708 .unwrap();
3709 server.register_client(client).unwrap();
3710
3711 let token_resp = issue_access_token_via_auth_code(
3712 &server,
3713 "c1",
3714 "http://localhost/cb",
3715 &["read"],
3716 "user1",
3717 );
3718 let refresh = token_resp.refresh_token.unwrap();
3719
3720 server.unregister_client("c1").unwrap();
3722
3723 let err = server
3724 .token(&TokenRequest {
3725 grant_type: "refresh_token".to_string(),
3726 code: None,
3727 redirect_uri: None,
3728 client_id: "c1".to_string(),
3729 client_secret: None,
3730 code_verifier: None,
3731 refresh_token: Some(refresh),
3732 scopes: None,
3733 })
3734 .unwrap_err();
3735
3736 assert_eq!(err.error_code(), "invalid_grant");
3738 }
3739
3740 #[test]
3741 fn server_issue_tokens_empty_scopes_returns_no_scope() {
3742 let server = OAuthServer::with_defaults();
3743 let client = OAuthClient::builder("c1")
3744 .redirect_uri("http://localhost/cb")
3745 .build()
3746 .unwrap();
3747 server.register_client(client).unwrap();
3748
3749 let resp =
3750 issue_access_token_via_auth_code(&server, "c1", "http://localhost/cb", &[], "user1");
3751
3752 assert!(resp.scope.is_none());
3753 }
3754
3755 #[test]
3756 fn server_revoke_refresh_token_specifically() {
3757 let server = OAuthServer::with_defaults();
3758 let client = OAuthClient::builder("c1")
3759 .redirect_uri("http://localhost/cb")
3760 .scope("read")
3761 .build()
3762 .unwrap();
3763 server.register_client(client).unwrap();
3764
3765 let resp = issue_access_token_via_auth_code(
3766 &server,
3767 "c1",
3768 "http://localhost/cb",
3769 &["read"],
3770 "user1",
3771 );
3772 let refresh = resp.refresh_token.unwrap();
3773
3774 server.revoke(&refresh, "c1", None).unwrap();
3776
3777 {
3779 let state = server.state.read().unwrap();
3780 assert!(state.revoked_tokens.contains(&refresh));
3781 }
3782 }
3783
3784 #[test]
3785 fn localhost_match_no_explicit_port() {
3786 assert!(localhost_match(
3787 "http://localhost/callback",
3788 "http://localhost:8080/callback"
3789 ));
3790 assert!(localhost_match(
3791 "http://localhost/callback",
3792 "http://localhost/callback"
3793 ));
3794 }
3795}