Skip to main content

stakpak_mcp_client/
lib.rs

1use anyhow::Result;
2use rmcp::{
3    RoleClient, ServiceExt,
4    model::{CallToolRequestParam, ClientRequest, Meta, Request, Tool},
5    service::{PeerRequestOptions, RequestHandle, RunningService},
6    transport::StreamableHttpClientTransport,
7    transport::streamable_http_client::StreamableHttpClientTransportConfig,
8};
9use stakpak_shared::cert_utils::CertificateChain;
10use stakpak_shared::models::integrations::openai::ToolCallResultProgress;
11use std::sync::Arc;
12use tokio::sync::mpsc::Sender;
13
14mod local;
15
16pub use local::LocalClientHandler;
17
18pub type McpClient = RunningService<RoleClient, LocalClientHandler>;
19
20/// Connect to the MCP proxy via stdio (legacy method)
21pub async fn connect(progress_tx: Option<Sender<ToolCallResultProgress>>) -> Result<McpClient> {
22    local::connect(progress_tx).await
23}
24
25/// Connect to an MCP server via HTTPS with optional mTLS
26pub async fn connect_https(
27    url: &str,
28    certificate_chain: Option<Arc<CertificateChain>>,
29    progress_tx: Option<Sender<ToolCallResultProgress>>,
30) -> Result<McpClient> {
31    let mut client_builder = reqwest::Client::builder()
32        .pool_idle_timeout(std::time::Duration::from_secs(90))
33        .pool_max_idle_per_host(10)
34        .tcp_keepalive(std::time::Duration::from_secs(60));
35
36    // Configure TLS: use mTLS cert chain if provided, otherwise use
37    // platform-verified TLS so the OS CA store is trusted.
38    if let Some(cert_chain) = certificate_chain {
39        let tls_config = cert_chain.create_client_config()?;
40        client_builder = client_builder.use_preconfigured_tls(tls_config);
41    } else {
42        let arc_crypto_provider = std::sync::Arc::new(rustls::crypto::ring::default_provider());
43        if let Ok(tls_config) = rustls::ClientConfig::builder_with_provider(arc_crypto_provider)
44            .with_safe_default_protocol_versions()
45            .map(|builder| {
46                rustls_platform_verifier::BuilderVerifierExt::with_platform_verifier(builder)
47                    .with_no_client_auth()
48            })
49        {
50            client_builder = client_builder.use_preconfigured_tls(tls_config);
51        }
52    }
53
54    let http_client = client_builder.build()?;
55
56    let config = StreamableHttpClientTransportConfig::with_uri(url);
57    let transport =
58        StreamableHttpClientTransport::<reqwest::Client>::with_client(http_client, config);
59
60    let client_handler = LocalClientHandler::new(progress_tx);
61    let client: McpClient = client_handler.serve(transport).await?;
62
63    Ok(client)
64}
65
66/// Get all available tools from the MCP client
67pub async fn get_tools(client: &McpClient) -> Result<Vec<Tool>> {
68    let tools = client.list_tools(Default::default()).await?;
69    Ok(tools.tools)
70}
71
72/// Call a tool on the MCP client
73pub async fn call_tool(
74    client: &McpClient,
75    params: CallToolRequestParam,
76    metadata: Option<serde_json::Map<String, serde_json::Value>>,
77) -> Result<RequestHandle<RoleClient>, String> {
78    let options = PeerRequestOptions {
79        meta: Some(Meta(metadata.unwrap_or_default())),
80        ..Default::default()
81    };
82    client
83        .send_cancellable_request(
84            ClientRequest::CallToolRequest(Request::new(params)),
85            options,
86        )
87        .await
88        .map_err(|e| e.to_string())
89}