use std::collections::HashMap;
use std::process::{Command, Stdio};
use std::sync::mpsc;
use std::thread;
use std::time::{Duration, Instant};
use serde_json::{json, Value};
use crate::audit::{audit_call, CallStatus};
use crate::backend;
use crate::errors::{McpError, McpErrorKind};
use crate::session::{looks_like_glob, Session};
const DEFAULT_TIMEOUT_SECS: u64 = 60;
const MIN_TIMEOUT_SECS: u64 = 1;
const MAX_TIMEOUT_SECS: u64 = 600;
#[derive(Debug)]
struct RunArgs {
command: String,
args: Vec<String>,
allowed_keys: Vec<String>,
cwd: Option<String>,
timeout_secs: u64,
}
pub fn call(session: &Session, raw: Value) -> Result<Value, McpError> {
let args = parse_args(&raw)?;
for key in &args.allowed_keys {
if looks_like_glob(key) {
let err = McpError::new(
McpErrorKind::InvalidParams,
format!(
"allowed_keys entry '{key}' contains glob characters; only literal key names are accepted"
),
);
audit_failure(session, &args.allowed_keys, &err);
return Err(err);
}
if !session.is_in_scope(key) {
let err = McpError::new(
McpErrorKind::KeyOutOfScope,
format!("key '{key}' is outside the configured scope for this server"),
)
.with_data(json!({"key": key}));
audit_failure(session, &args.allowed_keys, &err);
return Err(err);
}
}
let vault = match backend::open_vault(session) {
Ok(v) => v,
Err(e) => {
audit_failure(session, &args.allowed_keys, &e);
return Err(e);
}
};
let mut secrets: HashMap<String, String> = HashMap::new();
for key in &args.allowed_keys {
match backend::lookup_key(&vault, key) {
Ok(value) => {
secrets.insert(key.clone(), value.as_str().to_string());
}
Err(e) => {
audit_failure(session, &args.allowed_keys, &e);
return Err(e);
}
}
}
drop(vault);
let started_at = Instant::now();
let outcome = spawn_with_timeout(&args, &secrets);
let duration_ms = started_at.elapsed().as_millis() as u64;
for (_, mut v) in secrets {
use zeroize::Zeroize;
v.zeroize();
}
match outcome {
Ok(child_outcome) => {
audit_call(
session,
"tsafe_run",
None,
args.allowed_keys.clone(),
Some(child_outcome.exit_code),
Some(duration_ms),
CallStatus::Success,
None,
);
Ok(json!({
"stdout": child_outcome.stdout,
"stderr": child_outcome.stderr,
"exit_code": child_outcome.exit_code,
"duration_ms": duration_ms,
"injected_keys": args.allowed_keys,
}))
}
Err(e) => {
audit_call(
session,
"tsafe_run",
None,
args.allowed_keys.clone(),
None,
Some(duration_ms),
CallStatus::Failure,
Some(&e.message),
);
Err(e)
}
}
}
fn audit_failure(session: &Session, injected: &[String], err: &McpError) {
audit_call(
session,
"tsafe_run",
None,
injected.to_vec(),
None,
None,
CallStatus::Failure,
Some(&err.message),
);
}
fn parse_args(raw: &Value) -> Result<RunArgs, McpError> {
let obj = raw
.as_object()
.ok_or_else(|| McpError::new(McpErrorKind::InvalidParams, "expected an object"))?;
for k in obj.keys() {
if !matches!(
k.as_str(),
"command" | "args" | "allowed_keys" | "cwd" | "timeout_secs"
) {
return Err(McpError::new(
McpErrorKind::InvalidParams,
format!("unknown field '{k}'"),
));
}
}
let command = obj
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError::new(McpErrorKind::InvalidParams, "missing 'command'"))?
.to_string();
if command.is_empty() {
return Err(McpError::new(
McpErrorKind::InvalidParams,
"'command' must be non-empty",
));
}
let args = obj
.get("args")
.map(|v| -> Result<Vec<String>, McpError> {
let arr = v.as_array().ok_or_else(|| {
McpError::new(
McpErrorKind::InvalidParams,
"'args' must be an array of strings",
)
})?;
arr.iter()
.map(|x| {
x.as_str().map(str::to_string).ok_or_else(|| {
McpError::new(
McpErrorKind::InvalidParams,
"'args' entries must be strings",
)
})
})
.collect()
})
.transpose()?
.unwrap_or_default();
let allowed_keys = obj
.get("allowed_keys")
.ok_or_else(|| McpError::new(McpErrorKind::InvalidParams, "missing 'allowed_keys'"))?;
let allowed_keys_arr = allowed_keys.as_array().ok_or_else(|| {
McpError::new(
McpErrorKind::InvalidParams,
"'allowed_keys' must be an array",
)
})?;
if allowed_keys_arr.is_empty() {
return Err(McpError::new(
McpErrorKind::InvalidParams,
"'allowed_keys' must have minItems=1",
));
}
let allowed_keys: Vec<String> = allowed_keys_arr
.iter()
.map(|v| {
v.as_str().map(str::to_string).ok_or_else(|| {
McpError::new(
McpErrorKind::InvalidParams,
"'allowed_keys' entries must be strings",
)
})
})
.collect::<Result<_, _>>()?;
let cwd = obj
.get("cwd")
.map(|v| {
v.as_str()
.map(str::to_string)
.ok_or_else(|| McpError::new(McpErrorKind::InvalidParams, "'cwd' must be a string"))
})
.transpose()?;
let timeout_secs = match obj.get("timeout_secs") {
Some(v) => {
let n = v.as_u64().ok_or_else(|| {
McpError::new(
McpErrorKind::InvalidParams,
"'timeout_secs' must be a positive integer",
)
})?;
if !(MIN_TIMEOUT_SECS..=MAX_TIMEOUT_SECS).contains(&n) {
return Err(McpError::new(
McpErrorKind::InvalidParams,
format!(
"'timeout_secs' must be between {MIN_TIMEOUT_SECS} and {MAX_TIMEOUT_SECS}"
),
));
}
n
}
None => DEFAULT_TIMEOUT_SECS,
};
Ok(RunArgs {
command,
args,
allowed_keys,
cwd,
timeout_secs,
})
}
#[derive(Debug)]
struct ChildOutcome {
stdout: String,
stderr: String,
exit_code: i32,
}
fn spawn_with_timeout(
args: &RunArgs,
secrets: &HashMap<String, String>,
) -> Result<ChildOutcome, McpError> {
let mut cmd = Command::new(&args.command);
cmd.args(&args.args);
cmd.stdin(Stdio::null());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
if let Some(dir) = &args.cwd {
cmd.current_dir(dir);
}
for (k, v) in secrets {
cmd.env(k, v);
}
let mut child = cmd.spawn().map_err(|e| {
McpError::new(
McpErrorKind::InternalError,
format!("failed to spawn '{}': {e}", args.command),
)
})?;
let (tx, rx) = mpsc::channel();
let pid = child.id();
let stdout_handle = child.stdout.take();
let stderr_handle = child.stderr.take();
let join = thread::spawn(move || {
use std::io::Read;
let mut out_buf = Vec::new();
let mut err_buf = Vec::new();
if let Some(mut so) = stdout_handle {
let _ = so.read_to_end(&mut out_buf);
}
if let Some(mut se) = stderr_handle {
let _ = se.read_to_end(&mut err_buf);
}
let status = child.wait();
let _ = tx.send((status, out_buf, err_buf));
});
let timeout = Duration::from_secs(args.timeout_secs);
match rx.recv_timeout(timeout) {
Ok((status, out_buf, err_buf)) => {
let _ = join.join();
let status = status.map_err(|e| {
McpError::new(
McpErrorKind::InternalError,
format!("child wait failed: {e}"),
)
})?;
Ok(ChildOutcome {
stdout: String::from_utf8_lossy(&out_buf).into_owned(),
stderr: String::from_utf8_lossy(&err_buf).into_owned(),
exit_code: status.code().unwrap_or(-1),
})
}
Err(_) => {
let _ = kill_pid(pid);
drop(join);
Err(McpError::new(
McpErrorKind::RunTimeout,
format!("command exceeded {}s timeout", args.timeout_secs),
))
}
}
}
#[cfg(unix)]
fn kill_pid(pid: u32) -> std::io::Result<()> {
use std::process::Command;
let status = Command::new("kill")
.arg("-9")
.arg(pid.to_string())
.status()?;
if status.success() {
Ok(())
} else {
Err(std::io::Error::other("kill returned non-zero"))
}
}
#[cfg(windows)]
fn kill_pid(pid: u32) -> std::io::Result<()> {
use std::process::Command;
let status = Command::new("taskkill")
.args(["/F", "/PID"])
.arg(pid.to_string())
.status()?;
if status.success() {
Ok(())
} else {
Err(std::io::Error::other("taskkill returned non-zero"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::path::PathBuf;
fn session() -> Session {
Session {
profile: "demo".to_string(),
allowed_globs: vec!["demo/*".to_string()],
denied_globs: vec![],
contract: None,
allow_reveal: false,
audit_source: "mcp:test:1".to_string(),
pid: 1,
require_agent: false,
vault_path: PathBuf::from("nonexistent"),
}
}
fn isolated<F: FnOnce()>(f: F) {
let tmp = tempfile::tempdir().unwrap();
let vault_dir = tmp.path().join("vaults");
std::fs::create_dir_all(&vault_dir).unwrap();
temp_env::with_var("TSAFE_VAULT_DIR", Some(vault_dir.as_os_str()), f);
}
#[test]
fn parse_args_requires_command_and_keys() {
let err = parse_args(&json!({})).unwrap_err();
assert_eq!(err.kind, McpErrorKind::InvalidParams);
let err = parse_args(&json!({"command": "echo"})).unwrap_err();
assert!(err.message.contains("allowed_keys"));
}
#[test]
fn parse_args_rejects_unknown_field() {
let err = parse_args(&json!({"command": "x", "allowed_keys": ["k"], "shell": "/bin/sh"}))
.unwrap_err();
assert!(err.message.contains("unknown field"));
}
#[test]
fn parse_args_clamps_timeout_range() {
let err = parse_args(&json!({"command": "x", "allowed_keys": ["k"], "timeout_secs": 0}))
.unwrap_err();
assert!(err.message.contains("timeout"));
let err = parse_args(&json!({"command": "x", "allowed_keys": ["k"], "timeout_secs": 9999}))
.unwrap_err();
assert!(err.message.contains("timeout"));
}
#[test]
fn out_of_scope_key_returns_key_out_of_scope() {
isolated(|| {
let err = call(
&session(),
json!({
"command": "echo",
"args": ["hi"],
"allowed_keys": ["other/forbidden"]
}),
)
.unwrap_err();
assert_eq!(err.kind, McpErrorKind::KeyOutOfScope);
});
}
#[test]
fn glob_in_allowed_keys_returns_invalid_params() {
isolated(|| {
let err = call(
&session(),
json!({
"command": "echo",
"allowed_keys": ["demo/*"]
}),
)
.unwrap_err();
assert_eq!(err.kind, McpErrorKind::InvalidParams);
assert!(err.message.contains("glob"));
});
}
}