pulseengine_mcp_auth/transport/
stdio_auth.rs1use super::auth_extractors::{
7 AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8 TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12
13#[derive(Debug, Clone)]
15pub struct StdioAuthConfig {
16 pub api_key_env_var: String,
18
19 pub allow_init_params: bool,
21
22 pub allow_process_args: bool,
24
25 pub default_api_key: Option<String>,
27
28 pub require_auth: bool,
30}
31
32impl Default for StdioAuthConfig {
33 fn default() -> Self {
34 Self {
35 api_key_env_var: "MCP_API_KEY".to_string(),
36 allow_init_params: true,
37 allow_process_args: false, default_api_key: None,
39 require_auth: false, }
41 }
42}
43
44pub struct StdioAuthExtractor {
46 config: StdioAuthConfig,
47}
48
49impl StdioAuthExtractor {
50 pub fn new(config: StdioAuthConfig) -> Self {
52 Self { config }
53 }
54
55 pub fn default() -> Self {
57 Self::new(StdioAuthConfig::default())
58 }
59
60 fn extract_env_auth(&self) -> AuthExtractionResult {
62 if let Ok(api_key) = std::env::var(&self.config.api_key_env_var) {
63 if !api_key.is_empty() {
64 AuthUtils::validate_api_key_format(&api_key)?;
65 let context = TransportAuthContext::new(
66 api_key,
67 "Environment".to_string(),
68 TransportType::Stdio,
69 );
70 return Ok(Some(context));
71 }
72 }
73
74 Ok(None)
75 }
76
77 fn extract_init_params(&self, request: &TransportRequest) -> AuthExtractionResult {
79 if !self.config.allow_init_params {
80 return Ok(None);
81 }
82
83 if let Some(body) = &request.body {
84 if let Some(params) = body.get("params") {
86 if let Some(api_key) = self.find_api_key_in_params(params) {
88 AuthUtils::validate_api_key_format(&api_key)?;
89 let context = TransportAuthContext::new(
90 api_key,
91 "InitParams".to_string(),
92 TransportType::Stdio,
93 );
94 return Ok(Some(context));
95 }
96 }
97 }
98
99 Ok(None)
100 }
101
102 fn find_api_key_in_params(&self, params: &Value) -> Option<String> {
104 if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
106 return Some(api_key.to_string());
107 }
108
109 if let Some(client_info) = params.get("clientInfo") {
111 if let Some(api_key) = client_info.get("api_key").and_then(|v| v.as_str()) {
112 return Some(api_key.to_string());
113 }
114
115 if let Some(capabilities) = client_info.get("capabilities") {
117 if let Some(auth) = capabilities.get("authentication") {
118 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
119 return Some(api_key.to_string());
120 }
121 }
122 }
123 }
124
125 if let Some(capabilities) = params.get("capabilities") {
127 if let Some(auth) = capabilities.get("authentication") {
128 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
129 return Some(api_key.to_string());
130 }
131 }
132 }
133
134 None
135 }
136
137 fn extract_process_args(&self) -> AuthExtractionResult {
139 if !self.config.allow_process_args {
140 return Ok(None);
141 }
142
143 let args: Vec<String> = std::env::args().collect();
144
145 for i in 0..args.len() {
147 if args[i] == "--api-key" && i + 1 < args.len() {
148 let api_key = &args[i + 1];
149 AuthUtils::validate_api_key_format(api_key)?;
150 let context = TransportAuthContext::new(
151 api_key.clone(),
152 "ProcessArgs".to_string(),
153 TransportType::Stdio,
154 );
155 return Ok(Some(context));
156 }
157
158 if let Some(key_value) = args[i].strip_prefix("--api-key=") {
160 AuthUtils::validate_api_key_format(key_value)?;
161 let context = TransportAuthContext::new(
162 key_value.to_string(),
163 "ProcessArgs".to_string(),
164 TransportType::Stdio,
165 );
166 return Ok(Some(context));
167 }
168 }
169
170 Ok(None)
171 }
172
173 fn extract_default_auth(&self) -> AuthExtractionResult {
175 if let Some(ref api_key) = self.config.default_api_key {
176 AuthUtils::validate_api_key_format(api_key)?;
177 let context = TransportAuthContext::new(
178 api_key.clone(),
179 "Default".to_string(),
180 TransportType::Stdio,
181 );
182 return Ok(Some(context));
183 }
184
185 Ok(None)
186 }
187
188 fn enrich_context(
190 &self,
191 mut context: TransportAuthContext,
192 _request: &TransportRequest,
193 ) -> TransportAuthContext {
194 if let Ok(current_exe) = std::env::current_exe() {
196 if let Some(exe_name) = current_exe.file_name().and_then(|n| n.to_str()) {
197 context = context.with_metadata("process".to_string(), exe_name.to_string());
198 }
199 }
200
201 if let Ok(cwd) = std::env::current_dir() {
203 context =
204 context.with_metadata("working_dir".to_string(), cwd.to_string_lossy().to_string());
205 }
206
207 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
209 context = context.with_metadata("user".to_string(), user);
210 }
211
212 context
213 }
214}
215
216#[async_trait]
217impl AuthExtractor for StdioAuthExtractor {
218 async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
219 if let Ok(Some(context)) = self.extract_env_auth() {
223 return Ok(Some(self.enrich_context(context, request)));
224 }
225
226 if let Ok(Some(context)) = self.extract_init_params(request) {
228 return Ok(Some(self.enrich_context(context, request)));
229 }
230
231 if let Ok(Some(context)) = self.extract_process_args() {
233 return Ok(Some(self.enrich_context(context, request)));
234 }
235
236 if let Ok(Some(context)) = self.extract_default_auth() {
238 return Ok(Some(self.enrich_context(context, request)));
239 }
240
241 if self.config.require_auth {
243 return Err(TransportAuthError::NoAuth);
244 }
245
246 Ok(None)
247 }
248
249 fn transport_type(&self) -> TransportType {
250 TransportType::Stdio
251 }
252
253 fn can_handle(&self, _request: &TransportRequest) -> bool {
254 true
256 }
257
258 async fn validate_auth(
259 &self,
260 context: &TransportAuthContext,
261 ) -> Result<(), TransportAuthError> {
262 if context.credential.is_empty() {
264 return Err(TransportAuthError::InvalidFormat(
265 "Empty credential".to_string(),
266 ));
267 }
268
269 if context.method == "Default" {
271 tracing::warn!(
272 "Using default API key for stdio authentication - not recommended for production"
273 );
274 }
275
276 Ok(())
277 }
278}
279
280impl StdioAuthConfig {
282 pub fn development() -> Self {
284 Self {
285 api_key_env_var: "MCP_API_KEY".to_string(),
286 allow_init_params: true,
287 allow_process_args: true,
288 default_api_key: Some("lmcp_dev_1234567890abcdef".to_string()),
289 require_auth: false,
290 }
291 }
292
293 pub fn production() -> Self {
295 Self {
296 api_key_env_var: "MCP_API_KEY".to_string(),
297 allow_init_params: true,
298 allow_process_args: false,
299 default_api_key: None,
300 require_auth: true,
301 }
302 }
303
304 pub fn secure() -> Self {
306 Self {
307 api_key_env_var: "MCP_API_KEY".to_string(),
308 allow_init_params: false,
309 allow_process_args: false,
310 default_api_key: None,
311 require_auth: true,
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use serde_json::json;
320
321 #[test]
322 fn test_environment_variable_extraction() {
323 unsafe {
325 std::env::set_var("TEST_MCP_API_KEY", "lmcp_test_1234567890abcdef");
326 }
327
328 let config = StdioAuthConfig {
329 api_key_env_var: "TEST_MCP_API_KEY".to_string(),
330 ..Default::default()
331 };
332 let extractor = StdioAuthExtractor::new(config);
333 let request = TransportRequest::new();
334
335 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
336
337 assert!(result.is_some());
338 let context = result.unwrap();
339 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
340 assert_eq!(context.method, "Environment");
341 assert_eq!(context.transport_type, TransportType::Stdio);
342
343 unsafe {
345 std::env::remove_var("TEST_MCP_API_KEY");
346 }
347 }
348
349 #[test]
350 fn test_init_params_extraction() {
351 let extractor = StdioAuthExtractor::default();
352
353 let init_request = json!({
354 "params": {
355 "api_key": "lmcp_test_1234567890abcdef",
356 "clientInfo": {
357 "name": "test-client"
358 }
359 }
360 });
361
362 let request = TransportRequest::new().with_body(init_request);
363 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
364
365 assert!(result.is_some());
366 let context = result.unwrap();
367 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
368 assert_eq!(context.method, "InitParams");
369 }
370
371 #[test]
372 fn test_nested_init_params_extraction() {
373 let extractor = StdioAuthExtractor::default();
374
375 let init_request = json!({
376 "params": {
377 "clientInfo": {
378 "name": "test-client",
379 "capabilities": {
380 "authentication": {
381 "api_key": "lmcp_test_1234567890abcdef"
382 }
383 }
384 }
385 }
386 });
387
388 let request = TransportRequest::new().with_body(init_request);
389 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
390
391 assert!(result.is_some());
392 let context = result.unwrap();
393 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
394 assert_eq!(context.method, "InitParams");
395 }
396
397 #[test]
398 fn test_default_api_key() {
399 let config = StdioAuthConfig {
400 default_api_key: Some("lmcp_default_1234567890abcdef".to_string()),
401 ..Default::default()
402 };
403 let extractor = StdioAuthExtractor::new(config);
404 let request = TransportRequest::new();
405
406 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
407
408 assert!(result.is_some());
409 let context = result.unwrap();
410 assert_eq!(context.credential, "lmcp_default_1234567890abcdef");
411 assert_eq!(context.method, "Default");
412 }
413
414 #[test]
415 fn test_no_authentication_required() {
416 let config = StdioAuthConfig {
417 require_auth: false,
418 ..Default::default()
419 };
420 let extractor = StdioAuthExtractor::new(config);
421 let request = TransportRequest::new();
422
423 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
424 assert!(result.is_none());
425 }
426
427 #[test]
428 fn test_authentication_required_but_missing() {
429 let config = StdioAuthConfig {
430 require_auth: true,
431 ..Default::default()
432 };
433 let extractor = StdioAuthExtractor::new(config);
434 let request = TransportRequest::new();
435
436 let result = tokio_test::block_on(extractor.extract_auth(&request));
437 assert!(result.is_err());
438 assert!(matches!(result.unwrap_err(), TransportAuthError::NoAuth));
439 }
440
441 #[test]
442 fn test_configuration_presets() {
443 let dev_config = StdioAuthConfig::development();
444 assert!(dev_config.allow_process_args);
445 assert!(dev_config.default_api_key.is_some());
446 assert!(!dev_config.require_auth);
447
448 let prod_config = StdioAuthConfig::production();
449 assert!(!prod_config.allow_process_args);
450 assert!(prod_config.default_api_key.is_none());
451 assert!(prod_config.require_auth);
452
453 let secure_config = StdioAuthConfig::secure();
454 assert!(!secure_config.allow_init_params);
455 assert!(!secure_config.allow_process_args);
456 assert!(secure_config.require_auth);
457 }
458}