Skip to main content

agent_procs/daemon/
process_manager.rs

1use crate::daemon::log_writer::{self, OutputLine};
2use crate::daemon::port_allocator::PortAllocator;
3use crate::paths;
4use crate::protocol::{
5    ErrorCode, ProcessInfo, ProcessState, Response, Stream as ProtoStream, process_url,
6};
7use crate::session::IdCounter;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use std::sync::atomic::AtomicU64;
12use std::time::{Duration, Instant};
13use tokio::process::{Child, Command};
14use tokio::sync::broadcast;
15
16const DEFAULT_MAX_LOG_BYTES: u64 = 50 * 1024 * 1024; // 50MB
17
18/// Returns true if `name` is a valid DNS label: 1-63 lowercase alphanumeric/hyphen
19/// chars, not starting or ending with a hyphen.
20#[must_use]
21pub fn is_valid_dns_label(name: &str) -> bool {
22    if name.is_empty() || name.len() > 63 {
23        return false;
24    }
25    if name.starts_with('-') || name.ends_with('-') {
26        return false;
27    }
28    name.chars()
29        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
30}
31
32pub struct ManagedProcess {
33    pub name: String,
34    pub id: String,
35    pub command: String,
36    pub cwd: Option<String>,
37    pub env: HashMap<String, String>,
38    pub child: Option<Child>,
39    pub pid: u32,
40    pub started_at: Instant,
41    pub exit_code: Option<i32>,
42    pub port: Option<u16>,
43}
44
45pub struct ProcessManager {
46    processes: HashMap<String, ManagedProcess>,
47    id_counter: IdCounter,
48    session: String,
49    pub output_tx: broadcast::Sender<OutputLine>,
50    port_allocator: PortAllocator,
51}
52
53impl ProcessManager {
54    pub fn new(session: &str) -> Self {
55        let (output_tx, _) = broadcast::channel(1024);
56        Self {
57            processes: HashMap::new(),
58            id_counter: IdCounter::new(),
59            session: session.to_string(),
60            output_tx,
61            port_allocator: PortAllocator::new(),
62        }
63    }
64
65    #[allow(unsafe_code, clippy::unused_async)]
66    pub async fn spawn_process(
67        &mut self,
68        command: &str,
69        name: Option<String>,
70        cwd: Option<&str>,
71        env: Option<&HashMap<String, String>>,
72        port: Option<u16>,
73    ) -> Response {
74        let id = self.id_counter.next_id();
75        let name = name.unwrap_or_else(|| id.clone());
76
77        // Reject names that could cause path traversal in log files
78        if name.contains('/') || name.contains('\\') || name.contains("..") || name.contains('\0') {
79            return Response::Error {
80                code: ErrorCode::General,
81                message: format!("invalid process name: {}", name),
82            };
83        }
84
85        // When proxy is active, names must be valid DNS labels for subdomain routing
86        if self.port_allocator.is_proxy_enabled() && !is_valid_dns_label(&name) {
87            return Response::Error {
88                code: ErrorCode::General,
89                message: format!(
90                    "invalid process name for proxy: '{}' (must be lowercase alphanumeric/hyphens, max 63 chars)",
91                    name
92                ),
93            };
94        }
95
96        // Resolve the port: use explicit port, auto-assign if proxy is enabled, or None
97        let resolved_port = if let Some(p) = port {
98            Some(p)
99        } else if self.port_allocator.is_proxy_enabled() {
100            let assigned: std::collections::HashSet<u16> = self
101                .processes
102                .values()
103                .filter(|p| p.child.is_some())
104                .filter_map(|p| p.port)
105                .collect();
106            match self.port_allocator.auto_assign_port(&assigned) {
107                Ok(p) => Some(p),
108                Err(e) => {
109                    return Response::Error {
110                        code: ErrorCode::General,
111                        message: e.to_string(),
112                    };
113                }
114            }
115        } else {
116            None
117        };
118
119        if self.processes.contains_key(&name) {
120            return Response::Error {
121                code: ErrorCode::General,
122                message: format!("process already exists: {}", name),
123            };
124        }
125
126        let log_dir = paths::log_dir(&self.session);
127        let _ = std::fs::create_dir_all(&log_dir);
128
129        let mut cmd = Command::new("sh");
130        cmd.arg("-c")
131            .arg(command)
132            .stdout(Stdio::piped())
133            .stderr(Stdio::piped());
134        if let Some(dir) = cwd {
135            cmd.current_dir(dir);
136        }
137        if let Some(p) = resolved_port {
138            // Inject PORT and HOST; user-supplied env takes precedence
139            let mut merged_env: HashMap<String, String> = HashMap::new();
140            merged_env.insert("PORT".to_string(), p.to_string());
141            merged_env.insert("HOST".to_string(), "127.0.0.1".to_string());
142            if let Some(env_vars) = env {
143                for (k, v) in env_vars {
144                    merged_env.insert(k.clone(), v.clone());
145                }
146            }
147            cmd.envs(&merged_env);
148        } else if let Some(env_vars) = env {
149            cmd.envs(env_vars);
150        }
151        // SAFETY: `setpgid(0, 0)` creates a new process group with the child as
152        // leader.  This must happen before exec so that all grandchildren inherit
153        // the group.  The parent uses this PGID to signal the entire tree on stop.
154        unsafe {
155            cmd.pre_exec(|| {
156                nix::unistd::setpgid(nix::unistd::Pid::from_raw(0), nix::unistd::Pid::from_raw(0))
157                    .map_err(std::io::Error::other)?;
158                Ok(())
159            });
160        }
161
162        let mut child = match cmd.spawn() {
163            Ok(c) => c,
164            Err(e) => {
165                return Response::Error {
166                    code: ErrorCode::General,
167                    message: format!("failed to spawn: {}", e),
168                };
169            }
170        };
171
172        let pid = child.id().unwrap_or(0);
173
174        // Shared sequence counter for interleaved ordering across streams
175        let seq_counter = Arc::new(AtomicU64::new(0));
176
177        // Spawn output capture tasks via log_writer
178        if let Some(stdout) = child.stdout.take() {
179            let tx = self.output_tx.clone();
180            let pname = name.clone();
181            let path = log_dir.join(format!("{}.stdout", name));
182            let seq = Arc::clone(&seq_counter);
183            tokio::spawn(async move {
184                log_writer::capture_output(
185                    stdout,
186                    &path,
187                    &pname,
188                    ProtoStream::Stdout,
189                    tx,
190                    DEFAULT_MAX_LOG_BYTES,
191                    log_writer::DEFAULT_MAX_ROTATED_FILES,
192                    seq,
193                )
194                .await;
195            });
196        }
197        if let Some(stderr) = child.stderr.take() {
198            let tx = self.output_tx.clone();
199            let pname = name.clone();
200            let path = log_dir.join(format!("{}.stderr", name));
201            let seq = Arc::clone(&seq_counter);
202            tokio::spawn(async move {
203                log_writer::capture_output(
204                    stderr,
205                    &path,
206                    &pname,
207                    ProtoStream::Stderr,
208                    tx,
209                    DEFAULT_MAX_LOG_BYTES,
210                    log_writer::DEFAULT_MAX_ROTATED_FILES,
211                    seq,
212                )
213                .await;
214            });
215        }
216
217        self.processes.insert(
218            name.clone(),
219            ManagedProcess {
220                name: name.clone(),
221                id: id.clone(),
222                command: command.to_string(),
223                cwd: cwd.map(std::string::ToString::to_string),
224                env: env.cloned().unwrap_or_default(),
225                child: Some(child),
226                pid,
227                started_at: Instant::now(),
228                exit_code: None,
229                port: resolved_port,
230            },
231        );
232
233        let url = resolved_port.map(|p| process_url(&name, p, None));
234        Response::RunOk {
235            name,
236            id,
237            pid,
238            port: resolved_port,
239            url,
240        }
241    }
242
243    pub async fn stop_process(&mut self, target: &str) -> Response {
244        let proc = match self.find_mut(target) {
245            Some(p) => p,
246            None => {
247                return Response::Error {
248                    code: ErrorCode::NotFound,
249                    message: format!("process not found: {}", target),
250                };
251            }
252        };
253
254        if let Some(ref child) = proc.child {
255            let raw_pid = child.id().unwrap_or(0) as i32;
256            if raw_pid > 0 {
257                // Signal the entire process group (child PID == PGID due to setpgid in pre_exec)
258                let pgid = nix::unistd::Pid::from_raw(raw_pid);
259                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGTERM);
260            }
261        }
262
263        // Wait up to 10s for graceful exit, then SIGKILL
264        if let Some(ref mut child) = proc.child {
265            let wait_result = tokio::time::timeout(Duration::from_secs(10), child.wait()).await;
266
267            match wait_result {
268                Ok(Ok(status)) => {
269                    proc.exit_code = status.code();
270                }
271                _ => {
272                    // Timed out or error — force kill the process group
273                    let raw_pid = proc.pid as i32;
274                    if raw_pid > 0 {
275                        let pgid = nix::unistd::Pid::from_raw(raw_pid);
276                        let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
277                    }
278                    let _ = child.wait().await;
279                    proc.exit_code = Some(-9);
280                }
281            }
282            proc.child = None;
283        }
284
285        Response::Ok {
286            message: format!("stopped {}", target),
287        }
288    }
289
290    pub async fn stop_all(&mut self) -> Response {
291        let names: Vec<String> = self.processes.keys().cloned().collect();
292        for name in names {
293            let _ = self.stop_process(&name).await;
294        }
295        self.processes.clear();
296        Response::Ok {
297            message: "all processes stopped".into(),
298        }
299    }
300
301    pub async fn restart_process(&mut self, target: &str) -> Response {
302        let (command, name, cwd, env, port) = match self.find(target) {
303            Some(p) => (
304                p.command.clone(),
305                p.name.clone(),
306                p.cwd.clone(),
307                p.env.clone(),
308                p.port,
309            ),
310            None => {
311                return Response::Error {
312                    code: ErrorCode::NotFound,
313                    message: format!("process not found: {}", target),
314                };
315            }
316        };
317        let _ = self.stop_process(target).await;
318        self.processes.remove(&name);
319        let env = if env.is_empty() { None } else { Some(env) };
320        self.spawn_process(&command, Some(name), cwd.as_deref(), env.as_ref(), port)
321            .await
322    }
323
324    pub fn enable_proxy(&mut self) {
325        self.port_allocator.enable_proxy();
326    }
327
328    pub fn status(&mut self) -> Response {
329        self.refresh_exit_states();
330        Response::Status {
331            processes: self.build_process_infos(),
332        }
333    }
334
335    /// Returns `None` if process not found or still running.
336    /// Returns `Some(exit_code)` if process has exited (`exit_code` is None for signal kills).
337    pub fn is_process_exited(&mut self, target: &str) -> Option<Option<i32>> {
338        self.refresh_exit_states();
339        self.find(target).and_then(|p| {
340            if p.child.is_none() {
341                Some(p.exit_code)
342            } else {
343                None
344            }
345        })
346    }
347
348    pub(crate) fn refresh_exit_states(&mut self) -> bool {
349        let mut changed = false;
350        for proc in self.processes.values_mut() {
351            if proc.child.is_some()
352                && proc.exit_code.is_none()
353                && let Some(ref mut child) = proc.child
354                && let Ok(Some(status)) = child.try_wait()
355            {
356                proc.exit_code = status.code();
357                proc.child = None;
358                changed = true;
359            }
360        }
361        changed
362    }
363
364    pub fn session_name(&self) -> &str {
365        &self.session
366    }
367
368    pub fn has_process(&self, target: &str) -> bool {
369        self.find(target).is_some()
370    }
371
372    /// Returns a map of running process names to their assigned ports.
373    pub fn running_ports(&self) -> HashMap<String, u16> {
374        self.processes
375            .iter()
376            .filter_map(|(name, p)| {
377                if p.child.is_some() {
378                    p.port.map(|port| (name.clone(), port))
379                } else {
380                    None
381                }
382            })
383            .collect()
384    }
385
386    /// Non-mutating status snapshot for use by the proxy status page.
387    /// May show stale exit states since it skips `refresh_exit_states()`.
388    pub fn status_snapshot(&self) -> Response {
389        Response::Status {
390            processes: self.build_process_infos(),
391        }
392    }
393
394    fn build_process_infos(&self) -> Vec<ProcessInfo> {
395        let mut infos: Vec<ProcessInfo> = self
396            .processes
397            .values()
398            .map(|p| ProcessInfo {
399                name: p.name.clone(),
400                id: p.id.clone(),
401                pid: p.pid,
402                state: if p.child.is_some() {
403                    ProcessState::Running
404                } else {
405                    ProcessState::Exited
406                },
407                exit_code: p.exit_code,
408                uptime_secs: if p.child.is_some() {
409                    Some(p.started_at.elapsed().as_secs())
410                } else {
411                    None
412                },
413                command: p.command.clone(),
414                port: p.port,
415                url: p.port.map(|port| process_url(&p.name, port, None)),
416            })
417            .collect();
418        infos.sort_by(|a, b| a.name.cmp(&b.name));
419        infos
420    }
421
422    fn find(&self, target: &str) -> Option<&ManagedProcess> {
423        self.processes
424            .get(target)
425            .or_else(|| self.processes.values().find(|p| p.id == target))
426    }
427
428    fn find_mut(&mut self, target: &str) -> Option<&mut ManagedProcess> {
429        if self.processes.contains_key(target) {
430            self.processes.get_mut(target)
431        } else {
432            self.processes.values_mut().find(|p| p.id == target)
433        }
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_valid_dns_labels() {
443        assert!(is_valid_dns_label("api"));
444        assert!(is_valid_dns_label("my-app"));
445        assert!(is_valid_dns_label("a"));
446        assert!(is_valid_dns_label("a1"));
447        assert!(is_valid_dns_label("123"));
448    }
449
450    #[test]
451    fn test_invalid_dns_labels() {
452        assert!(!is_valid_dns_label(""));
453        assert!(!is_valid_dns_label("-start"));
454        assert!(!is_valid_dns_label("end-"));
455        assert!(!is_valid_dns_label("UPPER"));
456        assert!(!is_valid_dns_label("has.dot"));
457        assert!(!is_valid_dns_label("has space"));
458        assert!(!is_valid_dns_label(&"a".repeat(64))); // > 63 chars
459        assert!(!is_valid_dns_label("has_underscore"));
460    }
461}