Skip to main content

agent_procs/daemon/
process_manager.rs

1use crate::daemon::log_writer::{self, OutputLine};
2use crate::paths;
3use crate::protocol::{ProcessInfo, ProcessState, Response, Stream as ProtoStream};
4use crate::session::IdCounter;
5use std::collections::HashMap;
6use std::process::Stdio;
7use std::time::{Duration, Instant};
8use tokio::process::{Child, Command};
9use tokio::sync::broadcast;
10
11const DEFAULT_MAX_LOG_BYTES: u64 = 50 * 1024 * 1024; // 50MB
12
13pub struct ManagedProcess {
14    pub name: String,
15    pub id: String,
16    pub command: String,
17    pub cwd: Option<String>,
18    pub env: HashMap<String, String>,
19    pub child: Option<Child>,
20    pub pid: u32,
21    pub started_at: Instant,
22    pub exit_code: Option<i32>,
23}
24
25pub struct ProcessManager {
26    processes: HashMap<String, ManagedProcess>,
27    id_counter: IdCounter,
28    session: String,
29    pub output_tx: broadcast::Sender<OutputLine>,
30}
31
32impl ProcessManager {
33    pub fn new(session: &str) -> Self {
34        let (output_tx, _) = broadcast::channel(1024);
35        Self {
36            processes: HashMap::new(),
37            id_counter: IdCounter::new(),
38            session: session.to_string(),
39            output_tx,
40        }
41    }
42
43    pub async fn spawn_process(
44        &mut self,
45        command: &str,
46        name: Option<String>,
47        cwd: Option<&str>,
48        env: Option<&HashMap<String, String>>,
49    ) -> Response {
50        let id = self.id_counter.next_id();
51        let name = name.unwrap_or_else(|| id.clone());
52
53        // Reject names that could cause path traversal in log files
54        if name.contains('/') || name.contains('\\') || name.contains("..") || name.contains('\0') {
55            return Response::Error {
56                code: 1,
57                message: format!("invalid process name: {}", name),
58            };
59        }
60
61        if self.processes.contains_key(&name) {
62            return Response::Error {
63                code: 1,
64                message: format!("process already exists: {}", name),
65            };
66        }
67
68        let log_dir = paths::log_dir(&self.session);
69        let _ = std::fs::create_dir_all(&log_dir);
70
71        let mut cmd = Command::new("sh");
72        cmd.arg("-c")
73            .arg(command)
74            .stdout(Stdio::piped())
75            .stderr(Stdio::piped());
76        if let Some(dir) = cwd {
77            cmd.current_dir(dir);
78        }
79        if let Some(env_vars) = env {
80            cmd.envs(env_vars);
81        }
82        // Put child in its own process group so we can signal the entire tree
83        unsafe {
84            cmd.pre_exec(|| {
85                nix::unistd::setpgid(nix::unistd::Pid::from_raw(0), nix::unistd::Pid::from_raw(0))
86                    .map_err(std::io::Error::other)?;
87                Ok(())
88            });
89        }
90
91        let mut child = match cmd.spawn() {
92            Ok(c) => c,
93            Err(e) => {
94                return Response::Error {
95                    code: 1,
96                    message: format!("failed to spawn: {}", e),
97                }
98            }
99        };
100
101        let pid = child.id().unwrap_or(0);
102
103        // Spawn output capture tasks via log_writer
104        if let Some(stdout) = child.stdout.take() {
105            let tx = self.output_tx.clone();
106            let pname = name.clone();
107            let path = log_dir.join(format!("{}.stdout", name));
108            tokio::spawn(async move {
109                log_writer::capture_output(
110                    stdout,
111                    &path,
112                    &pname,
113                    ProtoStream::Stdout,
114                    tx,
115                    DEFAULT_MAX_LOG_BYTES,
116                )
117                .await;
118            });
119        }
120        if let Some(stderr) = child.stderr.take() {
121            let tx = self.output_tx.clone();
122            let pname = name.clone();
123            let path = log_dir.join(format!("{}.stderr", name));
124            tokio::spawn(async move {
125                log_writer::capture_output(
126                    stderr,
127                    &path,
128                    &pname,
129                    ProtoStream::Stderr,
130                    tx,
131                    DEFAULT_MAX_LOG_BYTES,
132                )
133                .await;
134            });
135        }
136
137        self.processes.insert(
138            name.clone(),
139            ManagedProcess {
140                name: name.clone(),
141                id: id.clone(),
142                command: command.to_string(),
143                cwd: cwd.map(|s| s.to_string()),
144                env: env.cloned().unwrap_or_default(),
145                child: Some(child),
146                pid,
147                started_at: Instant::now(),
148                exit_code: None,
149            },
150        );
151
152        Response::RunOk { name, id, pid }
153    }
154
155    pub async fn stop_process(&mut self, target: &str) -> Response {
156        let proc = match self.find_mut(target) {
157            Some(p) => p,
158            None => {
159                return Response::Error {
160                    code: 2,
161                    message: format!("process not found: {}", target),
162                }
163            }
164        };
165
166        if let Some(ref child) = proc.child {
167            let raw_pid = child.id().unwrap_or(0) as i32;
168            if raw_pid > 0 {
169                // Signal the entire process group (child PID == PGID due to setpgid in pre_exec)
170                let pgid = nix::unistd::Pid::from_raw(raw_pid);
171                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGTERM);
172            }
173        }
174
175        // Wait up to 10s for graceful exit, then SIGKILL
176        if let Some(ref mut child) = proc.child {
177            let wait_result = tokio::time::timeout(Duration::from_secs(10), child.wait()).await;
178
179            match wait_result {
180                Ok(Ok(status)) => {
181                    proc.exit_code = status.code();
182                }
183                _ => {
184                    // Timed out or error — force kill the process group
185                    let raw_pid = proc.pid as i32;
186                    if raw_pid > 0 {
187                        let pgid = nix::unistd::Pid::from_raw(raw_pid);
188                        let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
189                    }
190                    let _ = child.wait().await;
191                    proc.exit_code = Some(-9);
192                }
193            }
194            proc.child = None;
195        }
196
197        Response::Ok {
198            message: format!("stopped {}", target),
199        }
200    }
201
202    pub async fn stop_all(&mut self) -> Response {
203        let names: Vec<String> = self.processes.keys().cloned().collect();
204        for name in names {
205            self.stop_process(&name).await;
206        }
207        self.processes.clear();
208        Response::Ok {
209            message: "all processes stopped".into(),
210        }
211    }
212
213    pub async fn restart_process(&mut self, target: &str) -> Response {
214        let (command, name, cwd, env) = match self.find(target) {
215            Some(p) => (
216                p.command.clone(),
217                p.name.clone(),
218                p.cwd.clone(),
219                p.env.clone(),
220            ),
221            None => {
222                return Response::Error {
223                    code: 2,
224                    message: format!("process not found: {}", target),
225                }
226            }
227        };
228        self.stop_process(target).await;
229        self.processes.remove(&name);
230        let env = if env.is_empty() { None } else { Some(env) };
231        self.spawn_process(&command, Some(name), cwd.as_deref(), env.as_ref())
232            .await
233    }
234
235    pub fn status(&mut self) -> Response {
236        self.refresh_exit_states();
237        let mut infos: Vec<ProcessInfo> = self
238            .processes
239            .values()
240            .map(|p| ProcessInfo {
241                name: p.name.clone(),
242                id: p.id.clone(),
243                pid: p.pid,
244                state: if p.child.is_some() {
245                    ProcessState::Running
246                } else {
247                    ProcessState::Exited
248                },
249                exit_code: p.exit_code,
250                uptime_secs: if p.child.is_some() {
251                    Some(p.started_at.elapsed().as_secs())
252                } else {
253                    None
254                },
255                command: p.command.clone(),
256            })
257            .collect();
258        infos.sort_by(|a, b| a.name.cmp(&b.name));
259        Response::Status { processes: infos }
260    }
261
262    /// Returns `None` if process not found or still running.
263    /// Returns `Some(exit_code)` if process has exited (exit_code is None for signal kills).
264    pub fn is_process_exited(&mut self, target: &str) -> Option<Option<i32>> {
265        self.refresh_exit_states();
266        self.find(target).and_then(|p| {
267            if p.child.is_none() {
268                Some(p.exit_code)
269            } else {
270                None
271            }
272        })
273    }
274
275    fn refresh_exit_states(&mut self) {
276        for proc in self.processes.values_mut() {
277            if proc.child.is_some() && proc.exit_code.is_none() {
278                if let Some(ref mut child) = proc.child {
279                    if let Ok(Some(status)) = child.try_wait() {
280                        proc.exit_code = status.code();
281                        proc.child = None;
282                    }
283                }
284            }
285        }
286    }
287
288    pub fn has_process(&self, target: &str) -> bool {
289        self.find(target).is_some()
290    }
291
292    fn find(&self, target: &str) -> Option<&ManagedProcess> {
293        self.processes
294            .get(target)
295            .or_else(|| self.processes.values().find(|p| p.id == target))
296    }
297
298    fn find_mut(&mut self, target: &str) -> Option<&mut ManagedProcess> {
299        if self.processes.contains_key(target) {
300            self.processes.get_mut(target)
301        } else {
302            self.processes.values_mut().find(|p| p.id == target)
303        }
304    }
305}