reasonkit_web/mcp/
server.rs

1//! MCP stdio server implementation
2//!
3//! This module implements the MCP server that communicates over stdio,
4//! handling JSON-RPC requests and dispatching to registered tools.
5//!
6//! # Security
7//!
8//! The server supports optional token-based authentication via the
9//! `REASONKIT_MCP_TOKEN` environment variable. When set, all requests
10//! must include a valid authentication token in the request parameters
11//! or the request will be rejected with an authentication error.
12//!
13//! ## Authentication Methods
14//!
15//! 1. **Request params**: Include `auth_token` in the request params object
16//! 2. **Environment variable**: Set `REASONKIT_MCP_TOKEN` to enable authentication
17//!
18//! If `REASONKIT_MCP_TOKEN` is not set, authentication is disabled (backwards-compatible).
19
20use crate::error::Result;
21use crate::mcp::tools::ToolRegistry;
22use crate::mcp::types::{
23    JsonRpcRequest, JsonRpcResponse, McpCapabilities, McpServerInfo, ToolCallParams,
24};
25use serde_json::{json, Value};
26use std::io::{self, BufRead, Write};
27use tokio::sync::RwLock;
28use tracing::{debug, error, info, instrument, warn};
29
30/// Environment variable name for the MCP authentication token
31const MCP_TOKEN_ENV_VAR: &str = "REASONKIT_MCP_TOKEN";
32
33/// JSON-RPC error code for authentication failure (using -32000 range for server errors)
34const AUTH_ERROR_CODE: i32 = -32001;
35
36/// MCP server state
37pub struct McpServer {
38    /// Tool registry
39    tools: ToolRegistry,
40    /// Server info
41    info: McpServerInfo,
42    /// Whether the server has been initialized
43    initialized: RwLock<bool>,
44    /// Optional authentication token (loaded from REASONKIT_MCP_TOKEN env var)
45    /// When Some, all requests must include a matching auth_token in params
46    auth_token: Option<String>,
47}
48
49impl McpServer {
50    /// Create a new MCP server
51    ///
52    /// Loads authentication token from `REASONKIT_MCP_TOKEN` environment variable if set.
53    /// When a token is configured, all incoming requests must include a matching
54    /// `auth_token` field in their params object.
55    pub fn new() -> Self {
56        let auth_token = std::env::var(MCP_TOKEN_ENV_VAR)
57            .ok()
58            .filter(|t| !t.is_empty());
59
60        if auth_token.is_some() {
61            info!(
62                "MCP server authentication enabled via {}",
63                MCP_TOKEN_ENV_VAR
64            );
65        } else {
66            warn!(
67                "MCP server running without authentication. Set {} to enable.",
68                MCP_TOKEN_ENV_VAR
69            );
70        }
71
72        Self {
73            tools: ToolRegistry::new(),
74            info: McpServerInfo::default(),
75            initialized: RwLock::new(false),
76            auth_token,
77        }
78    }
79
80    /// Create a new MCP server with a specific authentication token
81    ///
82    /// This method is primarily for testing purposes. In production, use `new()`
83    /// which loads the token from the environment variable.
84    pub fn with_auth_token(token: impl Into<String>) -> Self {
85        let token = token.into();
86        let auth_token = if token.is_empty() { None } else { Some(token) };
87
88        Self {
89            tools: ToolRegistry::new(),
90            info: McpServerInfo::default(),
91            initialized: RwLock::new(false),
92            auth_token,
93        }
94    }
95
96    /// Check if authentication is enabled
97    pub fn is_auth_enabled(&self) -> bool {
98        self.auth_token.is_some()
99    }
100
101    /// Validate authentication for an incoming request
102    ///
103    /// # Authentication Logic
104    ///
105    /// 1. If no auth_token is configured (None), authentication passes (backwards-compatible)
106    /// 2. If auth_token is configured, the request params must contain a matching `auth_token` field
107    /// 3. Token comparison uses constant-time comparison to prevent timing attacks
108    ///
109    /// # Returns
110    ///
111    /// - `Ok(())` if authentication succeeds or is not required
112    /// - `Err(JsonRpcResponse)` with authentication error if validation fails
113    fn validate_auth(
114        &self,
115        request: &JsonRpcRequest,
116    ) -> std::result::Result<(), Box<JsonRpcResponse>> {
117        let expected_token = match &self.auth_token {
118            Some(token) => token,
119            None => return Ok(()), // No auth required if token not configured
120        };
121
122        // Extract auth_token from request params
123        let provided_token = request
124            .params
125            .as_ref()
126            .and_then(|p| p.get("auth_token"))
127            .and_then(|v| v.as_str());
128
129        match provided_token {
130            Some(token) => {
131                // Use constant-time comparison to prevent timing attacks
132                if constant_time_compare(token, expected_token) {
133                    debug!("Authentication successful for method: {}", request.method);
134                    Ok(())
135                } else {
136                    warn!(
137                        method = %request.method,
138                        "Authentication failed: invalid token"
139                    );
140                    Err(Box::new(JsonRpcResponse::error(
141                        request.id.clone(),
142                        AUTH_ERROR_CODE,
143                        "Authentication failed: invalid token",
144                    )))
145                }
146            }
147            None => {
148                warn!(
149                    method = %request.method,
150                    "Authentication failed: missing auth_token in params"
151                );
152                Err(Box::new(JsonRpcResponse::error(
153                    request.id.clone(),
154                    AUTH_ERROR_CODE,
155                    "Authentication required: missing auth_token in params",
156                )))
157            }
158        }
159    }
160
161    /// Run the MCP server (blocking)
162    #[instrument(skip(self))]
163    pub async fn run(&self) -> Result<()> {
164        info!(
165            "Starting MCP server: {} v{}",
166            self.info.name, self.info.version
167        );
168
169        if self.is_auth_enabled() {
170            info!("Authentication is ENABLED - all requests require valid auth_token");
171        } else {
172            warn!("Authentication is DISABLED - accepting all requests");
173        }
174
175        let stdin = io::stdin();
176        let mut stdout = io::stdout();
177
178        for line in stdin.lock().lines() {
179            let line = match line {
180                Ok(l) => l,
181                Err(e) => {
182                    error!("Failed to read line: {}", e);
183                    continue;
184                }
185            };
186
187            if line.trim().is_empty() {
188                continue;
189            }
190
191            debug!("Received: {}", line);
192
193            let response = self.handle_line(&line).await;
194
195            if let Some(resp) = response {
196                let json = serde_json::to_string(&resp).unwrap_or_else(|e| {
197                    error!("Failed to serialize response: {}", e);
198                    r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"Internal error"}}"#
199                        .to_string()
200                });
201
202                debug!("Sending: {}", json);
203
204                if let Err(e) = writeln!(stdout, "{}", json) {
205                    error!("Failed to write response: {}", e);
206                }
207                if let Err(e) = stdout.flush() {
208                    error!("Failed to flush stdout: {}", e);
209                }
210            }
211        }
212
213        info!("MCP server shutting down");
214        Ok(())
215    }
216
217    /// Handle a single line of input
218    async fn handle_line(&self, line: &str) -> Option<JsonRpcResponse> {
219        // Try to parse as JSON-RPC request
220        let request: JsonRpcRequest = match serde_json::from_str(line) {
221            Ok(r) => r,
222            Err(e) => {
223                warn!("Failed to parse request: {}", e);
224                return Some(JsonRpcResponse::parse_error());
225            }
226        };
227
228        // Handle the request
229        self.handle_request(request).await
230    }
231
232    /// Handle a JSON-RPC request
233    #[instrument(skip(self, request))]
234    async fn handle_request(&self, request: JsonRpcRequest) -> Option<JsonRpcResponse> {
235        let id = request.id.clone();
236        let method = request.method.as_str();
237
238        info!("Handling method: {}", method);
239
240        // Validate authentication BEFORE processing any method
241        // This prevents unauthenticated access to any server functionality
242        if let Err(auth_error) = self.validate_auth(&request) {
243            return Some(*auth_error);
244        }
245
246        let result = match method {
247            // Lifecycle methods
248            "initialize" => self.handle_initialize(request.params).await,
249            "initialized" => {
250                // Notification, no response needed
251                return None;
252            }
253            "shutdown" => self.handle_shutdown().await,
254
255            // Tool methods
256            "tools/list" => self.handle_tools_list().await,
257            "tools/call" => self.handle_tools_call(request.params).await,
258
259            // Ping (for testing)
260            "ping" => Ok(json!({ "pong": true })),
261
262            // Unknown method
263            _ => {
264                warn!("Unknown method: {}", method);
265                return Some(JsonRpcResponse::method_not_found(id, method));
266            }
267        };
268
269        Some(match result {
270            Ok(value) => JsonRpcResponse::success(id, value),
271            Err(e) => JsonRpcResponse::internal_error(id, &e.to_string()),
272        })
273    }
274
275    /// Handle initialize request
276    async fn handle_initialize(&self, params: Option<Value>) -> Result<Value> {
277        info!("Handling initialize");
278
279        // Validate protocol version if provided
280        if let Some(ref p) = params {
281            if let Some(version) = p.get("protocolVersion").and_then(|v| v.as_str()) {
282                debug!("Client protocol version: {}", version);
283                // We support MCP protocol version 2024-11-05 and earlier
284            }
285        }
286
287        *self.initialized.write().await = true;
288
289        Ok(json!({
290            "protocolVersion": "2024-11-05",
291            "capabilities": McpCapabilities::default(),
292            "serverInfo": self.info
293        }))
294    }
295
296    /// Handle shutdown request
297    async fn handle_shutdown(&self) -> Result<Value> {
298        info!("Handling shutdown");
299        *self.initialized.write().await = false;
300        Ok(json!(null))
301    }
302
303    /// Handle tools/list request
304    async fn handle_tools_list(&self) -> Result<Value> {
305        let definitions = self.tools.definitions();
306        Ok(json!({
307            "tools": definitions
308        }))
309    }
310
311    /// Handle tools/call request
312    async fn handle_tools_call(&self, params: Option<Value>) -> Result<Value> {
313        let params = params.ok_or_else(|| crate::error::Error::generic("Missing params"))?;
314
315        let tool_params: ToolCallParams = serde_json::from_value(params)
316            .map_err(|e| crate::error::Error::generic(format!("Invalid params: {}", e)))?;
317
318        let result = self
319            .tools
320            .execute(&tool_params.name, tool_params.arguments)
321            .await;
322
323        Ok(serde_json::to_value(result)?)
324    }
325}
326
327impl Default for McpServer {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333/// Constant-time string comparison to prevent timing attacks
334///
335/// This function compares two strings in constant time, regardless of where
336/// they differ. This prevents attackers from using timing information to
337/// gradually discover the correct token character by character.
338///
339/// # Security Note
340///
341/// This is a critical security function. The comparison must:
342/// 1. Always compare all bytes (no early exit)
343/// 2. Take the same amount of time regardless of input
344/// 3. Return a simple boolean result
345fn constant_time_compare(a: &str, b: &str) -> bool {
346    let a_bytes = a.as_bytes();
347    let b_bytes = b.as_bytes();
348
349    // If lengths differ, we still need to compare to avoid timing leak
350    // but we know the result will be false
351    if a_bytes.len() != b_bytes.len() {
352        // Still do a comparison to maintain constant time behavior
353        // Use a dummy comparison against self
354        let mut _dummy: u8 = 0;
355        for byte in a_bytes.iter() {
356            _dummy |= *byte; // Always 0, but compiler shouldn't optimize out
357        }
358        return false;
359    }
360
361    // XOR all bytes and accumulate differences
362    let mut result: u8 = 0;
363    for (x, y) in a_bytes.iter().zip(b_bytes.iter()) {
364        result |= x ^ y;
365    }
366
367    result == 0
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_constant_time_compare_equal() {
376        assert!(constant_time_compare("secret123", "secret123"));
377        assert!(constant_time_compare("", ""));
378        assert!(constant_time_compare("a", "a"));
379    }
380
381    #[test]
382    fn test_constant_time_compare_unequal() {
383        assert!(!constant_time_compare("secret123", "secret124"));
384        assert!(!constant_time_compare("secret123", "Secret123"));
385        assert!(!constant_time_compare("abc", "def"));
386    }
387
388    #[test]
389    fn test_constant_time_compare_different_lengths() {
390        assert!(!constant_time_compare("short", "longer"));
391        assert!(!constant_time_compare("longer", "short"));
392        assert!(!constant_time_compare("abc", ""));
393    }
394
395    #[tokio::test]
396    async fn test_mcp_server_new() {
397        // Clear any env var that might be set
398        std::env::remove_var(MCP_TOKEN_ENV_VAR);
399        let server = McpServer::new();
400        assert_eq!(server.info.name, "reasonkit-web");
401        assert!(!server.is_auth_enabled());
402    }
403
404    #[tokio::test]
405    async fn test_mcp_server_with_auth_token() {
406        let server = McpServer::with_auth_token("test-secret-token");
407        assert!(server.is_auth_enabled());
408    }
409
410    #[tokio::test]
411    async fn test_mcp_server_with_empty_auth_token() {
412        let server = McpServer::with_auth_token("");
413        assert!(!server.is_auth_enabled());
414    }
415
416    #[tokio::test]
417    async fn test_validate_auth_no_token_configured() {
418        let server = McpServer::with_auth_token("");
419        let request = JsonRpcRequest {
420            jsonrpc: "2.0".to_string(),
421            method: "ping".to_string(),
422            params: None,
423            id: Some(json!(1)),
424        };
425
426        assert!(server.validate_auth(&request).is_ok());
427    }
428
429    #[tokio::test]
430    async fn test_validate_auth_valid_token() {
431        let server = McpServer::with_auth_token("my-secret-token");
432        let request = JsonRpcRequest {
433            jsonrpc: "2.0".to_string(),
434            method: "ping".to_string(),
435            params: Some(json!({ "auth_token": "my-secret-token" })),
436            id: Some(json!(1)),
437        };
438
439        assert!(server.validate_auth(&request).is_ok());
440    }
441
442    #[tokio::test]
443    async fn test_validate_auth_invalid_token() {
444        let server = McpServer::with_auth_token("my-secret-token");
445        let request = JsonRpcRequest {
446            jsonrpc: "2.0".to_string(),
447            method: "ping".to_string(),
448            params: Some(json!({ "auth_token": "wrong-token" })),
449            id: Some(json!(1)),
450        };
451
452        let result = server.validate_auth(&request);
453        assert!(result.is_err());
454        let err_response = result.unwrap_err();
455        assert!(err_response.error.is_some());
456        assert_eq!(err_response.error.as_ref().unwrap().code, AUTH_ERROR_CODE);
457        assert!(err_response
458            .error
459            .as_ref()
460            .unwrap()
461            .message
462            .contains("invalid token"));
463    }
464
465    #[tokio::test]
466    async fn test_validate_auth_missing_token() {
467        let server = McpServer::with_auth_token("my-secret-token");
468        let request = JsonRpcRequest {
469            jsonrpc: "2.0".to_string(),
470            method: "ping".to_string(),
471            params: None,
472            id: Some(json!(1)),
473        };
474
475        let result = server.validate_auth(&request);
476        assert!(result.is_err());
477        let err_response = result.unwrap_err();
478        assert!(err_response.error.is_some());
479        assert_eq!(err_response.error.as_ref().unwrap().code, AUTH_ERROR_CODE);
480        assert!(err_response
481            .error
482            .as_ref()
483            .unwrap()
484            .message
485            .contains("missing auth_token"));
486    }
487
488    #[tokio::test]
489    async fn test_validate_auth_token_in_params_but_not_string() {
490        let server = McpServer::with_auth_token("my-secret-token");
491        let request = JsonRpcRequest {
492            jsonrpc: "2.0".to_string(),
493            method: "ping".to_string(),
494            params: Some(json!({ "auth_token": 12345 })), // Number, not string
495            id: Some(json!(1)),
496        };
497
498        let result = server.validate_auth(&request);
499        assert!(result.is_err());
500    }
501
502    #[tokio::test]
503    async fn test_handle_request_with_auth_required() {
504        let server = McpServer::with_auth_token("secret");
505        let request = JsonRpcRequest {
506            jsonrpc: "2.0".to_string(),
507            method: "ping".to_string(),
508            params: None, // No auth token
509            id: Some(json!(1)),
510        };
511
512        let response = server.handle_request(request).await.unwrap();
513        assert!(response.error.is_some());
514        assert_eq!(response.error.as_ref().unwrap().code, AUTH_ERROR_CODE);
515    }
516
517    #[tokio::test]
518    async fn test_handle_request_with_valid_auth() {
519        let server = McpServer::with_auth_token("secret");
520        let request = JsonRpcRequest {
521            jsonrpc: "2.0".to_string(),
522            method: "ping".to_string(),
523            params: Some(json!({ "auth_token": "secret" })),
524            id: Some(json!(1)),
525        };
526
527        let response = server.handle_request(request).await.unwrap();
528        assert!(response.result.is_some());
529        assert!(response.result.unwrap()["pong"].as_bool().unwrap());
530    }
531
532    #[tokio::test]
533    async fn test_handle_ping() {
534        std::env::remove_var(MCP_TOKEN_ENV_VAR);
535        let server = McpServer::new();
536        let request = JsonRpcRequest {
537            jsonrpc: "2.0".to_string(),
538            method: "ping".to_string(),
539            params: None,
540            id: Some(json!(1)),
541        };
542
543        let response = server.handle_request(request).await.unwrap();
544        assert!(response.result.is_some());
545        assert!(response.result.unwrap()["pong"].as_bool().unwrap());
546    }
547
548    #[tokio::test]
549    async fn test_handle_initialize() {
550        std::env::remove_var(MCP_TOKEN_ENV_VAR);
551        let server = McpServer::new();
552        let request = JsonRpcRequest {
553            jsonrpc: "2.0".to_string(),
554            method: "initialize".to_string(),
555            params: Some(json!({
556                "protocolVersion": "2024-11-05"
557            })),
558            id: Some(json!(1)),
559        };
560
561        let response = server.handle_request(request).await.unwrap();
562        assert!(response.result.is_some());
563        let result = response.result.unwrap();
564        assert_eq!(result["protocolVersion"], "2024-11-05");
565        assert!(result["capabilities"].is_object());
566        assert!(result["serverInfo"].is_object());
567    }
568
569    #[tokio::test]
570    async fn test_handle_tools_list() {
571        std::env::remove_var(MCP_TOKEN_ENV_VAR);
572        let server = McpServer::new();
573        let request = JsonRpcRequest {
574            jsonrpc: "2.0".to_string(),
575            method: "tools/list".to_string(),
576            params: None,
577            id: Some(json!(2)),
578        };
579
580        let response = server.handle_request(request).await.unwrap();
581        assert!(response.result.is_some());
582        let result = response.result.unwrap();
583        assert!(result["tools"].is_array());
584        assert!(!result["tools"].as_array().unwrap().is_empty());
585    }
586
587    #[tokio::test]
588    async fn test_handle_unknown_method() {
589        std::env::remove_var(MCP_TOKEN_ENV_VAR);
590        let server = McpServer::new();
591        let request = JsonRpcRequest {
592            jsonrpc: "2.0".to_string(),
593            method: "unknown/method".to_string(),
594            params: None,
595            id: Some(json!(3)),
596        };
597
598        let response = server.handle_request(request).await.unwrap();
599        assert!(response.error.is_some());
600        assert_eq!(response.error.unwrap().code, -32601);
601    }
602
603    #[tokio::test]
604    async fn test_handle_notification() {
605        std::env::remove_var(MCP_TOKEN_ENV_VAR);
606        let server = McpServer::new();
607        let request = JsonRpcRequest {
608            jsonrpc: "2.0".to_string(),
609            method: "initialized".to_string(),
610            params: None,
611            id: None, // Notification
612        };
613
614        let response = server.handle_request(request).await;
615        assert!(response.is_none()); // Notifications don't get responses
616    }
617
618    #[tokio::test]
619    async fn test_handle_initialize_with_auth() {
620        let server = McpServer::with_auth_token("init-secret");
621        let request = JsonRpcRequest {
622            jsonrpc: "2.0".to_string(),
623            method: "initialize".to_string(),
624            params: Some(json!({
625                "protocolVersion": "2024-11-05",
626                "auth_token": "init-secret"
627            })),
628            id: Some(json!(1)),
629        };
630
631        let response = server.handle_request(request).await.unwrap();
632        assert!(response.result.is_some());
633        let result = response.result.unwrap();
634        assert_eq!(result["protocolVersion"], "2024-11-05");
635    }
636
637    #[tokio::test]
638    async fn test_handle_tools_list_with_auth() {
639        let server = McpServer::with_auth_token("list-secret");
640        let request = JsonRpcRequest {
641            jsonrpc: "2.0".to_string(),
642            method: "tools/list".to_string(),
643            params: Some(json!({ "auth_token": "list-secret" })),
644            id: Some(json!(2)),
645        };
646
647        let response = server.handle_request(request).await.unwrap();
648        assert!(response.result.is_some());
649        let result = response.result.unwrap();
650        assert!(result["tools"].is_array());
651    }
652}