1use std::sync::Arc;
11
12use rmcp::model::{
13 CallToolRequestParams, CallToolResult, ClientCapabilities, ClientInfo,
14 CreateMessageRequestParams, CreateMessageResult, ErrorData, GetPromptRequestParams,
15 GetPromptResult, Implementation, ReadResourceRequestParams, ReadResourceResult, Role,
16 SamplingCapability, SamplingMessage, SamplingMessageContent, ServerInfo, Tool,
17};
18use rmcp::service::{RequestContext, RoleClient, RunningService};
19use rmcp::transport::TokioChildProcess;
20use rmcp::{ClientHandler, ServiceExt};
21use tl_errors::security::SecurityPolicy;
22
23use crate::error::McpError;
24
25const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
31
32const TOOL_CALL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
34
35const METADATA_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
37
38#[derive(Debug, Clone)]
46pub struct SamplingRequest {
47 pub messages: Vec<(String, String)>,
49 pub system_prompt: Option<String>,
51 pub max_tokens: u32,
53 pub temperature: Option<f64>,
55 pub model_hint: Option<String>,
57 pub stop_sequences: Option<Vec<String>>,
59}
60
61#[derive(Debug, Clone)]
63pub struct SamplingResponse {
64 pub model: String,
66 pub content: String,
68 pub stop_reason: Option<String>,
70}
71
72pub type SamplingCallback =
78 Arc<dyn Fn(SamplingRequest) -> Result<SamplingResponse, String> + Send + Sync>;
79
80pub struct TlClientHandler {
90 pub(crate) sampling_callback: Option<SamplingCallback>,
92}
93
94impl std::fmt::Debug for TlClientHandler {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("TlClientHandler")
97 .field("has_sampling", &self.sampling_callback.is_some())
98 .finish()
99 }
100}
101
102impl TlClientHandler {
103 pub fn new() -> Self {
105 Self {
106 sampling_callback: None,
107 }
108 }
109
110 pub fn with_sampling(mut self, cb: SamplingCallback) -> Self {
112 self.sampling_callback = Some(cb);
113 self
114 }
115}
116
117impl Default for TlClientHandler {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl ClientHandler for TlClientHandler {
124 fn get_info(&self) -> ClientInfo {
125 let mut caps = ClientCapabilities::default();
126 if self.sampling_callback.is_some() {
127 caps.sampling = Some(SamplingCapability::default());
128 }
129 ClientInfo::new(
130 caps,
131 Implementation::new("tl", env!("CARGO_PKG_VERSION"))
132 .with_title("ThinkingLanguage MCP Client"),
133 )
134 }
135
136 fn create_message(
137 &self,
138 params: CreateMessageRequestParams,
139 _context: RequestContext<RoleClient>,
140 ) -> impl Future<Output = Result<CreateMessageResult, ErrorData>> + Send + '_ {
141 let result = match &self.sampling_callback {
142 Some(cb) => {
143 let messages: Vec<(String, String)> = params
145 .messages
146 .iter()
147 .map(|m| {
148 let role = match m.role {
149 Role::User => "user".to_string(),
150 Role::Assistant => "assistant".to_string(),
151 };
152 let content: String = m
154 .content
155 .iter()
156 .filter_map(|c| c.as_text().map(|t| t.text.as_str()))
157 .collect::<Vec<_>>()
158 .join("");
159 (role, content)
160 })
161 .collect();
162
163 let model_hint = params
165 .model_preferences
166 .as_ref()
167 .and_then(|p| p.hints.as_ref())
168 .and_then(|h| h.first())
169 .and_then(|h| h.name.clone());
170
171 let req = SamplingRequest {
172 messages,
173 system_prompt: params.system_prompt.clone(),
174 max_tokens: params.max_tokens,
175 temperature: params.temperature.map(|t| t as f64),
176 model_hint,
177 stop_sequences: params.stop_sequences.clone(),
178 };
179
180 match cb(req) {
181 Ok(resp) => {
182 let mut result = CreateMessageResult::new(
183 SamplingMessage::new(
184 Role::Assistant,
185 SamplingMessageContent::text(resp.content),
186 ),
187 resp.model,
188 );
189 if let Some(reason) = resp.stop_reason {
190 result = result.with_stop_reason(reason);
191 }
192 Ok(result)
193 }
194 Err(e) => Err(ErrorData::internal_error(e, None)),
195 }
196 }
197 None => Err(ErrorData::method_not_found::<
198 rmcp::model::CreateMessageRequestMethod,
199 >()),
200 };
201 std::future::ready(result)
202 }
203}
204
205pub struct McpClient {
221 runtime: Arc<tokio::runtime::Runtime>,
223 service: Option<RunningService<RoleClient, TlClientHandler>>,
225 server_info: Option<ServerInfo>,
227}
228
229impl std::fmt::Debug for McpClient {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("McpClient")
232 .field("connected", &self.is_connected())
233 .field("server_info", &self.server_info)
234 .finish()
235 }
236}
237
238impl McpClient {
239 pub fn connect(
257 command: &str,
258 args: &[String],
259 security_policy: Option<&SecurityPolicy>,
260 ) -> Result<Self, McpError> {
261 Self::connect_with_sampling(command, args, security_policy, None)
262 }
263
264 pub fn connect_with_sampling(
279 command: &str,
280 args: &[String],
281 security_policy: Option<&SecurityPolicy>,
282 sampling_cb: Option<SamplingCallback>,
283 ) -> Result<Self, McpError> {
284 if let Some(policy) = security_policy
286 && !policy.check_command(command)
287 {
288 return Err(McpError::PermissionDenied(format!(
289 "Command '{}' is not allowed by security policy",
290 command
291 )));
292 }
293
294 let runtime = tokio::runtime::Builder::new_multi_thread()
296 .enable_all()
297 .build()
298 .map_err(|e| McpError::RuntimeError(e.to_string()))?;
299 let runtime = Arc::new(runtime);
300
301 let handler = match sampling_cb {
303 Some(cb) => TlClientHandler::new().with_sampling(cb),
304 None => TlClientHandler::new(),
305 };
306
307 let (service, server_info) = runtime.block_on(async {
309 let mut cmd = tokio::process::Command::new(command);
311 cmd.args(args);
312
313 let transport = TokioChildProcess::new(cmd).map_err(|e| {
315 McpError::ConnectionFailed(format!("Failed to spawn '{}': {}", command, e))
316 })?;
317
318 match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
320 Ok(Ok(service)) => {
321 let server_info = service.peer().peer_info().cloned();
322 Ok::<_, McpError>((service, server_info))
323 }
324 Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
325 "Handshake failed: {}",
326 e
327 ))),
328 Err(_) => Err(McpError::Timeout),
329 }
330 })?;
331
332 Ok(McpClient {
333 runtime,
334 service: Some(service),
335 server_info,
336 })
337 }
338
339 pub fn connect_with_runtime(
344 command: &str,
345 args: &[String],
346 security_policy: Option<&SecurityPolicy>,
347 runtime: Arc<tokio::runtime::Runtime>,
348 ) -> Result<Self, McpError> {
349 Self::connect_with_runtime_and_sampling(command, args, security_policy, runtime, None)
350 }
351
352 pub fn connect_with_runtime_and_sampling(
354 command: &str,
355 args: &[String],
356 security_policy: Option<&SecurityPolicy>,
357 runtime: Arc<tokio::runtime::Runtime>,
358 sampling_cb: Option<SamplingCallback>,
359 ) -> Result<Self, McpError> {
360 if let Some(policy) = security_policy
362 && !policy.check_command(command)
363 {
364 return Err(McpError::PermissionDenied(format!(
365 "Command '{}' is not allowed by security policy",
366 command
367 )));
368 }
369
370 let handler = match sampling_cb {
372 Some(cb) => TlClientHandler::new().with_sampling(cb),
373 None => TlClientHandler::new(),
374 };
375
376 let (service, server_info) = runtime.block_on(async {
378 let mut cmd = tokio::process::Command::new(command);
379 cmd.args(args);
380
381 let transport = TokioChildProcess::new(cmd).map_err(|e| {
382 McpError::ConnectionFailed(format!("Failed to spawn '{}': {}", command, e))
383 })?;
384
385 match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
387 Ok(Ok(service)) => {
388 let server_info = service.peer().peer_info().cloned();
389 Ok::<_, McpError>((service, server_info))
390 }
391 Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
392 "Handshake failed: {}",
393 e
394 ))),
395 Err(_) => Err(McpError::Timeout),
396 }
397 })?;
398
399 Ok(McpClient {
400 runtime,
401 service: Some(service),
402 server_info,
403 })
404 }
405
406 pub fn connect_http(url: &str) -> Result<Self, McpError> {
418 Self::connect_http_with_sampling(url, None)
419 }
420
421 pub fn connect_http_with_sampling(
423 url: &str,
424 sampling_cb: Option<SamplingCallback>,
425 ) -> Result<Self, McpError> {
426 let rt = Arc::new(
427 tokio::runtime::Builder::new_multi_thread()
428 .enable_all()
429 .build()
430 .map_err(|e| McpError::RuntimeError(format!("Failed to create runtime: {e}")))?,
431 );
432 Self::connect_http_with_runtime_and_sampling(url, rt, sampling_cb)
433 }
434
435 pub fn connect_http_with_runtime(
441 url: &str,
442 runtime: Arc<tokio::runtime::Runtime>,
443 ) -> Result<Self, McpError> {
444 Self::connect_http_with_runtime_and_sampling(url, runtime, None)
445 }
446
447 pub fn connect_http_with_runtime_and_sampling(
449 url: &str,
450 runtime: Arc<tokio::runtime::Runtime>,
451 sampling_cb: Option<SamplingCallback>,
452 ) -> Result<Self, McpError> {
453 let url_str = url.to_string();
454 let handler = match sampling_cb {
455 Some(cb) => TlClientHandler::new().with_sampling(cb),
456 None => TlClientHandler::new(),
457 };
458 let (service, server_info) = runtime.block_on(async {
459 use rmcp::transport::StreamableHttpClientTransport;
460
461 let transport = StreamableHttpClientTransport::from_uri(url_str);
462 match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
463 Ok(Ok(service)) => {
464 let info = service.peer_info().cloned();
465 Ok::<_, McpError>((service, info))
466 }
467 Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
468 "HTTP connect failed: {e}"
469 ))),
470 Err(_) => Err(McpError::Timeout),
471 }
472 })?;
473
474 Ok(McpClient {
475 runtime,
476 service: Some(service),
477 server_info,
478 })
479 }
480
481 pub fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
490 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
491 self.runtime.block_on(async {
492 match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_tools()).await {
493 Ok(Ok(tools)) => Ok(tools),
494 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
495 Err(_) => Err(McpError::Timeout),
496 }
497 })
498 }
499
500 pub fn call_tool(
511 &self,
512 name: &str,
513 arguments: serde_json::Value,
514 ) -> Result<CallToolResult, McpError> {
515 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
516
517 let args_map = match arguments {
519 serde_json::Value::Object(map) => Some(map),
520 serde_json::Value::Null => None,
521 other => {
522 return Err(McpError::ProtocolError(format!(
523 "Tool arguments must be a JSON object, got: {}",
524 other
525 )));
526 }
527 };
528
529 let mut params = CallToolRequestParams::new(name.to_string());
530 if let Some(map) = args_map {
531 params = params.with_arguments(map);
532 }
533
534 let result = self.runtime.block_on(async {
535 match tokio::time::timeout(TOOL_CALL_TIMEOUT, service.peer().call_tool(params)).await {
536 Ok(Ok(r)) => Ok(r),
537 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
538 Err(_) => Err(McpError::Timeout),
539 }
540 })?;
541
542 if result.is_error == Some(true) {
544 let error_text: String = result
546 .content
547 .iter()
548 .filter_map(|c| c.raw.as_text().map(|t| t.text.as_str()))
549 .collect::<Vec<_>>()
550 .join("\n");
551 return Err(McpError::ToolError(if error_text.is_empty() {
552 "Tool returned an error".to_string()
553 } else {
554 error_text
555 }));
556 }
557
558 Ok(result)
559 }
560
561 pub fn ping(&self) -> Result<(), McpError> {
566 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
567 self.runtime.block_on(async {
568 let ping_fut = service
569 .peer()
570 .send_request(rmcp::model::ClientRequest::PingRequest(
571 rmcp::model::PingRequest {
572 method: Default::default(),
573 extensions: Default::default(),
574 },
575 ));
576 match tokio::time::timeout(METADATA_TIMEOUT, ping_fut).await {
577 Ok(Ok(_)) => Ok(()),
578 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
579 Err(_) => Err(McpError::Timeout),
580 }
581 })
582 }
583
584 pub fn list_resources(&self) -> Result<Vec<rmcp::model::Resource>, McpError> {
589 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
590 self.runtime.block_on(async {
591 match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_resources()).await
592 {
593 Ok(Ok(resources)) => Ok(resources),
594 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
595 Err(_) => Err(McpError::Timeout),
596 }
597 })
598 }
599
600 pub fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
605 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
606 let params = ReadResourceRequestParams::new(uri);
607 self.runtime.block_on(async {
608 match tokio::time::timeout(METADATA_TIMEOUT, service.peer().read_resource(params)).await
609 {
610 Ok(Ok(result)) => Ok(result),
611 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
612 Err(_) => Err(McpError::Timeout),
613 }
614 })
615 }
616
617 pub fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>, McpError> {
622 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
623 self.runtime.block_on(async {
624 match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_prompts()).await {
625 Ok(Ok(prompts)) => Ok(prompts),
626 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
627 Err(_) => Err(McpError::Timeout),
628 }
629 })
630 }
631
632 pub fn get_prompt(
637 &self,
638 name: &str,
639 arguments: Option<serde_json::Map<String, serde_json::Value>>,
640 ) -> Result<GetPromptResult, McpError> {
641 let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
642 let mut params = GetPromptRequestParams::new(name);
643 if let Some(args) = arguments {
644 params.arguments = Some(args);
645 }
646 self.runtime.block_on(async {
647 match tokio::time::timeout(METADATA_TIMEOUT, service.peer().get_prompt(params)).await {
648 Ok(Ok(result)) => Ok(result),
649 Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
650 Err(_) => Err(McpError::Timeout),
651 }
652 })
653 }
654
655 pub fn server_info(&self) -> Option<&ServerInfo> {
659 self.server_info.as_ref()
660 }
661
662 pub fn disconnect(&mut self) -> Result<(), McpError> {
667 if let Some(service) = self.service.take() {
668 self.runtime.block_on(async {
669 let _ = service.cancel().await;
671 });
672 }
673 Ok(())
674 }
675
676 pub fn is_connected(&self) -> bool {
678 self.service
679 .as_ref()
680 .map(|s| !s.is_closed())
681 .unwrap_or(false)
682 }
683}
684
685impl Drop for McpClient {
686 fn drop(&mut self) {
687 if let Some(service) = self.service.take() {
691 let rt = self.runtime.clone();
694 std::thread::spawn(move || {
696 rt.block_on(async {
697 let _ = service.cancel().await;
698 });
699 });
700 }
701 }
702}
703
704#[cfg(test)]
709mod tests {
710 use super::*;
711
712 #[test]
713 fn test_mcp_error_display() {
714 let err = McpError::PermissionDenied("npx not allowed".to_string());
715 assert_eq!(err.to_string(), "Permission denied: npx not allowed");
716
717 let err = McpError::ConnectionFailed("spawn failed".to_string());
718 assert_eq!(err.to_string(), "Connection failed: spawn failed");
719
720 let err = McpError::ProtocolError("invalid response".to_string());
721 assert_eq!(err.to_string(), "Protocol error: invalid response");
722
723 let err = McpError::ToolError("division by zero".to_string());
724 assert_eq!(err.to_string(), "Tool error: division by zero");
725
726 let err = McpError::TransportClosed;
727 assert_eq!(err.to_string(), "Transport closed");
728
729 let err = McpError::Timeout;
730 assert_eq!(err.to_string(), "Timeout");
731
732 let err = McpError::RuntimeError("thread pool exhausted".to_string());
733 assert_eq!(err.to_string(), "Runtime error: thread pool exhausted");
734 }
735
736 #[test]
737 fn test_client_handler_info_no_sampling() {
738 let handler = TlClientHandler::new();
739 let info = handler.get_info();
740
741 assert_eq!(info.client_info.name, "tl");
742 assert_eq!(info.client_info.version, env!("CARGO_PKG_VERSION"));
743 assert_eq!(
744 info.client_info.title,
745 Some("ThinkingLanguage MCP Client".to_string())
746 );
747 assert!(info.capabilities.sampling.is_none());
749 }
750
751 #[test]
752 fn test_client_handler_info_with_sampling() {
753 let cb: SamplingCallback = Arc::new(|_req| {
754 Ok(SamplingResponse {
755 model: "test".to_string(),
756 content: "hello".to_string(),
757 stop_reason: None,
758 })
759 });
760 let handler = TlClientHandler::new().with_sampling(cb);
761 let info = handler.get_info();
762
763 assert_eq!(info.client_info.name, "tl");
764 assert!(info.capabilities.sampling.is_some());
766 }
767
768 #[test]
769 fn test_sampling_callback_construction() {
770 let cb: SamplingCallback = Arc::new(|req| {
771 Ok(SamplingResponse {
772 model: "test-model".to_string(),
773 content: format!(
774 "Echo: {}",
775 req.messages.last().map(|(_, c)| c.as_str()).unwrap_or("")
776 ),
777 stop_reason: Some("endTurn".to_string()),
778 })
779 });
780 let handler = TlClientHandler::new().with_sampling(cb);
781 assert!(handler.sampling_callback.is_some());
782 }
783
784 #[test]
785 fn test_no_sampling_callback() {
786 let handler = TlClientHandler::new();
787 assert!(handler.sampling_callback.is_none());
788 }
789
790 #[test]
791 fn test_security_policy_denies_command() {
792 let mut policy = SecurityPolicy::sandbox();
793 let result = McpClient::connect("npx", &[], Some(&policy));
795 assert!(result.is_err());
796 let err = result.unwrap_err();
797 assert!(matches!(err, McpError::PermissionDenied(_)));
798
799 policy.allow_subprocess = true;
801 policy.allowed_commands = vec!["node".to_string()];
802 let result = McpClient::connect("npx", &[], Some(&policy));
803 assert!(result.is_err());
804 let err = result.unwrap_err();
805 assert!(matches!(err, McpError::PermissionDenied(_)));
806 }
807
808 #[test]
809 fn test_security_policy_allows_command() {
810 let mut policy = SecurityPolicy::sandbox();
811 policy.allow_subprocess = true;
812 policy.allowed_commands = vec!["echo".to_string()];
813
814 let result = McpClient::connect("echo", &["hello".to_string()], Some(&policy));
817 assert!(result.is_err());
818 let err = result.unwrap_err();
819 assert!(
821 matches!(err, McpError::ConnectionFailed(_)),
822 "Expected ConnectionFailed, got: {:?}",
823 err
824 );
825 }
826
827 #[test]
828 fn test_no_security_policy_allows_anything() {
829 let result = McpClient::connect("__nonexistent_mcp_server__", &[], None);
832 assert!(result.is_err());
833 let err = result.unwrap_err();
834 assert!(
835 matches!(err, McpError::ConnectionFailed(_)),
836 "Expected ConnectionFailed, got: {:?}",
837 err
838 );
839 }
840
841 #[test]
842 fn test_permissive_policy_allows_anything() {
843 let policy = SecurityPolicy::permissive();
844 let result = McpClient::connect("__nonexistent_mcp_server__", &[], Some(&policy));
845 assert!(result.is_err());
846 let err = result.unwrap_err();
847 assert!(
848 matches!(err, McpError::ConnectionFailed(_)),
849 "Expected ConnectionFailed, got: {:?}",
850 err
851 );
852 }
853}