use std::sync::Arc;
use rmcp::model::{
CallToolRequestParams, CallToolResult, ClientCapabilities, ClientInfo,
CreateMessageRequestParams, CreateMessageResult, ErrorData, GetPromptRequestParams,
GetPromptResult, Implementation, ReadResourceRequestParams, ReadResourceResult, Role,
SamplingCapability, SamplingMessage, SamplingMessageContent, ServerInfo, Tool,
};
use rmcp::service::{RequestContext, RoleClient, RunningService};
use rmcp::transport::TokioChildProcess;
use rmcp::{ClientHandler, ServiceExt};
use tl_errors::security::SecurityPolicy;
use crate::error::McpError;
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const TOOL_CALL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
const METADATA_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct SamplingRequest {
pub messages: Vec<(String, String)>,
pub system_prompt: Option<String>,
pub max_tokens: u32,
pub temperature: Option<f64>,
pub model_hint: Option<String>,
pub stop_sequences: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct SamplingResponse {
pub model: String,
pub content: String,
pub stop_reason: Option<String>,
}
pub type SamplingCallback =
Arc<dyn Fn(SamplingRequest) -> Result<SamplingResponse, String> + Send + Sync>;
pub struct TlClientHandler {
pub(crate) sampling_callback: Option<SamplingCallback>,
}
impl std::fmt::Debug for TlClientHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlClientHandler")
.field("has_sampling", &self.sampling_callback.is_some())
.finish()
}
}
impl TlClientHandler {
pub fn new() -> Self {
Self {
sampling_callback: None,
}
}
pub fn with_sampling(mut self, cb: SamplingCallback) -> Self {
self.sampling_callback = Some(cb);
self
}
}
impl Default for TlClientHandler {
fn default() -> Self {
Self::new()
}
}
impl ClientHandler for TlClientHandler {
fn get_info(&self) -> ClientInfo {
let mut caps = ClientCapabilities::default();
if self.sampling_callback.is_some() {
caps.sampling = Some(SamplingCapability::default());
}
ClientInfo::new(
caps,
Implementation::new("tl", env!("CARGO_PKG_VERSION"))
.with_title("ThinkingLanguage MCP Client"),
)
}
fn create_message(
&self,
params: CreateMessageRequestParams,
_context: RequestContext<RoleClient>,
) -> impl Future<Output = Result<CreateMessageResult, ErrorData>> + Send + '_ {
let result = match &self.sampling_callback {
Some(cb) => {
let messages: Vec<(String, String)> = params
.messages
.iter()
.map(|m| {
let role = match m.role {
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
};
let content: String = m
.content
.iter()
.filter_map(|c| c.as_text().map(|t| t.text.as_str()))
.collect::<Vec<_>>()
.join("");
(role, content)
})
.collect();
let model_hint = params
.model_preferences
.as_ref()
.and_then(|p| p.hints.as_ref())
.and_then(|h| h.first())
.and_then(|h| h.name.clone());
let req = SamplingRequest {
messages,
system_prompt: params.system_prompt.clone(),
max_tokens: params.max_tokens,
temperature: params.temperature.map(|t| t as f64),
model_hint,
stop_sequences: params.stop_sequences.clone(),
};
match cb(req) {
Ok(resp) => {
let mut result = CreateMessageResult::new(
SamplingMessage::new(
Role::Assistant,
SamplingMessageContent::text(resp.content),
),
resp.model,
);
if let Some(reason) = resp.stop_reason {
result = result.with_stop_reason(reason);
}
Ok(result)
}
Err(e) => Err(ErrorData::internal_error(e, None)),
}
}
None => Err(ErrorData::method_not_found::<
rmcp::model::CreateMessageRequestMethod,
>()),
};
std::future::ready(result)
}
}
pub struct McpClient {
runtime: Arc<tokio::runtime::Runtime>,
service: Option<RunningService<RoleClient, TlClientHandler>>,
server_info: Option<ServerInfo>,
}
impl std::fmt::Debug for McpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpClient")
.field("connected", &self.is_connected())
.field("server_info", &self.server_info)
.finish()
}
}
impl McpClient {
pub fn connect(
command: &str,
args: &[String],
security_policy: Option<&SecurityPolicy>,
) -> Result<Self, McpError> {
Self::connect_with_sampling(command, args, security_policy, None)
}
pub fn connect_with_sampling(
command: &str,
args: &[String],
security_policy: Option<&SecurityPolicy>,
sampling_cb: Option<SamplingCallback>,
) -> Result<Self, McpError> {
if let Some(policy) = security_policy
&& !policy.check_command(command)
{
return Err(McpError::PermissionDenied(format!(
"Command '{}' is not allowed by security policy",
command
)));
}
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| McpError::RuntimeError(e.to_string()))?;
let runtime = Arc::new(runtime);
let handler = match sampling_cb {
Some(cb) => TlClientHandler::new().with_sampling(cb),
None => TlClientHandler::new(),
};
let (service, server_info) = runtime.block_on(async {
let mut cmd = tokio::process::Command::new(command);
cmd.args(args);
let transport = TokioChildProcess::new(cmd).map_err(|e| {
McpError::ConnectionFailed(format!("Failed to spawn '{}': {}", command, e))
})?;
match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
Ok(Ok(service)) => {
let server_info = service.peer().peer_info().cloned();
Ok::<_, McpError>((service, server_info))
}
Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
"Handshake failed: {}",
e
))),
Err(_) => Err(McpError::Timeout),
}
})?;
Ok(McpClient {
runtime,
service: Some(service),
server_info,
})
}
pub fn connect_with_runtime(
command: &str,
args: &[String],
security_policy: Option<&SecurityPolicy>,
runtime: Arc<tokio::runtime::Runtime>,
) -> Result<Self, McpError> {
Self::connect_with_runtime_and_sampling(command, args, security_policy, runtime, None)
}
pub fn connect_with_runtime_and_sampling(
command: &str,
args: &[String],
security_policy: Option<&SecurityPolicy>,
runtime: Arc<tokio::runtime::Runtime>,
sampling_cb: Option<SamplingCallback>,
) -> Result<Self, McpError> {
if let Some(policy) = security_policy
&& !policy.check_command(command)
{
return Err(McpError::PermissionDenied(format!(
"Command '{}' is not allowed by security policy",
command
)));
}
let handler = match sampling_cb {
Some(cb) => TlClientHandler::new().with_sampling(cb),
None => TlClientHandler::new(),
};
let (service, server_info) = runtime.block_on(async {
let mut cmd = tokio::process::Command::new(command);
cmd.args(args);
let transport = TokioChildProcess::new(cmd).map_err(|e| {
McpError::ConnectionFailed(format!("Failed to spawn '{}': {}", command, e))
})?;
match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
Ok(Ok(service)) => {
let server_info = service.peer().peer_info().cloned();
Ok::<_, McpError>((service, server_info))
}
Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
"Handshake failed: {}",
e
))),
Err(_) => Err(McpError::Timeout),
}
})?;
Ok(McpClient {
runtime,
service: Some(service),
server_info,
})
}
pub fn connect_http(url: &str) -> Result<Self, McpError> {
Self::connect_http_with_sampling(url, None)
}
pub fn connect_http_with_sampling(
url: &str,
sampling_cb: Option<SamplingCallback>,
) -> Result<Self, McpError> {
let rt = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| McpError::RuntimeError(format!("Failed to create runtime: {e}")))?,
);
Self::connect_http_with_runtime_and_sampling(url, rt, sampling_cb)
}
pub fn connect_http_with_runtime(
url: &str,
runtime: Arc<tokio::runtime::Runtime>,
) -> Result<Self, McpError> {
Self::connect_http_with_runtime_and_sampling(url, runtime, None)
}
pub fn connect_http_with_runtime_and_sampling(
url: &str,
runtime: Arc<tokio::runtime::Runtime>,
sampling_cb: Option<SamplingCallback>,
) -> Result<Self, McpError> {
let url_str = url.to_string();
let handler = match sampling_cb {
Some(cb) => TlClientHandler::new().with_sampling(cb),
None => TlClientHandler::new(),
};
let (service, server_info) = runtime.block_on(async {
use rmcp::transport::StreamableHttpClientTransport;
let transport = StreamableHttpClientTransport::from_uri(url_str);
match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
Ok(Ok(service)) => {
let info = service.peer_info().cloned();
Ok::<_, McpError>((service, info))
}
Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
"HTTP connect failed: {e}"
))),
Err(_) => Err(McpError::Timeout),
}
})?;
Ok(McpClient {
runtime,
service: Some(service),
server_info,
})
}
pub fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
self.runtime.block_on(async {
match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_tools()).await {
Ok(Ok(tools)) => Ok(tools),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})
}
pub fn call_tool(
&self,
name: &str,
arguments: serde_json::Value,
) -> Result<CallToolResult, McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
let args_map = match arguments {
serde_json::Value::Object(map) => Some(map),
serde_json::Value::Null => None,
other => {
return Err(McpError::ProtocolError(format!(
"Tool arguments must be a JSON object, got: {}",
other
)));
}
};
let mut params = CallToolRequestParams::new(name.to_string());
if let Some(map) = args_map {
params = params.with_arguments(map);
}
let result = self.runtime.block_on(async {
match tokio::time::timeout(TOOL_CALL_TIMEOUT, service.peer().call_tool(params)).await {
Ok(Ok(r)) => Ok(r),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})?;
if result.is_error == Some(true) {
let error_text: String = result
.content
.iter()
.filter_map(|c| c.raw.as_text().map(|t| t.text.as_str()))
.collect::<Vec<_>>()
.join("\n");
return Err(McpError::ToolError(if error_text.is_empty() {
"Tool returned an error".to_string()
} else {
error_text
}));
}
Ok(result)
}
pub fn ping(&self) -> Result<(), McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
self.runtime.block_on(async {
let ping_fut = service
.peer()
.send_request(rmcp::model::ClientRequest::PingRequest(
rmcp::model::PingRequest {
method: Default::default(),
extensions: Default::default(),
},
));
match tokio::time::timeout(METADATA_TIMEOUT, ping_fut).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})
}
pub fn list_resources(&self) -> Result<Vec<rmcp::model::Resource>, McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
self.runtime.block_on(async {
match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_resources()).await
{
Ok(Ok(resources)) => Ok(resources),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})
}
pub fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
let params = ReadResourceRequestParams::new(uri);
self.runtime.block_on(async {
match tokio::time::timeout(METADATA_TIMEOUT, service.peer().read_resource(params)).await
{
Ok(Ok(result)) => Ok(result),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})
}
pub fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>, McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
self.runtime.block_on(async {
match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_prompts()).await {
Ok(Ok(prompts)) => Ok(prompts),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})
}
pub fn get_prompt(
&self,
name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> Result<GetPromptResult, McpError> {
let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
let mut params = GetPromptRequestParams::new(name);
if let Some(args) = arguments {
params.arguments = Some(args);
}
self.runtime.block_on(async {
match tokio::time::timeout(METADATA_TIMEOUT, service.peer().get_prompt(params)).await {
Ok(Ok(result)) => Ok(result),
Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
Err(_) => Err(McpError::Timeout),
}
})
}
pub fn server_info(&self) -> Option<&ServerInfo> {
self.server_info.as_ref()
}
pub fn disconnect(&mut self) -> Result<(), McpError> {
if let Some(service) = self.service.take() {
self.runtime.block_on(async {
let _ = service.cancel().await;
});
}
Ok(())
}
pub fn is_connected(&self) -> bool {
self.service
.as_ref()
.map(|s| !s.is_closed())
.unwrap_or(false)
}
}
impl Drop for McpClient {
fn drop(&mut self) {
if let Some(service) = self.service.take() {
let rt = self.runtime.clone();
std::thread::spawn(move || {
rt.block_on(async {
let _ = service.cancel().await;
});
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_error_display() {
let err = McpError::PermissionDenied("npx not allowed".to_string());
assert_eq!(err.to_string(), "Permission denied: npx not allowed");
let err = McpError::ConnectionFailed("spawn failed".to_string());
assert_eq!(err.to_string(), "Connection failed: spawn failed");
let err = McpError::ProtocolError("invalid response".to_string());
assert_eq!(err.to_string(), "Protocol error: invalid response");
let err = McpError::ToolError("division by zero".to_string());
assert_eq!(err.to_string(), "Tool error: division by zero");
let err = McpError::TransportClosed;
assert_eq!(err.to_string(), "Transport closed");
let err = McpError::Timeout;
assert_eq!(err.to_string(), "Timeout");
let err = McpError::RuntimeError("thread pool exhausted".to_string());
assert_eq!(err.to_string(), "Runtime error: thread pool exhausted");
}
#[test]
fn test_client_handler_info_no_sampling() {
let handler = TlClientHandler::new();
let info = handler.get_info();
assert_eq!(info.client_info.name, "tl");
assert_eq!(info.client_info.version, env!("CARGO_PKG_VERSION"));
assert_eq!(
info.client_info.title,
Some("ThinkingLanguage MCP Client".to_string())
);
assert!(info.capabilities.sampling.is_none());
}
#[test]
fn test_client_handler_info_with_sampling() {
let cb: SamplingCallback = Arc::new(|_req| {
Ok(SamplingResponse {
model: "test".to_string(),
content: "hello".to_string(),
stop_reason: None,
})
});
let handler = TlClientHandler::new().with_sampling(cb);
let info = handler.get_info();
assert_eq!(info.client_info.name, "tl");
assert!(info.capabilities.sampling.is_some());
}
#[test]
fn test_sampling_callback_construction() {
let cb: SamplingCallback = Arc::new(|req| {
Ok(SamplingResponse {
model: "test-model".to_string(),
content: format!(
"Echo: {}",
req.messages.last().map(|(_, c)| c.as_str()).unwrap_or("")
),
stop_reason: Some("endTurn".to_string()),
})
});
let handler = TlClientHandler::new().with_sampling(cb);
assert!(handler.sampling_callback.is_some());
}
#[test]
fn test_no_sampling_callback() {
let handler = TlClientHandler::new();
assert!(handler.sampling_callback.is_none());
}
#[test]
fn test_security_policy_denies_command() {
let mut policy = SecurityPolicy::sandbox();
let result = McpClient::connect("npx", &[], Some(&policy));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, McpError::PermissionDenied(_)));
policy.allow_subprocess = true;
policy.allowed_commands = vec!["node".to_string()];
let result = McpClient::connect("npx", &[], Some(&policy));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, McpError::PermissionDenied(_)));
}
#[test]
fn test_security_policy_allows_command() {
let mut policy = SecurityPolicy::sandbox();
policy.allow_subprocess = true;
policy.allowed_commands = vec!["echo".to_string()];
let result = McpClient::connect("echo", &["hello".to_string()], Some(&policy));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, McpError::ConnectionFailed(_)),
"Expected ConnectionFailed, got: {:?}",
err
);
}
#[test]
fn test_no_security_policy_allows_anything() {
let result = McpClient::connect("__nonexistent_mcp_server__", &[], None);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, McpError::ConnectionFailed(_)),
"Expected ConnectionFailed, got: {:?}",
err
);
}
#[test]
fn test_permissive_policy_allows_anything() {
let policy = SecurityPolicy::permissive();
let result = McpClient::connect("__nonexistent_mcp_server__", &[], Some(&policy));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, McpError::ConnectionFailed(_)),
"Expected ConnectionFailed, got: {:?}",
err
);
}
}