Skip to main content

aster/tools/
kill_shell_tool.rs

1//! Kill Shell Tool Implementation
2//!
3//! 此模块实现了 `KillShellTool`,用于终止正在运行的后台任务:
4//! - 支持通过 task_id 终止特定任务
5//! - 与 TaskManager 集成
6//! - 提供安全的任务终止机制
7//! - 兼容 Claude Agent SDK 的 KillShellTool 接口
8//!
9//! Requirements: 基于 Claude Agent SDK bash.ts 中的 KillShellTool 实现
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14
15use super::base::{PermissionCheckResult, Tool};
16use super::context::{ToolContext, ToolOptions, ToolResult};
17use super::error::ToolError;
18use super::task::TaskManager;
19
20/// KillShell 工具输入参数
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct KillShellInput {
23    /// 要终止的任务 ID(支持 shell_id 和 task_id 两种格式)
24    #[serde(alias = "task_id")]
25    pub shell_id: String,
26}
27
28/// Kill Shell Tool for terminating background tasks
29///
30/// 提供安全的后台任务终止功能:
31/// - 通过 task_id 或 shell_id 终止任务
32/// - 与现有 TaskManager 集成
33/// - 支持向后兼容的参数名称
34/// - 提供详细的终止状态反馈
35#[derive(Debug)]
36pub struct KillShellTool {
37    /// Task manager for background task management
38    task_manager: Arc<TaskManager>,
39}
40
41impl Default for KillShellTool {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl KillShellTool {
48    /// Create a new KillShellTool with default TaskManager
49    pub fn new() -> Self {
50        Self {
51            task_manager: Arc::new(TaskManager::new()),
52        }
53    }
54
55    /// Create a KillShellTool with custom TaskManager
56    pub fn with_task_manager(task_manager: Arc<TaskManager>) -> Self {
57        Self { task_manager }
58    }
59
60    /// Get the task manager
61    pub fn task_manager(&self) -> &Arc<TaskManager> {
62        &self.task_manager
63    }
64}
65
66#[async_trait]
67impl Tool for KillShellTool {
68    /// Returns the tool name
69    fn name(&self) -> &str {
70        "KillShell"
71    }
72
73    /// Returns the tool description
74    fn description(&self) -> &str {
75        "Kills a running background bash shell by its ID. \
76         Takes a shell_id parameter identifying the shell to kill. \
77         Returns a success or failure status. \
78         Use this tool when you need to terminate a long-running shell. \
79         Shell IDs can be found using the TaskOutput tool or from background task execution results."
80    }
81
82    /// Returns the JSON Schema for input parameters
83    fn input_schema(&self) -> serde_json::Value {
84        serde_json::json!({
85            "type": "object",
86            "properties": {
87                "shell_id": {
88                    "type": "string",
89                    "description": "The ID of the background shell/task to kill"
90                }
91            },
92            "required": ["shell_id"]
93        })
94    }
95
96    /// Execute the kill shell command
97    async fn execute(
98        &self,
99        params: serde_json::Value,
100        _context: &ToolContext,
101    ) -> Result<ToolResult, ToolError> {
102        // Extract shell_id parameter (also accept task_id for compatibility)
103        let shell_id = params
104            .get("shell_id")
105            .or_else(|| params.get("task_id"))
106            .and_then(|v| v.as_str())
107            .ok_or_else(|| ToolError::invalid_params("Missing required parameter: shell_id"))?;
108
109        // Attempt to kill the task
110        match self.task_manager.kill(shell_id).await {
111            Ok(()) => {
112                let success_message = format!("Successfully killed shell: {}", shell_id);
113                Ok(ToolResult::success(success_message)
114                    .with_metadata("shell_id", serde_json::json!(shell_id))
115                    .with_metadata("killed", serde_json::json!(true)))
116            }
117            Err(ToolError::NotFound(_)) => {
118                let error_message = format!("No shell found with ID: {}", shell_id);
119                Ok(ToolResult::error(error_message)
120                    .with_metadata("shell_id", serde_json::json!(shell_id))
121                    .with_metadata("killed", serde_json::json!(false)))
122            }
123            Err(e) => {
124                let error_message = format!("Failed to kill shell {}: {}", shell_id, e);
125                Ok(ToolResult::error(error_message)
126                    .with_metadata("shell_id", serde_json::json!(shell_id))
127                    .with_metadata("killed", serde_json::json!(false)))
128            }
129        }
130    }
131
132    /// Check permissions before execution
133    async fn check_permissions(
134        &self,
135        params: &serde_json::Value,
136        _context: &ToolContext,
137    ) -> PermissionCheckResult {
138        // Extract shell_id for validation
139        let shell_id = match params
140            .get("shell_id")
141            .or_else(|| params.get("task_id"))
142            .and_then(|v| v.as_str())
143        {
144            Some(id) => id,
145            None => return PermissionCheckResult::deny("Missing shell_id parameter"),
146        };
147
148        // Basic validation - ensure shell_id is not empty
149        if shell_id.trim().is_empty() {
150            return PermissionCheckResult::deny("shell_id cannot be empty");
151        }
152
153        // Allow the operation - killing tasks is generally safe
154        PermissionCheckResult::allow()
155    }
156
157    /// Get tool options
158    fn options(&self) -> ToolOptions {
159        ToolOptions::new()
160            .with_max_retries(0) // Don't retry kill operations
161            .with_base_timeout(std::time::Duration::from_secs(10)) // Quick timeout for kill operations
162            .with_dynamic_timeout(false)
163    }
164}
165
166// =============================================================================
167// Unit Tests
168// =============================================================================
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use std::path::PathBuf;
174    use tempfile::TempDir;
175
176    fn create_test_context() -> ToolContext {
177        ToolContext::new(PathBuf::from("/tmp"))
178            .with_session_id("test-session")
179            .with_user("test-user")
180    }
181
182    fn create_test_manager() -> Arc<TaskManager> {
183        let temp_dir = TempDir::new().unwrap();
184        Arc::new(TaskManager::new().with_output_directory(temp_dir.path().to_path_buf()))
185    }
186
187    #[test]
188    fn test_tool_name() {
189        let tool = KillShellTool::new();
190        assert_eq!(tool.name(), "KillShell");
191    }
192
193    #[test]
194    fn test_tool_description() {
195        let tool = KillShellTool::new();
196        assert!(!tool.description().is_empty());
197        assert!(tool.description().contains("kill"));
198        assert!(tool.description().contains("shell"));
199    }
200
201    #[test]
202    fn test_tool_input_schema() {
203        let tool = KillShellTool::new();
204        let schema = tool.input_schema();
205        assert_eq!(schema["type"], "object");
206        assert!(schema["properties"]["shell_id"].is_object());
207        assert!(schema["required"]
208            .as_array()
209            .unwrap()
210            .contains(&serde_json::json!("shell_id")));
211    }
212
213    #[test]
214    fn test_tool_options() {
215        let tool = KillShellTool::new();
216        let options = tool.options();
217        assert_eq!(options.max_retries, 0);
218        assert_eq!(options.base_timeout, std::time::Duration::from_secs(10));
219        assert!(!options.enable_dynamic_timeout);
220    }
221
222    #[test]
223    fn test_builder_with_task_manager() {
224        let task_manager = create_test_manager();
225        let tool = KillShellTool::with_task_manager(task_manager.clone());
226        assert!(Arc::ptr_eq(&tool.task_manager, &task_manager));
227    }
228
229    // Permission Check Tests
230
231    #[tokio::test]
232    async fn test_check_permissions_valid_shell_id() {
233        let tool = KillShellTool::new();
234        let context = create_test_context();
235        let params = serde_json::json!({"shell_id": "test-task-123"});
236
237        let result = tool.check_permissions(&params, &context).await;
238        assert!(result.is_allowed());
239    }
240
241    #[tokio::test]
242    async fn test_check_permissions_task_id_alias() {
243        let tool = KillShellTool::new();
244        let context = create_test_context();
245        let params = serde_json::json!({"task_id": "test-task-123"});
246
247        let result = tool.check_permissions(&params, &context).await;
248        assert!(result.is_allowed());
249    }
250
251    #[tokio::test]
252    async fn test_check_permissions_missing_shell_id() {
253        let tool = KillShellTool::new();
254        let context = create_test_context();
255        let params = serde_json::json!({});
256
257        let result = tool.check_permissions(&params, &context).await;
258        assert!(result.is_denied());
259    }
260
261    #[tokio::test]
262    async fn test_check_permissions_empty_shell_id() {
263        let tool = KillShellTool::new();
264        let context = create_test_context();
265        let params = serde_json::json!({"shell_id": ""});
266
267        let result = tool.check_permissions(&params, &context).await;
268        assert!(result.is_denied());
269    }
270
271    // Execution Tests
272
273    #[tokio::test]
274    async fn test_execute_nonexistent_task() {
275        let task_manager = create_test_manager();
276        let tool = KillShellTool::with_task_manager(task_manager);
277        let context = create_test_context();
278        let params = serde_json::json!({"shell_id": "nonexistent-task"});
279
280        let result = tool.execute(params, &context).await;
281        assert!(result.is_ok());
282        let tool_result = result.unwrap();
283        assert!(tool_result.is_error());
284        // 错误信息在 error 字段中,不是 output 字段
285        assert!(tool_result.error.unwrap().contains("No shell found"));
286    }
287
288    #[tokio::test]
289    async fn test_execute_missing_shell_id() {
290        let tool = KillShellTool::new();
291        let context = create_test_context();
292        let params = serde_json::json!({});
293
294        let result = tool.execute(params, &context).await;
295        assert!(result.is_err());
296        assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
297    }
298
299    #[tokio::test]
300    async fn test_execute_with_task_id_alias() {
301        let task_manager = create_test_manager();
302        let tool = KillShellTool::with_task_manager(task_manager);
303        let context = create_test_context();
304        let params = serde_json::json!({"task_id": "nonexistent-task"});
305
306        let result = tool.execute(params, &context).await;
307        assert!(result.is_ok());
308        let tool_result = result.unwrap();
309        assert!(tool_result.is_error());
310        // 错误信息在 error 字段中,不是 output 字段
311        assert!(tool_result.error.unwrap().contains("No shell found"));
312    }
313
314    #[tokio::test]
315    async fn test_execute_kill_running_task() {
316        let task_manager = create_test_manager();
317        let tool = KillShellTool::with_task_manager(task_manager.clone());
318        let context = create_test_context();
319
320        // Start a long-running task
321        let command = if cfg!(target_os = "windows") {
322            "timeout /t 30"
323        } else {
324            "sleep 30"
325        };
326        let task_id = task_manager.start(command, &context).await.unwrap();
327
328        // Kill the task
329        let params = serde_json::json!({"shell_id": task_id});
330        let result = tool.execute(params, &context).await;
331
332        assert!(result.is_ok());
333        let tool_result = result.unwrap();
334        assert!(tool_result.is_success());
335        assert!(tool_result.output.unwrap().contains("Successfully killed"));
336    }
337}