Skip to main content

aster/tools/
task_output_tool.rs

1//! TaskOutput Tool - 任务输出查询工具
2//!
3//! 用于查询后台任务的状态和输出,对齐 Claude Agent SDK
4
5use super::base::{PermissionCheckResult, Tool};
6use super::context::{ToolContext, ToolResult};
7use super::error::ToolError;
8use super::task::TaskManager;
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::time::Duration;
13
14/// TaskOutputTool 输入参数
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TaskOutputInput {
17    /// 任务 ID
18    pub task_id: String,
19    /// 是否阻塞等待任务完成
20    pub block: Option<bool>,
21    /// 等待超时时间(毫秒)
22    pub timeout: Option<u64>,
23    /// 显示详细历史(扩展功能)
24    pub show_history: Option<bool>,
25    /// 限制输出行数
26    pub lines: Option<usize>,
27}
28
29/// TaskOutputTool - 查询任务输出和状态
30///
31/// 对齐 Claude Agent SDK 的 TaskOutputTool 功能
32pub struct TaskOutputTool {
33    /// 任务管理器
34    task_manager: Arc<TaskManager>,
35}
36
37impl TaskOutputTool {
38    /// 创建新的 TaskOutputTool
39    pub fn new() -> Self {
40        Self {
41            task_manager: Arc::new(TaskManager::new()),
42        }
43    }
44
45    /// 使用自定义 TaskManager 创建 TaskOutputTool
46    pub fn with_manager(task_manager: Arc<TaskManager>) -> Self {
47        Self { task_manager }
48    }
49}
50
51impl Default for TaskOutputTool {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57#[async_trait]
58impl Tool for TaskOutputTool {
59    fn name(&self) -> &str {
60        "TaskOutput"
61    }
62
63    fn description(&self) -> &str {
64        r#"获取后台任务的输出和状态
65
66用于查询通过 Task 工具启动的后台任务的执行状态和输出结果。
67
68参数:
69- task_id: 任务 ID(必需)
70- block: 是否等待任务完成(默认 false)
71- timeout: 等待超时时间(毫秒,默认 5000)
72- show_history: 显示详细执行历史(默认 false)
73- lines: 限制输出行数(可选)
74
75功能:
76- 查询任务状态(运行中/已完成/失败/超时/已终止)
77- 获取任务输出内容
78- 支持阻塞等待任务完成
79- 显示任务执行时间和统计信息"#
80    }
81
82    fn input_schema(&self) -> serde_json::Value {
83        serde_json::json!({
84            "type": "object",
85            "properties": {
86                "task_id": {
87                    "type": "string",
88                    "description": "要查询的任务 ID"
89                },
90                "block": {
91                    "type": "boolean",
92                    "description": "是否等待任务完成(默认 false)"
93                },
94                "timeout": {
95                    "type": "number",
96                    "description": "等待超时时间(毫秒,默认 5000)"
97                },
98                "show_history": {
99                    "type": "boolean",
100                    "description": "显示详细执行历史(默认 false)"
101                },
102                "lines": {
103                    "type": "number",
104                    "description": "限制输出行数(可选)"
105                }
106            },
107            "required": ["task_id"]
108        })
109    }
110
111    async fn execute(
112        &self,
113        params: serde_json::Value,
114        _context: &ToolContext,
115    ) -> Result<ToolResult, ToolError> {
116        let input: TaskOutputInput = serde_json::from_value(params)
117            .map_err(|e| ToolError::invalid_params(format!("参数解析失败: {}", e)))?;
118
119        let block = input.block.unwrap_or(false);
120        let timeout_ms = input.timeout.unwrap_or(5000);
121        let show_history = input.show_history.unwrap_or(false);
122
123        // 检查任务是否存在
124        if !self.task_manager.task_exists(&input.task_id).await {
125            return Err(ToolError::not_found(format!(
126                "任务未找到: {}",
127                input.task_id
128            )));
129        }
130
131        // 如果需要阻塞等待
132        if block {
133            let timeout = Duration::from_millis(timeout_ms);
134            let start_time = std::time::Instant::now();
135
136            loop {
137                if let Some(state) = self.task_manager.get_status(&input.task_id).await {
138                    if state.status.is_terminal() {
139                        break;
140                    }
141                }
142
143                // 检查超时
144                if start_time.elapsed() > timeout {
145                    break;
146                }
147
148                // 等待100ms后重新检查
149                tokio::time::sleep(Duration::from_millis(100)).await;
150            }
151        }
152
153        // 获取任务状态
154        let state = self
155            .task_manager
156            .get_status(&input.task_id)
157            .await
158            .ok_or_else(|| ToolError::not_found(format!("任务状态未找到: {}", input.task_id)))?;
159
160        // 构建输出信息
161        let mut output = Vec::new();
162        output.push(format!("=== 任务 {} ===", input.task_id));
163        output.push(format!("命令: {}", state.command));
164        output.push(format!("状态: {}", state.status));
165        output.push(format!("开始时间: {}", format_instant(state.start_time)));
166
167        let duration = state.duration();
168        if let Some(end_time) = state.end_time {
169            output.push(format!("结束时间: {}", format_instant(end_time)));
170            output.push(format!("执行时间: {:.2}秒", duration.as_secs_f64()));
171        } else {
172            output.push(format!("运行时间: {:.2}秒", duration.as_secs_f64()));
173        }
174
175        if let Some(exit_code) = state.exit_code {
176            output.push(format!("退出码: {}", exit_code));
177        }
178
179        output.push(format!("工作目录: {}", state.working_directory.display()));
180        output.push(format!("输出文件: {}", state.output_file.display()));
181        output.push(format!("会话 ID: {}", state.session_id));
182
183        // 显示详细历史(扩展功能)
184        if show_history {
185            output.push("\n=== 执行历史 ===".to_string());
186            output.push("(注意:当前实现中 TaskManager 不维护详细历史记录)".to_string());
187            output.push(format!("任务创建: {}", format_instant(state.start_time)));
188            if let Some(end_time) = state.end_time {
189                output.push(format!(
190                    "任务结束: {} (状态: {})",
191                    format_instant(end_time),
192                    state.status
193                ));
194            }
195        }
196
197        // 获取任务输出
198        match self
199            .task_manager
200            .get_output(&input.task_id, input.lines)
201            .await
202        {
203            Ok(task_output) => {
204                output.push("\n=== 任务输出 ===".to_string());
205                if task_output.trim().is_empty() {
206                    output.push("(暂无输出)".to_string());
207                } else {
208                    output.push(task_output);
209                }
210            }
211            Err(e) => {
212                output.push("\n=== 输出获取失败 ===".to_string());
213                output.push(format!("错误: {}", e));
214            }
215        }
216
217        // 根据任务状态添加状态说明
218        match state.status {
219            super::task::TaskStatus::Running => {
220                output.push("\n=== 状态说明 ===".to_string());
221                output.push("任务仍在运行中。使用 block=true 参数等待任务完成。".to_string());
222            }
223            super::task::TaskStatus::Completed => {
224                output.push("\n=== 状态说明 ===".to_string());
225                output.push("任务已成功完成。".to_string());
226            }
227            super::task::TaskStatus::Failed => {
228                output.push("\n=== 状态说明 ===".to_string());
229                output.push("任务执行失败。请检查命令和输出错误信息。".to_string());
230            }
231            super::task::TaskStatus::TimedOut => {
232                output.push("\n=== 状态说明 ===".to_string());
233                output.push("任务因超时被终止。".to_string());
234            }
235            super::task::TaskStatus::Killed => {
236                output.push("\n=== 状态说明 ===".to_string());
237                output.push("任务被用户终止。".to_string());
238            }
239        }
240
241        Ok(ToolResult::success(output.join("\n"))
242            .with_metadata("task_id", serde_json::json!(input.task_id))
243            .with_metadata("status", serde_json::json!(state.status.to_string()))
244            .with_metadata("duration", serde_json::json!(duration.as_secs_f64()))
245            .with_metadata("exit_code", serde_json::json!(state.exit_code)))
246    }
247
248    async fn check_permissions(
249        &self,
250        _params: &serde_json::Value,
251        _context: &ToolContext,
252    ) -> PermissionCheckResult {
253        // 查询任务输出是只读操作
254        PermissionCheckResult::allow()
255    }
256}
257
258/// 格式化 Instant 为可读字符串
259/// 注意:Instant 不能直接转换为绝对时间,这里只显示相对时间
260fn format_instant(instant: std::time::Instant) -> String {
261    let elapsed = instant.elapsed();
262    format!("{:.2}秒前", elapsed.as_secs_f64())
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use std::path::PathBuf;
269    use tempfile::TempDir;
270
271    fn create_test_context() -> ToolContext {
272        ToolContext::new(PathBuf::from("/tmp"))
273            .with_session_id("test-session")
274            .with_user("test-user")
275    }
276
277    #[tokio::test]
278    async fn test_task_output_tool_new() {
279        let tool = TaskOutputTool::new();
280        assert_eq!(tool.name(), "TaskOutput");
281    }
282
283    #[tokio::test]
284    async fn test_task_output_tool_input_schema() {
285        let tool = TaskOutputTool::new();
286        let schema = tool.input_schema();
287
288        assert_eq!(schema["type"], "object");
289        assert!(schema["properties"]["task_id"].is_object());
290        assert_eq!(schema["required"], serde_json::json!(["task_id"]));
291    }
292
293    #[tokio::test]
294    async fn test_task_output_tool_not_found() {
295        let tool = TaskOutputTool::new();
296        let context = create_test_context();
297
298        let params = serde_json::json!({
299            "task_id": "nonexistent-task"
300        });
301
302        let result = tool.execute(params, &context).await;
303        assert!(result.is_err());
304        assert!(matches!(result.unwrap_err(), ToolError::NotFound(_)));
305    }
306
307    #[tokio::test]
308    async fn test_task_output_tool_with_task() {
309        let temp_dir = TempDir::new().unwrap();
310        let task_manager = Arc::new(
311            TaskManager::new()
312                .with_output_directory(temp_dir.path().to_path_buf())
313                .with_max_concurrent(5),
314        );
315        let tool = TaskOutputTool::with_manager(task_manager.clone());
316        let context = create_test_context();
317
318        // 先启动一个任务
319        let task_id = task_manager.start("echo hello", &context).await.unwrap();
320
321        // 等待任务完成
322        tokio::time::sleep(Duration::from_millis(500)).await;
323
324        // 查询任务输出
325        let params = serde_json::json!({
326            "task_id": task_id
327        });
328
329        let result = tool.execute(params, &context).await;
330        assert!(result.is_ok());
331
332        let tool_result = result.unwrap();
333        assert!(tool_result.success);
334        assert!(tool_result.output.as_ref().unwrap().contains(&task_id));
335        assert!(tool_result.metadata.contains_key("status"));
336    }
337
338    #[tokio::test]
339    async fn test_task_output_tool_with_block() {
340        let temp_dir = TempDir::new().unwrap();
341        let task_manager = Arc::new(
342            TaskManager::new()
343                .with_output_directory(temp_dir.path().to_path_buf())
344                .with_max_concurrent(5),
345        );
346        let tool = TaskOutputTool::with_manager(task_manager.clone());
347        let context = create_test_context();
348
349        // 启动一个快速任务
350        let task_id = task_manager
351            .start("echo blocking test", &context)
352            .await
353            .unwrap();
354
355        // 使用阻塞模式查询
356        let params = serde_json::json!({
357            "task_id": task_id,
358            "block": true,
359            "timeout": 2000
360        });
361
362        let result = tool.execute(params, &context).await;
363        assert!(result.is_ok());
364
365        let tool_result = result.unwrap();
366        assert!(tool_result.success);
367        // 应该包含任务输出
368        let output = tool_result.output.as_ref().unwrap();
369        assert!(output.contains("blocking test") || output.contains("已完成"));
370    }
371
372    #[tokio::test]
373    async fn test_task_output_tool_invalid_params() {
374        let tool = TaskOutputTool::new();
375        let context = create_test_context();
376
377        let params = serde_json::json!({
378            "invalid": "params"
379        });
380
381        let result = tool.execute(params, &context).await;
382        assert!(result.is_err());
383    }
384
385    #[tokio::test]
386    async fn test_task_output_tool_check_permissions() {
387        let tool = TaskOutputTool::new();
388        let context = create_test_context();
389        let params = serde_json::json!({"task_id": "test"});
390
391        let result = tool.check_permissions(&params, &context).await;
392        assert!(result.is_allowed());
393    }
394}