Skip to main content

oxi/
bash_executor.rs

1//! Bash executor for persistent shell sessions
2//!
3//! Provides a persistent bash session that maintains state between commands,
4//! including working directory and environment variables.
5
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::process::{Command, Stdio};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11
12/// Result of a bash execution
13#[derive(Debug, Clone)]
14pub struct BashResult {
15    /// The command that was executed
16    pub command: String,
17    /// Standard output
18    pub stdout: String,
19    /// Standard error
20    pub stderr: String,
21    /// Exit code
22    pub exit_code: Option<i32>,
23    /// Whether the command was killed due to timeout
24    pub timed_out: bool,
25    /// Execution duration
26    pub duration_ms: u64,
27}
28
29/// Bash executor configuration
30#[derive(Debug, Clone)]
31pub struct BashExecutorConfig {
32    /// Shell to use
33    pub shell: String,
34    /// Initial working directory
35    pub cwd: PathBuf,
36    /// Environment variables to set
37    pub env: HashMap<String, String>,
38    /// Timeout for commands (None for no timeout)
39    pub timeout: Option<Duration>,
40    /// Maximum output size in bytes
41    pub max_output_size: usize,
42}
43
44impl Default for BashExecutorConfig {
45    fn default() -> Self {
46        Self {
47            shell: "/bin/bash".to_string(),
48            cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
49            env: HashMap::new(),
50            timeout: Some(Duration::from_secs(300)),
51            max_output_size: 10 * 1024 * 1024, // 10MB
52        }
53    }
54}
55
56/// A persistent bash executor
57pub struct BashExecutor {
58    config: BashExecutorConfig,
59    /// Current working directory
60    cwd: RwLock<PathBuf>,
61    /// Environment variables
62    env: RwLock<HashMap<String, String>>,
63    /// Command history
64    history: RwLock<Vec<String>>,
65}
66
67impl BashExecutor {
68    /// Create a new bash executor
69    pub fn new(config: BashExecutorConfig) -> Self {
70        let cwd = RwLock::new(config.cwd.clone());
71        let env = RwLock::new(config.env.clone());
72        Self {
73            config,
74            cwd,
75            env,
76            history: RwLock::new(Vec::new()),
77        }
78    }
79
80    /// Create with default configuration
81    pub fn default() -> Self {
82        Self::new(BashExecutorConfig::default())
83    }
84
85    /// Get the current working directory
86    pub fn cwd(&self) -> PathBuf {
87        self.cwd.read().unwrap().clone()
88    }
89
90    /// Get a copy of environment variables
91    pub fn env(&self) -> HashMap<String, String> {
92        self.env.read().unwrap().clone()
93    }
94
95    /// Get command history
96    pub fn history(&self) -> Vec<String> {
97        self.history.read().unwrap().clone()
98    }
99
100    /// Set working directory
101    pub fn set_cwd(&self, path: PathBuf) {
102        if path.exists() && path.is_dir() {
103            *self.cwd.write().unwrap() = path;
104        }
105    }
106
107    /// Set an environment variable
108    pub fn set_env(&self, key: &str, value: &str) {
109        self.env.write().unwrap().insert(key.to_string(), value.to_string());
110    }
111
112    /// Remove an environment variable
113    pub fn remove_env(&self, key: &str) {
114        self.env.write().unwrap().remove(key);
115    }
116
117    /// Execute a command and return the result
118    pub fn execute(&self, command: &str) -> BashResult {
119        let start = std::time::Instant::now();
120
121        // Build the command — wrap with cwd tracking
122        // We prefix with a pwd capture so we can track directory changes
123        let wrapped = format!(
124            "{}; __oxi_cwd=$(pwd)",
125            command
126        );
127
128        let mut cmd = Command::new(&self.config.shell);
129        cmd.arg("-c")
130            .arg(&wrapped)
131            .current_dir(self.cwd.read().unwrap().as_path())
132            .stdout(Stdio::piped())
133            .stderr(Stdio::piped());
134
135        // Set custom environment variables
136        let env = self.env.read().unwrap();
137        for (key, value) in env.iter() {
138            cmd.env(key, value);
139        }
140
141        // Execute with timeout
142        let output_result = match self.config.timeout {
143            Some(t) => {
144                match cmd.spawn() {
145                    Ok(mut child) => {
146                        let deadline = std::time::Instant::now() + t;
147                        loop {
148                            match child.try_wait() {
149                                Ok(Some(_status)) => {
150                                    break child.wait_with_output();
151                                }
152                                Ok(None) => {
153                                    if std::time::Instant::now() >= deadline {
154                                        let _ = child.kill();
155                                        let _ = child.wait();
156                                        let duration_ms = start.elapsed().as_millis() as u64;
157                                        self.history.write().unwrap().push(command.to_string());
158                                        return BashResult {
159                                            command: command.to_string(),
160                                            stdout: String::new(),
161                                            stderr: "Command timed out".to_string(),
162                                            exit_code: Some(-1),
163                                            timed_out: true,
164                                            duration_ms,
165                                        };
166                                    }
167                                    std::thread::sleep(Duration::from_millis(50));
168                                }
169                                Err(e) => {
170                                    let duration_ms = start.elapsed().as_millis() as u64;
171                                    self.history.write().unwrap().push(command.to_string());
172                                    return BashResult {
173                                        command: command.to_string(),
174                                        stdout: String::new(),
175                                        stderr: format!("Failed to wait: {}", e),
176                                        exit_code: Some(-1),
177                                        timed_out: false,
178                                        duration_ms,
179                                    };
180                                }
181                            }
182                        }
183                    }
184                    Err(e) => {
185                        let duration_ms = start.elapsed().as_millis() as u64;
186                        self.history.write().unwrap().push(command.to_string());
187                        return BashResult {
188                            command: command.to_string(),
189                            stdout: String::new(),
190                            stderr: format!("Failed to spawn: {}", e),
191                            exit_code: Some(-1),
192                            timed_out: false,
193                            duration_ms,
194                        };
195                    }
196                }
197            }
198            None => cmd.output(),
199        };
200
201        let duration_ms = start.elapsed().as_millis() as u64;
202
203        let output = match output_result {
204            Ok(o) => o,
205            Err(e) => {
206                self.history.write().unwrap().push(command.to_string());
207                return BashResult {
208                    command: command.to_string(),
209                    stdout: String::new(),
210                    stderr: format!("Failed to execute: {}", e),
211                    exit_code: Some(-1),
212                    timed_out: false,
213                    duration_ms,
214                };
215            }
216        };
217
218        let stdout = self.truncate_output(String::from_utf8_lossy(&output.stdout).to_string());
219        let stderr = self.truncate_output(String::from_utf8_lossy(&output.stderr).to_string());
220        let exit_code = output.status.code();
221
222        // Track cd commands — re-resolve cwd after cd
223        if command.trim().starts_with("cd ") && exit_code == Some(0) {
224            let target = command.trim().strip_prefix("cd ").unwrap().trim();
225            let target = if target.starts_with("~/") {
226                format!("{}/{}", dirs::home_dir().map(|p| p.display().to_string()).unwrap_or_default(), &target[2..])
227            } else {
228                target.to_string()
229            };
230            let new_cwd = if target.starts_with('/') {
231                PathBuf::from(target)
232            } else {
233                self.cwd.read().unwrap().join(&target)
234            };
235            if new_cwd.is_dir() {
236                *self.cwd.write().unwrap() = new_cwd;
237            }
238        }
239
240        // Add to history
241        self.history.write().unwrap().push(command.to_string());
242
243        BashResult {
244            command: command.to_string(),
245            stdout,
246            stderr,
247            exit_code,
248            timed_out: false,
249            duration_ms,
250        }
251    }
252
253    /// Execute a command with streaming output
254    pub fn execute_streaming<F>(&self, command: &str, mut on_output: F) -> BashResult
255    where
256        F: FnMut(&str),
257    {
258        let start = std::time::Instant::now();
259
260        let mut cmd = Command::new(&self.config.shell);
261        cmd.arg("-c")
262            .arg(command)
263            .current_dir(self.cwd.read().unwrap().as_path())
264            .stdout(Stdio::piped())
265            .stderr(Stdio::piped())
266            .env_clear();
267
268        let env = self.env.read().unwrap();
269        for (key, value) in env.iter() {
270            cmd.env(key, value);
271        }
272        if let Ok(path) = std::env::var("PATH") {
273            cmd.env("PATH", path);
274        }
275
276        let mut child = match cmd.spawn() {
277            Ok(c) => c,
278            Err(e) => {
279                return BashResult {
280                    command: command.to_string(),
281                    stdout: String::new(),
282                    stderr: format!("Failed to spawn: {}", e),
283                    exit_code: Some(-1),
284                    timed_out: false,
285                    duration_ms: start.elapsed().as_millis() as u64,
286                };
287            }
288        };
289
290        // Read stdout
291        let mut stdout = String::new();
292        if let Some(ref mut out) = child.stdout {
293            use std::io::BufRead;
294            let reader = std::io::BufReader::new(out);
295            for line in reader.lines() {
296                if let Ok(line) = line {
297                    on_output(&line);
298                    stdout.push_str(&line);
299                    stdout.push('\n');
300                }
301            }
302        }
303
304        // Read stderr
305        let mut stderr = String::new();
306        if let Some(ref mut err) = child.stderr {
307            use std::io::BufRead;
308            let reader = std::io::BufReader::new(err);
309            for line in reader.lines() {
310                if let Ok(line) = line {
311                    stderr.push_str(&line);
312                    stderr.push('\n');
313                }
314            }
315        }
316
317        let status = child.wait().ok();
318        let exit_code = status.and_then(|s| s.code());
319        let duration_ms = start.elapsed().as_millis() as u64;
320
321        // Add to history
322        self.history.write().unwrap().push(command.to_string());
323
324        BashResult {
325            command: command.to_string(),
326            stdout: self.truncate_output(stdout),
327            stderr: self.truncate_output(stderr),
328            exit_code,
329            timed_out: false,
330            duration_ms,
331        }
332    }
333
334    /// Execute multiple commands in sequence
335    pub fn execute_batch(&self, commands: &[&str]) -> Vec<BashResult> {
336        commands.iter().map(|cmd| self.execute(cmd)).collect()
337    }
338
339    /// Execute with error propagation
340    pub fn execute_required(&self, command: &str) -> Result<String, String> {
341        let result = self.execute(command);
342        if result.exit_code == Some(0) {
343            Ok(result.stdout)
344        } else {
345            Err(format!(
346                "Command failed with exit code {:?}: {}\n{}",
347                result.exit_code, result.stderr, result.stdout
348            ))
349        }
350    }
351
352    /// Truncate output to max size
353    fn truncate_output(&self, output: String) -> String {
354        if output.len() > self.config.max_output_size {
355            format!(
356                "{}...\n[Output truncated: {} bytes -> {} bytes]",
357                &output[..self.config.max_output_size / 2],
358                output.len(),
359                self.config.max_output_size
360            )
361        } else {
362            output
363        }
364    }
365}
366
367impl Default for BashExecutor {
368    fn default() -> Self {
369        Self::default()
370    }
371}
372
373/// Create an Arc-wrapped executor for sharing
374pub fn create_executor(config: BashExecutorConfig) -> Arc<BashExecutor> {
375    Arc::new(BashExecutor::new(config))
376}
377
378/// Execute a single command without persistent state
379pub fn execute_once(command: &str) -> BashResult {
380    let executor = BashExecutor::default();
381    executor.execute(command)
382}
383
384/// Execute a command with a specific timeout
385pub fn execute_with_timeout(command: &str, timeout: Duration) -> BashResult {
386    let config = BashExecutorConfig {
387        timeout: Some(timeout),
388        ..Default::default()
389    };
390    let executor = BashExecutor::new(config);
391    executor.execute(command)
392}
393
394/// Check if a command exists in PATH
395pub fn command_exists(command: &str) -> bool {
396    Command::new("sh")
397        .arg("-c")
398        .arg(format!("command -v {} > /dev/null 2>&1", command))
399        .output()
400        .map(|o| o.status.success())
401        .unwrap_or(false)
402}
403
404/// Get the shell being used
405pub fn get_shell() -> String {
406    std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string())
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_execute_simple() {
415        let executor = BashExecutor::default();
416        let result = executor.execute("echo hello");
417        assert_eq!(result.exit_code, Some(0));
418        assert!(result.stdout.contains("hello"));
419    }
420
421    #[test]
422    fn test_execute_with_cd() {
423        let executor = BashExecutor::default();
424        let result = executor.execute("pwd");
425        assert!(result.exit_code == Some(0));
426    }
427
428    #[test]
429    fn test_execute_failed_command() {
430        let executor = BashExecutor::default();
431        let result = executor.execute("exit 1");
432        assert_eq!(result.exit_code, Some(1));
433    }
434
435    #[test]
436    fn test_execute_nonexistent_command() {
437        let executor = BashExecutor::default();
438        let result = executor.execute("nonexistent_command_12345");
439        assert!(result.exit_code.is_some() && result.exit_code.unwrap() != 0 || result.stderr.contains("not found") || result.stderr.contains("command not found"));
440    }
441
442    #[test]
443    fn test_execute_batch() {
444        let executor = BashExecutor::default();
445        let results = executor.execute_batch(&["echo one", "echo two", "echo three"]);
446        assert_eq!(results.len(), 3);
447        assert!(results.iter().all(|r| r.exit_code == Some(0)));
448    }
449
450    #[test]
451    fn test_cwd_tracking() {
452        let executor = BashExecutor::default();
453        let initial_cwd = executor.cwd();
454        assert!(initial_cwd.exists());
455    }
456
457    #[test]
458    fn test_env_tracking() {
459        let executor = BashExecutor::default();
460        executor.set_env("TEST_VAR", "test_value");
461        let env = executor.env();
462        assert_eq!(env.get("TEST_VAR"), Some(&"test_value".to_string()));
463    }
464
465    #[test]
466    fn test_history() {
467        let executor = BashExecutor::default();
468        executor.execute("echo 1");
469        executor.execute("echo 2");
470        let history = executor.history();
471        assert_eq!(history.len(), 2);
472        assert!(history.contains(&"echo 1".to_string()));
473        assert!(history.contains(&"echo 2".to_string()));
474    }
475
476    #[test]
477    fn test_execute_required_success() {
478        let executor = BashExecutor::default();
479        let result = executor.execute_required("echo hello");
480        assert!(result.is_ok());
481    }
482
483    #[test]
484    fn test_execute_required_failure() {
485        let executor = BashExecutor::default();
486        let result = executor.execute_required("exit 1");
487        assert!(result.is_err());
488    }
489
490    #[test]
491    fn test_command_exists() {
492        assert!(command_exists("echo"));
493        assert!(command_exists("ls"));
494        assert!(!command_exists("nonexistent_command_xyz"));
495    }
496
497    #[test]
498    fn test_get_shell() {
499        let shell = get_shell();
500        assert!(!shell.is_empty());
501        assert!(shell.contains("bash") || shell.contains("zsh"));
502    }
503
504    #[test]
505    fn test_execute_with_timeout() {
506        let result = execute_with_timeout("echo hello", Duration::from_secs(5));
507        assert_eq!(result.exit_code, Some(0));
508    }
509
510    #[test]
511    fn test_execute_long_output() {
512        let executor = BashExecutor::default();
513        // Generate more output than max_output_size
514        let result = executor.execute(&"yes | head -n 100000".to_string());
515        assert!(result.stdout.contains("[Output truncated]") || result.stdout.len() < 100000);
516    }
517
518    #[test]
519    fn test_execute_streaming() {
520        let executor = BashExecutor::default();
521        let mut output_lines = Vec::new();
522        let result = executor.execute_streaming("echo line1; echo line2; echo line3", |line| {
523            output_lines.push(line.to_string());
524        });
525        assert_eq!(output_lines.len(), 3);
526        assert_eq!(result.exit_code, Some(0));
527    }
528}