1use crate::{McpError, McpTransport, Result};
26use async_trait::async_trait;
27use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30use std::collections::HashMap;
31use std::str::FromStr;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum AuthMethod {
36 None,
38 ApiKey(ApiKeyAuth),
40 Basic(BasicAuth),
42 Bearer(BearerAuth),
44 CustomHeader(CustomHeaderAuth),
46 OAuth2(OAuth2Auth),
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ApiKeyAuth {
53 pub header_name: String,
55 pub api_key: String,
57 pub prefix: Option<String>,
59}
60
61impl ApiKeyAuth {
62 pub fn new(header_name: impl Into<String>, api_key: impl Into<String>) -> Self {
64 Self {
65 header_name: header_name.into(),
66 api_key: api_key.into(),
67 prefix: None,
68 }
69 }
70
71 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
73 self.prefix = Some(prefix.into());
74 self
75 }
76
77 pub fn header_value(&self) -> String {
79 match &self.prefix {
80 Some(prefix) => format!("{} {}", prefix, self.api_key),
81 None => self.api_key.clone(),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct BasicAuth {
89 pub username: String,
91 pub password: String,
93}
94
95impl BasicAuth {
96 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
98 Self {
99 username: username.into(),
100 password: password.into(),
101 }
102 }
103
104 pub fn encoded_credentials(&self) -> String {
106 use base64::Engine;
107 let credentials = format!("{}:{}", self.username, self.password);
108 base64::engine::general_purpose::STANDARD.encode(credentials)
109 }
110
111 pub fn header_value(&self) -> String {
113 format!("Basic {}", self.encoded_credentials())
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct BearerAuth {
120 pub token: String,
122}
123
124impl BearerAuth {
125 pub fn new(token: impl Into<String>) -> Self {
127 Self {
128 token: token.into(),
129 }
130 }
131
132 pub fn header_value(&self) -> String {
134 format!("Bearer {}", self.token)
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct CustomHeaderAuth {
141 pub headers: HashMap<String, String>,
143}
144
145impl CustomHeaderAuth {
146 pub fn new() -> Self {
148 Self {
149 headers: HashMap::new(),
150 }
151 }
152
153 pub fn add_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
155 self.headers.insert(name.into(), value.into());
156 self
157 }
158}
159
160impl Default for CustomHeaderAuth {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
168pub enum OAuth2GrantType {
169 ClientCredentials,
171 AuthorizationCode,
173 RefreshToken,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct OAuth2Auth {
180 pub token_url: String,
182 pub client_id: String,
184 pub client_secret: Option<String>,
186 pub grant_type: OAuth2GrantType,
188 pub access_token: Option<String>,
190 pub refresh_token: Option<String>,
192 pub expires_at: Option<i64>,
194 pub scopes: Vec<String>,
196 pub code_verifier: Option<String>,
198 pub authorization_code: Option<String>,
200}
201
202impl OAuth2Auth {
203 pub fn client_credentials(
205 token_url: impl Into<String>,
206 client_id: impl Into<String>,
207 client_secret: impl Into<String>,
208 ) -> Self {
209 Self {
210 token_url: token_url.into(),
211 client_id: client_id.into(),
212 client_secret: Some(client_secret.into()),
213 grant_type: OAuth2GrantType::ClientCredentials,
214 access_token: None,
215 refresh_token: None,
216 expires_at: None,
217 scopes: Vec::new(),
218 code_verifier: None,
219 authorization_code: None,
220 }
221 }
222
223 pub fn authorization_code(
225 token_url: impl Into<String>,
226 client_id: impl Into<String>,
227 client_secret: Option<String>,
228 authorization_code: impl Into<String>,
229 ) -> Self {
230 Self {
231 token_url: token_url.into(),
232 client_id: client_id.into(),
233 client_secret,
234 grant_type: OAuth2GrantType::AuthorizationCode,
235 access_token: None,
236 refresh_token: None,
237 expires_at: None,
238 scopes: Vec::new(),
239 code_verifier: None,
240 authorization_code: Some(authorization_code.into()),
241 }
242 }
243
244 pub fn with_tokens(
246 token_url: impl Into<String>,
247 client_id: impl Into<String>,
248 client_secret: Option<String>,
249 access_token: impl Into<String>,
250 refresh_token: impl Into<String>,
251 ) -> Self {
252 Self {
253 token_url: token_url.into(),
254 client_id: client_id.into(),
255 client_secret,
256 grant_type: OAuth2GrantType::RefreshToken,
257 access_token: Some(access_token.into()),
258 refresh_token: Some(refresh_token.into()),
259 expires_at: None,
260 scopes: Vec::new(),
261 code_verifier: None,
262 authorization_code: None,
263 }
264 }
265
266 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
268 self.scopes = scopes;
269 self
270 }
271
272 pub fn with_pkce(mut self, code_verifier: impl Into<String>) -> Self {
274 self.code_verifier = Some(code_verifier.into());
275 self
276 }
277
278 pub fn is_token_expired(&self) -> bool {
280 if let Some(expires_at) = self.expires_at {
281 let now = std::time::SystemTime::now()
282 .duration_since(std::time::UNIX_EPOCH)
283 .unwrap()
284 .as_secs() as i64;
285 now >= expires_at - 60
287 } else {
288 false
290 }
291 }
292
293 pub fn header_value(&self) -> Option<String> {
295 self.access_token
296 .as_ref()
297 .map(|token| format!("Bearer {}", token))
298 }
299
300 pub async fn request_token(&mut self) -> Result<()> {
302 let client = reqwest::Client::new();
303 let mut params: Vec<(String, String)> = Vec::new();
304
305 match self.grant_type {
306 OAuth2GrantType::ClientCredentials => {
307 params.push(("grant_type".to_string(), "client_credentials".to_string()));
308 params.push(("client_id".to_string(), self.client_id.clone()));
309 if let Some(ref secret) = self.client_secret {
310 params.push(("client_secret".to_string(), secret.clone()));
311 }
312 if !self.scopes.is_empty() {
313 params.push(("scope".to_string(), self.scopes.join(" ")));
314 }
315 }
316 OAuth2GrantType::AuthorizationCode => {
317 params.push(("grant_type".to_string(), "authorization_code".to_string()));
318 params.push(("client_id".to_string(), self.client_id.clone()));
319 if let Some(ref secret) = self.client_secret {
320 params.push(("client_secret".to_string(), secret.clone()));
321 }
322 if let Some(ref code) = self.authorization_code {
323 params.push(("code".to_string(), code.clone()));
324 }
325 if let Some(ref verifier) = self.code_verifier {
326 params.push(("code_verifier".to_string(), verifier.clone()));
327 }
328 }
329 OAuth2GrantType::RefreshToken => {
330 params.push(("grant_type".to_string(), "refresh_token".to_string()));
331 params.push(("client_id".to_string(), self.client_id.clone()));
332 if let Some(ref secret) = self.client_secret {
333 params.push(("client_secret".to_string(), secret.clone()));
334 }
335 if let Some(ref refresh) = self.refresh_token {
336 params.push(("refresh_token".to_string(), refresh.clone()));
337 }
338 }
339 }
340
341 let response = client
342 .post(&self.token_url)
343 .form(¶ms)
344 .send()
345 .await
346 .map_err(|e| McpError::ServerError(format!("OAuth2 token request failed: {}", e)))?;
347
348 if !response.status().is_success() {
349 return Err(McpError::ServerError(format!(
350 "OAuth2 token request failed with status: {}",
351 response.status()
352 )));
353 }
354
355 let token_response: Value = response.json().await.map_err(|e| {
356 McpError::ProtocolError(format!("Failed to parse OAuth2 token response: {}", e))
357 })?;
358
359 if let Some(access_token) = token_response.get("access_token").and_then(|v| v.as_str()) {
361 self.access_token = Some(access_token.to_string());
362 } else {
363 return Err(McpError::ProtocolError(
364 "OAuth2 response missing access_token".to_string(),
365 ));
366 }
367
368 if let Some(refresh_token) = token_response.get("refresh_token").and_then(|v| v.as_str()) {
370 self.refresh_token = Some(refresh_token.to_string());
371 }
372
373 if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_i64()) {
375 let now = std::time::SystemTime::now()
376 .duration_since(std::time::UNIX_EPOCH)
377 .unwrap()
378 .as_secs() as i64;
379 self.expires_at = Some(now + expires_in);
380 }
381
382 Ok(())
383 }
384
385 pub async fn refresh_if_needed(&mut self) -> Result<bool> {
387 if self.is_token_expired() && self.refresh_token.is_some() {
388 let original_grant = self.grant_type.clone();
390 self.grant_type = OAuth2GrantType::RefreshToken;
391
392 let result = self.request_token().await;
393
394 self.grant_type = original_grant;
396
397 result?;
398 Ok(true) } else {
400 Ok(false) }
402 }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct AuthConfig {
408 pub method: AuthMethod,
410 pub scopes: Vec<String>,
412}
413
414impl AuthConfig {
415 pub fn none() -> Self {
417 Self {
418 method: AuthMethod::None,
419 scopes: Vec::new(),
420 }
421 }
422
423 pub fn api_key(header_name: impl Into<String>, api_key: impl Into<String>) -> Self {
425 Self {
426 method: AuthMethod::ApiKey(ApiKeyAuth::new(header_name, api_key)),
427 scopes: Vec::new(),
428 }
429 }
430
431 pub fn api_key_with_prefix(
433 header_name: impl Into<String>,
434 api_key: impl Into<String>,
435 prefix: impl Into<String>,
436 ) -> Self {
437 Self {
438 method: AuthMethod::ApiKey(ApiKeyAuth::new(header_name, api_key).with_prefix(prefix)),
439 scopes: Vec::new(),
440 }
441 }
442
443 pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
445 Self {
446 method: AuthMethod::Basic(BasicAuth::new(username, password)),
447 scopes: Vec::new(),
448 }
449 }
450
451 pub fn bearer_token(token: impl Into<String>) -> Self {
453 Self {
454 method: AuthMethod::Bearer(BearerAuth::new(token)),
455 scopes: Vec::new(),
456 }
457 }
458
459 pub fn custom_headers(headers: HashMap<String, String>) -> Self {
461 Self {
462 method: AuthMethod::CustomHeader(CustomHeaderAuth { headers }),
463 scopes: Vec::new(),
464 }
465 }
466
467 pub fn oauth2_client_credentials(
469 token_url: impl Into<String>,
470 client_id: impl Into<String>,
471 client_secret: impl Into<String>,
472 ) -> Self {
473 Self {
474 method: AuthMethod::OAuth2(OAuth2Auth::client_credentials(
475 token_url,
476 client_id,
477 client_secret,
478 )),
479 scopes: Vec::new(),
480 }
481 }
482
483 pub fn oauth2_authorization_code(
485 token_url: impl Into<String>,
486 client_id: impl Into<String>,
487 client_secret: Option<String>,
488 authorization_code: impl Into<String>,
489 ) -> Self {
490 Self {
491 method: AuthMethod::OAuth2(OAuth2Auth::authorization_code(
492 token_url,
493 client_id,
494 client_secret,
495 authorization_code,
496 )),
497 scopes: Vec::new(),
498 }
499 }
500
501 pub fn oauth2_with_tokens(
503 token_url: impl Into<String>,
504 client_id: impl Into<String>,
505 client_secret: Option<String>,
506 access_token: impl Into<String>,
507 refresh_token: impl Into<String>,
508 ) -> Self {
509 Self {
510 method: AuthMethod::OAuth2(OAuth2Auth::with_tokens(
511 token_url,
512 client_id,
513 client_secret,
514 access_token,
515 refresh_token,
516 )),
517 scopes: Vec::new(),
518 }
519 }
520
521 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
523 self.scopes = scopes;
524 self
525 }
526
527 pub fn build_headers(&self) -> Result<HeaderMap> {
529 let mut headers = HeaderMap::new();
530
531 match &self.method {
532 AuthMethod::None => {}
533 AuthMethod::ApiKey(auth) => {
534 let header_name = HeaderName::from_str(&auth.header_name)
535 .map_err(|e| McpError::InvalidRequest(format!("Invalid header name: {}", e)))?;
536 let header_value = HeaderValue::from_str(&auth.header_value()).map_err(|e| {
537 McpError::InvalidRequest(format!("Invalid header value: {}", e))
538 })?;
539 headers.insert(header_name, header_value);
540 }
541 AuthMethod::Basic(auth) => {
542 let header_value = HeaderValue::from_str(&auth.header_value()).map_err(|e| {
543 McpError::InvalidRequest(format!("Invalid header value: {}", e))
544 })?;
545 headers.insert(reqwest::header::AUTHORIZATION, header_value);
546 }
547 AuthMethod::Bearer(auth) => {
548 let header_value = HeaderValue::from_str(&auth.header_value()).map_err(|e| {
549 McpError::InvalidRequest(format!("Invalid header value: {}", e))
550 })?;
551 headers.insert(reqwest::header::AUTHORIZATION, header_value);
552 }
553 AuthMethod::CustomHeader(auth) => {
554 for (name, value) in &auth.headers {
555 let header_name = HeaderName::from_str(name).map_err(|e| {
556 McpError::InvalidRequest(format!("Invalid header name: {}", e))
557 })?;
558 let header_value = HeaderValue::from_str(value).map_err(|e| {
559 McpError::InvalidRequest(format!("Invalid header value: {}", e))
560 })?;
561 headers.insert(header_name, header_value);
562 }
563 }
564 AuthMethod::OAuth2(auth) => {
565 if let Some(header_value_str) = auth.header_value() {
566 let header_value = HeaderValue::from_str(&header_value_str).map_err(|e| {
567 McpError::InvalidRequest(format!("Invalid header value: {}", e))
568 })?;
569 headers.insert(reqwest::header::AUTHORIZATION, header_value);
570 }
571 }
572 }
573
574 Ok(headers)
575 }
576}
577
578impl Default for AuthConfig {
579 fn default() -> Self {
580 Self::none()
581 }
582}
583
584pub struct AuthenticatedHttpTransport {
586 client: reqwest::Client,
587 base_url: String,
588 auth: AuthConfig,
589 request_id: u64,
590 max_response_size: usize,
591}
592
593impl AuthenticatedHttpTransport {
594 const DEFAULT_MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
596
597 pub fn new(base_url: impl Into<String>, auth: AuthConfig) -> Result<Self> {
599 let headers = auth.build_headers()?;
600
601 let client = reqwest::Client::builder()
602 .timeout(std::time::Duration::from_secs(30))
603 .default_headers(headers)
604 .build()
605 .map_err(|e| McpError::ServerError(format!("Failed to build HTTP client: {}", e)))?;
606
607 Ok(Self {
608 client,
609 base_url: base_url.into(),
610 auth,
611 request_id: 1,
612 max_response_size: Self::DEFAULT_MAX_RESPONSE_SIZE,
613 })
614 }
615
616 pub fn with_max_response_size(mut self, size: usize) -> Self {
618 self.max_response_size = size;
619 self
620 }
621
622 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Result<Self> {
624 let headers = self.auth.build_headers()?;
625 self.client = reqwest::Client::builder()
626 .timeout(timeout)
627 .default_headers(headers)
628 .build()
629 .map_err(|e| McpError::ServerError(format!("Failed to rebuild HTTP client: {}", e)))?;
630 Ok(self)
631 }
632
633 pub fn auth_config(&self) -> &AuthConfig {
635 &self.auth
636 }
637}
638
639#[async_trait]
640impl McpTransport for AuthenticatedHttpTransport {
641 async fn send_request(&mut self, mut request: Value) -> Result<Value> {
642 if let Value::Object(ref mut obj) = request {
644 obj.insert("jsonrpc".to_string(), Value::String("2.0".to_string()));
645 obj.insert("id".to_string(), Value::Number(self.request_id.into()));
646 self.request_id += 1;
647 }
648
649 let response = self
650 .client
651 .post(&self.base_url)
652 .json(&request)
653 .send()
654 .await
655 .map_err(|e| McpError::ServerError(format!("HTTP request failed: {}", e)))?;
656
657 if response.status() == reqwest::StatusCode::UNAUTHORIZED {
659 return Err(McpError::ServerError(
660 "Authentication failed: Invalid or missing credentials".to_string(),
661 ));
662 }
663
664 if response.status() == reqwest::StatusCode::FORBIDDEN {
665 return Err(McpError::ServerError(
666 "Authorization failed: Insufficient permissions".to_string(),
667 ));
668 }
669
670 let response_json: Value = response
671 .json()
672 .await
673 .map_err(|e| McpError::ProtocolError(format!("Failed to parse response: {}", e)))?;
674
675 Ok(response_json)
676 }
677
678 async fn close(&mut self) -> Result<()> {
679 Ok(())
680 }
681}
682
683#[derive(Debug, Clone, Default)]
685pub struct CredentialStore {
686 credentials: HashMap<String, AuthConfig>,
687}
688
689impl CredentialStore {
690 pub fn new() -> Self {
692 Self {
693 credentials: HashMap::new(),
694 }
695 }
696
697 pub fn add(&mut self, server_id: impl Into<String>, auth: AuthConfig) {
699 self.credentials.insert(server_id.into(), auth);
700 }
701
702 pub fn get(&self, server_id: &str) -> Option<&AuthConfig> {
704 self.credentials.get(server_id)
705 }
706
707 pub fn remove(&mut self, server_id: &str) -> Option<AuthConfig> {
709 self.credentials.remove(server_id)
710 }
711
712 pub fn has(&self, server_id: &str) -> bool {
714 self.credentials.contains_key(server_id)
715 }
716
717 pub fn server_ids(&self) -> Vec<&String> {
719 self.credentials.keys().collect()
720 }
721
722 pub fn len(&self) -> usize {
724 self.credentials.len()
725 }
726
727 pub fn is_empty(&self) -> bool {
729 self.credentials.is_empty()
730 }
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736
737 #[test]
738 fn test_api_key_auth() {
739 let auth = ApiKeyAuth::new("X-API-Key", "secret123");
740 assert_eq!(auth.header_name, "X-API-Key");
741 assert_eq!(auth.api_key, "secret123");
742 assert_eq!(auth.header_value(), "secret123");
743 }
744
745 #[test]
746 fn test_api_key_auth_with_prefix() {
747 let auth = ApiKeyAuth::new("Authorization", "secret123").with_prefix("Api-Key");
748 assert_eq!(auth.header_value(), "Api-Key secret123");
749 }
750
751 #[test]
752 fn test_basic_auth() {
753 let auth = BasicAuth::new("user", "pass");
754 assert_eq!(auth.username, "user");
755 assert_eq!(auth.password, "pass");
756 assert_eq!(auth.header_value(), "Basic dXNlcjpwYXNz");
758 }
759
760 #[test]
761 fn test_bearer_auth() {
762 let auth = BearerAuth::new("jwt-token-here");
763 assert_eq!(auth.token, "jwt-token-here");
764 assert_eq!(auth.header_value(), "Bearer jwt-token-here");
765 }
766
767 #[test]
768 fn test_auth_config_none() {
769 let config = AuthConfig::none();
770 matches!(config.method, AuthMethod::None);
771 }
772
773 #[test]
774 fn test_auth_config_api_key() {
775 let config = AuthConfig::api_key("X-API-Key", "secret");
776 if let AuthMethod::ApiKey(auth) = config.method {
777 assert_eq!(auth.header_name, "X-API-Key");
778 assert_eq!(auth.api_key, "secret");
779 } else {
780 panic!("Expected ApiKey auth method");
781 }
782 }
783
784 #[test]
785 fn test_auth_config_bearer() {
786 let config = AuthConfig::bearer_token("token123");
787 if let AuthMethod::Bearer(auth) = config.method {
788 assert_eq!(auth.token, "token123");
789 } else {
790 panic!("Expected Bearer auth method");
791 }
792 }
793
794 #[test]
795 fn test_auth_config_basic() {
796 let config = AuthConfig::basic("user", "pass");
797 if let AuthMethod::Basic(auth) = config.method {
798 assert_eq!(auth.username, "user");
799 assert_eq!(auth.password, "pass");
800 } else {
801 panic!("Expected Basic auth method");
802 }
803 }
804
805 #[test]
806 fn test_auth_config_with_scopes() {
807 let config = AuthConfig::bearer_token("token")
808 .with_scopes(vec!["read".to_string(), "write".to_string()]);
809 assert_eq!(config.scopes.len(), 2);
810 assert!(config.scopes.contains(&"read".to_string()));
811 assert!(config.scopes.contains(&"write".to_string()));
812 }
813
814 #[test]
815 fn test_build_headers_api_key() {
816 let config = AuthConfig::api_key("X-API-Key", "secret");
817 let headers = config.build_headers().unwrap();
818 assert!(headers.contains_key("x-api-key"));
819 assert_eq!(headers.get("x-api-key").unwrap(), "secret");
820 }
821
822 #[test]
823 fn test_build_headers_bearer() {
824 let config = AuthConfig::bearer_token("token123");
825 let headers = config.build_headers().unwrap();
826 assert!(headers.contains_key("authorization"));
827 assert_eq!(headers.get("authorization").unwrap(), "Bearer token123");
828 }
829
830 #[test]
831 fn test_credential_store() {
832 let mut store = CredentialStore::new();
833 assert!(store.is_empty());
834
835 store.add("server1", AuthConfig::api_key("X-API-Key", "key1"));
836 store.add("server2", AuthConfig::bearer_token("token2"));
837
838 assert_eq!(store.len(), 2);
839 assert!(store.has("server1"));
840 assert!(!store.has("server3"));
841
842 let auth = store.get("server1").unwrap();
843 matches!(auth.method, AuthMethod::ApiKey(_));
844
845 store.remove("server1");
846 assert!(!store.has("server1"));
847 assert_eq!(store.len(), 1);
848 }
849
850 #[test]
851 fn test_custom_header_auth() {
852 let auth = CustomHeaderAuth::new()
853 .add_header("X-Custom-Header", "value1")
854 .add_header("X-Another-Header", "value2");
855
856 assert_eq!(auth.headers.len(), 2);
857 assert_eq!(
858 auth.headers.get("X-Custom-Header"),
859 Some(&"value1".to_string())
860 );
861 }
862
863 #[test]
864 fn test_build_headers_custom() {
865 let mut headers = HashMap::new();
866 headers.insert("X-Custom-1".to_string(), "value1".to_string());
867 headers.insert("X-Custom-2".to_string(), "value2".to_string());
868
869 let config = AuthConfig::custom_headers(headers);
870 let built = config.build_headers().unwrap();
871
872 assert!(built.contains_key("x-custom-1"));
873 assert!(built.contains_key("x-custom-2"));
874 }
875
876 #[test]
877 fn test_oauth2_client_credentials() {
878 let auth = OAuth2Auth::client_credentials(
879 "https://auth.example.com/token",
880 "client_id",
881 "client_secret",
882 );
883 assert_eq!(auth.token_url, "https://auth.example.com/token");
884 assert_eq!(auth.client_id, "client_id");
885 assert_eq!(auth.client_secret, Some("client_secret".to_string()));
886 assert_eq!(auth.grant_type, OAuth2GrantType::ClientCredentials);
887 assert!(auth.access_token.is_none());
888 }
889
890 #[test]
891 fn test_oauth2_authorization_code() {
892 let auth = OAuth2Auth::authorization_code(
893 "https://auth.example.com/token",
894 "client_id",
895 Some("client_secret".to_string()),
896 "auth_code_123",
897 );
898 assert_eq!(auth.grant_type, OAuth2GrantType::AuthorizationCode);
899 assert_eq!(auth.authorization_code, Some("auth_code_123".to_string()));
900 }
901
902 #[test]
903 fn test_oauth2_with_tokens() {
904 let auth = OAuth2Auth::with_tokens(
905 "https://auth.example.com/token",
906 "client_id",
907 Some("client_secret".to_string()),
908 "access_token_123",
909 "refresh_token_456",
910 );
911 assert_eq!(auth.grant_type, OAuth2GrantType::RefreshToken);
912 assert_eq!(auth.access_token, Some("access_token_123".to_string()));
913 assert_eq!(auth.refresh_token, Some("refresh_token_456".to_string()));
914 }
915
916 #[test]
917 fn test_oauth2_with_scopes() {
918 let auth = OAuth2Auth::client_credentials(
919 "https://auth.example.com/token",
920 "client_id",
921 "client_secret",
922 )
923 .with_scopes(vec!["read".to_string(), "write".to_string()]);
924
925 assert_eq!(auth.scopes.len(), 2);
926 assert!(auth.scopes.contains(&"read".to_string()));
927 }
928
929 #[test]
930 fn test_oauth2_with_pkce() {
931 let auth = OAuth2Auth::authorization_code(
932 "https://auth.example.com/token",
933 "client_id",
934 None,
935 "auth_code",
936 )
937 .with_pkce("code_verifier_123");
938
939 assert_eq!(auth.code_verifier, Some("code_verifier_123".to_string()));
940 }
941
942 #[test]
943 fn test_oauth2_header_value() {
944 let mut auth = OAuth2Auth::client_credentials(
945 "https://auth.example.com/token",
946 "client_id",
947 "client_secret",
948 );
949 assert!(auth.header_value().is_none());
950
951 auth.access_token = Some("test_token".to_string());
952 assert_eq!(auth.header_value(), Some("Bearer test_token".to_string()));
953 }
954
955 #[test]
956 fn test_oauth2_is_token_expired() {
957 let mut auth = OAuth2Auth::client_credentials(
958 "https://auth.example.com/token",
959 "client_id",
960 "client_secret",
961 );
962
963 assert!(!auth.is_token_expired());
965
966 auth.expires_at = Some(1000);
968 assert!(auth.is_token_expired());
969
970 let future = std::time::SystemTime::now()
972 .duration_since(std::time::UNIX_EPOCH)
973 .unwrap()
974 .as_secs() as i64
975 + 3600;
976 auth.expires_at = Some(future);
977 assert!(!auth.is_token_expired());
978 }
979
980 #[test]
981 fn test_auth_config_oauth2_client_credentials() {
982 let config = AuthConfig::oauth2_client_credentials(
983 "https://auth.example.com/token",
984 "client_id",
985 "client_secret",
986 );
987 if let AuthMethod::OAuth2(auth) = config.method {
988 assert_eq!(auth.grant_type, OAuth2GrantType::ClientCredentials);
989 } else {
990 panic!("Expected OAuth2 auth method");
991 }
992 }
993
994 #[test]
995 fn test_auth_config_oauth2_with_tokens() {
996 let config = AuthConfig::oauth2_with_tokens(
997 "https://auth.example.com/token",
998 "client_id",
999 Some("client_secret".to_string()),
1000 "access_token",
1001 "refresh_token",
1002 );
1003 if let AuthMethod::OAuth2(auth) = config.method {
1004 assert_eq!(auth.access_token, Some("access_token".to_string()));
1005 assert_eq!(auth.refresh_token, Some("refresh_token".to_string()));
1006 } else {
1007 panic!("Expected OAuth2 auth method");
1008 }
1009 }
1010
1011 #[test]
1012 fn test_build_headers_oauth2() {
1013 let mut auth = OAuth2Auth::client_credentials(
1014 "https://auth.example.com/token",
1015 "client_id",
1016 "client_secret",
1017 );
1018 auth.access_token = Some("test_access_token".to_string());
1019
1020 let config = AuthConfig {
1021 method: AuthMethod::OAuth2(auth),
1022 scopes: Vec::new(),
1023 };
1024
1025 let headers = config.build_headers().unwrap();
1026 assert!(headers.contains_key("authorization"));
1027 assert_eq!(
1028 headers.get("authorization").unwrap(),
1029 "Bearer test_access_token"
1030 );
1031 }
1032
1033 #[test]
1034 fn test_oauth2_grant_type_equality() {
1035 assert_eq!(
1036 OAuth2GrantType::ClientCredentials,
1037 OAuth2GrantType::ClientCredentials
1038 );
1039 assert_ne!(
1040 OAuth2GrantType::ClientCredentials,
1041 OAuth2GrantType::AuthorizationCode
1042 );
1043 }
1044}