Skip to main content

bamboo_agent/process/
registry.rs

1//! Process management for external agent runs and Claude sessions
2//!
3//! This module provides process lifecycle management:
4//! - Registration and tracking of running processes
5//! - Graceful and forceful termination
6//! - Live output capture
7//! - Cross-platform process killing
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tokio::process::Child;
14use tokio::sync::Mutex as AsyncMutex;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub enum ProcessType {
18    AgentRun { agent_id: i64, agent_name: String },
19    ClaudeSession { session_id: String },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ProcessInfo {
24    pub run_id: i64,
25    pub process_type: ProcessType,
26    pub pid: u32,
27    pub started_at: DateTime<Utc>,
28    pub project_path: String,
29    pub task: String,
30    pub model: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct ProcessRegistrationConfig {
35    pub run_id: i64,
36    pub agent_id: i64,
37    pub agent_name: String,
38    pub pid: u32,
39    pub project_path: String,
40    pub task: String,
41    pub model: String,
42}
43
44#[allow(dead_code)]
45pub struct ProcessHandle {
46    pub info: ProcessInfo,
47    pub child: Arc<Mutex<Option<Child>>>,
48    pub live_output: Arc<Mutex<String>>,
49}
50
51pub struct ProcessRegistry {
52    processes: Arc<AsyncMutex<HashMap<i64, ProcessHandle>>>,
53    next_id: Arc<Mutex<i64>>,
54}
55
56impl ProcessRegistry {
57    pub fn new() -> Self {
58        Self {
59            processes: Arc::new(AsyncMutex::new(HashMap::new())),
60            next_id: Arc::new(Mutex::new(1000000)),
61        }
62    }
63
64    pub fn generate_id(&self) -> Result<i64, String> {
65        let mut next_id = self.next_id.lock().map_err(|e| e.to_string())?;
66        let id = *next_id;
67        *next_id += 1;
68        Ok(id)
69    }
70
71    pub async fn register_process(
72        &self,
73        config: ProcessRegistrationConfig,
74        child: Child,
75    ) -> Result<(), String> {
76        let ProcessRegistrationConfig {
77            run_id,
78            agent_id,
79            agent_name,
80            pid,
81            project_path,
82            task,
83            model,
84        } = config;
85
86        let process_info = ProcessInfo {
87            run_id,
88            process_type: ProcessType::AgentRun {
89                agent_id,
90                agent_name,
91            },
92            pid,
93            started_at: Utc::now(),
94            project_path,
95            task,
96            model,
97        };
98
99        self.register_process_internal(run_id, process_info, child)
100            .await
101    }
102
103    pub async fn register_sidecar_process(
104        &self,
105        config: ProcessRegistrationConfig,
106    ) -> Result<(), String> {
107        let ProcessRegistrationConfig {
108            run_id,
109            agent_id,
110            agent_name,
111            pid,
112            project_path,
113            task,
114            model,
115        } = config;
116
117        let process_info = ProcessInfo {
118            run_id,
119            process_type: ProcessType::AgentRun {
120                agent_id,
121                agent_name,
122            },
123            pid,
124            started_at: Utc::now(),
125            project_path,
126            task,
127            model,
128        };
129
130        let mut processes = self.processes.lock().await;
131
132        let process_handle = ProcessHandle {
133            info: process_info,
134            child: Arc::new(Mutex::new(None)),
135            live_output: Arc::new(Mutex::new(String::new())),
136        };
137
138        processes.insert(run_id, process_handle);
139        Ok(())
140    }
141
142    pub async fn register_claude_session(
143        &self,
144        session_id: String,
145        pid: u32,
146        project_path: String,
147        task: String,
148        model: String,
149        child: Arc<Mutex<Option<Child>>>,
150    ) -> Result<i64, String> {
151        let run_id = self.generate_id()?;
152
153        let process_info = ProcessInfo {
154            run_id,
155            process_type: ProcessType::ClaudeSession { session_id },
156            pid,
157            started_at: Utc::now(),
158            project_path,
159            task,
160            model,
161        };
162
163        let mut processes = self.processes.lock().await;
164
165        let process_handle = ProcessHandle {
166            info: process_info,
167            child,
168            live_output: Arc::new(Mutex::new(String::new())),
169        };
170
171        processes.insert(run_id, process_handle);
172        Ok(run_id)
173    }
174
175    async fn register_process_internal(
176        &self,
177        run_id: i64,
178        process_info: ProcessInfo,
179        child: Child,
180    ) -> Result<(), String> {
181        let mut processes = self.processes.lock().await;
182
183        let process_handle = ProcessHandle {
184            info: process_info,
185            child: Arc::new(Mutex::new(Some(child))),
186            live_output: Arc::new(Mutex::new(String::new())),
187        };
188
189        processes.insert(run_id, process_handle);
190        Ok(())
191    }
192
193    pub async fn get_running_claude_sessions(&self) -> Result<Vec<ProcessInfo>, String> {
194        let processes = self.processes.lock().await;
195        Ok(processes
196            .values()
197            .filter_map(|handle| match &handle.info.process_type {
198                ProcessType::ClaudeSession { .. } => Some(handle.info.clone()),
199                _ => None,
200            })
201            .collect())
202    }
203
204    pub async fn get_claude_session_by_id(
205        &self,
206        session_id: &str,
207    ) -> Result<Option<ProcessInfo>, String> {
208        let processes = self.processes.lock().await;
209        Ok(processes
210            .values()
211            .find(|handle| match &handle.info.process_type {
212                ProcessType::ClaudeSession { session_id: sid } => sid == session_id,
213                _ => false,
214            })
215            .map(|handle| handle.info.clone()))
216    }
217
218    pub async fn unregister_process(&self, run_id: i64) -> Result<(), String> {
219        let mut processes = self.processes.lock().await;
220        processes.remove(&run_id);
221        Ok(())
222    }
223
224    /// Synchronous version for use in non-async contexts
225    #[allow(dead_code)]
226    fn unregister_process_sync(&self, run_id: i64) -> Result<(), String> {
227        // Use try_lock to avoid blocking in sync context
228        // If we can't get the lock, that's okay - the process will be cleaned up later
229        if let Ok(mut processes) = self.processes.try_lock() {
230            processes.remove(&run_id);
231        }
232        Ok(())
233    }
234
235    #[allow(dead_code)]
236    pub async fn get_running_processes(&self) -> Result<Vec<ProcessInfo>, String> {
237        let processes = self.processes.lock().await;
238        Ok(processes
239            .values()
240            .map(|handle| handle.info.clone())
241            .collect())
242    }
243
244    pub async fn get_running_agent_processes(&self) -> Result<Vec<ProcessInfo>, String> {
245        let processes = self.processes.lock().await;
246        Ok(processes
247            .values()
248            .filter_map(|handle| match &handle.info.process_type {
249                ProcessType::AgentRun { .. } => Some(handle.info.clone()),
250                _ => None,
251            })
252            .collect())
253    }
254
255    #[allow(dead_code)]
256    pub async fn get_process(&self, run_id: i64) -> Result<Option<ProcessInfo>, String> {
257        let processes = self.processes.lock().await;
258        Ok(processes.get(&run_id).map(|handle| handle.info.clone()))
259    }
260
261    pub async fn kill_process(&self, run_id: i64) -> Result<bool, String> {
262        use log::{error, info, warn};
263
264        let (pid, child_arc) = {
265            let processes = self.processes.lock().await;
266            if let Some(handle) = processes.get(&run_id) {
267                (handle.info.pid, handle.child.clone())
268            } else {
269                warn!("Process {} not found in registry", run_id);
270                return Ok(false);
271            }
272        };
273
274        info!(
275            "Attempting graceful shutdown of process {} (PID: {})",
276            run_id, pid
277        );
278
279        let kill_sent = {
280            let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
281            if let Some(child) = child_guard.as_mut() {
282                match child.start_kill() {
283                    Ok(_) => {
284                        info!("Successfully sent kill signal to process {}", run_id);
285                        true
286                    }
287                    Err(e) => {
288                        error!("Failed to send kill signal to process {}: {}", run_id, e);
289                        false
290                    }
291                }
292            } else {
293                warn!(
294                    "No child handle available for process {} (PID: {}), attempting system kill",
295                    run_id, pid
296                );
297                false
298            }
299        };
300
301        if !kill_sent {
302            info!(
303                "Attempting fallback kill for process {} (PID: {})",
304                run_id, pid
305            );
306            match self.kill_process_by_pid(run_id, pid).await {
307                Ok(true) => return Ok(true),
308                Ok(false) => warn!(
309                    "Fallback kill also failed for process {} (PID: {})",
310                    run_id, pid
311                ),
312                Err(e) => error!("Error during fallback kill: {}", e),
313            }
314        }
315
316        let wait_result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
317            loop {
318                let status = {
319                    let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
320                    if let Some(child) = child_guard.as_mut() {
321                        match child.try_wait() {
322                            Ok(Some(status)) => {
323                                info!("Process {} exited with status: {:?}", run_id, status);
324                                *child_guard = None;
325                                Some(Ok::<(), String>(()))
326                            }
327                            Ok(None) => None,
328                            Err(e) => {
329                                error!("Error checking process status: {}", e);
330                                Some(Err(e.to_string()))
331                            }
332                        }
333                    } else {
334                        Some(Ok(()))
335                    }
336                };
337
338                match status {
339                    Some(result) => return result,
340                    None => {
341                        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
342                    }
343                }
344            }
345        })
346        .await;
347
348        match wait_result {
349            Ok(Ok(_)) => {
350                info!("Process {} exited gracefully", run_id);
351            }
352            Ok(Err(e)) => {
353                error!("Error waiting for process {}: {}", run_id, e);
354            }
355            Err(_) => {
356                warn!("Process {} didn't exit within 5 seconds after kill", run_id);
357                if let Ok(mut child_guard) = child_arc.lock() {
358                    *child_guard = None;
359                }
360                let _ = self.kill_process_by_pid(run_id, pid).await;
361            }
362        }
363
364        self.unregister_process(run_id).await?;
365
366        Ok(true)
367    }
368
369    pub async fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result<bool, String> {
370        use log::{error, info, warn};
371
372        info!("Attempting to kill process {} by PID {}", run_id, pid);
373
374        let kill_result = if cfg!(target_os = "windows") {
375            std::process::Command::new("taskkill")
376                .args(["/F", "/PID", &pid.to_string()])
377                .output()
378        } else {
379            let term_result = std::process::Command::new("kill")
380                .args(["-TERM", &pid.to_string()])
381                .output();
382
383            match &term_result {
384                Ok(output) if output.status.success() => {
385                    info!("Sent SIGTERM to PID {}", pid);
386                    tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
387
388                    let check_result = std::process::Command::new("kill")
389                        .args(["-0", &pid.to_string()])
390                        .output();
391
392                    if let Ok(output) = check_result {
393                        if output.status.success() {
394                            warn!(
395                                "Process {} still running after SIGTERM, sending SIGKILL",
396                                pid
397                            );
398                            std::process::Command::new("kill")
399                                .args(["-KILL", &pid.to_string()])
400                                .output()
401                        } else {
402                            term_result
403                        }
404                    } else {
405                        term_result
406                    }
407                }
408                _ => {
409                    warn!("SIGTERM failed for PID {}, trying SIGKILL", pid);
410                    std::process::Command::new("kill")
411                        .args(["-KILL", &pid.to_string()])
412                        .output()
413                }
414            }
415        };
416
417        match kill_result {
418            Ok(output) => {
419                if output.status.success() {
420                    info!("Successfully killed process with PID {}", pid);
421                    self.unregister_process(run_id).await?;
422                    Ok(true)
423                } else {
424                    let error_msg = String::from_utf8_lossy(&output.stderr);
425                    warn!("Failed to kill PID {}: {}", pid, error_msg);
426                    Ok(false)
427                }
428            }
429            Err(e) => {
430                error!("Failed to execute kill command for PID {}: {}", pid, e);
431                Err(format!("Failed to execute kill command: {}", e))
432            }
433        }
434    }
435
436    #[allow(dead_code)]
437    pub async fn is_process_running(&self, run_id: i64) -> Result<bool, String> {
438        let processes = self.processes.lock().await;
439
440        if let Some(handle) = processes.get(&run_id) {
441            let child_arc = handle.child.clone();
442            drop(processes);
443
444            let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
445            if let Some(ref mut child) = child_guard.as_mut() {
446                match child.try_wait() {
447                    Ok(Some(_)) => {
448                        *child_guard = None;
449                        Ok(false)
450                    }
451                    Ok(None) => Ok(true),
452                    Err(_) => {
453                        *child_guard = None;
454                        Ok(false)
455                    }
456                }
457            } else {
458                Ok(false)
459            }
460        } else {
461            Ok(false)
462        }
463    }
464
465    pub async fn append_live_output(&self, run_id: i64, output: &str) -> Result<(), String> {
466        let processes = self.processes.lock().await;
467        if let Some(handle) = processes.get(&run_id) {
468            let mut live_output = handle.live_output.lock().map_err(|e| e.to_string())?;
469            live_output.push_str(output);
470            live_output.push('\n');
471        }
472        Ok(())
473    }
474
475    pub async fn get_live_output(&self, run_id: i64) -> Result<String, String> {
476        let processes = self.processes.lock().await;
477        if let Some(handle) = processes.get(&run_id) {
478            let live_output = handle.live_output.lock().map_err(|e| e.to_string())?;
479            Ok(live_output.clone())
480        } else {
481            Ok(String::new())
482        }
483    }
484
485    #[allow(dead_code)]
486    pub async fn cleanup_finished_processes(&self) -> Result<Vec<i64>, String> {
487        let mut finished_runs = Vec::new();
488
489        {
490            let processes = self.processes.lock().await;
491            let run_ids: Vec<i64> = processes.keys().cloned().collect();
492            drop(processes);
493
494            for run_id in run_ids {
495                if !self.is_process_running(run_id).await? {
496                    finished_runs.push(run_id);
497                }
498            }
499        }
500
501        {
502            let mut processes = self.processes.lock().await;
503            for run_id in &finished_runs {
504                processes.remove(run_id);
505            }
506        }
507
508        Ok(finished_runs)
509    }
510}
511
512impl Default for ProcessRegistry {
513    fn default() -> Self {
514        Self::new()
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[tokio::test]
523    async fn test_append_and_get_live_output() {
524        let registry = ProcessRegistry::new();
525        let run_id = registry
526            .register_claude_session(
527                "session-1".to_string(),
528                1234,
529                "/tmp/project".to_string(),
530                "task".to_string(),
531                "model".to_string(),
532                Arc::new(Mutex::new(None)),
533            )
534            .await
535            .unwrap();
536
537        registry.append_live_output(run_id, "line1").await.unwrap();
538        registry.append_live_output(run_id, "line2").await.unwrap();
539
540        let output = registry.get_live_output(run_id).await.unwrap();
541        assert_eq!(output, "line1\nline2\n");
542    }
543}