turbomcp_auth/oauth2/
client.rs1use std::collections::HashMap;
12
13use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient};
14
15use turbomcp_protocol::{Error as McpError, Result as McpResult};
16
17use super::super::config::{OAuth2Config, ProviderConfig, ProviderType, RefreshBehavior};
18
19#[derive(Debug, Clone)]
21pub struct OAuth2Client {
22 pub(crate) auth_code_client: BasicClient,
24 pub(crate) client_credentials_client: Option<BasicClient>,
26 pub(crate) device_code_client: Option<BasicClient>,
28 pub provider_config: ProviderConfig,
30}
31
32impl OAuth2Client {
33 pub fn new(config: &OAuth2Config, provider_type: ProviderType) -> McpResult<Self> {
35 let auth_url = AuthUrl::new(config.auth_url.clone())
37 .map_err(|_| McpError::validation("Invalid authorization URL".to_string()))?;
38
39 let token_url = TokenUrl::new(config.token_url.clone())
40 .map_err(|_| McpError::validation("Invalid token URL".to_string()))?;
41
42 let redirect_url = Self::validate_redirect_uri(&config.redirect_uri)?;
44
45 let client_secret = if config.client_secret.is_empty() {
47 None
48 } else {
49 Some(ClientSecret::new(config.client_secret.clone()))
50 };
51
52 let auth_code_client = BasicClient::new(
53 ClientId::new(config.client_id.clone()),
54 client_secret.clone(),
55 auth_url.clone(),
56 Some(token_url.clone()),
57 )
58 .set_redirect_uri(redirect_url);
59
60 let client_credentials_client = if client_secret.is_some() {
62 Some(BasicClient::new(
63 ClientId::new(config.client_id.clone()),
64 client_secret.clone(),
65 auth_url.clone(),
66 Some(token_url.clone()),
67 ))
68 } else {
69 None
70 };
71
72 let device_code_client = Some(BasicClient::new(
74 ClientId::new(config.client_id.clone()),
75 client_secret,
76 auth_url,
77 Some(token_url),
78 ));
79
80 let provider_config = Self::build_provider_config(provider_type);
82
83 Ok(Self {
84 auth_code_client,
85 client_credentials_client,
86 device_code_client,
87 provider_config,
88 })
89 }
90
91 fn build_provider_config(provider_type: ProviderType) -> ProviderConfig {
93 match provider_type {
94 ProviderType::Google => ProviderConfig {
95 provider_type,
96 default_scopes: vec![
97 "openid".to_string(),
98 "email".to_string(),
99 "profile".to_string(),
100 ],
101 refresh_behavior: RefreshBehavior::Proactive,
102 userinfo_endpoint: Some(
103 "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
104 ),
105 additional_params: HashMap::new(),
106 },
107 ProviderType::Microsoft => ProviderConfig {
108 provider_type,
109 default_scopes: vec![
110 "openid".to_string(),
111 "profile".to_string(),
112 "email".to_string(),
113 "User.Read".to_string(),
114 ],
115 refresh_behavior: RefreshBehavior::Proactive,
116 userinfo_endpoint: Some("https://graph.microsoft.com/v1.0/me".to_string()),
117 additional_params: HashMap::new(),
118 },
119 ProviderType::GitHub => ProviderConfig {
120 provider_type,
121 default_scopes: vec!["user:email".to_string(), "read:user".to_string()],
122 refresh_behavior: RefreshBehavior::Reactive,
123 userinfo_endpoint: Some("https://api.github.com/user".to_string()),
124 additional_params: HashMap::new(),
125 },
126 ProviderType::GitLab => ProviderConfig {
127 provider_type,
128 default_scopes: vec!["read_user".to_string(), "openid".to_string()],
129 refresh_behavior: RefreshBehavior::Proactive,
130 userinfo_endpoint: Some("https://gitlab.com/api/v4/user".to_string()),
131 additional_params: HashMap::new(),
132 },
133 ProviderType::Generic | ProviderType::Custom(_) => ProviderConfig {
134 provider_type,
135 default_scopes: vec!["openid".to_string(), "profile".to_string()],
136 refresh_behavior: RefreshBehavior::Proactive,
137 userinfo_endpoint: None,
138 additional_params: HashMap::new(),
139 },
140 }
141 }
142
143 fn validate_redirect_uri(uri: &str) -> McpResult<RedirectUrl> {
150 use url::Url;
151
152 let parsed = Url::parse(uri)
154 .map_err(|e| McpError::validation(format!("Invalid redirect URI format: {e}")))?;
155
156 match parsed.scheme() {
158 "http" => {
159 if let Some(host) = parsed.host_str() {
161 let is_localhost = host == "localhost"
163 || host.starts_with("localhost:")
164 || host == "127.0.0.1"
165 || host.starts_with("127.0.0.1:")
166 || host == "0.0.0.0"
167 || host.starts_with("0.0.0.0:");
168
169 if !is_localhost {
170 return Err(McpError::validation(
171 "HTTP redirect URIs only allowed for localhost in development"
172 .to_string(),
173 ));
174 }
175 } else {
176 return Err(McpError::validation(
177 "Redirect URI must have a valid host".to_string(),
178 ));
179 }
180 }
181 "https" => {
182 }
184 "com.example.app" | "msauth" => {
185 }
187 scheme if scheme.starts_with("app.") || scheme.ends_with(".app") => {
188 }
190 _ => {
191 return Err(McpError::validation(format!(
192 "Unsupported redirect URI scheme: {}. Use https, http (localhost only), or app-specific schemes",
193 parsed.scheme()
194 )));
195 }
196 }
197
198 if parsed.fragment().is_some() {
200 return Err(McpError::validation(
201 "Redirect URI must not contain URL fragment".to_string(),
202 ));
203 }
204
205 if let Some(path) = parsed.path_segments() {
209 for segment in path {
210 if segment == ".." {
211 return Err(McpError::validation(
212 "Redirect URI path must not contain traversal sequences".to_string(),
213 ));
214 }
215 }
216 }
217
218 RedirectUrl::new(uri.to_string())
223 .map_err(|_| McpError::validation("Failed to create redirect URL".to_string()))
224 }
225
226 #[must_use]
228 pub fn auth_code_client(&self) -> &BasicClient {
229 &self.auth_code_client
230 }
231
232 #[must_use]
234 pub fn client_credentials_client(&self) -> Option<&BasicClient> {
235 self.client_credentials_client.as_ref()
236 }
237
238 #[must_use]
240 pub fn device_code_client(&self) -> Option<&BasicClient> {
241 self.device_code_client.as_ref()
242 }
243
244 #[must_use]
246 pub fn provider_config(&self) -> &ProviderConfig {
247 &self.provider_config
248 }
249}