1use crate::error::Result;
21use crate::mcp::tools::ToolRegistry;
22use crate::mcp::types::{
23 JsonRpcRequest, JsonRpcResponse, McpCapabilities, McpServerInfo, ToolCallParams,
24};
25use serde_json::{json, Value};
26use std::io::{self, BufRead, Write};
27use tokio::sync::RwLock;
28use tracing::{debug, error, info, instrument, warn};
29
30const MCP_TOKEN_ENV_VAR: &str = "REASONKIT_MCP_TOKEN";
32
33const AUTH_ERROR_CODE: i32 = -32001;
35
36pub struct McpServer {
38 tools: ToolRegistry,
40 info: McpServerInfo,
42 initialized: RwLock<bool>,
44 auth_token: Option<String>,
47}
48
49impl McpServer {
50 pub fn new() -> Self {
56 let auth_token = std::env::var(MCP_TOKEN_ENV_VAR)
57 .ok()
58 .filter(|t| !t.is_empty());
59
60 if auth_token.is_some() {
61 info!(
62 "MCP server authentication enabled via {}",
63 MCP_TOKEN_ENV_VAR
64 );
65 } else {
66 warn!(
67 "MCP server running without authentication. Set {} to enable.",
68 MCP_TOKEN_ENV_VAR
69 );
70 }
71
72 Self {
73 tools: ToolRegistry::new(),
74 info: McpServerInfo::default(),
75 initialized: RwLock::new(false),
76 auth_token,
77 }
78 }
79
80 pub fn with_auth_token(token: impl Into<String>) -> Self {
85 let token = token.into();
86 let auth_token = if token.is_empty() { None } else { Some(token) };
87
88 Self {
89 tools: ToolRegistry::new(),
90 info: McpServerInfo::default(),
91 initialized: RwLock::new(false),
92 auth_token,
93 }
94 }
95
96 pub fn is_auth_enabled(&self) -> bool {
98 self.auth_token.is_some()
99 }
100
101 fn validate_auth(
114 &self,
115 request: &JsonRpcRequest,
116 ) -> std::result::Result<(), Box<JsonRpcResponse>> {
117 let expected_token = match &self.auth_token {
118 Some(token) => token,
119 None => return Ok(()), };
121
122 let provided_token = request
124 .params
125 .as_ref()
126 .and_then(|p| p.get("auth_token"))
127 .and_then(|v| v.as_str());
128
129 match provided_token {
130 Some(token) => {
131 if constant_time_compare(token, expected_token) {
133 debug!("Authentication successful for method: {}", request.method);
134 Ok(())
135 } else {
136 warn!(
137 method = %request.method,
138 "Authentication failed: invalid token"
139 );
140 Err(Box::new(JsonRpcResponse::error(
141 request.id.clone(),
142 AUTH_ERROR_CODE,
143 "Authentication failed: invalid token",
144 )))
145 }
146 }
147 None => {
148 warn!(
149 method = %request.method,
150 "Authentication failed: missing auth_token in params"
151 );
152 Err(Box::new(JsonRpcResponse::error(
153 request.id.clone(),
154 AUTH_ERROR_CODE,
155 "Authentication required: missing auth_token in params",
156 )))
157 }
158 }
159 }
160
161 #[instrument(skip(self))]
163 pub async fn run(&self) -> Result<()> {
164 info!(
165 "Starting MCP server: {} v{}",
166 self.info.name, self.info.version
167 );
168
169 if self.is_auth_enabled() {
170 info!("Authentication is ENABLED - all requests require valid auth_token");
171 } else {
172 warn!("Authentication is DISABLED - accepting all requests");
173 }
174
175 let stdin = io::stdin();
176 let mut stdout = io::stdout();
177
178 for line in stdin.lock().lines() {
179 let line = match line {
180 Ok(l) => l,
181 Err(e) => {
182 error!("Failed to read line: {}", e);
183 continue;
184 }
185 };
186
187 if line.trim().is_empty() {
188 continue;
189 }
190
191 debug!("Received: {}", line);
192
193 let response = self.handle_line(&line).await;
194
195 if let Some(resp) = response {
196 let json = serde_json::to_string(&resp).unwrap_or_else(|e| {
197 error!("Failed to serialize response: {}", e);
198 r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"Internal error"}}"#
199 .to_string()
200 });
201
202 debug!("Sending: {}", json);
203
204 if let Err(e) = writeln!(stdout, "{}", json) {
205 error!("Failed to write response: {}", e);
206 }
207 if let Err(e) = stdout.flush() {
208 error!("Failed to flush stdout: {}", e);
209 }
210 }
211 }
212
213 info!("MCP server shutting down");
214 Ok(())
215 }
216
217 async fn handle_line(&self, line: &str) -> Option<JsonRpcResponse> {
219 let request: JsonRpcRequest = match serde_json::from_str(line) {
221 Ok(r) => r,
222 Err(e) => {
223 warn!("Failed to parse request: {}", e);
224 return Some(JsonRpcResponse::parse_error());
225 }
226 };
227
228 self.handle_request(request).await
230 }
231
232 #[instrument(skip(self, request))]
234 async fn handle_request(&self, request: JsonRpcRequest) -> Option<JsonRpcResponse> {
235 let id = request.id.clone();
236 let method = request.method.as_str();
237
238 info!("Handling method: {}", method);
239
240 if let Err(auth_error) = self.validate_auth(&request) {
243 return Some(*auth_error);
244 }
245
246 let result = match method {
247 "initialize" => self.handle_initialize(request.params).await,
249 "initialized" => {
250 return None;
252 }
253 "shutdown" => self.handle_shutdown().await,
254
255 "tools/list" => self.handle_tools_list().await,
257 "tools/call" => self.handle_tools_call(request.params).await,
258
259 "ping" => Ok(json!({ "pong": true })),
261
262 _ => {
264 warn!("Unknown method: {}", method);
265 return Some(JsonRpcResponse::method_not_found(id, method));
266 }
267 };
268
269 Some(match result {
270 Ok(value) => JsonRpcResponse::success(id, value),
271 Err(e) => JsonRpcResponse::internal_error(id, &e.to_string()),
272 })
273 }
274
275 async fn handle_initialize(&self, params: Option<Value>) -> Result<Value> {
277 info!("Handling initialize");
278
279 if let Some(ref p) = params {
281 if let Some(version) = p.get("protocolVersion").and_then(|v| v.as_str()) {
282 debug!("Client protocol version: {}", version);
283 }
285 }
286
287 *self.initialized.write().await = true;
288
289 Ok(json!({
290 "protocolVersion": "2024-11-05",
291 "capabilities": McpCapabilities::default(),
292 "serverInfo": self.info
293 }))
294 }
295
296 async fn handle_shutdown(&self) -> Result<Value> {
298 info!("Handling shutdown");
299 *self.initialized.write().await = false;
300 Ok(json!(null))
301 }
302
303 async fn handle_tools_list(&self) -> Result<Value> {
305 let definitions = self.tools.definitions();
306 Ok(json!({
307 "tools": definitions
308 }))
309 }
310
311 async fn handle_tools_call(&self, params: Option<Value>) -> Result<Value> {
313 let params = params.ok_or_else(|| crate::error::Error::generic("Missing params"))?;
314
315 let tool_params: ToolCallParams = serde_json::from_value(params)
316 .map_err(|e| crate::error::Error::generic(format!("Invalid params: {}", e)))?;
317
318 let result = self
319 .tools
320 .execute(&tool_params.name, tool_params.arguments)
321 .await;
322
323 Ok(serde_json::to_value(result)?)
324 }
325}
326
327impl Default for McpServer {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333fn constant_time_compare(a: &str, b: &str) -> bool {
346 let a_bytes = a.as_bytes();
347 let b_bytes = b.as_bytes();
348
349 if a_bytes.len() != b_bytes.len() {
352 let mut _dummy: u8 = 0;
355 for byte in a_bytes.iter() {
356 _dummy |= *byte; }
358 return false;
359 }
360
361 let mut result: u8 = 0;
363 for (x, y) in a_bytes.iter().zip(b_bytes.iter()) {
364 result |= x ^ y;
365 }
366
367 result == 0
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn test_constant_time_compare_equal() {
376 assert!(constant_time_compare("secret123", "secret123"));
377 assert!(constant_time_compare("", ""));
378 assert!(constant_time_compare("a", "a"));
379 }
380
381 #[test]
382 fn test_constant_time_compare_unequal() {
383 assert!(!constant_time_compare("secret123", "secret124"));
384 assert!(!constant_time_compare("secret123", "Secret123"));
385 assert!(!constant_time_compare("abc", "def"));
386 }
387
388 #[test]
389 fn test_constant_time_compare_different_lengths() {
390 assert!(!constant_time_compare("short", "longer"));
391 assert!(!constant_time_compare("longer", "short"));
392 assert!(!constant_time_compare("abc", ""));
393 }
394
395 #[tokio::test]
396 async fn test_mcp_server_new() {
397 std::env::remove_var(MCP_TOKEN_ENV_VAR);
399 let server = McpServer::new();
400 assert_eq!(server.info.name, "reasonkit-web");
401 assert!(!server.is_auth_enabled());
402 }
403
404 #[tokio::test]
405 async fn test_mcp_server_with_auth_token() {
406 let server = McpServer::with_auth_token("test-secret-token");
407 assert!(server.is_auth_enabled());
408 }
409
410 #[tokio::test]
411 async fn test_mcp_server_with_empty_auth_token() {
412 let server = McpServer::with_auth_token("");
413 assert!(!server.is_auth_enabled());
414 }
415
416 #[tokio::test]
417 async fn test_validate_auth_no_token_configured() {
418 let server = McpServer::with_auth_token("");
419 let request = JsonRpcRequest {
420 jsonrpc: "2.0".to_string(),
421 method: "ping".to_string(),
422 params: None,
423 id: Some(json!(1)),
424 };
425
426 assert!(server.validate_auth(&request).is_ok());
427 }
428
429 #[tokio::test]
430 async fn test_validate_auth_valid_token() {
431 let server = McpServer::with_auth_token("my-secret-token");
432 let request = JsonRpcRequest {
433 jsonrpc: "2.0".to_string(),
434 method: "ping".to_string(),
435 params: Some(json!({ "auth_token": "my-secret-token" })),
436 id: Some(json!(1)),
437 };
438
439 assert!(server.validate_auth(&request).is_ok());
440 }
441
442 #[tokio::test]
443 async fn test_validate_auth_invalid_token() {
444 let server = McpServer::with_auth_token("my-secret-token");
445 let request = JsonRpcRequest {
446 jsonrpc: "2.0".to_string(),
447 method: "ping".to_string(),
448 params: Some(json!({ "auth_token": "wrong-token" })),
449 id: Some(json!(1)),
450 };
451
452 let result = server.validate_auth(&request);
453 assert!(result.is_err());
454 let err_response = result.unwrap_err();
455 assert!(err_response.error.is_some());
456 assert_eq!(err_response.error.as_ref().unwrap().code, AUTH_ERROR_CODE);
457 assert!(err_response
458 .error
459 .as_ref()
460 .unwrap()
461 .message
462 .contains("invalid token"));
463 }
464
465 #[tokio::test]
466 async fn test_validate_auth_missing_token() {
467 let server = McpServer::with_auth_token("my-secret-token");
468 let request = JsonRpcRequest {
469 jsonrpc: "2.0".to_string(),
470 method: "ping".to_string(),
471 params: None,
472 id: Some(json!(1)),
473 };
474
475 let result = server.validate_auth(&request);
476 assert!(result.is_err());
477 let err_response = result.unwrap_err();
478 assert!(err_response.error.is_some());
479 assert_eq!(err_response.error.as_ref().unwrap().code, AUTH_ERROR_CODE);
480 assert!(err_response
481 .error
482 .as_ref()
483 .unwrap()
484 .message
485 .contains("missing auth_token"));
486 }
487
488 #[tokio::test]
489 async fn test_validate_auth_token_in_params_but_not_string() {
490 let server = McpServer::with_auth_token("my-secret-token");
491 let request = JsonRpcRequest {
492 jsonrpc: "2.0".to_string(),
493 method: "ping".to_string(),
494 params: Some(json!({ "auth_token": 12345 })), id: Some(json!(1)),
496 };
497
498 let result = server.validate_auth(&request);
499 assert!(result.is_err());
500 }
501
502 #[tokio::test]
503 async fn test_handle_request_with_auth_required() {
504 let server = McpServer::with_auth_token("secret");
505 let request = JsonRpcRequest {
506 jsonrpc: "2.0".to_string(),
507 method: "ping".to_string(),
508 params: None, id: Some(json!(1)),
510 };
511
512 let response = server.handle_request(request).await.unwrap();
513 assert!(response.error.is_some());
514 assert_eq!(response.error.as_ref().unwrap().code, AUTH_ERROR_CODE);
515 }
516
517 #[tokio::test]
518 async fn test_handle_request_with_valid_auth() {
519 let server = McpServer::with_auth_token("secret");
520 let request = JsonRpcRequest {
521 jsonrpc: "2.0".to_string(),
522 method: "ping".to_string(),
523 params: Some(json!({ "auth_token": "secret" })),
524 id: Some(json!(1)),
525 };
526
527 let response = server.handle_request(request).await.unwrap();
528 assert!(response.result.is_some());
529 assert!(response.result.unwrap()["pong"].as_bool().unwrap());
530 }
531
532 #[tokio::test]
533 async fn test_handle_ping() {
534 std::env::remove_var(MCP_TOKEN_ENV_VAR);
535 let server = McpServer::new();
536 let request = JsonRpcRequest {
537 jsonrpc: "2.0".to_string(),
538 method: "ping".to_string(),
539 params: None,
540 id: Some(json!(1)),
541 };
542
543 let response = server.handle_request(request).await.unwrap();
544 assert!(response.result.is_some());
545 assert!(response.result.unwrap()["pong"].as_bool().unwrap());
546 }
547
548 #[tokio::test]
549 async fn test_handle_initialize() {
550 std::env::remove_var(MCP_TOKEN_ENV_VAR);
551 let server = McpServer::new();
552 let request = JsonRpcRequest {
553 jsonrpc: "2.0".to_string(),
554 method: "initialize".to_string(),
555 params: Some(json!({
556 "protocolVersion": "2024-11-05"
557 })),
558 id: Some(json!(1)),
559 };
560
561 let response = server.handle_request(request).await.unwrap();
562 assert!(response.result.is_some());
563 let result = response.result.unwrap();
564 assert_eq!(result["protocolVersion"], "2024-11-05");
565 assert!(result["capabilities"].is_object());
566 assert!(result["serverInfo"].is_object());
567 }
568
569 #[tokio::test]
570 async fn test_handle_tools_list() {
571 std::env::remove_var(MCP_TOKEN_ENV_VAR);
572 let server = McpServer::new();
573 let request = JsonRpcRequest {
574 jsonrpc: "2.0".to_string(),
575 method: "tools/list".to_string(),
576 params: None,
577 id: Some(json!(2)),
578 };
579
580 let response = server.handle_request(request).await.unwrap();
581 assert!(response.result.is_some());
582 let result = response.result.unwrap();
583 assert!(result["tools"].is_array());
584 assert!(!result["tools"].as_array().unwrap().is_empty());
585 }
586
587 #[tokio::test]
588 async fn test_handle_unknown_method() {
589 std::env::remove_var(MCP_TOKEN_ENV_VAR);
590 let server = McpServer::new();
591 let request = JsonRpcRequest {
592 jsonrpc: "2.0".to_string(),
593 method: "unknown/method".to_string(),
594 params: None,
595 id: Some(json!(3)),
596 };
597
598 let response = server.handle_request(request).await.unwrap();
599 assert!(response.error.is_some());
600 assert_eq!(response.error.unwrap().code, -32601);
601 }
602
603 #[tokio::test]
604 async fn test_handle_notification() {
605 std::env::remove_var(MCP_TOKEN_ENV_VAR);
606 let server = McpServer::new();
607 let request = JsonRpcRequest {
608 jsonrpc: "2.0".to_string(),
609 method: "initialized".to_string(),
610 params: None,
611 id: None, };
613
614 let response = server.handle_request(request).await;
615 assert!(response.is_none()); }
617
618 #[tokio::test]
619 async fn test_handle_initialize_with_auth() {
620 let server = McpServer::with_auth_token("init-secret");
621 let request = JsonRpcRequest {
622 jsonrpc: "2.0".to_string(),
623 method: "initialize".to_string(),
624 params: Some(json!({
625 "protocolVersion": "2024-11-05",
626 "auth_token": "init-secret"
627 })),
628 id: Some(json!(1)),
629 };
630
631 let response = server.handle_request(request).await.unwrap();
632 assert!(response.result.is_some());
633 let result = response.result.unwrap();
634 assert_eq!(result["protocolVersion"], "2024-11-05");
635 }
636
637 #[tokio::test]
638 async fn test_handle_tools_list_with_auth() {
639 let server = McpServer::with_auth_token("list-secret");
640 let request = JsonRpcRequest {
641 jsonrpc: "2.0".to_string(),
642 method: "tools/list".to_string(),
643 params: Some(json!({ "auth_token": "list-secret" })),
644 id: Some(json!(2)),
645 };
646
647 let response = server.handle_request(request).await.unwrap();
648 assert!(response.result.is_some());
649 let result = response.result.unwrap();
650 assert!(result["tools"].is_array());
651 }
652}