1use futures::StreamExt;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::process::Stdio;
35use std::sync::Arc;
36use thiserror::Error;
37use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
38use tokio::process::{Child, Command};
39use tokio::sync::RwLock;
40use tracing::{debug, info, warn};
41use uuid::Uuid;
42
43use crate::tools::{
44 JsonSchema, ToolCategory, ToolDefinition, ToolImpl, ToolResult, ToolResultValue,
45};
46
47#[derive(Error, Debug)]
50pub enum McpError {
51 #[error("Transport error: {0}")]
52 Transport(String),
53
54 #[error("JSON-RPC error: {0}")]
55 JsonRpc(String),
56
57 #[error("Server error: {code} - {message}")]
58 Server { code: i32, message: String },
59
60 #[error("Tool not found: {0}")]
61 #[allow(dead_code)]
62 ToolNotFound(String),
63
64 #[error("Invalid tool arguments: {0}")]
65 #[allow(dead_code)]
66 InvalidArguments(String),
67
68 #[error("Connection failed: {0}")]
69 ConnectionFailed(String),
70
71 #[error("Timeout: {0}")]
72 Timeout(String),
73
74 #[error("IO error: {0}")]
75 Io(#[from] std::io::Error),
76
77 #[error("JSON error: {0}")]
78 Json(#[from] serde_json::Error),
79}
80
81pub type McpResult<T> = std::result::Result<T, McpError>;
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct JsonRpcRequest {
91 pub jsonrpc: String,
92 pub method: String,
93 #[serde(default)]
94 pub params: Option<serde_json::Value>,
95 pub id: serde_json::Value,
96}
97
98impl JsonRpcRequest {
99 pub fn new(method: &str, params: serde_json::Value, id: i64) -> Self {
100 Self {
101 jsonrpc: "2.0".to_string(),
102 method: method.to_string(),
103 params: Some(params),
104 id: serde_json::Value::Number(id.into()),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct JsonRpcResponse {
112 pub jsonrpc: String,
113 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub result: Option<serde_json::Value>,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub error: Option<JsonRpcError>,
117 pub id: serde_json::Value,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct JsonRpcError {
122 pub code: i32,
123 pub message: String,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub data: Option<serde_json::Value>,
126}
127
128pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136pub struct InitializeParams {
137 pub protocol_version: String,
138 pub capabilities: McpClientCapabilities,
139 pub client_info: McpClientInfo,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "camelCase")]
144pub struct McpClientCapabilities {
145 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub roots: Option<McpRootsCapability>,
147 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub sampling: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct McpRootsCapability {
154 pub list_changed: bool,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct McpClientInfo {
159 pub name: String,
160 pub version: String,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165#[serde(rename_all = "camelCase")]
166pub struct InitializeResult {
167 pub protocol_version: String,
168 pub capabilities: McpServerCapabilities,
169 pub server_info: McpServerInfo,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173#[serde(rename_all = "camelCase")]
174pub struct McpServerCapabilities {
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub tools: Option<McpToolsCapability>,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub resources: Option<serde_json::Value>,
179 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub prompts: Option<serde_json::Value>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(rename_all = "camelCase")]
185pub struct McpToolsCapability {
186 #[serde(default)]
187 pub list_changed: bool,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct McpServerInfo {
192 pub name: String,
193 pub version: String,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct McpTool {
199 pub name: String,
200 #[serde(default, skip_serializing_if = "Option::is_none")]
201 pub description: Option<String>,
202 pub input_schema: serde_json::Value,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct McpToolCall {
208 pub name: String,
209 #[serde(default, skip_serializing_if = "Option::is_none")]
210 pub arguments: Option<serde_json::Value>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215#[serde(rename_all = "camelCase")]
216pub struct McpToolResult {
217 pub content: Vec<McpContent>,
218 #[serde(default)]
219 pub is_error: bool,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum McpContent {
225 Text { text: String },
226 Image { data: String, mime_type: String },
227 Resource { resource: serde_json::Value },
228}
229
230#[derive(Debug, Clone)]
234pub enum McpTransportConfig {
235 Stdio {
237 command: String,
238 args: Vec<String>,
239 env: HashMap<String, String>,
240 },
241 Sse { url: String },
243}
244
245pub struct McpTransport {
247 config: McpTransportConfig,
248 child: Option<Child>,
250 stdin: Option<tokio::process::ChildStdin>,
251 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
252 http_client: Option<reqwest::Client>,
254 message_endpoint: Option<String>,
256 sse_response_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
258 sse_reader_handle: Option<tokio::task::JoinHandle<()>>,
260 request_id: i64,
261}
262
263impl McpTransport {
264 pub fn new(config: McpTransportConfig) -> Self {
266 Self {
267 config,
268 child: None,
269 stdin: None,
270 stdout_reader: None,
271 http_client: None,
272 message_endpoint: None,
273 sse_response_rx: None,
274 sse_reader_handle: None,
275 request_id: 0,
276 }
277 }
278
279 pub async fn connect(&mut self) -> McpResult<()> {
281 let config = self.config.clone();
282 match &config {
283 McpTransportConfig::Stdio { command, args, env } => {
284 self.connect_stdio(command, args, env).await
285 }
286 McpTransportConfig::Sse { url } => self.connect_sse(url).await,
287 }
288 }
289
290 async fn connect_stdio(
292 &mut self,
293 command: &str,
294 args: &[String],
295 env: &HashMap<String, String>,
296 ) -> McpResult<()> {
297 let mut cmd = Command::new(command);
298 cmd.args(args);
299 cmd.stdin(Stdio::piped());
300 cmd.stdout(Stdio::piped());
301 cmd.stderr(Stdio::piped());
302 cmd.envs(env);
303
304 let mut child = cmd.spawn().map_err(|e| {
305 McpError::ConnectionFailed(format!("Failed to spawn {}: {}", command, e))
306 })?;
307
308 let stdin = child
309 .stdin
310 .take()
311 .ok_or_else(|| McpError::ConnectionFailed("No stdin available".to_string()))?;
312
313 let stdout = child
314 .stdout
315 .take()
316 .ok_or_else(|| McpError::ConnectionFailed("No stdout available".to_string()))?;
317
318 self.child = Some(child);
319 self.stdin = Some(stdin);
320 self.stdout_reader = Some(BufReader::new(stdout));
321
322 info!(command = %command, "MCP stdio transport connected");
323 Ok(())
324 }
325
326 async fn connect_sse(&mut self, url: &str) -> McpResult<()> {
328 let client = reqwest::Client::builder()
329 .danger_accept_invalid_certs(false)
330 .build()
331 .map_err(|e| {
332 McpError::ConnectionFailed(format!("Failed to create HTTP client: {}", e))
333 })?;
334
335 info!(url = %url, "MCP SSE transport connecting");
336
337 let response = client
339 .get(url)
340 .header("Accept", "text/event-stream")
341 .send()
342 .await
343 .map_err(|e| McpError::ConnectionFailed(format!("SSE connection failed: {}", e)))?;
344
345 if !response.status().is_success() {
346 return Err(McpError::ConnectionFailed(format!(
347 "SSE connection returned status {}",
348 response.status()
349 )));
350 }
351
352 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
354
355 let message_endpoint = Arc::new(tokio::sync::Mutex::new(None::<String>));
357 let message_endpoint_clone = message_endpoint.clone();
358
359 let handle = tokio::spawn(async move {
361 let mut stream = response.bytes_stream();
362 let mut buffer = String::new();
363
364 while let Some(chunk_result) = stream.next().await {
365 match chunk_result {
366 Ok(chunk) => {
367 let chunk_str = String::from_utf8_lossy(&chunk);
368 buffer.push_str(&chunk_str);
369
370 while let Some(event_end) = buffer.find("\n\n") {
372 let event = buffer[..event_end].to_string();
373 buffer = buffer[event_end + 2..].to_string();
374 Self::handle_sse_event(&event, &tx, &message_endpoint_clone).await;
375 }
376 }
377 Err(e) => {
378 warn!("SSE stream error: {}", e);
379 break;
380 }
381 }
382 }
383 debug!("MCP SSE stream ended");
384 });
385
386 self.http_client = Some(client);
387 self.sse_response_rx = Some(rx);
388 self.sse_reader_handle = Some(handle);
389 self.message_endpoint = None; let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(5);
394 let mut endpoint_received = false;
395
396 while tokio::time::Instant::now() < deadline {
397 if let Ok(msg) = self.sse_response_rx.as_mut().unwrap().try_recv() {
398 if msg.starts_with("endpoint:") {
400 let ep = msg
401 .strip_prefix("endpoint:")
402 .unwrap_or("")
403 .trim()
404 .to_string();
405 let resolved = if ep.starts_with("http") {
407 ep.clone()
408 } else {
409 let base = url.trim_end_matches('/');
410 let path = ep.trim_start_matches('/');
411 format!("{}/{}", base, path)
412 };
413 self.message_endpoint = Some(resolved.clone());
414 *message_endpoint.lock().await = Some(resolved);
415 endpoint_received = true;
416 info!(
417 "MCP SSE endpoint received: {}",
418 self.message_endpoint.as_ref().unwrap()
419 );
420 break;
421 }
422 } else {
423 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
424 }
425 }
426
427 if !endpoint_received {
428 self.message_endpoint = Some(url.to_string());
430 *message_endpoint.lock().await = Some(url.to_string());
431 warn!("No SSE endpoint event received, using SSE URL as message endpoint");
432 }
433
434 info!(url = %url, "MCP SSE transport connected");
435 Ok(())
436 }
437
438 async fn handle_sse_event(
440 event: &str,
441 tx: &tokio::sync::mpsc::UnboundedSender<String>,
442 message_endpoint: &Arc<tokio::sync::Mutex<Option<String>>>,
443 ) {
444 for line in event.lines() {
445 let line = line.trim();
446 if let Some(value) = line.strip_prefix("event:") {
447 let event_type = value.trim();
448 debug!("MCP SSE event: {}", event_type);
449 } else if let Some(value) = line.strip_prefix("data:") {
450 let data = value.trim().to_string();
451 if let Some(ep) = data.strip_prefix("/message") {
453 *message_endpoint.lock().await = Some(ep.to_string());
454 let _ = tx.send(format!("endpoint:{}", ep));
455 } else if data.starts_with("http") && data.contains("/message") {
456 *message_endpoint.lock().await = Some(data.clone());
457 let _ = tx.send(format!("endpoint:{}", data));
458 } else {
459 let _ = tx.send(data);
461 }
462 }
463 }
464 }
465
466 pub async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
468 match &self.config {
469 McpTransportConfig::Stdio { .. } => self.send_request_stdio(request).await,
470 McpTransportConfig::Sse { .. } => self.send_request_sse(request).await,
471 }
472 }
473
474 async fn send_request_stdio(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
476 let request_json = serde_json::to_string(&request)?;
477 debug!("MCP → {}", request_json);
478
479 let stdin = self
480 .stdin
481 .as_mut()
482 .ok_or_else(|| McpError::Transport("Transport not connected".to_string()))?;
483
484 stdin.write_all(request_json.as_bytes()).await?;
485 stdin.write_all(b"\n").await?;
486 stdin.flush().await?;
487
488 let stdout = self
489 .stdout_reader
490 .as_mut()
491 .ok_or_else(|| McpError::Transport("Transport not connected".to_string()))?;
492
493 let mut response_line = String::new();
494 stdout.read_line(&mut response_line).await?;
495
496 if response_line.trim().is_empty() {
497 return Err(McpError::Transport(
498 "Empty response from server".to_string(),
499 ));
500 }
501
502 debug!("MCP ← {}", response_line.trim());
503
504 let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
505
506 if let Some(err) = &response.error {
507 return Err(McpError::Server {
508 code: err.code,
509 message: err.message.clone(),
510 });
511 }
512
513 Ok(response)
514 }
515
516 async fn send_request_sse(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
518 let request_json = serde_json::to_string(&request)?;
519 let request_id = request.id.clone();
520 debug!(
521 "MCP SSE → POST {}: {}",
522 self.message_endpoint.as_deref().unwrap_or("?"),
523 request_json
524 );
525
526 let client = self
527 .http_client
528 .as_ref()
529 .ok_or_else(|| McpError::Transport("HTTP client not initialized".to_string()))?;
530
531 let endpoint = self
532 .message_endpoint
533 .as_ref()
534 .ok_or_else(|| McpError::Transport("No message endpoint available".to_string()))?;
535
536 let response = client
538 .post(endpoint)
539 .header("Content-Type", "application/json")
540 .header("Accept", "application/json")
541 .body(request_json.clone())
542 .send()
543 .await
544 .map_err(|e| McpError::Transport(format!("HTTP POST failed: {}", e)))?;
545
546 let status = response.status();
549 let body = response
550 .text()
551 .await
552 .map_err(|e| McpError::Transport(format!("Failed to read response body: {}", e)))?;
553
554 if !body.trim().is_empty() {
555 if let Ok(rpc_response) = serde_json::from_str::<JsonRpcResponse>(&body) {
557 debug!("MCP SSE ← {}", body.trim());
558 if let Some(err) = &rpc_response.error {
559 return Err(McpError::Server {
560 code: err.code,
561 message: err.message.clone(),
562 });
563 }
564 return Ok(rpc_response);
565 }
566 }
567
568 if status.is_success() {
570 let rx = self
571 .sse_response_rx
572 .as_mut()
573 .ok_or_else(|| McpError::Transport("SSE receiver not available".to_string()))?;
574
575 let timeout = tokio::time::Duration::from_secs(30);
576 let deadline = tokio::time::Instant::now() + timeout;
577
578 while tokio::time::Instant::now() < deadline {
579 match tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv()).await {
580 Ok(Some(msg)) => {
581 if msg.starts_with("endpoint:") {
583 continue;
584 }
585 if let Ok(rpc_response) = serde_json::from_str::<JsonRpcResponse>(&msg) {
587 if rpc_response.id == request_id
589 || rpc_response.id == serde_json::Value::Null
590 {
591 debug!("MCP SSE ← {}", msg);
592 if let Some(err) = &rpc_response.error {
593 return Err(McpError::Server {
594 code: err.code,
595 message: err.message.clone(),
596 });
597 }
598 return Ok(rpc_response);
599 }
600 }
601 }
602 Ok(None) => break,
603 Err(_) => continue, }
605 }
606
607 return Err(McpError::Timeout(
608 "No response received from SSE stream".to_string(),
609 ));
610 }
611
612 Err(McpError::Transport(format!(
613 "HTTP POST returned status {}",
614 status
615 )))
616 }
617
618 fn next_id(&mut self) -> i64 {
620 self.request_id += 1;
621 self.request_id
622 }
623}
624
625impl Drop for McpTransport {
626 fn drop(&mut self) {
627 if let Some(handle) = self.sse_reader_handle.take() {
629 handle.abort();
630 }
631 }
632}
633
634pub struct McpClient {
638 transport: McpTransport,
639 server_info: Option<McpServerInfo>,
640 tools: Arc<RwLock<Vec<McpTool>>>,
641}
642
643impl McpClient {
644 pub fn new(config: McpTransportConfig) -> Self {
646 Self {
647 transport: McpTransport::new(config),
648 server_info: None,
649 tools: Arc::new(RwLock::new(Vec::new())),
650 }
651 }
652
653 pub async fn connect(&mut self) -> McpResult<()> {
655 self.transport.connect().await?;
657
658 let init_params = InitializeParams {
660 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
661 capabilities: McpClientCapabilities {
662 roots: Some(McpRootsCapability {
663 list_changed: false,
664 }),
665 sampling: None,
666 },
667 client_info: McpClientInfo {
668 name: "ravenclaws".to_string(),
669 version: env!("CARGO_PKG_VERSION").to_string(),
670 },
671 };
672
673 let init_id = self.transport.next_id();
674 let response = self
675 .transport
676 .send_request(JsonRpcRequest::new(
677 "initialize",
678 serde_json::to_value(init_params)?,
679 init_id,
680 ))
681 .await?;
682
683 let init_result: InitializeResult = response
684 .result
685 .and_then(|v| serde_json::from_value(v).ok())
686 .ok_or_else(|| McpError::JsonRpc("Invalid initialize response".to_string()))?;
687
688 let server_info = init_result.server_info.clone();
689 self.server_info = Some(init_result.server_info);
690
691 info!(
692 server = %server_info.name,
693 version = %server_info.version,
694 "MCP server initialized"
695 );
696
697 let notify = JsonRpcRequest {
699 jsonrpc: "2.0".to_string(),
700 method: "notifications/initialized".to_string(),
701 params: Some(serde_json::Value::Null),
702 id: serde_json::Value::Null,
703 };
704 self.transport.send_request(notify).await?;
705
706 self.discover_tools().await?;
708
709 Ok(())
710 }
711
712 pub async fn discover_tools(&mut self) -> McpResult<()> {
714 let list_id = self.transport.next_id();
715 let response = self
716 .transport
717 .send_request(JsonRpcRequest::new(
718 "tools/list",
719 serde_json::Value::Null,
720 list_id,
721 ))
722 .await?;
723
724 let tools_result = response
725 .result
726 .and_then(|v| v.get("tools").cloned())
727 .ok_or_else(|| McpError::JsonRpc("No tools in response".to_string()))?;
728
729 let tools: Vec<McpTool> = serde_json::from_value(tools_result)?;
730
731 info!(count = tools.len(), "Discovered MCP tools");
732
733 let mut tool_lock = self.tools.write().await;
734 *tool_lock = tools;
735
736 Ok(())
737 }
738
739 pub async fn get_tools(&self) -> Vec<McpTool> {
741 self.tools.read().await.clone()
742 }
743
744 pub async fn call_tool(
746 &mut self,
747 name: &str,
748 arguments: Option<serde_json::Value>,
749 ) -> McpResult<McpToolResult> {
750 let params = McpToolCall {
751 name: name.to_string(),
752 arguments,
753 };
754
755 let call_id = self.transport.next_id();
756 let response = self
757 .transport
758 .send_request(JsonRpcRequest::new(
759 "tools/call",
760 serde_json::to_value(params)?,
761 call_id,
762 ))
763 .await?;
764
765 let result: McpToolResult = response
766 .result
767 .and_then(|v| serde_json::from_value(v).ok())
768 .ok_or_else(|| McpError::JsonRpc("Invalid tool call response".to_string()))?;
769
770 if result.is_error {
771 return Err(McpError::Server {
772 code: -32000,
773 message: "Tool execution failed".to_string(),
774 });
775 }
776
777 Ok(result)
778 }
779
780 pub fn server_info(&self) -> Option<&McpServerInfo> {
782 self.server_info.as_ref()
783 }
784}
785
786pub struct McpClientManager {
794 clients: Vec<(String, Arc<RwLock<McpClient>>)>,
796}
797
798impl McpClientManager {
799 pub fn new() -> Self {
801 Self {
802 clients: Vec::new(),
803 }
804 }
805
806 pub async fn from_config(config: &crate::config::McpConfig) -> Self {
808 let mut manager = Self::new();
809 for server in &config.servers {
810 let transport_config = if !server.url.is_empty() {
812 McpTransportConfig::Sse {
813 url: server.url.clone(),
814 }
815 } else {
816 McpTransportConfig::Stdio {
817 command: server.command.clone(),
818 args: server.args.clone(),
819 env: server.env.clone(),
820 }
821 };
822 let mut client = McpClient::new(transport_config);
823 match client.connect().await {
824 Ok(()) => {
825 info!(
826 server = %server.name,
827 server_info = ?client.server_info(),
828 "MCP client connected from config"
829 );
830 manager
831 .clients
832 .push((server.name.clone(), Arc::new(RwLock::new(client))));
833 }
834 Err(e) => {
835 warn!(
836 server = %server.name,
837 error = %e,
838 "Failed to connect to MCP server from config, skipping"
839 );
840 }
841 }
842 }
843 manager
844 }
845
846 #[allow(dead_code)]
848 pub fn add_client(&mut self, name: String, client: Arc<RwLock<McpClient>>) {
849 self.clients.push((name, client));
850 }
851
852 #[allow(dead_code)]
854 pub fn clients(&self) -> &[(String, Arc<RwLock<McpClient>>)] {
855 &self.clients
856 }
857
858 #[allow(dead_code)]
860 pub fn get_client(&self, name: &str) -> Option<&Arc<RwLock<McpClient>>> {
861 self.clients.iter().find(|(n, _)| n == name).map(|(_, c)| c)
862 }
863
864 pub async fn register_all_tools(&self, registry: &mut crate::tools::ToolRegistry) -> usize {
866 let mut total = 0;
867 for (name, client) in &self.clients {
868 let mcp_client = client.read().await;
869 let mcp_tools = mcp_client.get_tools().await;
870 drop(mcp_client);
871
872 for mcp_tool in mcp_tools {
873 let wrapper = McpToolWrapper::new(client.clone(), mcp_tool);
874 registry.register(Arc::new(wrapper));
875 total += 1;
876 }
877 info!(
878 server = %name,
879 tools_registered = total,
880 "Registered MCP tools from server"
881 );
882 }
883 info!(total, "Total MCP tools registered from all servers");
884 total
885 }
886
887 pub fn len(&self) -> usize {
889 self.clients.len()
890 }
891
892 pub fn is_empty(&self) -> bool {
894 self.clients.is_empty()
895 }
896}
897
898impl Default for McpClientManager {
899 fn default() -> Self {
900 Self::new()
901 }
902}
903
904pub struct McpToolWrapper {
908 definition: ToolDefinition,
909 client: Arc<RwLock<McpClient>>,
910 tool_name: String,
911}
912
913impl McpToolWrapper {
914 pub fn new(client: Arc<RwLock<McpClient>>, mcp_tool: McpTool) -> Self {
916 let parameters = Self::convert_schema(&mcp_tool.input_schema);
918
919 Self {
920 definition: ToolDefinition {
921 name: mcp_tool.name.clone(),
922 description: mcp_tool
923 .description
924 .unwrap_or_else(|| "MCP-provided tool".to_string()),
925 parameters,
926 requires_approval: false,
927 category: ToolCategory::Mcp,
928 },
929 client,
930 tool_name: mcp_tool.name,
931 }
932 }
933
934 fn convert_schema(schema: &serde_json::Value) -> JsonSchema {
936 if let Some(obj) = schema.as_object() {
937 let schema_type = obj
938 .get("type")
939 .and_then(|v| v.as_str())
940 .unwrap_or("object")
941 .to_string();
942
943 let description = obj
944 .get("description")
945 .and_then(|v| v.as_str())
946 .map(|s| s.to_string());
947
948 let properties = obj
949 .get("properties")
950 .and_then(|v| v.as_object())
951 .map(|props| {
952 props
953 .iter()
954 .map(|(k, v)| (k.clone(), Self::convert_schema(v)))
955 .collect::<HashMap<String, JsonSchema>>()
956 });
957
958 let required = obj.get("required").and_then(|v| v.as_array()).map(|arr| {
959 arr.iter()
960 .filter_map(|v| v.as_str())
961 .map(|s| s.to_string())
962 .collect()
963 });
964
965 JsonSchema {
966 schema_type,
967 description,
968 properties,
969 required,
970 items: None,
971 enum_values: None,
972 }
973 } else {
974 JsonSchema::string("MCP tool parameter")
975 }
976 }
977}
978
979#[async_trait::async_trait]
980impl ToolImpl for McpToolWrapper {
981 fn definition(&self) -> &ToolDefinition {
982 &self.definition
983 }
984
985 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
986 let mut client = self.client.write().await;
987
988 let result = client
989 .call_tool(&self.tool_name, Some(args))
990 .await
991 .map_err(|e| {
992 crate::tools::ToolError::ExecutionFailed(self.tool_name.clone(), e.to_string())
993 })?;
994
995 let output = result
997 .content
998 .iter()
999 .map(|c| match c {
1000 McpContent::Text { text } => text.clone(),
1001 McpContent::Image { data, mime_type } => {
1002 format!("[Image: {} bytes, {}]", data.len(), mime_type)
1003 }
1004 McpContent::Resource { resource } => {
1005 format!("[Resource: {}]", resource)
1006 }
1007 })
1008 .collect::<Vec<_>>()
1009 .join("\n");
1010
1011 Ok(ToolResult {
1012 tool_name: self.tool_name.clone(),
1013 success: !result.is_error,
1014 output,
1015 error: if result.is_error {
1016 Some("Tool returned error".to_string())
1017 } else {
1018 None
1019 },
1020 exit_code: None,
1021 duration_ms: None,
1022 })
1023 }
1024}
1025
1026pub async fn register_mcp_tools(
1030 registry: &mut crate::tools::ToolRegistry,
1031 client: Arc<RwLock<McpClient>>,
1032) -> McpResult<usize> {
1033 let mcp_client = client.read().await;
1034 let mcp_tools = mcp_client.get_tools().await;
1035 drop(mcp_client);
1036
1037 let count = mcp_tools.len();
1038
1039 for mcp_tool in mcp_tools {
1040 let wrapper = McpToolWrapper::new(client.clone(), mcp_tool);
1041 registry.register(Arc::new(wrapper));
1042 }
1043
1044 info!(count, "Registered MCP tools");
1045 Ok(count)
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::*;
1051
1052 #[test]
1053 fn test_json_rpc_request() {
1054 let req = JsonRpcRequest::new("tools/list", serde_json::Value::Null, 1);
1055 assert_eq!(req.jsonrpc, "2.0");
1056 assert_eq!(req.method, "tools/list");
1057 assert_eq!(req.id, serde_json::Value::Number(1.into()));
1058 }
1059
1060 #[test]
1061 fn test_mcp_tool_serialization() {
1062 let tool = McpTool {
1063 name: "test_tool".to_string(),
1064 description: Some("A test tool".to_string()),
1065 input_schema: serde_json::json!({
1066 "type": "object",
1067 "properties": {
1068 "query": {"type": "string"}
1069 }
1070 }),
1071 };
1072
1073 let json = serde_json::to_string(&tool).unwrap();
1074 assert!(json.contains("test_tool"));
1075 assert!(json.contains("A test tool"));
1076 }
1077
1078 #[test]
1079 fn test_initialize_params() {
1080 let params = InitializeParams {
1081 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
1082 capabilities: McpClientCapabilities {
1083 roots: Some(McpRootsCapability {
1084 list_changed: false,
1085 }),
1086 sampling: None,
1087 },
1088 client_info: McpClientInfo {
1089 name: "ravenclaws".to_string(),
1090 version: "0.5.2".to_string(),
1091 },
1092 };
1093
1094 let json = serde_json::to_string(¶ms).unwrap();
1095 assert!(json.contains("protocolVersion"));
1096 assert!(json.contains("ravenclaws"));
1097 }
1098}
1099
1100pub struct McpServer {
1117 registry: crate::tools::ToolRegistry,
1119 policy_engine: crate::policy::PolicyEngine,
1121 sandbox: crate::sandbox::Sandbox,
1123 audit_log: crate::audit::AuditLog,
1125 initialized: bool,
1127 server_info: McpServerInfo,
1129 request_id: i64,
1131}
1132
1133impl McpServer {
1134 pub fn new(registry: crate::tools::ToolRegistry) -> Self {
1138 let server_info = McpServerInfo {
1139 name: "ravenclaws".to_string(),
1140 version: env!("CARGO_PKG_VERSION").to_string(),
1141 };
1142
1143 Self {
1144 registry,
1145 policy_engine: crate::policy::PolicyEngine::default_secure(),
1146 sandbox: crate::sandbox::Sandbox::default(),
1147 audit_log: crate::audit::AuditLog::new(format!("mcp-server-{}", std::process::id())),
1148 initialized: false,
1149 server_info,
1150 request_id: 0,
1151 }
1152 }
1153
1154 pub async fn run(&mut self) -> Result<(), McpError> {
1157 self.sandbox
1159 .init()
1160 .await
1161 .map_err(|e| McpError::Transport(format!("Sandbox init failed: {}", e)))?;
1162
1163 info!("MCP server starting on stdio");
1164
1165 let stdin = tokio::io::stdin();
1166 let reader = BufReader::new(stdin);
1167 let mut lines = reader.lines();
1168
1169 while let Ok(Some(line)) = lines.next_line().await {
1170 let line = line.trim().to_string();
1171 if line.is_empty() {
1172 continue;
1173 }
1174
1175 debug!("MCP Server ← {}", &line);
1176
1177 let request: JsonRpcRequest = match serde_json::from_str(&line) {
1179 Ok(req) => req,
1180 Err(e) => {
1181 let error_response = serde_json::json!({
1182 "jsonrpc": "2.0",
1183 "error": {
1184 "code": -32700,
1185 "message": "Parse error",
1186 "data": e.to_string()
1187 },
1188 "id": serde_json::Value::Null
1189 });
1190 let _ = self.write_response(&error_response).await;
1191 continue;
1192 }
1193 };
1194
1195 let response = self.handle_request(&request).await;
1196 let _ = self.write_response(&response).await;
1197 }
1198
1199 info!("MCP server shutting down (stdin closed)");
1200 Ok(())
1201 }
1202
1203 async fn handle_request(&mut self, request: &JsonRpcRequest) -> serde_json::Value {
1205 let request_id = request.id.clone();
1206
1207 match request.method.as_str() {
1208 "initialize" => self.handle_initialize(request, &request_id).await,
1209 "notifications/initialized" => {
1210 self.initialized = true;
1211 info!("MCP server initialized by client");
1212 serde_json::json!({
1213 "jsonrpc": "2.0",
1214 "result": null,
1215 "id": request_id
1216 })
1217 }
1218 "tools/list" => self.handle_tools_list(&request_id).await,
1219 "tools/call" => self.handle_tools_call(request, &request_id).await,
1220 _ => {
1221 serde_json::json!({
1222 "jsonrpc": "2.0",
1223 "error": {
1224 "code": -32601,
1225 "message": format!("Method not found: {}", request.method)
1226 },
1227 "id": request_id
1228 })
1229 }
1230 }
1231 }
1232
1233 async fn handle_initialize(
1235 &mut self,
1236 request: &JsonRpcRequest,
1237 request_id: &serde_json::Value,
1238 ) -> serde_json::Value {
1239 if let Some(params) = request.params.as_ref().and_then(|p| p.as_object()) {
1241 if let Some(client_info) = params.get("clientInfo") {
1242 info!(
1243 client = ?client_info.get("name").and_then(|v| v.as_str()).unwrap_or("unknown"),
1244 "MCP client connected"
1245 );
1246 }
1247 }
1248
1249 let capabilities = McpServerCapabilities {
1250 tools: Some(McpToolsCapability {
1251 list_changed: false,
1252 }),
1253 resources: None,
1254 prompts: None,
1255 };
1256
1257 let result = serde_json::json!({
1258 "protocolVersion": MCP_PROTOCOL_VERSION,
1259 "capabilities": capabilities,
1260 "serverInfo": {
1261 "name": self.server_info.name,
1262 "version": self.server_info.version
1263 }
1264 });
1265
1266 serde_json::json!({
1267 "jsonrpc": "2.0",
1268 "result": result,
1269 "id": request_id
1270 })
1271 }
1272
1273 async fn handle_tools_list(&self, request_id: &serde_json::Value) -> serde_json::Value {
1275 let tools: Vec<serde_json::Value> = self
1276 .registry
1277 .definitions()
1278 .iter()
1279 .map(|def| {
1280 serde_json::json!({
1281 "name": def.name,
1282 "description": def.description,
1283 "inputSchema": def.parameters
1284 })
1285 })
1286 .collect();
1287
1288 serde_json::json!({
1289 "jsonrpc": "2.0",
1290 "result": {
1291 "tools": tools
1292 },
1293 "id": request_id
1294 })
1295 }
1296
1297 async fn handle_tools_call(
1299 &mut self,
1300 request: &JsonRpcRequest,
1301 request_id: &serde_json::Value,
1302 ) -> serde_json::Value {
1303 let params = request.params.as_ref().unwrap_or(&serde_json::Value::Null);
1304
1305 let name = params
1306 .get("name")
1307 .and_then(|v| v.as_str())
1308 .unwrap_or("")
1309 .to_string();
1310
1311 let arguments = params
1312 .get("arguments")
1313 .cloned()
1314 .unwrap_or(serde_json::Value::Null);
1315
1316 if name.is_empty() {
1317 return serde_json::json!({
1318 "jsonrpc": "2.0",
1319 "error": {
1320 "code": -32602,
1321 "message": "Invalid params: missing tool name"
1322 },
1323 "id": request_id
1324 });
1325 }
1326
1327 let policy_decision = self.policy_engine.check_tool_call(&name, &arguments);
1329 match policy_decision {
1330 crate::policy::Decision::Deny(reason) => {
1331 warn!(tool = %name, reason = %reason, "MCP tool call denied by policy");
1332 return serde_json::json!({
1333 "jsonrpc": "2.0",
1334 "result": {
1335 "content": [{
1336 "type": "text",
1337 "text": format!("Policy denied: {}", reason)
1338 }],
1339 "isError": true
1340 },
1341 "id": request_id
1342 });
1343 }
1344 crate::policy::Decision::Allow => {
1345 let _ = self.audit_log.tool_call(&name, &arguments);
1347 }
1348 }
1349
1350 let call = crate::tools::ToolCall {
1352 name: name.clone(),
1353 arguments,
1354 id: None,
1355 };
1356
1357 match self.registry.execute(call).await {
1358 Ok(result) => {
1359 let _ = self.audit_log.append(
1361 crate::audit::AuditEventType::ToolResult,
1362 &name,
1363 &format!("MCP tool executed: {} (success: {})", name, result.success),
1364 Some(serde_json::json!({
1365 "success": result.success,
1366 "exit_code": result.exit_code,
1367 "duration_ms": result.duration_ms,
1368 })),
1369 );
1370
1371 let content = if result.success {
1372 vec![serde_json::json!({
1373 "type": "text",
1374 "text": result.output
1375 })]
1376 } else {
1377 vec![serde_json::json!({
1378 "type": "text",
1379 "text": result.error.as_deref().unwrap_or("Unknown error")
1380 })]
1381 };
1382
1383 serde_json::json!({
1384 "jsonrpc": "2.0",
1385 "result": {
1386 "content": content,
1387 "isError": !result.success
1388 },
1389 "id": request_id
1390 })
1391 }
1392 Err(e) => {
1393 warn!(tool = %name, error = %e, "MCP tool execution failed");
1394 serde_json::json!({
1395 "jsonrpc": "2.0",
1396 "result": {
1397 "content": [{
1398 "type": "text",
1399 "text": format!("Tool execution failed: {}", e)
1400 }],
1401 "isError": true
1402 },
1403 "id": request_id
1404 })
1405 }
1406 }
1407 }
1408
1409 async fn write_response(&self, response: &serde_json::Value) -> std::io::Result<()> {
1411 let json = serde_json::to_string(response)?;
1412 debug!("MCP Server → {}", &json);
1413 use tokio::io::AsyncWriteExt;
1414 let mut stdout = tokio::io::stdout();
1415 stdout.write_all(json.as_bytes()).await?;
1416 stdout.write_all(b"\n").await?;
1417 stdout.flush().await?;
1418 Ok(())
1419 }
1420
1421 #[allow(dead_code)]
1423 fn next_id(&mut self) -> i64 {
1424 self.request_id += 1;
1425 self.request_id
1426 }
1427}
1428
1429#[cfg(test)]
1430mod server_tests {
1431 use super::*;
1432 use crate::tools::ToolRegistry;
1433
1434 #[test]
1435 fn test_mcp_server_initialize_response() {
1436 let registry = ToolRegistry::with_default_tools();
1437 let server = McpServer::new(registry);
1438
1439 assert_eq!(server.server_info.name, "ravenclaws");
1441 assert!(!server.server_info.version.is_empty());
1442 assert!(!server.initialized);
1443 }
1444
1445 #[test]
1446 fn test_mcp_server_tools_list_response() {
1447 let registry = ToolRegistry::with_default_tools();
1448 let server = McpServer::new(registry);
1449
1450 let defs = server.registry.definitions();
1452 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
1453 assert!(names.contains(&"shell_exec"));
1454 assert!(names.contains(&"read_file"));
1455 assert!(names.contains(&"write_file"));
1456 assert!(names.contains(&"web_fetch"));
1457 assert!(names.contains(&"web_search"));
1458 assert!(names.contains(&"browser"));
1459 assert_eq!(defs.len(), 6);
1460 }
1461
1462 #[tokio::test]
1463 async fn test_mcp_server_handle_unknown_method() {
1464 let registry = ToolRegistry::with_default_tools();
1465 let mut server = McpServer::new(registry);
1466
1467 let request = JsonRpcRequest {
1468 jsonrpc: "2.0".to_string(),
1469 method: "unknown_method".to_string(),
1470 params: Some(serde_json::Value::Null),
1471 id: serde_json::Value::Number(1.into()),
1472 };
1473
1474 let response = server.handle_request(&request).await;
1475 assert!(response.get("error").is_some());
1476 assert_eq!(
1477 response["error"]["code"],
1478 serde_json::Value::Number((-32601).into())
1479 );
1480 }
1481
1482 #[tokio::test]
1483 async fn test_mcp_server_handle_tools_list() {
1484 let registry = ToolRegistry::with_default_tools();
1485 let server = McpServer::new(registry);
1486
1487 let request_id = serde_json::Value::Number(1.into());
1488 let response = server.handle_tools_list(&request_id).await;
1489
1490 assert!(response.get("result").is_some());
1491 let tools = &response["result"]["tools"];
1492 assert!(tools.is_array());
1493 assert!(!tools.as_array().unwrap().is_empty());
1494 }
1495
1496 #[tokio::test]
1497 async fn test_mcp_server_handle_tools_call_missing_name() {
1498 let registry = ToolRegistry::with_default_tools();
1499 let mut server = McpServer::new(registry);
1500
1501 let request = JsonRpcRequest {
1502 jsonrpc: "2.0".to_string(),
1503 method: "tools/call".to_string(),
1504 params: Some(serde_json::json!({})),
1505 id: serde_json::Value::Number(1.into()),
1506 };
1507
1508 let request_id = serde_json::Value::Number(1.into());
1509 let response = server.handle_tools_call(&request, &request_id).await;
1510
1511 assert!(response.get("error").is_some());
1512 assert_eq!(
1513 response["error"]["code"],
1514 serde_json::Value::Number((-32602).into())
1515 );
1516 }
1517
1518 #[tokio::test]
1519 async fn test_mcp_server_handle_tools_call_unknown_tool() {
1520 let registry = ToolRegistry::with_default_tools();
1521 let mut server = McpServer::new(registry);
1522
1523 let request = JsonRpcRequest {
1524 jsonrpc: "2.0".to_string(),
1525 method: "tools/call".to_string(),
1526 params: Some(serde_json::json!({
1527 "name": "nonexistent_tool",
1528 "arguments": {}
1529 })),
1530 id: serde_json::Value::Number(1.into()),
1531 };
1532
1533 let request_id = serde_json::Value::Number(1.into());
1534 let response = server.handle_tools_call(&request, &request_id).await;
1535
1536 assert!(response["result"]["isError"].as_bool().unwrap_or(false));
1538 }
1539
1540 #[test]
1541 fn test_mcp_server_json_rpc_error_codes() {
1542 assert_eq!(-32700i32, -32700);
1548 assert_eq!(-32601i32, -32601);
1549 assert_eq!(-32602i32, -32602);
1550 }
1551}
1552
1553pub struct McpSseServer {
1569 registry: crate::tools::ToolRegistry,
1570 policy_engine: crate::policy::PolicyEngine,
1571 sandbox: crate::sandbox::Sandbox,
1572 audit_log: crate::audit::AuditLog,
1573 server_info: McpServerInfo,
1574 clients: Arc<tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>>,
1576 host: String,
1577 port: u16,
1578}
1579
1580impl McpSseServer {
1581 pub fn new(registry: crate::tools::ToolRegistry, host: String, port: u16) -> Self {
1583 let server_info = McpServerInfo {
1584 name: "ravenclaws".to_string(),
1585 version: env!("CARGO_PKG_VERSION").to_string(),
1586 };
1587
1588 Self {
1589 registry,
1590 policy_engine: crate::policy::PolicyEngine::default_secure(),
1591 sandbox: crate::sandbox::Sandbox::default(),
1592 audit_log: crate::audit::AuditLog::new(format!("mcp-sse-{}", std::process::id())),
1593 server_info,
1594 clients: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
1595 host,
1596 port,
1597 }
1598 }
1599
1600 pub async fn run(
1602 &mut self,
1603 shutdown: tokio::sync::watch::Receiver<bool>,
1604 ) -> Result<(), McpError> {
1605 self.sandbox
1607 .init()
1608 .await
1609 .map_err(|e| McpError::Transport(format!("Sandbox init failed: {}", e)))?;
1610
1611 let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
1612 .parse()
1613 .map_err(|e| McpError::Transport(format!("Invalid address: {}", e)))?;
1614
1615 info!(addr = %addr, "MCP SSE server starting");
1616
1617 let listener = tokio::net::TcpListener::bind(addr)
1620 .await
1621 .map_err(|e| McpError::Transport(format!("Failed to bind: {}", e)))?;
1622
1623 let clients = self.clients.clone();
1624 let registry = Arc::new(tokio::sync::RwLock::new(std::mem::replace(
1625 &mut self.registry,
1626 crate::tools::ToolRegistry::new(),
1627 )));
1628 let policy_engine = Arc::new(tokio::sync::RwLock::new(std::mem::replace(
1629 &mut self.policy_engine,
1630 crate::policy::PolicyEngine::default_secure(),
1631 )));
1632 let sandbox = Arc::new(tokio::sync::RwLock::new(std::mem::take(&mut self.sandbox)));
1633 let audit_log = Arc::new(tokio::sync::RwLock::new(std::mem::replace(
1634 &mut self.audit_log,
1635 crate::audit::AuditLog::new(format!("mcp-sse-{}", std::process::id())),
1636 )));
1637 let server_info = Arc::new(self.server_info.clone());
1638
1639 let mut shutdown = shutdown;
1641
1642 loop {
1643 tokio::select! {
1644 accept_result = listener.accept() => {
1645 match accept_result {
1646 Ok((stream, peer_addr)) => {
1647 let clients = clients.clone();
1648 let registry = registry.clone();
1649 let policy_engine = policy_engine.clone();
1650 let sandbox = sandbox.clone();
1651 let audit_log = audit_log.clone();
1652 let server_info = server_info.clone();
1653
1654 tokio::spawn(async move {
1655 if let Err(e) = Self::handle_connection(
1656 stream, peer_addr, clients, registry,
1657 policy_engine, sandbox, audit_log, server_info,
1658 ).await {
1659 warn!(peer = %peer_addr, error = %e, "MCP SSE connection error");
1660 }
1661 });
1662 }
1663 Err(e) => {
1664 warn!("Accept error: {}", e);
1665 }
1666 }
1667 }
1668 _ = shutdown.changed() => {
1669 info!("MCP SSE server shutting down");
1670 break;
1671 }
1672 }
1673 }
1674
1675 Ok(())
1676 }
1677
1678 #[allow(clippy::too_many_arguments)]
1680 async fn handle_connection(
1681 mut stream: tokio::net::TcpStream,
1682 peer_addr: std::net::SocketAddr,
1683 clients: Arc<
1684 tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>,
1685 >,
1686 registry: Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1687 policy_engine: Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1688 sandbox: Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1689 audit_log: Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1690 server_info: Arc<McpServerInfo>,
1691 ) -> Result<(), McpError> {
1692 use tokio::io::AsyncReadExt;
1693
1694 let mut buf = [0u8; 8192];
1695 let n = stream
1696 .read(&mut buf)
1697 .await
1698 .map_err(|e| McpError::Transport(format!("Read error: {}", e)))?;
1699
1700 if n == 0 {
1701 return Ok(());
1702 }
1703
1704 let request = String::from_utf8_lossy(&buf[..n]).to_string();
1705
1706 let (method, path) = if let Some(first_line) = request.lines().next() {
1708 let parts: Vec<&str> = first_line.split_whitespace().collect();
1709 if parts.len() < 2 {
1710 return Err(McpError::Transport("Invalid HTTP request".to_string()));
1711 }
1712 (parts[0].to_string(), parts[1].to_string())
1713 } else {
1714 return Err(McpError::Transport("Empty HTTP request".to_string()));
1715 };
1716
1717 match (method.as_str(), path.as_str()) {
1718 ("GET", "/sse") => {
1719 Self::handle_sse_connection(
1720 stream,
1721 peer_addr,
1722 clients,
1723 registry,
1724 policy_engine,
1725 sandbox,
1726 audit_log,
1727 server_info,
1728 )
1729 .await
1730 }
1731 ("POST", "/message") => {
1732 let body = if let Some(body_start) = request.find("\r\n\r\n") {
1734 request[body_start + 4..].to_string()
1735 } else {
1736 return Err(McpError::Transport("No body in POST request".to_string()));
1737 };
1738
1739 Self::handle_message_post(
1740 stream,
1741 &body,
1742 ®istry,
1743 &policy_engine,
1744 &sandbox,
1745 &audit_log,
1746 &server_info,
1747 clients,
1748 )
1749 .await
1750 }
1751 _ => {
1752 let response = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
1754 stream
1755 .write_all(response.as_bytes())
1756 .await
1757 .map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
1758 Ok(())
1759 }
1760 }
1761 }
1762
1763 #[allow(clippy::too_many_arguments)]
1765 async fn handle_sse_connection(
1766 mut stream: tokio::net::TcpStream,
1767 peer_addr: std::net::SocketAddr,
1768 clients: Arc<
1769 tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>,
1770 >,
1771 _registry: Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1772 _policy_engine: Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1773 _sandbox: Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1774 _audit_log: Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1775 _server_info: Arc<McpServerInfo>,
1776 ) -> Result<(), McpError> {
1777 use tokio::io::AsyncWriteExt;
1778
1779 let client_id = Uuid::new_v4().to_string();
1780 info!(client = %client_id, peer = %peer_addr, "MCP SSE client connected");
1781
1782 let headers = "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nCache-Control: no-cache\r\nConnection: keep-alive\r\nAccess-Control-Allow-Origin: *\r\n\r\n";
1784 stream
1785 .write_all(headers.as_bytes())
1786 .await
1787 .map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
1788 stream
1789 .flush()
1790 .await
1791 .map_err(|e| McpError::Transport(format!("Flush error: {}", e)))?;
1792
1793 let endpoint_event = "event: endpoint\ndata: /message\n\n".to_string();
1795 stream
1796 .write_all(endpoint_event.as_bytes())
1797 .await
1798 .map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
1799 stream
1800 .flush()
1801 .await
1802 .map_err(|e| McpError::Transport(format!("Flush error: {}", e)))?;
1803
1804 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1806
1807 clients.write().await.insert(client_id.clone(), tx);
1809
1810 loop {
1812 tokio::select! {
1813 msg = rx.recv() => {
1814 match msg {
1815 Some(data) => {
1816 let sse_event = format!("data: {}\n\n", data);
1817 if stream.write_all(sse_event.as_bytes()).await.is_err() {
1818 break;
1819 }
1820 if stream.flush().await.is_err() {
1821 break;
1822 }
1823 }
1824 None => break,
1825 }
1826 }
1827 }
1828 }
1829
1830 clients.write().await.remove(&client_id);
1832 info!(client = %client_id, "MCP SSE client disconnected");
1833 Ok(())
1834 }
1835
1836 #[allow(clippy::too_many_arguments)]
1838 async fn handle_message_post(
1839 mut stream: tokio::net::TcpStream,
1840 body: &str,
1841 registry: &Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1842 policy_engine: &Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1843 sandbox: &Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1844 audit_log: &Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1845 server_info: &Arc<McpServerInfo>,
1846 clients: Arc<
1847 tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>,
1848 >,
1849 ) -> Result<(), McpError> {
1850 use tokio::io::AsyncWriteExt;
1851
1852 let request: JsonRpcRequest = match serde_json::from_str(body) {
1854 Ok(req) => req,
1855 Err(e) => {
1856 let error_response = serde_json::json!({
1857 "jsonrpc": "2.0",
1858 "error": {
1859 "code": -32700,
1860 "message": "Parse error",
1861 "data": e.to_string()
1862 },
1863 "id": serde_json::Value::Null
1864 });
1865 let response_body = serde_json::to_string(&error_response)?;
1866 let http_response = format!(
1867 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
1868 response_body.len(),
1869 response_body
1870 );
1871 stream.write_all(http_response.as_bytes()).await?;
1872 return Ok(());
1873 }
1874 };
1875
1876 let request_id = request.id.clone();
1878 let response = Self::handle_jsonrpc_request(
1879 &request,
1880 &request_id,
1881 registry,
1882 policy_engine,
1883 sandbox,
1884 audit_log,
1885 server_info,
1886 )
1887 .await;
1888
1889 let response_body = serde_json::to_string(&response)?;
1890
1891 let http_response = format!(
1893 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\n\r\n{}",
1894 response_body.len(),
1895 response_body
1896 );
1897 stream.write_all(http_response.as_bytes()).await?;
1898 stream.flush().await?;
1899
1900 let response_json = serde_json::to_string(&response)?;
1902 let clients_guard = clients.read().await;
1903 for (_, tx) in clients_guard.iter() {
1904 let _ = tx.send(response_json.clone());
1905 }
1906
1907 Ok(())
1908 }
1909
1910 async fn handle_jsonrpc_request(
1912 request: &JsonRpcRequest,
1913 request_id: &serde_json::Value,
1914 registry: &Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1915 policy_engine: &Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1916 _sandbox: &Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1917 audit_log: &Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1918 server_info: &Arc<McpServerInfo>,
1919 ) -> serde_json::Value {
1920 match request.method.as_str() {
1921 "initialize" => {
1922 if let Some(params) = request.params.as_ref().and_then(|p| p.as_object()) {
1924 if let Some(client_info) = params.get("clientInfo") {
1925 info!(
1926 client = ?client_info.get("name").and_then(|v| v.as_str()).unwrap_or("unknown"),
1927 "MCP SSE client initialized"
1928 );
1929 }
1930 }
1931
1932 let capabilities = serde_json::json!({
1933 "protocolVersion": "2024-11-05",
1934 "capabilities": {
1935 "tools": {
1936 "listChanged": false
1937 }
1938 },
1939 "serverInfo": {
1940 "name": server_info.name,
1941 "version": server_info.version
1942 }
1943 });
1944
1945 serde_json::json!({
1946 "jsonrpc": "2.0",
1947 "result": capabilities,
1948 "id": request_id
1949 })
1950 }
1951 "notifications/initialized" => {
1952 info!("MCP SSE client initialized notification received");
1953 serde_json::json!({
1954 "jsonrpc": "2.0",
1955 "result": null,
1956 "id": request_id
1957 })
1958 }
1959 "tools/list" => {
1960 let defs = registry.read().await.definitions().clone();
1961 let tools: Vec<serde_json::Value> = defs
1962 .iter()
1963 .map(|def| {
1964 serde_json::json!({
1965 "name": def.name,
1966 "description": def.description,
1967 "inputSchema": def.parameters
1968 })
1969 })
1970 .collect();
1971
1972 serde_json::json!({
1973 "jsonrpc": "2.0",
1974 "result": {
1975 "tools": tools
1976 },
1977 "id": request_id
1978 })
1979 }
1980 "tools/call" => {
1981 let params = request.params.as_ref().unwrap_or(&serde_json::Value::Null);
1982
1983 let name = params
1984 .get("name")
1985 .and_then(|v| v.as_str())
1986 .unwrap_or("")
1987 .to_string();
1988
1989 let arguments = params
1990 .get("arguments")
1991 .cloned()
1992 .unwrap_or(serde_json::Value::Null);
1993
1994 if name.is_empty() {
1995 return serde_json::json!({
1996 "jsonrpc": "2.0",
1997 "error": {
1998 "code": -32602,
1999 "message": "Invalid params: missing tool name"
2000 },
2001 "id": request_id
2002 });
2003 }
2004
2005 let decision = policy_engine
2007 .read()
2008 .await
2009 .check_tool_call(&name, &arguments);
2010 match decision {
2011 crate::policy::Decision::Deny(reason) => {
2012 warn!(tool = %name, reason = %reason, "MCP SSE tool call denied by policy");
2013 return serde_json::json!({
2014 "jsonrpc": "2.0",
2015 "result": {
2016 "content": [{
2017 "type": "text",
2018 "text": format!("Policy denied: {}", reason)
2019 }],
2020 "isError": true
2021 },
2022 "id": request_id
2023 });
2024 }
2025 crate::policy::Decision::Allow => {
2026 let _ = audit_log.write().await.tool_call(&name, &arguments);
2027 }
2028 }
2029
2030 let call = crate::tools::ToolCall {
2032 name: name.clone(),
2033 arguments,
2034 id: None,
2035 };
2036
2037 match registry.read().await.execute(call).await {
2038 Ok(result) => {
2039 let _ = audit_log.write().await.append(
2040 crate::audit::AuditEventType::ToolResult,
2041 &name,
2042 &format!(
2043 "MCP SSE tool executed: {} (success: {})",
2044 name, result.success
2045 ),
2046 Some(serde_json::json!({
2047 "success": result.success,
2048 "exit_code": result.exit_code,
2049 "duration_ms": result.duration_ms,
2050 })),
2051 );
2052
2053 let content = if result.success {
2054 vec![serde_json::json!({
2055 "type": "text",
2056 "text": result.output
2057 })]
2058 } else {
2059 vec![serde_json::json!({
2060 "type": "text",
2061 "text": result.error.as_deref().unwrap_or("Unknown error")
2062 })]
2063 };
2064
2065 serde_json::json!({
2066 "jsonrpc": "2.0",
2067 "result": {
2068 "content": content,
2069 "isError": !result.success
2070 },
2071 "id": request_id
2072 })
2073 }
2074 Err(e) => {
2075 warn!(tool = %name, error = %e, "MCP SSE tool execution failed");
2076 serde_json::json!({
2077 "jsonrpc": "2.0",
2078 "result": {
2079 "content": [{
2080 "type": "text",
2081 "text": format!("Tool execution failed: {}", e)
2082 }],
2083 "isError": true
2084 },
2085 "id": request_id
2086 })
2087 }
2088 }
2089 }
2090 _ => {
2091 serde_json::json!({
2092 "jsonrpc": "2.0",
2093 "error": {
2094 "code": -32601,
2095 "message": format!("Method not found: {}", request.method)
2096 },
2097 "id": request_id
2098 })
2099 }
2100 }
2101 }
2102}
2103
2104#[cfg(test)]
2105mod sse_server_tests {
2106 use super::*;
2107 use crate::tools::ToolRegistry;
2108
2109 #[test]
2110 fn test_mcp_sse_server_new() {
2111 let registry = ToolRegistry::with_default_tools();
2112 let server = McpSseServer::new(registry, "127.0.0.1".to_string(), 9091);
2113
2114 assert_eq!(server.host, "127.0.0.1");
2115 assert_eq!(server.port, 9091);
2116 assert_eq!(server.server_info.name, "ravenclaws");
2117 assert!(server.clients.blocking_read().is_empty());
2118 }
2119
2120 #[test]
2121 fn test_mcp_sse_server_info() {
2122 let registry = ToolRegistry::with_default_tools();
2123 let server = McpSseServer::new(registry, "0.0.0.0".to_string(), 9092);
2124
2125 assert_eq!(server.server_info.name, "ravenclaws");
2126 assert!(!server.server_info.version.is_empty());
2127 }
2128
2129 #[tokio::test]
2130 async fn test_mcp_sse_handle_initialize() {
2131 let registry = ToolRegistry::with_default_tools();
2132 let server_info = Arc::new(McpServerInfo {
2133 name: "ravenclaws".to_string(),
2134 version: env!("CARGO_PKG_VERSION").to_string(),
2135 });
2136
2137 let request = JsonRpcRequest {
2138 jsonrpc: "2.0".to_string(),
2139 method: "initialize".to_string(),
2140 params: Some(serde_json::json!({
2141 "protocolVersion": "2024-11-05",
2142 "clientInfo": {
2143 "name": "test-client",
2144 "version": "1.0.0"
2145 }
2146 })),
2147 id: serde_json::Value::Number(1.into()),
2148 };
2149
2150 let request_id = serde_json::Value::Number(1.into());
2151 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2152 let policy = Arc::new(tokio::sync::RwLock::new(
2153 crate::policy::PolicyEngine::default_secure(),
2154 ));
2155 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2156 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2157 "test".to_string(),
2158 )));
2159
2160 let response = McpSseServer::handle_jsonrpc_request(
2161 &request,
2162 &request_id,
2163 ®istry,
2164 &policy,
2165 &sandbox,
2166 &audit,
2167 &server_info,
2168 )
2169 .await;
2170
2171 assert!(response.get("result").is_some());
2172 assert_eq!(response["result"]["serverInfo"]["name"], "ravenclaws");
2173 }
2174
2175 #[tokio::test]
2176 async fn test_mcp_sse_handle_tools_list() {
2177 let registry = ToolRegistry::with_default_tools();
2178 let server_info = Arc::new(McpServerInfo {
2179 name: "ravenclaws".to_string(),
2180 version: env!("CARGO_PKG_VERSION").to_string(),
2181 });
2182
2183 let request = JsonRpcRequest {
2184 jsonrpc: "2.0".to_string(),
2185 method: "tools/list".to_string(),
2186 params: Some(serde_json::Value::Null),
2187 id: serde_json::Value::Number(1.into()),
2188 };
2189
2190 let request_id = serde_json::Value::Number(1.into());
2191 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2192 let policy = Arc::new(tokio::sync::RwLock::new(
2193 crate::policy::PolicyEngine::default_secure(),
2194 ));
2195 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2196 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2197 "test".to_string(),
2198 )));
2199
2200 let response = McpSseServer::handle_jsonrpc_request(
2201 &request,
2202 &request_id,
2203 ®istry,
2204 &policy,
2205 &sandbox,
2206 &audit,
2207 &server_info,
2208 )
2209 .await;
2210
2211 assert!(response.get("result").is_some());
2212 let tools = &response["result"]["tools"];
2213 assert!(tools.is_array());
2214 assert!(!tools.as_array().unwrap().is_empty());
2215 }
2216
2217 #[tokio::test]
2218 async fn test_mcp_sse_handle_unknown_method() {
2219 let registry = ToolRegistry::with_default_tools();
2220 let server_info = Arc::new(McpServerInfo {
2221 name: "ravenclaws".to_string(),
2222 version: env!("CARGO_PKG_VERSION").to_string(),
2223 });
2224
2225 let request = JsonRpcRequest {
2226 jsonrpc: "2.0".to_string(),
2227 method: "unknown_method".to_string(),
2228 params: Some(serde_json::Value::Null),
2229 id: serde_json::Value::Number(1.into()),
2230 };
2231
2232 let request_id = serde_json::Value::Number(1.into());
2233 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2234 let policy = Arc::new(tokio::sync::RwLock::new(
2235 crate::policy::PolicyEngine::default_secure(),
2236 ));
2237 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2238 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2239 "test".to_string(),
2240 )));
2241
2242 let response = McpSseServer::handle_jsonrpc_request(
2243 &request,
2244 &request_id,
2245 ®istry,
2246 &policy,
2247 &sandbox,
2248 &audit,
2249 &server_info,
2250 )
2251 .await;
2252
2253 assert!(response.get("error").is_some());
2254 assert_eq!(
2255 response["error"]["code"],
2256 serde_json::Value::Number((-32601).into())
2257 );
2258 }
2259
2260 #[tokio::test]
2261 async fn test_mcp_sse_handle_tools_call_missing_name() {
2262 let registry = ToolRegistry::with_default_tools();
2263 let server_info = Arc::new(McpServerInfo {
2264 name: "ravenclaws".to_string(),
2265 version: env!("CARGO_PKG_VERSION").to_string(),
2266 });
2267
2268 let request = JsonRpcRequest {
2269 jsonrpc: "2.0".to_string(),
2270 method: "tools/call".to_string(),
2271 params: Some(serde_json::json!({})),
2272 id: serde_json::Value::Number(1.into()),
2273 };
2274
2275 let request_id = serde_json::Value::Number(1.into());
2276 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2277 let policy = Arc::new(tokio::sync::RwLock::new(
2278 crate::policy::PolicyEngine::default_secure(),
2279 ));
2280 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2281 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2282 "test".to_string(),
2283 )));
2284
2285 let response = McpSseServer::handle_jsonrpc_request(
2286 &request,
2287 &request_id,
2288 ®istry,
2289 &policy,
2290 &sandbox,
2291 &audit,
2292 &server_info,
2293 )
2294 .await;
2295
2296 assert!(response.get("error").is_some());
2297 assert_eq!(
2298 response["error"]["code"],
2299 serde_json::Value::Number((-32602).into())
2300 );
2301 }
2302
2303 #[tokio::test]
2304 async fn test_mcp_sse_transport_config_serde() {
2305 let config = McpTransportConfig::Sse {
2307 url: "http://localhost:9090/sse".to_string(),
2308 };
2309
2310 match config {
2311 McpTransportConfig::Sse { url } => {
2312 assert_eq!(url, "http://localhost:9090/sse");
2313 }
2314 _ => panic!("Expected SSE variant"),
2315 }
2316 }
2317
2318 #[tokio::test]
2319 async fn test_mcp_sse_transport_connect_failure() {
2320 let config = McpTransportConfig::Sse {
2322 url: "http://127.0.0.1:1/sse".to_string(),
2323 };
2324
2325 let mut transport = McpTransport::new(config);
2326 let result = transport.connect().await;
2327
2328 assert!(result.is_err());
2330 match result {
2331 Err(McpError::ConnectionFailed(_)) => {} Err(McpError::Transport(_)) => {} _ => panic!("Expected connection or transport error, got {:?}", result),
2334 }
2335 }
2336}