use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use serde_json::{json, Value};
use tokio::io::AsyncBufReadExt;
use tokio::process::Command;
use tokio_util::sync::CancellationToken;
use crate::shell_risk::{classify_shell_command, ShellRiskLevel};
use crate::tools::{
builtin_tool_specs, fs_glob_bounded, ToolFailure, ToolFailureKind, ToolInvocation,
ToolOutcome, ToolRuntime, ToolRuntimeError, ToolSpec,
};
use crate::tools::approval::{is_read_only, ApprovalGate};
pub type EmitFn = Arc<dyn Fn(Value) + Send + Sync + 'static>;
pub struct LocalToolConfig {
pub cwd: Option<PathBuf>,
pub approval: Arc<dyn ApprovalGate>,
pub emit: EmitFn,
}
#[derive(Clone)]
pub struct LocalToolRuntime {
cwd: PathBuf,
approval: Arc<dyn ApprovalGate>,
emit: EmitFn,
}
impl LocalToolRuntime {
pub fn new(config: LocalToolConfig) -> Self {
let cwd = config.cwd
.filter(|p| !p.as_os_str().is_empty())
.or_else(|| std::env::current_dir().ok())
.unwrap_or_else(|| PathBuf::from("/"));
Self { cwd, approval: config.approval, emit: config.emit }
}
fn resolve(&self, path: &str) -> PathBuf {
let p = Path::new(path);
if p.is_absolute() { p.to_path_buf() } else { self.cwd.join(p) }
}
async fn gate(
&self,
inv: &ToolInvocation,
cancel: Option<&CancellationToken>,
) -> Result<(), String> {
if inv.name == "bash" {
let cmd = inv.input.get("command").and_then(Value::as_str).unwrap_or("");
let decision = classify_shell_command(cmd);
match decision.level {
ShellRiskLevel::Blocked => {
return Err(format!("命令在禁止清单上,已拒绝:{}", decision.reason));
}
ShellRiskLevel::SafeRead => return Ok(()),
ShellRiskLevel::BoundedWrite
if self.approval.advertise_mutating_tools() =>
{
return Ok(());
}
_ => {}
}
} else if is_read_only(&inv.name) {
return Ok(());
}
let approved = if let Some(tok) = cancel {
tokio::select! {
biased;
_ = tok.cancelled() => return Err("已取消".into()),
result = self.approval.approve(inv) => result,
}
} else {
self.approval.approve(inv).await
};
if approved { Ok(()) } else { Err("操作被拒绝".into()) }
}
}
#[async_trait]
impl ToolRuntime for LocalToolRuntime {
fn specs(&self) -> Vec<ToolSpec> {
let all = builtin_tool_specs();
if self.approval.advertise_mutating_tools() {
all
} else {
all.into_iter().filter(|s| is_read_only(&s.name)).collect()
}
}
async fn invoke(&self, inv: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
self.invoke_cancellable(inv, None).await
}
async fn invoke_cancellable(
&self,
inv: ToolInvocation,
cancel: Option<&CancellationToken>,
) -> Result<ToolOutcome, ToolRuntimeError> {
if let Err(reason) = self.gate(&inv, cancel).await {
return Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Denied, reason)),
attachments: vec![],
});
}
match inv.name.as_str() {
"bash" => bash_invoke(inv, cancel, &self.cwd, self.emit.clone()).await,
"read" => read_invoke(inv, self).await,
"write" => write_invoke(inv, self).await,
"edit" => edit_invoke(inv, self).await,
"glob" => glob_invoke(inv, self).await,
"grep" => grep_invoke(inv, self).await,
other => Err(ToolRuntimeError::UnknownTool(other.into())),
}
}
}
fn epoch_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
async fn bash_invoke(
inv: ToolInvocation,
cancel: Option<&CancellationToken>,
cwd: &Path,
emit: EmitFn,
) -> Result<ToolOutcome, ToolRuntimeError> {
let command = req_str(&inv, "command")?;
let soft_ms: u64 = inv.input.get("soft_timeout_ms")
.and_then(|v| v.as_u64())
.unwrap_or(10_000);
let hard_ms: u64 = inv.input.get("timeout_ms")
.and_then(|v| v.as_u64())
.unwrap_or(120_000)
.min(3_600_000);
let last_out = Arc::new(AtomicU64::new(epoch_ms()));
let mut child = Command::new("/bin/sh")
.args(["-lc", command])
.current_dir(cwd)
.kill_on_drop(true)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| ToolRuntimeError::Runtime(format!("spawn failed: {e}")))?;
let raw_stdout = child.stdout.take().expect("stdout piped");
let raw_stderr = child.stderr.take().expect("stderr piped");
let act1 = last_out.clone();
let emit_out = emit.clone();
let stdout_task = tokio::spawn(async move {
let mut lines = tokio::io::BufReader::new(raw_stdout).lines();
let mut buf = String::new();
while let Ok(Some(line)) = lines.next_line().await {
emit_out(json!({ "type": "bash_stdout_line", "line": line, "stream": "stdout" }));
act1.store(epoch_ms(), Ordering::Relaxed);
buf.push_str(&line);
buf.push('\n');
}
buf
});
let act2 = last_out.clone();
let emit_err = emit.clone();
let stderr_task = tokio::spawn(async move {
let mut lines = tokio::io::BufReader::new(raw_stderr).lines();
let mut buf = String::new();
while let Ok(Some(line)) = lines.next_line().await {
emit_err(json!({ "type": "bash_stdout_line", "line": line, "stream": "stderr" }));
act2.store(epoch_ms(), Ordering::Relaxed);
buf.push_str(&line);
buf.push('\n');
}
buf
});
let watcher_ts = last_out.clone();
let soft_watcher = async move {
let start = epoch_ms();
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
let now = epoch_ms();
if now.saturating_sub(start) >= soft_ms
&& now.saturating_sub(watcher_ts.load(Ordering::Relaxed)) >= soft_ms
{
return (now.saturating_sub(start), now.saturating_sub(watcher_ts.load(Ordering::Relaxed)));
}
}
};
let timed = async {
let (out, err) = tokio::join!(
async { stdout_task.await.unwrap_or_default() },
async { stderr_task.await.unwrap_or_default() },
);
let status = child.wait().await;
(out, err, status)
};
let hard_timer = tokio::time::sleep(Duration::from_millis(hard_ms));
let soft_err = |tot: u64, sil: u64| ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Timeout, format!(
"Soft timeout: no output for {sil}ms (total: {tot}ms). \
Pass `soft_timeout_ms` or `timeout_ms` to extend."
))),
attachments: vec![],
};
let hard_err = || ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Timeout, format!(
"Hard timeout: command did not finish in {hard_ms}ms. \
Pass `timeout_ms` to extend (max 3 600 000)."
))),
attachments: vec![],
};
let result: Result<(String, String, _), ToolOutcome> = if let Some(tok) = cancel {
tokio::select! {
v = timed => Ok(v),
(tot, sil) = soft_watcher => Err(soft_err(tot, sil)),
_ = hard_timer => Err(hard_err()),
_ = tok.cancelled() => Err(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Runtime, "cancelled")),
attachments: vec![],
}),
}
} else {
tokio::select! {
v = timed => Ok(v),
(tot, sil) = soft_watcher => Err(soft_err(tot, sil)),
_ = hard_timer => Err(hard_err()),
}
};
let (stdout, stderr, status_result) = match result {
Err(outcome) => return Ok(outcome),
Ok(v) => v,
};
let exit_code = status_result.map(|s| s.code().unwrap_or(-1)).unwrap_or(-1);
if exit_code != 0 {
Ok(ToolOutcome {
output: Err(ToolFailure::new(
ToolFailureKind::NonZeroExit,
truncate(format!("exit {exit_code}\nstdout: {stdout}\nstderr: {stderr}")),
)),
attachments: vec![],
})
} else {
Ok(ToolOutcome {
output: Ok(json!({
"stdout": truncate(stdout),
"stderr": truncate(stderr),
"exit_code": exit_code,
})),
attachments: vec![],
})
}
}
async fn read_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
let path = req_str(&inv, "path")?;
let resolved = rt.resolve(path);
match tokio::fs::read_to_string(&resolved).await {
Ok(content) => {
let total = content.lines().count();
let offset = inv.input.get("offset").and_then(Value::as_u64).unwrap_or(0) as usize;
let limit = inv.input.get("limit").and_then(Value::as_u64)
.map(|v| v.clamp(1, 2_000) as usize);
let selected: Vec<&str> = match limit {
Some(n) => content.lines().skip(offset).take(n).collect(),
None => content.lines().skip(offset).collect(),
};
let end = offset + selected.len();
let text = if selected.is_empty() {
String::new()
} else {
let mut t = selected.join("\n");
if content.ends_with('\n') && end == total { t.push('\n'); }
t
};
Ok(ToolOutcome {
output: Ok(json!({
"path": resolved.to_string_lossy(),
"content": truncate(text),
"offset": offset,
"limit": limit,
"start_line": if selected.is_empty() { Value::Null } else { json!(offset + 1) },
"end_line": if selected.is_empty() { Value::Null } else { json!(end) },
"total_lines": total,
"truncated": limit.map(|n| offset + n < total).unwrap_or(false),
})),
attachments: vec![],
})
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::NotFound,
format!("file not found: {}", resolved.display()))),
attachments: vec![],
}),
Err(e) => Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Runtime, format!("read error: {e}"))),
attachments: vec![],
}),
}
}
async fn write_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
let path = req_str(&inv, "path")?;
let content = req_str(&inv, "content")?;
let resolved = rt.resolve(path);
if let Some(parent) = resolved.parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await
.map_err(|e| ToolRuntimeError::Runtime(format!("mkdir: {e}")))?;
}
}
tokio::fs::write(&resolved, content).await
.map_err(|e| ToolRuntimeError::Runtime(format!("write error: {e}")))?;
Ok(ToolOutcome {
output: Ok(json!({ "path": resolved.to_string_lossy(), "written": true })),
attachments: vec![],
})
}
async fn edit_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
let path = req_str(&inv, "path")?;
let old_string = req_str(&inv, "old_string")?;
let new_string = inv.input.get("new_string").and_then(Value::as_str).unwrap_or("");
let replace_all = inv.input.get("replace_all").and_then(Value::as_bool).unwrap_or(false);
let resolved = rt.resolve(path);
let content = match tokio::fs::read_to_string(&resolved).await {
Ok(c) => c,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::NotFound,
format!("file not found: {}", resolved.display()))),
attachments: vec![],
}),
Err(e) => return Err(ToolRuntimeError::Runtime(e.to_string())),
};
let occurrences = content.matches(old_string).count();
if occurrences == 0 {
return Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::InvalidInput,
format!("old_string not found in {}", resolved.display()))),
attachments: vec![],
});
}
if !replace_all && occurrences > 1 {
return Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::InvalidInput,
format!("{occurrences} occurrences; pass replace_all=true"))),
attachments: vec![],
});
}
let replaced = if replace_all { occurrences } else { 1 };
let new_content = if replace_all {
content.replace(old_string, new_string)
} else {
content.replacen(old_string, new_string, 1)
};
tokio::fs::write(&resolved, new_content).await
.map_err(|e| ToolRuntimeError::Runtime(e.to_string()))?;
Ok(ToolOutcome {
output: Ok(json!({
"path": resolved.to_string_lossy(),
"replaced": replaced,
"old_lines": old_string.lines().count(),
"new_lines": new_string.lines().count(),
})),
attachments: vec![],
})
}
async fn glob_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
let pattern = req_str(&inv, "pattern")?.to_string();
let base = match inv.input.get("path").and_then(Value::as_str).filter(|s| !s.is_empty()) {
Some(p) => rt.resolve(p),
None => rt.cwd.clone(),
};
let (matches, truncated) = fs_glob_bounded(&pattern, &base);
Ok(ToolOutcome {
output: Ok(json!({
"pattern": pattern,
"count": matches.len(),
"matches": matches,
"truncated": truncated,
})),
attachments: vec![],
})
}
async fn grep_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
let pattern = req_str(&inv, "pattern")?.to_string();
let ci = inv.input.get("case_insensitive").and_then(Value::as_bool).unwrap_or(false);
let search = match inv.input.get("path").and_then(Value::as_str).filter(|s| !s.is_empty()) {
Some(p) => rt.resolve(p),
None => rt.cwd.clone(),
};
let mut cmd = Command::new("grep");
cmd.arg("-rn");
if ci { cmd.arg("-i"); }
cmd.args([
"--exclude-dir=node_modules",
"--exclude-dir=target",
"--exclude-dir=.git",
"--exclude-dir=dist",
"--exclude-dir=build",
"--exclude-dir=__pycache__",
"--exclude-dir=.venv",
"--exclude-dir=vendor",
"--exclude-dir=.next",
]);
cmd.arg("-e").arg(&pattern).arg("--").arg(&search);
cmd.current_dir(&rt.cwd);
match tokio::time::timeout(Duration::from_secs(30), cmd.output()).await {
Err(_) => Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Timeout, "grep timed out after 30s")),
attachments: vec![],
}),
Ok(Err(e)) => Err(ToolRuntimeError::Runtime(format!("grep spawn failed: {e}"))),
Ok(Ok(out)) => {
let code = out.status.code().unwrap_or(-1);
if code >= 2 {
let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
return Ok(ToolOutcome {
output: Err(ToolFailure::new(ToolFailureKind::Runtime,
truncate(format!("grep error: {stderr}")))),
attachments: vec![],
});
}
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
Ok(ToolOutcome {
output: Ok(json!({
"pattern": pattern,
"matches": truncate(stdout),
})),
attachments: vec![],
})
}
}
}
const MAX_TOOL_CHARS: usize = 60_000;
fn truncate(s: String) -> String {
let n = s.chars().count();
if n <= MAX_TOOL_CHARS { return s; }
let kept: String = s.chars().take(MAX_TOOL_CHARS).collect();
format!(
"{kept}\n\n[output truncated: {n} chars total, showing first {MAX_TOOL_CHARS}. \
Use a narrower command, path, or offset/limit to see more.]"
)
}
fn req_str<'a>(inv: &'a ToolInvocation, key: &str) -> Result<&'a str, ToolRuntimeError> {
inv.input
.get(key)
.and_then(Value::as_str)
.filter(|s| !s.is_empty())
.ok_or_else(|| ToolRuntimeError::InvalidInput {
tool: inv.name.clone(),
message: format!("missing field `{key}`"),
})
}