Skip to main content

deepseek_rust_cli/agent/
executor.rs

1use std::{
2    collections::HashMap,
3    time::{Duration, Instant},
4};
5
6use anyhow::Result;
7use futures::future::join_all;
8use once_cell::sync::Lazy;
9use serde_json::Value;
10
11use crate::{agent::types::UndoAction, tools, tools::base::ToolRegistry};
12
13const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(120);
14const LONG_TOOL_TIMEOUT: Duration = Duration::from_secs(600);
15/// Cache TTL for read-only tool results (5 seconds)
16const CACHE_TTL: Duration = Duration::from_secs(5);
17
18/// Read-only tools that can be safely cached
19const CACHEABLE_TOOLS: &[&str] = &[
20    "read_local_file",
21    "list_directory",
22    "tree_view",
23    "diff_files",
24    "hash_file",
25    "count_lines",
26    "get_system_info",
27    "get_env_var",
28    "git_status",
29    "git_diff",
30    "git_log",
31    "git_branch",
32    "git_remote_list",
33    "git_stash",
34    "github_repo_info",
35    "github_repo_list_issues",
36    "github_pr_list",
37    "github_pr_info",
38    "github_search_code",
39    "github_search_repos",
40    "github_get_file",
41    "github_workflow_list",
42    "github_workflow_runs",
43];
44
45static TOOL_REGISTRY: Lazy<ToolRegistry> = Lazy::new(|| {
46    let mut registry = ToolRegistry::new();
47    for tool in tools::get_all_tools() {
48        registry.register(tool);
49    }
50    registry
51});
52
53/// Tool result cache entry
54#[derive(Clone)]
55pub struct CacheEntry {
56    pub timestamp: Instant,
57    pub result: String,
58}
59
60/// Global tool cache — survives across tool calls within the same iteration
61pub type ToolCache = HashMap<String, CacheEntry>;
62
63/// Get appropriate timeout for a given tool
64fn tool_timeout(name: &str) -> Duration {
65    match name {
66        "git_clone" | "git_push" | "git_pull" | "execute_shell_command" | "fetch_url" => {
67            LONG_TOOL_TIMEOUT
68        }
69        _ => DEFAULT_TOOL_TIMEOUT,
70    }
71}
72
73/// Build a cache key from tool name and args
74fn cache_key(name: &str, args_val: &serde_json::Map<String, Value>) -> String {
75    let mut key = name.to_string();
76    let mut sorted: Vec<(&String, &Value)> = args_val.iter().collect();
77    sorted.sort_by(|a, b| a.0.cmp(b.0));
78    for (k, v) in sorted {
79        key.push(':');
80        key.push_str(k);
81        key.push('=');
82        key.push_str(&v.to_string());
83    }
84    key
85}
86
87fn is_cacheable(name: &str) -> bool {
88    CACHEABLE_TOOLS.contains(&name)
89}
90
91pub async fn execute_tool(
92    name: &str,
93    args_val: &serde_json::Map<String, Value>,
94    undo_stack: &mut Vec<UndoAction>,
95    agent_cwd: Option<&std::path::Path>,
96) -> Result<String> {
97    execute_tool_inner(name, args_val, undo_stack, agent_cwd).await
98}
99
100async fn execute_tool_inner(
101    name: &str,
102    args_val: &serde_json::Map<String, Value>,
103    undo_stack: &mut Vec<UndoAction>,
104    agent_cwd: Option<&std::path::Path>,
105) -> Result<String> {
106    let timeout = tool_timeout(name);
107
108    tokio::time::timeout(
109        timeout,
110        execute_tool_raw(name, args_val, undo_stack, agent_cwd),
111    )
112    .await
113    .unwrap_or_else(|_| {
114        Err(anyhow::anyhow!(
115            "Tool '{}' timed out after {:?}",
116            name,
117            timeout
118        ))
119    })
120}
121
122async fn execute_tool_raw(
123    name: &str,
124    args_val: &serde_json::Map<String, Value>,
125    undo_stack: &mut Vec<UndoAction>,
126    agent_cwd: Option<&std::path::Path>,
127) -> Result<String> {
128    // Convert serde_json::Map to HashMap for the Tool trait
129    let mut args = HashMap::new();
130    for (k, v) in args_val {
131        args.insert(k.clone(), v.clone());
132    }
133
134    TOOL_REGISTRY
135        .execute(name, &args, undo_stack, agent_cwd)
136        .await
137}
138
139/// Execute a tool with cache support.
140/// Returns (result, was_cache_hit)
141pub async fn execute_tool_cached(
142    name: &str,
143    args_val: &serde_json::Map<String, Value>,
144    undo_stack: &mut Vec<UndoAction>,
145    cache: &mut ToolCache,
146    agent_cwd: Option<&std::path::Path>,
147) -> (Result<String>, bool) {
148    // Only cache read-only tools
149    if is_cacheable(name) {
150        let key = cache_key(name, args_val);
151        let now = Instant::now();
152
153        if let Some(entry) = cache.get(&key) {
154            if now.duration_since(entry.timestamp) < CACHE_TTL {
155                tracing::debug!(target: "cache", "Cache hit: {}", key);
156                return (Ok(entry.result.clone()), true);
157            }
158        }
159
160        let result = execute_tool(name, args_val, undo_stack, agent_cwd).await;
161        if let Ok(ref res) = result {
162            if res.len() < 50_000 {
163                // Don't cache very large results
164                cache.insert(
165                    key,
166                    CacheEntry {
167                        timestamp: Instant::now(),
168                        result: res.clone(),
169                    },
170                );
171            }
172        }
173        (result, false)
174    } else {
175        (
176            execute_tool(name, args_val, undo_stack, agent_cwd).await,
177            false,
178        )
179    }
180}
181
182/// Execute multiple independent tool calls in parallel
183pub async fn execute_tools_parallel(
184    tool_calls: &[(String, serde_json::Map<String, Value>)],
185    agent_cwd: Option<std::path::PathBuf>,
186) -> Vec<(usize, Result<String>, Vec<UndoAction>)> {
187    let futures: Vec<_> = tool_calls
188        .iter()
189        .enumerate()
190        .map(|(idx, (name, args))| {
191            let name = name.clone();
192            let args = args.clone();
193            let cwd = agent_cwd.clone();
194            async move {
195                let mut temp_undo = Vec::new();
196                let result = execute_tool(&name, &args, &mut temp_undo, cwd.as_deref()).await;
197                (idx, result, temp_undo)
198            }
199        })
200        .collect();
201
202    join_all(futures).await
203}