1use anyhow::{anyhow, Result};
3use serde_json::json;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, Mutex, RwLock};
7
8use mcp_protocol::{
9 constants::{error_codes, methods, PROTOCOL_VERSION},
10 messages::{ClientCapabilities, InitializeParams, InitializeResult, JsonRpcMessage},
11 types::{
12 completion::{CompleteRequest, CompleteResponse},
13 sampling::{CreateMessageParams, CreateMessageResult},
14 tool::{ToolCallParams, ToolCallResult, ToolsListResult},
15 ClientInfo,
16 },
17};
18
19use crate::transport::Transport;
20
21#[derive(Debug, Clone, PartialEq)]
23enum ClientState {
24 Created,
25 Initializing,
26 Ready,
27 ShuttingDown,
28}
29
30struct PendingRequest {
32 response_tx: mpsc::Sender<Result<JsonRpcMessage>>,
33}
34
35pub struct ClientBuilder {
37 name: String,
38 version: String,
39 transport: Option<Box<dyn Transport>>,
40 sampling_enabled: bool,
41}
42
43impl ClientBuilder {
44 pub fn new(name: &str, version: &str) -> Self {
46 Self {
47 name: name.to_string(),
48 version: version.to_string(),
49 transport: None,
50 sampling_enabled: false,
51 }
52 }
53
54 pub fn with_sampling(mut self) -> Self {
56 self.sampling_enabled = true;
57 self
58 }
59
60 pub fn with_transport<T: Transport>(mut self, transport: T) -> Self {
62 self.transport = Some(Box::new(transport));
63 self
64 }
65
66 pub fn build(self) -> Result<Client> {
68 let transport = self
69 .transport
70 .ok_or_else(|| anyhow!("Transport is required"))?;
71
72 let capabilities = if self.sampling_enabled {
74 let mut caps = ClientCapabilities::default();
75 caps.sampling = Some(HashMap::new());
76 caps
77 } else {
78 ClientCapabilities::default()
79 };
80
81 Ok(Client {
82 name: self.name,
83 version: self.version,
84 transport,
85 sampling_enabled: self.sampling_enabled,
86 capabilities,
87 state: Arc::new(RwLock::new(ClientState::Created)),
88 next_id: Arc::new(Mutex::new(1)),
89 pending_requests: Arc::new(RwLock::new(HashMap::new())),
90 initialized_result: Arc::new(RwLock::new(None)),
91 sampling_callback: Arc::new(RwLock::new(None)),
92 })
93 }
94}
95
96pub type SamplingCallback =
98 Box<dyn Fn(CreateMessageParams) -> Result<CreateMessageResult> + Send + Sync>;
99
100pub struct Client {
102 name: String,
103 version: String,
104 transport: Box<dyn Transport>,
105 sampling_enabled: bool,
106 capabilities: ClientCapabilities,
107 state: Arc<RwLock<ClientState>>,
108 next_id: Arc<Mutex<i64>>,
109 pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
110 initialized_result: Arc<RwLock<Option<InitializeResult>>>,
111 sampling_callback: Arc<RwLock<Option<SamplingCallback>>>,
112}
113
114impl Client {
115 pub async fn initialize(&self) -> Result<InitializeResult> {
117 {
119 let state = self.state.read().await;
120 if *state != ClientState::Created {
121 return Err(anyhow!("Client already initialized"));
122 }
123 }
124
125 {
127 let mut state = self.state.write().await;
128 *state = ClientState::Initializing;
129 }
130
131 self.transport.start().await?;
133
134 let params = InitializeParams {
136 protocol_version: PROTOCOL_VERSION.to_string(),
137 capabilities: self.capabilities.clone(),
138 client_info: ClientInfo {
139 name: self.name.clone(),
140 version: self.version.clone(),
141 },
142 };
143
144 let id = self.next_request_id().await?;
146 let response = self
147 .send_request(methods::INITIALIZE, Some(json!(params)), id.to_string())
148 .await?;
149
150 match response {
151 JsonRpcMessage::Response { result, error, .. } => {
152 if let Some(error) = error {
153 return Err(anyhow!(
154 "Initialize error: {} (code: {})",
155 error.message,
156 error.code
157 ));
158 }
159
160 if let Some(result) = result {
161 let result: InitializeResult = serde_json::from_value(result)?;
162
163 {
165 let mut initialized = self.initialized_result.write().await;
166 *initialized = Some(result.clone());
167 }
168
169 self.transport
171 .send(JsonRpcMessage::notification(methods::INITIALIZED, None))
172 .await?;
173
174 {
176 let mut state = self.state.write().await;
177 *state = ClientState::Ready;
178 }
179
180 return Ok(result);
181 }
182
183 Err(anyhow!("Invalid initialize response"))
184 }
185 _ => Err(anyhow!("Invalid response type")),
186 }
187 }
188
189 pub async fn list_tools(&self) -> Result<ToolsListResult> {
191 {
193 let state = self.state.read().await;
194 if *state != ClientState::Ready {
195 return Err(anyhow!("Client not initialized"));
196 }
197 }
198
199 let id = self.next_request_id().await?;
201 let response = self
202 .send_request(methods::TOOLS_LIST, None, id.to_string())
203 .await?;
204
205 match response {
206 JsonRpcMessage::Response { result, error, .. } => {
207 if let Some(error) = error {
208 return Err(anyhow!(
209 "List tools error: {} (code: {})",
210 error.message,
211 error.code
212 ));
213 }
214
215 if let Some(result) = result {
216 let result: ToolsListResult = serde_json::from_value(result)?;
217 return Ok(result);
218 }
219
220 Err(anyhow!("Invalid list tools response"))
221 }
222 _ => Err(anyhow!("Invalid response type")),
223 }
224 }
225
226 pub async fn list_resource_templates(
228 &self,
229 ) -> Result<mcp_protocol::types::resource::ResourceTemplatesListResult> {
230 {
232 let state = self.state.read().await;
233 if *state != ClientState::Ready {
234 return Err(anyhow!("Client not initialized"));
235 }
236 }
237
238 let id = self.next_request_id().await?;
240 let response = self
241 .send_request(methods::RESOURCES_TEMPLATES_LIST, None, id.to_string())
242 .await?;
243
244 match response {
245 JsonRpcMessage::Response { result, error, .. } => {
246 if let Some(error) = error {
247 return Err(anyhow!(
248 "List resource templates error: {} (code: {})",
249 error.message,
250 error.code
251 ));
252 }
253
254 if let Some(result) = result {
255 let result: mcp_protocol::types::resource::ResourceTemplatesListResult =
256 serde_json::from_value(result)?;
257 return Ok(result);
258 }
259
260 Err(anyhow!("Invalid resource templates list response"))
261 }
262 _ => Err(anyhow!("Invalid response type")),
263 }
264 }
265
266 pub async fn complete(&self, request: CompleteRequest) -> Result<CompleteResponse> {
268 {
270 let state = self.state.read().await;
271 if *state != ClientState::Ready {
272 return Err(anyhow!("Client not initialized"));
273 }
274 }
275
276 let id = self.next_request_id().await?;
278 let response = self
279 .send_request("completion/complete", Some(json!(request)), id.to_string())
280 .await?;
281
282 match response {
283 JsonRpcMessage::Response { result, error, .. } => {
284 if let Some(error) = error {
285 return Err(anyhow!(
286 "Completion error: {} (code: {})",
287 error.message,
288 error.code
289 ));
290 }
291
292 if let Some(result) = result {
293 let result: CompleteResponse = serde_json::from_value(result)?;
294 return Ok(result);
295 }
296
297 Err(anyhow!("Invalid completion response"))
298 }
299 _ => Err(anyhow!("Invalid response type")),
300 }
301 }
302
303 pub async fn call_tool(
305 &self,
306 name: &str,
307 arguments: &serde_json::Value,
308 ) -> Result<ToolCallResult> {
309 {
311 let state = self.state.read().await;
312 if *state != ClientState::Ready {
313 return Err(anyhow!("Client not initialized"));
314 }
315 }
316
317 let params = ToolCallParams {
319 name: name.to_string(),
320 arguments: arguments.clone(),
321 };
322
323 let id = self.next_request_id().await?;
325 let response = self
326 .send_request(methods::TOOLS_CALL, Some(json!(params)), id.to_string())
327 .await?;
328
329 match response {
330 JsonRpcMessage::Response { result, error, .. } => {
331 if let Some(error) = error {
332 return Err(anyhow!(
333 "Tool call error: {} (code: {})",
334 error.message,
335 error.code
336 ));
337 }
338
339 if let Some(result) = result {
340 let result: ToolCallResult = serde_json::from_value(result)?;
341 return Ok(result);
342 }
343
344 Err(anyhow!("Invalid tool call response"))
345 }
346 _ => Err(anyhow!("Invalid response type")),
347 }
348 }
349
350 pub async fn shutdown(&self) -> Result<()> {
352 {
354 let state = self.state.read().await;
355 if *state != ClientState::Ready {
356 return Err(anyhow!("Client not initialized"));
357 }
358 }
359
360 {
362 let mut state = self.state.write().await;
363 *state = ClientState::ShuttingDown;
364 }
365
366 self.transport.close().await?;
368
369 Ok(())
370 }
371
372 pub async fn refresh_prompts(&self) -> Result<serde_json::Value> {
374 {
376 let state = self.state.read().await;
377 if *state != ClientState::Ready {
378 return Err(anyhow!("Client not initialized"));
379 }
380 }
381
382 let id = self.next_request_id().await?;
384 let response = self
385 .send_request(methods::PROMPTS_LIST, None, id.to_string())
386 .await?;
387
388 match response {
389 JsonRpcMessage::Response { result, error, .. } => {
390 if let Some(error) = error {
391 return Err(anyhow!(
392 "List prompts error: {} (code: {})",
393 error.message,
394 error.code
395 ));
396 }
397
398 if let Some(result) = result {
399 return Ok(result);
400 }
401
402 Err(anyhow!("Invalid list prompts response"))
403 }
404 _ => Err(anyhow!("Invalid response type")),
405 }
406 }
407
408 pub async fn next_request_id(&self) -> Result<i64> {
410 let mut id = self.next_id.lock().await;
411 let current = *id;
412 *id += 1;
413 Ok(current)
414 }
415
416 pub async fn send_request(
418 &self,
419 method: &str,
420 params: Option<serde_json::Value>,
421 id: String,
422 ) -> Result<JsonRpcMessage> {
423 let request = JsonRpcMessage::request(id.clone().into(), method, params);
425
426 let (tx, mut rx) = mpsc::channel(1);
428
429 {
431 let mut pending = self.pending_requests.write().await;
432 pending.insert(id.clone(), PendingRequest { response_tx: tx });
433 }
434
435 self.transport.send(request).await?;
437
438 match rx.recv().await {
440 Some(result) => {
441 let mut pending = self.pending_requests.write().await;
443 pending.remove(&id);
444
445 result
446 }
447 None => Err(anyhow!("Failed to receive response")),
448 }
449 }
450
451 pub async fn register_sampling_callback(&self, callback: SamplingCallback) -> Result<()> {
453 if !self.sampling_enabled {
454 return Err(anyhow!("Sampling is not enabled"));
455 }
456
457 let mut sampling_callback = self.sampling_callback.write().await;
458 *sampling_callback = Some(callback);
459
460 Ok(())
461 }
462
463 async fn handle_sampling_create_message(&self, message: JsonRpcMessage) -> Result<()> {
465 match message {
466 JsonRpcMessage::Request { id, params, .. } => {
467 if !self.sampling_enabled {
469 self.transport
471 .send(JsonRpcMessage::error(
472 id,
473 error_codes::SAMPLING_NOT_ENABLED,
474 "Sampling is not enabled",
475 None,
476 ))
477 .await?;
478 return Ok(());
479 }
480
481 let params: CreateMessageParams = match params {
483 Some(params) => match serde_json::from_value(params) {
484 Ok(params) => params,
485 Err(err) => {
486 self.transport
488 .send(JsonRpcMessage::error(
489 id,
490 error_codes::INVALID_PARAMS,
491 &format!("Invalid sampling parameters: {}", err),
492 None,
493 ))
494 .await?;
495 return Ok(());
496 }
497 },
498 None => {
499 self.transport
501 .send(JsonRpcMessage::error(
502 id,
503 error_codes::INVALID_PARAMS,
504 "Missing sampling parameters",
505 None,
506 ))
507 .await?;
508 return Ok(());
509 }
510 };
511
512 let callback_result = {
514 let callback = self.sampling_callback.read().await;
515 if callback.is_some() {
516 Ok(())
517 } else {
518 Err(anyhow!("No sampling callback registered"))
519 }
520 };
521
522 if let Err(_) = callback_result {
524 self.transport
526 .send(JsonRpcMessage::error(
527 id,
528 error_codes::SAMPLING_NO_CALLBACK,
529 "No sampling callback registered",
530 None,
531 ))
532 .await?;
533 return Ok(());
534 }
535
536 let result = {
539 let callback_guard = self.sampling_callback.read().await;
540 if let Some(callback) = &*callback_guard {
542 callback(params.clone())
543 } else {
544 Err(anyhow!("No sampling callback registered"))
546 }
547 };
548
549 match result {
550 Ok(result) => {
551 self.transport
553 .send(JsonRpcMessage::response(id, json!(result)))
554 .await?;
555 }
556 Err(err) => {
557 self.transport
559 .send(JsonRpcMessage::error(
560 id,
561 error_codes::SAMPLING_ERROR,
562 &format!("Sampling error: {}", err),
563 None,
564 ))
565 .await?;
566 }
567 }
568
569 Ok(())
570 }
571 _ => Err(anyhow!(
572 "Expected request message for sampling/createMessage"
573 )),
574 }
575 }
576
577 pub async fn handle_message(&self, message: JsonRpcMessage) -> Result<()> {
579 match message.clone() {
580 JsonRpcMessage::Response { ref id, .. } => {
581 let id = match id {
583 serde_json::Value::String(s) => s.clone(),
584 serde_json::Value::Number(n) => n.to_string(),
585 _ => return Err(anyhow!("Invalid response ID type")),
586 };
587
588 let pending = {
590 let pending = self.pending_requests.read().await;
591 match pending.get(&id) {
592 Some(req) => req.response_tx.clone(),
593 None => return Err(anyhow!("No pending request for ID: {}", id)),
594 }
595 };
596
597 if let Err(e) = pending.send(Ok(message)).await {
599 Err(anyhow!("Failed to send response: {}", e))
600 } else {
601 Ok(())
602 }
603 }
604 JsonRpcMessage::Notification { method, params, .. } => {
605 match method.as_str() {
607 methods::PROMPTS_LIST_CHANGED => {
609 tracing::debug!("Received notification: prompts list changed");
611
612 Ok(())
615 }
616 methods::RESOURCES_UPDATED => {
618 if let Some(params) = params {
620 if let Some(uri) = params.get("uri").and_then(|u| u.as_str()) {
621 tracing::debug!(
622 "Received notification: resource updated - URI: {}",
623 uri
624 );
625 }
626 }
627 Ok(())
628 }
629 _ => {
631 tracing::debug!("Unhandled notification: {}", method);
632 Ok(())
633 }
634 }
635 }
636 JsonRpcMessage::Request { method, .. } => match method.as_str() {
637 methods::SAMPLING_CREATE_MESSAGE => {
638 self.handle_sampling_create_message(message).await
639 }
640 _ => {
641 tracing::debug!("Unhandled server request: {}", method);
642 Ok(())
643 }
644 },
645 }
646 }
647}