airsprotocols_mcp/integration/
client.rs1use std::collections::HashMap;
107use std::time::Duration;
108
109use serde_json::Value;
111use tracing::{debug, error, info};
112
113use crate::integration::constants::methods;
115use crate::integration::McpError;
116use crate::protocol::{
117 CallToolRequest, CallToolResponse, ClientCapabilities, ClientInfo, Content, GetPromptRequest,
118 GetPromptResponse, InitializeRequest, InitializeResponse, JsonRpcRequest, ListPromptsRequest,
119 ListPromptsResponse, ListResourcesRequest, ListResourcesResponse, ListToolsRequest,
120 ListToolsResponse, LoggingConfig, Prompt, PromptMessage, ProtocolVersion, ReadResourceRequest,
121 ReadResourceResponse, RequestId, Resource, ServerCapabilities, SetLoggingRequest,
122 SetLoggingResponse, SubscribeResourceRequest, Tool, TransportClient,
123};
124
125pub type McpResult<T> = Result<T, McpError>;
127
128#[derive(Debug, Clone, PartialEq)]
130pub enum McpSessionState {
131 NotInitialized,
133 Initializing,
135 Ready,
137 Failed,
139}
140
141#[derive(Debug, Clone)]
143pub struct McpClientConfig {
144 pub client_info: ClientInfo,
146 pub capabilities: ClientCapabilities,
148 pub protocol_version: ProtocolVersion,
150 pub default_timeout: Duration,
152}
153
154impl Default for McpClientConfig {
155 fn default() -> Self {
156 Self {
157 client_info: ClientInfo {
158 name: "airsprotocols-mcp-client".to_string(),
159 version: env!("CARGO_PKG_VERSION").to_string(),
160 },
161 capabilities: ClientCapabilities::default(),
162 protocol_version: ProtocolVersion::current(),
163 default_timeout: Duration::from_secs(30),
164 }
165 }
166}
167
168pub struct McpClientBuilder {
170 config: McpClientConfig,
171}
172
173impl McpClientBuilder {
174 pub fn new() -> Self {
176 Self {
177 config: McpClientConfig::default(),
178 }
179 }
180
181 pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
183 self.config.client_info = ClientInfo {
184 name: name.into(),
185 version: version.into(),
186 };
187 self
188 }
189
190 pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
192 self.config.capabilities = capabilities;
193 self
194 }
195
196 pub fn protocol_version(mut self, version: ProtocolVersion) -> Self {
198 self.config.protocol_version = version;
199 self
200 }
201
202 pub fn timeout(mut self, timeout: Duration) -> Self {
204 self.config.default_timeout = timeout;
205 self
206 }
207
208 pub fn build<T: TransportClient + 'static>(self, transport: T) -> McpClient<T> {
243 McpClient {
244 transport,
245 config: self.config,
246 session_state: McpSessionState::NotInitialized,
247 server_capabilities: None,
248 }
249 }
250}
251
252impl Default for McpClientBuilder {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258pub struct McpClient<T: TransportClient> {
260 transport: T,
262 config: McpClientConfig,
264 session_state: McpSessionState,
266 server_capabilities: Option<ServerCapabilities>,
268}
269
270impl<T: TransportClient + 'static> McpClient<T> {
271 pub async fn initialize(&mut self) -> McpResult<ServerCapabilities> {
273 info!("Starting MCP client initialization");
274
275 if matches!(self.session_state, McpSessionState::Ready) {
277 return Err(McpError::AlreadyConnected);
278 }
279
280 if !self.transport.is_ready() {
282 return Err(McpError::NotConnected);
283 }
284
285 self.session_state = McpSessionState::Initializing;
286
287 let result = self.perform_initialize().await;
288
289 match &result {
290 Ok(capabilities) => {
291 self.session_state = McpSessionState::Ready;
292 self.server_capabilities = Some(capabilities.clone());
293 info!("MCP client initialization completed successfully");
294 }
295 Err(error) => {
296 self.session_state = McpSessionState::Failed;
297 error!(%error, "MCP client initialization failed");
298 }
299 }
300
301 result
302 }
303
304 async fn perform_initialize(&mut self) -> McpResult<ServerCapabilities> {
306 debug!("Sending initialize request");
307
308 let request = InitializeRequest {
309 protocol_version: self.config.protocol_version.clone(),
310 capabilities: serde_json::to_value(&self.config.capabilities)
311 .map_err(|e| McpError::custom(format!("Serialization error: {e}")))?,
312 client_info: self.config.client_info.clone(),
313 };
314
315 let json_request = JsonRpcRequest {
316 jsonrpc: "2.0".to_string(),
317 method: methods::INITIALIZE.to_string(),
318 params: Some(
319 serde_json::to_value(&request)
320 .map_err(|e| McpError::custom(format!("Serialization error: {e}")))?,
321 ),
322 id: RequestId::new_number(1),
323 };
324
325 debug!("Calling transport with initialize request");
326 let response = self
327 .transport
328 .call(json_request)
329 .await
330 .map_err(|e| McpError::custom(format!("Transport error: {e}")))?;
331
332 debug!("Received initialize response");
333
334 if let Some(error) = response.error {
336 return Err(McpError::custom(format!("JSON-RPC error: {error:?}")));
337 }
338
339 let init_response: InitializeResponse =
341 serde_json::from_value(response.result.ok_or_else(|| McpError::InvalidResponse {
342 reason: "Missing result in initialize response".to_string(),
343 })?)
344 .map_err(|e| McpError::InvalidResponse {
345 reason: format!("Invalid initialize response: {e}"),
346 })?;
347
348 debug!(
349 protocol_version = %init_response.protocol_version,
350 "Initialization successful"
351 );
352
353 let server_capabilities: ServerCapabilities =
355 serde_json::from_value(init_response.capabilities).map_err(|e| {
356 McpError::InvalidResponse {
357 reason: format!("Invalid server capabilities: {e}"),
358 }
359 })?;
360
361 Ok(server_capabilities)
362 }
363
364 pub fn session_state(&self) -> McpSessionState {
366 self.session_state.clone()
367 }
368
369 pub fn is_ready(&self) -> bool {
371 self.transport.is_ready() && matches!(self.session_state, McpSessionState::Ready)
372 }
373
374 pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
376 self.server_capabilities.as_ref()
377 }
378
379 fn ensure_initialized(&self) -> McpResult<()> {
381 if !self.is_ready() {
382 return Err(McpError::NotConnected);
383 }
384 Ok(())
385 }
386
387 pub fn supports_capability(&self, check: impl Fn(&ServerCapabilities) -> bool) -> bool {
389 if let Some(caps) = &self.server_capabilities {
390 check(caps)
391 } else {
392 false
393 }
394 }
395
396 pub async fn list_resources(&mut self) -> McpResult<Vec<Resource>> {
400 self.ensure_initialized()?;
401
402 if !self.supports_capability(|caps| caps.resources.is_some()) {
404 return Err(McpError::UnsupportedCapability {
405 capability: "resources".to_string(),
406 });
407 }
408
409 let request = ListResourcesRequest::new();
410 let response = self.call_mcp(methods::RESOURCES_LIST, &request).await?;
411
412 let list_response: ListResourcesResponse =
413 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
414 reason: format!("Invalid list resources response: {e}"),
415 })?;
416
417 Ok(list_response.resources)
418 }
419
420 pub async fn read_resource(&mut self, uri: impl Into<String>) -> McpResult<Vec<Content>> {
422 self.ensure_initialized()?;
423 let uri = uri.into();
424
425 let request =
426 ReadResourceRequest::new(uri.clone()).map_err(|e| McpError::custom(e.to_string()))?;
427
428 let response = self.call_mcp(methods::RESOURCES_READ, &request).await?;
429
430 let read_response: ReadResourceResponse =
431 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
432 reason: format!("Invalid read resource response: {e}"),
433 })?;
434
435 Ok(read_response.contents)
436 }
437
438 pub async fn subscribe_to_resource(&mut self, uri: impl Into<String>) -> McpResult<()> {
440 self.ensure_initialized()?;
441 let uri = uri.into();
442
443 if !self.supports_capability(|caps| {
445 caps.resources
446 .as_ref()
447 .map(|r| r.subscribe.unwrap_or(false))
448 .unwrap_or(false)
449 }) {
450 return Err(McpError::UnsupportedCapability {
451 capability: "resource subscriptions".to_string(),
452 });
453 }
454
455 let request = SubscribeResourceRequest::new(uri.clone())
456 .map_err(|e| McpError::custom(e.to_string()))?;
457
458 let _response = self
459 .call_mcp(methods::RESOURCES_SUBSCRIBE, &request)
460 .await?;
461
462 Ok(())
463 }
464
465 pub async fn list_tools(&mut self) -> McpResult<Vec<Tool>> {
469 self.ensure_initialized()?;
470
471 if !self.supports_capability(|caps| caps.tools.is_some()) {
473 return Err(McpError::UnsupportedCapability {
474 capability: "tools".to_string(),
475 });
476 }
477
478 let request = ListToolsRequest::new();
479 let response = self.call_mcp(methods::TOOLS_LIST, &request).await?;
480
481 let list_response: ListToolsResponse =
482 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
483 reason: format!("Invalid list tools response: {e}"),
484 })?;
485
486 Ok(list_response.tools)
487 }
488
489 pub async fn call_tool(
491 &mut self,
492 name: impl Into<String>,
493 arguments: Option<Value>,
494 ) -> McpResult<Vec<Content>> {
495 self.ensure_initialized()?;
496 let name = name.into();
497
498 let request = CallToolRequest::new(name.clone(), arguments.unwrap_or(Value::Null));
499 let response = self.call_mcp(methods::TOOLS_CALL, &request).await?;
500
501 let call_response: CallToolResponse =
502 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
503 reason: format!("Invalid call tool response: {e}"),
504 })?;
505
506 if call_response.is_error.unwrap_or(false) {
507 use crate::protocol::errors::ProtocolError;
508 return Err(McpError::Protocol(ProtocolError::invalid_message(format!(
509 "Tool '{}' returned error: {}",
510 name,
511 call_response
512 .content
513 .first()
514 .map(|c| format!("{c:?}"))
515 .unwrap_or_else(|| "Unknown error".to_string())
516 ))));
517 }
518
519 Ok(call_response.content)
520 }
521
522 pub async fn list_prompts(&mut self) -> McpResult<Vec<Prompt>> {
526 self.ensure_initialized()?;
527
528 if !self.supports_capability(|caps| caps.prompts.is_some()) {
530 return Err(McpError::UnsupportedCapability {
531 capability: "prompts".to_string(),
532 });
533 }
534
535 let request = ListPromptsRequest::new();
536 let response = self.call_mcp(methods::PROMPTS_LIST, &request).await?;
537
538 let list_response: ListPromptsResponse =
539 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
540 reason: format!("Invalid list prompts response: {e}"),
541 })?;
542
543 Ok(list_response.prompts)
544 }
545
546 pub async fn get_prompt(
548 &mut self,
549 name: impl Into<String>,
550 arguments: HashMap<String, String>,
551 ) -> McpResult<Vec<PromptMessage>> {
552 self.ensure_initialized()?;
553 let name = name.into();
554
555 let request = GetPromptRequest::new(name.clone(), arguments);
556 let response = self.call_mcp(methods::PROMPTS_GET, &request).await?;
557
558 let prompt_response: GetPromptResponse =
559 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
560 reason: format!("Invalid get prompt response: {e}"),
561 })?;
562
563 Ok(prompt_response.messages)
564 }
565
566 pub async fn set_logging_config(&mut self, config: LoggingConfig) -> McpResult<()> {
570 self.ensure_initialized()?;
571
572 if !self.supports_capability(|caps| caps.logging.is_some()) {
574 return Err(McpError::UnsupportedCapability {
575 capability: "logging".to_string(),
576 });
577 }
578
579 let request = SetLoggingRequest::new(config.level);
580 let response = self.call_mcp(methods::LOGGING_SET_LEVEL, &request).await?;
581
582 let log_response: SetLoggingResponse =
583 serde_json::from_value(response).map_err(|e| McpError::InvalidResponse {
584 reason: format!("Invalid set logging response: {e}"),
585 })?;
586
587 if !log_response.success {
588 use crate::protocol::errors::ProtocolError;
589 return Err(McpError::Protocol(ProtocolError::invalid_message(
590 "Server rejected logging configuration".to_string(),
591 )));
592 }
593
594 Ok(())
595 }
596
597 pub async fn close(&mut self) -> McpResult<()> {
629 info!("Closing MCP client");
630
631 self.transport
633 .close()
634 .await
635 .map_err(|e| McpError::custom(format!("Transport close error: {e}")))?;
636
637 self.session_state = McpSessionState::NotInitialized;
639 self.server_capabilities = None;
640
641 info!("MCP client closed successfully");
642 Ok(())
643 }
644
645 async fn call_mcp<P: serde::Serialize>(
647 &mut self,
648 method: &str,
649 params: &P,
650 ) -> McpResult<Value> {
651 debug!(method = method, "Calling MCP method");
652
653 let params_value = serde_json::to_value(params)
654 .map_err(|e| McpError::custom(format!("Serialization error: {e}")))?;
655
656 let request = JsonRpcRequest {
657 jsonrpc: "2.0".to_string(),
658 method: method.to_string(),
659 params: Some(params_value),
660 id: RequestId::new_number(42), };
662
663 let response = self
664 .transport
665 .call(request)
666 .await
667 .map_err(|e| McpError::custom(format!("Transport error: {e}")))?;
668
669 if let Some(error) = response.error {
670 return Err(McpError::custom(format!("JSON-RPC error: {error:?}")));
671 }
672
673 debug!(method = method, "MCP method call completed successfully");
674 Ok(response.result.unwrap_or(Value::Null))
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681 use crate::protocol::{JsonRpcResponse, TransportError};
682 use async_trait::async_trait;
683
684 struct MockTransportClient {
686 ready: bool,
687 responses: HashMap<String, Value>,
688 }
689
690 impl MockTransportClient {
691 fn new() -> Self {
692 let mut responses = HashMap::new();
693
694 responses.insert(
696 "initialize".to_string(),
697 serde_json::json!({
698 "protocolVersion": "1.0.0",
699 "capabilities": {
700 "tools": { "listChanged": true },
701 "resources": { "subscribe": true, "listChanged": true },
702 "prompts": { "listChanged": true },
703 "logging": {}
704 },
705 "serverInfo": {
706 "name": "mock-server",
707 "version": "1.0.0"
708 }
709 }),
710 );
711
712 Self {
713 ready: true,
714 responses,
715 }
716 }
717 }
718
719 #[async_trait]
720 impl TransportClient for MockTransportClient {
721 type Error = TransportError;
722
723 async fn call(&mut self, request: JsonRpcRequest) -> Result<JsonRpcResponse, Self::Error> {
724 let result = self
725 .responses
726 .get(&request.method)
727 .cloned()
728 .unwrap_or(serde_json::json!({}));
729
730 Ok(JsonRpcResponse {
731 jsonrpc: "2.0".to_string(),
732 result: Some(result),
733 error: None,
734 id: Some(request.id),
735 })
736 }
737
738 fn is_ready(&self) -> bool {
739 self.ready
740 }
741
742 fn transport_type(&self) -> &'static str {
743 "mock"
744 }
745
746 async fn close(&mut self) -> Result<(), Self::Error> {
747 self.ready = false;
748 Ok(())
749 }
750 }
751
752 #[tokio::test]
753 async fn test_client_creation() {
754 let transport = MockTransportClient::new();
755 let client = McpClientBuilder::new()
756 .client_info("test-client", "1.0.0")
757 .build(transport);
758
759 assert_eq!(client.session_state(), McpSessionState::NotInitialized);
760 assert!(!client.is_ready());
761 }
762
763 #[tokio::test]
764 async fn test_initialization() {
765 let transport = MockTransportClient::new();
766 let mut client = McpClientBuilder::new().build(transport);
767
768 let capabilities = client.initialize().await.unwrap();
769
770 assert_eq!(client.session_state(), McpSessionState::Ready);
771 assert!(client.is_ready());
772 assert!(capabilities.tools.is_some());
773 assert!(capabilities.resources.is_some());
774 }
775
776 #[tokio::test]
777 async fn test_double_initialization() {
778 let transport = MockTransportClient::new();
779 let mut client = McpClientBuilder::new().build(transport);
780
781 client.initialize().await.unwrap();
783
784 let result = client.initialize().await;
786 assert!(matches!(result.unwrap_err(), McpError::AlreadyConnected));
787 }
788
789 #[tokio::test]
790 async fn test_client_close() {
791 let transport = MockTransportClient::new();
792 let mut client = McpClientBuilder::new().build(transport);
793
794 client.initialize().await.unwrap();
795 assert!(client.is_ready());
796
797 client.close().await.unwrap();
798 assert_eq!(client.session_state(), McpSessionState::NotInitialized);
799 assert!(!client.is_ready());
800 }
801}