Skip to main content

elizaos_plugin_shell/
service.rs

1#![allow(missing_docs)]
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10use tokio::time::timeout;
11use tracing::info;
12
13use crate::error::Result;
14use crate::path_utils::{is_forbidden_command, is_safe_command, validate_path};
15use crate::types::{
16    CommandHistoryEntry, CommandResult, FileOperation, FileOperationType, ShellConfig,
17};
18
19pub struct ShellService {
20    config: ShellConfig,
21    current_directory: PathBuf,
22    command_history: HashMap<String, Vec<CommandHistoryEntry>>,
23    max_history_per_conversation: usize,
24}
25
26impl ShellService {
27    pub fn new(config: ShellConfig) -> Self {
28        let current_directory = config.allowed_directory.clone();
29        info!("Shell service initialized with history tracking");
30
31        Self {
32            config,
33            current_directory,
34            command_history: HashMap::new(),
35            max_history_per_conversation: 100,
36        }
37    }
38
39    pub fn current_directory(&self) -> &Path {
40        &self.current_directory
41    }
42
43    pub fn allowed_directory(&self) -> &Path {
44        &self.config.allowed_directory
45    }
46
47    pub async fn execute_command(
48        &mut self,
49        command: &str,
50        conversation_id: Option<&str>,
51    ) -> Result<CommandResult> {
52        if !self.config.enabled {
53            return Ok(CommandResult::error(
54                "Shell plugin disabled",
55                "Shell plugin is disabled.",
56                &self.current_directory.display().to_string(),
57            ));
58        }
59
60        let trimmed_command = command.trim();
61        if trimmed_command.is_empty() {
62            return Ok(CommandResult::error(
63                "Invalid command",
64                "Command must be a non-empty string",
65                &self.current_directory.display().to_string(),
66            ));
67        }
68
69        if !is_safe_command(trimmed_command) {
70            return Ok(CommandResult::error(
71                "Security policy violation",
72                "Command contains forbidden patterns",
73                &self.current_directory.display().to_string(),
74            ));
75        }
76
77        if is_forbidden_command(trimmed_command, &self.config.forbidden_commands) {
78            return Ok(CommandResult::error(
79                "Forbidden command",
80                "Command is forbidden by security policy",
81                &self.current_directory.display().to_string(),
82            ));
83        }
84
85        if trimmed_command.starts_with("cd ") {
86            let result = self.handle_cd_command(trimmed_command);
87            if let Some(conv_id) = conversation_id {
88                self.add_to_history(conv_id, trimmed_command, &result, None);
89            }
90            return Ok(result);
91        }
92
93        let result = self.run_command(trimmed_command).await?;
94
95        if let Some(conv_id) = conversation_id {
96            let file_ops = if result.success {
97                self.detect_file_operations(trimmed_command)
98            } else {
99                None
100            };
101            self.add_to_history(conv_id, trimmed_command, &result, file_ops);
102        }
103
104        Ok(result)
105    }
106
107    fn handle_cd_command(&mut self, command: &str) -> CommandResult {
108        let parts: Vec<&str> = command.split_whitespace().collect();
109
110        if parts.len() < 2 {
111            self.current_directory = self.config.allowed_directory.clone();
112            return CommandResult::success(
113                format!("Changed directory to: {}", self.current_directory.display()),
114                &self.current_directory.display().to_string(),
115            );
116        }
117
118        let target_path = parts[1..].join(" ");
119        let validated = validate_path(
120            &target_path,
121            &self.config.allowed_directory,
122            &self.current_directory,
123        );
124
125        match validated {
126            Some(path) => {
127                self.current_directory = path;
128                CommandResult::success(
129                    format!("Changed directory to: {}", self.current_directory.display()),
130                    &self.current_directory.display().to_string(),
131                )
132            }
133            None => CommandResult::error(
134                "Permission denied",
135                "Cannot navigate outside allowed directory",
136                &self.current_directory.display().to_string(),
137            ),
138        }
139    }
140
141    /// Run a command using tokio process.
142    async fn run_command(&self, command: &str) -> Result<CommandResult> {
143        let cwd = self.current_directory.display().to_string();
144        let use_shell = command.contains('>') || command.contains('<') || command.contains('|');
145
146        let mut cmd = if use_shell {
147            info!("Executing shell command: sh -c \"{}\" in {}", command, cwd);
148            let mut c = Command::new("sh");
149            c.args(["-c", command]);
150            c
151        } else {
152            let parts: Vec<&str> = command.split_whitespace().collect();
153            if parts.is_empty() {
154                return Ok(CommandResult::error(
155                    "Invalid command",
156                    "Empty command",
157                    &cwd,
158                ));
159            }
160            info!("Executing command: {} in {}", command, cwd);
161            let mut c = Command::new(parts[0]);
162            if parts.len() > 1 {
163                c.args(&parts[1..]);
164            }
165            c
166        };
167
168        cmd.current_dir(&self.current_directory)
169            .stdout(Stdio::piped())
170            .stderr(Stdio::piped());
171
172        let timeout_duration = Duration::from_millis(self.config.timeout_ms);
173        let spawn_result = cmd.spawn();
174
175        match spawn_result {
176            Ok(mut child) => {
177                let stdout_handle = child.stdout.take();
178                let stderr_handle = child.stderr.take();
179
180                match timeout(timeout_duration, child.wait()).await {
181                    Ok(Ok(status)) => {
182                        let mut stdout = String::new();
183                        let mut stderr = String::new();
184
185                        if let Some(mut handle) = stdout_handle {
186                            let _ = handle.read_to_string(&mut stdout).await;
187                        }
188                        if let Some(mut handle) = stderr_handle {
189                            let _ = handle.read_to_string(&mut stderr).await;
190                        }
191
192                        Ok(CommandResult {
193                            success: status.success(),
194                            stdout,
195                            stderr,
196                            exit_code: status.code(),
197                            error: None,
198                            executed_in: cwd,
199                        })
200                    }
201                    Ok(Err(e)) => Ok(CommandResult::error(
202                        "Failed to execute command",
203                        &e.to_string(),
204                        &cwd,
205                    )),
206                    Err(_) => {
207                        let _ = child.kill().await;
208                        Ok(CommandResult {
209                            success: false,
210                            stdout: String::new(),
211                            stderr: "Command timed out".to_string(),
212                            exit_code: None,
213                            error: Some("Command execution timeout".to_string()),
214                            executed_in: cwd,
215                        })
216                    }
217                }
218            }
219            Err(e) => Ok(CommandResult::error(
220                "Failed to execute command",
221                &e.to_string(),
222                &cwd,
223            )),
224        }
225    }
226
227    fn add_to_history(
228        &mut self,
229        conversation_id: &str,
230        command: &str,
231        result: &CommandResult,
232        file_operations: Option<Vec<FileOperation>>,
233    ) {
234        let timestamp = SystemTime::now()
235            .duration_since(UNIX_EPOCH)
236            .map(|d| d.as_secs_f64())
237            .unwrap_or(0.0);
238
239        let entry = CommandHistoryEntry {
240            command: command.to_string(),
241            stdout: result.stdout.clone(),
242            stderr: result.stderr.clone(),
243            exit_code: result.exit_code,
244            timestamp,
245            working_directory: result.executed_in.clone(),
246            file_operations,
247        };
248
249        let history = self
250            .command_history
251            .entry(conversation_id.to_string())
252            .or_default();
253
254        history.push(entry);
255
256        if history.len() > self.max_history_per_conversation {
257            history.remove(0);
258        }
259    }
260
261    fn detect_file_operations(&self, command: &str) -> Option<Vec<FileOperation>> {
262        let parts: Vec<&str> = command.split_whitespace().collect();
263        if parts.is_empty() {
264            return None;
265        }
266
267        let cmd = parts[0].to_lowercase();
268        let cwd = &self.current_directory;
269        let mut operations = Vec::new();
270
271        let resolve_path = |path: &str| -> String {
272            if Path::new(path).is_absolute() {
273                path.to_string()
274            } else {
275                cwd.join(path).display().to_string()
276            }
277        };
278
279        match cmd.as_str() {
280            "touch" if parts.len() > 1 => {
281                operations.push(FileOperation {
282                    op_type: FileOperationType::Create,
283                    target: resolve_path(parts[1]),
284                    secondary_target: None,
285                });
286            }
287            "echo" if command.contains('>') => {
288                if let Some(pos) = command.rfind('>') {
289                    let target = command[pos + 1..].trim();
290                    if !target.is_empty() {
291                        let target = target.split_whitespace().next().unwrap_or(target);
292                        operations.push(FileOperation {
293                            op_type: FileOperationType::Write,
294                            target: resolve_path(target),
295                            secondary_target: None,
296                        });
297                    }
298                }
299            }
300            "mkdir" if parts.len() > 1 => {
301                operations.push(FileOperation {
302                    op_type: FileOperationType::Mkdir,
303                    target: resolve_path(parts[1]),
304                    secondary_target: None,
305                });
306            }
307            "cat" if parts.len() > 1 && !command.contains('>') => {
308                operations.push(FileOperation {
309                    op_type: FileOperationType::Read,
310                    target: resolve_path(parts[1]),
311                    secondary_target: None,
312                });
313            }
314            "mv" if parts.len() > 2 => {
315                operations.push(FileOperation {
316                    op_type: FileOperationType::Move,
317                    target: resolve_path(parts[1]),
318                    secondary_target: Some(resolve_path(parts[2])),
319                });
320            }
321            "cp" if parts.len() > 2 => {
322                operations.push(FileOperation {
323                    op_type: FileOperationType::Copy,
324                    target: resolve_path(parts[1]),
325                    secondary_target: Some(resolve_path(parts[2])),
326                });
327            }
328            _ => {}
329        }
330
331        if operations.is_empty() {
332            None
333        } else {
334            Some(operations)
335        }
336    }
337
338    pub fn get_command_history(
339        &self,
340        conversation_id: &str,
341        limit: Option<usize>,
342    ) -> Vec<CommandHistoryEntry> {
343        let history = self
344            .command_history
345            .get(conversation_id)
346            .cloned()
347            .unwrap_or_default();
348
349        match limit {
350            Some(n) if n > 0 => history.into_iter().rev().take(n).rev().collect(),
351            _ => history,
352        }
353    }
354
355    pub fn clear_command_history(&mut self, conversation_id: &str) {
356        self.command_history.remove(conversation_id);
357        info!(
358            "Cleared command history for conversation: {}",
359            conversation_id
360        );
361    }
362
363    pub fn get_current_directory(&self, _conversation_id: Option<&str>) -> &Path {
364        &self.current_directory
365    }
366
367    pub fn get_allowed_directory(&self) -> &Path {
368        &self.config.allowed_directory
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use tempfile::tempdir;
376
377    fn test_config() -> ShellConfig {
378        let dir = tempdir().unwrap();
379        ShellConfig {
380            enabled: true,
381            allowed_directory: dir.keep(),
382            timeout_ms: 30000,
383            forbidden_commands: vec!["rm".to_string(), "rmdir".to_string()],
384        }
385    }
386
387    #[tokio::test]
388    async fn test_disabled_shell() {
389        let mut config = test_config();
390        config.enabled = false;
391        let mut service = ShellService::new(config);
392
393        let result = service.execute_command("ls", None).await.unwrap();
394        assert!(!result.success);
395        assert!(result.stderr.contains("disabled"));
396    }
397
398    #[tokio::test]
399    async fn test_forbidden_command() {
400        let config = test_config();
401        let mut service = ShellService::new(config);
402
403        let result = service.execute_command("rm file.txt", None).await.unwrap();
404        assert!(!result.success);
405        assert!(result.stderr.contains("forbidden"));
406    }
407
408    #[tokio::test]
409    async fn test_history_tracking() {
410        let config = test_config();
411        let mut service = ShellService::new(config);
412        let conv_id = "test-conv";
413
414        service
415            .execute_command("echo hello", Some(conv_id))
416            .await
417            .unwrap();
418
419        let history = service.get_command_history(conv_id, None);
420        assert_eq!(history.len(), 1);
421        assert_eq!(history[0].command, "echo hello");
422    }
423
424    #[tokio::test]
425    async fn test_clear_history() {
426        let config = test_config();
427        let mut service = ShellService::new(config);
428        let conv_id = "test-conv";
429
430        service
431            .execute_command("echo test", Some(conv_id))
432            .await
433            .unwrap();
434        assert_eq!(service.get_command_history(conv_id, None).len(), 1);
435
436        service.clear_command_history(conv_id);
437        assert_eq!(service.get_command_history(conv_id, None).len(), 0);
438    }
439}