pulseengine_mcp_auth/transport/
stdio_auth.rs1use super::auth_extractors::{
7 AuthExtractionResult, AuthExtractor, AuthUtils, TransportAuthContext, TransportAuthError,
8 TransportRequest, TransportType,
9};
10use async_trait::async_trait;
11use serde_json::Value;
12
13#[derive(Debug, Clone)]
15pub struct StdioAuthConfig {
16 pub api_key_env_var: String,
18
19 pub allow_init_params: bool,
21
22 pub allow_process_args: bool,
24
25 pub default_api_key: Option<String>,
27
28 pub require_auth: bool,
30}
31
32impl Default for StdioAuthConfig {
33 fn default() -> Self {
34 Self {
35 api_key_env_var: "MCP_API_KEY".to_string(),
36 allow_init_params: true,
37 allow_process_args: false, default_api_key: None,
39 require_auth: false, }
41 }
42}
43
44pub struct StdioAuthExtractor {
46 config: StdioAuthConfig,
47}
48
49impl StdioAuthExtractor {
50 pub fn new(config: StdioAuthConfig) -> Self {
52 Self { config }
53 }
54
55 pub fn default() -> Self {
57 Self::new(StdioAuthConfig::default())
58 }
59
60 fn extract_env_auth(&self) -> AuthExtractionResult {
62 if let Ok(api_key) = std::env::var(&self.config.api_key_env_var) {
63 if !api_key.is_empty() {
64 AuthUtils::validate_api_key_format(&api_key)?;
65 let context = TransportAuthContext::new(
66 api_key,
67 "Environment".to_string(),
68 TransportType::Stdio,
69 );
70 return Ok(Some(context));
71 }
72 }
73
74 Ok(None)
75 }
76
77 fn extract_init_params(&self, request: &TransportRequest) -> AuthExtractionResult {
79 if !self.config.allow_init_params {
80 return Ok(None);
81 }
82
83 if let Some(body) = &request.body {
84 if let Some(params) = body.get("params") {
86 if let Some(api_key) = self.find_api_key_in_params(params) {
88 AuthUtils::validate_api_key_format(&api_key)?;
89 let context = TransportAuthContext::new(
90 api_key,
91 "InitParams".to_string(),
92 TransportType::Stdio,
93 );
94 return Ok(Some(context));
95 }
96 }
97 }
98
99 Ok(None)
100 }
101
102 fn find_api_key_in_params(&self, params: &Value) -> Option<String> {
104 if let Some(api_key) = params.get("api_key").and_then(|v| v.as_str()) {
106 return Some(api_key.to_string());
107 }
108
109 if let Some(client_info) = params.get("clientInfo") {
111 if let Some(api_key) = client_info.get("api_key").and_then(|v| v.as_str()) {
112 return Some(api_key.to_string());
113 }
114
115 if let Some(capabilities) = client_info.get("capabilities") {
117 if let Some(auth) = capabilities.get("authentication") {
118 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
119 return Some(api_key.to_string());
120 }
121 }
122 }
123 }
124
125 if let Some(capabilities) = params.get("capabilities") {
127 if let Some(auth) = capabilities.get("authentication") {
128 if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) {
129 return Some(api_key.to_string());
130 }
131 }
132 }
133
134 None
135 }
136
137 fn extract_process_args(&self) -> AuthExtractionResult {
139 if !self.config.allow_process_args {
140 return Ok(None);
141 }
142
143 let args: Vec<String> = std::env::args().collect();
144
145 for i in 0..args.len() {
147 if args[i] == "--api-key" && i + 1 < args.len() {
148 let api_key = &args[i + 1];
149 AuthUtils::validate_api_key_format(api_key)?;
150 let context = TransportAuthContext::new(
151 api_key.clone(),
152 "ProcessArgs".to_string(),
153 TransportType::Stdio,
154 );
155 return Ok(Some(context));
156 }
157
158 if let Some(key_value) = args[i].strip_prefix("--api-key=") {
160 AuthUtils::validate_api_key_format(key_value)?;
161 let context = TransportAuthContext::new(
162 key_value.to_string(),
163 "ProcessArgs".to_string(),
164 TransportType::Stdio,
165 );
166 return Ok(Some(context));
167 }
168 }
169
170 Ok(None)
171 }
172
173 fn extract_default_auth(&self) -> AuthExtractionResult {
175 if let Some(ref api_key) = self.config.default_api_key {
176 AuthUtils::validate_api_key_format(api_key)?;
177 let context = TransportAuthContext::new(
178 api_key.clone(),
179 "Default".to_string(),
180 TransportType::Stdio,
181 );
182 return Ok(Some(context));
183 }
184
185 Ok(None)
186 }
187
188 fn enrich_context(
190 &self,
191 mut context: TransportAuthContext,
192 _request: &TransportRequest,
193 ) -> TransportAuthContext {
194 if let Ok(current_exe) = std::env::current_exe() {
196 if let Some(exe_name) = current_exe.file_name().and_then(|n| n.to_str()) {
197 context = context.with_metadata("process".to_string(), exe_name.to_string());
198 }
199 }
200
201 if let Ok(cwd) = std::env::current_dir() {
203 context =
204 context.with_metadata("working_dir".to_string(), cwd.to_string_lossy().to_string());
205 }
206
207 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
209 context = context.with_metadata("user".to_string(), user);
210 }
211
212 context
213 }
214}
215
216#[async_trait]
217impl AuthExtractor for StdioAuthExtractor {
218 async fn extract_auth(&self, request: &TransportRequest) -> AuthExtractionResult {
219 if let Ok(Some(context)) = self.extract_env_auth() {
223 return Ok(Some(self.enrich_context(context, request)));
224 }
225
226 if let Ok(Some(context)) = self.extract_init_params(request) {
228 return Ok(Some(self.enrich_context(context, request)));
229 }
230
231 if let Ok(Some(context)) = self.extract_process_args() {
233 return Ok(Some(self.enrich_context(context, request)));
234 }
235
236 if let Ok(Some(context)) = self.extract_default_auth() {
238 return Ok(Some(self.enrich_context(context, request)));
239 }
240
241 if self.config.require_auth {
243 return Err(TransportAuthError::NoAuth);
244 }
245
246 Ok(None)
247 }
248
249 fn transport_type(&self) -> TransportType {
250 TransportType::Stdio
251 }
252
253 fn can_handle(&self, _request: &TransportRequest) -> bool {
254 true
256 }
257
258 async fn validate_auth(
259 &self,
260 context: &TransportAuthContext,
261 ) -> Result<(), TransportAuthError> {
262 if context.credential.is_empty() {
264 return Err(TransportAuthError::InvalidFormat(
265 "Empty credential".to_string(),
266 ));
267 }
268
269 if context.method == "Default" {
271 tracing::warn!(
272 "Using default API key for stdio authentication - not recommended for production"
273 );
274 }
275
276 Ok(())
277 }
278}
279
280impl StdioAuthConfig {
282 pub fn development() -> Self {
284 Self {
285 api_key_env_var: "MCP_API_KEY".to_string(),
286 allow_init_params: true,
287 allow_process_args: true,
288 default_api_key: Some("lmcp_dev_1234567890abcdef".to_string()),
289 require_auth: false,
290 }
291 }
292
293 pub fn production() -> Self {
295 Self {
296 api_key_env_var: "MCP_API_KEY".to_string(),
297 allow_init_params: true,
298 allow_process_args: false,
299 default_api_key: None,
300 require_auth: true,
301 }
302 }
303
304 pub fn secure() -> Self {
306 Self {
307 api_key_env_var: "MCP_API_KEY".to_string(),
308 allow_init_params: false,
309 allow_process_args: false,
310 default_api_key: None,
311 require_auth: true,
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use serde_json::json;
320
321 #[test]
322 fn test_environment_variable_extraction() {
323 std::env::set_var("TEST_MCP_API_KEY", "lmcp_test_1234567890abcdef");
324
325 let config = StdioAuthConfig {
326 api_key_env_var: "TEST_MCP_API_KEY".to_string(),
327 ..Default::default()
328 };
329 let extractor = StdioAuthExtractor::new(config);
330 let request = TransportRequest::new();
331
332 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
333
334 assert!(result.is_some());
335 let context = result.unwrap();
336 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
337 assert_eq!(context.method, "Environment");
338 assert_eq!(context.transport_type, TransportType::Stdio);
339
340 std::env::remove_var("TEST_MCP_API_KEY");
341 }
342
343 #[test]
344 fn test_init_params_extraction() {
345 let extractor = StdioAuthExtractor::default();
346
347 let init_request = json!({
348 "params": {
349 "api_key": "lmcp_test_1234567890abcdef",
350 "clientInfo": {
351 "name": "test-client"
352 }
353 }
354 });
355
356 let request = TransportRequest::new().with_body(init_request);
357 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
358
359 assert!(result.is_some());
360 let context = result.unwrap();
361 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
362 assert_eq!(context.method, "InitParams");
363 }
364
365 #[test]
366 fn test_nested_init_params_extraction() {
367 let extractor = StdioAuthExtractor::default();
368
369 let init_request = json!({
370 "params": {
371 "clientInfo": {
372 "name": "test-client",
373 "capabilities": {
374 "authentication": {
375 "api_key": "lmcp_test_1234567890abcdef"
376 }
377 }
378 }
379 }
380 });
381
382 let request = TransportRequest::new().with_body(init_request);
383 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
384
385 assert!(result.is_some());
386 let context = result.unwrap();
387 assert_eq!(context.credential, "lmcp_test_1234567890abcdef");
388 assert_eq!(context.method, "InitParams");
389 }
390
391 #[test]
392 fn test_default_api_key() {
393 let config = StdioAuthConfig {
394 default_api_key: Some("lmcp_default_1234567890abcdef".to_string()),
395 ..Default::default()
396 };
397 let extractor = StdioAuthExtractor::new(config);
398 let request = TransportRequest::new();
399
400 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
401
402 assert!(result.is_some());
403 let context = result.unwrap();
404 assert_eq!(context.credential, "lmcp_default_1234567890abcdef");
405 assert_eq!(context.method, "Default");
406 }
407
408 #[test]
409 fn test_no_authentication_required() {
410 let config = StdioAuthConfig {
411 require_auth: false,
412 ..Default::default()
413 };
414 let extractor = StdioAuthExtractor::new(config);
415 let request = TransportRequest::new();
416
417 let result = tokio_test::block_on(extractor.extract_auth(&request)).unwrap();
418 assert!(result.is_none());
419 }
420
421 #[test]
422 fn test_authentication_required_but_missing() {
423 let config = StdioAuthConfig {
424 require_auth: true,
425 ..Default::default()
426 };
427 let extractor = StdioAuthExtractor::new(config);
428 let request = TransportRequest::new();
429
430 let result = tokio_test::block_on(extractor.extract_auth(&request));
431 assert!(result.is_err());
432 assert!(matches!(result.unwrap_err(), TransportAuthError::NoAuth));
433 }
434
435 #[test]
436 fn test_configuration_presets() {
437 let dev_config = StdioAuthConfig::development();
438 assert!(dev_config.allow_process_args);
439 assert!(dev_config.default_api_key.is_some());
440 assert!(!dev_config.require_auth);
441
442 let prod_config = StdioAuthConfig::production();
443 assert!(!prod_config.allow_process_args);
444 assert!(prod_config.default_api_key.is_none());
445 assert!(prod_config.require_auth);
446
447 let secure_config = StdioAuthConfig::secure();
448 assert!(!secure_config.allow_init_params);
449 assert!(!secure_config.allow_process_args);
450 assert!(secure_config.require_auth);
451 }
452}