bamboo_engine/mcp/protocol/
client.rs1use async_trait::async_trait;
2use serde_json::Value;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use tokio::sync::{mpsc, oneshot, RwLock};
6use tracing::{error, trace, warn};
7
8use crate::mcp::error::{McpError, Result};
9use crate::mcp::protocol::models::*;
10use crate::mcp::types::{McpCallResult, McpTool};
11
12#[async_trait]
14pub trait McpTransport: Send + Sync {
15 async fn connect(&mut self) -> Result<()>;
16 async fn disconnect(&mut self) -> Result<()>;
17 async fn send(&self, message: String) -> Result<()>;
18 async fn receive(&self) -> Result<Option<String>>;
19 fn is_connected(&self) -> bool;
20}
21
22struct PendingRequest {
24 sender: oneshot::Sender<Result<JsonRpcResponse>>,
25}
26
27pub struct McpProtocolClient {
29 transport: Arc<RwLock<Box<dyn McpTransport>>>,
30 next_id: AtomicU64,
31 pending_requests: Arc<RwLock<std::collections::HashMap<u64, PendingRequest>>>,
32 message_handler: Option<tokio::task::JoinHandle<()>>,
33 notification_tx: mpsc::Sender<JsonRpcNotification>,
34 notification_rx: Arc<RwLock<mpsc::Receiver<JsonRpcNotification>>>,
35}
36
37impl McpProtocolClient {
38 pub fn new(transport: Box<dyn McpTransport>) -> Self {
39 let (notification_tx, notification_rx) = mpsc::channel(100);
40 Self {
41 transport: Arc::new(RwLock::new(transport)),
42 next_id: AtomicU64::new(1),
43 pending_requests: Arc::new(RwLock::new(std::collections::HashMap::new())),
44 message_handler: None,
45 notification_tx,
46 notification_rx: Arc::new(RwLock::new(notification_rx)),
47 }
48 }
49
50 pub async fn connect(&mut self) -> Result<()> {
51 let mut transport = self.transport.write().await;
52 transport.connect().await?;
53 drop(transport);
54
55 self.start_message_handler();
57
58 Ok(())
59 }
60
61 pub async fn disconnect(&mut self) -> Result<()> {
62 if let Some(handler) = self.message_handler.take() {
63 handler.abort();
64 }
65
66 let mut transport = self.transport.write().await;
67 transport.disconnect().await
68 }
69
70 fn start_message_handler(&mut self) {
71 let transport = self.transport.clone();
72 let pending_requests = self.pending_requests.clone();
73 let notification_tx = self.notification_tx.clone();
74
75 let handler = tokio::spawn(async move {
76 loop {
77 let transport = transport.read().await;
78 if !transport.is_connected() {
79 break;
80 }
81
82 match transport.receive().await {
83 Ok(Some(message)) => {
84 trace!("Received message (bytes={})", message.len());
86 if let Err(e) =
87 Self::handle_message(&message, &pending_requests, ¬ification_tx)
88 .await
89 {
90 warn!("Failed to handle message: {}", e);
91 }
92 }
93 Ok(None) => {
94 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
96 }
97 Err(e) => {
98 error!("Transport error: {}", e);
99 break;
100 }
101 }
102 }
103 });
104
105 self.message_handler = Some(handler);
106 }
107
108 async fn handle_message(
109 message: &str,
110 pending_requests: &RwLock<std::collections::HashMap<u64, PendingRequest>>,
111 notification_tx: &mpsc::Sender<JsonRpcNotification>,
112 ) -> Result<()> {
113 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(message) {
115 let mut pending = pending_requests.write().await;
116 if let Some(request) = pending.remove(&response.id) {
117 trace!("MCP JSON-RPC response matched (id={})", response.id);
118 let _ = request.sender.send(Ok(response));
119 } else {
120 warn!(
123 "MCP JSON-RPC response had no pending request (id={})",
124 response.id
125 );
126 }
127 return Ok(());
128 }
129
130 if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(message) {
132 trace!(
133 "MCP JSON-RPC notification received (method={})",
134 notification.method
135 );
136 let _ = notification_tx.send(notification).await;
137 return Ok(());
138 }
139
140 Err(McpError::Protocol("Unknown message type".to_string()))
141 }
142
143 async fn send_request(
144 &self,
145 method: &str,
146 params: Option<Value>,
147 timeout_ms: u64,
148 ) -> Result<JsonRpcResponse> {
149 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
150
151 let request = JsonRpcRequest::new(id, method, params);
152 let request_json = serde_json::to_string(&request)?;
153 trace!(
154 "MCP JSON-RPC request send (id={}, method={}, timeout_ms={})",
155 id,
156 method,
157 timeout_ms
158 );
159
160 let (tx, rx) = oneshot::channel();
161 {
162 let mut pending = self.pending_requests.write().await;
163 pending.insert(id, PendingRequest { sender: tx });
164 }
165
166 let transport = self.transport.read().await;
167 if let Err(e) = transport.send(request_json).await {
168 self.pending_requests.write().await.remove(&id);
170 warn!(
171 "MCP JSON-RPC request send failed (id={}, method={}): {}",
172 id, method, e
173 );
174 return Err(e);
175 }
176 drop(transport);
177
178 match tokio::time::timeout(tokio::time::Duration::from_millis(timeout_ms), rx).await {
179 Ok(Ok(Ok(response))) => {
180 if let Some(error) = response.error {
181 Err(McpError::Protocol(format!(
182 "{}: {}",
183 error.code, error.message
184 )))
185 } else {
186 Ok(response)
187 }
188 }
189 Ok(Ok(Err(e))) => Err(e),
190 Ok(Err(_)) => Err(McpError::Disconnected),
191 Err(_) => {
192 self.pending_requests.write().await.remove(&id);
193 warn!(
194 "MCP JSON-RPC request timed out (id={}, method={}, timeout_ms={})",
195 id, method, timeout_ms
196 );
197 Err(McpError::Timeout(format!(
198 "Request {} timed out after {}ms",
199 id, timeout_ms
200 )))
201 }
202 }
203 }
204
205 pub async fn initialize(&self, timeout_ms: u64) -> Result<McpInitializeResult> {
206 let request = McpInitializeRequest::default();
207 let params = serde_json::to_value(request)?;
208
209 let response = self
210 .send_request("initialize", Some(params), timeout_ms)
211 .await?;
212
213 let result: McpInitializeResult = serde_json::from_value(
214 response
215 .result
216 .ok_or_else(|| McpError::Protocol("Missing result".to_string()))?,
217 )?;
218
219 let initialized = JsonRpcNotification {
221 jsonrpc: "2.0".to_string(),
222 method: "notifications/initialized".to_string(),
223 params: None,
224 };
225 let transport = self.transport.read().await;
226 transport.send(serde_json::to_string(&initialized)?).await?;
227
228 Ok(result)
229 }
230
231 pub async fn list_tools(&self, timeout_ms: u64) -> Result<Vec<McpTool>> {
232 let response = self.send_request("tools/list", None, timeout_ms).await?;
233
234 let result: McpToolListResult = serde_json::from_value(
235 response
236 .result
237 .ok_or_else(|| McpError::Protocol("Missing result".to_string()))?,
238 )?;
239
240 Ok(result
241 .tools
242 .into_iter()
243 .map(|t| McpTool {
244 name: t.name,
245 description: t.description,
246 parameters: t.input_schema.unwrap_or_else(|| serde_json::json!({})),
247 })
248 .collect())
249 }
250
251 pub async fn call_tool(
252 &self,
253 name: &str,
254 arguments: Value,
255 timeout_ms: u64,
256 ) -> Result<McpCallResult> {
257 let request = McpToolCallRequest {
258 name: name.to_string(),
259 arguments: Some(arguments),
260 };
261 let params = serde_json::to_value(request)?;
262
263 let response = self
264 .send_request("tools/call", Some(params), timeout_ms)
265 .await?;
266
267 let result: McpToolCallResult = serde_json::from_value(
268 response
269 .result
270 .ok_or_else(|| McpError::Protocol("Missing result".to_string()))?,
271 )?;
272
273 Ok(McpCallResult {
274 content: result.content,
275 is_error: result.is_error,
276 })
277 }
278
279 pub async fn ping(&self, timeout_ms: u64) -> Result<()> {
280 self.send_request("ping", None, timeout_ms).await?;
281 Ok(())
282 }
283
284 pub async fn try_receive_notification(&self) -> Option<JsonRpcNotification> {
285 let mut rx = self.notification_rx.write().await;
286 rx.try_recv().ok()
287 }
288
289 pub async fn is_connected(&self) -> bool {
290 let transport = self.transport.read().await;
291 transport.is_connected()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use async_trait::async_trait;
299
300 struct MockTransport {
302 connected: bool,
303 messages_sent: Arc<RwLock<Vec<String>>>,
304 messages_to_receive: Arc<RwLock<Vec<String>>>,
305 }
306
307 impl MockTransport {
308 fn new() -> Self {
309 Self {
310 connected: false,
311 messages_sent: Arc::new(RwLock::new(Vec::new())),
312 messages_to_receive: Arc::new(RwLock::new(Vec::new())),
313 }
314 }
315
316 fn with_response(message: String) -> Self {
317 let messages = Arc::new(RwLock::new(vec![message]));
318 Self {
319 connected: false,
320 messages_sent: Arc::new(RwLock::new(Vec::new())),
321 messages_to_receive: messages,
322 }
323 }
324 }
325
326 #[async_trait]
327 impl McpTransport for MockTransport {
328 async fn connect(&mut self) -> Result<()> {
329 self.connected = true;
330 Ok(())
331 }
332
333 async fn disconnect(&mut self) -> Result<()> {
334 self.connected = false;
335 Ok(())
336 }
337
338 async fn send(&self, message: String) -> Result<()> {
339 let mut sent = self.messages_sent.write().await;
340 sent.push(message);
341 Ok(())
342 }
343
344 async fn receive(&self) -> Result<Option<String>> {
345 let mut messages = self.messages_to_receive.write().await;
346 if messages.is_empty() {
347 Ok(None)
348 } else {
349 Ok(Some(messages.remove(0)))
350 }
351 }
352
353 fn is_connected(&self) -> bool {
354 self.connected
355 }
356 }
357
358 #[tokio::test]
359 async fn test_client_new() {
360 let transport = Box::new(MockTransport::new());
361 let client = McpProtocolClient::new(transport);
362 assert!(client.message_handler.is_none());
363 }
364
365 #[tokio::test]
366 async fn test_client_connect() {
367 let transport = Box::new(MockTransport::new());
368 let mut client = McpProtocolClient::new(transport);
369
370 let result = client.connect().await;
371 assert!(result.is_ok());
372 assert!(client.message_handler.is_some());
373 assert!(client.is_connected().await);
374 }
375
376 #[tokio::test]
377 async fn test_client_disconnect() {
378 let transport = Box::new(MockTransport::new());
379 let mut client = McpProtocolClient::new(transport);
380
381 client.connect().await.unwrap();
382 assert!(client.is_connected().await);
383
384 let result = client.disconnect().await;
385 assert!(result.is_ok());
386 assert!(!client.is_connected().await);
387 }
388
389 #[tokio::test]
390 async fn test_client_is_connected() {
391 let transport = Box::new(MockTransport::new());
392 let mut client = McpProtocolClient::new(transport);
393
394 assert!(!client.is_connected().await);
395 client.connect().await.unwrap();
396 assert!(client.is_connected().await);
397 }
398
399 #[test]
400 fn test_json_rpc_request_new() {
401 let request =
402 JsonRpcRequest::new(1, "test/method", Some(serde_json::json!({"key": "value"})));
403 assert_eq!(request.jsonrpc, "2.0");
404 assert_eq!(request.id, 1);
405 assert_eq!(request.method, "test/method");
406 assert!(request.params.is_some());
407 }
408
409 #[tokio::test]
410 async fn test_send_request_timeout() {
411 let transport = Box::new(MockTransport::new()); let client = McpProtocolClient::new(transport);
413
414 let result = client.send_request("test", None, 100).await;
415 assert!(result.is_err());
416 match result.unwrap_err() {
417 McpError::Timeout(_) => {}
418 _ => panic!("Expected Timeout error"),
419 }
420 }
421
422 #[tokio::test]
423 async fn test_send_request_receives_response() {
424 let response = JsonRpcResponse {
425 jsonrpc: "2.0".to_string(),
426 id: 1,
427 result: Some(serde_json::json!({"status": "ok"})),
428 error: None,
429 };
430 let message = serde_json::to_string(&response).unwrap();
431
432 let transport = Box::new(MockTransport::with_response(message));
433 let mut client = McpProtocolClient::new(transport);
434 client.connect().await.unwrap();
435
436 let result = client
437 .send_request("test/method", None, 1000)
438 .await
439 .unwrap();
440 assert_eq!(result.id, 1);
441 assert!(result.result.is_some());
442 }
443
444 #[test]
445 fn test_pending_request() {
446 let (tx, _rx) = oneshot::channel();
447 let _pending = PendingRequest { sender: tx };
448
449 let response = JsonRpcResponse {
451 jsonrpc: "2.0".to_string(),
452 id: 1,
453 result: Some(serde_json::json!({"status": "ok"})),
454 error: None,
455 };
456
457 let (tx2, rx2): (oneshot::Sender<Result<JsonRpcResponse>>, _) = oneshot::channel();
459 tx2.send(Ok(response)).unwrap();
460
461 let result = rx2.blocking_recv().unwrap().unwrap();
463 assert_eq!(result.id, 1);
464 assert!(result.result.is_some());
465 }
466}