1use super::base::{PermissionCheckResult, Tool};
6use super::context::{ToolContext, ToolResult};
7use super::error::ToolError;
8use super::task::TaskManager;
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::time::Duration;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TaskInput {
17 pub command: String,
19 pub description: Option<String>,
21 pub run_in_background: Option<bool>,
23}
24
25pub struct TaskTool {
29 task_manager: Arc<TaskManager>,
31}
32
33impl TaskTool {
34 pub fn new() -> Self {
36 Self {
37 task_manager: Arc::new(TaskManager::new()),
38 }
39 }
40
41 pub fn with_manager(task_manager: Arc<TaskManager>) -> Self {
43 Self { task_manager }
44 }
45}
46
47impl Default for TaskTool {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53#[async_trait]
54impl Tool for TaskTool {
55 fn name(&self) -> &str {
56 "Task"
57 }
58
59 fn description(&self) -> &str {
60 r#"启动后台任务执行命令
61
62用于启动长时间运行的命令或需要并行执行的任务。支持:
63- 后台执行命令
64- 任务状态跟踪
65- 输出文件持久化
66- 并发任务限制
67
68参数:
69- command: 要执行的命令
70- description: 任务描述(可选)
71- run_in_background: 是否后台运行(默认 true)
72
73返回任务 ID,可用于后续查询任务状态和输出。"#
74 }
75
76 fn input_schema(&self) -> serde_json::Value {
77 serde_json::json!({
78 "type": "object",
79 "properties": {
80 "command": {
81 "type": "string",
82 "description": "要执行的命令"
83 },
84 "description": {
85 "type": "string",
86 "description": "任务描述(可选)"
87 },
88 "run_in_background": {
89 "type": "boolean",
90 "description": "是否在后台运行(默认 true)"
91 }
92 },
93 "required": ["command"]
94 })
95 }
96
97 async fn execute(
98 &self,
99 params: serde_json::Value,
100 context: &ToolContext,
101 ) -> Result<ToolResult, ToolError> {
102 let input: TaskInput = serde_json::from_value(params)
103 .map_err(|e| ToolError::invalid_params(format!("参数解析失败: {}", e)))?;
104
105 let run_in_background = input.run_in_background.unwrap_or(true);
106
107 let task_id = self.task_manager.start(&input.command, context).await?;
109
110 let description = input.description.unwrap_or_else(|| {
111 let cmd = &input.command;
113 if cmd.chars().count() > 50 {
114 let truncated: String = cmd.chars().take(47).collect();
115 format!("{}...", truncated)
116 } else {
117 cmd.to_string()
118 }
119 });
120
121 if run_in_background {
122 Ok(ToolResult::success(format!(
124 "任务已启动(后台运行)\n任务 ID: {}\n描述: {}\n命令: {}\n\n使用 TaskOutput 工具查询任务状态和输出。",
125 task_id, description, input.command
126 )).with_metadata("task_id", serde_json::json!(task_id)))
127 } else {
128 let timeout = Duration::from_secs(30);
131 let start_time = std::time::Instant::now();
132
133 loop {
134 if let Some(state) = self.task_manager.get_status(&task_id).await {
135 if state.status.is_terminal() {
136 let output = self
138 .task_manager
139 .get_output(&task_id, None)
140 .await
141 .unwrap_or_else(|_| "无法获取任务输出".to_string());
142
143 let duration = state.duration().as_secs_f64();
144
145 return Ok(ToolResult::success(format!(
146 "任务已完成\n任务 ID: {}\n描述: {}\n状态: {}\n执行时间: {:.2}秒\n\n=== 输出 ===\n{}",
147 task_id, description, state.status, duration, output
148 )).with_metadata("task_id", serde_json::json!(task_id))
149 .with_metadata("status", serde_json::json!(state.status.to_string()))
150 .with_metadata("duration", serde_json::json!(duration)));
151 }
152 }
153
154 if start_time.elapsed() > timeout {
156 return Ok(ToolResult::success(format!(
157 "任务启动成功但执行时间超过 {}秒,已转为后台运行\n任务 ID: {}\n描述: {}\n\n使用 TaskOutput 工具查询任务状态和输出。",
158 timeout.as_secs(), task_id, description
159 )).with_metadata("task_id", serde_json::json!(task_id)));
160 }
161
162 tokio::time::sleep(Duration::from_millis(100)).await;
164 }
165 }
166 }
167
168 async fn check_permissions(
169 &self,
170 _params: &serde_json::Value,
171 _context: &ToolContext,
172 ) -> PermissionCheckResult {
173 PermissionCheckResult::ask("执行后台任务")
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use std::path::PathBuf;
182 use tempfile::TempDir;
183
184 fn create_test_context() -> ToolContext {
185 ToolContext::new(PathBuf::from("/tmp"))
186 .with_session_id("test-session")
187 .with_user("test-user")
188 }
189
190 #[tokio::test]
191 async fn test_task_tool_new() {
192 let tool = TaskTool::new();
193 assert_eq!(tool.name(), "Task");
194 }
195
196 #[tokio::test]
197 async fn test_task_tool_input_schema() {
198 let tool = TaskTool::new();
199 let schema = tool.input_schema();
200
201 assert_eq!(schema["type"], "object");
202 assert!(schema["properties"]["command"].is_object());
203 assert_eq!(schema["required"], serde_json::json!(["command"]));
204 }
205
206 #[tokio::test]
207 async fn test_task_tool_execute_background() {
208 let temp_dir = TempDir::new().unwrap();
209 let task_manager = Arc::new(
210 TaskManager::new()
211 .with_output_directory(temp_dir.path().to_path_buf())
212 .with_max_concurrent(5),
213 );
214 let tool = TaskTool::with_manager(task_manager);
215 let context = create_test_context();
216
217 let params = serde_json::json!({
218 "command": "echo hello",
219 "description": "测试任务",
220 "run_in_background": true
221 });
222
223 let result = tool.execute(params, &context).await;
224 assert!(result.is_ok());
225
226 let tool_result = result.unwrap();
227 assert!(tool_result.success);
228 assert!(tool_result.output.as_ref().unwrap().contains("任务已启动"));
229 assert!(tool_result.metadata.contains_key("task_id"));
230 }
231
232 #[tokio::test]
233 async fn test_task_tool_execute_foreground() {
234 let temp_dir = TempDir::new().unwrap();
235 let task_manager = Arc::new(
236 TaskManager::new()
237 .with_output_directory(temp_dir.path().to_path_buf())
238 .with_max_concurrent(5),
239 );
240 let tool = TaskTool::with_manager(task_manager);
241 let context = create_test_context();
242
243 let params = serde_json::json!({
244 "command": "echo hello world",
245 "run_in_background": false
246 });
247
248 let result = tool.execute(params, &context).await;
249 assert!(result.is_ok());
250
251 let tool_result = result.unwrap();
252 assert!(tool_result.success);
253 let output = tool_result.output.as_ref().unwrap();
255 assert!(output.contains("hello world") || output.contains("任务已完成"));
256 }
257
258 #[tokio::test]
259 async fn test_task_tool_invalid_params() {
260 let tool = TaskTool::new();
261 let context = create_test_context();
262
263 let params = serde_json::json!({
264 "invalid": "params"
265 });
266
267 let result = tool.execute(params, &context).await;
268 assert!(result.is_err());
269 }
270
271 #[tokio::test]
272 async fn test_task_tool_check_permissions() {
273 let tool = TaskTool::new();
274 let context = create_test_context();
275 let params = serde_json::json!({"command": "echo test"});
276
277 let result = tool.check_permissions(¶ms, &context).await;
278 assert!(result.requires_confirmation());
279 }
280}