1use futures::StreamExt;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::process::Stdio;
35use std::sync::Arc;
36use thiserror::Error;
37use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
38use tokio::process::{Child, Command};
39use tokio::sync::RwLock;
40use tracing::{debug, info, warn};
41use uuid::Uuid;
42
43use crate::tools::{
44 JsonSchema, ToolCategory, ToolDefinition, ToolImpl, ToolResult, ToolResultValue,
45};
46
47#[derive(Error, Debug)]
50pub enum McpError {
51 #[error("Transport error: {0}")]
52 Transport(String),
53
54 #[error("JSON-RPC error: {0}")]
55 JsonRpc(String),
56
57 #[error("Server error: {code} - {message}")]
58 Server { code: i32, message: String },
59
60 #[error("Tool not found: {0}")]
61 #[allow(dead_code)]
62 ToolNotFound(String),
63
64 #[error("Invalid tool arguments: {0}")]
65 #[allow(dead_code)]
66 InvalidArguments(String),
67
68 #[error("Connection failed: {0}")]
69 ConnectionFailed(String),
70
71 #[error("Timeout: {0}")]
72 Timeout(String),
73
74 #[error("IO error: {0}")]
75 Io(#[from] std::io::Error),
76
77 #[error("JSON error: {0}")]
78 Json(#[from] serde_json::Error),
79}
80
81pub type McpResult<T> = std::result::Result<T, McpError>;
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct JsonRpcRequest {
91 pub jsonrpc: String,
92 pub method: String,
93 #[serde(default)]
94 pub params: Option<serde_json::Value>,
95 pub id: serde_json::Value,
96}
97
98impl JsonRpcRequest {
99 pub fn new(method: &str, params: serde_json::Value, id: i64) -> Self {
100 Self {
101 jsonrpc: "2.0".to_string(),
102 method: method.to_string(),
103 params: Some(params),
104 id: serde_json::Value::Number(id.into()),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct JsonRpcResponse {
112 pub jsonrpc: String,
113 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub result: Option<serde_json::Value>,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub error: Option<JsonRpcError>,
117 pub id: serde_json::Value,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct JsonRpcError {
122 pub code: i32,
123 pub message: String,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub data: Option<serde_json::Value>,
126}
127
128pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136pub struct InitializeParams {
137 pub protocol_version: String,
138 pub capabilities: McpClientCapabilities,
139 pub client_info: McpClientInfo,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "camelCase")]
144pub struct McpClientCapabilities {
145 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub roots: Option<McpRootsCapability>,
147 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub sampling: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct McpRootsCapability {
154 pub list_changed: bool,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct McpClientInfo {
159 pub name: String,
160 pub version: String,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165#[serde(rename_all = "camelCase")]
166pub struct InitializeResult {
167 pub protocol_version: String,
168 pub capabilities: McpServerCapabilities,
169 pub server_info: McpServerInfo,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173#[serde(rename_all = "camelCase")]
174pub struct McpServerCapabilities {
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub tools: Option<McpToolsCapability>,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub resources: Option<serde_json::Value>,
179 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub prompts: Option<serde_json::Value>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(rename_all = "camelCase")]
185pub struct McpToolsCapability {
186 #[serde(default)]
187 pub list_changed: bool,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct McpServerInfo {
192 pub name: String,
193 pub version: String,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct McpTool {
199 pub name: String,
200 #[serde(default, skip_serializing_if = "Option::is_none")]
201 pub description: Option<String>,
202 pub input_schema: serde_json::Value,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct McpToolCall {
208 pub name: String,
209 #[serde(default, skip_serializing_if = "Option::is_none")]
210 pub arguments: Option<serde_json::Value>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215#[serde(rename_all = "camelCase")]
216pub struct McpToolResult {
217 pub content: Vec<McpContent>,
218 #[serde(default)]
219 pub is_error: bool,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum McpContent {
225 Text { text: String },
226 Image { data: String, mime_type: String },
227 Resource { resource: serde_json::Value },
228}
229
230#[derive(Debug, Clone)]
234pub enum McpTransportConfig {
235 Stdio {
237 command: String,
238 args: Vec<String>,
239 env: HashMap<String, String>,
240 },
241 Sse { url: String },
243}
244
245pub struct McpTransport {
247 config: McpTransportConfig,
248 child: Option<Child>,
250 stdin: Option<tokio::process::ChildStdin>,
251 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
252 http_client: Option<reqwest::Client>,
254 message_endpoint: Option<String>,
256 sse_response_rx: Option<tokio::sync::mpsc::UnboundedReceiver<String>>,
258 sse_reader_handle: Option<tokio::task::JoinHandle<()>>,
260 request_id: i64,
261}
262
263impl McpTransport {
264 pub fn new(config: McpTransportConfig) -> Self {
266 Self {
267 config,
268 child: None,
269 stdin: None,
270 stdout_reader: None,
271 http_client: None,
272 message_endpoint: None,
273 sse_response_rx: None,
274 sse_reader_handle: None,
275 request_id: 0,
276 }
277 }
278
279 pub async fn connect(&mut self) -> McpResult<()> {
281 let config = self.config.clone();
282 match &config {
283 McpTransportConfig::Stdio { command, args, env } => {
284 self.connect_stdio(command, args, env).await
285 }
286 McpTransportConfig::Sse { url } => self.connect_sse(url).await,
287 }
288 }
289
290 async fn connect_stdio(
292 &mut self,
293 command: &str,
294 args: &[String],
295 env: &HashMap<String, String>,
296 ) -> McpResult<()> {
297 let mut cmd = Command::new(command);
298 cmd.args(args);
299 cmd.stdin(Stdio::piped());
300 cmd.stdout(Stdio::piped());
301 cmd.stderr(Stdio::piped());
302 cmd.envs(env);
303
304 let mut child = cmd.spawn().map_err(|e| {
305 McpError::ConnectionFailed(format!("Failed to spawn {}: {}", command, e))
306 })?;
307
308 let stdin = child
309 .stdin
310 .take()
311 .ok_or_else(|| McpError::ConnectionFailed("No stdin available".to_string()))?;
312
313 let stdout = child
314 .stdout
315 .take()
316 .ok_or_else(|| McpError::ConnectionFailed("No stdout available".to_string()))?;
317
318 self.child = Some(child);
319 self.stdin = Some(stdin);
320 self.stdout_reader = Some(BufReader::new(stdout));
321
322 info!(command = %command, "MCP stdio transport connected");
323 Ok(())
324 }
325
326 async fn connect_sse(&mut self, url: &str) -> McpResult<()> {
328 let client = reqwest::Client::builder()
329 .danger_accept_invalid_certs(false)
330 .build()
331 .map_err(|e| {
332 McpError::ConnectionFailed(format!("Failed to create HTTP client: {}", e))
333 })?;
334
335 info!(url = %url, "MCP SSE transport connecting");
336
337 let response = client
339 .get(url)
340 .header("Accept", "text/event-stream")
341 .send()
342 .await
343 .map_err(|e| McpError::ConnectionFailed(format!("SSE connection failed: {}", e)))?;
344
345 if !response.status().is_success() {
346 return Err(McpError::ConnectionFailed(format!(
347 "SSE connection returned status {}",
348 response.status()
349 )));
350 }
351
352 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
354
355 let message_endpoint = Arc::new(tokio::sync::Mutex::new(None::<String>));
357 let message_endpoint_clone = message_endpoint.clone();
358
359 let handle = tokio::spawn(async move {
361 let mut stream = response.bytes_stream();
362 let mut buffer = String::new();
363
364 while let Some(chunk_result) = stream.next().await {
365 match chunk_result {
366 Ok(chunk) => {
367 let chunk_str = String::from_utf8_lossy(&chunk);
368 buffer.push_str(&chunk_str);
369
370 while let Some(event_end) = buffer.find("\n\n") {
372 let event = buffer[..event_end].to_string();
373 buffer = buffer[event_end + 2..].to_string();
374 Self::handle_sse_event(&event, &tx, &message_endpoint_clone).await;
375 }
376 }
377 Err(e) => {
378 warn!("SSE stream error: {}", e);
379 break;
380 }
381 }
382 }
383 debug!("MCP SSE stream ended");
384 });
385
386 self.http_client = Some(client);
387 self.sse_response_rx = Some(rx);
388 self.sse_reader_handle = Some(handle);
389 self.message_endpoint = None; let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(5);
394 let mut endpoint_received = false;
395
396 while tokio::time::Instant::now() < deadline {
397 if let Ok(msg) = self.sse_response_rx.as_mut().unwrap().try_recv() {
398 if msg.starts_with("endpoint:") {
400 let ep = msg
401 .strip_prefix("endpoint:")
402 .unwrap_or("")
403 .trim()
404 .to_string();
405 let resolved = if ep.starts_with("http") {
407 ep.clone()
408 } else {
409 let base = url.trim_end_matches('/');
410 let path = ep.trim_start_matches('/');
411 format!("{}/{}", base, path)
412 };
413 self.message_endpoint = Some(resolved.clone());
414 *message_endpoint.lock().await = Some(resolved);
415 endpoint_received = true;
416 info!(
417 "MCP SSE endpoint received: {}",
418 self.message_endpoint.as_ref().unwrap()
419 );
420 break;
421 }
422 } else {
423 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
424 }
425 }
426
427 if !endpoint_received {
428 self.message_endpoint = Some(url.to_string());
430 *message_endpoint.lock().await = Some(url.to_string());
431 warn!("No SSE endpoint event received, using SSE URL as message endpoint");
432 }
433
434 info!(url = %url, "MCP SSE transport connected");
435 Ok(())
436 }
437
438 async fn handle_sse_event(
440 event: &str,
441 tx: &tokio::sync::mpsc::UnboundedSender<String>,
442 message_endpoint: &Arc<tokio::sync::Mutex<Option<String>>>,
443 ) {
444 for line in event.lines() {
445 let line = line.trim();
446 if let Some(value) = line.strip_prefix("event:") {
447 let event_type = value.trim();
448 debug!("MCP SSE event: {}", event_type);
449 } else if let Some(value) = line.strip_prefix("data:") {
450 let data = value.trim().to_string();
451 if let Some(ep) = data.strip_prefix("/message") {
453 *message_endpoint.lock().await = Some(ep.to_string());
454 let _ = tx.send(format!("endpoint:{}", ep));
455 } else if data.starts_with("http") && data.contains("/message") {
456 *message_endpoint.lock().await = Some(data.clone());
457 let _ = tx.send(format!("endpoint:{}", data));
458 } else {
459 let _ = tx.send(data);
461 }
462 }
463 }
464 }
465
466 pub async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
468 match &self.config {
469 McpTransportConfig::Stdio { .. } => self.send_request_stdio(request).await,
470 McpTransportConfig::Sse { .. } => self.send_request_sse(request).await,
471 }
472 }
473
474 async fn send_request_stdio(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
476 let request_json = serde_json::to_string(&request)?;
477 debug!("MCP → {}", request_json);
478
479 let stdin = self
480 .stdin
481 .as_mut()
482 .ok_or_else(|| McpError::Transport("Transport not connected".to_string()))?;
483
484 stdin.write_all(request_json.as_bytes()).await?;
485 stdin.write_all(b"\n").await?;
486 stdin.flush().await?;
487
488 let stdout = self
489 .stdout_reader
490 .as_mut()
491 .ok_or_else(|| McpError::Transport("Transport not connected".to_string()))?;
492
493 let mut response_line = String::new();
494 stdout.read_line(&mut response_line).await?;
495
496 if response_line.trim().is_empty() {
497 return Err(McpError::Transport(
498 "Empty response from server".to_string(),
499 ));
500 }
501
502 debug!("MCP ← {}", response_line.trim());
503
504 let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
505
506 if let Some(err) = &response.error {
507 return Err(McpError::Server {
508 code: err.code,
509 message: err.message.clone(),
510 });
511 }
512
513 Ok(response)
514 }
515
516 async fn send_request_sse(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
518 let request_json = serde_json::to_string(&request)?;
519 let request_id = request.id.clone();
520 debug!(
521 "MCP SSE → POST {}: {}",
522 self.message_endpoint.as_deref().unwrap_or("?"),
523 request_json
524 );
525
526 let client = self
527 .http_client
528 .as_ref()
529 .ok_or_else(|| McpError::Transport("HTTP client not initialized".to_string()))?;
530
531 let endpoint = self
532 .message_endpoint
533 .as_ref()
534 .ok_or_else(|| McpError::Transport("No message endpoint available".to_string()))?;
535
536 let response = client
538 .post(endpoint)
539 .header("Content-Type", "application/json")
540 .header("Accept", "application/json")
541 .body(request_json.clone())
542 .send()
543 .await
544 .map_err(|e| McpError::Transport(format!("HTTP POST failed: {}", e)))?;
545
546 let status = response.status();
549 let body = response
550 .text()
551 .await
552 .map_err(|e| McpError::Transport(format!("Failed to read response body: {}", e)))?;
553
554 if !body.trim().is_empty() {
555 if let Ok(rpc_response) = serde_json::from_str::<JsonRpcResponse>(&body) {
557 debug!("MCP SSE ← {}", body.trim());
558 if let Some(err) = &rpc_response.error {
559 return Err(McpError::Server {
560 code: err.code,
561 message: err.message.clone(),
562 });
563 }
564 return Ok(rpc_response);
565 }
566 }
567
568 if status.is_success() {
570 let rx = self
571 .sse_response_rx
572 .as_mut()
573 .ok_or_else(|| McpError::Transport("SSE receiver not available".to_string()))?;
574
575 let timeout = tokio::time::Duration::from_secs(30);
576 let deadline = tokio::time::Instant::now() + timeout;
577
578 while tokio::time::Instant::now() < deadline {
579 match tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv()).await {
580 Ok(Some(msg)) => {
581 if msg.starts_with("endpoint:") {
583 continue;
584 }
585 if let Ok(rpc_response) = serde_json::from_str::<JsonRpcResponse>(&msg) {
587 if rpc_response.id == request_id
589 || rpc_response.id == serde_json::Value::Null
590 {
591 debug!("MCP SSE ← {}", msg);
592 if let Some(err) = &rpc_response.error {
593 return Err(McpError::Server {
594 code: err.code,
595 message: err.message.clone(),
596 });
597 }
598 return Ok(rpc_response);
599 }
600 }
601 }
602 Ok(None) => break,
603 Err(_) => continue, }
605 }
606
607 return Err(McpError::Timeout(
608 "No response received from SSE stream".to_string(),
609 ));
610 }
611
612 Err(McpError::Transport(format!(
613 "HTTP POST returned status {}",
614 status
615 )))
616 }
617
618 fn next_id(&mut self) -> i64 {
620 self.request_id += 1;
621 self.request_id
622 }
623}
624
625impl Drop for McpTransport {
626 fn drop(&mut self) {
627 if let Some(handle) = self.sse_reader_handle.take() {
629 handle.abort();
630 }
631 }
632}
633
634pub struct McpClient {
638 transport: McpTransport,
639 server_info: Option<McpServerInfo>,
640 tools: Arc<RwLock<Vec<McpTool>>>,
641}
642
643impl McpClient {
644 pub fn new(config: McpTransportConfig) -> Self {
646 Self {
647 transport: McpTransport::new(config),
648 server_info: None,
649 tools: Arc::new(RwLock::new(Vec::new())),
650 }
651 }
652
653 pub async fn connect(&mut self) -> McpResult<()> {
655 self.transport.connect().await?;
657
658 let init_params = InitializeParams {
660 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
661 capabilities: McpClientCapabilities {
662 roots: Some(McpRootsCapability {
663 list_changed: false,
664 }),
665 sampling: None,
666 },
667 client_info: McpClientInfo {
668 name: "ravenclaws".to_string(),
669 version: env!("CARGO_PKG_VERSION").to_string(),
670 },
671 };
672
673 let init_id = self.transport.next_id();
674 let response = self
675 .transport
676 .send_request(JsonRpcRequest::new(
677 "initialize",
678 serde_json::to_value(init_params)?,
679 init_id,
680 ))
681 .await?;
682
683 let init_result: InitializeResult = response
684 .result
685 .and_then(|v| serde_json::from_value(v).ok())
686 .ok_or_else(|| McpError::JsonRpc("Invalid initialize response".to_string()))?;
687
688 let server_info = init_result.server_info.clone();
689 self.server_info = Some(init_result.server_info);
690
691 info!(
692 server = %server_info.name,
693 version = %server_info.version,
694 "MCP server initialized"
695 );
696
697 let notify = JsonRpcRequest {
699 jsonrpc: "2.0".to_string(),
700 method: "notifications/initialized".to_string(),
701 params: Some(serde_json::Value::Null),
702 id: serde_json::Value::Null,
703 };
704 self.transport.send_request(notify).await?;
705
706 self.discover_tools().await?;
708
709 Ok(())
710 }
711
712 pub async fn discover_tools(&mut self) -> McpResult<()> {
714 let list_id = self.transport.next_id();
715 let response = self
716 .transport
717 .send_request(JsonRpcRequest::new(
718 "tools/list",
719 serde_json::Value::Null,
720 list_id,
721 ))
722 .await?;
723
724 let tools_result = response
725 .result
726 .and_then(|v| v.get("tools").cloned())
727 .ok_or_else(|| McpError::JsonRpc("No tools in response".to_string()))?;
728
729 let tools: Vec<McpTool> = serde_json::from_value(tools_result)?;
730
731 info!(count = tools.len(), "Discovered MCP tools");
732
733 let mut tool_lock = self.tools.write().await;
734 *tool_lock = tools;
735
736 Ok(())
737 }
738
739 pub async fn get_tools(&self) -> Vec<McpTool> {
741 self.tools.read().await.clone()
742 }
743
744 pub async fn call_tool(
746 &mut self,
747 name: &str,
748 arguments: Option<serde_json::Value>,
749 ) -> McpResult<McpToolResult> {
750 let params = McpToolCall {
751 name: name.to_string(),
752 arguments,
753 };
754
755 let call_id = self.transport.next_id();
756 let response = self
757 .transport
758 .send_request(JsonRpcRequest::new(
759 "tools/call",
760 serde_json::to_value(params)?,
761 call_id,
762 ))
763 .await?;
764
765 let result: McpToolResult = response
766 .result
767 .and_then(|v| serde_json::from_value(v).ok())
768 .ok_or_else(|| McpError::JsonRpc("Invalid tool call response".to_string()))?;
769
770 if result.is_error {
771 return Err(McpError::Server {
772 code: -32000,
773 message: "Tool execution failed".to_string(),
774 });
775 }
776
777 Ok(result)
778 }
779
780 pub fn server_info(&self) -> Option<&McpServerInfo> {
782 self.server_info.as_ref()
783 }
784}
785
786pub struct McpClientManager {
794 clients: Vec<(String, Arc<RwLock<McpClient>>)>,
796}
797
798impl McpClientManager {
799 pub fn new() -> Self {
801 Self {
802 clients: Vec::new(),
803 }
804 }
805
806 pub async fn from_config(config: &crate::config::McpConfig) -> Self {
808 let mut manager = Self::new();
809 for server in &config.servers {
810 let transport_config = if !server.url.is_empty() {
812 McpTransportConfig::Sse {
813 url: server.url.clone(),
814 }
815 } else {
816 McpTransportConfig::Stdio {
817 command: server.command.clone(),
818 args: server.args.clone(),
819 env: server.env.clone(),
820 }
821 };
822 let mut client = McpClient::new(transport_config);
823 match client.connect().await {
824 Ok(()) => {
825 info!(
826 server = %server.name,
827 server_info = ?client.server_info(),
828 "MCP client connected from config"
829 );
830 manager
831 .clients
832 .push((server.name.clone(), Arc::new(RwLock::new(client))));
833 }
834 Err(e) => {
835 warn!(
836 server = %server.name,
837 error = %e,
838 "Failed to connect to MCP server from config, skipping"
839 );
840 }
841 }
842 }
843 manager
844 }
845
846 #[allow(dead_code)]
848 pub fn add_client(&mut self, name: String, client: Arc<RwLock<McpClient>>) {
849 self.clients.push((name, client));
850 }
851
852 #[allow(dead_code)]
854 pub fn clients(&self) -> &[(String, Arc<RwLock<McpClient>>)] {
855 &self.clients
856 }
857
858 #[allow(dead_code)]
860 pub fn get_client(&self, name: &str) -> Option<&Arc<RwLock<McpClient>>> {
861 self.clients.iter().find(|(n, _)| n == name).map(|(_, c)| c)
862 }
863
864 pub async fn register_all_tools(&self, registry: &mut crate::tools::ToolRegistry) -> usize {
866 let mut total = 0;
867 for (name, client) in &self.clients {
868 let mcp_client = client.read().await;
869 let mcp_tools = mcp_client.get_tools().await;
870 drop(mcp_client);
871
872 for mcp_tool in mcp_tools {
873 let wrapper = McpToolWrapper::new(client.clone(), mcp_tool);
874 registry.register(Arc::new(wrapper));
875 total += 1;
876 }
877 info!(
878 server = %name,
879 tools_registered = total,
880 "Registered MCP tools from server"
881 );
882 }
883 info!(total, "Total MCP tools registered from all servers");
884 total
885 }
886
887 pub fn len(&self) -> usize {
889 self.clients.len()
890 }
891
892 pub fn is_empty(&self) -> bool {
894 self.clients.is_empty()
895 }
896}
897
898impl Default for McpClientManager {
899 fn default() -> Self {
900 Self::new()
901 }
902}
903
904pub struct McpToolWrapper {
908 definition: ToolDefinition,
909 client: Arc<RwLock<McpClient>>,
910 tool_name: String,
911}
912
913impl McpToolWrapper {
914 pub fn new(client: Arc<RwLock<McpClient>>, mcp_tool: McpTool) -> Self {
916 let parameters = Self::convert_schema(&mcp_tool.input_schema);
918
919 Self {
920 definition: ToolDefinition {
921 name: mcp_tool.name.clone(),
922 description: mcp_tool
923 .description
924 .unwrap_or_else(|| "MCP-provided tool".to_string()),
925 parameters,
926 requires_approval: false,
927 category: ToolCategory::Mcp,
928 },
929 client,
930 tool_name: mcp_tool.name,
931 }
932 }
933
934 fn convert_schema(schema: &serde_json::Value) -> JsonSchema {
936 if let Some(obj) = schema.as_object() {
937 let schema_type = obj
938 .get("type")
939 .and_then(|v| v.as_str())
940 .unwrap_or("object")
941 .to_string();
942
943 let description = obj
944 .get("description")
945 .and_then(|v| v.as_str())
946 .map(|s| s.to_string());
947
948 let properties = obj
949 .get("properties")
950 .and_then(|v| v.as_object())
951 .map(|props| {
952 props
953 .iter()
954 .map(|(k, v)| (k.clone(), Self::convert_schema(v)))
955 .collect::<HashMap<String, JsonSchema>>()
956 });
957
958 let required = obj.get("required").and_then(|v| v.as_array()).map(|arr| {
959 arr.iter()
960 .filter_map(|v| v.as_str())
961 .map(|s| s.to_string())
962 .collect()
963 });
964
965 JsonSchema {
966 schema_type,
967 description,
968 properties,
969 required,
970 items: None,
971 enum_values: None,
972 }
973 } else {
974 JsonSchema::string("MCP tool parameter")
975 }
976 }
977}
978
979#[async_trait::async_trait]
980impl ToolImpl for McpToolWrapper {
981 fn definition(&self) -> &ToolDefinition {
982 &self.definition
983 }
984
985 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
986 let mut client = self.client.write().await;
987
988 let result = client
989 .call_tool(&self.tool_name, Some(args))
990 .await
991 .map_err(|e| {
992 crate::tools::ToolError::ExecutionFailed(self.tool_name.clone(), e.to_string())
993 })?;
994
995 let output = result
997 .content
998 .iter()
999 .map(|c| match c {
1000 McpContent::Text { text } => text.clone(),
1001 McpContent::Image { data, mime_type } => {
1002 format!("[Image: {} bytes, {}]", data.len(), mime_type)
1003 }
1004 McpContent::Resource { resource } => {
1005 format!("[Resource: {}]", resource)
1006 }
1007 })
1008 .collect::<Vec<_>>()
1009 .join("\n");
1010
1011 Ok(ToolResult {
1012 tool_name: self.tool_name.clone(),
1013 success: !result.is_error,
1014 output,
1015 error: if result.is_error {
1016 Some("Tool returned error".to_string())
1017 } else {
1018 None
1019 },
1020 exit_code: None,
1021 duration_ms: None,
1022 })
1023 }
1024}
1025
1026pub async fn register_mcp_tools(
1030 registry: &mut crate::tools::ToolRegistry,
1031 client: Arc<RwLock<McpClient>>,
1032) -> McpResult<usize> {
1033 let mcp_client = client.read().await;
1034 let mcp_tools = mcp_client.get_tools().await;
1035 drop(mcp_client);
1036
1037 let count = mcp_tools.len();
1038
1039 for mcp_tool in mcp_tools {
1040 let wrapper = McpToolWrapper::new(client.clone(), mcp_tool);
1041 registry.register(Arc::new(wrapper));
1042 }
1043
1044 info!(count, "Registered MCP tools");
1045 Ok(count)
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::*;
1051
1052 #[test]
1053 fn test_json_rpc_request() {
1054 let req = JsonRpcRequest::new("tools/list", serde_json::Value::Null, 1);
1055 assert_eq!(req.jsonrpc, "2.0");
1056 assert_eq!(req.method, "tools/list");
1057 assert_eq!(req.id, serde_json::Value::Number(1.into()));
1058 }
1059
1060 #[test]
1061 fn test_mcp_tool_serialization() {
1062 let tool = McpTool {
1063 name: "test_tool".to_string(),
1064 description: Some("A test tool".to_string()),
1065 input_schema: serde_json::json!({
1066 "type": "object",
1067 "properties": {
1068 "query": {"type": "string"}
1069 }
1070 }),
1071 };
1072
1073 let json = serde_json::to_string(&tool).unwrap();
1074 assert!(json.contains("test_tool"));
1075 assert!(json.contains("A test tool"));
1076 }
1077
1078 #[test]
1079 fn test_initialize_params() {
1080 let params = InitializeParams {
1081 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
1082 capabilities: McpClientCapabilities {
1083 roots: Some(McpRootsCapability {
1084 list_changed: false,
1085 }),
1086 sampling: None,
1087 },
1088 client_info: McpClientInfo {
1089 name: "ravenclaws".to_string(),
1090 version: "0.5.2".to_string(),
1091 },
1092 };
1093
1094 let json = serde_json::to_string(¶ms).unwrap();
1095 assert!(json.contains("protocolVersion"));
1096 assert!(json.contains("ravenclaws"));
1097 }
1098}
1099
1100pub struct McpServer {
1117 registry: crate::tools::ToolRegistry,
1119 policy_engine: crate::policy::PolicyEngine,
1121 sandbox: crate::sandbox::Sandbox,
1123 audit_log: crate::audit::AuditLog,
1125 initialized: bool,
1127 server_info: McpServerInfo,
1129 request_id: i64,
1131}
1132
1133impl McpServer {
1134 pub fn new(registry: crate::tools::ToolRegistry) -> Self {
1138 let server_info = McpServerInfo {
1139 name: "ravenclaws".to_string(),
1140 version: env!("CARGO_PKG_VERSION").to_string(),
1141 };
1142
1143 Self {
1144 registry,
1145 policy_engine: crate::policy::PolicyEngine::default_secure(),
1146 sandbox: crate::sandbox::Sandbox::default(),
1147 audit_log: crate::audit::AuditLog::new(format!("mcp-server-{}", std::process::id())),
1148 initialized: false,
1149 server_info,
1150 request_id: 0,
1151 }
1152 }
1153
1154 pub async fn run(&mut self) -> Result<(), McpError> {
1157 self.sandbox
1159 .init()
1160 .await
1161 .map_err(|e| McpError::Transport(format!("Sandbox init failed: {}", e)))?;
1162
1163 info!("MCP server starting on stdio");
1164
1165 let stdin = tokio::io::stdin();
1166 let reader = BufReader::new(stdin);
1167 let mut lines = reader.lines();
1168
1169 while let Ok(Some(line)) = lines.next_line().await {
1170 let line = line.trim().to_string();
1171 if line.is_empty() {
1172 continue;
1173 }
1174
1175 debug!("MCP Server ← {}", &line);
1176
1177 let request: JsonRpcRequest = match serde_json::from_str(&line) {
1179 Ok(req) => req,
1180 Err(e) => {
1181 let error_response = serde_json::json!({
1182 "jsonrpc": "2.0",
1183 "error": {
1184 "code": -32700,
1185 "message": "Parse error",
1186 "data": e.to_string()
1187 },
1188 "id": serde_json::Value::Null
1189 });
1190 let _ = self.write_response(&error_response).await;
1191 continue;
1192 }
1193 };
1194
1195 let response = self.handle_request(&request).await;
1196 let _ = self.write_response(&response).await;
1197 }
1198
1199 info!("MCP server shutting down (stdin closed)");
1200 Ok(())
1201 }
1202
1203 async fn handle_request(&mut self, request: &JsonRpcRequest) -> serde_json::Value {
1205 let request_id = request.id.clone();
1206
1207 match request.method.as_str() {
1208 "initialize" => self.handle_initialize(request, &request_id).await,
1209 "notifications/initialized" => {
1210 self.initialized = true;
1211 info!("MCP server initialized by client");
1212 serde_json::json!({
1213 "jsonrpc": "2.0",
1214 "result": null,
1215 "id": request_id
1216 })
1217 }
1218 "tools/list" => self.handle_tools_list(&request_id).await,
1219 "tools/call" => self.handle_tools_call(request, &request_id).await,
1220 _ => {
1221 serde_json::json!({
1222 "jsonrpc": "2.0",
1223 "error": {
1224 "code": -32601,
1225 "message": format!("Method not found: {}", request.method)
1226 },
1227 "id": request_id
1228 })
1229 }
1230 }
1231 }
1232
1233 async fn handle_initialize(
1235 &mut self,
1236 request: &JsonRpcRequest,
1237 request_id: &serde_json::Value,
1238 ) -> serde_json::Value {
1239 if let Some(params) = request.params.as_ref().and_then(|p| p.as_object()) {
1241 if let Some(client_info) = params.get("clientInfo") {
1242 info!(
1243 client = ?client_info.get("name").and_then(|v| v.as_str()).unwrap_or("unknown"),
1244 "MCP client connected"
1245 );
1246 }
1247 }
1248
1249 let capabilities = McpServerCapabilities {
1250 tools: Some(McpToolsCapability {
1251 list_changed: false,
1252 }),
1253 resources: None,
1254 prompts: None,
1255 };
1256
1257 let result = serde_json::json!({
1258 "protocolVersion": MCP_PROTOCOL_VERSION,
1259 "capabilities": capabilities,
1260 "serverInfo": {
1261 "name": self.server_info.name,
1262 "version": self.server_info.version
1263 }
1264 });
1265
1266 serde_json::json!({
1267 "jsonrpc": "2.0",
1268 "result": result,
1269 "id": request_id
1270 })
1271 }
1272
1273 async fn handle_tools_list(&self, request_id: &serde_json::Value) -> serde_json::Value {
1275 let tools: Vec<serde_json::Value> = self
1276 .registry
1277 .definitions()
1278 .iter()
1279 .map(|def| {
1280 serde_json::json!({
1281 "name": def.name,
1282 "description": def.description,
1283 "inputSchema": def.parameters
1284 })
1285 })
1286 .collect();
1287
1288 serde_json::json!({
1289 "jsonrpc": "2.0",
1290 "result": {
1291 "tools": tools
1292 },
1293 "id": request_id
1294 })
1295 }
1296
1297 async fn handle_tools_call(
1299 &mut self,
1300 request: &JsonRpcRequest,
1301 request_id: &serde_json::Value,
1302 ) -> serde_json::Value {
1303 let params = request.params.as_ref().unwrap_or(&serde_json::Value::Null);
1304
1305 let name = params
1306 .get("name")
1307 .and_then(|v| v.as_str())
1308 .unwrap_or("")
1309 .to_string();
1310
1311 let arguments = params
1312 .get("arguments")
1313 .cloned()
1314 .unwrap_or(serde_json::Value::Null);
1315
1316 if name.is_empty() {
1317 return serde_json::json!({
1318 "jsonrpc": "2.0",
1319 "error": {
1320 "code": -32602,
1321 "message": "Invalid params: missing tool name"
1322 },
1323 "id": request_id
1324 });
1325 }
1326
1327 let policy_decision = self.policy_engine.check_tool_call(&name, &arguments);
1329 match policy_decision {
1330 crate::policy::Decision::Deny(reason) => {
1331 warn!(tool = %name, reason = %reason, "MCP tool call denied by policy");
1332 return serde_json::json!({
1333 "jsonrpc": "2.0",
1334 "result": {
1335 "content": [{
1336 "type": "text",
1337 "text": format!("Policy denied: {}", reason)
1338 }],
1339 "isError": true
1340 },
1341 "id": request_id
1342 });
1343 }
1344 crate::policy::Decision::Allow => {
1345 let _ = self.audit_log.tool_call(&name, &arguments);
1347 }
1348 }
1349
1350 let call = crate::tools::ToolCall {
1352 name: name.clone(),
1353 arguments,
1354 id: None,
1355 };
1356
1357 match self.registry.execute(call).await {
1358 Ok(result) => {
1359 let _ = self.audit_log.append(
1361 crate::audit::AuditEventType::ToolResult,
1362 &name,
1363 &format!("MCP tool executed: {} (success: {})", name, result.success),
1364 Some(serde_json::json!({
1365 "success": result.success,
1366 "exit_code": result.exit_code,
1367 "duration_ms": result.duration_ms,
1368 })),
1369 );
1370
1371 let content = if result.success {
1372 vec![serde_json::json!({
1373 "type": "text",
1374 "text": result.output
1375 })]
1376 } else {
1377 vec![serde_json::json!({
1378 "type": "text",
1379 "text": result.error.as_deref().unwrap_or("Unknown error")
1380 })]
1381 };
1382
1383 serde_json::json!({
1384 "jsonrpc": "2.0",
1385 "result": {
1386 "content": content,
1387 "isError": !result.success
1388 },
1389 "id": request_id
1390 })
1391 }
1392 Err(e) => {
1393 warn!(tool = %name, error = %e, "MCP tool execution failed");
1394 serde_json::json!({
1395 "jsonrpc": "2.0",
1396 "result": {
1397 "content": [{
1398 "type": "text",
1399 "text": format!("Tool execution failed: {}", e)
1400 }],
1401 "isError": true
1402 },
1403 "id": request_id
1404 })
1405 }
1406 }
1407 }
1408
1409 async fn write_response(&self, response: &serde_json::Value) -> std::io::Result<()> {
1411 let json = serde_json::to_string(response)?;
1412 debug!("MCP Server → {}", &json);
1413 use tokio::io::AsyncWriteExt;
1414 let mut stdout = tokio::io::stdout();
1415 stdout.write_all(json.as_bytes()).await?;
1416 stdout.write_all(b"\n").await?;
1417 stdout.flush().await?;
1418 Ok(())
1419 }
1420
1421 #[allow(dead_code)]
1423 fn next_id(&mut self) -> i64 {
1424 self.request_id += 1;
1425 self.request_id
1426 }
1427}
1428
1429#[cfg(test)]
1430mod server_tests {
1431 use super::*;
1432 use crate::tools::ToolRegistry;
1433
1434 #[test]
1435 fn test_mcp_server_initialize_response() {
1436 let registry = ToolRegistry::with_default_tools();
1437 let server = McpServer::new(registry);
1438
1439 assert_eq!(server.server_info.name, "ravenclaws");
1441 assert!(!server.server_info.version.is_empty());
1442 assert!(!server.initialized);
1443 }
1444
1445 #[test]
1446 fn test_mcp_server_tools_list_response() {
1447 let registry = ToolRegistry::with_default_tools();
1448 let server = McpServer::new(registry);
1449
1450 let defs = server.registry.definitions();
1452 let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
1453 assert!(names.contains(&"shell_exec"));
1454 assert!(names.contains(&"read_file"));
1455 assert!(names.contains(&"write_file"));
1456 assert!(names.contains(&"web_fetch"));
1457 assert!(names.contains(&"web_search"));
1458 assert_eq!(defs.len(), 5);
1459 }
1460
1461 #[tokio::test]
1462 async fn test_mcp_server_handle_unknown_method() {
1463 let registry = ToolRegistry::with_default_tools();
1464 let mut server = McpServer::new(registry);
1465
1466 let request = JsonRpcRequest {
1467 jsonrpc: "2.0".to_string(),
1468 method: "unknown_method".to_string(),
1469 params: Some(serde_json::Value::Null),
1470 id: serde_json::Value::Number(1.into()),
1471 };
1472
1473 let response = server.handle_request(&request).await;
1474 assert!(response.get("error").is_some());
1475 assert_eq!(
1476 response["error"]["code"],
1477 serde_json::Value::Number((-32601).into())
1478 );
1479 }
1480
1481 #[tokio::test]
1482 async fn test_mcp_server_handle_tools_list() {
1483 let registry = ToolRegistry::with_default_tools();
1484 let server = McpServer::new(registry);
1485
1486 let request_id = serde_json::Value::Number(1.into());
1487 let response = server.handle_tools_list(&request_id).await;
1488
1489 assert!(response.get("result").is_some());
1490 let tools = &response["result"]["tools"];
1491 assert!(tools.is_array());
1492 assert!(!tools.as_array().unwrap().is_empty());
1493 }
1494
1495 #[tokio::test]
1496 async fn test_mcp_server_handle_tools_call_missing_name() {
1497 let registry = ToolRegistry::with_default_tools();
1498 let mut server = McpServer::new(registry);
1499
1500 let request = JsonRpcRequest {
1501 jsonrpc: "2.0".to_string(),
1502 method: "tools/call".to_string(),
1503 params: Some(serde_json::json!({})),
1504 id: serde_json::Value::Number(1.into()),
1505 };
1506
1507 let request_id = serde_json::Value::Number(1.into());
1508 let response = server.handle_tools_call(&request, &request_id).await;
1509
1510 assert!(response.get("error").is_some());
1511 assert_eq!(
1512 response["error"]["code"],
1513 serde_json::Value::Number((-32602).into())
1514 );
1515 }
1516
1517 #[tokio::test]
1518 async fn test_mcp_server_handle_tools_call_unknown_tool() {
1519 let registry = ToolRegistry::with_default_tools();
1520 let mut server = McpServer::new(registry);
1521
1522 let request = JsonRpcRequest {
1523 jsonrpc: "2.0".to_string(),
1524 method: "tools/call".to_string(),
1525 params: Some(serde_json::json!({
1526 "name": "nonexistent_tool",
1527 "arguments": {}
1528 })),
1529 id: serde_json::Value::Number(1.into()),
1530 };
1531
1532 let request_id = serde_json::Value::Number(1.into());
1533 let response = server.handle_tools_call(&request, &request_id).await;
1534
1535 assert!(response["result"]["isError"].as_bool().unwrap_or(false));
1537 }
1538
1539 #[test]
1540 fn test_mcp_server_json_rpc_error_codes() {
1541 assert_eq!(-32700i32, -32700);
1547 assert_eq!(-32601i32, -32601);
1548 assert_eq!(-32602i32, -32602);
1549 }
1550}
1551
1552pub struct McpSseServer {
1568 registry: crate::tools::ToolRegistry,
1569 policy_engine: crate::policy::PolicyEngine,
1570 sandbox: crate::sandbox::Sandbox,
1571 audit_log: crate::audit::AuditLog,
1572 server_info: McpServerInfo,
1573 clients: Arc<tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>>,
1575 host: String,
1576 port: u16,
1577}
1578
1579impl McpSseServer {
1580 pub fn new(registry: crate::tools::ToolRegistry, host: String, port: u16) -> Self {
1582 let server_info = McpServerInfo {
1583 name: "ravenclaws".to_string(),
1584 version: env!("CARGO_PKG_VERSION").to_string(),
1585 };
1586
1587 Self {
1588 registry,
1589 policy_engine: crate::policy::PolicyEngine::default_secure(),
1590 sandbox: crate::sandbox::Sandbox::default(),
1591 audit_log: crate::audit::AuditLog::new(format!("mcp-sse-{}", std::process::id())),
1592 server_info,
1593 clients: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
1594 host,
1595 port,
1596 }
1597 }
1598
1599 pub async fn run(
1601 &mut self,
1602 shutdown: tokio::sync::watch::Receiver<bool>,
1603 ) -> Result<(), McpError> {
1604 self.sandbox
1606 .init()
1607 .await
1608 .map_err(|e| McpError::Transport(format!("Sandbox init failed: {}", e)))?;
1609
1610 let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
1611 .parse()
1612 .map_err(|e| McpError::Transport(format!("Invalid address: {}", e)))?;
1613
1614 info!(addr = %addr, "MCP SSE server starting");
1615
1616 let listener = tokio::net::TcpListener::bind(addr)
1619 .await
1620 .map_err(|e| McpError::Transport(format!("Failed to bind: {}", e)))?;
1621
1622 let clients = self.clients.clone();
1623 let registry = Arc::new(tokio::sync::RwLock::new(std::mem::replace(
1624 &mut self.registry,
1625 crate::tools::ToolRegistry::new(),
1626 )));
1627 let policy_engine = Arc::new(tokio::sync::RwLock::new(std::mem::replace(
1628 &mut self.policy_engine,
1629 crate::policy::PolicyEngine::default_secure(),
1630 )));
1631 let sandbox = Arc::new(tokio::sync::RwLock::new(std::mem::take(&mut self.sandbox)));
1632 let audit_log = Arc::new(tokio::sync::RwLock::new(std::mem::replace(
1633 &mut self.audit_log,
1634 crate::audit::AuditLog::new(format!("mcp-sse-{}", std::process::id())),
1635 )));
1636 let server_info = Arc::new(self.server_info.clone());
1637
1638 let mut shutdown = shutdown;
1640
1641 loop {
1642 tokio::select! {
1643 accept_result = listener.accept() => {
1644 match accept_result {
1645 Ok((stream, peer_addr)) => {
1646 let clients = clients.clone();
1647 let registry = registry.clone();
1648 let policy_engine = policy_engine.clone();
1649 let sandbox = sandbox.clone();
1650 let audit_log = audit_log.clone();
1651 let server_info = server_info.clone();
1652
1653 tokio::spawn(async move {
1654 if let Err(e) = Self::handle_connection(
1655 stream, peer_addr, clients, registry,
1656 policy_engine, sandbox, audit_log, server_info,
1657 ).await {
1658 warn!(peer = %peer_addr, error = %e, "MCP SSE connection error");
1659 }
1660 });
1661 }
1662 Err(e) => {
1663 warn!("Accept error: {}", e);
1664 }
1665 }
1666 }
1667 _ = shutdown.changed() => {
1668 info!("MCP SSE server shutting down");
1669 break;
1670 }
1671 }
1672 }
1673
1674 Ok(())
1675 }
1676
1677 #[allow(clippy::too_many_arguments)]
1679 async fn handle_connection(
1680 mut stream: tokio::net::TcpStream,
1681 peer_addr: std::net::SocketAddr,
1682 clients: Arc<
1683 tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>,
1684 >,
1685 registry: Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1686 policy_engine: Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1687 sandbox: Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1688 audit_log: Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1689 server_info: Arc<McpServerInfo>,
1690 ) -> Result<(), McpError> {
1691 use tokio::io::AsyncReadExt;
1692
1693 let mut buf = [0u8; 8192];
1694 let n = stream
1695 .read(&mut buf)
1696 .await
1697 .map_err(|e| McpError::Transport(format!("Read error: {}", e)))?;
1698
1699 if n == 0 {
1700 return Ok(());
1701 }
1702
1703 let request = String::from_utf8_lossy(&buf[..n]).to_string();
1704
1705 let (method, path) = if let Some(first_line) = request.lines().next() {
1707 let parts: Vec<&str> = first_line.split_whitespace().collect();
1708 if parts.len() < 2 {
1709 return Err(McpError::Transport("Invalid HTTP request".to_string()));
1710 }
1711 (parts[0].to_string(), parts[1].to_string())
1712 } else {
1713 return Err(McpError::Transport("Empty HTTP request".to_string()));
1714 };
1715
1716 match (method.as_str(), path.as_str()) {
1717 ("GET", "/sse") => {
1718 Self::handle_sse_connection(
1719 stream,
1720 peer_addr,
1721 clients,
1722 registry,
1723 policy_engine,
1724 sandbox,
1725 audit_log,
1726 server_info,
1727 )
1728 .await
1729 }
1730 ("POST", "/message") => {
1731 let body = if let Some(body_start) = request.find("\r\n\r\n") {
1733 request[body_start + 4..].to_string()
1734 } else {
1735 return Err(McpError::Transport("No body in POST request".to_string()));
1736 };
1737
1738 Self::handle_message_post(
1739 stream,
1740 &body,
1741 ®istry,
1742 &policy_engine,
1743 &sandbox,
1744 &audit_log,
1745 &server_info,
1746 clients,
1747 )
1748 .await
1749 }
1750 _ => {
1751 let response = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
1753 stream
1754 .write_all(response.as_bytes())
1755 .await
1756 .map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
1757 Ok(())
1758 }
1759 }
1760 }
1761
1762 #[allow(clippy::too_many_arguments)]
1764 async fn handle_sse_connection(
1765 mut stream: tokio::net::TcpStream,
1766 peer_addr: std::net::SocketAddr,
1767 clients: Arc<
1768 tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>,
1769 >,
1770 _registry: Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1771 _policy_engine: Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1772 _sandbox: Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1773 _audit_log: Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1774 _server_info: Arc<McpServerInfo>,
1775 ) -> Result<(), McpError> {
1776 use tokio::io::AsyncWriteExt;
1777
1778 let client_id = Uuid::new_v4().to_string();
1779 info!(client = %client_id, peer = %peer_addr, "MCP SSE client connected");
1780
1781 let headers = "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nCache-Control: no-cache\r\nConnection: keep-alive\r\nAccess-Control-Allow-Origin: *\r\n\r\n";
1783 stream
1784 .write_all(headers.as_bytes())
1785 .await
1786 .map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
1787 stream
1788 .flush()
1789 .await
1790 .map_err(|e| McpError::Transport(format!("Flush error: {}", e)))?;
1791
1792 let endpoint_event = "event: endpoint\ndata: /message\n\n".to_string();
1794 stream
1795 .write_all(endpoint_event.as_bytes())
1796 .await
1797 .map_err(|e| McpError::Transport(format!("Write error: {}", e)))?;
1798 stream
1799 .flush()
1800 .await
1801 .map_err(|e| McpError::Transport(format!("Flush error: {}", e)))?;
1802
1803 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1805
1806 clients.write().await.insert(client_id.clone(), tx);
1808
1809 loop {
1811 tokio::select! {
1812 msg = rx.recv() => {
1813 match msg {
1814 Some(data) => {
1815 let sse_event = format!("data: {}\n\n", data);
1816 if stream.write_all(sse_event.as_bytes()).await.is_err() {
1817 break;
1818 }
1819 if stream.flush().await.is_err() {
1820 break;
1821 }
1822 }
1823 None => break,
1824 }
1825 }
1826 }
1827 }
1828
1829 clients.write().await.remove(&client_id);
1831 info!(client = %client_id, "MCP SSE client disconnected");
1832 Ok(())
1833 }
1834
1835 #[allow(clippy::too_many_arguments)]
1837 async fn handle_message_post(
1838 mut stream: tokio::net::TcpStream,
1839 body: &str,
1840 registry: &Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1841 policy_engine: &Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1842 sandbox: &Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1843 audit_log: &Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1844 server_info: &Arc<McpServerInfo>,
1845 clients: Arc<
1846 tokio::sync::RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>,
1847 >,
1848 ) -> Result<(), McpError> {
1849 use tokio::io::AsyncWriteExt;
1850
1851 let request: JsonRpcRequest = match serde_json::from_str(body) {
1853 Ok(req) => req,
1854 Err(e) => {
1855 let error_response = serde_json::json!({
1856 "jsonrpc": "2.0",
1857 "error": {
1858 "code": -32700,
1859 "message": "Parse error",
1860 "data": e.to_string()
1861 },
1862 "id": serde_json::Value::Null
1863 });
1864 let response_body = serde_json::to_string(&error_response)?;
1865 let http_response = format!(
1866 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
1867 response_body.len(),
1868 response_body
1869 );
1870 stream.write_all(http_response.as_bytes()).await?;
1871 return Ok(());
1872 }
1873 };
1874
1875 let request_id = request.id.clone();
1877 let response = Self::handle_jsonrpc_request(
1878 &request,
1879 &request_id,
1880 registry,
1881 policy_engine,
1882 sandbox,
1883 audit_log,
1884 server_info,
1885 )
1886 .await;
1887
1888 let response_body = serde_json::to_string(&response)?;
1889
1890 let http_response = format!(
1892 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\n\r\n{}",
1893 response_body.len(),
1894 response_body
1895 );
1896 stream.write_all(http_response.as_bytes()).await?;
1897 stream.flush().await?;
1898
1899 let response_json = serde_json::to_string(&response)?;
1901 let clients_guard = clients.read().await;
1902 for (_, tx) in clients_guard.iter() {
1903 let _ = tx.send(response_json.clone());
1904 }
1905
1906 Ok(())
1907 }
1908
1909 async fn handle_jsonrpc_request(
1911 request: &JsonRpcRequest,
1912 request_id: &serde_json::Value,
1913 registry: &Arc<tokio::sync::RwLock<crate::tools::ToolRegistry>>,
1914 policy_engine: &Arc<tokio::sync::RwLock<crate::policy::PolicyEngine>>,
1915 _sandbox: &Arc<tokio::sync::RwLock<crate::sandbox::Sandbox>>,
1916 audit_log: &Arc<tokio::sync::RwLock<crate::audit::AuditLog>>,
1917 server_info: &Arc<McpServerInfo>,
1918 ) -> serde_json::Value {
1919 match request.method.as_str() {
1920 "initialize" => {
1921 if let Some(params) = request.params.as_ref().and_then(|p| p.as_object()) {
1923 if let Some(client_info) = params.get("clientInfo") {
1924 info!(
1925 client = ?client_info.get("name").and_then(|v| v.as_str()).unwrap_or("unknown"),
1926 "MCP SSE client initialized"
1927 );
1928 }
1929 }
1930
1931 let capabilities = serde_json::json!({
1932 "protocolVersion": "2024-11-05",
1933 "capabilities": {
1934 "tools": {
1935 "listChanged": false
1936 }
1937 },
1938 "serverInfo": {
1939 "name": server_info.name,
1940 "version": server_info.version
1941 }
1942 });
1943
1944 serde_json::json!({
1945 "jsonrpc": "2.0",
1946 "result": capabilities,
1947 "id": request_id
1948 })
1949 }
1950 "notifications/initialized" => {
1951 info!("MCP SSE client initialized notification received");
1952 serde_json::json!({
1953 "jsonrpc": "2.0",
1954 "result": null,
1955 "id": request_id
1956 })
1957 }
1958 "tools/list" => {
1959 let defs = registry.read().await.definitions().clone();
1960 let tools: Vec<serde_json::Value> = defs
1961 .iter()
1962 .map(|def| {
1963 serde_json::json!({
1964 "name": def.name,
1965 "description": def.description,
1966 "inputSchema": def.parameters
1967 })
1968 })
1969 .collect();
1970
1971 serde_json::json!({
1972 "jsonrpc": "2.0",
1973 "result": {
1974 "tools": tools
1975 },
1976 "id": request_id
1977 })
1978 }
1979 "tools/call" => {
1980 let params = request.params.as_ref().unwrap_or(&serde_json::Value::Null);
1981
1982 let name = params
1983 .get("name")
1984 .and_then(|v| v.as_str())
1985 .unwrap_or("")
1986 .to_string();
1987
1988 let arguments = params
1989 .get("arguments")
1990 .cloned()
1991 .unwrap_or(serde_json::Value::Null);
1992
1993 if name.is_empty() {
1994 return serde_json::json!({
1995 "jsonrpc": "2.0",
1996 "error": {
1997 "code": -32602,
1998 "message": "Invalid params: missing tool name"
1999 },
2000 "id": request_id
2001 });
2002 }
2003
2004 let decision = policy_engine
2006 .read()
2007 .await
2008 .check_tool_call(&name, &arguments);
2009 match decision {
2010 crate::policy::Decision::Deny(reason) => {
2011 warn!(tool = %name, reason = %reason, "MCP SSE tool call denied by policy");
2012 return serde_json::json!({
2013 "jsonrpc": "2.0",
2014 "result": {
2015 "content": [{
2016 "type": "text",
2017 "text": format!("Policy denied: {}", reason)
2018 }],
2019 "isError": true
2020 },
2021 "id": request_id
2022 });
2023 }
2024 crate::policy::Decision::Allow => {
2025 let _ = audit_log.write().await.tool_call(&name, &arguments);
2026 }
2027 }
2028
2029 let call = crate::tools::ToolCall {
2031 name: name.clone(),
2032 arguments,
2033 id: None,
2034 };
2035
2036 match registry.read().await.execute(call).await {
2037 Ok(result) => {
2038 let _ = audit_log.write().await.append(
2039 crate::audit::AuditEventType::ToolResult,
2040 &name,
2041 &format!(
2042 "MCP SSE tool executed: {} (success: {})",
2043 name, result.success
2044 ),
2045 Some(serde_json::json!({
2046 "success": result.success,
2047 "exit_code": result.exit_code,
2048 "duration_ms": result.duration_ms,
2049 })),
2050 );
2051
2052 let content = if result.success {
2053 vec![serde_json::json!({
2054 "type": "text",
2055 "text": result.output
2056 })]
2057 } else {
2058 vec![serde_json::json!({
2059 "type": "text",
2060 "text": result.error.as_deref().unwrap_or("Unknown error")
2061 })]
2062 };
2063
2064 serde_json::json!({
2065 "jsonrpc": "2.0",
2066 "result": {
2067 "content": content,
2068 "isError": !result.success
2069 },
2070 "id": request_id
2071 })
2072 }
2073 Err(e) => {
2074 warn!(tool = %name, error = %e, "MCP SSE tool execution failed");
2075 serde_json::json!({
2076 "jsonrpc": "2.0",
2077 "result": {
2078 "content": [{
2079 "type": "text",
2080 "text": format!("Tool execution failed: {}", e)
2081 }],
2082 "isError": true
2083 },
2084 "id": request_id
2085 })
2086 }
2087 }
2088 }
2089 _ => {
2090 serde_json::json!({
2091 "jsonrpc": "2.0",
2092 "error": {
2093 "code": -32601,
2094 "message": format!("Method not found: {}", request.method)
2095 },
2096 "id": request_id
2097 })
2098 }
2099 }
2100 }
2101}
2102
2103#[cfg(test)]
2104mod sse_server_tests {
2105 use super::*;
2106 use crate::tools::ToolRegistry;
2107
2108 #[test]
2109 fn test_mcp_sse_server_new() {
2110 let registry = ToolRegistry::with_default_tools();
2111 let server = McpSseServer::new(registry, "127.0.0.1".to_string(), 9091);
2112
2113 assert_eq!(server.host, "127.0.0.1");
2114 assert_eq!(server.port, 9091);
2115 assert_eq!(server.server_info.name, "ravenclaws");
2116 assert!(server.clients.blocking_read().is_empty());
2117 }
2118
2119 #[test]
2120 fn test_mcp_sse_server_info() {
2121 let registry = ToolRegistry::with_default_tools();
2122 let server = McpSseServer::new(registry, "0.0.0.0".to_string(), 9092);
2123
2124 assert_eq!(server.server_info.name, "ravenclaws");
2125 assert!(!server.server_info.version.is_empty());
2126 }
2127
2128 #[tokio::test]
2129 async fn test_mcp_sse_handle_initialize() {
2130 let registry = ToolRegistry::with_default_tools();
2131 let server_info = Arc::new(McpServerInfo {
2132 name: "ravenclaws".to_string(),
2133 version: env!("CARGO_PKG_VERSION").to_string(),
2134 });
2135
2136 let request = JsonRpcRequest {
2137 jsonrpc: "2.0".to_string(),
2138 method: "initialize".to_string(),
2139 params: Some(serde_json::json!({
2140 "protocolVersion": "2024-11-05",
2141 "clientInfo": {
2142 "name": "test-client",
2143 "version": "1.0.0"
2144 }
2145 })),
2146 id: serde_json::Value::Number(1.into()),
2147 };
2148
2149 let request_id = serde_json::Value::Number(1.into());
2150 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2151 let policy = Arc::new(tokio::sync::RwLock::new(
2152 crate::policy::PolicyEngine::default_secure(),
2153 ));
2154 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2155 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2156 "test".to_string(),
2157 )));
2158
2159 let response = McpSseServer::handle_jsonrpc_request(
2160 &request,
2161 &request_id,
2162 ®istry,
2163 &policy,
2164 &sandbox,
2165 &audit,
2166 &server_info,
2167 )
2168 .await;
2169
2170 assert!(response.get("result").is_some());
2171 assert_eq!(response["result"]["serverInfo"]["name"], "ravenclaws");
2172 }
2173
2174 #[tokio::test]
2175 async fn test_mcp_sse_handle_tools_list() {
2176 let registry = ToolRegistry::with_default_tools();
2177 let server_info = Arc::new(McpServerInfo {
2178 name: "ravenclaws".to_string(),
2179 version: env!("CARGO_PKG_VERSION").to_string(),
2180 });
2181
2182 let request = JsonRpcRequest {
2183 jsonrpc: "2.0".to_string(),
2184 method: "tools/list".to_string(),
2185 params: Some(serde_json::Value::Null),
2186 id: serde_json::Value::Number(1.into()),
2187 };
2188
2189 let request_id = serde_json::Value::Number(1.into());
2190 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2191 let policy = Arc::new(tokio::sync::RwLock::new(
2192 crate::policy::PolicyEngine::default_secure(),
2193 ));
2194 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2195 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2196 "test".to_string(),
2197 )));
2198
2199 let response = McpSseServer::handle_jsonrpc_request(
2200 &request,
2201 &request_id,
2202 ®istry,
2203 &policy,
2204 &sandbox,
2205 &audit,
2206 &server_info,
2207 )
2208 .await;
2209
2210 assert!(response.get("result").is_some());
2211 let tools = &response["result"]["tools"];
2212 assert!(tools.is_array());
2213 assert!(!tools.as_array().unwrap().is_empty());
2214 }
2215
2216 #[tokio::test]
2217 async fn test_mcp_sse_handle_unknown_method() {
2218 let registry = ToolRegistry::with_default_tools();
2219 let server_info = Arc::new(McpServerInfo {
2220 name: "ravenclaws".to_string(),
2221 version: env!("CARGO_PKG_VERSION").to_string(),
2222 });
2223
2224 let request = JsonRpcRequest {
2225 jsonrpc: "2.0".to_string(),
2226 method: "unknown_method".to_string(),
2227 params: Some(serde_json::Value::Null),
2228 id: serde_json::Value::Number(1.into()),
2229 };
2230
2231 let request_id = serde_json::Value::Number(1.into());
2232 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2233 let policy = Arc::new(tokio::sync::RwLock::new(
2234 crate::policy::PolicyEngine::default_secure(),
2235 ));
2236 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2237 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2238 "test".to_string(),
2239 )));
2240
2241 let response = McpSseServer::handle_jsonrpc_request(
2242 &request,
2243 &request_id,
2244 ®istry,
2245 &policy,
2246 &sandbox,
2247 &audit,
2248 &server_info,
2249 )
2250 .await;
2251
2252 assert!(response.get("error").is_some());
2253 assert_eq!(
2254 response["error"]["code"],
2255 serde_json::Value::Number((-32601).into())
2256 );
2257 }
2258
2259 #[tokio::test]
2260 async fn test_mcp_sse_handle_tools_call_missing_name() {
2261 let registry = ToolRegistry::with_default_tools();
2262 let server_info = Arc::new(McpServerInfo {
2263 name: "ravenclaws".to_string(),
2264 version: env!("CARGO_PKG_VERSION").to_string(),
2265 });
2266
2267 let request = JsonRpcRequest {
2268 jsonrpc: "2.0".to_string(),
2269 method: "tools/call".to_string(),
2270 params: Some(serde_json::json!({})),
2271 id: serde_json::Value::Number(1.into()),
2272 };
2273
2274 let request_id = serde_json::Value::Number(1.into());
2275 let registry = Arc::new(tokio::sync::RwLock::new(registry));
2276 let policy = Arc::new(tokio::sync::RwLock::new(
2277 crate::policy::PolicyEngine::default_secure(),
2278 ));
2279 let sandbox = Arc::new(tokio::sync::RwLock::new(crate::sandbox::Sandbox::default()));
2280 let audit = Arc::new(tokio::sync::RwLock::new(crate::audit::AuditLog::new(
2281 "test".to_string(),
2282 )));
2283
2284 let response = McpSseServer::handle_jsonrpc_request(
2285 &request,
2286 &request_id,
2287 ®istry,
2288 &policy,
2289 &sandbox,
2290 &audit,
2291 &server_info,
2292 )
2293 .await;
2294
2295 assert!(response.get("error").is_some());
2296 assert_eq!(
2297 response["error"]["code"],
2298 serde_json::Value::Number((-32602).into())
2299 );
2300 }
2301
2302 #[tokio::test]
2303 async fn test_mcp_sse_transport_config_serde() {
2304 let config = McpTransportConfig::Sse {
2306 url: "http://localhost:9090/sse".to_string(),
2307 };
2308
2309 match config {
2310 McpTransportConfig::Sse { url } => {
2311 assert_eq!(url, "http://localhost:9090/sse");
2312 }
2313 _ => panic!("Expected SSE variant"),
2314 }
2315 }
2316
2317 #[tokio::test]
2318 async fn test_mcp_sse_transport_connect_failure() {
2319 let config = McpTransportConfig::Sse {
2321 url: "http://127.0.0.1:1/sse".to_string(),
2322 };
2323
2324 let mut transport = McpTransport::new(config);
2325 let result = transport.connect().await;
2326
2327 assert!(result.is_err());
2329 match result {
2330 Err(McpError::ConnectionFailed(_)) => {} Err(McpError::Transport(_)) => {} _ => panic!("Expected connection or transport error, got {:?}", result),
2333 }
2334 }
2335}