Skip to main content

agent_procs/daemon/
process_manager.rs

1use crate::daemon::log_writer::{self, OutputLine};
2use crate::daemon::port_allocator::PortAllocator;
3use crate::paths;
4use crate::protocol::{
5    ErrorCode, ProcessInfo, ProcessState, Response, Stream as ProtoStream, process_url,
6};
7use crate::session::IdCounter;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::time::{Duration, Instant};
11use tokio::process::{Child, Command};
12use tokio::sync::broadcast;
13
14const DEFAULT_MAX_LOG_BYTES: u64 = 50 * 1024 * 1024; // 50MB
15
16/// Returns true if `name` is a valid DNS label: 1-63 lowercase alphanumeric/hyphen
17/// chars, not starting or ending with a hyphen.
18#[must_use]
19pub fn is_valid_dns_label(name: &str) -> bool {
20    if name.is_empty() || name.len() > 63 {
21        return false;
22    }
23    if name.starts_with('-') || name.ends_with('-') {
24        return false;
25    }
26    name.chars()
27        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
28}
29
30pub struct ManagedProcess {
31    pub name: String,
32    pub id: String,
33    pub command: String,
34    pub cwd: Option<String>,
35    pub env: HashMap<String, String>,
36    pub child: Option<Child>,
37    pub pid: u32,
38    pub started_at: Instant,
39    pub exit_code: Option<i32>,
40    pub port: Option<u16>,
41}
42
43pub struct ProcessManager {
44    processes: HashMap<String, ManagedProcess>,
45    id_counter: IdCounter,
46    session: String,
47    pub output_tx: broadcast::Sender<OutputLine>,
48    port_allocator: PortAllocator,
49}
50
51impl ProcessManager {
52    pub fn new(session: &str) -> Self {
53        let (output_tx, _) = broadcast::channel(1024);
54        Self {
55            processes: HashMap::new(),
56            id_counter: IdCounter::new(),
57            session: session.to_string(),
58            output_tx,
59            port_allocator: PortAllocator::new(),
60        }
61    }
62
63    #[allow(unsafe_code, clippy::unused_async)]
64    pub async fn spawn_process(
65        &mut self,
66        command: &str,
67        name: Option<String>,
68        cwd: Option<&str>,
69        env: Option<&HashMap<String, String>>,
70        port: Option<u16>,
71    ) -> Response {
72        let id = self.id_counter.next_id();
73        let name = name.unwrap_or_else(|| id.clone());
74
75        // Reject names that could cause path traversal in log files
76        if name.contains('/') || name.contains('\\') || name.contains("..") || name.contains('\0') {
77            return Response::Error {
78                code: ErrorCode::General,
79                message: format!("invalid process name: {}", name),
80            };
81        }
82
83        // When proxy is active, names must be valid DNS labels for subdomain routing
84        if self.port_allocator.is_proxy_enabled() && !is_valid_dns_label(&name) {
85            return Response::Error {
86                code: ErrorCode::General,
87                message: format!(
88                    "invalid process name for proxy: '{}' (must be lowercase alphanumeric/hyphens, max 63 chars)",
89                    name
90                ),
91            };
92        }
93
94        // Resolve the port: use explicit port, auto-assign if proxy is enabled, or None
95        let resolved_port = if let Some(p) = port {
96            Some(p)
97        } else if self.port_allocator.is_proxy_enabled() {
98            let assigned: std::collections::HashSet<u16> =
99                self.processes.values().filter_map(|p| p.port).collect();
100            match self.port_allocator.auto_assign_port(&assigned) {
101                Ok(p) => Some(p),
102                Err(e) => {
103                    return Response::Error {
104                        code: ErrorCode::General,
105                        message: e.to_string(),
106                    };
107                }
108            }
109        } else {
110            None
111        };
112
113        if self.processes.contains_key(&name) {
114            return Response::Error {
115                code: ErrorCode::General,
116                message: format!("process already exists: {}", name),
117            };
118        }
119
120        let log_dir = paths::log_dir(&self.session);
121        let _ = std::fs::create_dir_all(&log_dir);
122
123        let mut cmd = Command::new("sh");
124        cmd.arg("-c")
125            .arg(command)
126            .stdout(Stdio::piped())
127            .stderr(Stdio::piped());
128        if let Some(dir) = cwd {
129            cmd.current_dir(dir);
130        }
131        if let Some(p) = resolved_port {
132            // Inject PORT and HOST; user-supplied env takes precedence
133            let mut merged_env: HashMap<String, String> = HashMap::new();
134            merged_env.insert("PORT".to_string(), p.to_string());
135            merged_env.insert("HOST".to_string(), "127.0.0.1".to_string());
136            if let Some(env_vars) = env {
137                for (k, v) in env_vars {
138                    merged_env.insert(k.clone(), v.clone());
139                }
140            }
141            cmd.envs(&merged_env);
142        } else if let Some(env_vars) = env {
143            cmd.envs(env_vars);
144        }
145        // SAFETY: `setpgid(0, 0)` creates a new process group with the child as
146        // leader.  This must happen before exec so that all grandchildren inherit
147        // the group.  The parent uses this PGID to signal the entire tree on stop.
148        unsafe {
149            cmd.pre_exec(|| {
150                nix::unistd::setpgid(nix::unistd::Pid::from_raw(0), nix::unistd::Pid::from_raw(0))
151                    .map_err(std::io::Error::other)?;
152                Ok(())
153            });
154        }
155
156        let mut child = match cmd.spawn() {
157            Ok(c) => c,
158            Err(e) => {
159                return Response::Error {
160                    code: ErrorCode::General,
161                    message: format!("failed to spawn: {}", e),
162                };
163            }
164        };
165
166        let pid = child.id().unwrap_or(0);
167
168        // Spawn output capture tasks via log_writer
169        if let Some(stdout) = child.stdout.take() {
170            let tx = self.output_tx.clone();
171            let pname = name.clone();
172            let path = log_dir.join(format!("{}.stdout", name));
173            tokio::spawn(async move {
174                log_writer::capture_output(
175                    stdout,
176                    &path,
177                    &pname,
178                    ProtoStream::Stdout,
179                    tx,
180                    DEFAULT_MAX_LOG_BYTES,
181                    log_writer::DEFAULT_MAX_ROTATED_FILES,
182                )
183                .await;
184            });
185        }
186        if let Some(stderr) = child.stderr.take() {
187            let tx = self.output_tx.clone();
188            let pname = name.clone();
189            let path = log_dir.join(format!("{}.stderr", name));
190            tokio::spawn(async move {
191                log_writer::capture_output(
192                    stderr,
193                    &path,
194                    &pname,
195                    ProtoStream::Stderr,
196                    tx,
197                    DEFAULT_MAX_LOG_BYTES,
198                    log_writer::DEFAULT_MAX_ROTATED_FILES,
199                )
200                .await;
201            });
202        }
203
204        self.processes.insert(
205            name.clone(),
206            ManagedProcess {
207                name: name.clone(),
208                id: id.clone(),
209                command: command.to_string(),
210                cwd: cwd.map(std::string::ToString::to_string),
211                env: env.cloned().unwrap_or_default(),
212                child: Some(child),
213                pid,
214                started_at: Instant::now(),
215                exit_code: None,
216                port: resolved_port,
217            },
218        );
219
220        let url = resolved_port.map(|p| process_url(&name, p, None));
221        Response::RunOk {
222            name,
223            id,
224            pid,
225            port: resolved_port,
226            url,
227        }
228    }
229
230    pub async fn stop_process(&mut self, target: &str) -> Response {
231        let proc = match self.find_mut(target) {
232            Some(p) => p,
233            None => {
234                return Response::Error {
235                    code: ErrorCode::NotFound,
236                    message: format!("process not found: {}", target),
237                };
238            }
239        };
240
241        if let Some(ref child) = proc.child {
242            let raw_pid = child.id().unwrap_or(0) as i32;
243            if raw_pid > 0 {
244                // Signal the entire process group (child PID == PGID due to setpgid in pre_exec)
245                let pgid = nix::unistd::Pid::from_raw(raw_pid);
246                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGTERM);
247            }
248        }
249
250        // Wait up to 10s for graceful exit, then SIGKILL
251        if let Some(ref mut child) = proc.child {
252            let wait_result = tokio::time::timeout(Duration::from_secs(10), child.wait()).await;
253
254            match wait_result {
255                Ok(Ok(status)) => {
256                    proc.exit_code = status.code();
257                }
258                _ => {
259                    // Timed out or error — force kill the process group
260                    let raw_pid = proc.pid as i32;
261                    if raw_pid > 0 {
262                        let pgid = nix::unistd::Pid::from_raw(raw_pid);
263                        let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
264                    }
265                    let _ = child.wait().await;
266                    proc.exit_code = Some(-9);
267                }
268            }
269            proc.child = None;
270        }
271
272        Response::Ok {
273            message: format!("stopped {}", target),
274        }
275    }
276
277    pub async fn stop_all(&mut self) -> Response {
278        let names: Vec<String> = self.processes.keys().cloned().collect();
279        for name in names {
280            let _ = self.stop_process(&name).await;
281        }
282        self.processes.clear();
283        Response::Ok {
284            message: "all processes stopped".into(),
285        }
286    }
287
288    pub async fn restart_process(&mut self, target: &str) -> Response {
289        let (command, name, cwd, env, port) = match self.find(target) {
290            Some(p) => (
291                p.command.clone(),
292                p.name.clone(),
293                p.cwd.clone(),
294                p.env.clone(),
295                p.port,
296            ),
297            None => {
298                return Response::Error {
299                    code: ErrorCode::NotFound,
300                    message: format!("process not found: {}", target),
301                };
302            }
303        };
304        let _ = self.stop_process(target).await;
305        self.processes.remove(&name);
306        let env = if env.is_empty() { None } else { Some(env) };
307        self.spawn_process(&command, Some(name), cwd.as_deref(), env.as_ref(), port)
308            .await
309    }
310
311    pub fn enable_proxy(&mut self) {
312        self.port_allocator.enable_proxy();
313    }
314
315    pub fn status(&mut self) -> Response {
316        self.refresh_exit_states();
317        Response::Status {
318            processes: self.build_process_infos(),
319        }
320    }
321
322    /// Returns `None` if process not found or still running.
323    /// Returns `Some(exit_code)` if process has exited (`exit_code` is None for signal kills).
324    pub fn is_process_exited(&mut self, target: &str) -> Option<Option<i32>> {
325        self.refresh_exit_states();
326        self.find(target).and_then(|p| {
327            if p.child.is_none() {
328                Some(p.exit_code)
329            } else {
330                None
331            }
332        })
333    }
334
335    fn refresh_exit_states(&mut self) {
336        for proc in self.processes.values_mut() {
337            if proc.child.is_some() && proc.exit_code.is_none() {
338                if let Some(ref mut child) = proc.child {
339                    if let Ok(Some(status)) = child.try_wait() {
340                        proc.exit_code = status.code();
341                        proc.child = None;
342                    }
343                }
344            }
345        }
346    }
347
348    pub fn session_name(&self) -> &str {
349        &self.session
350    }
351
352    pub fn has_process(&self, target: &str) -> bool {
353        self.find(target).is_some()
354    }
355
356    /// Returns a map of running process names to their assigned ports.
357    pub fn running_ports(&self) -> HashMap<String, u16> {
358        self.processes
359            .iter()
360            .filter_map(|(name, p)| {
361                if p.child.is_some() {
362                    p.port.map(|port| (name.clone(), port))
363                } else {
364                    None
365                }
366            })
367            .collect()
368    }
369
370    /// Non-mutating status snapshot for use by the proxy status page.
371    /// May show stale exit states since it skips `refresh_exit_states()`.
372    pub fn status_snapshot(&self) -> Response {
373        Response::Status {
374            processes: self.build_process_infos(),
375        }
376    }
377
378    fn build_process_infos(&self) -> Vec<ProcessInfo> {
379        let mut infos: Vec<ProcessInfo> = self
380            .processes
381            .values()
382            .map(|p| ProcessInfo {
383                name: p.name.clone(),
384                id: p.id.clone(),
385                pid: p.pid,
386                state: if p.child.is_some() {
387                    ProcessState::Running
388                } else {
389                    ProcessState::Exited
390                },
391                exit_code: p.exit_code,
392                uptime_secs: if p.child.is_some() {
393                    Some(p.started_at.elapsed().as_secs())
394                } else {
395                    None
396                },
397                command: p.command.clone(),
398                port: p.port,
399                url: p.port.map(|port| process_url(&p.name, port, None)),
400            })
401            .collect();
402        infos.sort_by(|a, b| a.name.cmp(&b.name));
403        infos
404    }
405
406    fn find(&self, target: &str) -> Option<&ManagedProcess> {
407        self.processes
408            .get(target)
409            .or_else(|| self.processes.values().find(|p| p.id == target))
410    }
411
412    fn find_mut(&mut self, target: &str) -> Option<&mut ManagedProcess> {
413        if self.processes.contains_key(target) {
414            self.processes.get_mut(target)
415        } else {
416            self.processes.values_mut().find(|p| p.id == target)
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_valid_dns_labels() {
427        assert!(is_valid_dns_label("api"));
428        assert!(is_valid_dns_label("my-app"));
429        assert!(is_valid_dns_label("a"));
430        assert!(is_valid_dns_label("a1"));
431        assert!(is_valid_dns_label("123"));
432    }
433
434    #[test]
435    fn test_invalid_dns_labels() {
436        assert!(!is_valid_dns_label(""));
437        assert!(!is_valid_dns_label("-start"));
438        assert!(!is_valid_dns_label("end-"));
439        assert!(!is_valid_dns_label("UPPER"));
440        assert!(!is_valid_dns_label("has.dot"));
441        assert!(!is_valid_dns_label("has space"));
442        assert!(!is_valid_dns_label(&"a".repeat(64))); // > 63 chars
443        assert!(!is_valid_dns_label("has_underscore"));
444    }
445}