aster/tools/
task_output_tool.rs1use super::base::{PermissionCheckResult, Tool};
6use super::context::{ToolContext, ToolResult};
7use super::error::ToolError;
8use super::task::TaskManager;
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::time::Duration;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TaskOutputInput {
17 pub task_id: String,
19 pub block: Option<bool>,
21 pub timeout: Option<u64>,
23 pub show_history: Option<bool>,
25 pub lines: Option<usize>,
27}
28
29pub struct TaskOutputTool {
33 task_manager: Arc<TaskManager>,
35}
36
37impl TaskOutputTool {
38 pub fn new() -> Self {
40 Self {
41 task_manager: Arc::new(TaskManager::new()),
42 }
43 }
44
45 pub fn with_manager(task_manager: Arc<TaskManager>) -> Self {
47 Self { task_manager }
48 }
49}
50
51impl Default for TaskOutputTool {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57#[async_trait]
58impl Tool for TaskOutputTool {
59 fn name(&self) -> &str {
60 "TaskOutput"
61 }
62
63 fn description(&self) -> &str {
64 r#"获取后台任务的输出和状态
65
66用于查询通过 Task 工具启动的后台任务的执行状态和输出结果。
67
68参数:
69- task_id: 任务 ID(必需)
70- block: 是否等待任务完成(默认 false)
71- timeout: 等待超时时间(毫秒,默认 5000)
72- show_history: 显示详细执行历史(默认 false)
73- lines: 限制输出行数(可选)
74
75功能:
76- 查询任务状态(运行中/已完成/失败/超时/已终止)
77- 获取任务输出内容
78- 支持阻塞等待任务完成
79- 显示任务执行时间和统计信息"#
80 }
81
82 fn input_schema(&self) -> serde_json::Value {
83 serde_json::json!({
84 "type": "object",
85 "properties": {
86 "task_id": {
87 "type": "string",
88 "description": "要查询的任务 ID"
89 },
90 "block": {
91 "type": "boolean",
92 "description": "是否等待任务完成(默认 false)"
93 },
94 "timeout": {
95 "type": "number",
96 "description": "等待超时时间(毫秒,默认 5000)"
97 },
98 "show_history": {
99 "type": "boolean",
100 "description": "显示详细执行历史(默认 false)"
101 },
102 "lines": {
103 "type": "number",
104 "description": "限制输出行数(可选)"
105 }
106 },
107 "required": ["task_id"]
108 })
109 }
110
111 async fn execute(
112 &self,
113 params: serde_json::Value,
114 _context: &ToolContext,
115 ) -> Result<ToolResult, ToolError> {
116 let input: TaskOutputInput = serde_json::from_value(params)
117 .map_err(|e| ToolError::invalid_params(format!("参数解析失败: {}", e)))?;
118
119 let block = input.block.unwrap_or(false);
120 let timeout_ms = input.timeout.unwrap_or(5000);
121 let show_history = input.show_history.unwrap_or(false);
122
123 if !self.task_manager.task_exists(&input.task_id).await {
125 return Err(ToolError::not_found(format!(
126 "任务未找到: {}",
127 input.task_id
128 )));
129 }
130
131 if block {
133 let timeout = Duration::from_millis(timeout_ms);
134 let start_time = std::time::Instant::now();
135
136 loop {
137 if let Some(state) = self.task_manager.get_status(&input.task_id).await {
138 if state.status.is_terminal() {
139 break;
140 }
141 }
142
143 if start_time.elapsed() > timeout {
145 break;
146 }
147
148 tokio::time::sleep(Duration::from_millis(100)).await;
150 }
151 }
152
153 let state = self
155 .task_manager
156 .get_status(&input.task_id)
157 .await
158 .ok_or_else(|| ToolError::not_found(format!("任务状态未找到: {}", input.task_id)))?;
159
160 let mut output = Vec::new();
162 output.push(format!("=== 任务 {} ===", input.task_id));
163 output.push(format!("命令: {}", state.command));
164 output.push(format!("状态: {}", state.status));
165 output.push(format!("开始时间: {}", format_instant(state.start_time)));
166
167 let duration = state.duration();
168 if let Some(end_time) = state.end_time {
169 output.push(format!("结束时间: {}", format_instant(end_time)));
170 output.push(format!("执行时间: {:.2}秒", duration.as_secs_f64()));
171 } else {
172 output.push(format!("运行时间: {:.2}秒", duration.as_secs_f64()));
173 }
174
175 if let Some(exit_code) = state.exit_code {
176 output.push(format!("退出码: {}", exit_code));
177 }
178
179 output.push(format!("工作目录: {}", state.working_directory.display()));
180 output.push(format!("输出文件: {}", state.output_file.display()));
181 output.push(format!("会话 ID: {}", state.session_id));
182
183 if show_history {
185 output.push("\n=== 执行历史 ===".to_string());
186 output.push("(注意:当前实现中 TaskManager 不维护详细历史记录)".to_string());
187 output.push(format!("任务创建: {}", format_instant(state.start_time)));
188 if let Some(end_time) = state.end_time {
189 output.push(format!(
190 "任务结束: {} (状态: {})",
191 format_instant(end_time),
192 state.status
193 ));
194 }
195 }
196
197 match self
199 .task_manager
200 .get_output(&input.task_id, input.lines)
201 .await
202 {
203 Ok(task_output) => {
204 output.push("\n=== 任务输出 ===".to_string());
205 if task_output.trim().is_empty() {
206 output.push("(暂无输出)".to_string());
207 } else {
208 output.push(task_output);
209 }
210 }
211 Err(e) => {
212 output.push("\n=== 输出获取失败 ===".to_string());
213 output.push(format!("错误: {}", e));
214 }
215 }
216
217 match state.status {
219 super::task::TaskStatus::Running => {
220 output.push("\n=== 状态说明 ===".to_string());
221 output.push("任务仍在运行中。使用 block=true 参数等待任务完成。".to_string());
222 }
223 super::task::TaskStatus::Completed => {
224 output.push("\n=== 状态说明 ===".to_string());
225 output.push("任务已成功完成。".to_string());
226 }
227 super::task::TaskStatus::Failed => {
228 output.push("\n=== 状态说明 ===".to_string());
229 output.push("任务执行失败。请检查命令和输出错误信息。".to_string());
230 }
231 super::task::TaskStatus::TimedOut => {
232 output.push("\n=== 状态说明 ===".to_string());
233 output.push("任务因超时被终止。".to_string());
234 }
235 super::task::TaskStatus::Killed => {
236 output.push("\n=== 状态说明 ===".to_string());
237 output.push("任务被用户终止。".to_string());
238 }
239 }
240
241 Ok(ToolResult::success(output.join("\n"))
242 .with_metadata("task_id", serde_json::json!(input.task_id))
243 .with_metadata("status", serde_json::json!(state.status.to_string()))
244 .with_metadata("duration", serde_json::json!(duration.as_secs_f64()))
245 .with_metadata("exit_code", serde_json::json!(state.exit_code)))
246 }
247
248 async fn check_permissions(
249 &self,
250 _params: &serde_json::Value,
251 _context: &ToolContext,
252 ) -> PermissionCheckResult {
253 PermissionCheckResult::allow()
255 }
256}
257
258fn format_instant(instant: std::time::Instant) -> String {
261 let elapsed = instant.elapsed();
262 format!("{:.2}秒前", elapsed.as_secs_f64())
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use std::path::PathBuf;
269 use tempfile::TempDir;
270
271 fn create_test_context() -> ToolContext {
272 ToolContext::new(PathBuf::from("/tmp"))
273 .with_session_id("test-session")
274 .with_user("test-user")
275 }
276
277 #[tokio::test]
278 async fn test_task_output_tool_new() {
279 let tool = TaskOutputTool::new();
280 assert_eq!(tool.name(), "TaskOutput");
281 }
282
283 #[tokio::test]
284 async fn test_task_output_tool_input_schema() {
285 let tool = TaskOutputTool::new();
286 let schema = tool.input_schema();
287
288 assert_eq!(schema["type"], "object");
289 assert!(schema["properties"]["task_id"].is_object());
290 assert_eq!(schema["required"], serde_json::json!(["task_id"]));
291 }
292
293 #[tokio::test]
294 async fn test_task_output_tool_not_found() {
295 let tool = TaskOutputTool::new();
296 let context = create_test_context();
297
298 let params = serde_json::json!({
299 "task_id": "nonexistent-task"
300 });
301
302 let result = tool.execute(params, &context).await;
303 assert!(result.is_err());
304 assert!(matches!(result.unwrap_err(), ToolError::NotFound(_)));
305 }
306
307 #[tokio::test]
308 async fn test_task_output_tool_with_task() {
309 let temp_dir = TempDir::new().unwrap();
310 let task_manager = Arc::new(
311 TaskManager::new()
312 .with_output_directory(temp_dir.path().to_path_buf())
313 .with_max_concurrent(5),
314 );
315 let tool = TaskOutputTool::with_manager(task_manager.clone());
316 let context = create_test_context();
317
318 let task_id = task_manager.start("echo hello", &context).await.unwrap();
320
321 tokio::time::sleep(Duration::from_millis(500)).await;
323
324 let params = serde_json::json!({
326 "task_id": task_id
327 });
328
329 let result = tool.execute(params, &context).await;
330 assert!(result.is_ok());
331
332 let tool_result = result.unwrap();
333 assert!(tool_result.success);
334 assert!(tool_result.output.as_ref().unwrap().contains(&task_id));
335 assert!(tool_result.metadata.contains_key("status"));
336 }
337
338 #[tokio::test]
339 async fn test_task_output_tool_with_block() {
340 let temp_dir = TempDir::new().unwrap();
341 let task_manager = Arc::new(
342 TaskManager::new()
343 .with_output_directory(temp_dir.path().to_path_buf())
344 .with_max_concurrent(5),
345 );
346 let tool = TaskOutputTool::with_manager(task_manager.clone());
347 let context = create_test_context();
348
349 let task_id = task_manager
351 .start("echo blocking test", &context)
352 .await
353 .unwrap();
354
355 let params = serde_json::json!({
357 "task_id": task_id,
358 "block": true,
359 "timeout": 2000
360 });
361
362 let result = tool.execute(params, &context).await;
363 assert!(result.is_ok());
364
365 let tool_result = result.unwrap();
366 assert!(tool_result.success);
367 let output = tool_result.output.as_ref().unwrap();
369 assert!(output.contains("blocking test") || output.contains("已完成"));
370 }
371
372 #[tokio::test]
373 async fn test_task_output_tool_invalid_params() {
374 let tool = TaskOutputTool::new();
375 let context = create_test_context();
376
377 let params = serde_json::json!({
378 "invalid": "params"
379 });
380
381 let result = tool.execute(params, &context).await;
382 assert!(result.is_err());
383 }
384
385 #[tokio::test]
386 async fn test_task_output_tool_check_permissions() {
387 let tool = TaskOutputTool::new();
388 let context = create_test_context();
389 let params = serde_json::json!({"task_id": "test"});
390
391 let result = tool.check_permissions(¶ms, &context).await;
392 assert!(result.is_allowed());
393 }
394}