Skip to main content

matrixcode_core/tools/
task.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{Mutex, mpsc};
7use uuid::Uuid;
8
9use super::{Tool, ToolDefinition};
10use super::subagent_executor::{
11    SubagentExecutor, SubagentConfig,
12    create_task, setup_worktree, cleanup_worktree,
13};
14use crate::approval::RiskLevel;
15use crate::event::AgentEvent;
16
17/// Task tool for spawning sub-agents to handle complex multi-step tasks
18pub struct TaskTool;
19
20#[async_trait]
21impl Tool for TaskTool {
22    fn definition(&self) -> ToolDefinition {
23        ToolDefinition {
24            name: "task".to_string(),
25            description: "启动新代理处理复杂的多步骤任务。每个代理独立运行,可并行处理不同任务。适用于:(1) 需多次查询/查找的研究任务;(2) 可在后台运行的长时间操作;(3) 需与主上下文隔离的任务;(4) 可并行执行的多个独立任务。".to_string(),
26            parameters: json!({
27                "type": "object",
28                "properties": {
29                    "description": {
30                        "type": "string",
31                        "description": "任务简短描述(3-5 个词)"
32                    },
33                    "prompt": {
34                        "type": "string",
35                        "description": "代理要执行的任务,需包含所有必要上下文"
36                    },
37                    "subagent_type": {
38                        "type": "string",
39                        "enum": ["general-purpose", "Explore", "Plan"],
40                        "default": "general-purpose",
41                        "description": "代理类型:'general-purpose' 用于通用任务,'Explore' 用于快速只读搜索,'Plan' 用于架构规划"
42                    },
43                    "run_in_background": {
44                        "type": "boolean",
45                        "default": false,
46                        "description": "若为 true,在后台运行代理,完成时会收到通知"
47                    },
48                    "isolation": {
49                        "type": "string",
50                        "enum": ["none", "worktree"],
51                        "default": "none",
52                        "description": "隔离模式:'none' 在当前目录工作,'worktree' 创建隔离的 git worktree"
53                    }
54                },
55                "required": ["description", "prompt"]
56            }),
57            ..Default::default()
58        }
59    }
60
61    fn risk_level(&self) -> RiskLevel {
62        RiskLevel::Mutating // Tasks can modify files
63    }
64
65    async fn execute(&self, params: Value) -> Result<String> {
66        let description = params["description"]
67            .as_str()
68            .ok_or_else(|| anyhow::anyhow!("missing 'description'"))?;
69        let prompt = params["prompt"]
70            .as_str()
71            .ok_or_else(|| anyhow::anyhow!("missing 'prompt'"))?;
72        let subagent_type = params["subagent_type"]
73            .as_str()
74            .unwrap_or("general-purpose");
75        let run_in_background = params["run_in_background"].as_bool().unwrap_or(false);
76        let isolation = params["isolation"].as_str().unwrap_or("none");
77
78        // Generate task ID
79        let task_id = Uuid::new_v4().to_string();
80
81        // Create task info
82        let task_info = TaskInfo {
83            id: task_id.clone(),
84            description: description.to_string(),
85            prompt: prompt.to_string(),
86            subagent_type: subagent_type.to_string(),
87            status: TaskStatus::Pending,
88            result: None,
89            started_at: Some(std::time::Instant::now()),
90        };
91
92        // Get or create task manager
93        let manager = get_task_manager();
94
95        // Add task
96        {
97            let mut tasks = manager.tasks.lock().await;
98            tasks.insert(task_id.clone(), task_info);
99        }
100
101        if run_in_background {
102            // Spawn background task
103            let manager_clone = Arc::clone(&manager);
104            let task_id_clone = task_id.clone();
105            let prompt_clone = prompt.to_string();
106            let subagent_type_clone = subagent_type.to_string();
107            let isolation_clone = isolation.to_string();
108
109            tokio::spawn(async move {
110                let result =
111                    execute_subagent_task(&prompt_clone, &subagent_type_clone, &isolation_clone)
112                        .await;
113
114                // Update task status
115                let mut tasks = manager_clone.tasks.lock().await;
116                if let Some(task) = tasks.get_mut(&task_id_clone) {
117                    match result {
118                        Ok(output) => {
119                            task.status = TaskStatus::Completed;
120                            task.result = Some(output);
121                        }
122                        Err(e) => {
123                            task.status = TaskStatus::Failed;
124                            task.result = Some(e.to_string());
125                        }
126                    }
127                }
128
129                // Send notification (if channel available)
130                if let Some(tx) = &manager_clone.notification_tx {
131                    let _ = tx.try_send(TaskNotification {
132                        task_id: task_id_clone,
133                        status: "completed".to_string(),
134                    });
135                }
136            });
137
138            Ok(format!(
139                "Task {} started in background. You'll be notified when it completes.",
140                task_id
141            ))
142        } else {
143            // Execute synchronously
144            let result = execute_subagent_task(prompt, subagent_type, isolation).await?;
145
146            // Update task status
147            {
148                let mut tasks = manager.tasks.lock().await;
149                if let Some(task) = tasks.get_mut(&task_id) {
150                    task.status = TaskStatus::Completed;
151                    task.result = Some(result.clone());
152                }
153            }
154
155            Ok(result)
156        }
157    }
158}
159
160/// Task status
161#[derive(Debug, Clone, PartialEq)]
162pub enum TaskStatus {
163    Pending,
164    Running,
165    Completed,
166    Failed,
167    Cancelled,
168}
169
170impl std::fmt::Display for TaskStatus {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        match self {
173            TaskStatus::Pending => write!(f, "pending"),
174            TaskStatus::Running => write!(f, "running"),
175            TaskStatus::Completed => write!(f, "completed"),
176            TaskStatus::Failed => write!(f, "failed"),
177            TaskStatus::Cancelled => write!(f, "cancelled"),
178        }
179    }
180}
181
182/// Task info
183#[derive(Debug, Clone)]
184pub struct TaskInfo {
185    pub id: String,
186    pub description: String,
187    pub prompt: String,
188    pub subagent_type: String,
189    pub status: TaskStatus,
190    pub result: Option<String>,
191    pub started_at: Option<std::time::Instant>,
192}
193
194/// Task notification
195#[derive(Debug, Clone)]
196pub struct TaskNotification {
197    pub task_id: String,
198    pub status: String,
199}
200
201/// Task manager singleton
202pub struct TaskManager {
203    pub tasks: Mutex<HashMap<String, TaskInfo>>,
204    pub notification_tx: Option<mpsc::Sender<TaskNotification>>,
205}
206
207static TASK_MANAGER: std::sync::OnceLock<Arc<TaskManager>> = std::sync::OnceLock::new();
208
209fn get_task_manager() -> Arc<TaskManager> {
210    TASK_MANAGER
211        .get_or_init(|| {
212            Arc::new(TaskManager {
213                tasks: Mutex::new(HashMap::new()),
214                notification_tx: None,
215            })
216        })
217        .clone()
218}
219
220/// Execute subagent task
221async fn execute_subagent_task(
222    prompt: &str,
223    subagent_type: &str,
224    isolation: &str,
225) -> Result<String> {
226    // Create the task
227    let task = create_task(
228        "", // description is not needed for execute
229        prompt,
230        subagent_type,
231        isolation,
232    );
233
234    // Setup worktree if needed
235    if isolation == "worktree" {
236        setup_worktree(&task).await?;
237    }
238
239    // Create event channel (no external forwarding in this context)
240    let (event_tx, _event_rx) = mpsc::channel::<AgentEvent>(100);
241
242    // Create executor with default config
243    // Use fast model for subagents
244    let config = SubagentConfig {
245        model_name: "claude-sonnet-4-20250514".to_string(),
246        max_tokens: 4096,
247        system_prompt_prefix: None,
248        think: false,
249        tool_names: None,
250    };
251
252    // Get tools from registry or use empty set
253    let tools = get_subagent_tools();
254
255    let mut executor = SubagentExecutor::new(config, event_tx, tools);
256
257    // Execute the task
258    let result = executor.execute(task.clone()).await?;
259
260    // Cleanup worktree if needed
261    if isolation == "worktree" {
262        cleanup_worktree(&task).await?;
263    }
264
265    // Format result
266    if result.success {
267        Ok(format!(
268            "[{} Agent] Task completed successfully\n\
269            Tokens used: {} input, {} output\n\
270            Result: {}",
271            subagent_type,
272            result.usage.input_tokens,
273            result.usage.output_tokens,
274            result.content
275        ))
276    } else {
277        Ok(format!(
278            "[{} Agent] Task failed\n\
279            Error: {}",
280            subagent_type,
281            result.content
282        ))
283    }
284}
285
286/// Get tools for subagent execution
287/// Returns a minimal set of tools suitable for background tasks
288fn get_subagent_tools() -> Vec<Arc<dyn Tool>> {
289    // In production, this would get tools from the main agent's tool registry
290    // For now, return empty set - the executor will handle tool filtering
291    Vec::new()
292}
293
294/// TaskCreate tool for creating background tasks
295pub struct TaskCreateTool;
296
297#[async_trait]
298impl Tool for TaskCreateTool {
299    fn definition(&self) -> ToolDefinition {
300        ToolDefinition {
301            name: "task_create".to_string(),
302            description: "创建独立运行的后台任务".to_string(),
303            parameters: json!({
304                "type": "object",
305                "properties": {
306                    "description": {
307                        "type": "string",
308                        "description": "任务描述"
309                    },
310                    "prompt": {
311                        "type": "string",
312                        "description": "任务提示"
313                    }
314                },
315                "required": ["description", "prompt"]
316            }),
317            ..Default::default()
318        }
319    }
320
321    fn risk_level(&self) -> RiskLevel {
322        RiskLevel::Mutating
323    }
324
325    async fn execute(&self, params: Value) -> Result<String> {
326        let description = params["description"]
327            .as_str()
328            .ok_or_else(|| anyhow::anyhow!("missing 'description'"))?;
329        let prompt = params["prompt"]
330            .as_str()
331            .ok_or_else(|| anyhow::anyhow!("missing 'prompt'"))?;
332
333        let task_id = Uuid::new_v4().to_string();
334        let manager = get_task_manager();
335
336        let task_info = TaskInfo {
337            id: task_id.clone(),
338            description: description.to_string(),
339            prompt: prompt.to_string(),
340            subagent_type: "general-purpose".to_string(),
341            status: TaskStatus::Running,
342            result: None,
343            started_at: Some(std::time::Instant::now()),
344        };
345
346        {
347            let mut tasks = manager.tasks.lock().await;
348            tasks.insert(task_id.clone(), task_info);
349        }
350
351        Ok(format!("Task {} created and running", task_id))
352    }
353}
354
355/// TaskGet tool for getting task status
356pub struct TaskGetTool;
357
358#[async_trait]
359impl Tool for TaskGetTool {
360    fn definition(&self) -> ToolDefinition {
361        ToolDefinition {
362            name: "task_get".to_string(),
363            description: "获取指定任务的状态和结果".to_string(),
364            parameters: json!({
365                "type": "object",
366                "properties": {
367                    "task_id": {
368                        "type": "string",
369                        "description": "要查询的任务 ID"
370                    }
371                },
372                "required": ["task_id"]
373            }),
374            ..Default::default()
375        }
376    }
377
378    async fn execute(&self, params: Value) -> Result<String> {
379        let task_id = params["task_id"]
380            .as_str()
381            .ok_or_else(|| anyhow::anyhow!("missing 'task_id'"))?;
382
383        let manager = get_task_manager();
384        let tasks = manager.tasks.lock().await;
385
386        if let Some(task) = tasks.get(task_id) {
387            let status_str = task.status.to_string();
388
389            let elapsed = task
390                .started_at
391                .map(|s| format!("{:.1}s", s.elapsed().as_secs_f64()))
392                .unwrap_or_else(|| "N/A".to_string());
393
394            let result_str = task
395                .result
396                .clone()
397                .unwrap_or_else(|| "No result yet".to_string());
398
399            Ok(format!(
400                "Task: {}\nDescription: {}\nStatus: {}\nElapsed: {}\nResult: {}",
401                task_id, task.description, status_str, elapsed, result_str
402            ))
403        } else {
404            Ok(format!("Task {} not found", task_id))
405        }
406    }
407}
408
409/// TaskList tool for listing all tasks
410pub struct TaskListTool;
411
412#[async_trait]
413impl Tool for TaskListTool {
414    fn definition(&self) -> ToolDefinition {
415        ToolDefinition {
416            name: "task_list".to_string(),
417            description: "列出所有活动任务".to_string(),
418            parameters: json!({
419                "type": "object",
420                "properties": {}
421            }),
422            ..Default::default()
423        }
424    }
425
426    async fn execute(&self, _params: Value) -> Result<String> {
427        let manager = get_task_manager();
428        let tasks = manager.tasks.lock().await;
429
430        if tasks.is_empty() {
431            return Ok("No active tasks".to_string());
432        }
433
434        let mut result = Vec::new();
435        for (id, task) in tasks.iter() {
436            result.push(format!("{} [{}] - {}", id, task.status, task.description));
437        }
438
439        Ok(result.join("\n"))
440    }
441}
442
443/// TaskStop tool for stopping a task
444pub struct TaskStopTool;
445
446#[async_trait]
447impl Tool for TaskStopTool {
448    fn definition(&self) -> ToolDefinition {
449        ToolDefinition {
450            name: "task_stop".to_string(),
451            description: "停止正在运行的任务".to_string(),
452            parameters: json!({
453                "type": "object",
454                "properties": {
455                    "task_id": {
456                        "type": "string",
457                        "description": "要停止的任务 ID"
458                    }
459                },
460                "required": ["task_id"]
461            }),
462            ..Default::default()
463        }
464    }
465
466    fn risk_level(&self) -> RiskLevel {
467        RiskLevel::Mutating
468    }
469
470    async fn execute(&self, params: Value) -> Result<String> {
471        let task_id = params["task_id"]
472            .as_str()
473            .ok_or_else(|| anyhow::anyhow!("missing 'task_id'"))?;
474
475        let manager = get_task_manager();
476        let mut tasks = manager.tasks.lock().await;
477
478        if let Some(task) = tasks.get_mut(task_id) {
479            task.status = TaskStatus::Cancelled;
480            Ok(format!("Task {} stopped", task_id))
481        } else {
482            Ok(format!("Task {} not found", task_id))
483        }
484    }
485}