aster/tools/
kill_shell_tool.rs1use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14
15use super::base::{PermissionCheckResult, Tool};
16use super::context::{ToolContext, ToolOptions, ToolResult};
17use super::error::ToolError;
18use super::task::TaskManager;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct KillShellInput {
23 #[serde(alias = "task_id")]
25 pub shell_id: String,
26}
27
28#[derive(Debug)]
36pub struct KillShellTool {
37 task_manager: Arc<TaskManager>,
39}
40
41impl Default for KillShellTool {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl KillShellTool {
48 pub fn new() -> Self {
50 Self {
51 task_manager: Arc::new(TaskManager::new()),
52 }
53 }
54
55 pub fn with_task_manager(task_manager: Arc<TaskManager>) -> Self {
57 Self { task_manager }
58 }
59
60 pub fn task_manager(&self) -> &Arc<TaskManager> {
62 &self.task_manager
63 }
64}
65
66#[async_trait]
67impl Tool for KillShellTool {
68 fn name(&self) -> &str {
70 "KillShell"
71 }
72
73 fn description(&self) -> &str {
75 "Kills a running background bash shell by its ID. \
76 Takes a shell_id parameter identifying the shell to kill. \
77 Returns a success or failure status. \
78 Use this tool when you need to terminate a long-running shell. \
79 Shell IDs can be found using the TaskOutput tool or from background task execution results."
80 }
81
82 fn input_schema(&self) -> serde_json::Value {
84 serde_json::json!({
85 "type": "object",
86 "properties": {
87 "shell_id": {
88 "type": "string",
89 "description": "The ID of the background shell/task to kill"
90 }
91 },
92 "required": ["shell_id"]
93 })
94 }
95
96 async fn execute(
98 &self,
99 params: serde_json::Value,
100 _context: &ToolContext,
101 ) -> Result<ToolResult, ToolError> {
102 let shell_id = params
104 .get("shell_id")
105 .or_else(|| params.get("task_id"))
106 .and_then(|v| v.as_str())
107 .ok_or_else(|| ToolError::invalid_params("Missing required parameter: shell_id"))?;
108
109 match self.task_manager.kill(shell_id).await {
111 Ok(()) => {
112 let success_message = format!("Successfully killed shell: {}", shell_id);
113 Ok(ToolResult::success(success_message)
114 .with_metadata("shell_id", serde_json::json!(shell_id))
115 .with_metadata("killed", serde_json::json!(true)))
116 }
117 Err(ToolError::NotFound(_)) => {
118 let error_message = format!("No shell found with ID: {}", shell_id);
119 Ok(ToolResult::error(error_message)
120 .with_metadata("shell_id", serde_json::json!(shell_id))
121 .with_metadata("killed", serde_json::json!(false)))
122 }
123 Err(e) => {
124 let error_message = format!("Failed to kill shell {}: {}", shell_id, e);
125 Ok(ToolResult::error(error_message)
126 .with_metadata("shell_id", serde_json::json!(shell_id))
127 .with_metadata("killed", serde_json::json!(false)))
128 }
129 }
130 }
131
132 async fn check_permissions(
134 &self,
135 params: &serde_json::Value,
136 _context: &ToolContext,
137 ) -> PermissionCheckResult {
138 let shell_id = match params
140 .get("shell_id")
141 .or_else(|| params.get("task_id"))
142 .and_then(|v| v.as_str())
143 {
144 Some(id) => id,
145 None => return PermissionCheckResult::deny("Missing shell_id parameter"),
146 };
147
148 if shell_id.trim().is_empty() {
150 return PermissionCheckResult::deny("shell_id cannot be empty");
151 }
152
153 PermissionCheckResult::allow()
155 }
156
157 fn options(&self) -> ToolOptions {
159 ToolOptions::new()
160 .with_max_retries(0) .with_base_timeout(std::time::Duration::from_secs(10)) .with_dynamic_timeout(false)
163 }
164}
165
166#[cfg(test)]
171mod tests {
172 use super::*;
173 use std::path::PathBuf;
174 use tempfile::TempDir;
175
176 fn create_test_context() -> ToolContext {
177 ToolContext::new(PathBuf::from("/tmp"))
178 .with_session_id("test-session")
179 .with_user("test-user")
180 }
181
182 fn create_test_manager() -> Arc<TaskManager> {
183 let temp_dir = TempDir::new().unwrap();
184 Arc::new(TaskManager::new().with_output_directory(temp_dir.path().to_path_buf()))
185 }
186
187 #[test]
188 fn test_tool_name() {
189 let tool = KillShellTool::new();
190 assert_eq!(tool.name(), "KillShell");
191 }
192
193 #[test]
194 fn test_tool_description() {
195 let tool = KillShellTool::new();
196 assert!(!tool.description().is_empty());
197 assert!(tool.description().contains("kill"));
198 assert!(tool.description().contains("shell"));
199 }
200
201 #[test]
202 fn test_tool_input_schema() {
203 let tool = KillShellTool::new();
204 let schema = tool.input_schema();
205 assert_eq!(schema["type"], "object");
206 assert!(schema["properties"]["shell_id"].is_object());
207 assert!(schema["required"]
208 .as_array()
209 .unwrap()
210 .contains(&serde_json::json!("shell_id")));
211 }
212
213 #[test]
214 fn test_tool_options() {
215 let tool = KillShellTool::new();
216 let options = tool.options();
217 assert_eq!(options.max_retries, 0);
218 assert_eq!(options.base_timeout, std::time::Duration::from_secs(10));
219 assert!(!options.enable_dynamic_timeout);
220 }
221
222 #[test]
223 fn test_builder_with_task_manager() {
224 let task_manager = create_test_manager();
225 let tool = KillShellTool::with_task_manager(task_manager.clone());
226 assert!(Arc::ptr_eq(&tool.task_manager, &task_manager));
227 }
228
229 #[tokio::test]
232 async fn test_check_permissions_valid_shell_id() {
233 let tool = KillShellTool::new();
234 let context = create_test_context();
235 let params = serde_json::json!({"shell_id": "test-task-123"});
236
237 let result = tool.check_permissions(¶ms, &context).await;
238 assert!(result.is_allowed());
239 }
240
241 #[tokio::test]
242 async fn test_check_permissions_task_id_alias() {
243 let tool = KillShellTool::new();
244 let context = create_test_context();
245 let params = serde_json::json!({"task_id": "test-task-123"});
246
247 let result = tool.check_permissions(¶ms, &context).await;
248 assert!(result.is_allowed());
249 }
250
251 #[tokio::test]
252 async fn test_check_permissions_missing_shell_id() {
253 let tool = KillShellTool::new();
254 let context = create_test_context();
255 let params = serde_json::json!({});
256
257 let result = tool.check_permissions(¶ms, &context).await;
258 assert!(result.is_denied());
259 }
260
261 #[tokio::test]
262 async fn test_check_permissions_empty_shell_id() {
263 let tool = KillShellTool::new();
264 let context = create_test_context();
265 let params = serde_json::json!({"shell_id": ""});
266
267 let result = tool.check_permissions(¶ms, &context).await;
268 assert!(result.is_denied());
269 }
270
271 #[tokio::test]
274 async fn test_execute_nonexistent_task() {
275 let task_manager = create_test_manager();
276 let tool = KillShellTool::with_task_manager(task_manager);
277 let context = create_test_context();
278 let params = serde_json::json!({"shell_id": "nonexistent-task"});
279
280 let result = tool.execute(params, &context).await;
281 assert!(result.is_ok());
282 let tool_result = result.unwrap();
283 assert!(tool_result.is_error());
284 assert!(tool_result.error.unwrap().contains("No shell found"));
286 }
287
288 #[tokio::test]
289 async fn test_execute_missing_shell_id() {
290 let tool = KillShellTool::new();
291 let context = create_test_context();
292 let params = serde_json::json!({});
293
294 let result = tool.execute(params, &context).await;
295 assert!(result.is_err());
296 assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
297 }
298
299 #[tokio::test]
300 async fn test_execute_with_task_id_alias() {
301 let task_manager = create_test_manager();
302 let tool = KillShellTool::with_task_manager(task_manager);
303 let context = create_test_context();
304 let params = serde_json::json!({"task_id": "nonexistent-task"});
305
306 let result = tool.execute(params, &context).await;
307 assert!(result.is_ok());
308 let tool_result = result.unwrap();
309 assert!(tool_result.is_error());
310 assert!(tool_result.error.unwrap().contains("No shell found"));
312 }
313
314 #[tokio::test]
315 async fn test_execute_kill_running_task() {
316 let task_manager = create_test_manager();
317 let tool = KillShellTool::with_task_manager(task_manager.clone());
318 let context = create_test_context();
319
320 let command = if cfg!(target_os = "windows") {
322 "timeout /t 30"
323 } else {
324 "sleep 30"
325 };
326 let task_id = task_manager.start(command, &context).await.unwrap();
327
328 let params = serde_json::json!({"shell_id": task_id});
330 let result = tool.execute(params, &context).await;
331
332 assert!(result.is_ok());
333 let tool_result = result.unwrap();
334 assert!(tool_result.is_success());
335 assert!(tool_result.output.unwrap().contains("Successfully killed"));
336 }
337}