1use std::path::{Path, PathBuf};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use async_trait::async_trait;
7use serde_json::{json, Value};
8use tokio::io::AsyncBufReadExt;
9use tokio::process::Command;
10use tokio_util::sync::CancellationToken;
11
12use crate::shell_risk::{classify_shell_command, ShellRiskLevel};
13use crate::tools::{
14 builtin_tool_specs, fs_glob_bounded, ToolFailure, ToolFailureKind, ToolInvocation,
15 ToolOutcome, ToolRuntime, ToolRuntimeError, ToolSpec,
16};
17use crate::tools::approval::{is_read_only, ApprovalGate};
18
19pub type EmitFn = Arc<dyn Fn(Value) + Send + Sync + 'static>;
23
24pub struct LocalToolConfig {
26 pub cwd: Option<PathBuf>,
30 pub approval: Arc<dyn ApprovalGate>,
34 pub emit: EmitFn,
38}
39
40#[derive(Clone)]
47pub struct LocalToolRuntime {
48 cwd: PathBuf,
49 approval: Arc<dyn ApprovalGate>,
50 emit: EmitFn,
51}
52
53impl LocalToolRuntime {
54 pub fn new(config: LocalToolConfig) -> Self {
55 let cwd = config.cwd
56 .filter(|p| !p.as_os_str().is_empty())
57 .or_else(|| std::env::current_dir().ok())
58 .unwrap_or_else(|| PathBuf::from("/"));
59 Self { cwd, approval: config.approval, emit: config.emit }
60 }
61
62 fn resolve(&self, path: &str) -> PathBuf {
63 let p = Path::new(path);
64 if p.is_absolute() { p.to_path_buf() } else { self.cwd.join(p) }
65 }
66
67 async fn gate(
73 &self,
74 inv: &ToolInvocation,
75 cancel: Option<&CancellationToken>,
76 ) -> Result<(), String> {
77 if inv.name == "bash" {
78 let cmd = inv.input.get("command").and_then(Value::as_str).unwrap_or("");
79 let decision = classify_shell_command(cmd);
80 match decision.level {
81 ShellRiskLevel::Blocked => {
82 return Err(format!("命令在禁止清单上,已拒绝:{}", decision.reason));
83 }
84 ShellRiskLevel::SafeRead => return Ok(()),
85 ShellRiskLevel::BoundedWrite
86 if self.approval.advertise_mutating_tools() =>
87 {
88 return Ok(());
89 }
90 _ => {}
91 }
92 } else if is_read_only(&inv.name) {
93 return Ok(());
94 }
95
96 let approved = if let Some(tok) = cancel {
98 tokio::select! {
99 biased;
100 _ = tok.cancelled() => return Err("已取消".into()),
101 result = self.approval.approve(inv) => result,
102 }
103 } else {
104 self.approval.approve(inv).await
105 };
106
107 if approved { Ok(()) } else { Err("操作被拒绝".into()) }
108 }
109}
110
111#[async_trait]
112impl ToolRuntime for LocalToolRuntime {
113 fn specs(&self) -> Vec<ToolSpec> {
114 let all = builtin_tool_specs();
115 if self.approval.advertise_mutating_tools() {
116 all
117 } else {
118 all.into_iter().filter(|s| is_read_only(&s.name)).collect()
119 }
120 }
121
122 async fn invoke(&self, inv: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
123 self.invoke_cancellable(inv, None).await
124 }
125
126 async fn invoke_cancellable(
127 &self,
128 inv: ToolInvocation,
129 cancel: Option<&CancellationToken>,
130 ) -> Result<ToolOutcome, ToolRuntimeError> {
131 if let Err(reason) = self.gate(&inv, cancel).await {
132 return Ok(ToolOutcome {
133 output: Err(ToolFailure::new(ToolFailureKind::Denied, reason)),
134 attachments: vec![],
135 });
136 }
137 match inv.name.as_str() {
138 "bash" => bash_invoke(inv, cancel, &self.cwd, self.emit.clone()).await,
139 "read" => read_invoke(inv, self).await,
140 "write" => write_invoke(inv, self).await,
141 "edit" => edit_invoke(inv, self).await,
142 "glob" => glob_invoke(inv, self).await,
143 "grep" => grep_invoke(inv, self).await,
144 other => Err(ToolRuntimeError::UnknownTool(other.into())),
145 }
146 }
147}
148
149fn epoch_ms() -> u64 {
152 SystemTime::now()
153 .duration_since(UNIX_EPOCH)
154 .unwrap_or_default()
155 .as_millis() as u64
156}
157
158async fn bash_invoke(
159 inv: ToolInvocation,
160 cancel: Option<&CancellationToken>,
161 cwd: &Path,
162 emit: EmitFn,
163) -> Result<ToolOutcome, ToolRuntimeError> {
164 let command = req_str(&inv, "command")?;
165 let id = &*inv.id;
166
167 let soft_ms: u64 = inv.input.get("soft_timeout_ms")
174 .and_then(|v| v.as_u64())
175 .unwrap_or(10_000);
176 let hard_ms: u64 = inv.input.get("timeout_ms")
177 .and_then(|v| v.as_u64())
178 .unwrap_or(120_000)
179 .min(3_600_000);
180
181 let last_out = Arc::new(AtomicU64::new(epoch_ms()));
182 let stdout_buf = Arc::new(Mutex::new(String::new()));
183 let stderr_buf = Arc::new(Mutex::new(String::new()));
184
185 let shell = if Path::new("/bin/bash").exists() { "/bin/bash" } else { "/bin/sh" };
186 let mut child = Command::new(shell)
187 .args(["-lc", command])
188 .current_dir(cwd)
189 .kill_on_drop(true)
190 .stdout(std::process::Stdio::piped())
191 .stderr(std::process::Stdio::piped())
192 .spawn()
193 .map_err(|e| ToolRuntimeError::Runtime(format!("spawn failed: {e}")))?;
194
195 let raw_stdout = child.stdout.take().expect("stdout piped");
196 let raw_stderr = child.stderr.take().expect("stderr piped");
197
198 let act1 = last_out.clone();
199 let emit_out = emit.clone();
200 let stdout_acc = stdout_buf.clone();
201 let stdout_task = tokio::spawn(async move {
202 let mut lines = tokio::io::BufReader::new(raw_stdout).lines();
203 let mut buf = String::new();
204 while let Ok(Some(line)) = lines.next_line().await {
205 emit_out(json!({ "type": "bash_stdout_line", "line": line, "stream": "stdout" }));
206 act1.store(epoch_ms(), Ordering::Relaxed);
207 buf.push_str(&line);
208 buf.push('\n');
209 if let Ok(mut acc) = stdout_acc.lock() {
210 acc.push_str(&line);
211 acc.push('\n');
212 }
213 }
214 buf
215 });
216
217 let act2 = last_out.clone();
218 let emit_err = emit.clone();
219 let stderr_acc = stderr_buf.clone();
220 let stderr_task = tokio::spawn(async move {
221 let mut lines = tokio::io::BufReader::new(raw_stderr).lines();
222 let mut buf = String::new();
223 while let Ok(Some(line)) = lines.next_line().await {
224 emit_err(json!({ "type": "bash_stdout_line", "line": line, "stream": "stderr" }));
225 act2.store(epoch_ms(), Ordering::Relaxed);
226 buf.push_str(&line);
227 buf.push('\n');
228 if let Ok(mut acc) = stderr_acc.lock() {
229 acc.push_str(&line);
230 acc.push('\n');
231 }
232 }
233 buf
234 });
235
236 let watcher_ts = last_out.clone();
237 let soft_watcher = async move {
238 let start = epoch_ms();
239 loop {
240 tokio::time::sleep(Duration::from_millis(500)).await;
241 let now = epoch_ms();
242 if now.saturating_sub(start) >= soft_ms
243 && now.saturating_sub(watcher_ts.load(Ordering::Relaxed)) >= soft_ms
244 {
245 return (now.saturating_sub(start), now.saturating_sub(watcher_ts.load(Ordering::Relaxed)));
246 }
247 }
248 };
249
250 let timed = async {
251 let (out, err) = tokio::join!(
252 async { stdout_task.await.unwrap_or_default() },
253 async { stderr_task.await.unwrap_or_default() },
254 );
255 let status = child.wait().await;
256 (out, err, status)
257 };
258
259 let hard_timer = tokio::time::sleep(Duration::from_millis(hard_ms));
260
261 let timeout_outcome = |kind: &str, message: String| ToolOutcome {
262 output: Ok(json!({
263 "command": command,
264 "shell": shell,
265 "stdout": bound_output(stdout_buf.lock().map(|s| s.clone()).unwrap_or_default(), id, "stdout"),
266 "stderr": bound_output(stderr_buf.lock().map(|s| s.clone()).unwrap_or_default(), id, "stderr"),
267 "exit_code": null,
268 "success": false,
269 "timed_out": true,
270 "timeout_kind": kind,
271 "message": message,
272 })),
273 attachments: vec![],
274 };
275 let soft_err = |tot: u64, sil: u64| timeout_outcome(
276 "soft",
277 format!(
278 "Command produced no output for {sil}ms (total {tot}ms). \
279Retry with larger `soft_timeout_ms` or `timeout_ms` if it is expected to take longer."
280 ),
281 );
282 let hard_err = || timeout_outcome(
283 "hard",
284 format!(
285 "Command did not finish in {hard_ms}ms. Retry with a larger `timeout_ms` if it is expected to take longer."
286 ),
287 );
288
289 let result: Result<(String, String, _), ToolOutcome> = if let Some(tok) = cancel {
290 tokio::select! {
291 v = timed => Ok(v),
292 (tot, sil) = soft_watcher => Err(soft_err(tot, sil)),
293 _ = hard_timer => Err(hard_err()),
294 _ = tok.cancelled() => Err(ToolOutcome {
295 output: Err(ToolFailure::new(ToolFailureKind::Runtime, "cancelled")),
296 attachments: vec![],
297 }),
298 }
299 } else {
300 tokio::select! {
301 v = timed => Ok(v),
302 (tot, sil) = soft_watcher => Err(soft_err(tot, sil)),
303 _ = hard_timer => Err(hard_err()),
304 }
305 };
306
307 let (stdout, stderr, status_result) = match result {
308 Err(outcome) => return Ok(outcome),
309 Ok(v) => v,
310 };
311
312 let exit_code = status_result.map(|s| s.code().unwrap_or(-1)).unwrap_or(-1);
313
314 Ok(ToolOutcome {
315 output: Ok(json!({
316 "command": command,
317 "shell": shell,
318 "stdout": bound_output(stdout, id, "stdout"),
319 "stderr": bound_output(stderr, id, "stderr"),
320 "exit_code": exit_code,
321 "success": exit_code == 0,
322 })),
323 attachments: vec![],
324 })
325}
326
327async fn read_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
330 let path = req_str(&inv, "path")?;
331 let resolved = rt.resolve(path);
332 match tokio::fs::read_to_string(&resolved).await {
333 Ok(content) => {
334 let total = content.lines().count();
335 let offset = inv.input.get("offset").and_then(Value::as_u64).unwrap_or(0) as usize;
336 let limit = inv.input.get("limit").and_then(Value::as_u64)
337 .map(|v| v.clamp(1, 2_000) as usize);
338 let selected: Vec<&str> = match limit {
339 Some(n) => content.lines().skip(offset).take(n).collect(),
340 None => content.lines().skip(offset).collect(),
341 };
342 let end = offset + selected.len();
343 let text = if selected.is_empty() {
344 String::new()
345 } else {
346 let mut t = selected.join("\n");
347 if content.ends_with('\n') && end == total { t.push('\n'); }
348 t
349 };
350 Ok(ToolOutcome {
351 output: Ok(json!({
352 "path": resolved.to_string_lossy(),
353 "content": truncate(text),
354 "offset": offset,
355 "limit": limit,
356 "start_line": if selected.is_empty() { Value::Null } else { json!(offset + 1) },
357 "end_line": if selected.is_empty() { Value::Null } else { json!(end) },
358 "total_lines": total,
359 "truncated": limit.map(|n| offset + n < total).unwrap_or(false),
360 })),
361 attachments: vec![],
362 })
363 }
364 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(ToolOutcome {
365 output: Err(ToolFailure::new(ToolFailureKind::NotFound,
366 format!("file not found: {}", resolved.display()))),
367 attachments: vec![],
368 }),
369 Err(e) => Ok(ToolOutcome {
370 output: Err(ToolFailure::new(ToolFailureKind::Runtime, format!("read error: {e}"))),
371 attachments: vec![],
372 }),
373 }
374}
375
376async fn write_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
379 let path = req_str(&inv, "path")?;
380 let content = req_str(&inv, "content")?;
381 let resolved = rt.resolve(path);
382 if let Some(parent) = resolved.parent() {
383 if !parent.as_os_str().is_empty() {
384 tokio::fs::create_dir_all(parent).await
385 .map_err(|e| ToolRuntimeError::Runtime(format!("mkdir: {e}")))?;
386 }
387 }
388 tokio::fs::write(&resolved, content).await
389 .map_err(|e| ToolRuntimeError::Runtime(format!("write error: {e}")))?;
390 Ok(ToolOutcome {
391 output: Ok(json!({ "path": resolved.to_string_lossy(), "written": true })),
392 attachments: vec![],
393 })
394}
395
396async fn edit_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
399 let path = req_str(&inv, "path")?;
400 let old_string = req_str(&inv, "old_string")?;
401 let new_string = inv.input.get("new_string").and_then(Value::as_str).unwrap_or("");
402 let replace_all = inv.input.get("replace_all").and_then(Value::as_bool).unwrap_or(false);
403 let resolved = rt.resolve(path);
404
405 let content = match tokio::fs::read_to_string(&resolved).await {
406 Ok(c) => c,
407 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(ToolOutcome {
408 output: Err(ToolFailure::new(ToolFailureKind::NotFound,
409 format!("file not found: {}", resolved.display()))),
410 attachments: vec![],
411 }),
412 Err(e) => return Err(ToolRuntimeError::Runtime(e.to_string())),
413 };
414
415 let occurrences = content.matches(old_string).count();
416 if occurrences == 0 {
417 return Ok(ToolOutcome {
418 output: Err(ToolFailure::new(ToolFailureKind::InvalidInput,
419 "Could not find old_string in the file. It must match exactly, including whitespace and indentation. Read the file again before retrying.".to_string())),
420 attachments: vec![],
421 });
422 }
423 if !replace_all && occurrences > 1 {
424 return Ok(ToolOutcome {
425 output: Err(ToolFailure::new(ToolFailureKind::InvalidInput,
426 format!("Found {occurrences} exact matches for old_string. Provide more surrounding context or set replace_all=true."))),
427 attachments: vec![],
428 });
429 }
430
431 let replaced = if replace_all { occurrences } else { 1 };
432 let new_content = if replace_all {
433 content.replace(old_string, new_string)
434 } else {
435 content.replacen(old_string, new_string, 1)
436 };
437 tokio::fs::write(&resolved, new_content).await
438 .map_err(|e| ToolRuntimeError::Runtime(e.to_string()))?;
439 Ok(ToolOutcome {
440 output: Ok(json!({
441 "path": resolved.to_string_lossy(),
442 "replaced": replaced,
443 "old_lines": old_string.lines().count(),
444 "new_lines": new_string.lines().count(),
445 })),
446 attachments: vec![],
447 })
448}
449
450async fn glob_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
453 let pattern = req_str(&inv, "pattern")?.to_string();
454 let base = match inv.input.get("path").and_then(Value::as_str).filter(|s| !s.is_empty()) {
455 Some(p) => rt.resolve(p),
456 None => rt.cwd.clone(),
457 };
458 let (matches, truncated) = fs_glob_bounded(&pattern, &base);
459 Ok(ToolOutcome {
460 output: Ok(json!({
461 "pattern": pattern,
462 "count": matches.len(),
463 "matches": matches,
464 "truncated": truncated,
465 })),
466 attachments: vec![],
467 })
468}
469
470async fn grep_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
473 let pattern = req_str(&inv, "pattern")?.to_string();
474 let ci = inv.input.get("case_insensitive").and_then(Value::as_bool).unwrap_or(false);
475 let search = match inv.input.get("path").and_then(Value::as_str).filter(|s| !s.is_empty()) {
476 Some(p) => rt.resolve(p),
477 None => rt.cwd.clone(),
478 };
479
480 let mut cmd = Command::new("grep");
481 cmd.arg("-rn");
482 if ci { cmd.arg("-i"); }
483 cmd.args([
484 "--exclude-dir=node_modules",
485 "--exclude-dir=target",
486 "--exclude-dir=.git",
487 "--exclude-dir=dist",
488 "--exclude-dir=build",
489 "--exclude-dir=__pycache__",
490 "--exclude-dir=.venv",
491 "--exclude-dir=vendor",
492 "--exclude-dir=.next",
493 ]);
494 cmd.arg("-e").arg(&pattern).arg("--").arg(&search);
495 cmd.current_dir(&rt.cwd);
496
497 match tokio::time::timeout(Duration::from_secs(30), cmd.output()).await {
498 Err(_) => Ok(ToolOutcome {
499 output: Err(ToolFailure::new(ToolFailureKind::Timeout, "grep timed out after 30s")),
500 attachments: vec![],
501 }),
502 Ok(Err(e)) => Err(ToolRuntimeError::Runtime(format!("grep spawn failed: {e}"))),
503 Ok(Ok(out)) => {
504 let code = out.status.code().unwrap_or(-1);
505 if code >= 2 {
506 let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
507 return Ok(ToolOutcome {
508 output: Err(ToolFailure::new(ToolFailureKind::Runtime,
509 truncate(format!("grep error: {stderr}")))),
510 attachments: vec![],
511 });
512 }
513 let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
514 Ok(ToolOutcome {
515 output: Ok(json!({
516 "pattern": pattern,
517 "matches": bound_output(stdout, &inv.id, "matches"),
518 })),
519 attachments: vec![],
520 })
521 }
522 }
523}
524
525fn bound_output(content: String, id: &str, suffix: &str) -> String {
531 let path = format!("/tmp/harness_out_{id}_{suffix}.txt");
532 match crate::tools::bounded_preview(&content, &path) {
533 None => content,
534 Some(preview) => {
535 let _ = std::fs::write(&path, &content);
536 preview
537 }
538 }
539}
540
541fn truncate(s: String) -> String {
544 crate::tools::clip_head(s)
545}
546
547fn req_str<'a>(inv: &'a ToolInvocation, key: &str) -> Result<&'a str, ToolRuntimeError> {
550 inv.input
551 .get(key)
552 .and_then(Value::as_str)
553 .filter(|s| !s.is_empty())
554 .ok_or_else(|| ToolRuntimeError::InvalidInput {
555 tool: inv.name.clone(),
556 message: format!("missing field `{key}`"),
557 })
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use crate::tools::approval::YoloApproval;
564
565 fn runtime() -> LocalToolRuntime {
566 LocalToolRuntime::new(LocalToolConfig {
567 cwd: Some(std::env::temp_dir()),
568 approval: Arc::new(YoloApproval),
569 emit: Arc::new(|_| {}),
570 })
571 }
572
573 #[tokio::test]
574 async fn bash_non_zero_exit_returns_structured_result() {
575 let out = runtime()
576 .invoke(ToolInvocation {
577 id: "tc_nonzero".into(),
578 name: "bash".into(),
579 input: json!({"command": "printf nope >&2; exit 7"}),
580 })
581 .await
582 .unwrap()
583 .output
584 .unwrap();
585 assert_eq!(out["exit_code"], 7);
586 assert_eq!(out["success"], false);
587 assert_eq!(out["stderr"], "nope\n");
588 }
589
590 #[tokio::test]
591 async fn bash_timeout_returns_structured_result() {
592 let out = runtime()
593 .invoke(ToolInvocation {
594 id: "tc_timeout".into(),
595 name: "bash".into(),
596 input: json!({
597 "command": "sleep 2",
598 "soft_timeout_ms": 1000,
599 "timeout_ms": 5000
600 }),
601 })
602 .await
603 .unwrap()
604 .output
605 .unwrap();
606 assert_eq!(out["success"], false);
607 assert_eq!(out["timed_out"], true);
608 assert_eq!(out["timeout_kind"], "soft");
609 }
610
611 #[tokio::test]
612 async fn bash_tool_supports_bash_syntax_when_bash_exists() {
613 if !Path::new("/bin/bash").exists() {
614 return;
615 }
616 let out = runtime()
617 .invoke(ToolInvocation {
618 id: "tc_bash_syntax".into(),
619 name: "bash".into(),
620 input: json!({"command": "diff <(printf a) <(printf a)"}),
621 })
622 .await
623 .unwrap()
624 .output
625 .unwrap();
626 assert_eq!(out["success"], true);
627 assert_eq!(out["exit_code"], 0);
628 assert_eq!(out["shell"], "/bin/bash");
629 }
630}