Skip to main content

harness/tools/
local.rs

1use std::path::{Path, PathBuf};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use async_trait::async_trait;
7use serde_json::{json, Value};
8use tokio::io::AsyncBufReadExt;
9use tokio::process::Command;
10use tokio_util::sync::CancellationToken;
11
12use crate::shell_risk::{classify_shell_command, ShellRiskLevel};
13use crate::tools::{
14    builtin_tool_specs, fs_glob_bounded, ToolFailure, ToolFailureKind, ToolInvocation,
15    ToolOutcome, ToolRuntime, ToolRuntimeError, ToolSpec,
16};
17use crate::tools::approval::{is_read_only, ApprovalGate};
18
19/// Opaque event emitter: receives structured JSON events produced during tool
20/// execution (e.g. `bash_stdout_line`). Use an `Arc<dyn Fn(Value) + ...>`
21/// or provide a no-op with `Arc::new(|_| {})`.
22pub type EmitFn = Arc<dyn Fn(Value) + Send + Sync + 'static>;
23
24/// Configuration for `LocalToolRuntime`.
25pub struct LocalToolConfig {
26    /// Absolute path to the working directory for all tool operations.
27    /// Relative paths inside tool arguments are resolved against this.
28    /// Defaults to the process's `std::env::current_dir()` when `None`.
29    pub cwd: Option<PathBuf>,
30    /// Controls which tool calls are allowed and which tool specs are
31    /// advertised to the model. Typically `YoloApproval`, `PlanApproval`,
32    /// or a custom gate (e.g. `TauriApproval` in a desktop app).
33    pub approval: Arc<dyn ApprovalGate>,
34    /// Receives structured events produced during tool execution.
35    /// Most callers forward these to a UI channel. Pass `Arc::new(|_| {})`
36    /// to discard them.
37    pub emit: EmitFn,
38}
39
40/// Tool runtime that executes bash, read, write, edit, glob and grep
41/// directly on the local filesystem.
42///
43/// All tool output is capped at `MAX_TOOL_CHARS` characters so large
44/// files and noisy commands don't blow out the model's context window.
45/// Glob/grep skip dependency and build directories automatically.
46#[derive(Clone)]
47pub struct LocalToolRuntime {
48    cwd: PathBuf,
49    approval: Arc<dyn ApprovalGate>,
50    emit: EmitFn,
51}
52
53impl LocalToolRuntime {
54    pub fn new(config: LocalToolConfig) -> Self {
55        let cwd = config.cwd
56            .filter(|p| !p.as_os_str().is_empty())
57            .or_else(|| std::env::current_dir().ok())
58            .unwrap_or_else(|| PathBuf::from("/"));
59        Self { cwd, approval: config.approval, emit: config.emit }
60    }
61
62    fn resolve(&self, path: &str) -> PathBuf {
63        let p = Path::new(path);
64        if p.is_absolute() { p.to_path_buf() } else { self.cwd.join(p) }
65    }
66
67    /// Gate that decides whether a tool invocation may proceed.
68    ///
69    /// * Hard-blocked bash commands are rejected in every mode.
70    /// * Read-only bash / read / glob / grep pass without hitting the gate.
71    /// * Everything else goes to `approval.approve()`.
72    async fn gate(
73        &self,
74        inv: &ToolInvocation,
75        cancel: Option<&CancellationToken>,
76    ) -> Result<(), String> {
77        if inv.name == "bash" {
78            let cmd = inv.input.get("command").and_then(Value::as_str).unwrap_or("");
79            let decision = classify_shell_command(cmd);
80            match decision.level {
81                ShellRiskLevel::Blocked => {
82                    return Err(format!("命令在禁止清单上,已拒绝:{}", decision.reason));
83                }
84                ShellRiskLevel::SafeRead => return Ok(()),
85                ShellRiskLevel::BoundedWrite
86                    if self.approval.advertise_mutating_tools() =>
87                {
88                    return Ok(());
89                }
90                _ => {}
91            }
92        } else if is_read_only(&inv.name) {
93            return Ok(());
94        }
95
96        // Non-read-only tool → delegate to the approval gate.
97        let approved = if let Some(tok) = cancel {
98            tokio::select! {
99                biased;
100                _ = tok.cancelled() => return Err("已取消".into()),
101                result = self.approval.approve(inv) => result,
102            }
103        } else {
104            self.approval.approve(inv).await
105        };
106
107        if approved { Ok(()) } else { Err("操作被拒绝".into()) }
108    }
109}
110
111#[async_trait]
112impl ToolRuntime for LocalToolRuntime {
113    fn specs(&self) -> Vec<ToolSpec> {
114        let all = builtin_tool_specs();
115        if self.approval.advertise_mutating_tools() {
116            all
117        } else {
118            all.into_iter().filter(|s| is_read_only(&s.name)).collect()
119        }
120    }
121
122    async fn invoke(&self, inv: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
123        self.invoke_cancellable(inv, None).await
124    }
125
126    async fn invoke_cancellable(
127        &self,
128        inv: ToolInvocation,
129        cancel: Option<&CancellationToken>,
130    ) -> Result<ToolOutcome, ToolRuntimeError> {
131        if let Err(reason) = self.gate(&inv, cancel).await {
132            return Ok(ToolOutcome {
133                output: Err(ToolFailure::new(ToolFailureKind::Denied, reason)),
134                attachments: vec![],
135            });
136        }
137        match inv.name.as_str() {
138            "bash"  => bash_invoke(inv, cancel, &self.cwd, self.emit.clone()).await,
139            "read"  => read_invoke(inv, self).await,
140            "write" => write_invoke(inv, self).await,
141            "edit"  => edit_invoke(inv, self).await,
142            "glob"  => glob_invoke(inv, self).await,
143            "grep"  => grep_invoke(inv, self).await,
144            other   => Err(ToolRuntimeError::UnknownTool(other.into())),
145        }
146    }
147}
148
149// ── bash ──────────────────────────────────────────────────────────────────────
150
151fn epoch_ms() -> u64 {
152    SystemTime::now()
153        .duration_since(UNIX_EPOCH)
154        .unwrap_or_default()
155        .as_millis() as u64
156}
157
158async fn bash_invoke(
159    inv: ToolInvocation,
160    cancel: Option<&CancellationToken>,
161    cwd: &Path,
162    emit: EmitFn,
163) -> Result<ToolOutcome, ToolRuntimeError> {
164    let command = req_str(&inv, "command")?;
165    let id = &*inv.id;
166
167    // Dual-layer timeout:
168    //   soft_timeout_ms — no-output detector: if the process produces
169    //     nothing for this long it is killed and the model is told to retry.
170    //     Streaming output resets the clock, so a long build that prints
171    //     progress will never hit this.
172    //   timeout_ms — absolute hard ceiling, regardless of output activity.
173    let soft_ms: u64 = inv.input.get("soft_timeout_ms")
174        .and_then(|v| v.as_u64())
175        .unwrap_or(10_000);
176    let hard_ms: u64 = inv.input.get("timeout_ms")
177        .and_then(|v| v.as_u64())
178        .unwrap_or(120_000)
179        .min(3_600_000);
180
181    let last_out = Arc::new(AtomicU64::new(epoch_ms()));
182    let stdout_buf = Arc::new(Mutex::new(String::new()));
183    let stderr_buf = Arc::new(Mutex::new(String::new()));
184
185    let shell = if Path::new("/bin/bash").exists() { "/bin/bash" } else { "/bin/sh" };
186    let mut child = Command::new(shell)
187        .args(["-lc", command])
188        .current_dir(cwd)
189        .kill_on_drop(true)
190        .stdout(std::process::Stdio::piped())
191        .stderr(std::process::Stdio::piped())
192        .spawn()
193        .map_err(|e| ToolRuntimeError::Runtime(format!("spawn failed: {e}")))?;
194
195    let raw_stdout = child.stdout.take().expect("stdout piped");
196    let raw_stderr = child.stderr.take().expect("stderr piped");
197
198    let act1 = last_out.clone();
199    let emit_out = emit.clone();
200    let stdout_acc = stdout_buf.clone();
201    let stdout_task = tokio::spawn(async move {
202        let mut lines = tokio::io::BufReader::new(raw_stdout).lines();
203        let mut buf = String::new();
204        while let Ok(Some(line)) = lines.next_line().await {
205            emit_out(json!({ "type": "bash_stdout_line", "line": line, "stream": "stdout" }));
206            act1.store(epoch_ms(), Ordering::Relaxed);
207            buf.push_str(&line);
208            buf.push('\n');
209            if let Ok(mut acc) = stdout_acc.lock() {
210                acc.push_str(&line);
211                acc.push('\n');
212            }
213        }
214        buf
215    });
216
217    let act2 = last_out.clone();
218    let emit_err = emit.clone();
219    let stderr_acc = stderr_buf.clone();
220    let stderr_task = tokio::spawn(async move {
221        let mut lines = tokio::io::BufReader::new(raw_stderr).lines();
222        let mut buf = String::new();
223        while let Ok(Some(line)) = lines.next_line().await {
224            emit_err(json!({ "type": "bash_stdout_line", "line": line, "stream": "stderr" }));
225            act2.store(epoch_ms(), Ordering::Relaxed);
226            buf.push_str(&line);
227            buf.push('\n');
228            if let Ok(mut acc) = stderr_acc.lock() {
229                acc.push_str(&line);
230                acc.push('\n');
231            }
232        }
233        buf
234    });
235
236    let watcher_ts = last_out.clone();
237    let soft_watcher = async move {
238        let start = epoch_ms();
239        loop {
240            tokio::time::sleep(Duration::from_millis(500)).await;
241            let now = epoch_ms();
242            if now.saturating_sub(start) >= soft_ms
243                && now.saturating_sub(watcher_ts.load(Ordering::Relaxed)) >= soft_ms
244            {
245                return (now.saturating_sub(start), now.saturating_sub(watcher_ts.load(Ordering::Relaxed)));
246            }
247        }
248    };
249
250    let timed = async {
251        let (out, err) = tokio::join!(
252            async { stdout_task.await.unwrap_or_default() },
253            async { stderr_task.await.unwrap_or_default() },
254        );
255        let status = child.wait().await;
256        (out, err, status)
257    };
258
259    let hard_timer = tokio::time::sleep(Duration::from_millis(hard_ms));
260
261    let timeout_outcome = |kind: &str, message: String| ToolOutcome {
262        output: Ok(json!({
263            "command": command,
264            "shell": shell,
265            "stdout": bound_output(stdout_buf.lock().map(|s| s.clone()).unwrap_or_default(), id, "stdout"),
266            "stderr": bound_output(stderr_buf.lock().map(|s| s.clone()).unwrap_or_default(), id, "stderr"),
267            "exit_code": null,
268            "success": false,
269            "timed_out": true,
270            "timeout_kind": kind,
271            "message": message,
272        })),
273        attachments: vec![],
274    };
275    let soft_err = |tot: u64, sil: u64| timeout_outcome(
276        "soft",
277        format!(
278            "Command produced no output for {sil}ms (total {tot}ms). \
279Retry with larger `soft_timeout_ms` or `timeout_ms` if it is expected to take longer."
280        ),
281    );
282    let hard_err = || timeout_outcome(
283        "hard",
284        format!(
285            "Command did not finish in {hard_ms}ms. Retry with a larger `timeout_ms` if it is expected to take longer."
286        ),
287    );
288
289    let result: Result<(String, String, _), ToolOutcome> = if let Some(tok) = cancel {
290        tokio::select! {
291            v = timed => Ok(v),
292            (tot, sil) = soft_watcher => Err(soft_err(tot, sil)),
293            _ = hard_timer => Err(hard_err()),
294            _ = tok.cancelled() => Err(ToolOutcome {
295                output: Err(ToolFailure::new(ToolFailureKind::Runtime, "cancelled")),
296                attachments: vec![],
297            }),
298        }
299    } else {
300        tokio::select! {
301            v = timed => Ok(v),
302            (tot, sil) = soft_watcher => Err(soft_err(tot, sil)),
303            _ = hard_timer => Err(hard_err()),
304        }
305    };
306
307    let (stdout, stderr, status_result) = match result {
308        Err(outcome) => return Ok(outcome),
309        Ok(v) => v,
310    };
311
312    let exit_code = status_result.map(|s| s.code().unwrap_or(-1)).unwrap_or(-1);
313
314    Ok(ToolOutcome {
315        output: Ok(json!({
316            "command": command,
317            "shell": shell,
318            "stdout": bound_output(stdout, id, "stdout"),
319            "stderr": bound_output(stderr, id, "stderr"),
320            "exit_code": exit_code,
321            "success": exit_code == 0,
322        })),
323        attachments: vec![],
324    })
325}
326
327// ── read ──────────────────────────────────────────────────────────────────────
328
329async fn read_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
330    let path = req_str(&inv, "path")?;
331    let resolved = rt.resolve(path);
332    match tokio::fs::read_to_string(&resolved).await {
333        Ok(content) => {
334            let total = content.lines().count();
335            let offset = inv.input.get("offset").and_then(Value::as_u64).unwrap_or(0) as usize;
336            let limit = inv.input.get("limit").and_then(Value::as_u64)
337                .map(|v| v.clamp(1, 2_000) as usize);
338            let selected: Vec<&str> = match limit {
339                Some(n) => content.lines().skip(offset).take(n).collect(),
340                None    => content.lines().skip(offset).collect(),
341            };
342            let end = offset + selected.len();
343            let text = if selected.is_empty() {
344                String::new()
345            } else {
346                let mut t = selected.join("\n");
347                if content.ends_with('\n') && end == total { t.push('\n'); }
348                t
349            };
350            Ok(ToolOutcome {
351                output: Ok(json!({
352                    "path": resolved.to_string_lossy(),
353                    "content": truncate(text),
354                    "offset": offset,
355                    "limit": limit,
356                    "start_line": if selected.is_empty() { Value::Null } else { json!(offset + 1) },
357                    "end_line": if selected.is_empty() { Value::Null } else { json!(end) },
358                    "total_lines": total,
359                    "truncated": limit.map(|n| offset + n < total).unwrap_or(false),
360                })),
361                attachments: vec![],
362            })
363        }
364        Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(ToolOutcome {
365            output: Err(ToolFailure::new(ToolFailureKind::NotFound,
366                format!("file not found: {}", resolved.display()))),
367            attachments: vec![],
368        }),
369        Err(e) => Ok(ToolOutcome {
370            output: Err(ToolFailure::new(ToolFailureKind::Runtime, format!("read error: {e}"))),
371            attachments: vec![],
372        }),
373    }
374}
375
376// ── write ─────────────────────────────────────────────────────────────────────
377
378async fn write_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
379    let path = req_str(&inv, "path")?;
380    let content = req_str(&inv, "content")?;
381    let resolved = rt.resolve(path);
382    if let Some(parent) = resolved.parent() {
383        if !parent.as_os_str().is_empty() {
384            tokio::fs::create_dir_all(parent).await
385                .map_err(|e| ToolRuntimeError::Runtime(format!("mkdir: {e}")))?;
386        }
387    }
388    tokio::fs::write(&resolved, content).await
389        .map_err(|e| ToolRuntimeError::Runtime(format!("write error: {e}")))?;
390    Ok(ToolOutcome {
391        output: Ok(json!({ "path": resolved.to_string_lossy(), "written": true })),
392        attachments: vec![],
393    })
394}
395
396// ── edit ──────────────────────────────────────────────────────────────────────
397
398async fn edit_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
399    let path = req_str(&inv, "path")?;
400    let old_string = req_str(&inv, "old_string")?;
401    let new_string = inv.input.get("new_string").and_then(Value::as_str).unwrap_or("");
402    let replace_all = inv.input.get("replace_all").and_then(Value::as_bool).unwrap_or(false);
403    let resolved = rt.resolve(path);
404
405    let content = match tokio::fs::read_to_string(&resolved).await {
406        Ok(c) => c,
407        Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(ToolOutcome {
408            output: Err(ToolFailure::new(ToolFailureKind::NotFound,
409                format!("file not found: {}", resolved.display()))),
410            attachments: vec![],
411        }),
412        Err(e) => return Err(ToolRuntimeError::Runtime(e.to_string())),
413    };
414
415    let occurrences = content.matches(old_string).count();
416    if occurrences == 0 {
417        return Ok(ToolOutcome {
418            output: Err(ToolFailure::new(ToolFailureKind::InvalidInput,
419                "Could not find old_string in the file. It must match exactly, including whitespace and indentation. Read the file again before retrying.".to_string())),
420            attachments: vec![],
421        });
422    }
423    if !replace_all && occurrences > 1 {
424        return Ok(ToolOutcome {
425            output: Err(ToolFailure::new(ToolFailureKind::InvalidInput,
426                format!("Found {occurrences} exact matches for old_string. Provide more surrounding context or set replace_all=true."))),
427            attachments: vec![],
428        });
429    }
430
431    let replaced = if replace_all { occurrences } else { 1 };
432    let new_content = if replace_all {
433        content.replace(old_string, new_string)
434    } else {
435        content.replacen(old_string, new_string, 1)
436    };
437    tokio::fs::write(&resolved, new_content).await
438        .map_err(|e| ToolRuntimeError::Runtime(e.to_string()))?;
439    Ok(ToolOutcome {
440        output: Ok(json!({
441            "path": resolved.to_string_lossy(),
442            "replaced": replaced,
443            "old_lines": old_string.lines().count(),
444            "new_lines": new_string.lines().count(),
445        })),
446        attachments: vec![],
447    })
448}
449
450// ── glob ──────────────────────────────────────────────────────────────────────
451
452async fn glob_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
453    let pattern = req_str(&inv, "pattern")?.to_string();
454    let base = match inv.input.get("path").and_then(Value::as_str).filter(|s| !s.is_empty()) {
455        Some(p) => rt.resolve(p),
456        None    => rt.cwd.clone(),
457    };
458    let (matches, truncated) = fs_glob_bounded(&pattern, &base);
459    Ok(ToolOutcome {
460        output: Ok(json!({
461            "pattern": pattern,
462            "count": matches.len(),
463            "matches": matches,
464            "truncated": truncated,
465        })),
466        attachments: vec![],
467    })
468}
469
470// ── grep ──────────────────────────────────────────────────────────────────────
471
472async fn grep_invoke(inv: ToolInvocation, rt: &LocalToolRuntime) -> Result<ToolOutcome, ToolRuntimeError> {
473    let pattern = req_str(&inv, "pattern")?.to_string();
474    let ci = inv.input.get("case_insensitive").and_then(Value::as_bool).unwrap_or(false);
475    let search = match inv.input.get("path").and_then(Value::as_str).filter(|s| !s.is_empty()) {
476        Some(p) => rt.resolve(p),
477        None    => rt.cwd.clone(),
478    };
479
480    let mut cmd = Command::new("grep");
481    cmd.arg("-rn");
482    if ci { cmd.arg("-i"); }
483    cmd.args([
484        "--exclude-dir=node_modules",
485        "--exclude-dir=target",
486        "--exclude-dir=.git",
487        "--exclude-dir=dist",
488        "--exclude-dir=build",
489        "--exclude-dir=__pycache__",
490        "--exclude-dir=.venv",
491        "--exclude-dir=vendor",
492        "--exclude-dir=.next",
493    ]);
494    cmd.arg("-e").arg(&pattern).arg("--").arg(&search);
495    cmd.current_dir(&rt.cwd);
496
497    match tokio::time::timeout(Duration::from_secs(30), cmd.output()).await {
498        Err(_) => Ok(ToolOutcome {
499            output: Err(ToolFailure::new(ToolFailureKind::Timeout, "grep timed out after 30s")),
500            attachments: vec![],
501        }),
502        Ok(Err(e)) => Err(ToolRuntimeError::Runtime(format!("grep spawn failed: {e}"))),
503        Ok(Ok(out)) => {
504            let code = out.status.code().unwrap_or(-1);
505            if code >= 2 {
506                let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
507                return Ok(ToolOutcome {
508                    output: Err(ToolFailure::new(ToolFailureKind::Runtime,
509                        truncate(format!("grep error: {stderr}")))),
510                    attachments: vec![],
511                });
512            }
513            let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
514            Ok(ToolOutcome {
515                output: Ok(json!({
516                    "pattern": pattern,
517                    "matches": bound_output(stdout, &inv.id, "matches"),
518                })),
519                attachments: vec![],
520            })
521        }
522    }
523}
524
525// ── output cap ────────────────────────────────────────────────────────────────
526
527/// Write `content` to `/tmp/harness_out_{id}_{suffix}.txt` and return an
528/// error-aware head+tail preview with the path. Used for bash stdout/stderr
529/// and grep matches. Spills only when over budget (see `bounded_preview`).
530fn bound_output(content: String, id: &str, suffix: &str) -> String {
531    let path = format!("/tmp/harness_out_{id}_{suffix}.txt");
532    match crate::tools::bounded_preview(&content, &path) {
533        None => content,
534        Some(preview) => {
535            let _ = std::fs::write(&path, &content);
536            preview
537        }
538    }
539}
540
541/// Simple safety truncation used only by the read tool as a backstop for
542/// pages that exceed the output budget after pagination.
543fn truncate(s: String) -> String {
544    crate::tools::clip_head(s)
545}
546
547// ── helpers ───────────────────────────────────────────────────────────────────
548
549fn req_str<'a>(inv: &'a ToolInvocation, key: &str) -> Result<&'a str, ToolRuntimeError> {
550    inv.input
551        .get(key)
552        .and_then(Value::as_str)
553        .filter(|s| !s.is_empty())
554        .ok_or_else(|| ToolRuntimeError::InvalidInput {
555            tool: inv.name.clone(),
556            message: format!("missing field `{key}`"),
557        })
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::tools::approval::YoloApproval;
564
565    fn runtime() -> LocalToolRuntime {
566        LocalToolRuntime::new(LocalToolConfig {
567            cwd: Some(std::env::temp_dir()),
568            approval: Arc::new(YoloApproval),
569            emit: Arc::new(|_| {}),
570        })
571    }
572
573    #[tokio::test]
574    async fn bash_non_zero_exit_returns_structured_result() {
575        let out = runtime()
576            .invoke(ToolInvocation {
577                id: "tc_nonzero".into(),
578                name: "bash".into(),
579                input: json!({"command": "printf nope >&2; exit 7"}),
580            })
581            .await
582            .unwrap()
583            .output
584            .unwrap();
585        assert_eq!(out["exit_code"], 7);
586        assert_eq!(out["success"], false);
587        assert_eq!(out["stderr"], "nope\n");
588    }
589
590    #[tokio::test]
591    async fn bash_timeout_returns_structured_result() {
592        let out = runtime()
593            .invoke(ToolInvocation {
594                id: "tc_timeout".into(),
595                name: "bash".into(),
596                input: json!({
597                    "command": "sleep 2",
598                    "soft_timeout_ms": 1000,
599                    "timeout_ms": 5000
600                }),
601            })
602            .await
603            .unwrap()
604            .output
605            .unwrap();
606        assert_eq!(out["success"], false);
607        assert_eq!(out["timed_out"], true);
608        assert_eq!(out["timeout_kind"], "soft");
609    }
610
611    #[tokio::test]
612    async fn bash_tool_supports_bash_syntax_when_bash_exists() {
613        if !Path::new("/bin/bash").exists() {
614            return;
615        }
616        let out = runtime()
617            .invoke(ToolInvocation {
618                id: "tc_bash_syntax".into(),
619                name: "bash".into(),
620                input: json!({"command": "diff <(printf a) <(printf a)"}),
621            })
622            .await
623            .unwrap()
624            .output
625            .unwrap();
626        assert_eq!(out["success"], true);
627        assert_eq!(out["exit_code"], 0);
628        assert_eq!(out["shell"], "/bin/bash");
629    }
630}