pulseengine_mcp_auth/transport/
stdio_auth.rs

1//! Stdio Transport Authentication
2//!
3//! This module provides authentication for stdio-based MCP servers,
4//! typically used with Claude Desktop and CLI clients.
5
6use super::auth_extractors::{
7    AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8    TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12
13/// Configuration for stdio authentication
14#[derive(Debug, Clone)]
15pub struct StdioAuthConfig {
16    /// Environment variable name for API key
17    pub api_key_env_var: String,
18
19    /// Allow authentication through MCP initialize params
20    pub allow_init_params: bool,
21
22    /// Allow authentication through process arguments
23    pub allow_process_args: bool,
24
25    /// Default API key for development
26    pub default_api_key: Option<String>,
27
28    /// Require authentication for stdio
29    pub require_auth: bool,
30}
31
32impl Default for StdioAuthConfig {
33    fn default() -> Self {
34        Self {
35            api_key_env_var: "MCP_API_KEY".to_string(),
36            allow_init_params: true,
37            allow_process_args: false, // Security risk in production
38            default_api_key: None,
39            require_auth: false, // Often used locally
40        }
41    }
42}
43
44/// Stdio authentication extractor
45pub struct StdioAuthExtractor {
46    config: StdioAuthConfig,
47}
48
49impl StdioAuthExtractor {
50    /// Create a new stdio authentication extractor
51    pub fn new(config: StdioAuthConfig) -> Self {
52        Self { config }
53    }
54
55    /// Create with default configuration
56    pub fn default() -> Self {
57        Self::new(StdioAuthConfig::default())
58    }
59
60    /// Extract authentication from environment variables
61    fn extract_env_auth(&self) -> AuthExtractionResult {
62        if let Ok(api_key) = std::env::var(&self.config.api_key_env_var) {
63            if !api_key.is_empty() {
64                AuthUtils::validate_api_key_format(&api_key)?;
65                let context = TransportAuthContext::new(
66                    api_key,
67                    "Environment".to_string(),
68                    TransportType::Stdio,
69                );
70                return Ok(Some(context));
71            }
72        }
73
74        Ok(None)
75    }
76
77    /// Extract authentication from MCP initialize parameters
78    fn extract_init_params(&self, request: &TransportRequest) -> AuthExtractionResult {
79        if !self.config.allow_init_params {
80            return Ok(None);
81        }
82
83        if let Some(body) = &request.body {
84            // Look for authentication in initialize request params
85            if let Some(params) = body.get("params") {
86                // Check for API key in various locations
87                if let Some(api_key) = self.find_api_key_in_params(params) {
88                    AuthUtils::validate_api_key_format(&api_key)?;
89                    let context = TransportAuthContext::new(
90                        api_key,
91                        "InitParams".to_string(),
92                        TransportType::Stdio,
93                    );
94                    return Ok(Some(context));
95                }
96            }
97        }
98
99        Ok(None)
100    }
101
102    /// Find API key in various parameter structures
103    fn find_api_key_in_params(&self, params: &Value) -> Option<String> {
104        // Try direct api_key field
105        if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
106            return Some(api_key.to_string());
107        }
108
109        // Try nested clientInfo
110        if let Some(client_info) = params.get("clientInfo") {
111            if let Some(api_key) = client_info.get("api_key").and_then(|v| v.as_str()) {
112                return Some(api_key.to_string());
113            }
114
115            // Try in capabilities
116            if let Some(capabilities) = client_info.get("capabilities") {
117                if let Some(auth) = capabilities.get("authentication") {
118                    if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
119                        return Some(api_key.to_string());
120                    }
121                }
122            }
123        }
124
125        // Try in server capabilities/config
126        if let Some(capabilities) = params.get("capabilities") {
127            if let Some(auth) = capabilities.get("authentication") {
128                if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
129                    return Some(api_key.to_string());
130                }
131            }
132        }
133
134        None
135    }
136
137    /// Extract authentication from process arguments
138    fn extract_process_args(&self) -> AuthExtractionResult {
139        if !self.config.allow_process_args {
140            return Ok(None);
141        }
142
143        let args: Vec<String> = std::env::args().collect();
144
145        // Look for --api-key argument
146        for i in 0..args.len() {
147            if args[i] == "--api-key" && i + 1 < args.len() {
148                let api_key = &args[i + 1];
149                AuthUtils::validate_api_key_format(api_key)?;
150                let context = TransportAuthContext::new(
151                    api_key.clone(),
152                    "ProcessArgs".to_string(),
153                    TransportType::Stdio,
154                );
155                return Ok(Some(context));
156            }
157
158            // Look for --api-key=value format
159            if let Some(key_value) = args[i].strip_prefix("--api-key=") {
160                AuthUtils::validate_api_key_format(key_value)?;
161                let context = TransportAuthContext::new(
162                    key_value.to_string(),
163                    "ProcessArgs".to_string(),
164                    TransportType::Stdio,
165                );
166                return Ok(Some(context));
167            }
168        }
169
170        Ok(None)
171    }
172
173    /// Use default API key if configured
174    fn extract_default_auth(&self) -> AuthExtractionResult {
175        if let Some(ref api_key) = self.config.default_api_key {
176            AuthUtils::validate_api_key_format(api_key)?;
177            let context = TransportAuthContext::new(
178                api_key.clone(),
179                "Default".to_string(),
180                TransportType::Stdio,
181            );
182            return Ok(Some(context));
183        }
184
185        Ok(None)
186    }
187
188    /// Add stdio-specific context information
189    fn enrich_context(
190        &self,
191        mut context: TransportAuthContext,
192        _request: &TransportRequest,
193    ) -> TransportAuthContext {
194        // Add process information
195        if let Ok(current_exe) = std::env::current_exe() {
196            if let Some(exe_name) = current_exe.file_name().and_then(|n| n.to_str()) {
197                context = context.with_metadata("process".to_string(), exe_name.to_string());
198            }
199        }
200
201        // Add working directory
202        if let Ok(cwd) = std::env::current_dir() {
203            context =
204                context.with_metadata("working_dir".to_string(), cwd.to_string_lossy().to_string());
205        }
206
207        // Add user information if available
208        if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
209            context = context.with_metadata("user".to_string(), user);
210        }
211
212        context
213    }
214}
215
216#[async_trait]
217impl AuthExtractor for StdioAuthExtractor {
218    async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
219        // Try different authentication sources in order of preference
220
221        // 1. Environment variables
222        if let Ok(Some(context)) = self.extract_env_auth() {
223            return Ok(Some(self.enrich_context(context, request)));
224        }
225
226        // 2. MCP initialize parameters
227        if let Ok(Some(context)) = self.extract_init_params(request) {
228            return Ok(Some(self.enrich_context(context, request)));
229        }
230
231        // 3. Process arguments (if allowed)
232        if let Ok(Some(context)) = self.extract_process_args() {
233            return Ok(Some(self.enrich_context(context, request)));
234        }
235
236        // 4. Default API key (if configured)
237        if let Ok(Some(context)) = self.extract_default_auth() {
238            return Ok(Some(self.enrich_context(context, request)));
239        }
240
241        // No authentication found
242        if self.config.require_auth {
243            return Err(TransportAuthError::NoAuth);
244        }
245
246        Ok(None)
247    }
248
249    fn transport_type(&self) -> TransportType {
250        TransportType::Stdio
251    }
252
253    fn can_handle(&self, _request: &TransportRequest) -> bool {
254        // Stdio extractor can always attempt extraction
255        true
256    }
257
258    async fn validate_auth(
259        &self,
260        context: &TransportAuthContext,
261    ) -> Result<(), TransportAuthError> {
262        // Stdio-specific validation
263        if context.credential.is_empty() {
264            return Err(TransportAuthError::InvalidFormat(
265                "Empty credential".to_string(),
266            ));
267        }
268
269        // Additional validation for development environments
270        if context.method == "Default" {
271            tracing::warn!(
272                "Using default API key for stdio authentication - not recommended for production"
273            );
274        }
275
276        Ok(())
277    }
278}
279
280/// Helper for creating stdio authentication configuration
281impl StdioAuthConfig {
282    /// Create a development-friendly configuration
283    pub fn development() -> Self {
284        Self {
285            api_key_env_var: "MCP_API_KEY".to_string(),
286            allow_init_params: true,
287            allow_process_args: true,
288            default_api_key: Some("lmcp_dev_1234567890abcdef".to_string()),
289            require_auth: false,
290        }
291    }
292
293    /// Create a production configuration
294    pub fn production() -> Self {
295        Self {
296            api_key_env_var: "MCP_API_KEY".to_string(),
297            allow_init_params: true,
298            allow_process_args: false,
299            default_api_key: None,
300            require_auth: true,
301        }
302    }
303
304    /// Create a secure configuration (minimal attack surface)
305    pub fn secure() -> Self {
306        Self {
307            api_key_env_var: "MCP_API_KEY".to_string(),
308            allow_init_params: false,
309            allow_process_args: false,
310            default_api_key: None,
311            require_auth: true,
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use serde_json::json;
320
321    #[test]
322    fn test_environment_variable_extraction() {
323        // SAFETY: Setting test environment variable
324        unsafe {
325            std::env::set_var("TEST_MCP_API_KEY", "lmcp_test_1234567890abcdef");
326        }
327
328        let config = StdioAuthConfig {
329            api_key_env_var: "TEST_MCP_API_KEY".to_string(),
330            ..Default::default()
331        };
332        let extractor = StdioAuthExtractor::new(config);
333        let request = TransportRequest::new();
334
335        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
336
337        assert!(result.is_some());
338        let context = result.unwrap();
339        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
340        assert_eq!(context.method, "Environment");
341        assert_eq!(context.transport_type, TransportType::Stdio);
342
343        // SAFETY: Removing test environment variable
344        unsafe {
345            std::env::remove_var("TEST_MCP_API_KEY");
346        }
347    }
348
349    #[test]
350    fn test_init_params_extraction() {
351        let extractor = StdioAuthExtractor::default();
352
353        let init_request = json!({
354            "params": {
355                "api_key": "lmcp_test_1234567890abcdef",
356                "clientInfo": {
357                    "name": "test-client"
358                }
359            }
360        });
361
362        let request = TransportRequest::new().with_body(init_request);
363        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
364
365        assert!(result.is_some());
366        let context = result.unwrap();
367        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
368        assert_eq!(context.method, "InitParams");
369    }
370
371    #[test]
372    fn test_nested_init_params_extraction() {
373        let extractor = StdioAuthExtractor::default();
374
375        let init_request = json!({
376            "params": {
377                "clientInfo": {
378                    "name": "test-client",
379                    "capabilities": {
380                        "authentication": {
381                            "api_key": "lmcp_test_1234567890abcdef"
382                        }
383                    }
384                }
385            }
386        });
387
388        let request = TransportRequest::new().with_body(init_request);
389        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
390
391        assert!(result.is_some());
392        let context = result.unwrap();
393        assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
394        assert_eq!(context.method, "InitParams");
395    }
396
397    #[test]
398    fn test_default_api_key() {
399        let config = StdioAuthConfig {
400            default_api_key: Some("lmcp_default_1234567890abcdef".to_string()),
401            ..Default::default()
402        };
403        let extractor = StdioAuthExtractor::new(config);
404        let request = TransportRequest::new();
405
406        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
407
408        assert!(result.is_some());
409        let context = result.unwrap();
410        assert_eq!(context.credential, "lmcp_default_1234567890abcdef");
411        assert_eq!(context.method, "Default");
412    }
413
414    #[test]
415    fn test_no_authentication_required() {
416        let config = StdioAuthConfig {
417            require_auth: false,
418            ..Default::default()
419        };
420        let extractor = StdioAuthExtractor::new(config);
421        let request = TransportRequest::new();
422
423        let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
424        assert!(result.is_none());
425    }
426
427    #[test]
428    fn test_authentication_required_but_missing() {
429        let config = StdioAuthConfig {
430            require_auth: true,
431            ..Default::default()
432        };
433        let extractor = StdioAuthExtractor::new(config);
434        let request = TransportRequest::new();
435
436        let result = tokio_test::block_on(extractor.extract_auth(&request));
437        assert!(result.is_err());
438        assert!(matches!(result.unwrap_err(), TransportAuthError::NoAuth));
439    }
440
441    #[test]
442    fn test_configuration_presets() {
443        let dev_config = StdioAuthConfig::development();
444        assert!(dev_config.allow_process_args);
445        assert!(dev_config.default_api_key.is_some());
446        assert!(!dev_config.require_auth);
447
448        let prod_config = StdioAuthConfig::production();
449        assert!(!prod_config.allow_process_args);
450        assert!(prod_config.default_api_key.is_none());
451        assert!(prod_config.require_auth);
452
453        let secure_config = StdioAuthConfig::secure();
454        assert!(!secure_config.allow_init_params);
455        assert!(!secure_config.allow_process_args);
456        assert!(secure_config.require_auth);
457    }
458}