1#![allow(missing_docs)]
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10use tokio::time::timeout;
11use tracing::info;
12
13use crate::error::Result;
14use crate::path_utils::{is_forbidden_command, is_safe_command, validate_path};
15use crate::types::{
16 CommandHistoryEntry, CommandResult, FileOperation, FileOperationType, ShellConfig,
17};
18
19pub struct ShellService {
20 config: ShellConfig,
21 current_directory: PathBuf,
22 command_history: HashMap<String, Vec<CommandHistoryEntry>>,
23 max_history_per_conversation: usize,
24}
25
26impl ShellService {
27 pub fn new(config: ShellConfig) -> Self {
28 let current_directory = config.allowed_directory.clone();
29 info!("Shell service initialized with history tracking");
30
31 Self {
32 config,
33 current_directory,
34 command_history: HashMap::new(),
35 max_history_per_conversation: 100,
36 }
37 }
38
39 pub fn current_directory(&self) -> &Path {
40 &self.current_directory
41 }
42
43 pub fn allowed_directory(&self) -> &Path {
44 &self.config.allowed_directory
45 }
46
47 pub async fn execute_command(
48 &mut self,
49 command: &str,
50 conversation_id: Option<&str>,
51 ) -> Result<CommandResult> {
52 if !self.config.enabled {
53 return Ok(CommandResult::error(
54 "Shell plugin disabled",
55 "Shell plugin is disabled.",
56 &self.current_directory.display().to_string(),
57 ));
58 }
59
60 let trimmed_command = command.trim();
61 if trimmed_command.is_empty() {
62 return Ok(CommandResult::error(
63 "Invalid command",
64 "Command must be a non-empty string",
65 &self.current_directory.display().to_string(),
66 ));
67 }
68
69 if !is_safe_command(trimmed_command) {
70 return Ok(CommandResult::error(
71 "Security policy violation",
72 "Command contains forbidden patterns",
73 &self.current_directory.display().to_string(),
74 ));
75 }
76
77 if is_forbidden_command(trimmed_command, &self.config.forbidden_commands) {
78 return Ok(CommandResult::error(
79 "Forbidden command",
80 "Command is forbidden by security policy",
81 &self.current_directory.display().to_string(),
82 ));
83 }
84
85 if trimmed_command.starts_with("cd ") {
86 let result = self.handle_cd_command(trimmed_command);
87 if let Some(conv_id) = conversation_id {
88 self.add_to_history(conv_id, trimmed_command, &result, None);
89 }
90 return Ok(result);
91 }
92
93 let result = self.run_command(trimmed_command).await?;
94
95 if let Some(conv_id) = conversation_id {
96 let file_ops = if result.success {
97 self.detect_file_operations(trimmed_command)
98 } else {
99 None
100 };
101 self.add_to_history(conv_id, trimmed_command, &result, file_ops);
102 }
103
104 Ok(result)
105 }
106
107 fn handle_cd_command(&mut self, command: &str) -> CommandResult {
108 let parts: Vec<&str> = command.split_whitespace().collect();
109
110 if parts.len() < 2 {
111 self.current_directory = self.config.allowed_directory.clone();
112 return CommandResult::success(
113 format!("Changed directory to: {}", self.current_directory.display()),
114 &self.current_directory.display().to_string(),
115 );
116 }
117
118 let target_path = parts[1..].join(" ");
119 let validated = validate_path(
120 &target_path,
121 &self.config.allowed_directory,
122 &self.current_directory,
123 );
124
125 match validated {
126 Some(path) => {
127 self.current_directory = path;
128 CommandResult::success(
129 format!("Changed directory to: {}", self.current_directory.display()),
130 &self.current_directory.display().to_string(),
131 )
132 }
133 None => CommandResult::error(
134 "Permission denied",
135 "Cannot navigate outside allowed directory",
136 &self.current_directory.display().to_string(),
137 ),
138 }
139 }
140
141 async fn run_command(&self, command: &str) -> Result<CommandResult> {
143 let cwd = self.current_directory.display().to_string();
144 let use_shell = command.contains('>') || command.contains('<') || command.contains('|');
145
146 let mut cmd = if use_shell {
147 info!("Executing shell command: sh -c \"{}\" in {}", command, cwd);
148 let mut c = Command::new("sh");
149 c.args(["-c", command]);
150 c
151 } else {
152 let parts: Vec<&str> = command.split_whitespace().collect();
153 if parts.is_empty() {
154 return Ok(CommandResult::error(
155 "Invalid command",
156 "Empty command",
157 &cwd,
158 ));
159 }
160 info!("Executing command: {} in {}", command, cwd);
161 let mut c = Command::new(parts[0]);
162 if parts.len() > 1 {
163 c.args(&parts[1..]);
164 }
165 c
166 };
167
168 cmd.current_dir(&self.current_directory)
169 .stdout(Stdio::piped())
170 .stderr(Stdio::piped());
171
172 let timeout_duration = Duration::from_millis(self.config.timeout_ms);
173 let spawn_result = cmd.spawn();
174
175 match spawn_result {
176 Ok(mut child) => {
177 let stdout_handle = child.stdout.take();
178 let stderr_handle = child.stderr.take();
179
180 match timeout(timeout_duration, child.wait()).await {
181 Ok(Ok(status)) => {
182 let mut stdout = String::new();
183 let mut stderr = String::new();
184
185 if let Some(mut handle) = stdout_handle {
186 let _ = handle.read_to_string(&mut stdout).await;
187 }
188 if let Some(mut handle) = stderr_handle {
189 let _ = handle.read_to_string(&mut stderr).await;
190 }
191
192 Ok(CommandResult {
193 success: status.success(),
194 stdout,
195 stderr,
196 exit_code: status.code(),
197 error: None,
198 executed_in: cwd,
199 })
200 }
201 Ok(Err(e)) => Ok(CommandResult::error(
202 "Failed to execute command",
203 &e.to_string(),
204 &cwd,
205 )),
206 Err(_) => {
207 let _ = child.kill().await;
208 Ok(CommandResult {
209 success: false,
210 stdout: String::new(),
211 stderr: "Command timed out".to_string(),
212 exit_code: None,
213 error: Some("Command execution timeout".to_string()),
214 executed_in: cwd,
215 })
216 }
217 }
218 }
219 Err(e) => Ok(CommandResult::error(
220 "Failed to execute command",
221 &e.to_string(),
222 &cwd,
223 )),
224 }
225 }
226
227 fn add_to_history(
228 &mut self,
229 conversation_id: &str,
230 command: &str,
231 result: &CommandResult,
232 file_operations: Option<Vec<FileOperation>>,
233 ) {
234 let timestamp = SystemTime::now()
235 .duration_since(UNIX_EPOCH)
236 .map(|d| d.as_secs_f64())
237 .unwrap_or(0.0);
238
239 let entry = CommandHistoryEntry {
240 command: command.to_string(),
241 stdout: result.stdout.clone(),
242 stderr: result.stderr.clone(),
243 exit_code: result.exit_code,
244 timestamp,
245 working_directory: result.executed_in.clone(),
246 file_operations,
247 };
248
249 let history = self
250 .command_history
251 .entry(conversation_id.to_string())
252 .or_default();
253
254 history.push(entry);
255
256 if history.len() > self.max_history_per_conversation {
257 history.remove(0);
258 }
259 }
260
261 fn detect_file_operations(&self, command: &str) -> Option<Vec<FileOperation>> {
262 let parts: Vec<&str> = command.split_whitespace().collect();
263 if parts.is_empty() {
264 return None;
265 }
266
267 let cmd = parts[0].to_lowercase();
268 let cwd = &self.current_directory;
269 let mut operations = Vec::new();
270
271 let resolve_path = |path: &str| -> String {
272 if Path::new(path).is_absolute() {
273 path.to_string()
274 } else {
275 cwd.join(path).display().to_string()
276 }
277 };
278
279 match cmd.as_str() {
280 "touch" if parts.len() > 1 => {
281 operations.push(FileOperation {
282 op_type: FileOperationType::Create,
283 target: resolve_path(parts[1]),
284 secondary_target: None,
285 });
286 }
287 "echo" if command.contains('>') => {
288 if let Some(pos) = command.rfind('>') {
289 let target = command[pos + 1..].trim();
290 if !target.is_empty() {
291 let target = target.split_whitespace().next().unwrap_or(target);
292 operations.push(FileOperation {
293 op_type: FileOperationType::Write,
294 target: resolve_path(target),
295 secondary_target: None,
296 });
297 }
298 }
299 }
300 "mkdir" if parts.len() > 1 => {
301 operations.push(FileOperation {
302 op_type: FileOperationType::Mkdir,
303 target: resolve_path(parts[1]),
304 secondary_target: None,
305 });
306 }
307 "cat" if parts.len() > 1 && !command.contains('>') => {
308 operations.push(FileOperation {
309 op_type: FileOperationType::Read,
310 target: resolve_path(parts[1]),
311 secondary_target: None,
312 });
313 }
314 "mv" if parts.len() > 2 => {
315 operations.push(FileOperation {
316 op_type: FileOperationType::Move,
317 target: resolve_path(parts[1]),
318 secondary_target: Some(resolve_path(parts[2])),
319 });
320 }
321 "cp" if parts.len() > 2 => {
322 operations.push(FileOperation {
323 op_type: FileOperationType::Copy,
324 target: resolve_path(parts[1]),
325 secondary_target: Some(resolve_path(parts[2])),
326 });
327 }
328 _ => {}
329 }
330
331 if operations.is_empty() {
332 None
333 } else {
334 Some(operations)
335 }
336 }
337
338 pub fn get_command_history(
339 &self,
340 conversation_id: &str,
341 limit: Option<usize>,
342 ) -> Vec<CommandHistoryEntry> {
343 let history = self
344 .command_history
345 .get(conversation_id)
346 .cloned()
347 .unwrap_or_default();
348
349 match limit {
350 Some(n) if n > 0 => history.into_iter().rev().take(n).rev().collect(),
351 _ => history,
352 }
353 }
354
355 pub fn clear_command_history(&mut self, conversation_id: &str) {
356 self.command_history.remove(conversation_id);
357 info!(
358 "Cleared command history for conversation: {}",
359 conversation_id
360 );
361 }
362
363 pub fn get_current_directory(&self, _conversation_id: Option<&str>) -> &Path {
364 &self.current_directory
365 }
366
367 pub fn get_allowed_directory(&self) -> &Path {
368 &self.config.allowed_directory
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use tempfile::tempdir;
376
377 fn test_config() -> ShellConfig {
378 let dir = tempdir().unwrap();
379 ShellConfig {
380 enabled: true,
381 allowed_directory: dir.keep(),
382 timeout_ms: 30000,
383 forbidden_commands: vec!["rm".to_string(), "rmdir".to_string()],
384 }
385 }
386
387 #[tokio::test]
388 async fn test_disabled_shell() {
389 let mut config = test_config();
390 config.enabled = false;
391 let mut service = ShellService::new(config);
392
393 let result = service.execute_command("ls", None).await.unwrap();
394 assert!(!result.success);
395 assert!(result.stderr.contains("disabled"));
396 }
397
398 #[tokio::test]
399 async fn test_forbidden_command() {
400 let config = test_config();
401 let mut service = ShellService::new(config);
402
403 let result = service.execute_command("rm file.txt", None).await.unwrap();
404 assert!(!result.success);
405 assert!(result.stderr.contains("forbidden"));
406 }
407
408 #[tokio::test]
409 async fn test_history_tracking() {
410 let config = test_config();
411 let mut service = ShellService::new(config);
412 let conv_id = "test-conv";
413
414 service
415 .execute_command("echo hello", Some(conv_id))
416 .await
417 .unwrap();
418
419 let history = service.get_command_history(conv_id, None);
420 assert_eq!(history.len(), 1);
421 assert_eq!(history[0].command, "echo hello");
422 }
423
424 #[tokio::test]
425 async fn test_clear_history() {
426 let config = test_config();
427 let mut service = ShellService::new(config);
428 let conv_id = "test-conv";
429
430 service
431 .execute_command("echo test", Some(conv_id))
432 .await
433 .unwrap();
434 assert_eq!(service.get_command_history(conv_id, None).len(), 1);
435
436 service.clear_command_history(conv_id);
437 assert_eq!(service.get_command_history(conv_id, None).len(), 0);
438 }
439}