turbomcp_auth/oauth2/
client.rs

1//! OAuth 2.0 Client Implementation
2//!
3//! This module provides a proven OAuth 2.0 client wrapper that supports:
4//! - Authorization Code flow (with PKCE)
5//! - Client Credentials flow (server-to-server)
6//! - Device Authorization flow (CLI/IoT)
7//!
8//! The client handles provider-specific configurations and quirks for
9//! Google, Microsoft, GitHub, GitLab, and generic OAuth providers.
10
11use std::collections::HashMap;
12
13use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient};
14
15use turbomcp_protocol::{Error as McpError, Result as McpResult};
16
17use super::super::config::{OAuth2Config, ProviderConfig, ProviderType, RefreshBehavior};
18
19/// Production-grade OAuth2 client wrapper supporting all modern flows
20#[derive(Debug, Clone)]
21pub struct OAuth2Client {
22    /// Authorization code flow client (most common)
23    pub(crate) auth_code_client: BasicClient,
24    /// Client credentials client (server-to-server)
25    pub(crate) client_credentials_client: Option<BasicClient>,
26    /// Device code client (for CLI/IoT applications)
27    pub(crate) device_code_client: Option<BasicClient>,
28    /// Provider-specific configuration
29    pub provider_config: ProviderConfig,
30}
31
32impl OAuth2Client {
33    /// Create a proven OAuth2 client supporting all flows
34    pub fn new(config: &OAuth2Config, provider_type: ProviderType) -> McpResult<Self> {
35        // Validate URLs
36        let auth_url = AuthUrl::new(config.auth_url.clone())
37            .map_err(|_| McpError::validation("Invalid authorization URL".to_string()))?;
38
39        let token_url = TokenUrl::new(config.token_url.clone())
40            .map_err(|_| McpError::validation("Invalid token URL".to_string()))?;
41
42        // Enhanced redirect URI validation with comprehensive security checks
43        let redirect_url = Self::validate_redirect_uri(&config.redirect_uri)?;
44
45        // Create authorization code flow client (primary)
46        let client_secret = if config.client_secret.is_empty() {
47            None
48        } else {
49            Some(ClientSecret::new(config.client_secret.clone()))
50        };
51
52        let auth_code_client = BasicClient::new(
53            ClientId::new(config.client_id.clone()),
54            client_secret.clone(),
55            auth_url.clone(),
56            Some(token_url.clone()),
57        )
58        .set_redirect_uri(redirect_url);
59
60        // Create client credentials client if we have a secret (server-to-server)
61        let client_credentials_client = if client_secret.is_some() {
62            Some(BasicClient::new(
63                ClientId::new(config.client_id.clone()),
64                client_secret.clone(),
65                auth_url.clone(),
66                Some(token_url.clone()),
67            ))
68        } else {
69            None
70        };
71
72        // Device code client (for CLI/IoT apps) - uses same configuration
73        let device_code_client = Some(BasicClient::new(
74            ClientId::new(config.client_id.clone()),
75            client_secret,
76            auth_url,
77            Some(token_url),
78        ));
79
80        // Provider-specific configuration
81        let provider_config = Self::build_provider_config(provider_type);
82
83        Ok(Self {
84            auth_code_client,
85            client_credentials_client,
86            device_code_client,
87            provider_config,
88        })
89    }
90
91    /// Build provider-specific configuration
92    fn build_provider_config(provider_type: ProviderType) -> ProviderConfig {
93        match provider_type {
94            ProviderType::Google => ProviderConfig {
95                provider_type,
96                default_scopes: vec![
97                    "openid".to_string(),
98                    "email".to_string(),
99                    "profile".to_string(),
100                ],
101                refresh_behavior: RefreshBehavior::Proactive,
102                userinfo_endpoint: Some(
103                    "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
104                ),
105                additional_params: HashMap::new(),
106            },
107            ProviderType::Microsoft => ProviderConfig {
108                provider_type,
109                default_scopes: vec![
110                    "openid".to_string(),
111                    "profile".to_string(),
112                    "email".to_string(),
113                    "User.Read".to_string(),
114                ],
115                refresh_behavior: RefreshBehavior::Proactive,
116                userinfo_endpoint: Some("https://graph.microsoft.com/v1.0/me".to_string()),
117                additional_params: HashMap::new(),
118            },
119            ProviderType::GitHub => ProviderConfig {
120                provider_type,
121                default_scopes: vec!["user:email".to_string(), "read:user".to_string()],
122                refresh_behavior: RefreshBehavior::Reactive,
123                userinfo_endpoint: Some("https://api.github.com/user".to_string()),
124                additional_params: HashMap::new(),
125            },
126            ProviderType::GitLab => ProviderConfig {
127                provider_type,
128                default_scopes: vec!["read_user".to_string(), "openid".to_string()],
129                refresh_behavior: RefreshBehavior::Proactive,
130                userinfo_endpoint: Some("https://gitlab.com/api/v4/user".to_string()),
131                additional_params: HashMap::new(),
132            },
133            ProviderType::Generic | ProviderType::Custom(_) => ProviderConfig {
134                provider_type,
135                default_scopes: vec!["openid".to_string(), "profile".to_string()],
136                refresh_behavior: RefreshBehavior::Proactive,
137                userinfo_endpoint: None,
138                additional_params: HashMap::new(),
139            },
140        }
141    }
142
143    /// Comprehensive redirect URI validation with security best practices
144    ///
145    /// Security considerations:
146    /// - Prevents open redirect attacks
147    /// - Validates URL format and structure
148    /// - Environment-aware validation (localhost for development)
149    fn validate_redirect_uri(uri: &str) -> McpResult<RedirectUrl> {
150        use url::Url;
151
152        // Parse and validate URL structure
153        let parsed = Url::parse(uri)
154            .map_err(|e| McpError::validation(format!("Invalid redirect URI format: {e}")))?;
155
156        // Security: Validate scheme
157        match parsed.scheme() {
158            "http" => {
159                // Only allow http for localhost/127.0.0.1/0.0.0.0 in development
160                if let Some(host) = parsed.host_str() {
161                    // Allow localhost, 127.0.0.1, 0.0.0.0 (bind all interfaces)
162                    let is_localhost = host == "localhost"
163                        || host.starts_with("localhost:")
164                        || host == "127.0.0.1"
165                        || host.starts_with("127.0.0.1:")
166                        || host == "0.0.0.0"
167                        || host.starts_with("0.0.0.0:");
168
169                    if !is_localhost {
170                        return Err(McpError::validation(
171                            "HTTP redirect URIs only allowed for localhost in development"
172                                .to_string(),
173                        ));
174                    }
175                } else {
176                    return Err(McpError::validation(
177                        "Redirect URI must have a valid host".to_string(),
178                    ));
179                }
180            }
181            "https" => {
182                // HTTPS is always allowed
183            }
184            "com.example.app" | "msauth" => {
185                // Allow custom schemes for mobile apps (common patterns)
186            }
187            scheme if scheme.starts_with("app.") || scheme.ends_with(".app") => {
188                // Allow app-specific custom schemes
189            }
190            _ => {
191                return Err(McpError::validation(format!(
192                    "Unsupported redirect URI scheme: {}. Use https, http (localhost only), or app-specific schemes",
193                    parsed.scheme()
194                )));
195            }
196        }
197
198        // Security: Prevent fragment in redirect URI (per OAuth 2.0 spec)
199        if parsed.fragment().is_some() {
200            return Err(McpError::validation(
201                "Redirect URI must not contain URL fragment".to_string(),
202            ));
203        }
204
205        // Security: Check for path traversal in PATH component only
206        // Note: url::Url::parse() already normalizes paths and removes .. segments
207        // We check the final path to ensure no traversal remains after normalization
208        if let Some(path) = parsed.path_segments() {
209            for segment in path {
210                if segment == ".." {
211                    return Err(McpError::validation(
212                        "Redirect URI path must not contain traversal sequences".to_string(),
213                    ));
214                }
215            }
216        }
217
218        // Industry Standard: Use oauth2 crate's RedirectUrl for validation
219        // This provides well-established URL validation per OAuth 2.0 specifications
220        // For production security, implement exact whitelist matching of allowed URIs
221        // (not pattern matching, which is error-prone per OAuth Security Best Practice RFC)
222        RedirectUrl::new(uri.to_string())
223            .map_err(|_| McpError::validation("Failed to create redirect URL".to_string()))
224    }
225
226    /// Get access to the authorization code client
227    #[must_use]
228    pub fn auth_code_client(&self) -> &BasicClient {
229        &self.auth_code_client
230    }
231
232    /// Get access to the client credentials client (if available)
233    #[must_use]
234    pub fn client_credentials_client(&self) -> Option<&BasicClient> {
235        self.client_credentials_client.as_ref()
236    }
237
238    /// Get access to the device code client (if available)
239    #[must_use]
240    pub fn device_code_client(&self) -> Option<&BasicClient> {
241        self.device_code_client.as_ref()
242    }
243
244    /// Get the provider configuration
245    #[must_use]
246    pub fn provider_config(&self) -> &ProviderConfig {
247        &self.provider_config
248    }
249}