Skip to main content

stakpak_mcp_client/
local.rs

1use anyhow::Result;
2use rmcp::{
3    ClientHandler, RoleClient, ServiceExt,
4    model::{ClientCapabilities, ClientInfo, Implementation},
5    service::RunningService,
6    transport::TokioChildProcess,
7};
8use stakpak_shared::models::integrations::openai::ToolCallResultProgress;
9use tokio::process::Command;
10use tokio::sync::mpsc::Sender;
11
12#[derive(Clone)]
13pub struct LocalClientHandler {
14    progress_tx: Option<Sender<ToolCallResultProgress>>,
15}
16
17impl LocalClientHandler {
18    pub fn new(progress_tx: Option<Sender<ToolCallResultProgress>>) -> Self {
19        Self { progress_tx }
20    }
21}
22
23impl ClientHandler for LocalClientHandler {
24    async fn on_progress(
25        &self,
26        progress: rmcp::model::ProgressNotificationParam,
27        _ctx: rmcp::service::NotificationContext<rmcp::RoleClient>,
28    ) {
29        if let Some(progress_tx) = self.progress_tx.clone()
30            && let Some(message) = progress.message
31        {
32            match serde_json::from_str::<ToolCallResultProgress>(&message) {
33                Ok(tool_call_progress) => {
34                    let _ = progress_tx.send(tool_call_progress).await;
35                }
36                Err(e) => {
37                    tracing::warn!("Failed to deserialize ToolCallProgress: {}", e);
38                }
39            }
40        }
41    }
42
43    fn get_info(&self) -> ClientInfo {
44        ClientInfo {
45            protocol_version: Default::default(),
46            capabilities: ClientCapabilities::default(),
47            client_info: Implementation {
48                name: "stakpak-mcp-client".to_string(),
49                version: "0.0.1".to_string(),
50                title: Some("Stakpak MCP Client".to_string()),
51                icons: Some(vec![]),
52                website_url: Some("https://stakpak.dev".to_string()),
53            },
54        }
55    }
56}
57
58pub async fn connect(
59    progress_tx: Option<Sender<ToolCallResultProgress>>,
60) -> Result<RunningService<RoleClient, LocalClientHandler>> {
61    // Get the path to the current executable and use it for the proxy
62    let current_exe = std::env::current_exe()?;
63    let mut cmd = Command::new(current_exe);
64    cmd.arg("mcp").arg("proxy");
65
66    let proc = TokioChildProcess::new(cmd)?;
67    let client_handler = LocalClientHandler::new(progress_tx);
68    let client: RunningService<RoleClient, LocalClientHandler> = client_handler.serve(proc).await?;
69
70    Ok(client)
71}