Skip to main content

agent_procs/daemon/
process_manager.rs

1use crate::daemon::log_writer::{self, OutputLine};
2use crate::error::ProxyError;
3use crate::paths;
4use crate::protocol::{ProcessInfo, ProcessState, Response, Stream as ProtoStream};
5use crate::session::IdCounter;
6use std::collections::HashMap;
7use std::process::Stdio;
8use std::time::{Duration, Instant};
9use tokio::process::{Child, Command};
10use tokio::sync::broadcast;
11
12const DEFAULT_MAX_LOG_BYTES: u64 = 50 * 1024 * 1024; // 50MB
13const AUTO_PORT_MIN: u16 = 4000;
14const AUTO_PORT_MAX: u16 = 4999;
15
16/// Returns true if `name` is a valid DNS label: 1-63 lowercase alphanumeric/hyphen
17/// chars, not starting or ending with a hyphen.
18#[must_use]
19pub fn is_valid_dns_label(name: &str) -> bool {
20    if name.is_empty() || name.len() > 63 {
21        return false;
22    }
23    if name.starts_with('-') || name.ends_with('-') {
24        return false;
25    }
26    name.chars()
27        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
28}
29
30pub struct ManagedProcess {
31    pub name: String,
32    pub id: String,
33    pub command: String,
34    pub cwd: Option<String>,
35    pub env: HashMap<String, String>,
36    pub child: Option<Child>,
37    pub pid: u32,
38    pub started_at: Instant,
39    pub exit_code: Option<i32>,
40    pub port: Option<u16>,
41}
42
43pub struct ProcessManager {
44    processes: HashMap<String, ManagedProcess>,
45    id_counter: IdCounter,
46    session: String,
47    pub output_tx: broadcast::Sender<OutputLine>,
48    proxy_enabled: bool,
49    next_auto_port: u16,
50}
51
52impl ProcessManager {
53    pub fn new(session: &str) -> Self {
54        let (output_tx, _) = broadcast::channel(1024);
55        Self {
56            processes: HashMap::new(),
57            id_counter: IdCounter::new(),
58            session: session.to_string(),
59            output_tx,
60            proxy_enabled: false,
61            next_auto_port: AUTO_PORT_MIN,
62        }
63    }
64
65    fn auto_assign_port(&mut self) -> Result<u16, ProxyError> {
66        let start = self.next_auto_port;
67        let assigned: std::collections::HashSet<u16> =
68            self.processes.values().filter_map(|p| p.port).collect();
69        let range_size = (AUTO_PORT_MAX - AUTO_PORT_MIN + 1) as usize;
70
71        for i in 0..range_size {
72            let candidate = AUTO_PORT_MIN
73                + (((self.next_auto_port - AUTO_PORT_MIN) as usize + i) % range_size) as u16;
74            if assigned.contains(&candidate) {
75                continue;
76            }
77            // Bind-test: if we can bind, the port is free (listener drops immediately)
78            if std::net::TcpListener::bind(("127.0.0.1", candidate)).is_ok() {
79                self.next_auto_port = if candidate >= AUTO_PORT_MAX {
80                    AUTO_PORT_MIN
81                } else {
82                    candidate + 1
83                };
84                return Ok(candidate);
85            }
86        }
87        Err(ProxyError::NoFreeAutoPort {
88            min: AUTO_PORT_MIN,
89            max: AUTO_PORT_MAX,
90            start,
91        })
92    }
93
94    #[allow(unsafe_code, clippy::unused_async)]
95    pub async fn spawn_process(
96        &mut self,
97        command: &str,
98        name: Option<String>,
99        cwd: Option<&str>,
100        env: Option<&HashMap<String, String>>,
101        port: Option<u16>,
102    ) -> Response {
103        let id = self.id_counter.next_id();
104        let name = name.unwrap_or_else(|| id.clone());
105
106        // Reject names that could cause path traversal in log files
107        if name.contains('/') || name.contains('\\') || name.contains("..") || name.contains('\0') {
108            return Response::Error {
109                code: 1,
110                message: format!("invalid process name: {}", name),
111            };
112        }
113
114        // When proxy is active, names must be valid DNS labels for subdomain routing
115        if self.proxy_enabled && !is_valid_dns_label(&name) {
116            return Response::Error {
117                code: 1,
118                message: format!(
119                    "invalid process name for proxy: '{}' (must be lowercase alphanumeric/hyphens, max 63 chars)",
120                    name
121                ),
122            };
123        }
124
125        // Resolve the port: use explicit port, auto-assign if proxy is enabled, or None
126        let resolved_port = if let Some(p) = port {
127            Some(p)
128        } else if self.proxy_enabled {
129            match self.auto_assign_port() {
130                Ok(p) => Some(p),
131                Err(e) => {
132                    return Response::Error {
133                        code: 1,
134                        message: e.to_string(),
135                    };
136                }
137            }
138        } else {
139            None
140        };
141
142        if self.processes.contains_key(&name) {
143            return Response::Error {
144                code: 1,
145                message: format!("process already exists: {}", name),
146            };
147        }
148
149        let log_dir = paths::log_dir(&self.session);
150        let _ = std::fs::create_dir_all(&log_dir);
151
152        let mut cmd = Command::new("sh");
153        cmd.arg("-c")
154            .arg(command)
155            .stdout(Stdio::piped())
156            .stderr(Stdio::piped());
157        if let Some(dir) = cwd {
158            cmd.current_dir(dir);
159        }
160        if let Some(p) = resolved_port {
161            // Inject PORT and HOST; user-supplied env takes precedence
162            let mut merged_env: HashMap<String, String> = HashMap::new();
163            merged_env.insert("PORT".to_string(), p.to_string());
164            merged_env.insert("HOST".to_string(), "127.0.0.1".to_string());
165            if let Some(env_vars) = env {
166                for (k, v) in env_vars {
167                    merged_env.insert(k.clone(), v.clone());
168                }
169            }
170            cmd.envs(&merged_env);
171        } else if let Some(env_vars) = env {
172            cmd.envs(env_vars);
173        }
174        // SAFETY: `setpgid(0, 0)` creates a new process group with the child as
175        // leader.  This must happen before exec so that all grandchildren inherit
176        // the group.  The parent uses this PGID to signal the entire tree on stop.
177        unsafe {
178            cmd.pre_exec(|| {
179                nix::unistd::setpgid(nix::unistd::Pid::from_raw(0), nix::unistd::Pid::from_raw(0))
180                    .map_err(std::io::Error::other)?;
181                Ok(())
182            });
183        }
184
185        let mut child = match cmd.spawn() {
186            Ok(c) => c,
187            Err(e) => {
188                return Response::Error {
189                    code: 1,
190                    message: format!("failed to spawn: {}", e),
191                };
192            }
193        };
194
195        let pid = child.id().unwrap_or(0);
196
197        // Spawn output capture tasks via log_writer
198        if let Some(stdout) = child.stdout.take() {
199            let tx = self.output_tx.clone();
200            let pname = name.clone();
201            let path = log_dir.join(format!("{}.stdout", name));
202            tokio::spawn(async move {
203                log_writer::capture_output(
204                    stdout,
205                    &path,
206                    &pname,
207                    ProtoStream::Stdout,
208                    tx,
209                    DEFAULT_MAX_LOG_BYTES,
210                )
211                .await;
212            });
213        }
214        if let Some(stderr) = child.stderr.take() {
215            let tx = self.output_tx.clone();
216            let pname = name.clone();
217            let path = log_dir.join(format!("{}.stderr", name));
218            tokio::spawn(async move {
219                log_writer::capture_output(
220                    stderr,
221                    &path,
222                    &pname,
223                    ProtoStream::Stderr,
224                    tx,
225                    DEFAULT_MAX_LOG_BYTES,
226                )
227                .await;
228            });
229        }
230
231        self.processes.insert(
232            name.clone(),
233            ManagedProcess {
234                name: name.clone(),
235                id: id.clone(),
236                command: command.to_string(),
237                cwd: cwd.map(std::string::ToString::to_string),
238                env: env.cloned().unwrap_or_default(),
239                child: Some(child),
240                pid,
241                started_at: Instant::now(),
242                exit_code: None,
243                port: resolved_port,
244            },
245        );
246
247        let url = resolved_port.map(|p| format!("http://127.0.0.1:{}", p));
248        Response::RunOk {
249            name,
250            id,
251            pid,
252            port: resolved_port,
253            url,
254        }
255    }
256
257    pub async fn stop_process(&mut self, target: &str) -> Response {
258        let proc = match self.find_mut(target) {
259            Some(p) => p,
260            None => {
261                return Response::Error {
262                    code: 2,
263                    message: format!("process not found: {}", target),
264                };
265            }
266        };
267
268        if let Some(ref child) = proc.child {
269            let raw_pid = child.id().unwrap_or(0) as i32;
270            if raw_pid > 0 {
271                // Signal the entire process group (child PID == PGID due to setpgid in pre_exec)
272                let pgid = nix::unistd::Pid::from_raw(raw_pid);
273                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGTERM);
274            }
275        }
276
277        // Wait up to 10s for graceful exit, then SIGKILL
278        if let Some(ref mut child) = proc.child {
279            let wait_result = tokio::time::timeout(Duration::from_secs(10), child.wait()).await;
280
281            match wait_result {
282                Ok(Ok(status)) => {
283                    proc.exit_code = status.code();
284                }
285                _ => {
286                    // Timed out or error — force kill the process group
287                    let raw_pid = proc.pid as i32;
288                    if raw_pid > 0 {
289                        let pgid = nix::unistd::Pid::from_raw(raw_pid);
290                        let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
291                    }
292                    let _ = child.wait().await;
293                    proc.exit_code = Some(-9);
294                }
295            }
296            proc.child = None;
297        }
298
299        Response::Ok {
300            message: format!("stopped {}", target),
301        }
302    }
303
304    pub async fn stop_all(&mut self) -> Response {
305        let names: Vec<String> = self.processes.keys().cloned().collect();
306        for name in names {
307            let _ = self.stop_process(&name).await;
308        }
309        self.processes.clear();
310        Response::Ok {
311            message: "all processes stopped".into(),
312        }
313    }
314
315    pub async fn restart_process(&mut self, target: &str) -> Response {
316        let (command, name, cwd, env, port) = match self.find(target) {
317            Some(p) => (
318                p.command.clone(),
319                p.name.clone(),
320                p.cwd.clone(),
321                p.env.clone(),
322                p.port,
323            ),
324            None => {
325                return Response::Error {
326                    code: 2,
327                    message: format!("process not found: {}", target),
328                };
329            }
330        };
331        let _ = self.stop_process(target).await;
332        self.processes.remove(&name);
333        let env = if env.is_empty() { None } else { Some(env) };
334        self.spawn_process(&command, Some(name), cwd.as_deref(), env.as_ref(), port)
335            .await
336    }
337
338    pub fn enable_proxy(&mut self) {
339        self.proxy_enabled = true;
340    }
341
342    pub fn status(&mut self) -> Response {
343        self.refresh_exit_states();
344        Response::Status {
345            processes: self.build_process_infos(),
346        }
347    }
348
349    /// Returns `None` if process not found or still running.
350    /// Returns `Some(exit_code)` if process has exited (`exit_code` is None for signal kills).
351    pub fn is_process_exited(&mut self, target: &str) -> Option<Option<i32>> {
352        self.refresh_exit_states();
353        self.find(target).and_then(|p| {
354            if p.child.is_none() {
355                Some(p.exit_code)
356            } else {
357                None
358            }
359        })
360    }
361
362    fn refresh_exit_states(&mut self) {
363        for proc in self.processes.values_mut() {
364            if proc.child.is_some() && proc.exit_code.is_none() {
365                if let Some(ref mut child) = proc.child {
366                    if let Ok(Some(status)) = child.try_wait() {
367                        proc.exit_code = status.code();
368                        proc.child = None;
369                    }
370                }
371            }
372        }
373    }
374
375    pub fn session_name(&self) -> &str {
376        &self.session
377    }
378
379    pub fn has_process(&self, target: &str) -> bool {
380        self.find(target).is_some()
381    }
382
383    pub fn get_process_port(&self, name: &str) -> Option<u16> {
384        self.processes
385            .get(name)
386            .and_then(|p| if p.child.is_some() { p.port } else { None })
387    }
388
389    /// Non-mutating status snapshot for use by the proxy status page.
390    /// May show stale exit states since it skips `refresh_exit_states()`.
391    pub fn status_snapshot(&self) -> Response {
392        Response::Status {
393            processes: self.build_process_infos(),
394        }
395    }
396
397    fn build_process_infos(&self) -> Vec<ProcessInfo> {
398        let mut infos: Vec<ProcessInfo> = self
399            .processes
400            .values()
401            .map(|p| ProcessInfo {
402                name: p.name.clone(),
403                id: p.id.clone(),
404                pid: p.pid,
405                state: if p.child.is_some() {
406                    ProcessState::Running
407                } else {
408                    ProcessState::Exited
409                },
410                exit_code: p.exit_code,
411                uptime_secs: if p.child.is_some() {
412                    Some(p.started_at.elapsed().as_secs())
413                } else {
414                    None
415                },
416                command: p.command.clone(),
417                port: p.port,
418                url: p.port.map(|port| format!("http://127.0.0.1:{}", port)),
419            })
420            .collect();
421        infos.sort_by(|a, b| a.name.cmp(&b.name));
422        infos
423    }
424
425    fn find(&self, target: &str) -> Option<&ManagedProcess> {
426        self.processes
427            .get(target)
428            .or_else(|| self.processes.values().find(|p| p.id == target))
429    }
430
431    fn find_mut(&mut self, target: &str) -> Option<&mut ManagedProcess> {
432        if self.processes.contains_key(target) {
433            self.processes.get_mut(target)
434        } else {
435            self.processes.values_mut().find(|p| p.id == target)
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_valid_dns_labels() {
446        assert!(is_valid_dns_label("api"));
447        assert!(is_valid_dns_label("my-app"));
448        assert!(is_valid_dns_label("a"));
449        assert!(is_valid_dns_label("a1"));
450        assert!(is_valid_dns_label("123"));
451    }
452
453    #[test]
454    fn test_invalid_dns_labels() {
455        assert!(!is_valid_dns_label(""));
456        assert!(!is_valid_dns_label("-start"));
457        assert!(!is_valid_dns_label("end-"));
458        assert!(!is_valid_dns_label("UPPER"));
459        assert!(!is_valid_dns_label("has.dot"));
460        assert!(!is_valid_dns_label("has space"));
461        assert!(!is_valid_dns_label(&"a".repeat(64))); // > 63 chars
462        assert!(!is_valid_dns_label("has_underscore"));
463    }
464}