Skip to main content

bamboo_engine/mcp/protocol/
client.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use tokio::sync::{mpsc, oneshot, RwLock};
6use tracing::{error, trace, warn};
7
8use crate::mcp::error::{McpError, Result};
9use crate::mcp::protocol::models::*;
10use crate::mcp::types::{McpCallResult, McpTool};
11
12/// Transport trait for MCP communication
13#[async_trait]
14pub trait McpTransport: Send + Sync {
15    async fn connect(&mut self) -> Result<()>;
16    async fn disconnect(&mut self) -> Result<()>;
17    async fn send(&self, message: String) -> Result<()>;
18    async fn receive(&self) -> Result<Option<String>>;
19    fn is_connected(&self) -> bool;
20}
21
22/// Pending request waiting for response
23struct PendingRequest {
24    sender: oneshot::Sender<Result<JsonRpcResponse>>,
25}
26
27/// MCP protocol client
28pub struct McpProtocolClient {
29    transport: Arc<RwLock<Box<dyn McpTransport>>>,
30    next_id: AtomicU64,
31    pending_requests: Arc<RwLock<std::collections::HashMap<u64, PendingRequest>>>,
32    message_handler: Option<tokio::task::JoinHandle<()>>,
33    notification_tx: mpsc::Sender<JsonRpcNotification>,
34    notification_rx: Arc<RwLock<mpsc::Receiver<JsonRpcNotification>>>,
35}
36
37impl McpProtocolClient {
38    pub fn new(transport: Box<dyn McpTransport>) -> Self {
39        let (notification_tx, notification_rx) = mpsc::channel(100);
40        Self {
41            transport: Arc::new(RwLock::new(transport)),
42            next_id: AtomicU64::new(1),
43            pending_requests: Arc::new(RwLock::new(std::collections::HashMap::new())),
44            message_handler: None,
45            notification_tx,
46            notification_rx: Arc::new(RwLock::new(notification_rx)),
47        }
48    }
49
50    pub async fn connect(&mut self) -> Result<()> {
51        let mut transport = self.transport.write().await;
52        transport.connect().await?;
53        drop(transport);
54
55        // Start message handler
56        self.start_message_handler();
57
58        Ok(())
59    }
60
61    pub async fn disconnect(&mut self) -> Result<()> {
62        if let Some(handler) = self.message_handler.take() {
63            handler.abort();
64        }
65
66        let mut transport = self.transport.write().await;
67        transport.disconnect().await
68    }
69
70    fn start_message_handler(&mut self) {
71        let transport = self.transport.clone();
72        let pending_requests = self.pending_requests.clone();
73        let notification_tx = self.notification_tx.clone();
74
75        let handler = tokio::spawn(async move {
76            loop {
77                let transport = transport.read().await;
78                if !transport.is_connected() {
79                    break;
80                }
81
82                match transport.receive().await {
83                    Ok(Some(message)) => {
84                        // Raw inbound wire messages can be extremely noisy and may contain secrets.
85                        trace!("Received message (bytes={})", message.len());
86                        if let Err(e) =
87                            Self::handle_message(&message, &pending_requests, &notification_tx)
88                                .await
89                        {
90                            warn!("Failed to handle message: {}", e);
91                        }
92                    }
93                    Ok(None) => {
94                        // No message available, continue
95                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
96                    }
97                    Err(e) => {
98                        error!("Transport error: {}", e);
99                        break;
100                    }
101                }
102            }
103        });
104
105        self.message_handler = Some(handler);
106    }
107
108    async fn handle_message(
109        message: &str,
110        pending_requests: &RwLock<std::collections::HashMap<u64, PendingRequest>>,
111        notification_tx: &mpsc::Sender<JsonRpcNotification>,
112    ) -> Result<()> {
113        // Try to parse as response
114        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(message) {
115            let mut pending = pending_requests.write().await;
116            if let Some(request) = pending.remove(&response.id) {
117                trace!("MCP JSON-RPC response matched (id={})", response.id);
118                let _ = request.sender.send(Ok(response));
119            } else {
120                // Common in transport/proxy bugs: responses arrive but the client never registered
121                // the request, or IDs got out of sync.
122                warn!(
123                    "MCP JSON-RPC response had no pending request (id={})",
124                    response.id
125                );
126            }
127            return Ok(());
128        }
129
130        // Try to parse as notification
131        if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(message) {
132            trace!(
133                "MCP JSON-RPC notification received (method={})",
134                notification.method
135            );
136            let _ = notification_tx.send(notification).await;
137            return Ok(());
138        }
139
140        Err(McpError::Protocol("Unknown message type".to_string()))
141    }
142
143    async fn send_request(
144        &self,
145        method: &str,
146        params: Option<Value>,
147        timeout_ms: u64,
148    ) -> Result<JsonRpcResponse> {
149        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
150
151        let request = JsonRpcRequest::new(id, method, params);
152        let request_json = serde_json::to_string(&request)?;
153        trace!(
154            "MCP JSON-RPC request send (id={}, method={}, timeout_ms={})",
155            id,
156            method,
157            timeout_ms
158        );
159
160        let (tx, rx) = oneshot::channel();
161        {
162            let mut pending = self.pending_requests.write().await;
163            pending.insert(id, PendingRequest { sender: tx });
164        }
165
166        let transport = self.transport.read().await;
167        if let Err(e) = transport.send(request_json).await {
168            // Avoid leaking pending requests on send failure.
169            self.pending_requests.write().await.remove(&id);
170            warn!(
171                "MCP JSON-RPC request send failed (id={}, method={}): {}",
172                id, method, e
173            );
174            return Err(e);
175        }
176        drop(transport);
177
178        match tokio::time::timeout(tokio::time::Duration::from_millis(timeout_ms), rx).await {
179            Ok(Ok(Ok(response))) => {
180                if let Some(error) = response.error {
181                    Err(McpError::Protocol(format!(
182                        "{}: {}",
183                        error.code, error.message
184                    )))
185                } else {
186                    Ok(response)
187                }
188            }
189            Ok(Ok(Err(e))) => Err(e),
190            Ok(Err(_)) => Err(McpError::Disconnected),
191            Err(_) => {
192                self.pending_requests.write().await.remove(&id);
193                warn!(
194                    "MCP JSON-RPC request timed out (id={}, method={}, timeout_ms={})",
195                    id, method, timeout_ms
196                );
197                Err(McpError::Timeout(format!(
198                    "Request {} timed out after {}ms",
199                    id, timeout_ms
200                )))
201            }
202        }
203    }
204
205    pub async fn initialize(&self, timeout_ms: u64) -> Result<McpInitializeResult> {
206        let request = McpInitializeRequest::default();
207        let params = serde_json::to_value(request)?;
208
209        let response = self
210            .send_request("initialize", Some(params), timeout_ms)
211            .await?;
212
213        let result: McpInitializeResult = serde_json::from_value(
214            response
215                .result
216                .ok_or_else(|| McpError::Protocol("Missing result".to_string()))?,
217        )?;
218
219        // Send initialized notification
220        let initialized = JsonRpcNotification {
221            jsonrpc: "2.0".to_string(),
222            method: "notifications/initialized".to_string(),
223            params: None,
224        };
225        let transport = self.transport.read().await;
226        transport.send(serde_json::to_string(&initialized)?).await?;
227
228        Ok(result)
229    }
230
231    pub async fn list_tools(&self, timeout_ms: u64) -> Result<Vec<McpTool>> {
232        let response = self.send_request("tools/list", None, timeout_ms).await?;
233
234        let result: McpToolListResult = serde_json::from_value(
235            response
236                .result
237                .ok_or_else(|| McpError::Protocol("Missing result".to_string()))?,
238        )?;
239
240        Ok(result
241            .tools
242            .into_iter()
243            .map(|t| McpTool {
244                name: t.name,
245                description: t.description,
246                parameters: t.input_schema.unwrap_or_else(|| serde_json::json!({})),
247            })
248            .collect())
249    }
250
251    pub async fn call_tool(
252        &self,
253        name: &str,
254        arguments: Value,
255        timeout_ms: u64,
256    ) -> Result<McpCallResult> {
257        let request = McpToolCallRequest {
258            name: name.to_string(),
259            arguments: Some(arguments),
260        };
261        let params = serde_json::to_value(request)?;
262
263        let response = self
264            .send_request("tools/call", Some(params), timeout_ms)
265            .await?;
266
267        let result: McpToolCallResult = serde_json::from_value(
268            response
269                .result
270                .ok_or_else(|| McpError::Protocol("Missing result".to_string()))?,
271        )?;
272
273        Ok(McpCallResult {
274            content: result.content,
275            is_error: result.is_error,
276        })
277    }
278
279    pub async fn ping(&self, timeout_ms: u64) -> Result<()> {
280        self.send_request("ping", None, timeout_ms).await?;
281        Ok(())
282    }
283
284    pub async fn try_receive_notification(&self) -> Option<JsonRpcNotification> {
285        let mut rx = self.notification_rx.write().await;
286        rx.try_recv().ok()
287    }
288
289    pub async fn is_connected(&self) -> bool {
290        let transport = self.transport.read().await;
291        transport.is_connected()
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use async_trait::async_trait;
299
300    // Mock transport for testing
301    struct MockTransport {
302        connected: bool,
303        messages_sent: Arc<RwLock<Vec<String>>>,
304        messages_to_receive: Arc<RwLock<Vec<String>>>,
305    }
306
307    impl MockTransport {
308        fn new() -> Self {
309            Self {
310                connected: false,
311                messages_sent: Arc::new(RwLock::new(Vec::new())),
312                messages_to_receive: Arc::new(RwLock::new(Vec::new())),
313            }
314        }
315
316        fn with_response(message: String) -> Self {
317            let messages = Arc::new(RwLock::new(vec![message]));
318            Self {
319                connected: false,
320                messages_sent: Arc::new(RwLock::new(Vec::new())),
321                messages_to_receive: messages,
322            }
323        }
324    }
325
326    #[async_trait]
327    impl McpTransport for MockTransport {
328        async fn connect(&mut self) -> Result<()> {
329            self.connected = true;
330            Ok(())
331        }
332
333        async fn disconnect(&mut self) -> Result<()> {
334            self.connected = false;
335            Ok(())
336        }
337
338        async fn send(&self, message: String) -> Result<()> {
339            let mut sent = self.messages_sent.write().await;
340            sent.push(message);
341            Ok(())
342        }
343
344        async fn receive(&self) -> Result<Option<String>> {
345            let mut messages = self.messages_to_receive.write().await;
346            if messages.is_empty() {
347                Ok(None)
348            } else {
349                Ok(Some(messages.remove(0)))
350            }
351        }
352
353        fn is_connected(&self) -> bool {
354            self.connected
355        }
356    }
357
358    #[tokio::test]
359    async fn test_client_new() {
360        let transport = Box::new(MockTransport::new());
361        let client = McpProtocolClient::new(transport);
362        assert!(client.message_handler.is_none());
363    }
364
365    #[tokio::test]
366    async fn test_client_connect() {
367        let transport = Box::new(MockTransport::new());
368        let mut client = McpProtocolClient::new(transport);
369
370        let result = client.connect().await;
371        assert!(result.is_ok());
372        assert!(client.message_handler.is_some());
373        assert!(client.is_connected().await);
374    }
375
376    #[tokio::test]
377    async fn test_client_disconnect() {
378        let transport = Box::new(MockTransport::new());
379        let mut client = McpProtocolClient::new(transport);
380
381        client.connect().await.unwrap();
382        assert!(client.is_connected().await);
383
384        let result = client.disconnect().await;
385        assert!(result.is_ok());
386        assert!(!client.is_connected().await);
387    }
388
389    #[tokio::test]
390    async fn test_client_is_connected() {
391        let transport = Box::new(MockTransport::new());
392        let mut client = McpProtocolClient::new(transport);
393
394        assert!(!client.is_connected().await);
395        client.connect().await.unwrap();
396        assert!(client.is_connected().await);
397    }
398
399    #[test]
400    fn test_json_rpc_request_new() {
401        let request =
402            JsonRpcRequest::new(1, "test/method", Some(serde_json::json!({"key": "value"})));
403        assert_eq!(request.jsonrpc, "2.0");
404        assert_eq!(request.id, 1);
405        assert_eq!(request.method, "test/method");
406        assert!(request.params.is_some());
407    }
408
409    #[tokio::test]
410    async fn test_send_request_timeout() {
411        let transport = Box::new(MockTransport::new()); // Won't respond
412        let client = McpProtocolClient::new(transport);
413
414        let result = client.send_request("test", None, 100).await;
415        assert!(result.is_err());
416        match result.unwrap_err() {
417            McpError::Timeout(_) => {}
418            _ => panic!("Expected Timeout error"),
419        }
420    }
421
422    #[tokio::test]
423    async fn test_send_request_receives_response() {
424        let response = JsonRpcResponse {
425            jsonrpc: "2.0".to_string(),
426            id: 1,
427            result: Some(serde_json::json!({"status": "ok"})),
428            error: None,
429        };
430        let message = serde_json::to_string(&response).unwrap();
431
432        let transport = Box::new(MockTransport::with_response(message));
433        let mut client = McpProtocolClient::new(transport);
434        client.connect().await.unwrap();
435
436        let result = client
437            .send_request("test/method", None, 1000)
438            .await
439            .unwrap();
440        assert_eq!(result.id, 1);
441        assert!(result.result.is_some());
442    }
443
444    #[test]
445    fn test_pending_request() {
446        let (tx, _rx) = oneshot::channel();
447        let _pending = PendingRequest { sender: tx };
448
449        // Send a response
450        let response = JsonRpcResponse {
451            jsonrpc: "2.0".to_string(),
452            id: 1,
453            result: Some(serde_json::json!({"status": "ok"})),
454            error: None,
455        };
456
457        // Use a separate sender since tx was moved into pending
458        let (tx2, rx2): (oneshot::Sender<Result<JsonRpcResponse>>, _) = oneshot::channel();
459        tx2.send(Ok(response)).unwrap();
460
461        // Receive it
462        let result = rx2.blocking_recv().unwrap().unwrap();
463        assert_eq!(result.id, 1);
464        assert!(result.result.is_some());
465    }
466}