oxify_mcp/
auth.rs

1//! Authentication support for MCP communication
2//!
3//! This module provides authentication mechanisms for secure MCP server communication.
4//!
5//! # Supported Authentication Methods
6//!
7//! - **API Key**: Pass API keys via headers or query parameters
8//! - **Basic Auth**: HTTP Basic authentication (username:password)
9//! - **Bearer Token**: OAuth2-style bearer token authentication
10//!
11//! # Example
12//!
13//! ```ignore
14//! use oxify_mcp::auth::{AuthConfig, ApiKeyAuth, AuthenticatedHttpTransport};
15//!
16//! // API Key authentication
17//! let auth = AuthConfig::api_key("X-API-Key", "your-api-key");
18//! let transport = AuthenticatedHttpTransport::new("http://localhost:3000", auth);
19//!
20//! // Bearer token authentication
21//! let auth = AuthConfig::bearer_token("your-jwt-token");
22//! let transport = AuthenticatedHttpTransport::new("http://localhost:3000", auth);
23//! ```
24
25use crate::{McpError, McpTransport, Result};
26use async_trait::async_trait;
27use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30use std::collections::HashMap;
31use std::str::FromStr;
32
33/// Authentication method for MCP servers
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum AuthMethod {
36    /// No authentication
37    None,
38    /// API key authentication
39    ApiKey(ApiKeyAuth),
40    /// HTTP Basic authentication
41    Basic(BasicAuth),
42    /// Bearer token authentication
43    Bearer(BearerAuth),
44    /// Custom header authentication
45    CustomHeader(CustomHeaderAuth),
46    /// OAuth2 authentication with token refresh
47    OAuth2(OAuth2Auth),
48}
49
50/// API Key authentication configuration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ApiKeyAuth {
53    /// Header name to use (e.g., "X-API-Key", "Authorization")
54    pub header_name: String,
55    /// The API key value
56    pub api_key: String,
57    /// Optional prefix (e.g., "Bearer", "Api-Key")
58    pub prefix: Option<String>,
59}
60
61impl ApiKeyAuth {
62    /// Create new API key authentication
63    pub fn new(header_name: impl Into<String>, api_key: impl Into<String>) -> Self {
64        Self {
65            header_name: header_name.into(),
66            api_key: api_key.into(),
67            prefix: None,
68        }
69    }
70
71    /// Add a prefix to the API key value
72    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
73        self.prefix = Some(prefix.into());
74        self
75    }
76
77    /// Get the formatted header value
78    pub fn header_value(&self) -> String {
79        match &self.prefix {
80            Some(prefix) => format!("{} {}", prefix, self.api_key),
81            None => self.api_key.clone(),
82        }
83    }
84}
85
86/// HTTP Basic authentication configuration
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct BasicAuth {
89    /// Username
90    pub username: String,
91    /// Password
92    pub password: String,
93}
94
95impl BasicAuth {
96    /// Create new Basic authentication
97    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
98        Self {
99            username: username.into(),
100            password: password.into(),
101        }
102    }
103
104    /// Get the encoded credentials
105    pub fn encoded_credentials(&self) -> String {
106        use base64::Engine;
107        let credentials = format!("{}:{}", self.username, self.password);
108        base64::engine::general_purpose::STANDARD.encode(credentials)
109    }
110
111    /// Get the Authorization header value
112    pub fn header_value(&self) -> String {
113        format!("Basic {}", self.encoded_credentials())
114    }
115}
116
117/// Bearer token authentication configuration
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct BearerAuth {
120    /// The bearer token (JWT or opaque token)
121    pub token: String,
122}
123
124impl BearerAuth {
125    /// Create new Bearer authentication
126    pub fn new(token: impl Into<String>) -> Self {
127        Self {
128            token: token.into(),
129        }
130    }
131
132    /// Get the Authorization header value
133    pub fn header_value(&self) -> String {
134        format!("Bearer {}", self.token)
135    }
136}
137
138/// Custom header authentication
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct CustomHeaderAuth {
141    /// Custom headers to add
142    pub headers: HashMap<String, String>,
143}
144
145impl CustomHeaderAuth {
146    /// Create new custom header authentication
147    pub fn new() -> Self {
148        Self {
149            headers: HashMap::new(),
150        }
151    }
152
153    /// Add a header
154    pub fn add_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
155        self.headers.insert(name.into(), value.into());
156        self
157    }
158}
159
160impl Default for CustomHeaderAuth {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166/// OAuth2 grant type
167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
168pub enum OAuth2GrantType {
169    /// Client credentials flow (machine-to-machine)
170    ClientCredentials,
171    /// Authorization code flow (with optional PKCE)
172    AuthorizationCode,
173    /// Refresh token flow
174    RefreshToken,
175}
176
177/// OAuth2 authentication configuration with token refresh support
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct OAuth2Auth {
180    /// OAuth2 token endpoint URL
181    pub token_url: String,
182    /// Client ID
183    pub client_id: String,
184    /// Client secret (optional for PKCE)
185    pub client_secret: Option<String>,
186    /// Grant type
187    pub grant_type: OAuth2GrantType,
188    /// Access token
189    pub access_token: Option<String>,
190    /// Refresh token
191    pub refresh_token: Option<String>,
192    /// Token expiry time (Unix timestamp)
193    pub expires_at: Option<i64>,
194    /// Requested scopes
195    pub scopes: Vec<String>,
196    /// PKCE code verifier (for authorization code flow)
197    pub code_verifier: Option<String>,
198    /// Authorization code (for authorization code flow)
199    pub authorization_code: Option<String>,
200}
201
202impl OAuth2Auth {
203    /// Create new OAuth2 authentication with client credentials
204    pub fn client_credentials(
205        token_url: impl Into<String>,
206        client_id: impl Into<String>,
207        client_secret: impl Into<String>,
208    ) -> Self {
209        Self {
210            token_url: token_url.into(),
211            client_id: client_id.into(),
212            client_secret: Some(client_secret.into()),
213            grant_type: OAuth2GrantType::ClientCredentials,
214            access_token: None,
215            refresh_token: None,
216            expires_at: None,
217            scopes: Vec::new(),
218            code_verifier: None,
219            authorization_code: None,
220        }
221    }
222
223    /// Create new OAuth2 authentication with authorization code
224    pub fn authorization_code(
225        token_url: impl Into<String>,
226        client_id: impl Into<String>,
227        client_secret: Option<String>,
228        authorization_code: impl Into<String>,
229    ) -> Self {
230        Self {
231            token_url: token_url.into(),
232            client_id: client_id.into(),
233            client_secret,
234            grant_type: OAuth2GrantType::AuthorizationCode,
235            access_token: None,
236            refresh_token: None,
237            expires_at: None,
238            scopes: Vec::new(),
239            code_verifier: None,
240            authorization_code: Some(authorization_code.into()),
241        }
242    }
243
244    /// Create OAuth2 authentication with existing access and refresh tokens
245    pub fn with_tokens(
246        token_url: impl Into<String>,
247        client_id: impl Into<String>,
248        client_secret: Option<String>,
249        access_token: impl Into<String>,
250        refresh_token: impl Into<String>,
251    ) -> Self {
252        Self {
253            token_url: token_url.into(),
254            client_id: client_id.into(),
255            client_secret,
256            grant_type: OAuth2GrantType::RefreshToken,
257            access_token: Some(access_token.into()),
258            refresh_token: Some(refresh_token.into()),
259            expires_at: None,
260            scopes: Vec::new(),
261            code_verifier: None,
262            authorization_code: None,
263        }
264    }
265
266    /// Add scopes to the OAuth2 request
267    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
268        self.scopes = scopes;
269        self
270    }
271
272    /// Set PKCE code verifier (for authorization code flow)
273    pub fn with_pkce(mut self, code_verifier: impl Into<String>) -> Self {
274        self.code_verifier = Some(code_verifier.into());
275        self
276    }
277
278    /// Check if the token is expired
279    pub fn is_token_expired(&self) -> bool {
280        if let Some(expires_at) = self.expires_at {
281            let now = std::time::SystemTime::now()
282                .duration_since(std::time::UNIX_EPOCH)
283                .unwrap()
284                .as_secs() as i64;
285            // Consider token expired 60 seconds before actual expiry (buffer)
286            now >= expires_at - 60
287        } else {
288            // If no expiry time set, assume not expired
289            false
290        }
291    }
292
293    /// Get the Authorization header value
294    pub fn header_value(&self) -> Option<String> {
295        self.access_token
296            .as_ref()
297            .map(|token| format!("Bearer {}", token))
298    }
299
300    /// Request a new access token from the OAuth2 server
301    pub async fn request_token(&mut self) -> Result<()> {
302        let client = reqwest::Client::new();
303        let mut params: Vec<(String, String)> = Vec::new();
304
305        match self.grant_type {
306            OAuth2GrantType::ClientCredentials => {
307                params.push(("grant_type".to_string(), "client_credentials".to_string()));
308                params.push(("client_id".to_string(), self.client_id.clone()));
309                if let Some(ref secret) = self.client_secret {
310                    params.push(("client_secret".to_string(), secret.clone()));
311                }
312                if !self.scopes.is_empty() {
313                    params.push(("scope".to_string(), self.scopes.join(" ")));
314                }
315            }
316            OAuth2GrantType::AuthorizationCode => {
317                params.push(("grant_type".to_string(), "authorization_code".to_string()));
318                params.push(("client_id".to_string(), self.client_id.clone()));
319                if let Some(ref secret) = self.client_secret {
320                    params.push(("client_secret".to_string(), secret.clone()));
321                }
322                if let Some(ref code) = self.authorization_code {
323                    params.push(("code".to_string(), code.clone()));
324                }
325                if let Some(ref verifier) = self.code_verifier {
326                    params.push(("code_verifier".to_string(), verifier.clone()));
327                }
328            }
329            OAuth2GrantType::RefreshToken => {
330                params.push(("grant_type".to_string(), "refresh_token".to_string()));
331                params.push(("client_id".to_string(), self.client_id.clone()));
332                if let Some(ref secret) = self.client_secret {
333                    params.push(("client_secret".to_string(), secret.clone()));
334                }
335                if let Some(ref refresh) = self.refresh_token {
336                    params.push(("refresh_token".to_string(), refresh.clone()));
337                }
338            }
339        }
340
341        let response = client
342            .post(&self.token_url)
343            .form(&params)
344            .send()
345            .await
346            .map_err(|e| McpError::ServerError(format!("OAuth2 token request failed: {}", e)))?;
347
348        if !response.status().is_success() {
349            return Err(McpError::ServerError(format!(
350                "OAuth2 token request failed with status: {}",
351                response.status()
352            )));
353        }
354
355        let token_response: Value = response.json().await.map_err(|e| {
356            McpError::ProtocolError(format!("Failed to parse OAuth2 token response: {}", e))
357        })?;
358
359        // Extract access token
360        if let Some(access_token) = token_response.get("access_token").and_then(|v| v.as_str()) {
361            self.access_token = Some(access_token.to_string());
362        } else {
363            return Err(McpError::ProtocolError(
364                "OAuth2 response missing access_token".to_string(),
365            ));
366        }
367
368        // Extract refresh token (if provided)
369        if let Some(refresh_token) = token_response.get("refresh_token").and_then(|v| v.as_str()) {
370            self.refresh_token = Some(refresh_token.to_string());
371        }
372
373        // Extract expiry time
374        if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_i64()) {
375            let now = std::time::SystemTime::now()
376                .duration_since(std::time::UNIX_EPOCH)
377                .unwrap()
378                .as_secs() as i64;
379            self.expires_at = Some(now + expires_in);
380        }
381
382        Ok(())
383    }
384
385    /// Refresh the access token if expired
386    pub async fn refresh_if_needed(&mut self) -> Result<bool> {
387        if self.is_token_expired() && self.refresh_token.is_some() {
388            // Temporarily switch to refresh token grant type
389            let original_grant = self.grant_type.clone();
390            self.grant_type = OAuth2GrantType::RefreshToken;
391
392            let result = self.request_token().await;
393
394            // Restore original grant type
395            self.grant_type = original_grant;
396
397            result?;
398            Ok(true) // Token was refreshed
399        } else {
400            Ok(false) // No refresh needed
401        }
402    }
403}
404
405/// Authentication configuration wrapper
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct AuthConfig {
408    /// The authentication method
409    pub method: AuthMethod,
410    /// Optional scopes or permissions
411    pub scopes: Vec<String>,
412}
413
414impl AuthConfig {
415    /// Create authentication configuration with no auth
416    pub fn none() -> Self {
417        Self {
418            method: AuthMethod::None,
419            scopes: Vec::new(),
420        }
421    }
422
423    /// Create API key authentication
424    pub fn api_key(header_name: impl Into<String>, api_key: impl Into<String>) -> Self {
425        Self {
426            method: AuthMethod::ApiKey(ApiKeyAuth::new(header_name, api_key)),
427            scopes: Vec::new(),
428        }
429    }
430
431    /// Create API key authentication with prefix
432    pub fn api_key_with_prefix(
433        header_name: impl Into<String>,
434        api_key: impl Into<String>,
435        prefix: impl Into<String>,
436    ) -> Self {
437        Self {
438            method: AuthMethod::ApiKey(ApiKeyAuth::new(header_name, api_key).with_prefix(prefix)),
439            scopes: Vec::new(),
440        }
441    }
442
443    /// Create Basic authentication
444    pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
445        Self {
446            method: AuthMethod::Basic(BasicAuth::new(username, password)),
447            scopes: Vec::new(),
448        }
449    }
450
451    /// Create Bearer token authentication
452    pub fn bearer_token(token: impl Into<String>) -> Self {
453        Self {
454            method: AuthMethod::Bearer(BearerAuth::new(token)),
455            scopes: Vec::new(),
456        }
457    }
458
459    /// Create custom header authentication
460    pub fn custom_headers(headers: HashMap<String, String>) -> Self {
461        Self {
462            method: AuthMethod::CustomHeader(CustomHeaderAuth { headers }),
463            scopes: Vec::new(),
464        }
465    }
466
467    /// Create OAuth2 authentication with client credentials
468    pub fn oauth2_client_credentials(
469        token_url: impl Into<String>,
470        client_id: impl Into<String>,
471        client_secret: impl Into<String>,
472    ) -> Self {
473        Self {
474            method: AuthMethod::OAuth2(OAuth2Auth::client_credentials(
475                token_url,
476                client_id,
477                client_secret,
478            )),
479            scopes: Vec::new(),
480        }
481    }
482
483    /// Create OAuth2 authentication with authorization code
484    pub fn oauth2_authorization_code(
485        token_url: impl Into<String>,
486        client_id: impl Into<String>,
487        client_secret: Option<String>,
488        authorization_code: impl Into<String>,
489    ) -> Self {
490        Self {
491            method: AuthMethod::OAuth2(OAuth2Auth::authorization_code(
492                token_url,
493                client_id,
494                client_secret,
495                authorization_code,
496            )),
497            scopes: Vec::new(),
498        }
499    }
500
501    /// Create OAuth2 authentication with existing tokens
502    pub fn oauth2_with_tokens(
503        token_url: impl Into<String>,
504        client_id: impl Into<String>,
505        client_secret: Option<String>,
506        access_token: impl Into<String>,
507        refresh_token: impl Into<String>,
508    ) -> Self {
509        Self {
510            method: AuthMethod::OAuth2(OAuth2Auth::with_tokens(
511                token_url,
512                client_id,
513                client_secret,
514                access_token,
515                refresh_token,
516            )),
517            scopes: Vec::new(),
518        }
519    }
520
521    /// Add scopes to authentication
522    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
523        self.scopes = scopes;
524        self
525    }
526
527    /// Build request headers from authentication config
528    pub fn build_headers(&self) -> Result<HeaderMap> {
529        let mut headers = HeaderMap::new();
530
531        match &self.method {
532            AuthMethod::None => {}
533            AuthMethod::ApiKey(auth) => {
534                let header_name = HeaderName::from_str(&auth.header_name)
535                    .map_err(|e| McpError::InvalidRequest(format!("Invalid header name: {}", e)))?;
536                let header_value = HeaderValue::from_str(&auth.header_value()).map_err(|e| {
537                    McpError::InvalidRequest(format!("Invalid header value: {}", e))
538                })?;
539                headers.insert(header_name, header_value);
540            }
541            AuthMethod::Basic(auth) => {
542                let header_value = HeaderValue::from_str(&auth.header_value()).map_err(|e| {
543                    McpError::InvalidRequest(format!("Invalid header value: {}", e))
544                })?;
545                headers.insert(reqwest::header::AUTHORIZATION, header_value);
546            }
547            AuthMethod::Bearer(auth) => {
548                let header_value = HeaderValue::from_str(&auth.header_value()).map_err(|e| {
549                    McpError::InvalidRequest(format!("Invalid header value: {}", e))
550                })?;
551                headers.insert(reqwest::header::AUTHORIZATION, header_value);
552            }
553            AuthMethod::CustomHeader(auth) => {
554                for (name, value) in &auth.headers {
555                    let header_name = HeaderName::from_str(name).map_err(|e| {
556                        McpError::InvalidRequest(format!("Invalid header name: {}", e))
557                    })?;
558                    let header_value = HeaderValue::from_str(value).map_err(|e| {
559                        McpError::InvalidRequest(format!("Invalid header value: {}", e))
560                    })?;
561                    headers.insert(header_name, header_value);
562                }
563            }
564            AuthMethod::OAuth2(auth) => {
565                if let Some(header_value_str) = auth.header_value() {
566                    let header_value = HeaderValue::from_str(&header_value_str).map_err(|e| {
567                        McpError::InvalidRequest(format!("Invalid header value: {}", e))
568                    })?;
569                    headers.insert(reqwest::header::AUTHORIZATION, header_value);
570                }
571            }
572        }
573
574        Ok(headers)
575    }
576}
577
578impl Default for AuthConfig {
579    fn default() -> Self {
580        Self::none()
581    }
582}
583
584/// HTTP transport with authentication support
585pub struct AuthenticatedHttpTransport {
586    client: reqwest::Client,
587    base_url: String,
588    auth: AuthConfig,
589    request_id: u64,
590    max_response_size: usize,
591}
592
593impl AuthenticatedHttpTransport {
594    /// Default maximum response size (10MB)
595    const DEFAULT_MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
596
597    /// Create a new authenticated HTTP transport
598    pub fn new(base_url: impl Into<String>, auth: AuthConfig) -> Result<Self> {
599        let headers = auth.build_headers()?;
600
601        let client = reqwest::Client::builder()
602            .timeout(std::time::Duration::from_secs(30))
603            .default_headers(headers)
604            .build()
605            .map_err(|e| McpError::ServerError(format!("Failed to build HTTP client: {}", e)))?;
606
607        Ok(Self {
608            client,
609            base_url: base_url.into(),
610            auth,
611            request_id: 1,
612            max_response_size: Self::DEFAULT_MAX_RESPONSE_SIZE,
613        })
614    }
615
616    /// Set maximum response size
617    pub fn with_max_response_size(mut self, size: usize) -> Self {
618        self.max_response_size = size;
619        self
620    }
621
622    /// Set timeout
623    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Result<Self> {
624        let headers = self.auth.build_headers()?;
625        self.client = reqwest::Client::builder()
626            .timeout(timeout)
627            .default_headers(headers)
628            .build()
629            .map_err(|e| McpError::ServerError(format!("Failed to rebuild HTTP client: {}", e)))?;
630        Ok(self)
631    }
632
633    /// Get the current authentication config
634    pub fn auth_config(&self) -> &AuthConfig {
635        &self.auth
636    }
637}
638
639#[async_trait]
640impl McpTransport for AuthenticatedHttpTransport {
641    async fn send_request(&mut self, mut request: Value) -> Result<Value> {
642        // Add JSON-RPC fields
643        if let Value::Object(ref mut obj) = request {
644            obj.insert("jsonrpc".to_string(), Value::String("2.0".to_string()));
645            obj.insert("id".to_string(), Value::Number(self.request_id.into()));
646            self.request_id += 1;
647        }
648
649        let response = self
650            .client
651            .post(&self.base_url)
652            .json(&request)
653            .send()
654            .await
655            .map_err(|e| McpError::ServerError(format!("HTTP request failed: {}", e)))?;
656
657        // Check for auth errors
658        if response.status() == reqwest::StatusCode::UNAUTHORIZED {
659            return Err(McpError::ServerError(
660                "Authentication failed: Invalid or missing credentials".to_string(),
661            ));
662        }
663
664        if response.status() == reqwest::StatusCode::FORBIDDEN {
665            return Err(McpError::ServerError(
666                "Authorization failed: Insufficient permissions".to_string(),
667            ));
668        }
669
670        let response_json: Value = response
671            .json()
672            .await
673            .map_err(|e| McpError::ProtocolError(format!("Failed to parse response: {}", e)))?;
674
675        Ok(response_json)
676    }
677
678    async fn close(&mut self) -> Result<()> {
679        Ok(())
680    }
681}
682
683/// Credential store for managing multiple API keys
684#[derive(Debug, Clone, Default)]
685pub struct CredentialStore {
686    credentials: HashMap<String, AuthConfig>,
687}
688
689impl CredentialStore {
690    /// Create a new credential store
691    pub fn new() -> Self {
692        Self {
693            credentials: HashMap::new(),
694        }
695    }
696
697    /// Add credentials for a server
698    pub fn add(&mut self, server_id: impl Into<String>, auth: AuthConfig) {
699        self.credentials.insert(server_id.into(), auth);
700    }
701
702    /// Get credentials for a server
703    pub fn get(&self, server_id: &str) -> Option<&AuthConfig> {
704        self.credentials.get(server_id)
705    }
706
707    /// Remove credentials for a server
708    pub fn remove(&mut self, server_id: &str) -> Option<AuthConfig> {
709        self.credentials.remove(server_id)
710    }
711
712    /// Check if credentials exist for a server
713    pub fn has(&self, server_id: &str) -> bool {
714        self.credentials.contains_key(server_id)
715    }
716
717    /// List all server IDs with stored credentials
718    pub fn server_ids(&self) -> Vec<&String> {
719        self.credentials.keys().collect()
720    }
721
722    /// Get the number of stored credentials
723    pub fn len(&self) -> usize {
724        self.credentials.len()
725    }
726
727    /// Check if the store is empty
728    pub fn is_empty(&self) -> bool {
729        self.credentials.is_empty()
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736
737    #[test]
738    fn test_api_key_auth() {
739        let auth = ApiKeyAuth::new("X-API-Key", "secret123");
740        assert_eq!(auth.header_name, "X-API-Key");
741        assert_eq!(auth.api_key, "secret123");
742        assert_eq!(auth.header_value(), "secret123");
743    }
744
745    #[test]
746    fn test_api_key_auth_with_prefix() {
747        let auth = ApiKeyAuth::new("Authorization", "secret123").with_prefix("Api-Key");
748        assert_eq!(auth.header_value(), "Api-Key secret123");
749    }
750
751    #[test]
752    fn test_basic_auth() {
753        let auth = BasicAuth::new("user", "pass");
754        assert_eq!(auth.username, "user");
755        assert_eq!(auth.password, "pass");
756        // Base64 of "user:pass" is "dXNlcjpwYXNz"
757        assert_eq!(auth.header_value(), "Basic dXNlcjpwYXNz");
758    }
759
760    #[test]
761    fn test_bearer_auth() {
762        let auth = BearerAuth::new("jwt-token-here");
763        assert_eq!(auth.token, "jwt-token-here");
764        assert_eq!(auth.header_value(), "Bearer jwt-token-here");
765    }
766
767    #[test]
768    fn test_auth_config_none() {
769        let config = AuthConfig::none();
770        matches!(config.method, AuthMethod::None);
771    }
772
773    #[test]
774    fn test_auth_config_api_key() {
775        let config = AuthConfig::api_key("X-API-Key", "secret");
776        if let AuthMethod::ApiKey(auth) = config.method {
777            assert_eq!(auth.header_name, "X-API-Key");
778            assert_eq!(auth.api_key, "secret");
779        } else {
780            panic!("Expected ApiKey auth method");
781        }
782    }
783
784    #[test]
785    fn test_auth_config_bearer() {
786        let config = AuthConfig::bearer_token("token123");
787        if let AuthMethod::Bearer(auth) = config.method {
788            assert_eq!(auth.token, "token123");
789        } else {
790            panic!("Expected Bearer auth method");
791        }
792    }
793
794    #[test]
795    fn test_auth_config_basic() {
796        let config = AuthConfig::basic("user", "pass");
797        if let AuthMethod::Basic(auth) = config.method {
798            assert_eq!(auth.username, "user");
799            assert_eq!(auth.password, "pass");
800        } else {
801            panic!("Expected Basic auth method");
802        }
803    }
804
805    #[test]
806    fn test_auth_config_with_scopes() {
807        let config = AuthConfig::bearer_token("token")
808            .with_scopes(vec!["read".to_string(), "write".to_string()]);
809        assert_eq!(config.scopes.len(), 2);
810        assert!(config.scopes.contains(&"read".to_string()));
811        assert!(config.scopes.contains(&"write".to_string()));
812    }
813
814    #[test]
815    fn test_build_headers_api_key() {
816        let config = AuthConfig::api_key("X-API-Key", "secret");
817        let headers = config.build_headers().unwrap();
818        assert!(headers.contains_key("x-api-key"));
819        assert_eq!(headers.get("x-api-key").unwrap(), "secret");
820    }
821
822    #[test]
823    fn test_build_headers_bearer() {
824        let config = AuthConfig::bearer_token("token123");
825        let headers = config.build_headers().unwrap();
826        assert!(headers.contains_key("authorization"));
827        assert_eq!(headers.get("authorization").unwrap(), "Bearer token123");
828    }
829
830    #[test]
831    fn test_credential_store() {
832        let mut store = CredentialStore::new();
833        assert!(store.is_empty());
834
835        store.add("server1", AuthConfig::api_key("X-API-Key", "key1"));
836        store.add("server2", AuthConfig::bearer_token("token2"));
837
838        assert_eq!(store.len(), 2);
839        assert!(store.has("server1"));
840        assert!(!store.has("server3"));
841
842        let auth = store.get("server1").unwrap();
843        matches!(auth.method, AuthMethod::ApiKey(_));
844
845        store.remove("server1");
846        assert!(!store.has("server1"));
847        assert_eq!(store.len(), 1);
848    }
849
850    #[test]
851    fn test_custom_header_auth() {
852        let auth = CustomHeaderAuth::new()
853            .add_header("X-Custom-Header", "value1")
854            .add_header("X-Another-Header", "value2");
855
856        assert_eq!(auth.headers.len(), 2);
857        assert_eq!(
858            auth.headers.get("X-Custom-Header"),
859            Some(&"value1".to_string())
860        );
861    }
862
863    #[test]
864    fn test_build_headers_custom() {
865        let mut headers = HashMap::new();
866        headers.insert("X-Custom-1".to_string(), "value1".to_string());
867        headers.insert("X-Custom-2".to_string(), "value2".to_string());
868
869        let config = AuthConfig::custom_headers(headers);
870        let built = config.build_headers().unwrap();
871
872        assert!(built.contains_key("x-custom-1"));
873        assert!(built.contains_key("x-custom-2"));
874    }
875
876    #[test]
877    fn test_oauth2_client_credentials() {
878        let auth = OAuth2Auth::client_credentials(
879            "https://auth.example.com/token",
880            "client_id",
881            "client_secret",
882        );
883        assert_eq!(auth.token_url, "https://auth.example.com/token");
884        assert_eq!(auth.client_id, "client_id");
885        assert_eq!(auth.client_secret, Some("client_secret".to_string()));
886        assert_eq!(auth.grant_type, OAuth2GrantType::ClientCredentials);
887        assert!(auth.access_token.is_none());
888    }
889
890    #[test]
891    fn test_oauth2_authorization_code() {
892        let auth = OAuth2Auth::authorization_code(
893            "https://auth.example.com/token",
894            "client_id",
895            Some("client_secret".to_string()),
896            "auth_code_123",
897        );
898        assert_eq!(auth.grant_type, OAuth2GrantType::AuthorizationCode);
899        assert_eq!(auth.authorization_code, Some("auth_code_123".to_string()));
900    }
901
902    #[test]
903    fn test_oauth2_with_tokens() {
904        let auth = OAuth2Auth::with_tokens(
905            "https://auth.example.com/token",
906            "client_id",
907            Some("client_secret".to_string()),
908            "access_token_123",
909            "refresh_token_456",
910        );
911        assert_eq!(auth.grant_type, OAuth2GrantType::RefreshToken);
912        assert_eq!(auth.access_token, Some("access_token_123".to_string()));
913        assert_eq!(auth.refresh_token, Some("refresh_token_456".to_string()));
914    }
915
916    #[test]
917    fn test_oauth2_with_scopes() {
918        let auth = OAuth2Auth::client_credentials(
919            "https://auth.example.com/token",
920            "client_id",
921            "client_secret",
922        )
923        .with_scopes(vec!["read".to_string(), "write".to_string()]);
924
925        assert_eq!(auth.scopes.len(), 2);
926        assert!(auth.scopes.contains(&"read".to_string()));
927    }
928
929    #[test]
930    fn test_oauth2_with_pkce() {
931        let auth = OAuth2Auth::authorization_code(
932            "https://auth.example.com/token",
933            "client_id",
934            None,
935            "auth_code",
936        )
937        .with_pkce("code_verifier_123");
938
939        assert_eq!(auth.code_verifier, Some("code_verifier_123".to_string()));
940    }
941
942    #[test]
943    fn test_oauth2_header_value() {
944        let mut auth = OAuth2Auth::client_credentials(
945            "https://auth.example.com/token",
946            "client_id",
947            "client_secret",
948        );
949        assert!(auth.header_value().is_none());
950
951        auth.access_token = Some("test_token".to_string());
952        assert_eq!(auth.header_value(), Some("Bearer test_token".to_string()));
953    }
954
955    #[test]
956    fn test_oauth2_is_token_expired() {
957        let mut auth = OAuth2Auth::client_credentials(
958            "https://auth.example.com/token",
959            "client_id",
960            "client_secret",
961        );
962
963        // No expiry set
964        assert!(!auth.is_token_expired());
965
966        // Set expiry in the past
967        auth.expires_at = Some(1000);
968        assert!(auth.is_token_expired());
969
970        // Set expiry far in the future
971        let future = std::time::SystemTime::now()
972            .duration_since(std::time::UNIX_EPOCH)
973            .unwrap()
974            .as_secs() as i64
975            + 3600;
976        auth.expires_at = Some(future);
977        assert!(!auth.is_token_expired());
978    }
979
980    #[test]
981    fn test_auth_config_oauth2_client_credentials() {
982        let config = AuthConfig::oauth2_client_credentials(
983            "https://auth.example.com/token",
984            "client_id",
985            "client_secret",
986        );
987        if let AuthMethod::OAuth2(auth) = config.method {
988            assert_eq!(auth.grant_type, OAuth2GrantType::ClientCredentials);
989        } else {
990            panic!("Expected OAuth2 auth method");
991        }
992    }
993
994    #[test]
995    fn test_auth_config_oauth2_with_tokens() {
996        let config = AuthConfig::oauth2_with_tokens(
997            "https://auth.example.com/token",
998            "client_id",
999            Some("client_secret".to_string()),
1000            "access_token",
1001            "refresh_token",
1002        );
1003        if let AuthMethod::OAuth2(auth) = config.method {
1004            assert_eq!(auth.access_token, Some("access_token".to_string()));
1005            assert_eq!(auth.refresh_token, Some("refresh_token".to_string()));
1006        } else {
1007            panic!("Expected OAuth2 auth method");
1008        }
1009    }
1010
1011    #[test]
1012    fn test_build_headers_oauth2() {
1013        let mut auth = OAuth2Auth::client_credentials(
1014            "https://auth.example.com/token",
1015            "client_id",
1016            "client_secret",
1017        );
1018        auth.access_token = Some("test_access_token".to_string());
1019
1020        let config = AuthConfig {
1021            method: AuthMethod::OAuth2(auth),
1022            scopes: Vec::new(),
1023        };
1024
1025        let headers = config.build_headers().unwrap();
1026        assert!(headers.contains_key("authorization"));
1027        assert_eq!(
1028            headers.get("authorization").unwrap(),
1029            "Bearer test_access_token"
1030        );
1031    }
1032
1033    #[test]
1034    fn test_oauth2_grant_type_equality() {
1035        assert_eq!(
1036            OAuth2GrantType::ClientCredentials,
1037            OAuth2GrantType::ClientCredentials
1038        );
1039        assert_ne!(
1040            OAuth2GrantType::ClientCredentials,
1041            OAuth2GrantType::AuthorizationCode
1042        );
1043    }
1044}