use crate::commands::agent::run::helpers::convert_tools_with_filter;
use crate::commands::get_client;
use crate::config::AppConfig;
use crate::utils::network;
use stakpak_api::local::skills::default_skill_directories;
use stakpak_mcp_client::McpClient;
use stakpak_mcp_proxy::client::{ClientPoolConfig, ServerConfig};
use stakpak_mcp_proxy::server::start_proxy_server;
use stakpak_mcp_server::{
EnabledToolsConfig, MCPServerConfig, SubagentConfig, ToolMode, start_server,
};
use stakpak_shared::cert_utils::CertificateChain;
use stakpak_shared::models::integrations::openai::ToolCallResultProgress;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tokio::sync::mpsc::Sender;
#[allow(dead_code)]
pub struct McpInitConfig {
pub enabled_tools: EnabledToolsConfig,
pub enable_mtls: bool,
pub enable_subagents: bool,
pub allowed_tools: Option<Vec<String>>,
pub subagent_config: SubagentConfig,
pub redact_secrets: bool,
pub privacy_mode: bool,
}
impl Default for McpInitConfig {
fn default() -> Self {
Self {
enabled_tools: EnabledToolsConfig { slack: false },
enable_mtls: true,
enable_subagents: true,
allowed_tools: None,
subagent_config: SubagentConfig::default(),
redact_secrets: true,
privacy_mode: false,
}
}
}
pub struct McpInitResult {
pub client: Arc<McpClient>,
pub mcp_tools: Vec<rmcp::model::Tool>,
pub tools: Vec<stakpak_shared::models::integrations::openai::Tool>,
pub server_shutdown_tx: broadcast::Sender<()>,
pub proxy_shutdown_tx: broadcast::Sender<()>,
}
struct CertificateChains {
server_chain: Arc<Option<CertificateChain>>,
proxy_chain: Arc<CertificateChain>,
}
struct ServerBinding {
address: String,
listener: TcpListener,
}
impl CertificateChains {
fn generate() -> Result<Self, String> {
let server_chain =
Arc::new(Some(CertificateChain::generate().map_err(|e| {
format!("Failed to generate server certificates: {}", e)
})?));
let proxy_chain = Arc::new(
CertificateChain::generate()
.map_err(|e| format!("Failed to generate proxy certificates: {}", e))?,
);
Ok(Self {
server_chain,
proxy_chain,
})
}
}
impl ServerBinding {
async fn new(purpose: &str) -> Result<Self, String> {
let (address, listener) = network::find_available_bind_address_with_listener()
.await
.map_err(|e| format!("Failed to find available port for {}: {}", purpose, e))?;
Ok(Self { address, listener })
}
fn https_url(&self, path: &str) -> String {
format!("https://{}{}", self.address, path)
}
}
async fn start_mcp_server(
app_config: &AppConfig,
mcp_config: &McpInitConfig,
binding: ServerBinding,
cert_chain: Arc<Option<CertificateChain>>,
shutdown_rx: broadcast::Receiver<()>,
) -> Result<(), String> {
let api_client = get_client(app_config).await?;
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
let bind_address = binding.address.clone();
let enabled_tools = mcp_config.enabled_tools.clone();
let enable_subagents = mcp_config.enable_subagents;
let subagent_config = mcp_config.subagent_config.clone();
tokio::spawn(async move {
let server_config = MCPServerConfig {
client: Some(api_client),
bind_address,
enabled_tools,
tool_mode: ToolMode::Combined,
enable_subagents,
certificate_chain: cert_chain,
skill_directories: default_skill_directories(),
subagent_config,
server_tls_config: None,
};
let _ = ready_tx.send(Ok(()));
if let Err(e) = start_server(server_config, Some(binding.listener), Some(shutdown_rx)).await
{
tracing::error!("Local MCP server error: {}", e);
}
});
ready_rx
.await
.map_err(|_| "MCP server task failed to start".to_string())?
.map_err(|e| format!("MCP server failed to start: {}", e))?;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(())
}
fn build_proxy_config(
local_server_url: String,
server_cert_chain: Arc<Option<CertificateChain>>,
) -> ClientPoolConfig {
let mut servers: HashMap<String, ServerConfig> = HashMap::new();
servers.insert(
"stakpak".to_string(),
ServerConfig::Http {
url: local_server_url,
headers: None,
certificate_chain: server_cert_chain,
client_tls_config: None,
},
);
servers.insert(
"paks".to_string(),
ServerConfig::Http {
url: "https://apiv2.stakpak.dev/v1/paks/mcp".to_string(),
headers: None,
certificate_chain: Arc::new(None),
client_tls_config: None,
},
);
ClientPoolConfig::with_servers(servers)
}
async fn start_proxy(
pool_config: ClientPoolConfig,
mcp_config: &McpInitConfig,
binding: ServerBinding,
cert_chain: Arc<CertificateChain>,
shutdown_rx: broadcast::Receiver<()>,
) -> Result<(), String> {
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
let redact_secrets = mcp_config.redact_secrets;
let privacy_mode = mcp_config.privacy_mode;
tokio::spawn(async move {
let _ = ready_tx.send(Ok(()));
if let Err(e) = start_proxy_server(
pool_config,
binding.listener,
cert_chain,
redact_secrets,
privacy_mode,
Some(shutdown_rx),
)
.await
{
tracing::error!("Proxy server error: {}", e);
}
});
ready_rx
.await
.map_err(|_| "Proxy server task failed to start".to_string())?
.map_err(|e| format!("Proxy server failed to start: {}", e))?;
Ok(())
}
async fn connect_to_proxy(
proxy_url: &str,
cert_chain: Arc<CertificateChain>,
progress_tx: Option<Sender<ToolCallResultProgress>>,
) -> Result<Arc<McpClient>, String> {
const MAX_RETRIES: u32 = 5;
let mut retry_delay = tokio::time::Duration::from_millis(50);
let mut last_error = None;
for attempt in 1..=MAX_RETRIES {
match stakpak_mcp_client::connect_https(
proxy_url,
Some(cert_chain.clone()),
progress_tx.clone(),
)
.await
{
Ok(client) => return Ok(Arc::new(client)),
Err(e) => {
last_error = Some(e);
if attempt < MAX_RETRIES {
tokio::time::sleep(retry_delay).await;
retry_delay *= 2; }
}
}
}
Err(format!(
"Failed to connect to MCP proxy after {} retries: {}",
MAX_RETRIES,
last_error.map(|e| e.to_string()).unwrap_or_default()
))
}
pub async fn initialize_mcp_server_and_tools(
app_config: &AppConfig,
mcp_config: McpInitConfig,
progress_tx: Option<Sender<ToolCallResultProgress>>,
) -> Result<McpInitResult, String> {
let certs = CertificateChains::generate()?;
let server_binding = ServerBinding::new("MCP server").await?;
let proxy_binding = ServerBinding::new("proxy").await?;
let local_mcp_server_url = server_binding.https_url("/mcp");
let proxy_url = proxy_binding.https_url("/mcp");
let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel::<()>(1);
let (proxy_shutdown_tx, proxy_shutdown_rx) = broadcast::channel::<()>(1);
start_mcp_server(
app_config,
&mcp_config,
server_binding,
certs.server_chain.clone(),
server_shutdown_rx,
)
.await?;
let pool_config = build_proxy_config(local_mcp_server_url, certs.server_chain);
start_proxy(
pool_config,
&mcp_config,
proxy_binding,
certs.proxy_chain.clone(),
proxy_shutdown_rx,
)
.await?;
let mcp_client = connect_to_proxy(&proxy_url, certs.proxy_chain, progress_tx).await?;
let mcp_tools = stakpak_mcp_client::get_tools(&mcp_client)
.await
.map_err(|e| format!("Failed to get tools: {}", e))?;
let allowed_tools_ref = mcp_config
.allowed_tools
.as_ref()
.or(app_config.allowed_tools.as_ref());
let tools = convert_tools_with_filter(&mcp_tools, allowed_tools_ref);
Ok(McpInitResult {
client: mcp_client,
mcp_tools,
tools,
server_shutdown_tx,
proxy_shutdown_tx,
})
}