1use crate::{Config, Result, Error, config::{ServerConfig, Transport, Auth}};
14use serde::{Deserialize, Serialize};
15use serde_json::{json, Value};
16use std::collections::HashMap;
17use std::process::Stdio;
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
21use tokio::process::{Child, Command};
22use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
23use tokio::time::{interval, timeout};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct JsonRpcRequest {
32 pub jsonrpc: String,
33 pub id: Value,
34 pub method: String,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub params: Option<Value>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct JsonRpcResponse {
42 pub jsonrpc: String,
43 pub id: Value,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub result: Option<Value>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub error: Option<JsonRpcError>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct JsonRpcNotification {
53 pub jsonrpc: String,
54 pub method: String,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub params: Option<Value>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct JsonRpcError {
62 pub code: i64,
63 pub message: String,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub data: Option<Value>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct McpTool {
71 pub name: String,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub title: Option<String>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub description: Option<String>,
76 #[serde(rename = "inputSchema")]
77 pub input_schema: Value,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct McpResource {
83 pub uri: String,
84 pub name: String,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub description: Option<String>,
87 #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
88 pub mime_type: Option<String>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct McpPrompt {
94 pub name: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub description: Option<String>,
97 #[serde(default)]
98 pub arguments: Vec<McpPromptArgument>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct McpPromptArgument {
104 pub name: String,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub description: Option<String>,
107 #[serde(default)]
108 pub required: bool,
109}
110
111#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct McpCapabilities {
114 #[serde(default)]
115 pub tools: Option<ToolsCapability>,
116 #[serde(default)]
117 pub resources: Option<ResourcesCapability>,
118 #[serde(default)]
119 pub prompts: Option<PromptsCapability>,
120 #[serde(default)]
121 pub logging: Option<Value>,
122}
123
124#[derive(Debug, Clone, Default, Serialize, Deserialize)]
125pub struct ToolsCapability {
126 #[serde(rename = "listChanged", default)]
127 pub list_changed: bool,
128}
129
130#[derive(Debug, Clone, Default, Serialize, Deserialize)]
131pub struct ResourcesCapability {
132 #[serde(rename = "listChanged", default)]
133 pub list_changed: bool,
134 #[serde(default)]
135 pub subscribe: bool,
136}
137
138#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct PromptsCapability {
140 #[serde(rename = "listChanged", default)]
141 pub list_changed: bool,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct McpServerInfo {
147 pub name: String,
148 #[serde(default)]
149 pub version: String,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158pub enum ServerStatus {
159 Connecting,
160 Connected,
161 Disconnected,
162 Error,
163 Reconnecting,
164}
165
166impl std::fmt::Display for ServerStatus {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 match self {
169 ServerStatus::Connecting => write!(f, "connecting"),
170 ServerStatus::Connected => write!(f, "connected"),
171 ServerStatus::Disconnected => write!(f, "disconnected"),
172 ServerStatus::Error => write!(f, "error"),
173 ServerStatus::Reconnecting => write!(f, "reconnecting"),
174 }
175 }
176}
177
178pub struct StdioTransport {
184 stdin: Arc<Mutex<tokio::process::ChildStdin>>,
185 pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
186 connected: Arc<std::sync::atomic::AtomicBool>,
187 _child: Arc<Mutex<Child>>,
188}
189
190impl StdioTransport {
191 pub async fn spawn(command: &str, args: &[String], env: Option<&HashMap<String, String>>) -> Result<Self> {
193 let mut cmd = Command::new(command);
194 cmd.args(args)
195 .stdin(Stdio::piped())
196 .stdout(Stdio::piped())
197 .stderr(Stdio::piped());
198
199 if let Some(env_vars) = env {
200 for (k, v) in env_vars {
201 cmd.env(k, v);
202 }
203 }
204
205 let mut child = cmd.spawn()
206 .map_err(|e| Error::Transport(format!("failed to spawn {}: {}", command, e)))?;
207
208 let stdin = child.stdin.take()
209 .ok_or_else(|| Error::Transport("failed to get stdin".into()))?;
210 let stdout = child.stdout.take()
211 .ok_or_else(|| Error::Transport("failed to get stdout".into()))?;
212
213 let pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
214 Arc::new(RwLock::new(HashMap::new()));
215 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
216
217 let pending_clone = pending.clone();
219 let connected_clone = connected.clone();
220 tokio::spawn(async move {
221 let mut reader = BufReader::new(stdout).lines();
222 while let Ok(Some(line)) = reader.next_line().await {
223 if line.is_empty() {
224 continue;
225 }
226
227 match serde_json::from_str::<JsonRpcResponse>(&line) {
228 Ok(response) => {
229 let id_str = match &response.id {
230 Value::Number(n) => n.to_string(),
231 Value::String(s) => s.clone(),
232 _ => continue,
233 };
234
235 let mut pending = pending_clone.write().await;
236 if let Some(tx) = pending.remove(&id_str) {
237 let _ = tx.send(response);
238 }
239 }
240 Err(e) => {
241 tracing::debug!("Failed to parse response: {} - line: {}", e, line);
242 }
243 }
244 }
245 connected_clone.store(false, std::sync::atomic::Ordering::SeqCst);
246 });
247
248 Ok(Self {
249 stdin: Arc::new(Mutex::new(stdin)),
250 pending,
251 connected,
252 _child: Arc::new(Mutex::new(child)),
253 })
254 }
255
256 pub async fn request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
257 let id_str = match &req.id {
258 Value::Number(n) => n.to_string(),
259 Value::String(s) => s.clone(),
260 _ => return Err(Error::Protocol("invalid request id".into())),
261 };
262
263 let (tx, rx) = oneshot::channel();
264
265 {
266 let mut pending = self.pending.write().await;
267 pending.insert(id_str.clone(), tx);
268 }
269
270 let line = serde_json::to_string(&req)? + "\n";
271 {
272 let mut stdin = self.stdin.lock().await;
273 stdin.write_all(line.as_bytes()).await
274 .map_err(|e| Error::Transport(format!("write failed: {}", e)))?;
275 stdin.flush().await
276 .map_err(|e| Error::Transport(format!("flush failed: {}", e)))?;
277 }
278
279 match timeout(Duration::from_secs(30), rx).await {
280 Ok(Ok(response)) => Ok(response),
281 Ok(Err(_)) => Err(Error::Transport("response channel closed".into())),
282 Err(_) => {
283 let mut pending = self.pending.write().await;
284 pending.remove(&id_str);
285 Err(Error::Transport("request timeout".into()))
286 }
287 }
288 }
289
290 pub async fn notify(&self, notif: JsonRpcNotification) -> Result<()> {
291 let line = serde_json::to_string(¬if)? + "\n";
292 let mut stdin = self.stdin.lock().await;
293 stdin.write_all(line.as_bytes()).await
294 .map_err(|e| Error::Transport(format!("write failed: {}", e)))?;
295 stdin.flush().await
296 .map_err(|e| Error::Transport(format!("flush failed: {}", e)))?;
297 Ok(())
298 }
299
300 pub async fn close(&self) -> Result<()> {
301 let mut child = self._child.lock().await;
302 let _ = child.kill().await;
303 self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
304 Ok(())
305 }
306
307 pub fn is_connected(&self) -> bool {
308 self.connected.load(std::sync::atomic::Ordering::SeqCst)
309 }
310}
311
312pub struct HttpTransport {
318 endpoint: String,
319 session_id: Arc<RwLock<Option<String>>>,
320 auth: Option<Auth>,
321 connected: Arc<std::sync::atomic::AtomicBool>,
322}
323
324impl HttpTransport {
325 pub fn new(endpoint: &str, auth: Option<Auth>) -> Result<Self> {
326 Ok(Self {
327 endpoint: endpoint.to_string(),
328 session_id: Arc::new(RwLock::new(None)),
329 auth,
330 connected: Arc::new(std::sync::atomic::AtomicBool::new(true)),
331 })
332 }
333
334 pub async fn request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
335 use http_body_util::{BodyExt, Full};
336 use hyper::body::Bytes;
337 use hyper::Request;
338 use hyper_util::client::legacy::Client;
339 use hyper_util::rt::TokioExecutor;
340
341 let body_json = serde_json::to_string(&req)?;
342
343 let uri: hyper::Uri = self.endpoint.parse()
344 .map_err(|e| Error::Transport(format!("invalid URI: {}", e)))?;
345
346 let mut request_builder = Request::builder()
347 .method("POST")
348 .uri(&uri)
349 .header("Content-Type", "application/json")
350 .header("Accept", "application/json, text/event-stream");
351
352 if let Some(ref auth) = self.auth {
353 match auth {
354 Auth::Bearer { token } => {
355 request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
356 }
357 Auth::Basic { username, password } => {
358 let credentials = format!("{}:{}", username, password);
359 let encoded = hex::encode(credentials.as_bytes());
360 request_builder = request_builder.header("Authorization", format!("Basic {}", encoded));
361 }
362 }
363 }
364
365 if let Some(ref sid) = *self.session_id.read().await {
366 request_builder = request_builder.header("Mcp-Session-Id", sid.as_str());
367 }
368
369 let request = request_builder
370 .body(Full::new(Bytes::from(body_json)))
371 .map_err(|e| Error::Transport(format!("failed to build request: {}", e)))?;
372
373 let https = hyper_util::client::legacy::connect::HttpConnector::new();
374 let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(https);
375
376 let response = client.request(request).await
377 .map_err(|e| Error::Transport(format!("HTTP request failed: {}", e)))?;
378
379 if let Some(sid) = response.headers().get("Mcp-Session-Id") {
380 if let Ok(sid_str) = sid.to_str() {
381 *self.session_id.write().await = Some(sid_str.to_string());
382 }
383 }
384
385 let status = response.status();
386 if !status.is_success() {
387 self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
388 return Err(Error::Transport(format!("HTTP error: {}", status)));
389 }
390
391 let body_bytes = response.into_body().collect().await
392 .map_err(|e| Error::Transport(format!("failed to read response: {}", e)))?
393 .to_bytes();
394
395 let body = String::from_utf8_lossy(&body_bytes);
396
397 let json_str = if body.starts_with("data:") {
398 body.lines()
399 .filter(|l| l.starts_with("data:"))
400 .last()
401 .map(|l| l.trim_start_matches("data:").trim())
402 .unwrap_or(&body)
403 } else {
404 &body
405 };
406
407 serde_json::from_str(json_str)
408 .map_err(|e| Error::Protocol(format!("invalid JSON response: {}", e)))
409 }
410
411 pub async fn notify(&self, notif: JsonRpcNotification) -> Result<()> {
412 use http_body_util::Full;
413 use hyper::body::Bytes;
414 use hyper::Request;
415 use hyper_util::client::legacy::Client;
416 use hyper_util::rt::TokioExecutor;
417
418 let body_json = serde_json::to_string(¬if)?;
419 let uri: hyper::Uri = self.endpoint.parse()
420 .map_err(|e| Error::Transport(format!("invalid URI: {}", e)))?;
421
422 let mut request_builder = Request::builder()
423 .method("POST")
424 .uri(&uri)
425 .header("Content-Type", "application/json");
426
427 if let Some(ref auth) = self.auth {
428 match auth {
429 Auth::Bearer { token } => {
430 request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
431 }
432 Auth::Basic { username, password } => {
433 let credentials = format!("{}:{}", username, password);
434 let encoded = hex::encode(credentials.as_bytes());
435 request_builder = request_builder.header("Authorization", format!("Basic {}", encoded));
436 }
437 }
438 }
439
440 if let Some(ref sid) = *self.session_id.read().await {
441 request_builder = request_builder.header("Mcp-Session-Id", sid.as_str());
442 }
443
444 let request = request_builder
445 .body(Full::new(Bytes::from(body_json)))
446 .map_err(|e| Error::Transport(format!("failed to build request: {}", e)))?;
447
448 let https = hyper_util::client::legacy::connect::HttpConnector::new();
449 let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(https);
450
451 let response = client.request(request).await
452 .map_err(|e| Error::Transport(format!("HTTP request failed: {}", e)))?;
453
454 let status = response.status();
455 if status != hyper::StatusCode::ACCEPTED && !status.is_success() {
456 return Err(Error::Transport(format!("unexpected status: {}", status)));
457 }
458
459 Ok(())
460 }
461
462 pub async fn close(&self) -> Result<()> {
463 self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
464 Ok(())
465 }
466
467 pub fn is_connected(&self) -> bool {
468 self.connected.load(std::sync::atomic::Ordering::SeqCst)
469 }
470}
471
472pub struct WebSocketTransport {
478 write: Arc<Mutex<futures::stream::SplitSink<
479 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
480 tokio_tungstenite::tungstenite::Message
481 >>>,
482 pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
483 connected: Arc<std::sync::atomic::AtomicBool>,
484}
485
486impl WebSocketTransport {
487 pub async fn connect(url: &str) -> Result<Self> {
488 use futures::StreamExt;
489 use tokio_tungstenite::connect_async;
490
491 let (ws_stream, _) = connect_async(url).await
492 .map_err(|e| Error::Transport(format!("WebSocket connect failed: {}", e)))?;
493
494 let (write, mut read) = ws_stream.split();
495 let pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
496 Arc::new(RwLock::new(HashMap::new()));
497 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
498
499 let pending_clone = pending.clone();
500 let connected_clone = connected.clone();
501 tokio::spawn(async move {
502 while let Some(msg) = read.next().await {
503 match msg {
504 Ok(tokio_tungstenite::tungstenite::Message::Text(text)) => {
505 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
506 let id_str = match &response.id {
507 Value::Number(n) => n.to_string(),
508 Value::String(s) => s.clone(),
509 _ => continue,
510 };
511
512 let mut pending = pending_clone.write().await;
513 if let Some(tx) = pending.remove(&id_str) {
514 let _ = tx.send(response);
515 }
516 }
517 }
518 Ok(tokio_tungstenite::tungstenite::Message::Close(_)) => break,
519 Err(_) => break,
520 _ => {}
521 }
522 }
523 connected_clone.store(false, std::sync::atomic::Ordering::SeqCst);
524 });
525
526 Ok(Self { write: Arc::new(Mutex::new(write)), pending, connected })
527 }
528
529 pub async fn request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
530 use futures::SinkExt;
531 use tokio_tungstenite::tungstenite::Message;
532
533 let id_str = match &req.id {
534 Value::Number(n) => n.to_string(),
535 Value::String(s) => s.clone(),
536 _ => return Err(Error::Protocol("invalid request id".into())),
537 };
538
539 let (tx, rx) = oneshot::channel();
540 { self.pending.write().await.insert(id_str.clone(), tx); }
541
542 let json = serde_json::to_string(&req)?;
543 { self.write.lock().await.send(Message::Text(json.into())).await
544 .map_err(|e| Error::Transport(format!("WebSocket send failed: {}", e)))?; }
545
546 match timeout(Duration::from_secs(30), rx).await {
547 Ok(Ok(response)) => Ok(response),
548 Ok(Err(_)) => Err(Error::Transport("response channel closed".into())),
549 Err(_) => { self.pending.write().await.remove(&id_str); Err(Error::Transport("request timeout".into())) }
550 }
551 }
552
553 pub async fn notify(&self, notif: JsonRpcNotification) -> Result<()> {
554 use futures::SinkExt;
555 use tokio_tungstenite::tungstenite::Message;
556
557 let json = serde_json::to_string(¬if)?;
558 self.write.lock().await.send(Message::Text(json.into())).await
559 .map_err(|e| Error::Transport(format!("WebSocket send failed: {}", e)))
560 }
561
562 pub async fn close(&self) -> Result<()> {
563 use futures::SinkExt;
564 use tokio_tungstenite::tungstenite::Message;
565 let _ = self.write.lock().await.send(Message::Close(None)).await;
566 self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
567 Ok(())
568 }
569
570 pub fn is_connected(&self) -> bool {
571 self.connected.load(std::sync::atomic::Ordering::SeqCst)
572 }
573}
574
575enum McpClientTransport {
580 Stdio(StdioTransport),
581 Http(HttpTransport),
582 WebSocket(WebSocketTransport),
583}
584
585pub struct McpClient {
587 transport: McpClientTransport,
588 server_info: RwLock<Option<McpServerInfo>>,
589 capabilities: RwLock<McpCapabilities>,
590 tools: RwLock<Vec<McpTool>>,
591 resources: RwLock<Vec<McpResource>>,
592 prompts: RwLock<Vec<McpPrompt>>,
593 request_id: std::sync::atomic::AtomicU64,
594}
595
596impl McpClient {
597 fn next_id(&self) -> Value {
598 Value::Number(self.request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst).into())
599 }
600
601 async fn send_request(&self, method: &str, params: Option<Value>) -> Result<JsonRpcResponse> {
602 let req = JsonRpcRequest {
603 jsonrpc: "2.0".to_string(),
604 id: self.next_id(),
605 method: method.to_string(),
606 params,
607 };
608 match &self.transport {
609 McpClientTransport::Stdio(t) => t.request(req).await,
610 McpClientTransport::Http(t) => t.request(req).await,
611 McpClientTransport::WebSocket(t) => t.request(req).await,
612 }
613 }
614
615 async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
616 let notif = JsonRpcNotification {
617 jsonrpc: "2.0".to_string(),
618 method: method.to_string(),
619 params,
620 };
621 match &self.transport {
622 McpClientTransport::Stdio(t) => t.notify(notif).await,
623 McpClientTransport::Http(t) => t.notify(notif).await,
624 McpClientTransport::WebSocket(t) => t.notify(notif).await,
625 }
626 }
627
628 pub async fn connect_stdio(command: &str, args: &[String], env: Option<&HashMap<String, String>>) -> Result<Self> {
629 let transport = StdioTransport::spawn(command, args, env).await?;
630 let client = Self {
631 transport: McpClientTransport::Stdio(transport),
632 server_info: RwLock::new(None),
633 capabilities: RwLock::new(McpCapabilities::default()),
634 tools: RwLock::new(Vec::new()),
635 resources: RwLock::new(Vec::new()),
636 prompts: RwLock::new(Vec::new()),
637 request_id: std::sync::atomic::AtomicU64::new(1),
638 };
639 client.initialize().await?;
640 Ok(client)
641 }
642
643 pub async fn connect_http(endpoint: &str, auth: Option<Auth>) -> Result<Self> {
644 let transport = HttpTransport::new(endpoint, auth)?;
645 let client = Self {
646 transport: McpClientTransport::Http(transport),
647 server_info: RwLock::new(None),
648 capabilities: RwLock::new(McpCapabilities::default()),
649 tools: RwLock::new(Vec::new()),
650 resources: RwLock::new(Vec::new()),
651 prompts: RwLock::new(Vec::new()),
652 request_id: std::sync::atomic::AtomicU64::new(1),
653 };
654 client.initialize().await?;
655 Ok(client)
656 }
657
658 pub async fn connect_websocket(url: &str) -> Result<Self> {
659 let transport = WebSocketTransport::connect(url).await?;
660 let client = Self {
661 transport: McpClientTransport::WebSocket(transport),
662 server_info: RwLock::new(None),
663 capabilities: RwLock::new(McpCapabilities::default()),
664 tools: RwLock::new(Vec::new()),
665 resources: RwLock::new(Vec::new()),
666 prompts: RwLock::new(Vec::new()),
667 request_id: std::sync::atomic::AtomicU64::new(1),
668 };
669 client.initialize().await?;
670 Ok(client)
671 }
672
673 async fn initialize(&self) -> Result<()> {
674 let params = json!({
675 "protocolVersion": "2024-11-05",
676 "capabilities": { "roots": { "listChanged": true }, "sampling": {} },
677 "clientInfo": { "name": "zap-gateway", "version": env!("CARGO_PKG_VERSION") }
678 });
679
680 let response = self.send_request("initialize", Some(params)).await?;
681 if let Some(error) = response.error {
682 return Err(Error::Protocol(format!("initialize failed: {}", error.message)));
683 }
684
685 if let Some(result) = response.result {
686 if let Some(server_info) = result.get("serverInfo") {
687 *self.server_info.write().await = serde_json::from_value(server_info.clone()).ok();
688 }
689 if let Some(caps) = result.get("capabilities") {
690 *self.capabilities.write().await = serde_json::from_value(caps.clone()).unwrap_or_default();
691 }
692 }
693
694 self.send_notification("notifications/initialized", None).await?;
695 self.refresh_all().await?;
696 Ok(())
697 }
698
699 pub async fn refresh_all(&self) -> Result<()> {
700 let caps = self.capabilities.read().await.clone();
701 if caps.tools.is_some() { let _ = self.refresh_tools().await; }
702 if caps.resources.is_some() { let _ = self.refresh_resources().await; }
703 if caps.prompts.is_some() { let _ = self.refresh_prompts().await; }
704 Ok(())
705 }
706
707 pub async fn refresh_tools(&self) -> Result<()> {
708 let response = self.send_request("tools/list", None).await?;
709 if let Some(result) = response.result {
710 if let Some(tools_val) = result.get("tools") {
711 *self.tools.write().await = serde_json::from_value(tools_val.clone()).unwrap_or_default();
712 }
713 }
714 Ok(())
715 }
716
717 pub async fn refresh_resources(&self) -> Result<()> {
718 let response = self.send_request("resources/list", None).await?;
719 if let Some(result) = response.result {
720 if let Some(resources_val) = result.get("resources") {
721 *self.resources.write().await = serde_json::from_value(resources_val.clone()).unwrap_or_default();
722 }
723 }
724 Ok(())
725 }
726
727 pub async fn refresh_prompts(&self) -> Result<()> {
728 let response = self.send_request("prompts/list", None).await?;
729 if let Some(result) = response.result {
730 if let Some(prompts_val) = result.get("prompts") {
731 *self.prompts.write().await = serde_json::from_value(prompts_val.clone()).unwrap_or_default();
732 }
733 }
734 Ok(())
735 }
736
737 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
738 let params = json!({ "name": name, "arguments": arguments });
739 let response = self.send_request("tools/call", Some(params)).await?;
740 if let Some(error) = response.error {
741 return Err(Error::ToolCallFailed(format!("{}: {}", name, error.message)));
742 }
743 response.result.ok_or_else(|| Error::Protocol("empty tool result".into()))
744 }
745
746 pub async fn read_resource(&self, uri: &str) -> Result<Value> {
747 let params = json!({ "uri": uri });
748 let response = self.send_request("resources/read", Some(params)).await?;
749 if let Some(error) = response.error {
750 return Err(Error::ResourceNotFound(format!("{}: {}", uri, error.message)));
751 }
752 response.result.ok_or_else(|| Error::Protocol("empty resource result".into()))
753 }
754
755 pub async fn get_prompt(&self, name: &str, arguments: Option<Value>) -> Result<Value> {
756 let params = json!({ "name": name, "arguments": arguments.unwrap_or(json!({})) });
757 let response = self.send_request("prompts/get", Some(params)).await?;
758 if let Some(error) = response.error {
759 return Err(Error::Protocol(format!("prompt {} failed: {}", name, error.message)));
760 }
761 response.result.ok_or_else(|| Error::Protocol("empty prompt result".into()))
762 }
763
764 pub async fn tools(&self) -> Vec<McpTool> { self.tools.read().await.clone() }
765 pub async fn resources(&self) -> Vec<McpResource> { self.resources.read().await.clone() }
766 pub async fn prompts(&self) -> Vec<McpPrompt> { self.prompts.read().await.clone() }
767 pub async fn server_info(&self) -> Option<McpServerInfo> { self.server_info.read().await.clone() }
768
769 pub fn is_connected(&self) -> bool {
770 match &self.transport {
771 McpClientTransport::Stdio(t) => t.is_connected(),
772 McpClientTransport::Http(t) => t.is_connected(),
773 McpClientTransport::WebSocket(t) => t.is_connected(),
774 }
775 }
776
777 pub async fn close(&self) -> Result<()> {
778 match &self.transport {
779 McpClientTransport::Stdio(t) => t.close().await,
780 McpClientTransport::Http(t) => t.close().await,
781 McpClientTransport::WebSocket(t) => t.close().await,
782 }
783 }
784}
785
786struct ConnectedServer {
791 id: String,
792 name: String,
793 config: ServerConfig,
794 client: Option<Arc<McpClient>>,
795 status: ServerStatus,
796 last_error: Option<String>,
797 #[allow(dead_code)]
798 last_health_check: Option<Instant>,
799 reconnect_attempts: u32,
800}
801
802impl ConnectedServer {
803 fn new(id: String, name: String, config: ServerConfig) -> Self {
804 Self { id, name, config, client: None, status: ServerStatus::Disconnected,
805 last_error: None, last_health_check: None, reconnect_attempts: 0 }
806 }
807}
808
809pub struct Gateway {
815 config: Config,
816 servers: Arc<RwLock<HashMap<String, ConnectedServer>>>,
817 tool_routing: Arc<RwLock<HashMap<String, String>>>,
818 resource_routing: Arc<RwLock<HashMap<String, String>>>,
819 prompt_routing: Arc<RwLock<HashMap<String, String>>>,
820 shutdown_tx: Option<mpsc::Sender<()>>,
821}
822
823#[derive(Debug, Clone)]
825pub struct ServerInfo {
826 pub id: String,
827 pub name: String,
828 pub url: String,
829 pub status: ServerStatus,
830 pub tools_count: usize,
831 pub resources_count: usize,
832 pub prompts_count: usize,
833 pub last_error: Option<String>,
834}
835
836impl Gateway {
837 pub fn new(config: Config) -> Self {
838 Self {
839 config,
840 servers: Arc::new(RwLock::new(HashMap::new())),
841 tool_routing: Arc::new(RwLock::new(HashMap::new())),
842 resource_routing: Arc::new(RwLock::new(HashMap::new())),
843 prompt_routing: Arc::new(RwLock::new(HashMap::new())),
844 shutdown_tx: None,
845 }
846 }
847
848 fn generate_id() -> String {
849 use std::time::{SystemTime, UNIX_EPOCH};
850 format!("{:x}", SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos())
851 }
852
853 pub async fn add_server(&self, name: &str, config: ServerConfig) -> Result<String> {
854 let id = Self::generate_id();
855 let server = ConnectedServer::new(id.clone(), name.to_string(), config);
856 self.servers.write().await.insert(id.clone(), server);
857
858 let servers = self.servers.clone();
859 let tool_routing = self.tool_routing.clone();
860 let resource_routing = self.resource_routing.clone();
861 let prompt_routing = self.prompt_routing.clone();
862 let server_id = id.clone();
863
864 tokio::spawn(async move {
865 if let Err(e) = Self::connect_server(&servers, &tool_routing, &resource_routing, &prompt_routing, &server_id).await {
866 tracing::error!("Failed to connect to server {}: {}", server_id, e);
867 }
868 });
869
870 Ok(id)
871 }
872
873 async fn connect_server(
874 servers: &Arc<RwLock<HashMap<String, ConnectedServer>>>,
875 tool_routing: &Arc<RwLock<HashMap<String, String>>>,
876 resource_routing: &Arc<RwLock<HashMap<String, String>>>,
877 prompt_routing: &Arc<RwLock<HashMap<String, String>>>,
878 server_id: &str,
879 ) -> Result<()> {
880 let config = {
881 let mut servers = servers.write().await;
882 let server = servers.get_mut(server_id).ok_or_else(|| Error::Server(format!("server {} not found", server_id)))?;
883 server.status = ServerStatus::Connecting;
884 server.config.clone()
885 };
886
887 let client_result = match config.transport {
888 Transport::Stdio => {
889 let url = url::Url::parse(&config.url).map_err(|e| Error::Config(format!("invalid URL: {}", e)))?;
890 let command = url.path();
891 let args: Vec<String> = url.query_pairs().filter(|(k, _)| k == "arg").map(|(_, v)| v.to_string()).collect();
892 McpClient::connect_stdio(command, &args, None).await
893 }
894 Transport::Http => McpClient::connect_http(&config.url, config.auth.clone()).await,
895 Transport::WebSocket => McpClient::connect_websocket(&config.url).await,
896 Transport::Zap => return Err(Error::Transport("ZAP transport not yet implemented".into())),
897 Transport::Unix => return Err(Error::Transport("Unix transport not yet implemented".into())),
898 };
899
900 match client_result {
901 Ok(client) => {
902 let client = Arc::new(client);
903
904 { let tools = client.tools().await; let mut routing = tool_routing.write().await;
905 for tool in &tools { routing.insert(tool.name.clone(), server_id.to_string()); } }
906
907 { let resources = client.resources().await; let mut routing = resource_routing.write().await;
908 for resource in &resources {
909 if let Some(scheme) = resource.uri.split(':').next() { routing.insert(format!("{}:", scheme), server_id.to_string()); }
910 routing.insert(resource.uri.clone(), server_id.to_string());
911 } }
912
913 { let prompts = client.prompts().await; let mut routing = prompt_routing.write().await;
914 for prompt in &prompts { routing.insert(prompt.name.clone(), server_id.to_string()); } }
915
916 { let mut servers = servers.write().await;
917 if let Some(server) = servers.get_mut(server_id) {
918 server.client = Some(client);
919 server.status = ServerStatus::Connected;
920 server.last_error = None;
921 server.reconnect_attempts = 0;
922 server.last_health_check = Some(Instant::now());
923 } }
924
925 tracing::info!("Connected to MCP server: {}", server_id);
926 Ok(())
927 }
928 Err(e) => {
929 let mut servers = servers.write().await;
930 if let Some(server) = servers.get_mut(server_id) {
931 server.status = ServerStatus::Error;
932 server.last_error = Some(e.to_string());
933 server.reconnect_attempts += 1;
934 }
935 Err(e)
936 }
937 }
938 }
939
940 pub async fn remove_server(&self, id: &str) -> Result<()> {
941 let server = self.servers.write().await.remove(id);
942 if let Some(server) = server {
943 self.tool_routing.write().await.retain(|_, v| v != id);
944 self.resource_routing.write().await.retain(|_, v| v != id);
945 self.prompt_routing.write().await.retain(|_, v| v != id);
946 if let Some(client) = &server.client { let _ = client.close().await; }
947 }
948 Ok(())
949 }
950
951 pub async fn list_servers(&self) -> Vec<ServerInfo> {
952 let servers = self.servers.read().await;
953 let mut result = Vec::new();
954 for server in servers.values() {
955 let (tools_count, resources_count, prompts_count) = if let Some(client) = &server.client {
956 (client.tools().await.len(), client.resources().await.len(), client.prompts().await.len())
957 } else { (0, 0, 0) };
958 result.push(ServerInfo {
959 id: server.id.clone(), name: server.name.clone(), url: server.config.url.clone(),
960 status: server.status, tools_count, resources_count, prompts_count, last_error: server.last_error.clone(),
961 });
962 }
963 result
964 }
965
966 pub async fn server_status(&self, id: &str) -> Option<ServerStatus> {
967 self.servers.read().await.get(id).map(|s| s.status)
968 }
969
970 pub async fn list_tools(&self) -> Vec<McpTool> {
971 let servers = self.servers.read().await;
972 let mut tools = Vec::new();
973 for server in servers.values() {
974 if let Some(client) = &server.client {
975 if server.status == ServerStatus::Connected { tools.extend(client.tools().await); }
976 }
977 }
978 tools
979 }
980
981 pub async fn list_resources(&self) -> Vec<McpResource> {
982 let servers = self.servers.read().await;
983 let mut resources = Vec::new();
984 for server in servers.values() {
985 if let Some(client) = &server.client {
986 if server.status == ServerStatus::Connected { resources.extend(client.resources().await); }
987 }
988 }
989 resources
990 }
991
992 pub async fn list_prompts(&self) -> Vec<McpPrompt> {
993 let servers = self.servers.read().await;
994 let mut prompts = Vec::new();
995 for server in servers.values() {
996 if let Some(client) = &server.client {
997 if server.status == ServerStatus::Connected { prompts.extend(client.prompts().await); }
998 }
999 }
1000 prompts
1001 }
1002
1003 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
1004 let server_id = self.tool_routing.read().await.get(name).cloned()
1005 .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
1006 let client = self.servers.read().await.get(&server_id).and_then(|s| s.client.clone())
1007 .ok_or_else(|| Error::Server(format!("server {} not connected", server_id)))?;
1008 client.call_tool(name, arguments).await
1009 }
1010
1011 pub async fn read_resource(&self, uri: &str) -> Result<Value> {
1012 let server_id = {
1013 let routing = self.resource_routing.read().await;
1014 routing.get(uri).cloned().or_else(|| routing.iter().find(|(prefix, _)| uri.starts_with(prefix.as_str())).map(|(_, id)| id.clone()))
1015 }.ok_or_else(|| Error::ResourceNotFound(uri.to_string()))?;
1016 let client = self.servers.read().await.get(&server_id).and_then(|s| s.client.clone())
1017 .ok_or_else(|| Error::Server(format!("server {} not connected", server_id)))?;
1018 client.read_resource(uri).await
1019 }
1020
1021 pub async fn get_prompt(&self, name: &str, arguments: Option<Value>) -> Result<Value> {
1022 let server_id = self.prompt_routing.read().await.get(name).cloned()
1023 .ok_or_else(|| Error::Protocol(format!("prompt {} not found", name)))?;
1024 let client = self.servers.read().await.get(&server_id).and_then(|s| s.client.clone())
1025 .ok_or_else(|| Error::Server(format!("server {} not connected", server_id)))?;
1026 client.get_prompt(name, arguments).await
1027 }
1028
1029 pub async fn run(&mut self) -> Result<()> {
1030 let addr = format!("{}:{}", self.config.listen, self.config.port);
1031 tracing::info!("ZAP gateway starting on {}", addr);
1032
1033 for server_config in self.config.servers.clone() {
1034 let name = server_config.name.clone();
1035 match self.add_server(&name, server_config).await {
1036 Ok(id) => tracing::info!("Added server {} with id {}", name, id),
1037 Err(e) => tracing::error!("Failed to add server {}: {}", name, e),
1038 }
1039 }
1040
1041 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
1042 self.shutdown_tx = Some(shutdown_tx);
1043
1044 let servers = self.servers.clone();
1045 let tool_routing = self.tool_routing.clone();
1046 let resource_routing = self.resource_routing.clone();
1047 let prompt_routing = self.prompt_routing.clone();
1048
1049 let health_task = tokio::spawn(async move {
1050 let mut check_interval = interval(Duration::from_secs(30));
1051 loop {
1052 check_interval.tick().await;
1053 let server_ids: Vec<String> = servers.read().await.keys().cloned().collect();
1054 for server_id in server_ids {
1055 let (needs_reconnect, client) = {
1056 let servers = servers.read().await;
1057 if let Some(server) = servers.get(&server_id) {
1058 let needs_reconnect = match server.status {
1059 ServerStatus::Error | ServerStatus::Disconnected => true,
1060 ServerStatus::Connected => server.client.as_ref().map(|c| !c.is_connected()).unwrap_or(true),
1061 _ => false,
1062 };
1063 (needs_reconnect, server.client.clone())
1064 } else { (false, None) }
1065 };
1066
1067 if needs_reconnect {
1068 tracing::info!("Health check: reconnecting {}", server_id);
1069 { servers.write().await.get_mut(&server_id).map(|s| s.status = ServerStatus::Reconnecting); }
1070 let _ = Self::connect_server(&servers, &tool_routing, &resource_routing, &prompt_routing, &server_id).await;
1071 } else if let Some(client) = client {
1072 let _ = client.refresh_all().await;
1073 }
1074 }
1075 }
1076 });
1077
1078 tokio::select! {
1079 _ = shutdown_rx.recv() => { tracing::info!("Shutdown signal received"); }
1080 _ = tokio::signal::ctrl_c() => { tracing::info!("Ctrl+C received"); }
1081 }
1082
1083 health_task.abort();
1084 for id in self.servers.read().await.keys().cloned().collect::<Vec<_>>() { let _ = self.remove_server(&id).await; }
1085 tracing::info!("Gateway shutdown complete");
1086 Ok(())
1087 }
1088
1089 pub async fn shutdown(&self) -> Result<()> {
1090 if let Some(tx) = &self.shutdown_tx { let _ = tx.send(()).await; }
1091 Ok(())
1092 }
1093}
1094
1095#[cfg(test)]
1100mod tests {
1101 use super::*;
1102
1103 #[test]
1104 fn test_json_rpc_request_serialize() {
1105 let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: json!(1), method: "tools/list".to_string(), params: None };
1106 let json = serde_json::to_string(&req).unwrap();
1107 assert!(json.contains("\"jsonrpc\":\"2.0\""));
1108 assert!(json.contains("\"method\":\"tools/list\""));
1109 }
1110
1111 #[test]
1112 fn test_json_rpc_response_deserialize() {
1113 let json = r#"{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}"#;
1114 let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
1115 assert_eq!(resp.jsonrpc, "2.0");
1116 assert!(resp.result.is_some());
1117 }
1118
1119 #[test]
1120 fn test_mcp_tool_deserialize() {
1121 let json = r#"{"name": "calculator", "description": "Perform calculations", "inputSchema": {"type": "object"}}"#;
1122 let tool: McpTool = serde_json::from_str(json).unwrap();
1123 assert_eq!(tool.name, "calculator");
1124 }
1125
1126 #[tokio::test]
1127 async fn test_gateway_create() {
1128 let config = Config::default();
1129 let gateway = Gateway::new(config);
1130 assert!(gateway.list_servers().await.is_empty());
1131 }
1132
1133 #[tokio::test]
1134 async fn test_gateway_add_remove_server() {
1135 let config = Config::default();
1136 let gateway = Gateway::new(config);
1137 let server_config = ServerConfig { name: "test".to_string(), url: "http://localhost:8080".to_string(),
1138 transport: Transport::Http, timeout: 30000, auth: None };
1139 let id = gateway.add_server("test", server_config).await.unwrap();
1140 assert!(!id.is_empty());
1141 tokio::time::sleep(Duration::from_millis(10)).await;
1142 assert_eq!(gateway.list_servers().await.len(), 1);
1143 gateway.remove_server(&id).await.unwrap();
1144 assert!(gateway.list_servers().await.is_empty());
1145 }
1146
1147 #[test]
1148 fn test_server_status_display() {
1149 assert_eq!(ServerStatus::Connecting.to_string(), "connecting");
1150 assert_eq!(ServerStatus::Connected.to_string(), "connected");
1151 }
1152}