1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use crate::{
error::{MCPError, Result},
retry::RetryConfig,
schema::*,
transport::{Transport, TransportStream},
};
/// Type for handling either a response or error from JSON-RPC
enum ResponseOrError {
Response(JSONRPCResponse),
Error(JSONRPCError),
}
/// Configuration for the MCP client
#[derive(Clone, Debug)]
pub struct ClientConfig {
/// Retry configuration for requests
pub retry: RetryConfig,
/// Default timeout for requests
pub request_timeout: Duration,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
retry: RetryConfig::default(),
request_timeout: Duration::from_secs(30),
}
}
}
/// MCP Client implementation
pub struct MCPClient {
transport_tx: Option<SplitSink<Box<dyn TransportStream>, JSONRPCMessage>>,
pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<ResponseOrError>>>>,
notification_tx: mpsc::Sender<JSONRPCNotification>,
notification_rx: Option<mpsc::Receiver<JSONRPCNotification>>,
next_request_id: Arc<Mutex<u64>>,
config: ClientConfig,
}
impl MCPClient {
/// Create a new MCP client with default configuration
pub fn new() -> Self {
Self::with_config(ClientConfig::default())
}
/// Create a new MCP client with custom configuration
pub fn with_config(config: ClientConfig) -> Self {
let (notification_tx, notification_rx) = mpsc::channel(100);
Self {
transport_tx: None,
pending_requests: Arc::new(Mutex::new(HashMap::new())),
notification_tx,
notification_rx: Some(notification_rx),
next_request_id: Arc::new(Mutex::new(1)),
config,
}
}
/// Connect using the provided transport
pub async fn connect(&mut self, mut transport: Box<dyn Transport>) -> Result<()> {
transport.connect().await?;
let stream = transport.framed()?;
// Start the message handler task before storing transport
self.start_message_handler(stream).await?;
info!("MCP client connected");
Ok(())
}
/// Initialize the connection with the server
pub async fn initialize(
&mut self,
client_info: Implementation,
capabilities: ClientCapabilities,
) -> Result<InitializeResult> {
let request = ClientRequest::Initialize {
protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
capabilities,
client_info,
};
let value = self.request(request).await?;
let result: InitializeResult = serde_json::from_value(value)?;
// Send the initialized notification to complete the handshake
self.send_notification("notifications/initialized", None)
.await?;
Ok(result)
}
/// List available tools from the server
pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
let value = self.request(ClientRequest::ListTools).await?;
let result: ListToolsResult = serde_json::from_value(value)?;
Ok(result)
}
/// Call a tool on the server
pub async fn call_tool(
&mut self,
name: String,
arguments: Option<serde_json::Value>,
) -> Result<CallToolResult> {
let arguments = arguments.map(|args| {
if let serde_json::Value::Object(map) = args {
map.into_iter().collect()
} else {
std::collections::HashMap::new()
}
});
let request = ClientRequest::CallTool { name, arguments };
let value = self.request_with_retry(request).await?;
let result: CallToolResult = serde_json::from_value(value)?;
Ok(result)
}
/// Take the notification receiver channel
pub fn take_notification_receiver(&mut self) -> Option<mpsc::Receiver<JSONRPCNotification>> {
self.notification_rx.take()
}
/// Send a request with retry logic
async fn request_with_retry(&mut self, request: ClientRequest) -> Result<serde_json::Value> {
// For now, we'll just do a single request without retry
// TODO: Implement proper retry logic that doesn't require mutable self in closure
self.request(request).await
}
/// Send a request and wait for response
async fn request(&mut self, request: ClientRequest) -> Result<serde_json::Value> {
let id = self.next_request_id().await;
let (tx, rx) = oneshot::channel();
// Store the response channel
{
let mut pending = self.pending_requests.lock().await;
pending.insert(id.clone(), tx);
}
// Create the JSON-RPC request
let jsonrpc_request = JSONRPCRequest {
jsonrpc: JSONRPC_VERSION.to_string(),
id: RequestId::String(id.clone()),
request: Request {
method: request.method().to_string(),
params: Some(RequestParams {
meta: None,
other: serde_json::to_value(&request)?
.as_object()
.unwrap_or(&serde_json::Map::new())
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
}),
},
};
self.send_message(JSONRPCMessage::Request(jsonrpc_request))
.await?;
// Wait for response with timeout
match timeout(self.config.request_timeout, rx).await {
Ok(Ok(response_or_error)) => {
match response_or_error {
ResponseOrError::Response(response) => {
// Extract result from the flattened Result structure
// For now, we'll return the whole result as JSON
Ok(serde_json::to_value(response.result)?)
}
ResponseOrError::Error(error) => {
// Map JSON-RPC errors to appropriate MCPError variants
match error.error.code {
METHOD_NOT_FOUND => Err(MCPError::MethodNotFound(error.error.message)),
INVALID_PARAMS => Err(MCPError::invalid_params(
request.method(),
error.error.message,
)),
_ => Err(MCPError::Protocol(format!(
"JSON-RPC error {}: {}",
error.error.code, error.error.message
))),
}
}
}
}
Ok(Err(e)) => {
error!("Response channel closed for request {}: {}", id, e);
// Remove the pending request
self.pending_requests.lock().await.remove(&id);
Err(MCPError::Protocol("Response channel closed".to_string()))
}
Err(_) => {
// Timeout occurred
error!(
"Request {} timed out after {:?}",
id, self.config.request_timeout
);
// Remove the pending request
self.pending_requests.lock().await.remove(&id);
Err(MCPError::timeout(self.config.request_timeout, id))
}
}
}
/// Send a message through the transport
async fn send_message(&mut self, message: JSONRPCMessage) -> Result<()> {
if let Some(transport_tx) = &mut self.transport_tx {
transport_tx.send(message).await?;
Ok(())
} else {
Err(MCPError::Transport("Not connected".to_string()))
}
}
/// Send a notification to the server
async fn send_notification(
&mut self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<()> {
let notification_params = params.map(|v| NotificationParams {
meta: None,
other: if let Some(obj) = v.as_object() {
obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
} else {
HashMap::new()
},
});
let notification = JSONRPCNotification {
jsonrpc: JSONRPC_VERSION.to_string(),
notification: Notification {
method: method.to_string(),
params: notification_params,
},
};
self.send_message(JSONRPCMessage::Notification(notification))
.await
}
/// Generate the next request ID
async fn next_request_id(&self) -> String {
let mut id = self.next_request_id.lock().await;
let current = *id;
*id += 1;
format!("req-{current}")
}
/// Start the background task that handles incoming messages
async fn start_message_handler(&mut self, stream: Box<dyn TransportStream>) -> Result<()> {
let pending_requests = self.pending_requests.clone();
let notification_tx = self.notification_tx.clone();
// Split the transport stream into read and write halves
let (tx, mut rx) = stream.split();
// Store the sender half for sending messages
self.transport_tx = Some(tx);
// Spawn a task to handle incoming messages
tokio::spawn(async move {
debug!("Message handler started");
while let Some(result) = rx.next().await {
match result {
Ok(message) => {
debug!("Received message: {:?}", message);
match message {
JSONRPCMessage::Response(response) => {
// Extract the ID and find the corresponding request
if let RequestId::String(id) = &response.id {
let mut pending = pending_requests.lock().await;
if let Some(tx) = pending.remove(id) {
// Send the response to the waiting request
let _ = tx.send(ResponseOrError::Response(response));
} else {
warn!("Received response for unknown request ID: {}", id);
}
}
}
JSONRPCMessage::Notification(notification) => {
// Forward notifications to the notification channel
if let Err(e) = notification_tx.send(notification).await {
error!("Failed to send notification: {}", e);
// If the receiver is dropped, we should stop
break;
}
}
JSONRPCMessage::Error(error) => {
// Handle JSON-RPC errors
if let RequestId::String(id) = &error.id {
let mut pending = pending_requests.lock().await;
if let Some(tx) = pending.remove(id) {
let _ = tx.send(ResponseOrError::Error(error));
} else {
warn!("Received error for unknown request ID: {}", id);
}
} else {
error!(
"Received error with non-string request ID: {:?}",
error.id
);
}
}
JSONRPCMessage::Request(_request) => {
// Clients typically don't receive requests from servers in MCP
warn!("Received unexpected request from server");
}
JSONRPCMessage::BatchRequest(_batch) => {
// Clients typically don't receive batch requests from servers
warn!("Received unexpected batch request from server");
}
JSONRPCMessage::BatchResponse(_batch) => {
// TODO: Handle batch responses if we implement batch requests
warn!(
"Received batch response - batch requests not yet implemented"
);
}
}
}
Err(e) => {
error!("Error receiving message: {}", e);
// On error, we should probably break the loop
break;
}
}
}
info!("Message handler stopped");
});
Ok(())
}
}
impl Default for MCPClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = MCPClient::new();
assert!(client.transport_tx.is_none());
}
#[tokio::test]
async fn test_next_request_id() {
let client = MCPClient::new();
let id1 = client.next_request_id().await;
let id2 = client.next_request_id().await;
assert_eq!(id1, "req-1");
assert_eq!(id2, "req-2");
}
}