claude_agent/tools/
process.rs

1//! Process manager for background shell execution with security hardening.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Instant;
7
8use tokio::io::AsyncReadExt;
9use tokio::process::{Child, Command};
10use tokio::sync::Mutex;
11
12use crate::security::bash::SanitizedEnv;
13
14/// Unique identifier for a managed process.
15pub type ProcessId = String;
16
17/// Information about a running process.
18#[derive(Debug, Clone)]
19pub struct ProcessInfo {
20    /// Unique process identifier.
21    pub id: ProcessId,
22    /// The command that was executed.
23    pub command: String,
24    /// When the process was started.
25    pub started_at: Instant,
26    /// OS process ID if available.
27    pub pid: Option<u32>,
28}
29
30const MAX_OUTPUT_BUFFER_SIZE: usize = 1024 * 1024; // 1MB limit
31
32struct ManagedProcess {
33    child: Child,
34    info: ProcessInfo,
35    output_buffer: String,
36}
37
38/// Manager for background shell processes.
39#[derive(Clone)]
40pub struct ProcessManager {
41    processes: Arc<Mutex<HashMap<ProcessId, ManagedProcess>>>,
42}
43
44impl ProcessManager {
45    /// Create a new process manager.
46    #[must_use]
47    pub fn new() -> Self {
48        Self {
49            processes: Arc::new(Mutex::new(HashMap::new())),
50        }
51    }
52
53    /// Spawn a new background process.
54    pub async fn spawn(&self, command: &str, working_dir: &Path) -> Result<ProcessId, String> {
55        self.spawn_with_env(command, working_dir, SanitizedEnv::from_current())
56            .await
57    }
58
59    /// Spawn a new background process with custom sanitized environment.
60    pub async fn spawn_with_env(
61        &self,
62        command: &str,
63        working_dir: &Path,
64        env: SanitizedEnv,
65    ) -> Result<ProcessId, String> {
66        let mut cmd = Command::new("bash");
67        cmd.arg("-c").arg(command);
68        cmd.current_dir(working_dir);
69        cmd.env_clear();
70        cmd.envs(env);
71        cmd.stdout(std::process::Stdio::piped());
72        cmd.stderr(std::process::Stdio::piped());
73        // Ensure process is killed when Child is dropped (safety net)
74        cmd.kill_on_drop(true);
75
76        let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
77
78        let id = uuid::Uuid::new_v4().to_string();
79        let pid = child.id();
80
81        let info = ProcessInfo {
82            id: id.clone(),
83            command: command.to_string(),
84            started_at: Instant::now(),
85            pid,
86        };
87
88        let managed = ManagedProcess {
89            child,
90            info,
91            output_buffer: String::new(),
92        };
93
94        self.processes.lock().await.insert(id.clone(), managed);
95        Ok(id)
96    }
97
98    /// Kill a background process and wait to reap it (prevents zombie).
99    pub async fn kill(&self, id: &ProcessId) -> Result<(), String> {
100        let mut processes = self.processes.lock().await;
101
102        if let Some(mut proc) = processes.remove(id) {
103            proc.child
104                .kill()
105                .await
106                .map_err(|e| format!("Failed to kill: {}", e))?;
107            // Wait to reap the process and prevent zombie
108            let _ = proc.child.wait().await;
109            Ok(())
110        } else {
111            Err(format!("Process '{}' not found", id))
112        }
113    }
114
115    /// Get output from a background process (non-blocking read of available output).
116    pub async fn get_output(&self, id: &ProcessId) -> Result<String, String> {
117        let mut processes = self.processes.lock().await;
118
119        let proc = processes
120            .get_mut(id)
121            .ok_or_else(|| format!("Process '{}' not found", id))?;
122
123        // Try to read available stdout
124        if let Some(ref mut stdout) = proc.child.stdout {
125            let mut buffer = vec![0u8; 8192];
126            match tokio::time::timeout(
127                std::time::Duration::from_millis(100),
128                stdout.read(&mut buffer),
129            )
130            .await
131            {
132                Ok(Ok(n)) if n > 0 => {
133                    let s = String::from_utf8_lossy(&buffer[..n]);
134                    proc.output_buffer.push_str(&s);
135                }
136                _ => {}
137            }
138        }
139
140        // Try to read available stderr
141        if let Some(ref mut stderr) = proc.child.stderr {
142            let mut buffer = vec![0u8; 8192];
143            match tokio::time::timeout(
144                std::time::Duration::from_millis(100),
145                stderr.read(&mut buffer),
146            )
147            .await
148            {
149                Ok(Ok(n)) if n > 0 => {
150                    let s = String::from_utf8_lossy(&buffer[..n]);
151                    proc.output_buffer.push_str(&s);
152                }
153                _ => {}
154            }
155        }
156
157        // Truncate buffer if it exceeds the limit (keep the most recent data)
158        // Uses drain() for in-place removal without new allocation
159        if proc.output_buffer.len() > MAX_OUTPUT_BUFFER_SIZE {
160            let remove_bytes = proc.output_buffer.len() - MAX_OUTPUT_BUFFER_SIZE;
161            // Find safe UTF-8 character boundary
162            let boundary = proc
163                .output_buffer
164                .char_indices()
165                .find(|(i, _)| *i >= remove_bytes)
166                .map_or(remove_bytes, |(i, _)| i);
167            proc.output_buffer.drain(..boundary);
168        }
169
170        Ok(proc.output_buffer.clone())
171    }
172
173    /// Check if a process is still running.
174    pub async fn is_running(&self, id: &ProcessId) -> bool {
175        let mut processes = self.processes.lock().await;
176
177        if let Some(proc) = processes.get_mut(id) {
178            matches!(proc.child.try_wait(), Ok(None))
179        } else {
180            false
181        }
182    }
183
184    /// List all tracked processes.
185    pub async fn list(&self) -> Vec<ProcessInfo> {
186        self.processes
187            .lock()
188            .await
189            .values()
190            .map(|p| p.info.clone())
191            .collect()
192    }
193
194    /// Clean up finished processes and return their final output.
195    pub async fn cleanup_finished(&self) -> Vec<(ProcessInfo, String)> {
196        let mut processes = self.processes.lock().await;
197        let mut finished = Vec::new();
198
199        let ids: Vec<_> = processes.keys().cloned().collect();
200        for id in ids {
201            if let Some(proc) = processes.get_mut(&id)
202                && let Ok(Some(_status)) = proc.child.try_wait()
203                && let Some(proc) = processes.remove(&id)
204            {
205                finished.push((proc.info, proc.output_buffer));
206            }
207        }
208
209        finished
210    }
211}
212
213impl Default for ProcessManager {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use std::path::PathBuf;
223
224    #[tokio::test]
225    async fn test_spawn_and_list() {
226        let mgr = ProcessManager::new();
227        let id = mgr
228            .spawn("sleep 0.1", &PathBuf::from("/tmp"))
229            .await
230            .unwrap();
231
232        let list = mgr.list().await;
233        assert_eq!(list.len(), 1);
234        assert_eq!(list[0].id, id);
235
236        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
237        assert!(!mgr.is_running(&id).await);
238    }
239
240    #[tokio::test]
241    async fn test_kill() {
242        let mgr = ProcessManager::new();
243        let id = mgr.spawn("sleep 10", &PathBuf::from("/tmp")).await.unwrap();
244
245        assert!(mgr.is_running(&id).await);
246        mgr.kill(&id).await.unwrap();
247        assert!(!mgr.is_running(&id).await);
248    }
249
250    #[tokio::test]
251    async fn test_get_output() {
252        let mgr = ProcessManager::new();
253        let id = mgr
254            .spawn("echo hello", &PathBuf::from("/tmp"))
255            .await
256            .unwrap();
257
258        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
259        let output = mgr.get_output(&id).await.unwrap();
260        assert!(output.contains("hello"));
261    }
262
263    #[tokio::test]
264    async fn test_cleanup_finished() {
265        let mgr = ProcessManager::new();
266        let id = mgr
267            .spawn("echo done", &PathBuf::from("/tmp"))
268            .await
269            .unwrap();
270
271        tokio::time::sleep(std::time::Duration::from_millis(150)).await;
272        // Read output into buffer before cleanup
273        let _ = mgr.get_output(&id).await;
274        assert!(!mgr.is_running(&id).await);
275
276        let finished = mgr.cleanup_finished().await;
277        assert_eq!(finished.len(), 1);
278        assert!(finished[0].1.contains("done"));
279    }
280
281    #[tokio::test]
282    async fn test_process_not_found() {
283        let mgr = ProcessManager::new();
284        let result = mgr.get_output(&"nonexistent".to_string()).await;
285        assert!(result.is_err());
286
287        let result = mgr.kill(&"nonexistent".to_string()).await;
288        assert!(result.is_err());
289    }
290
291    #[tokio::test]
292    async fn test_buffer_overflow_keeps_recent_data() {
293        let mgr = ProcessManager::new();
294
295        // Generate output larger than MAX_OUTPUT_BUFFER_SIZE (1MB)
296        // We generate 1.5MB of data: 1500 lines of 1000 chars each
297        let id = mgr
298            .spawn(
299                "for i in $(seq 1 1500); do printf 'LINE%04d:%0990d\\n' $i $i; done",
300                &PathBuf::from("/tmp"),
301            )
302            .await
303            .unwrap();
304
305        // Wait for process to complete and read all output
306        let mut output = String::new();
307        for _ in 0..50 {
308            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
309            output = mgr.get_output(&id).await.unwrap();
310            if !mgr.is_running(&id).await && output.len() > MAX_OUTPUT_BUFFER_SIZE / 2 {
311                break;
312            }
313        }
314
315        // Buffer should be truncated to ~1MB
316        assert!(
317            output.len() <= MAX_OUTPUT_BUFFER_SIZE + 4,
318            "Buffer should be truncated to MAX_OUTPUT_BUFFER_SIZE, got {}",
319            output.len()
320        );
321
322        // When buffer overflows, recent data should be preserved
323        if output.len() > MAX_OUTPUT_BUFFER_SIZE / 2 {
324            // Check that we have some later lines (not necessarily the last one due to timing)
325            let has_later_lines = (1000..=1500).any(|n| output.contains(&format!("LINE{:04}", n)));
326            assert!(has_later_lines, "Some later data should be preserved");
327        }
328    }
329}