Skip to main content

aster/tools/
task_tool.rs

1//! Task Tool - 后台任务管理工具
2//!
3//! 基于 TaskManager 实现的任务启动和管理工具,对齐 Claude Agent SDK
4
5use super::base::{PermissionCheckResult, Tool};
6use super::context::{ToolContext, ToolResult};
7use super::error::ToolError;
8use super::task::TaskManager;
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::time::Duration;
13
14/// TaskTool 输入参数
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TaskInput {
17    /// 任务命令
18    pub command: String,
19    /// 任务描述(可选)
20    pub description: Option<String>,
21    /// 是否在后台运行
22    pub run_in_background: Option<bool>,
23}
24
25/// TaskTool - 启动后台任务
26///
27/// 对齐 Claude Agent SDK 的 TaskTool 功能,用于启动和管理后台任务
28pub struct TaskTool {
29    /// 任务管理器
30    task_manager: Arc<TaskManager>,
31}
32
33impl TaskTool {
34    /// 创建新的 TaskTool
35    pub fn new() -> Self {
36        Self {
37            task_manager: Arc::new(TaskManager::new()),
38        }
39    }
40
41    /// 使用自定义 TaskManager 创建 TaskTool
42    pub fn with_manager(task_manager: Arc<TaskManager>) -> Self {
43        Self { task_manager }
44    }
45}
46
47impl Default for TaskTool {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[async_trait]
54impl Tool for TaskTool {
55    fn name(&self) -> &str {
56        "Task"
57    }
58
59    fn description(&self) -> &str {
60        r#"启动后台任务执行命令
61
62用于启动长时间运行的命令或需要并行执行的任务。支持:
63- 后台执行命令
64- 任务状态跟踪
65- 输出文件持久化
66- 并发任务限制
67
68参数:
69- command: 要执行的命令
70- description: 任务描述(可选)
71- run_in_background: 是否后台运行(默认 true)
72
73返回任务 ID,可用于后续查询任务状态和输出。"#
74    }
75
76    fn input_schema(&self) -> serde_json::Value {
77        serde_json::json!({
78            "type": "object",
79            "properties": {
80                "command": {
81                    "type": "string",
82                    "description": "要执行的命令"
83                },
84                "description": {
85                    "type": "string",
86                    "description": "任务描述(可选)"
87                },
88                "run_in_background": {
89                    "type": "boolean",
90                    "description": "是否在后台运行(默认 true)"
91                }
92            },
93            "required": ["command"]
94        })
95    }
96
97    async fn execute(
98        &self,
99        params: serde_json::Value,
100        context: &ToolContext,
101    ) -> Result<ToolResult, ToolError> {
102        let input: TaskInput = serde_json::from_value(params)
103            .map_err(|e| ToolError::invalid_params(format!("参数解析失败: {}", e)))?;
104
105        let run_in_background = input.run_in_background.unwrap_or(true);
106
107        // 启动任务
108        let task_id = self.task_manager.start(&input.command, context).await?;
109
110        let description = input.description.unwrap_or_else(|| {
111            // 截取命令的前50个字符作为描述,安全处理 UTF-8
112            let cmd = &input.command;
113            if cmd.chars().count() > 50 {
114                let truncated: String = cmd.chars().take(47).collect();
115                format!("{}...", truncated)
116            } else {
117                cmd.to_string()
118            }
119        });
120
121        if run_in_background {
122            // 后台运行 - 立即返回任务 ID
123            Ok(ToolResult::success(format!(
124                "任务已启动(后台运行)\n任务 ID: {}\n描述: {}\n命令: {}\n\n使用 TaskOutput 工具查询任务状态和输出。",
125                task_id, description, input.command
126            )).with_metadata("task_id", serde_json::json!(task_id)))
127        } else {
128            // 前台运行 - 等待完成
129            // 等待任务完成(最多等待30秒)
130            let timeout = Duration::from_secs(30);
131            let start_time = std::time::Instant::now();
132
133            loop {
134                if let Some(state) = self.task_manager.get_status(&task_id).await {
135                    if state.status.is_terminal() {
136                        // 任务已完成,获取输出
137                        let output = self
138                            .task_manager
139                            .get_output(&task_id, None)
140                            .await
141                            .unwrap_or_else(|_| "无法获取任务输出".to_string());
142
143                        let duration = state.duration().as_secs_f64();
144
145                        return Ok(ToolResult::success(format!(
146                            "任务已完成\n任务 ID: {}\n描述: {}\n状态: {}\n执行时间: {:.2}秒\n\n=== 输出 ===\n{}",
147                            task_id, description, state.status, duration, output
148                        )).with_metadata("task_id", serde_json::json!(task_id))
149                          .with_metadata("status", serde_json::json!(state.status.to_string()))
150                          .with_metadata("duration", serde_json::json!(duration)));
151                    }
152                }
153
154                // 检查超时
155                if start_time.elapsed() > timeout {
156                    return Ok(ToolResult::success(format!(
157                        "任务启动成功但执行时间超过 {}秒,已转为后台运行\n任务 ID: {}\n描述: {}\n\n使用 TaskOutput 工具查询任务状态和输出。",
158                        timeout.as_secs(), task_id, description
159                    )).with_metadata("task_id", serde_json::json!(task_id)));
160                }
161
162                // 等待100ms后重新检查
163                tokio::time::sleep(Duration::from_millis(100)).await;
164            }
165        }
166    }
167
168    async fn check_permissions(
169        &self,
170        _params: &serde_json::Value,
171        _context: &ToolContext,
172    ) -> PermissionCheckResult {
173        // 任务启动需要执行权限
174        PermissionCheckResult::ask("执行后台任务")
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use std::path::PathBuf;
182    use tempfile::TempDir;
183
184    fn create_test_context() -> ToolContext {
185        ToolContext::new(PathBuf::from("/tmp"))
186            .with_session_id("test-session")
187            .with_user("test-user")
188    }
189
190    #[tokio::test]
191    async fn test_task_tool_new() {
192        let tool = TaskTool::new();
193        assert_eq!(tool.name(), "Task");
194    }
195
196    #[tokio::test]
197    async fn test_task_tool_input_schema() {
198        let tool = TaskTool::new();
199        let schema = tool.input_schema();
200
201        assert_eq!(schema["type"], "object");
202        assert!(schema["properties"]["command"].is_object());
203        assert_eq!(schema["required"], serde_json::json!(["command"]));
204    }
205
206    #[tokio::test]
207    async fn test_task_tool_execute_background() {
208        let temp_dir = TempDir::new().unwrap();
209        let task_manager = Arc::new(
210            TaskManager::new()
211                .with_output_directory(temp_dir.path().to_path_buf())
212                .with_max_concurrent(5),
213        );
214        let tool = TaskTool::with_manager(task_manager);
215        let context = create_test_context();
216
217        let params = serde_json::json!({
218            "command": "echo hello",
219            "description": "测试任务",
220            "run_in_background": true
221        });
222
223        let result = tool.execute(params, &context).await;
224        assert!(result.is_ok());
225
226        let tool_result = result.unwrap();
227        assert!(tool_result.success);
228        assert!(tool_result.output.as_ref().unwrap().contains("任务已启动"));
229        assert!(tool_result.metadata.contains_key("task_id"));
230    }
231
232    #[tokio::test]
233    async fn test_task_tool_execute_foreground() {
234        let temp_dir = TempDir::new().unwrap();
235        let task_manager = Arc::new(
236            TaskManager::new()
237                .with_output_directory(temp_dir.path().to_path_buf())
238                .with_max_concurrent(5),
239        );
240        let tool = TaskTool::with_manager(task_manager);
241        let context = create_test_context();
242
243        let params = serde_json::json!({
244            "command": "echo hello world",
245            "run_in_background": false
246        });
247
248        let result = tool.execute(params, &context).await;
249        assert!(result.is_ok());
250
251        let tool_result = result.unwrap();
252        assert!(tool_result.success);
253        // 应该包含任务输出
254        let output = tool_result.output.as_ref().unwrap();
255        assert!(output.contains("hello world") || output.contains("任务已完成"));
256    }
257
258    #[tokio::test]
259    async fn test_task_tool_invalid_params() {
260        let tool = TaskTool::new();
261        let context = create_test_context();
262
263        let params = serde_json::json!({
264            "invalid": "params"
265        });
266
267        let result = tool.execute(params, &context).await;
268        assert!(result.is_err());
269    }
270
271    #[tokio::test]
272    async fn test_task_tool_check_permissions() {
273        let tool = TaskTool::new();
274        let context = create_test_context();
275        let params = serde_json::json!({"command": "echo test"});
276
277        let result = tool.check_permissions(&params, &context).await;
278        assert!(result.requires_confirmation());
279    }
280}