Skip to main content

bamboo_agent/process/
registry.rs

1//! Process management for external agent runs and Claude sessions
2//!
3//! This module provides process lifecycle management:
4//! - Registration and tracking of running processes
5//! - Graceful and forceful termination
6//! - Live output capture
7//! - Cross-platform process killing
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tokio::process::Child;
14use tokio::sync::Mutex as AsyncMutex;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub enum ProcessType {
18    AgentRun { agent_id: i64, agent_name: String },
19    ClaudeSession { session_id: String },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ProcessInfo {
24    pub run_id: i64,
25    pub process_type: ProcessType,
26    pub pid: u32,
27    pub started_at: DateTime<Utc>,
28    pub project_path: String,
29    pub task: String,
30    pub model: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct ProcessRegistrationConfig {
35    pub run_id: i64,
36    pub agent_id: i64,
37    pub agent_name: String,
38    pub pid: u32,
39    pub project_path: String,
40    pub task: String,
41    pub model: String,
42}
43
44#[allow(dead_code)]
45pub struct ProcessHandle {
46    pub info: ProcessInfo,
47    pub child: Arc<Mutex<Option<Child>>>,
48    pub live_output: Arc<Mutex<String>>,
49}
50
51pub struct ProcessRegistry {
52    processes: Arc<AsyncMutex<HashMap<i64, ProcessHandle>>>,
53    next_id: Arc<Mutex<i64>>,
54}
55
56impl ProcessRegistry {
57    pub fn new() -> Self {
58        Self {
59            processes: Arc::new(AsyncMutex::new(HashMap::new())),
60            next_id: Arc::new(Mutex::new(1000000)),
61        }
62    }
63
64    pub fn generate_id(&self) -> Result<i64, String> {
65        let mut next_id = self.next_id.lock().map_err(|e| e.to_string())?;
66        let id = *next_id;
67        *next_id += 1;
68        Ok(id)
69    }
70
71    pub async fn register_process(
72        &self,
73        config: ProcessRegistrationConfig,
74        child: Child,
75    ) -> Result<(), String> {
76        let ProcessRegistrationConfig {
77            run_id,
78            agent_id,
79            agent_name,
80            pid,
81            project_path,
82            task,
83            model,
84        } = config;
85
86        let process_info = ProcessInfo {
87            run_id,
88            process_type: ProcessType::AgentRun {
89                agent_id,
90                agent_name,
91            },
92            pid,
93            started_at: Utc::now(),
94            project_path,
95            task,
96            model,
97        };
98
99        self.register_process_internal(run_id, process_info, child).await
100    }
101
102    pub async fn register_sidecar_process(
103        &self,
104        config: ProcessRegistrationConfig,
105    ) -> Result<(), String> {
106        let ProcessRegistrationConfig {
107            run_id,
108            agent_id,
109            agent_name,
110            pid,
111            project_path,
112            task,
113            model,
114        } = config;
115
116        let process_info = ProcessInfo {
117            run_id,
118            process_type: ProcessType::AgentRun {
119                agent_id,
120                agent_name,
121            },
122            pid,
123            started_at: Utc::now(),
124            project_path,
125            task,
126            model,
127        };
128
129        let mut processes = self.processes.lock().await;
130
131        let process_handle = ProcessHandle {
132            info: process_info,
133            child: Arc::new(Mutex::new(None)),
134            live_output: Arc::new(Mutex::new(String::new())),
135        };
136
137        processes.insert(run_id, process_handle);
138        Ok(())
139    }
140
141    pub async fn register_claude_session(
142        &self,
143        session_id: String,
144        pid: u32,
145        project_path: String,
146        task: String,
147        model: String,
148        child: Arc<Mutex<Option<Child>>>,
149    ) -> Result<i64, String> {
150        let run_id = self.generate_id()?;
151
152        let process_info = ProcessInfo {
153            run_id,
154            process_type: ProcessType::ClaudeSession { session_id },
155            pid,
156            started_at: Utc::now(),
157            project_path,
158            task,
159            model,
160        };
161
162        let mut processes = self.processes.lock().await;
163
164        let process_handle = ProcessHandle {
165            info: process_info,
166            child,
167            live_output: Arc::new(Mutex::new(String::new())),
168        };
169
170        processes.insert(run_id, process_handle);
171        Ok(run_id)
172    }
173
174    async fn register_process_internal(
175        &self,
176        run_id: i64,
177        process_info: ProcessInfo,
178        child: Child,
179    ) -> Result<(), String> {
180        let mut processes = self.processes.lock().await;
181
182        let process_handle = ProcessHandle {
183            info: process_info,
184            child: Arc::new(Mutex::new(Some(child))),
185            live_output: Arc::new(Mutex::new(String::new())),
186        };
187
188        processes.insert(run_id, process_handle);
189        Ok(())
190    }
191
192    pub async fn get_running_claude_sessions(&self) -> Result<Vec<ProcessInfo>, String> {
193        let processes = self.processes.lock().await;
194        Ok(processes
195            .values()
196            .filter_map(|handle| match &handle.info.process_type {
197                ProcessType::ClaudeSession { .. } => Some(handle.info.clone()),
198                _ => None,
199            })
200            .collect())
201    }
202
203    pub async fn get_claude_session_by_id(
204        &self,
205        session_id: &str,
206    ) -> Result<Option<ProcessInfo>, String> {
207        let processes = self.processes.lock().await;
208        Ok(processes
209            .values()
210            .find(|handle| match &handle.info.process_type {
211                ProcessType::ClaudeSession { session_id: sid } => sid == session_id,
212                _ => false,
213            })
214            .map(|handle| handle.info.clone()))
215    }
216
217    pub async fn unregister_process(&self, run_id: i64) -> Result<(), String> {
218        let mut processes = self.processes.lock().await;
219        processes.remove(&run_id);
220        Ok(())
221    }
222
223    /// Synchronous version for use in non-async contexts
224    #[allow(dead_code)]
225    fn unregister_process_sync(&self, run_id: i64) -> Result<(), String> {
226        // Use try_lock to avoid blocking in sync context
227        // If we can't get the lock, that's okay - the process will be cleaned up later
228        if let Ok(mut processes) = self.processes.try_lock() {
229            processes.remove(&run_id);
230        }
231        Ok(())
232    }
233
234    #[allow(dead_code)]
235    pub async fn get_running_processes(&self) -> Result<Vec<ProcessInfo>, String> {
236        let processes = self.processes.lock().await;
237        Ok(processes
238            .values()
239            .map(|handle| handle.info.clone())
240            .collect())
241    }
242
243    pub async fn get_running_agent_processes(&self) -> Result<Vec<ProcessInfo>, String> {
244        let processes = self.processes.lock().await;
245        Ok(processes
246            .values()
247            .filter_map(|handle| match &handle.info.process_type {
248                ProcessType::AgentRun { .. } => Some(handle.info.clone()),
249                _ => None,
250            })
251            .collect())
252    }
253
254    #[allow(dead_code)]
255    pub async fn get_process(&self, run_id: i64) -> Result<Option<ProcessInfo>, String> {
256        let processes = self.processes.lock().await;
257        Ok(processes.get(&run_id).map(|handle| handle.info.clone()))
258    }
259
260    pub async fn kill_process(&self, run_id: i64) -> Result<bool, String> {
261        use log::{error, info, warn};
262
263        let (pid, child_arc) = {
264            let processes = self.processes.lock().await;
265            if let Some(handle) = processes.get(&run_id) {
266                (handle.info.pid, handle.child.clone())
267            } else {
268                warn!("Process {} not found in registry", run_id);
269                return Ok(false);
270            }
271        };
272
273        info!(
274            "Attempting graceful shutdown of process {} (PID: {})",
275            run_id, pid
276        );
277
278        let kill_sent = {
279            let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
280            if let Some(child) = child_guard.as_mut() {
281                match child.start_kill() {
282                    Ok(_) => {
283                        info!("Successfully sent kill signal to process {}", run_id);
284                        true
285                    }
286                    Err(e) => {
287                        error!("Failed to send kill signal to process {}: {}", run_id, e);
288                        false
289                    }
290                }
291            } else {
292                warn!(
293                    "No child handle available for process {} (PID: {}), attempting system kill",
294                    run_id, pid
295                );
296                false
297            }
298        };
299
300        if !kill_sent {
301            info!(
302                "Attempting fallback kill for process {} (PID: {})",
303                run_id, pid
304            );
305            match self.kill_process_by_pid(run_id, pid).await {
306                Ok(true) => return Ok(true),
307                Ok(false) => warn!(
308                    "Fallback kill also failed for process {} (PID: {})",
309                    run_id, pid
310                ),
311                Err(e) => error!("Error during fallback kill: {}", e),
312            }
313        }
314
315        let wait_result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
316            loop {
317                let status = {
318                    let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
319                    if let Some(child) = child_guard.as_mut() {
320                        match child.try_wait() {
321                            Ok(Some(status)) => {
322                                info!("Process {} exited with status: {:?}", run_id, status);
323                                *child_guard = None;
324                                Some(Ok::<(), String>(()))
325                            }
326                            Ok(None) => None,
327                            Err(e) => {
328                                error!("Error checking process status: {}", e);
329                                Some(Err(e.to_string()))
330                            }
331                        }
332                    } else {
333                        Some(Ok(()))
334                    }
335                };
336
337                match status {
338                    Some(result) => return result,
339                    None => {
340                        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
341                    }
342                }
343            }
344        })
345        .await;
346
347        match wait_result {
348            Ok(Ok(_)) => {
349                info!("Process {} exited gracefully", run_id);
350            }
351            Ok(Err(e)) => {
352                error!("Error waiting for process {}: {}", run_id, e);
353            }
354            Err(_) => {
355                warn!("Process {} didn't exit within 5 seconds after kill", run_id);
356                if let Ok(mut child_guard) = child_arc.lock() {
357                    *child_guard = None;
358                }
359                let _ = self.kill_process_by_pid(run_id, pid).await;
360            }
361        }
362
363        self.unregister_process(run_id).await?;
364
365        Ok(true)
366    }
367
368    pub async fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result<bool, String> {
369        use log::{error, info, warn};
370
371        info!("Attempting to kill process {} by PID {}", run_id, pid);
372
373        let kill_result = if cfg!(target_os = "windows") {
374            std::process::Command::new("taskkill")
375                .args(["/F", "/PID", &pid.to_string()])
376                .output()
377        } else {
378            let term_result = std::process::Command::new("kill")
379                .args(["-TERM", &pid.to_string()])
380                .output();
381
382            match &term_result {
383                Ok(output) if output.status.success() => {
384                    info!("Sent SIGTERM to PID {}", pid);
385                    tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
386
387                    let check_result = std::process::Command::new("kill")
388                        .args(["-0", &pid.to_string()])
389                        .output();
390
391                    if let Ok(output) = check_result {
392                        if output.status.success() {
393                            warn!(
394                                "Process {} still running after SIGTERM, sending SIGKILL",
395                                pid
396                            );
397                            std::process::Command::new("kill")
398                                .args(["-KILL", &pid.to_string()])
399                                .output()
400                        } else {
401                            term_result
402                        }
403                    } else {
404                        term_result
405                    }
406                }
407                _ => {
408                    warn!("SIGTERM failed for PID {}, trying SIGKILL", pid);
409                    std::process::Command::new("kill")
410                        .args(["-KILL", &pid.to_string()])
411                        .output()
412                }
413            }
414        };
415
416        match kill_result {
417            Ok(output) => {
418                if output.status.success() {
419                    info!("Successfully killed process with PID {}", pid);
420                    self.unregister_process(run_id).await?;
421                    Ok(true)
422                } else {
423                    let error_msg = String::from_utf8_lossy(&output.stderr);
424                    warn!("Failed to kill PID {}: {}", pid, error_msg);
425                    Ok(false)
426                }
427            }
428            Err(e) => {
429                error!("Failed to execute kill command for PID {}: {}", pid, e);
430                Err(format!("Failed to execute kill command: {}", e))
431            }
432        }
433    }
434
435    #[allow(dead_code)]
436    pub async fn is_process_running(&self, run_id: i64) -> Result<bool, String> {
437        let processes = self.processes.lock().await;
438
439        if let Some(handle) = processes.get(&run_id) {
440            let child_arc = handle.child.clone();
441            drop(processes);
442
443            let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
444            if let Some(ref mut child) = child_guard.as_mut() {
445                match child.try_wait() {
446                    Ok(Some(_)) => {
447                        *child_guard = None;
448                        Ok(false)
449                    }
450                    Ok(None) => Ok(true),
451                    Err(_) => {
452                        *child_guard = None;
453                        Ok(false)
454                    }
455                }
456            } else {
457                Ok(false)
458            }
459        } else {
460            Ok(false)
461        }
462    }
463
464    pub async fn append_live_output(&self, run_id: i64, output: &str) -> Result<(), String> {
465        let processes = self.processes.lock().await;
466        if let Some(handle) = processes.get(&run_id) {
467            let mut live_output = handle.live_output.lock().map_err(|e| e.to_string())?;
468            live_output.push_str(output);
469            live_output.push('\n');
470        }
471        Ok(())
472    }
473
474    pub async fn get_live_output(&self, run_id: i64) -> Result<String, String> {
475        let processes = self.processes.lock().await;
476        if let Some(handle) = processes.get(&run_id) {
477            let live_output = handle.live_output.lock().map_err(|e| e.to_string())?;
478            Ok(live_output.clone())
479        } else {
480            Ok(String::new())
481        }
482    }
483
484    #[allow(dead_code)]
485    pub async fn cleanup_finished_processes(&self) -> Result<Vec<i64>, String> {
486        let mut finished_runs = Vec::new();
487
488        {
489            let processes = self.processes.lock().await;
490            let run_ids: Vec<i64> = processes.keys().cloned().collect();
491            drop(processes);
492
493            for run_id in run_ids {
494                if !self.is_process_running(run_id).await? {
495                    finished_runs.push(run_id);
496                }
497            }
498        }
499
500        {
501            let mut processes = self.processes.lock().await;
502            for run_id in &finished_runs {
503                processes.remove(run_id);
504            }
505        }
506
507        Ok(finished_runs)
508    }
509}
510
511impl Default for ProcessRegistry {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[tokio::test]
522    async fn test_append_and_get_live_output() {
523        let registry = ProcessRegistry::new();
524        let run_id = registry
525            .register_claude_session(
526                "session-1".to_string(),
527                1234,
528                "/tmp/project".to_string(),
529                "task".to_string(),
530                "model".to_string(),
531                Arc::new(Mutex::new(None)),
532            )
533            .await
534            .unwrap();
535
536        registry.append_live_output(run_id, "line1").await.unwrap();
537        registry.append_live_output(run_id, "line2").await.unwrap();
538
539        let output = registry.get_live_output(run_id).await.unwrap();
540        assert_eq!(output, "line1\nline2\n");
541    }
542}