Skip to main content

aft/bash_background/
registry.rs

1use std::collections::{HashMap, VecDeque};
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9use serde::Serialize;
10
11#[cfg(unix)]
12use std::os::unix::process::CommandExt;
13
14use super::buffer::BgBuffer;
15use super::persistence::{
16    create_capture_file, read_exit_marker, read_task, session_tasks_dir, task_paths, unix_millis,
17    update_task, write_kill_marker_if_absent, write_task, ExitMarker, PersistedTask, TaskPaths,
18};
19use super::process::terminate_pgid;
20use super::{BgTaskInfo, BgTaskStatus};
21
22/// Default timeout for background bash tasks: 30 minutes.
23/// Agents can override per-call via the `timeout` parameter (in ms).
24const DEFAULT_BG_TIMEOUT: Duration = Duration::from_secs(30 * 60);
25const STALE_RUNNING_AFTER: Duration = Duration::from_secs(24 * 60 * 60);
26
27#[derive(Debug, Clone, Serialize)]
28pub struct BgCompletion {
29    pub task_id: String,
30    #[serde(skip_serializing)]
31    pub session_id: String,
32    pub status: BgTaskStatus,
33    pub exit_code: Option<i32>,
34    pub command: String,
35}
36
37#[derive(Debug, Clone, Serialize)]
38pub struct BgTaskSnapshot {
39    #[serde(flatten)]
40    pub info: BgTaskInfo,
41    pub exit_code: Option<i32>,
42    pub child_pid: Option<u32>,
43    pub workdir: String,
44    pub output_preview: String,
45    pub output_truncated: bool,
46    pub output_path: Option<String>,
47}
48
49#[derive(Clone)]
50pub struct BgTaskRegistry {
51    pub(crate) inner: Arc<RegistryInner>,
52}
53
54pub(crate) struct RegistryInner {
55    pub(crate) tasks: Mutex<HashMap<String, Arc<BgTask>>>,
56    pub(crate) completions: Mutex<VecDeque<BgCompletion>>,
57    watchdog_started: AtomicBool,
58    pub(crate) shutdown: AtomicBool,
59}
60
61pub(crate) struct BgTask {
62    pub(crate) task_id: String,
63    pub(crate) session_id: String,
64    pub(crate) paths: TaskPaths,
65    pub(crate) started: Instant,
66    pub(crate) state: Mutex<BgTaskState>,
67}
68
69pub(crate) struct BgTaskState {
70    pub(crate) metadata: PersistedTask,
71    pub(crate) child: Option<Child>,
72    pub(crate) detached: bool,
73    pub(crate) buffer: BgBuffer,
74}
75
76impl BgTaskRegistry {
77    pub fn new() -> Self {
78        Self {
79            inner: Arc::new(RegistryInner {
80                tasks: Mutex::new(HashMap::new()),
81                completions: Mutex::new(VecDeque::new()),
82                watchdog_started: AtomicBool::new(false),
83                shutdown: AtomicBool::new(false),
84            }),
85        }
86    }
87
88    #[cfg(unix)]
89    pub fn spawn(
90        &self,
91        command: &str,
92        session_id: String,
93        workdir: PathBuf,
94        env: HashMap<String, String>,
95        timeout: Option<Duration>,
96        storage_dir: PathBuf,
97        max_running: usize,
98    ) -> Result<String, String> {
99        self.start_watchdog();
100
101        let running = self.running_count();
102        if running >= max_running {
103            return Err(format!(
104                "background bash task limit exceeded: {running} running (max {max_running})"
105            ));
106        }
107
108        let timeout = timeout.or(Some(DEFAULT_BG_TIMEOUT));
109        let timeout_ms = timeout.map(|timeout| timeout.as_millis() as u64);
110        let task_id = self.generate_unique_task_id()?;
111        let paths = task_paths(&storage_dir, &session_id, &task_id);
112        fs::create_dir_all(&paths.dir)
113            .map_err(|e| format!("failed to create background task dir: {e}"))?;
114
115        let mut metadata = PersistedTask::starting(
116            task_id.clone(),
117            session_id.clone(),
118            command.to_string(),
119            workdir.clone(),
120            timeout_ms,
121        );
122        write_task(&paths.json, &metadata)
123            .map_err(|e| format!("failed to persist background task metadata: {e}"))?;
124
125        let stdout = create_capture_file(&paths.stdout)
126            .map_err(|e| format!("failed to create stdout capture file: {e}"))?;
127        let stderr = create_capture_file(&paths.stderr)
128            .map_err(|e| format!("failed to create stderr capture file: {e}"))?;
129
130        let child = detached_shell_command(command, &paths.exit)
131            .current_dir(&workdir)
132            .envs(&env)
133            .stdin(Stdio::null())
134            .stdout(Stdio::from(stdout))
135            .stderr(Stdio::from(stderr))
136            .spawn()
137            .map_err(|e| format!("failed to spawn background bash command: {e}"))?;
138
139        let child_pid = child.id();
140        metadata.mark_running(child_pid, child_pid as i32);
141        write_task(&paths.json, &metadata)
142            .map_err(|e| format!("failed to persist running background task metadata: {e}"))?;
143
144        let task = Arc::new(BgTask {
145            task_id: task_id.clone(),
146            session_id,
147            paths: paths.clone(),
148            started: Instant::now(),
149            state: Mutex::new(BgTaskState {
150                metadata,
151                child: Some(child),
152                detached: false,
153                buffer: BgBuffer::new(paths.stdout.clone(), paths.stderr.clone()),
154            }),
155        });
156
157        self.inner
158            .tasks
159            .lock()
160            .map_err(|_| "background task registry lock poisoned".to_string())?
161            .insert(task_id.clone(), task);
162
163        Ok(task_id)
164    }
165
166    #[cfg(windows)]
167    pub fn spawn(
168        &self,
169        _command: &str,
170        _session_id: String,
171        _workdir: PathBuf,
172        _env: HashMap<String, String>,
173        _timeout: Option<Duration>,
174        _storage_dir: PathBuf,
175        _max_running: usize,
176    ) -> Result<String, String> {
177        Err("background bash is not yet supported on Windows".to_string())
178    }
179
180    pub fn replay_session(&self, storage_dir: &Path, session_id: &str) -> Result<(), String> {
181        self.start_watchdog();
182        let dir = session_tasks_dir(storage_dir, session_id);
183        if !dir.exists() {
184            return Ok(());
185        }
186
187        let entries = fs::read_dir(&dir)
188            .map_err(|e| format!("failed to read background task dir {}: {e}", dir.display()))?;
189        for entry in entries.flatten() {
190            let path = entry.path();
191            if path.extension().and_then(|extension| extension.to_str()) != Some("json") {
192                continue;
193            }
194            let Ok(mut metadata) = read_task(&path) else {
195                continue;
196            };
197            if metadata.session_id != session_id {
198                continue;
199            }
200
201            let paths = task_paths(storage_dir, session_id, &metadata.task_id);
202            match metadata.status {
203                BgTaskStatus::Starting => {
204                    metadata.mark_terminal(
205                        BgTaskStatus::Failed,
206                        None,
207                        Some("spawn aborted".to_string()),
208                    );
209                    let _ = write_task(&paths.json, &metadata);
210                    self.enqueue_completion_if_needed(&metadata);
211                }
212                BgTaskStatus::Running => {
213                    if self.running_metadata_is_stale(&metadata) {
214                        metadata.mark_terminal(
215                            BgTaskStatus::Killed,
216                            None,
217                            Some("orphaned (>24h)".to_string()),
218                        );
219                        if !paths.exit.exists() {
220                            let _ = write_kill_marker_if_absent(&paths.exit);
221                        }
222                        let _ = write_task(&paths.json, &metadata);
223                        self.enqueue_completion_if_needed(&metadata);
224                    } else if let Ok(Some(marker)) = read_exit_marker(&paths.exit) {
225                        metadata = terminal_metadata_from_marker(metadata, marker, None);
226                        let _ = write_task(&paths.json, &metadata);
227                        self.enqueue_completion_if_needed(&metadata);
228                    } else {
229                        self.insert_rehydrated_task(metadata, paths, true)?;
230                    }
231                }
232                _ if metadata.status.is_terminal() => {
233                    self.insert_rehydrated_task(metadata.clone(), paths, true)?;
234                    self.enqueue_completion_if_needed(&metadata);
235                }
236                _ => {}
237            }
238        }
239
240        Ok(())
241    }
242
243    pub fn status(
244        &self,
245        task_id: &str,
246        session_id: &str,
247        preview_bytes: usize,
248    ) -> Option<BgTaskSnapshot> {
249        let task = self.task_for_session(task_id, session_id)?;
250        let _ = self.poll_task(&task);
251        Some(task.snapshot(preview_bytes))
252    }
253
254    pub fn list(&self, preview_bytes: usize) -> Vec<BgTaskSnapshot> {
255        let tasks = self
256            .inner
257            .tasks
258            .lock()
259            .map(|tasks| tasks.values().cloned().collect::<Vec<_>>())
260            .unwrap_or_default();
261        tasks
262            .into_iter()
263            .map(|task| {
264                let _ = self.poll_task(&task);
265                task.snapshot(preview_bytes)
266            })
267            .collect()
268    }
269
270    pub fn kill(&self, task_id: &str, session_id: &str) -> Result<BgTaskSnapshot, String> {
271        self.kill_with_status(task_id, session_id, BgTaskStatus::Killed)
272    }
273
274    pub(crate) fn kill_for_timeout(&self, task_id: &str, session_id: &str) -> Result<(), String> {
275        self.kill_with_status(task_id, session_id, BgTaskStatus::TimedOut)
276            .map(|_| ())
277    }
278
279    pub fn cleanup_finished(&self, _older_than: Duration) {}
280
281    pub fn drain_completions(&self) -> Vec<BgCompletion> {
282        self.drain_completions_for_session(None)
283    }
284
285    pub fn drain_completions_for_session(&self, session_id: Option<&str>) -> Vec<BgCompletion> {
286        let mut completions = match self.inner.completions.lock() {
287            Ok(completions) => completions,
288            Err(_) => return Vec::new(),
289        };
290
291        let drained = if let Some(session_id) = session_id {
292            let mut matched = Vec::new();
293            let mut retained = VecDeque::new();
294            while let Some(completion) = completions.pop_front() {
295                if completion.session_id == session_id {
296                    matched.push(completion);
297                } else {
298                    retained.push_back(completion);
299                }
300            }
301            *completions = retained;
302            matched
303        } else {
304            completions.drain(..).collect()
305        };
306        drop(completions);
307
308        for completion in &drained {
309            if let Some(task) = self.task_for_session(&completion.task_id, &completion.session_id) {
310                let _ = task.set_completion_delivered(true);
311            }
312        }
313
314        drained
315    }
316
317    pub fn pending_completions_for_session(&self, session_id: &str) -> Vec<BgCompletion> {
318        self.inner
319            .completions
320            .lock()
321            .map(|completions| {
322                completions
323                    .iter()
324                    .filter(|completion| completion.session_id == session_id)
325                    .cloned()
326                    .collect()
327            })
328            .unwrap_or_default()
329    }
330
331    pub fn detach(&self) {
332        self.inner.shutdown.store(true, Ordering::SeqCst);
333        if let Ok(mut tasks) = self.inner.tasks.lock() {
334            for task in tasks.values() {
335                if let Ok(mut state) = task.state.lock() {
336                    state.child = None;
337                    state.detached = true;
338                }
339            }
340            tasks.clear();
341        }
342    }
343
344    pub fn shutdown(&self) {
345        let tasks = self
346            .inner
347            .tasks
348            .lock()
349            .map(|tasks| {
350                tasks
351                    .values()
352                    .map(|task| (task.task_id.clone(), task.session_id.clone()))
353                    .collect::<Vec<_>>()
354            })
355            .unwrap_or_default();
356        for (task_id, session_id) in tasks {
357            let _ = self.kill(&task_id, &session_id);
358        }
359    }
360
361    pub(crate) fn poll_task(&self, task: &Arc<BgTask>) -> Result<(), String> {
362        let marker = match read_exit_marker(&task.paths.exit) {
363            Ok(Some(marker)) => marker,
364            Ok(None) => return Ok(()),
365            Err(error) => return Err(format!("failed to read exit marker: {error}")),
366        };
367        self.finalize_from_marker(task, marker, None)
368    }
369
370    pub(crate) fn reap_child(&self, task: &Arc<BgTask>) {
371        let Ok(mut state) = task.state.lock() else {
372            return;
373        };
374        if let Some(child) = state.child.as_mut() {
375            if matches!(child.try_wait(), Ok(Some(_))) {
376                state.child = None;
377                state.detached = true;
378            }
379        }
380    }
381
382    pub(crate) fn running_tasks(&self) -> Vec<Arc<BgTask>> {
383        self.inner
384            .tasks
385            .lock()
386            .map(|tasks| {
387                tasks
388                    .values()
389                    .filter(|task| task.is_running())
390                    .cloned()
391                    .collect()
392            })
393            .unwrap_or_default()
394    }
395
396    fn insert_rehydrated_task(
397        &self,
398        metadata: PersistedTask,
399        paths: TaskPaths,
400        detached: bool,
401    ) -> Result<(), String> {
402        let task_id = metadata.task_id.clone();
403        let session_id = metadata.session_id.clone();
404        let task = Arc::new(BgTask {
405            task_id: task_id.clone(),
406            session_id,
407            paths: paths.clone(),
408            started: Instant::now(),
409            state: Mutex::new(BgTaskState {
410                metadata,
411                child: None,
412                detached,
413                buffer: BgBuffer::new(paths.stdout.clone(), paths.stderr.clone()),
414            }),
415        });
416        self.inner
417            .tasks
418            .lock()
419            .map_err(|_| "background task registry lock poisoned".to_string())?
420            .insert(task_id, task);
421        Ok(())
422    }
423
424    fn kill_with_status(
425        &self,
426        task_id: &str,
427        session_id: &str,
428        terminal_status: BgTaskStatus,
429    ) -> Result<BgTaskSnapshot, String> {
430        let task = self
431            .task_for_session(task_id, session_id)
432            .ok_or_else(|| format!("background task not found: {task_id}"))?;
433
434        {
435            let mut state = task
436                .state
437                .lock()
438                .map_err(|_| "background task lock poisoned".to_string())?;
439            if state.metadata.status.is_terminal() {
440                return Ok(task.snapshot_locked(&state, 5 * 1024));
441            }
442
443            state.metadata.status = BgTaskStatus::Killing;
444            write_task(&task.paths.json, &state.metadata)
445                .map_err(|e| format!("failed to persist killing state: {e}"))?;
446
447            if let Some(pgid) = state.metadata.pgid {
448                terminate_pgid(pgid, state.child.as_mut());
449            }
450            if let Some(child) = state.child.as_mut() {
451                let _ = child.wait();
452            }
453            state.child = None;
454            state.detached = true;
455
456            if !task.paths.exit.exists() {
457                write_kill_marker_if_absent(&task.paths.exit)
458                    .map_err(|e| format!("failed to write kill marker: {e}"))?;
459            }
460
461            let exit_code = if terminal_status == BgTaskStatus::TimedOut {
462                Some(124)
463            } else {
464                None
465            };
466            state
467                .metadata
468                .mark_terminal(terminal_status, exit_code, None);
469            write_task(&task.paths.json, &state.metadata)
470                .map_err(|e| format!("failed to persist killed state: {e}"))?;
471            state.buffer.enforce_terminal_cap();
472            self.enqueue_completion_locked(&state.metadata);
473        }
474
475        Ok(task.snapshot(5 * 1024))
476    }
477
478    fn finalize_from_marker(
479        &self,
480        task: &Arc<BgTask>,
481        marker: ExitMarker,
482        reason: Option<String>,
483    ) -> Result<(), String> {
484        let mut state = task
485            .state
486            .lock()
487            .map_err(|_| "background task lock poisoned".to_string())?;
488        if state.metadata.status.is_terminal() {
489            return Ok(());
490        }
491
492        let updated = update_task(&task.paths.json, |metadata| {
493            let new_metadata = terminal_metadata_from_marker(metadata.clone(), marker, reason);
494            *metadata = new_metadata;
495        })
496        .map_err(|e| format!("failed to persist terminal state: {e}"))?;
497        state.metadata = updated;
498        state.child = None;
499        state.detached = true;
500        state.buffer.enforce_terminal_cap();
501        self.enqueue_completion_locked(&state.metadata);
502        Ok(())
503    }
504
505    fn enqueue_completion_if_needed(&self, metadata: &PersistedTask) {
506        if metadata.status.is_terminal() && !metadata.completion_delivered {
507            self.enqueue_completion_locked(metadata);
508        }
509    }
510
511    fn enqueue_completion_locked(&self, metadata: &PersistedTask) {
512        if !metadata.status.is_terminal() || metadata.completion_delivered {
513            return;
514        }
515        if let Ok(mut completions) = self.inner.completions.lock() {
516            if completions
517                .iter()
518                .any(|completion| completion.task_id == metadata.task_id)
519            {
520                return;
521            }
522            completions.push_back(BgCompletion {
523                task_id: metadata.task_id.clone(),
524                session_id: metadata.session_id.clone(),
525                status: metadata.status.clone(),
526                exit_code: metadata.exit_code,
527                command: metadata.command.clone(),
528            });
529        }
530    }
531
532    fn task(&self, task_id: &str) -> Option<Arc<BgTask>> {
533        self.inner
534            .tasks
535            .lock()
536            .ok()
537            .and_then(|tasks| tasks.get(task_id).cloned())
538    }
539
540    fn task_for_session(&self, task_id: &str, session_id: &str) -> Option<Arc<BgTask>> {
541        self.task(task_id)
542            .filter(|task| task.session_id == session_id)
543    }
544
545    fn running_count(&self) -> usize {
546        self.inner
547            .tasks
548            .lock()
549            .map(|tasks| tasks.values().filter(|task| task.is_running()).count())
550            .unwrap_or(0)
551    }
552
553    fn start_watchdog(&self) {
554        if !self.inner.watchdog_started.swap(true, Ordering::SeqCst) {
555            super::watchdog::start(self.clone());
556        }
557    }
558
559    fn running_metadata_is_stale(&self, metadata: &PersistedTask) -> bool {
560        unix_millis().saturating_sub(metadata.started_at) > STALE_RUNNING_AFTER.as_millis() as u64
561    }
562
563    #[cfg(test)]
564    pub fn task_json_path(&self, task_id: &str, session_id: &str) -> Option<PathBuf> {
565        self.task_for_session(task_id, session_id)
566            .map(|task| task.paths.json.clone())
567    }
568
569    #[cfg(test)]
570    pub fn task_exit_path(&self, task_id: &str, session_id: &str) -> Option<PathBuf> {
571        self.task_for_session(task_id, session_id)
572            .map(|task| task.paths.exit.clone())
573    }
574
575    /// Generate a `bgb-{8hex}` slug that is unique against live tasks and queued completions.
576    fn generate_unique_task_id(&self) -> Result<String, String> {
577        for _ in 0..32 {
578            let candidate = random_slug();
579            let tasks = self
580                .inner
581                .tasks
582                .lock()
583                .map_err(|_| "background task registry lock poisoned".to_string())?;
584            if tasks.contains_key(&candidate) {
585                continue;
586            }
587            let completions = self
588                .inner
589                .completions
590                .lock()
591                .map_err(|_| "background completions lock poisoned".to_string())?;
592            if completions
593                .iter()
594                .any(|completion| completion.task_id == candidate)
595            {
596                continue;
597            }
598            return Ok(candidate);
599        }
600        Err("failed to allocate unique background task id after 32 attempts".to_string())
601    }
602}
603
604impl Default for BgTaskRegistry {
605    fn default() -> Self {
606        Self::new()
607    }
608}
609
610impl BgTask {
611    fn snapshot(&self, preview_bytes: usize) -> BgTaskSnapshot {
612        let state = self
613            .state
614            .lock()
615            .unwrap_or_else(|poison| poison.into_inner());
616        self.snapshot_locked(&state, preview_bytes)
617    }
618
619    fn snapshot_locked(&self, state: &BgTaskState, preview_bytes: usize) -> BgTaskSnapshot {
620        let metadata = &state.metadata;
621        let duration_ms = metadata.duration_ms.or_else(|| {
622            metadata
623                .status
624                .is_terminal()
625                .then(|| self.started.elapsed().as_millis() as u64)
626        });
627        let (output_preview, output_truncated) = state.buffer.read_tail(preview_bytes);
628        BgTaskSnapshot {
629            info: BgTaskInfo {
630                task_id: self.task_id.clone(),
631                status: metadata.status.clone(),
632                command: metadata.command.clone(),
633                started_at: metadata.started_at,
634                duration_ms,
635            },
636            exit_code: metadata.exit_code,
637            child_pid: metadata.child_pid,
638            workdir: metadata.workdir.display().to_string(),
639            output_preview,
640            output_truncated,
641            output_path: state
642                .buffer
643                .output_path()
644                .map(|path| path.display().to_string()),
645        }
646    }
647
648    pub(crate) fn is_running(&self) -> bool {
649        self.state
650            .lock()
651            .map(|state| state.metadata.status == BgTaskStatus::Running)
652            .unwrap_or(false)
653    }
654
655    fn set_completion_delivered(&self, delivered: bool) -> Result<(), String> {
656        let mut state = self
657            .state
658            .lock()
659            .map_err(|_| "background task lock poisoned".to_string())?;
660        let updated = update_task(&self.paths.json, |metadata| {
661            metadata.completion_delivered = delivered;
662        })
663        .map_err(|e| format!("failed to update completion delivery: {e}"))?;
664        state.metadata = updated;
665        Ok(())
666    }
667}
668
669fn terminal_metadata_from_marker(
670    mut metadata: PersistedTask,
671    marker: ExitMarker,
672    reason: Option<String>,
673) -> PersistedTask {
674    match marker {
675        ExitMarker::Code(code) => {
676            let status = if code == 0 {
677                BgTaskStatus::Completed
678            } else {
679                BgTaskStatus::Failed
680            };
681            metadata.mark_terminal(status, Some(code), reason);
682        }
683        ExitMarker::Killed => metadata.mark_terminal(BgTaskStatus::Killed, None, reason),
684    }
685    metadata
686}
687
688#[cfg(unix)]
689fn detached_shell_command(command: &str, exit_path: &Path) -> Command {
690    let mut cmd = Command::new("/bin/sh");
691    cmd.arg("-c")
692        .arg("\"$0\" -c \"$1\"; code=$?; printf \"%s\" \"$code\" > \"$2.tmp.$$\"; mv -f \"$2.tmp.$$\" \"$2\"")
693        .arg("/bin/sh")
694        .arg(command)
695        .arg(exit_path);
696    unsafe {
697        cmd.pre_exec(|| {
698            if libc::setsid() == -1 {
699                return Err(std::io::Error::last_os_error());
700            }
701            Ok(())
702        });
703    }
704    cmd
705}
706
707fn random_slug() -> String {
708    static COUNTER: AtomicU64 = AtomicU64::new(0);
709    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
710    let mixed = unix_millis_nanos()
711        ^ (std::process::id() as u128).wrapping_mul(0x9E3779B97F4A7C15)
712        ^ (counter as u128).wrapping_mul(0xBF58476D1CE4E5B9);
713    format!("bgb-{:08x}", (mixed as u32))
714}
715
716fn unix_millis_nanos() -> u128 {
717    SystemTime::now()
718        .duration_since(UNIX_EPOCH)
719        .map(|duration| duration.as_nanos())
720        .unwrap_or(0)
721}