Skip to main content

agent_procs/daemon/
process_manager.rs

1use crate::daemon::log_writer::{self, OutputLine};
2use crate::daemon::port_allocator::PortAllocator;
3use crate::paths;
4use crate::protocol::{
5    ErrorCode, ProcessInfo, ProcessState, Response, Stream as ProtoStream, process_url,
6};
7use crate::session::IdCounter;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use std::sync::atomic::AtomicU64;
12use std::time::{Duration, Instant};
13use tokio::process::{Child, Command};
14use tokio::sync::broadcast;
15
16const DEFAULT_MAX_LOG_BYTES: u64 = 50 * 1024 * 1024; // 50MB
17
18/// Returns true if `name` is a valid DNS label: 1-63 lowercase alphanumeric/hyphen
19/// chars, not starting or ending with a hyphen.
20#[must_use]
21pub fn is_valid_dns_label(name: &str) -> bool {
22    if name.is_empty() || name.len() > 63 {
23        return false;
24    }
25    if name.starts_with('-') || name.ends_with('-') {
26        return false;
27    }
28    name.chars()
29        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
30}
31
32pub struct ManagedProcess {
33    pub name: String,
34    pub id: String,
35    pub command: String,
36    pub cwd: Option<String>,
37    pub env: HashMap<String, String>,
38    pub child: Option<Child>,
39    pub pid: u32,
40    pub started_at: Instant,
41    pub exit_code: Option<i32>,
42    pub port: Option<u16>,
43}
44
45pub struct ProcessManager {
46    processes: HashMap<String, ManagedProcess>,
47    id_counter: IdCounter,
48    session: String,
49    pub output_tx: broadcast::Sender<OutputLine>,
50    port_allocator: PortAllocator,
51}
52
53impl ProcessManager {
54    pub fn new(session: &str) -> Self {
55        let (output_tx, _) = broadcast::channel(1024);
56        Self {
57            processes: HashMap::new(),
58            id_counter: IdCounter::new(),
59            session: session.to_string(),
60            output_tx,
61            port_allocator: PortAllocator::new(),
62        }
63    }
64
65    #[allow(unsafe_code, clippy::unused_async)]
66    pub async fn spawn_process(
67        &mut self,
68        command: &str,
69        name: Option<String>,
70        cwd: Option<&str>,
71        env: Option<&HashMap<String, String>>,
72        port: Option<u16>,
73    ) -> Response {
74        let id = self.id_counter.next_id();
75        let name = name.unwrap_or_else(|| id.clone());
76
77        // Reject names that could cause path traversal in log files
78        if name.contains('/') || name.contains('\\') || name.contains("..") || name.contains('\0') {
79            return Response::Error {
80                code: ErrorCode::General,
81                message: format!("invalid process name: {}", name),
82            };
83        }
84
85        // When proxy is active, names must be valid DNS labels for subdomain routing
86        if self.port_allocator.is_proxy_enabled() && !is_valid_dns_label(&name) {
87            return Response::Error {
88                code: ErrorCode::General,
89                message: format!(
90                    "invalid process name for proxy: '{}' (must be lowercase alphanumeric/hyphens, max 63 chars)",
91                    name
92                ),
93            };
94        }
95
96        // Resolve the port: use explicit port, auto-assign if proxy is enabled, or None
97        let resolved_port = if let Some(p) = port {
98            Some(p)
99        } else if self.port_allocator.is_proxy_enabled() {
100            let assigned: std::collections::HashSet<u16> =
101                self.processes.values().filter_map(|p| p.port).collect();
102            match self.port_allocator.auto_assign_port(&assigned) {
103                Ok(p) => Some(p),
104                Err(e) => {
105                    return Response::Error {
106                        code: ErrorCode::General,
107                        message: e.to_string(),
108                    };
109                }
110            }
111        } else {
112            None
113        };
114
115        if self.processes.contains_key(&name) {
116            return Response::Error {
117                code: ErrorCode::General,
118                message: format!("process already exists: {}", name),
119            };
120        }
121
122        let log_dir = paths::log_dir(&self.session);
123        let _ = std::fs::create_dir_all(&log_dir);
124
125        let mut cmd = Command::new("sh");
126        cmd.arg("-c")
127            .arg(command)
128            .stdout(Stdio::piped())
129            .stderr(Stdio::piped());
130        if let Some(dir) = cwd {
131            cmd.current_dir(dir);
132        }
133        if let Some(p) = resolved_port {
134            // Inject PORT and HOST; user-supplied env takes precedence
135            let mut merged_env: HashMap<String, String> = HashMap::new();
136            merged_env.insert("PORT".to_string(), p.to_string());
137            merged_env.insert("HOST".to_string(), "127.0.0.1".to_string());
138            if let Some(env_vars) = env {
139                for (k, v) in env_vars {
140                    merged_env.insert(k.clone(), v.clone());
141                }
142            }
143            cmd.envs(&merged_env);
144        } else if let Some(env_vars) = env {
145            cmd.envs(env_vars);
146        }
147        // SAFETY: `setpgid(0, 0)` creates a new process group with the child as
148        // leader.  This must happen before exec so that all grandchildren inherit
149        // the group.  The parent uses this PGID to signal the entire tree on stop.
150        unsafe {
151            cmd.pre_exec(|| {
152                nix::unistd::setpgid(nix::unistd::Pid::from_raw(0), nix::unistd::Pid::from_raw(0))
153                    .map_err(std::io::Error::other)?;
154                Ok(())
155            });
156        }
157
158        let mut child = match cmd.spawn() {
159            Ok(c) => c,
160            Err(e) => {
161                return Response::Error {
162                    code: ErrorCode::General,
163                    message: format!("failed to spawn: {}", e),
164                };
165            }
166        };
167
168        let pid = child.id().unwrap_or(0);
169
170        // Shared sequence counter for interleaved ordering across streams
171        let seq_counter = Arc::new(AtomicU64::new(0));
172
173        // Spawn output capture tasks via log_writer
174        if let Some(stdout) = child.stdout.take() {
175            let tx = self.output_tx.clone();
176            let pname = name.clone();
177            let path = log_dir.join(format!("{}.stdout", name));
178            let seq = Arc::clone(&seq_counter);
179            tokio::spawn(async move {
180                log_writer::capture_output(
181                    stdout,
182                    &path,
183                    &pname,
184                    ProtoStream::Stdout,
185                    tx,
186                    DEFAULT_MAX_LOG_BYTES,
187                    log_writer::DEFAULT_MAX_ROTATED_FILES,
188                    seq,
189                )
190                .await;
191            });
192        }
193        if let Some(stderr) = child.stderr.take() {
194            let tx = self.output_tx.clone();
195            let pname = name.clone();
196            let path = log_dir.join(format!("{}.stderr", name));
197            let seq = Arc::clone(&seq_counter);
198            tokio::spawn(async move {
199                log_writer::capture_output(
200                    stderr,
201                    &path,
202                    &pname,
203                    ProtoStream::Stderr,
204                    tx,
205                    DEFAULT_MAX_LOG_BYTES,
206                    log_writer::DEFAULT_MAX_ROTATED_FILES,
207                    seq,
208                )
209                .await;
210            });
211        }
212
213        self.processes.insert(
214            name.clone(),
215            ManagedProcess {
216                name: name.clone(),
217                id: id.clone(),
218                command: command.to_string(),
219                cwd: cwd.map(std::string::ToString::to_string),
220                env: env.cloned().unwrap_or_default(),
221                child: Some(child),
222                pid,
223                started_at: Instant::now(),
224                exit_code: None,
225                port: resolved_port,
226            },
227        );
228
229        let url = resolved_port.map(|p| process_url(&name, p, None));
230        Response::RunOk {
231            name,
232            id,
233            pid,
234            port: resolved_port,
235            url,
236        }
237    }
238
239    pub async fn stop_process(&mut self, target: &str) -> Response {
240        let proc = match self.find_mut(target) {
241            Some(p) => p,
242            None => {
243                return Response::Error {
244                    code: ErrorCode::NotFound,
245                    message: format!("process not found: {}", target),
246                };
247            }
248        };
249
250        if let Some(ref child) = proc.child {
251            let raw_pid = child.id().unwrap_or(0) as i32;
252            if raw_pid > 0 {
253                // Signal the entire process group (child PID == PGID due to setpgid in pre_exec)
254                let pgid = nix::unistd::Pid::from_raw(raw_pid);
255                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGTERM);
256            }
257        }
258
259        // Wait up to 10s for graceful exit, then SIGKILL
260        if let Some(ref mut child) = proc.child {
261            let wait_result = tokio::time::timeout(Duration::from_secs(10), child.wait()).await;
262
263            match wait_result {
264                Ok(Ok(status)) => {
265                    proc.exit_code = status.code();
266                }
267                _ => {
268                    // Timed out or error — force kill the process group
269                    let raw_pid = proc.pid as i32;
270                    if raw_pid > 0 {
271                        let pgid = nix::unistd::Pid::from_raw(raw_pid);
272                        let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
273                    }
274                    let _ = child.wait().await;
275                    proc.exit_code = Some(-9);
276                }
277            }
278            proc.child = None;
279        }
280
281        Response::Ok {
282            message: format!("stopped {}", target),
283        }
284    }
285
286    pub async fn stop_all(&mut self) -> Response {
287        let names: Vec<String> = self.processes.keys().cloned().collect();
288        for name in names {
289            let _ = self.stop_process(&name).await;
290        }
291        self.processes.clear();
292        Response::Ok {
293            message: "all processes stopped".into(),
294        }
295    }
296
297    pub async fn restart_process(&mut self, target: &str) -> Response {
298        let (command, name, cwd, env, port) = match self.find(target) {
299            Some(p) => (
300                p.command.clone(),
301                p.name.clone(),
302                p.cwd.clone(),
303                p.env.clone(),
304                p.port,
305            ),
306            None => {
307                return Response::Error {
308                    code: ErrorCode::NotFound,
309                    message: format!("process not found: {}", target),
310                };
311            }
312        };
313        let _ = self.stop_process(target).await;
314        self.processes.remove(&name);
315        let env = if env.is_empty() { None } else { Some(env) };
316        self.spawn_process(&command, Some(name), cwd.as_deref(), env.as_ref(), port)
317            .await
318    }
319
320    pub fn enable_proxy(&mut self) {
321        self.port_allocator.enable_proxy();
322    }
323
324    pub fn status(&mut self) -> Response {
325        self.refresh_exit_states();
326        Response::Status {
327            processes: self.build_process_infos(),
328        }
329    }
330
331    /// Returns `None` if process not found or still running.
332    /// Returns `Some(exit_code)` if process has exited (`exit_code` is None for signal kills).
333    pub fn is_process_exited(&mut self, target: &str) -> Option<Option<i32>> {
334        self.refresh_exit_states();
335        self.find(target).and_then(|p| {
336            if p.child.is_none() {
337                Some(p.exit_code)
338            } else {
339                None
340            }
341        })
342    }
343
344    fn refresh_exit_states(&mut self) {
345        for proc in self.processes.values_mut() {
346            if proc.child.is_some()
347                && proc.exit_code.is_none()
348                && let Some(ref mut child) = proc.child
349                && let Ok(Some(status)) = child.try_wait()
350            {
351                proc.exit_code = status.code();
352                proc.child = None;
353            }
354        }
355    }
356
357    pub fn session_name(&self) -> &str {
358        &self.session
359    }
360
361    pub fn has_process(&self, target: &str) -> bool {
362        self.find(target).is_some()
363    }
364
365    /// Returns a map of running process names to their assigned ports.
366    pub fn running_ports(&self) -> HashMap<String, u16> {
367        self.processes
368            .iter()
369            .filter_map(|(name, p)| {
370                if p.child.is_some() {
371                    p.port.map(|port| (name.clone(), port))
372                } else {
373                    None
374                }
375            })
376            .collect()
377    }
378
379    /// Non-mutating status snapshot for use by the proxy status page.
380    /// May show stale exit states since it skips `refresh_exit_states()`.
381    pub fn status_snapshot(&self) -> Response {
382        Response::Status {
383            processes: self.build_process_infos(),
384        }
385    }
386
387    fn build_process_infos(&self) -> Vec<ProcessInfo> {
388        let mut infos: Vec<ProcessInfo> = self
389            .processes
390            .values()
391            .map(|p| ProcessInfo {
392                name: p.name.clone(),
393                id: p.id.clone(),
394                pid: p.pid,
395                state: if p.child.is_some() {
396                    ProcessState::Running
397                } else {
398                    ProcessState::Exited
399                },
400                exit_code: p.exit_code,
401                uptime_secs: if p.child.is_some() {
402                    Some(p.started_at.elapsed().as_secs())
403                } else {
404                    None
405                },
406                command: p.command.clone(),
407                port: p.port,
408                url: p.port.map(|port| process_url(&p.name, port, None)),
409            })
410            .collect();
411        infos.sort_by(|a, b| a.name.cmp(&b.name));
412        infos
413    }
414
415    fn find(&self, target: &str) -> Option<&ManagedProcess> {
416        self.processes
417            .get(target)
418            .or_else(|| self.processes.values().find(|p| p.id == target))
419    }
420
421    fn find_mut(&mut self, target: &str) -> Option<&mut ManagedProcess> {
422        if self.processes.contains_key(target) {
423            self.processes.get_mut(target)
424        } else {
425            self.processes.values_mut().find(|p| p.id == target)
426        }
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_valid_dns_labels() {
436        assert!(is_valid_dns_label("api"));
437        assert!(is_valid_dns_label("my-app"));
438        assert!(is_valid_dns_label("a"));
439        assert!(is_valid_dns_label("a1"));
440        assert!(is_valid_dns_label("123"));
441    }
442
443    #[test]
444    fn test_invalid_dns_labels() {
445        assert!(!is_valid_dns_label(""));
446        assert!(!is_valid_dns_label("-start"));
447        assert!(!is_valid_dns_label("end-"));
448        assert!(!is_valid_dns_label("UPPER"));
449        assert!(!is_valid_dns_label("has.dot"));
450        assert!(!is_valid_dns_label("has space"));
451        assert!(!is_valid_dns_label(&"a".repeat(64))); // > 63 chars
452        assert!(!is_valid_dns_label("has_underscore"));
453    }
454}