#[allow(unused_imports)]
use crate::sync_util::LockExt;
use std::borrow::Cow;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use rig::completion::ToolDefinition;
use rig::tool::{ToolDyn, ToolError};
use rig::wasm_compat::WasmBoxedFuture;
use rmcp::ServiceError;
use rmcp::model::{CallToolRequestParams, JsonObject, RawContent};
use tokio::sync::Mutex;
use crate::agent::tools::check_perm;
use crate::extras::mcp::client::{SharedConnection, raw_connect};
use crate::extras::mcp::config::McpServerConfig;
use crate::permission::ask::AskSender;
use crate::permission::checker::PermCheck;
use crate::timeout::Deadline;
#[derive(Debug)]
pub struct McpToolError(pub String);
impl fmt::Display for McpToolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for McpToolError {}
pub struct McpTool {
pub server_name: String,
pub definition: rmcp::model::Tool,
pub connection: Arc<SharedConnection>,
pub config: Option<Arc<McpServerConfig>>,
pub reconnect_lock: Arc<Mutex<u64>>,
pub permission: Option<PermCheck>,
pub ask_tx: Option<AskSender>,
}
fn is_transport_failure(err: &ServiceError) -> bool {
matches!(
err,
ServiceError::TransportSend(_) | ServiceError::TransportClosed
)
}
impl ToolDyn for McpTool {
fn name(&self) -> String {
self.definition.name.to_string()
}
fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
let name = self.definition.name.to_string();
let description = self
.definition
.description
.clone()
.unwrap_or(Cow::from(""))
.to_string();
let parameters = serde_json::to_value(&self.definition.input_schema)
.ok()
.filter(|v| !v.is_null())
.unwrap_or_else(|| serde_json::json!({}));
Box::pin(async move {
ToolDefinition {
name,
description,
parameters,
}
})
}
fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
let server_name = self.server_name.clone();
let tool_name = self.definition.name.to_string();
let connection = Arc::clone(&self.connection);
let config = self.config.clone();
let reconnect_lock = self.reconnect_lock.clone();
let permission = self.permission.clone();
let ask_tx = self.ask_tx.clone();
Box::pin(async move {
if let Some(perm) = permission.as_ref() {
let qualified = format!("mcp_tool:{}:{}", server_name, tool_name);
let denied = {
let guard = perm.lock_ignore_poison();
guard.any_prompt_denied(&[tool_name.as_str(), qualified.as_str(), "mcp_tool"])
};
if denied {
return Err(ToolError::ToolCallError(Box::new(McpToolError(format!(
"MCP tool {}::{} is denied by the active prompt's `deny_tools` frontmatter. Switch with `/prompt <other>` to use it.",
server_name, tool_name,
)))));
}
}
let perm_key = format!("mcp_tool:{server_name}:{tool_name}");
check_perm(&permission, &ask_tx, "mcp_tool", &perm_key)
.await
.map_err(|e| ToolError::ToolCallError(Box::new(McpToolError(e.to_string()))))?;
let trimmed = args.trim();
let arguments: Option<JsonObject> = if trimmed.is_empty() {
None
} else {
match serde_json::from_str::<JsonObject>(trimmed) {
Ok(obj) => Some(obj),
Err(e) => {
return Err(ToolError::ToolCallError(Box::new(McpToolError(format!(
"MCP tool {}::{}: malformed JSON arguments ({e}). Got: {trimmed:.200}",
server_name, tool_name,
)))));
}
}
};
let allow_external = config
.as_ref()
.map(|c| c.allow_external_paths())
.unwrap_or(false);
if let Some(perm) = permission.as_ref()
&& let Some(args_obj) = arguments.as_ref()
&& let Some(p) = first_external_path(perm, args_obj, allow_external)
{
return Err(ToolError::ToolCallError(Box::new(McpToolError(format!(
"MCP tool {server_name}::{tool_name} refused: path {p:?} is outside the working directory. \
Set `allow_external_paths: true` on the `{server_name}` server config to permit external paths for this server."
)))));
}
let params = arguments
.map(|a| CallToolRequestParams::new(tool_name.clone()).with_arguments(a))
.unwrap_or_else(|| CallToolRequestParams::new(tool_name.clone()));
let deadline = Deadline::start(crate::timeout::Timeouts::get().mcp_call);
let result = match try_call_with_reconnect(
&server_name,
&connection,
config.as_deref(),
&reconnect_lock,
params,
deadline,
)
.await
{
Ok(r) => r,
Err(e) => {
return Err(ToolError::ToolCallError(Box::new(McpToolError(e))));
}
};
if result.is_error.unwrap_or(false) {
let error_msg = result
.content
.iter()
.filter_map(|c| match &c.raw {
RawContent::Text(t) => Some(t.text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
let msg = if error_msg.is_empty() {
"MCP tool returned an error".to_string()
} else {
error_msg
};
return Err(ToolError::ToolCallError(Box::new(McpToolError(msg))));
}
const MCP_RESULT_CAP_BYTES: usize = 256 * 1024;
let mut content = String::new();
let mut truncated = false;
for item in result.content {
if truncated {
break;
}
let chunk: String = match item.raw {
RawContent::Text(t) => t.text,
RawContent::Image(img) => {
format!("data:{};base64,{}", img.mime_type, img.data)
}
RawContent::Resource(r) => match r.resource {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => text,
rmcp::model::ResourceContents::BlobResourceContents { blob, .. } => blob,
},
_ => continue,
};
let remaining = MCP_RESULT_CAP_BYTES.saturating_sub(content.len());
if chunk.len() <= remaining {
content.push_str(&chunk);
} else {
let mut cut = remaining;
while cut > 0 && !chunk.is_char_boundary(cut) {
cut -= 1;
}
content.push_str(&chunk[..cut]);
truncated = true;
}
}
if truncated {
content.push_str(&format!(
"\n…[MCP result truncated at {} bytes — {}::{} returned more]",
MCP_RESULT_CAP_BYTES, server_name, tool_name,
));
}
Ok(content)
})
}
}
async fn try_call_with_reconnect(
server_name: &str,
connection: &Arc<SharedConnection>,
config: Option<&McpServerConfig>,
reconnect_lock: &Arc<Mutex<u64>>,
params: CallToolRequestParams,
deadline: Deadline,
) -> Result<rmcp::model::CallToolResult, String> {
let gen_before = *reconnect_lock.lock().await;
let remaining = deadline.remaining();
let first = call_once(server_name, connection, params.clone(), remaining).await;
let err = match first {
Ok(r) => return Ok(r),
Err(e) => e,
};
let Some(svc_err) = err.as_service_error() else {
return Err(err.message);
};
if !is_transport_failure(svc_err) {
return Err(err.message);
}
let Some(cfg) = config else {
return Err(format!(
"{}\n(auto-reconnect unavailable — no config retained for server '{}')",
err.message, server_name,
));
};
{
let mut gen_guard = reconnect_lock.lock().await;
if *gen_guard == gen_before {
tracing::warn!(
target: "dirge::mcp",
server = %server_name,
"transport failure detected — attempting auto-reconnect",
);
let reconnect_budget = deadline.remaining();
let reconnect_result =
tokio::time::timeout(reconnect_budget, raw_connect(server_name, cfg)).await;
match reconnect_result {
Ok(Ok((new_peer, new_rs))) => {
connection.replace(new_peer, new_rs).await;
*gen_guard += 1;
tracing::info!(
target: "dirge::mcp",
server = %server_name,
"MCP server reconnected after transport failure",
);
}
Ok(Err(e)) => {
return Err(format!(
"{}\n(auto-reconnect to '{}' also failed: {})",
err.message, server_name, e,
));
}
Err(_) => {
return Err(format!(
"{}\n(auto-reconnect to '{}' timed out within the {}s budget)",
err.message,
server_name,
deadline.budget().as_secs(),
));
}
}
}
}
if deadline.is_expired() {
return Err(format!(
"MCP tool {}::{} budget ({}s) exhausted before retry",
server_name,
params.name,
deadline.budget().as_secs(),
));
}
let remaining = deadline.remaining();
match call_once(server_name, connection, params, remaining).await {
Ok(r) => Ok(r),
Err(e) => Err(format!(
"{}\n(reconnected but the retry also failed)",
e.message,
)),
}
}
struct CallErr {
message: String,
service_error: Option<ServiceError>,
}
impl CallErr {
fn as_service_error(&self) -> Option<&ServiceError> {
self.service_error.as_ref()
}
}
async fn call_once(
server_name: &str,
connection: &Arc<SharedConnection>,
params: CallToolRequestParams,
timeout: Duration,
) -> Result<rmcp::model::CallToolResult, CallErr> {
let tool_name = params.name.to_string();
let peer = connection.current_peer().await;
match tokio::time::timeout(timeout, peer.call_tool(params)).await {
Ok(Ok(r)) => Ok(r),
Ok(Err(svc_err)) => {
let msg = format!("MCP tool error ({server_name}::{tool_name}): {svc_err}");
Err(CallErr {
message: msg,
service_error: Some(svc_err),
})
}
Err(_) => Err(CallErr {
message: format!(
"MCP tool {server_name}::{tool_name} timed out after {}s",
timeout.as_secs(),
),
service_error: Some(ServiceError::Timeout { timeout }),
}),
}
}
pub(crate) fn first_external_path(
perm: &PermCheck,
args: &JsonObject,
allow_external: bool,
) -> Option<String> {
if allow_external {
return None;
}
let paths = extract_arg_paths(args);
if paths.is_empty() {
return None;
}
let guard = perm.lock_ignore_poison();
paths.into_iter().find(|p| guard.is_external_path(p))
}
fn extract_arg_paths(args: &JsonObject) -> Vec<String> {
const PATH_KEYS: &[&str] = &[
"path",
"file_path",
"file",
"directory",
"dir",
"cwd",
"src",
"src_path",
"source",
"target",
"target_path",
"dest",
"destination",
"input_file",
"output_file",
"working_dir",
"target_dir",
"paths",
];
fn key_is_pathlike(k: &str) -> bool {
PATH_KEYS.contains(&k)
}
fn walk(v: &serde_json::Value, key_pathlike: bool, out: &mut Vec<String>) {
match v {
serde_json::Value::String(s) => {
if s.is_empty() {
return;
}
if key_pathlike
|| s.starts_with('/')
|| s.starts_with("./")
|| s.starts_with("../")
|| s.starts_with("~/")
{
out.push(s.clone());
}
}
serde_json::Value::Array(arr) => {
for e in arr {
walk(e, key_pathlike, out);
}
}
serde_json::Value::Object(map) => {
for (k, val) in map {
walk(val, key_is_pathlike(k), out);
}
}
_ => {}
}
}
let mut out: Vec<String> = Vec::new();
for (k, val) in args.iter() {
walk(val, key_is_pathlike(k), &mut out);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[test]
fn is_transport_failure_classifies_correctly() {
assert!(is_transport_failure(&ServiceError::TransportClosed));
assert!(!is_transport_failure(&ServiceError::UnexpectedResponse));
assert!(!is_transport_failure(&ServiceError::Timeout {
timeout: Duration::from_secs(1),
}));
let mcp_err = rmcp::ErrorData::new(
rmcp::model::ErrorCode::INTERNAL_ERROR,
"the tool refused",
None,
);
assert!(!is_transport_failure(&ServiceError::McpError(mcp_err)));
assert!(!is_transport_failure(&ServiceError::Cancelled {
reason: Some("user".into()),
}));
}
#[test]
fn mcp_deadline_decays_and_saturates() {
let now = Instant::now();
let total = Duration::from_millis(100);
let deadline = Deadline::from_start(now, total);
assert!(deadline.remaining() > Duration::from_millis(90));
std::thread::sleep(Duration::from_millis(110));
assert_eq!(deadline.remaining(), Duration::ZERO);
}
use crate::agent::tools::check_perm;
use crate::permission::{
Action, OpSpec, PermissionConfig, RuleConfig, SecurityMode, checker::PermissionChecker,
};
use std::sync::{Arc, Mutex as StdMutex};
fn mk_perm(extra_rules: PermissionConfig) -> (PermCheck, String) {
let cwd = std::env::temp_dir().join(format!(
"dirge-mgub-{}-{}",
std::process::id(),
crate::time_util::now_unix_nanos(),
));
std::fs::create_dir_all(&cwd).expect("create temp cwd");
let checker =
PermissionChecker::new(&extra_rules, SecurityMode::Standard, Some(cwd.clone()));
let perm: PermCheck = Arc::new(StdMutex::new(checker));
(perm, cwd.to_string_lossy().into_owned())
}
#[test]
fn mcp_allow_external_paths_default_false_blocks_external() {
let (perm, _cwd) = mk_perm(PermissionConfig::default());
let args: JsonObject =
serde_json::from_str(r#"{"path": "/etc/passwd"}"#).expect("parse args");
let hit = first_external_path(&perm, &args, false);
assert_eq!(
hit.as_deref(),
Some("/etc/passwd"),
"default config must flag an external path; got {hit:?}",
);
}
#[test]
fn mcp_allow_external_paths_true_permits_external() {
let (perm, _cwd) = mk_perm(PermissionConfig::default());
let args: JsonObject =
serde_json::from_str(r#"{"path": "/etc/passwd", "paths": ["/var/log/system.log"]}"#)
.expect("parse args");
let hit = first_external_path(&perm, &args, true);
assert!(
hit.is_none(),
"allow_external_paths=true must skip the cwd guard; got {hit:?}",
);
}
#[tokio::test]
async fn mcp_allow_external_paths_does_not_bypass_deny_rules() {
let config = PermissionConfig {
rules: vec![RuleConfig {
op: OpSpec::Mcp,
pattern: "mcp_tool:indexer:*".to_string(),
effect: Action::Deny,
tool: None,
}],
..Default::default()
};
let (perm, _cwd) = mk_perm(config);
let args: JsonObject =
serde_json::from_str(r#"{"path": "/etc/passwd"}"#).expect("parse args");
let perm_key = "mcp_tool:indexer:scan".to_string();
let result = check_perm(&Some(perm.clone()), &None, "mcp_tool", &perm_key).await;
assert!(
result.is_err(),
"deny rule must block the call even when allow_external_paths=true",
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("denied") || msg.contains("Deny") || msg.contains("Blocked"),
"expected deny message, got {msg:?}",
);
let guard_hit = first_external_path(&perm, &args, true);
assert!(
guard_hit.is_none(),
"guard must skip cwd-check on allow_external_paths=true; got {guard_hit:?}",
);
}
#[test]
fn mcp_external_path_guard_skips_argless_calls() {
let (perm, _cwd) = mk_perm(PermissionConfig::default());
let args: JsonObject = serde_json::from_str(r#"{"query": "needle"}"#).expect("parse args");
assert!(first_external_path(&perm, &args, false).is_none());
}
#[test]
fn mcp_external_path_guard_permits_in_cwd_paths() {
let (perm, cwd) = mk_perm(PermissionConfig::default());
let abs_in = format!("{cwd}/inside.txt");
let args = serde_json::json!({
"path": abs_in,
"paths": ["./relative-inside.rs"],
});
let obj = args.as_object().unwrap().clone();
assert!(first_external_path(&perm, &obj, false).is_none());
}
#[test]
fn mcp_external_path_guard_catches_extended_key_names() {
let (perm, _cwd) = mk_perm(PermissionConfig::default());
for key in [
"src",
"src_path",
"source",
"target",
"target_path",
"dest",
"destination",
"input_file",
"output_file",
"working_dir",
"target_dir",
] {
let mut obj = serde_json::Map::new();
obj.insert(
key.to_string(),
serde_json::Value::String("/etc/passwd".to_string()),
);
assert!(
first_external_path(&perm, &obj, false).is_some(),
"extended key {key:?} should be picked up by the guard",
);
}
}
#[test]
fn mcp_external_path_guard_path_shaped_fallback() {
let (perm, _cwd) = mk_perm(PermissionConfig::default());
let args = serde_json::json!({"weird_key_name_my_mcp_uses": "/etc/passwd"});
let obj = args.as_object().unwrap().clone();
assert!(first_external_path(&perm, &obj, false).is_some());
let args = serde_json::json!({
"url": "https://example.com/path",
"regex": ".*/foo",
"name": "foo/bar",
});
let obj = args.as_object().unwrap().clone();
assert!(first_external_path(&perm, &obj, false).is_none());
}
#[test]
fn mcp_external_path_guard_recurses_into_nested_args() {
let (perm, _cwd) = mk_perm(PermissionConfig::default());
let args = serde_json::json!({"options": {"cwd": "/etc"}});
let obj = args.as_object().unwrap().clone();
assert!(
first_external_path(&perm, &obj, false).is_some(),
"nested whitelisted key should be caught",
);
let args = serde_json::json!({
"edits": [
{"path": "src/in_cwd.rs", "contents": "ok"},
{"path": "/etc/passwd", "contents": "evil"},
],
});
let obj = args.as_object().unwrap().clone();
assert!(
first_external_path(&perm, &obj, false).is_some(),
"external path in array-of-objects should be caught",
);
let args = serde_json::json!({"payload": {"items": ["ok", "/etc/shadow"]}});
let obj = args.as_object().unwrap().clone();
assert!(
first_external_path(&perm, &obj, false).is_some(),
"path-shaped scalar in nested array should be caught",
);
let args = serde_json::json!({
"options": {"url": "https://example.com", "label": "foo/bar"},
"edits": [{"note": "nothing path-shaped"}],
});
let obj = args.as_object().unwrap().clone();
assert!(
first_external_path(&perm, &obj, false).is_none(),
"nested non-path values must not trip the guard",
);
}
#[test]
fn mcp_server_config_allow_external_paths_round_trip() {
let cmd_default: McpServerConfig = serde_json::from_str(r#"{"command": "x"}"#).unwrap();
assert!(!cmd_default.allow_external_paths());
let cmd_true: McpServerConfig =
serde_json::from_str(r#"{"command": "x", "allow_external_paths": true}"#).unwrap();
assert!(cmd_true.allow_external_paths());
let url_default: McpServerConfig = serde_json::from_str(r#"{"url": "https://x"}"#).unwrap();
assert!(!url_default.allow_external_paths());
let url_true: McpServerConfig =
serde_json::from_str(r#"{"url": "https://x", "allow_external_paths": true}"#).unwrap();
assert!(url_true.allow_external_paths());
}
}