use super::{Agent, Channel, LlmProvider};
impl<C: Channel> Agent<C> {
pub(super) async fn handle_mcp_command(
&mut self,
args: &str,
) -> Result<(), super::error::AgentError> {
let parts: Vec<&str> = args.split_whitespace().collect();
match parts.first().copied() {
Some("add") => self.handle_mcp_add(&parts[1..]).await,
Some("list") => self.handle_mcp_list().await,
Some("tools") => self.handle_mcp_tools(parts.get(1).copied()).await,
Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
_ => {
self.channel
.send("Usage: /mcp add|list|tools|remove")
.await?;
Ok(())
}
}
}
async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<(), super::error::AgentError> {
if args.len() < 2 {
self.channel
.send("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>")
.await?;
return Ok(());
}
let Some(ref manager) = self.mcp.manager else {
self.channel.send("MCP is not enabled.").await?;
return Ok(());
};
let target = args[1];
let is_url = target.starts_with("http://") || target.starts_with("https://");
if !is_url
&& !self.mcp.allowed_commands.is_empty()
&& !self.mcp.allowed_commands.iter().any(|c| c == target)
{
self.channel
.send(&format!(
"Command '{target}' is not allowed. Permitted: {}",
self.mcp.allowed_commands.join(", ")
))
.await?;
return Ok(());
}
let current_count = manager.list_servers().await.len();
if current_count >= self.mcp.max_dynamic {
self.channel
.send(&format!(
"Server limit reached ({}/{}).",
current_count, self.mcp.max_dynamic
))
.await?;
return Ok(());
}
let transport = if is_url {
zeph_mcp::McpTransport::Http {
url: target.to_owned(),
headers: std::collections::HashMap::new(),
}
} else {
zeph_mcp::McpTransport::Stdio {
command: target.to_owned(),
args: args[2..].iter().map(|&s| s.to_owned()).collect(),
env: std::collections::HashMap::new(),
}
};
let entry = zeph_mcp::ServerEntry {
id: args[0].to_owned(),
transport,
timeout: std::time::Duration::from_secs(30),
trusted: false,
};
let _ = self.channel.send_status("connecting to mcp...").await;
match manager.add_server(&entry).await {
Ok(tools) => {
let _ = self.channel.send_status("").await;
let count = tools.len();
self.mcp.tools.extend(tools);
self.sync_mcp_executor_tools();
self.sync_mcp_registry().await;
let mcp_total = self.mcp.tools.len();
let mcp_servers = self
.mcp
.tools
.iter()
.map(|t| &t.server_id)
.collect::<std::collections::HashSet<_>>()
.len();
self.update_metrics(|m| {
m.mcp_tool_count = mcp_total;
m.mcp_server_count = mcp_servers;
});
self.channel
.send(&format!(
"Connected MCP server '{}' ({count} tool(s))",
entry.id
))
.await?;
Ok(())
}
Err(e) => {
let _ = self.channel.send_status("").await;
tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
self.channel
.send(&format!("Failed to connect server '{}': {e}", entry.id))
.await?;
Ok(())
}
}
}
async fn handle_mcp_list(&mut self) -> Result<(), super::error::AgentError> {
use std::fmt::Write;
let Some(ref manager) = self.mcp.manager else {
self.channel.send("MCP is not enabled.").await?;
return Ok(());
};
let server_ids = manager.list_servers().await;
if server_ids.is_empty() {
self.channel.send("No MCP servers connected.").await?;
return Ok(());
}
let mut output = String::from("Connected MCP servers:\n");
let mut total = 0usize;
for id in &server_ids {
let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
total += count;
let _ = writeln!(output, "- {id} ({count} tools)");
}
let _ = write!(output, "Total: {total} tool(s)");
self.channel.send(&output).await?;
Ok(())
}
async fn handle_mcp_tools(
&mut self,
server_id: Option<&str>,
) -> Result<(), super::error::AgentError> {
use std::fmt::Write;
let Some(server_id) = server_id else {
self.channel.send("Usage: /mcp tools <server_id>").await?;
return Ok(());
};
let tools: Vec<_> = self
.mcp
.tools
.iter()
.filter(|t| t.server_id == server_id)
.collect();
if tools.is_empty() {
self.channel
.send(&format!("No tools found for server '{server_id}'."))
.await?;
return Ok(());
}
let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
for t in &tools {
if t.description.is_empty() {
let _ = writeln!(output, "- {}", t.name);
} else {
let _ = writeln!(output, "- {} — {}", t.name, t.description);
}
}
self.channel.send(&output).await?;
Ok(())
}
async fn handle_mcp_remove(
&mut self,
server_id: Option<&str>,
) -> Result<(), super::error::AgentError> {
let Some(server_id) = server_id else {
self.channel.send("Usage: /mcp remove <id>").await?;
return Ok(());
};
let Some(ref manager) = self.mcp.manager else {
self.channel.send("MCP is not enabled.").await?;
return Ok(());
};
match manager.remove_server(server_id).await {
Ok(()) => {
let before = self.mcp.tools.len();
self.mcp.tools.retain(|t| t.server_id != server_id);
let removed = before - self.mcp.tools.len();
self.sync_mcp_executor_tools();
self.sync_mcp_registry().await;
let mcp_total = self.mcp.tools.len();
let mcp_servers = self
.mcp
.tools
.iter()
.map(|t| &t.server_id)
.collect::<std::collections::HashSet<_>>()
.len();
self.update_metrics(|m| {
m.mcp_tool_count = mcp_total;
m.mcp_server_count = mcp_servers;
m.active_mcp_tools
.retain(|name| !name.starts_with(&format!("{server_id}:")));
});
self.channel
.send(&format!(
"Disconnected MCP server '{server_id}' (removed {removed} tools)"
))
.await?;
Ok(())
}
Err(e) => {
tracing::warn!(server_id, "MCP remove failed: {e:#}");
self.channel
.send(&format!("Failed to remove server '{server_id}': {e}"))
.await?;
Ok(())
}
}
}
pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
let matched_tools = self.match_mcp_tools(query).await;
let active_mcp: Vec<String> = matched_tools
.iter()
.map(zeph_mcp::McpTool::qualified_name)
.collect();
let mcp_total = self.mcp.tools.len();
let mcp_servers = self
.mcp
.tools
.iter()
.map(|t| &t.server_id)
.collect::<std::collections::HashSet<_>>()
.len();
self.update_metrics(|m| {
m.active_mcp_tools = active_mcp;
m.mcp_tool_count = mcp_total;
m.mcp_server_count = mcp_servers;
});
if self.provider.supports_tool_use() {
return;
}
if !matched_tools.is_empty() {
let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
tracing::debug!(
skills = ?self.skill_state.active_skill_names,
mcp_tools = ?tool_names,
"matched items"
);
let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
if !tools_prompt.is_empty() {
system_prompt.push_str("\n\n");
system_prompt.push_str(&tools_prompt);
}
}
}
async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
let Some(ref registry) = self.mcp.registry else {
return self.mcp.tools.clone();
};
let provider = self.provider.clone();
registry
.search(query, self.skill_state.max_active_skills, |text| {
let owned = text.to_owned();
let p = provider.clone();
Box::pin(async move { p.embed(&owned).await })
})
.await
}
#[cfg(test)]
pub(crate) fn mcp_tool_count(&self) -> usize {
self.mcp.tools.len()
}
pub(super) async fn check_tool_refresh(&mut self) {
let Some(ref mut rx) = self.mcp.tool_rx else {
return;
};
if !rx.has_changed().unwrap_or(false) {
return;
}
let new_tools = rx.borrow_and_update().clone();
if new_tools.is_empty() {
return;
}
tracing::info!(
tools = new_tools.len(),
"tools/list_changed: agent tool list refreshed"
);
self.mcp.tools = new_tools;
self.sync_mcp_executor_tools();
self.sync_mcp_registry().await;
let mcp_total = self.mcp.tools.len();
let mcp_servers = self
.mcp
.tools
.iter()
.map(|t| &t.server_id)
.collect::<std::collections::HashSet<_>>()
.len();
self.update_metrics(|m| {
m.mcp_tool_count = mcp_total;
m.mcp_server_count = mcp_servers;
});
}
pub(super) fn sync_mcp_executor_tools(&self) {
if let Some(ref shared) = self.mcp.shared_tools {
let mut guard = shared
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.clone_from(&self.mcp.tools);
}
}
pub(super) async fn sync_mcp_registry(&mut self) {
let Some(ref mut registry) = self.mcp.registry else {
return;
};
if !self.provider.supports_embeddings() {
return;
}
let provider = self.provider.clone();
let embed_fn = |text: &str| -> zeph_mcp::registry::EmbedFuture {
let owned = text.to_owned();
let p = provider.clone();
Box::pin(async move { p.embed(&owned).await })
};
if let Err(e) = registry
.sync(&self.mcp.tools, &self.skill_state.embedding_model, embed_fn)
.await
{
tracing::warn!("failed to sync MCP tool registry: {e:#}");
}
}
}
#[cfg(test)]
mod tests {
use super::super::agent_tests::{
MockChannel, MockToolExecutor, create_test_registry, mock_provider,
};
use super::*;
#[tokio::test]
async fn handle_mcp_command_unknown_subcommand_shows_usage() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.handle_mcp_command("unknown").await.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("Usage: /mcp")),
"expected usage message, got: {sent:?}"
);
}
#[tokio::test]
async fn handle_mcp_list_no_manager_shows_disabled() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.handle_mcp_command("list").await.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("MCP is not enabled")),
"expected not-enabled message, got: {sent:?}"
);
}
#[tokio::test]
async fn handle_mcp_tools_no_server_id_shows_usage() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.handle_mcp_command("tools").await.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("Usage: /mcp tools")),
"expected tools usage message, got: {sent:?}"
);
}
#[tokio::test]
async fn handle_mcp_remove_no_server_id_shows_usage() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.handle_mcp_command("remove").await.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("Usage: /mcp remove")),
"expected remove usage message, got: {sent:?}"
);
}
#[tokio::test]
async fn handle_mcp_remove_no_manager_shows_disabled() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.handle_mcp_command("remove my-server").await.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("MCP is not enabled")),
"expected not-enabled message, got: {sent:?}"
);
}
#[tokio::test]
async fn handle_mcp_add_insufficient_args_shows_usage() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.handle_mcp_command("add server-id").await.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("Usage: /mcp add")),
"expected add usage message, got: {sent:?}"
);
}
#[tokio::test]
async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent
.handle_mcp_command("tools nonexistent-server")
.await
.unwrap();
let sent = agent.channel.sent_messages();
assert!(
sent.iter().any(|s| s.contains("No tools found")),
"expected no-tools message, got: {sent:?}"
);
}
#[tokio::test]
async fn mcp_tool_count_starts_at_zero() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let agent = Agent::new(provider, channel, registry, None, 5, executor);
assert_eq!(agent.mcp_tool_count(), 0);
}
#[tokio::test]
async fn check_tool_refresh_no_rx_is_noop() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.check_tool_refresh().await;
assert_eq!(agent.mcp_tool_count(), 0);
}
#[tokio::test]
async fn check_tool_refresh_no_change_is_noop() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
let (tx, rx) = tokio::sync::watch::channel(Vec::new());
agent.mcp.tool_rx = Some(rx);
agent.check_tool_refresh().await;
assert_eq!(agent.mcp_tool_count(), 0);
drop(tx);
}
#[tokio::test]
async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.mcp.tools = vec![zeph_mcp::McpTool {
server_id: "srv".into(),
name: "existing_tool".into(),
description: "".into(),
input_schema: serde_json::json!({}),
}];
let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
agent.mcp.tool_rx = Some(rx);
agent.check_tool_refresh().await;
assert_eq!(agent.mcp_tool_count(), 1);
}
#[tokio::test]
async fn check_tool_refresh_applies_update() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
agent.mcp.tool_rx = Some(rx);
let new_tools = vec![zeph_mcp::McpTool {
server_id: "srv".into(),
name: "refreshed_tool".into(),
description: "".into(),
input_schema: serde_json::json!({}),
}];
tx.send(new_tools).unwrap();
agent.check_tool_refresh().await;
assert_eq!(agent.mcp_tool_count(), 1);
assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
}
}