Skip to main content

aft/bash_background/
registry.rs

1use std::collections::{HashMap, VecDeque};
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9use serde::Serialize;
10
11#[cfg(unix)]
12use std::os::unix::process::CommandExt;
13
14use super::buffer::BgBuffer;
15use super::persistence::{
16    create_capture_file, read_exit_marker, read_task, session_tasks_dir, task_paths, unix_millis,
17    update_task, write_kill_marker_if_absent, write_task, ExitMarker, PersistedTask, TaskPaths,
18};
19#[cfg(unix)]
20use super::process::terminate_pgid;
21use super::{BgTaskInfo, BgTaskStatus};
22
23/// Default timeout for background bash tasks: 30 minutes.
24/// Agents can override per-call via the `timeout` parameter (in ms).
25const DEFAULT_BG_TIMEOUT: Duration = Duration::from_secs(30 * 60);
26const STALE_RUNNING_AFTER: Duration = Duration::from_secs(24 * 60 * 60);
27
28#[derive(Debug, Clone, Serialize)]
29pub struct BgCompletion {
30    pub task_id: String,
31    #[serde(skip_serializing)]
32    pub session_id: String,
33    pub status: BgTaskStatus,
34    pub exit_code: Option<i32>,
35    pub command: String,
36}
37
38#[derive(Debug, Clone, Serialize)]
39pub struct BgTaskSnapshot {
40    #[serde(flatten)]
41    pub info: BgTaskInfo,
42    pub exit_code: Option<i32>,
43    pub child_pid: Option<u32>,
44    pub workdir: String,
45    pub output_preview: String,
46    pub output_truncated: bool,
47    pub output_path: Option<String>,
48}
49
50#[derive(Clone)]
51pub struct BgTaskRegistry {
52    pub(crate) inner: Arc<RegistryInner>,
53}
54
55pub(crate) struct RegistryInner {
56    pub(crate) tasks: Mutex<HashMap<String, Arc<BgTask>>>,
57    pub(crate) completions: Mutex<VecDeque<BgCompletion>>,
58    watchdog_started: AtomicBool,
59    pub(crate) shutdown: AtomicBool,
60}
61
62pub(crate) struct BgTask {
63    pub(crate) task_id: String,
64    pub(crate) session_id: String,
65    pub(crate) paths: TaskPaths,
66    pub(crate) started: Instant,
67    pub(crate) state: Mutex<BgTaskState>,
68}
69
70pub(crate) struct BgTaskState {
71    pub(crate) metadata: PersistedTask,
72    pub(crate) child: Option<Child>,
73    pub(crate) detached: bool,
74    pub(crate) buffer: BgBuffer,
75}
76
77impl BgTaskRegistry {
78    pub fn new() -> Self {
79        Self {
80            inner: Arc::new(RegistryInner {
81                tasks: Mutex::new(HashMap::new()),
82                completions: Mutex::new(VecDeque::new()),
83                watchdog_started: AtomicBool::new(false),
84                shutdown: AtomicBool::new(false),
85            }),
86        }
87    }
88
89    #[cfg(unix)]
90    pub fn spawn(
91        &self,
92        command: &str,
93        session_id: String,
94        workdir: PathBuf,
95        env: HashMap<String, String>,
96        timeout: Option<Duration>,
97        storage_dir: PathBuf,
98        max_running: usize,
99    ) -> Result<String, String> {
100        self.start_watchdog();
101
102        let running = self.running_count();
103        if running >= max_running {
104            return Err(format!(
105                "background bash task limit exceeded: {running} running (max {max_running})"
106            ));
107        }
108
109        let timeout = timeout.or(Some(DEFAULT_BG_TIMEOUT));
110        let timeout_ms = timeout.map(|timeout| timeout.as_millis() as u64);
111        let task_id = self.generate_unique_task_id()?;
112        let paths = task_paths(&storage_dir, &session_id, &task_id);
113        fs::create_dir_all(&paths.dir)
114            .map_err(|e| format!("failed to create background task dir: {e}"))?;
115
116        let mut metadata = PersistedTask::starting(
117            task_id.clone(),
118            session_id.clone(),
119            command.to_string(),
120            workdir.clone(),
121            timeout_ms,
122        );
123        write_task(&paths.json, &metadata)
124            .map_err(|e| format!("failed to persist background task metadata: {e}"))?;
125
126        let stdout = create_capture_file(&paths.stdout)
127            .map_err(|e| format!("failed to create stdout capture file: {e}"))?;
128        let stderr = create_capture_file(&paths.stderr)
129            .map_err(|e| format!("failed to create stderr capture file: {e}"))?;
130
131        let child = detached_shell_command(command, &paths.exit)
132            .current_dir(&workdir)
133            .envs(&env)
134            .stdin(Stdio::null())
135            .stdout(Stdio::from(stdout))
136            .stderr(Stdio::from(stderr))
137            .spawn()
138            .map_err(|e| format!("failed to spawn background bash command: {e}"))?;
139
140        let child_pid = child.id();
141        metadata.mark_running(child_pid, child_pid as i32);
142        write_task(&paths.json, &metadata)
143            .map_err(|e| format!("failed to persist running background task metadata: {e}"))?;
144
145        let task = Arc::new(BgTask {
146            task_id: task_id.clone(),
147            session_id,
148            paths: paths.clone(),
149            started: Instant::now(),
150            state: Mutex::new(BgTaskState {
151                metadata,
152                child: Some(child),
153                detached: false,
154                buffer: BgBuffer::new(paths.stdout.clone(), paths.stderr.clone()),
155            }),
156        });
157
158        self.inner
159            .tasks
160            .lock()
161            .map_err(|_| "background task registry lock poisoned".to_string())?
162            .insert(task_id.clone(), task);
163
164        Ok(task_id)
165    }
166
167    #[cfg(windows)]
168    pub fn spawn(
169        &self,
170        _command: &str,
171        _session_id: String,
172        _workdir: PathBuf,
173        _env: HashMap<String, String>,
174        _timeout: Option<Duration>,
175        _storage_dir: PathBuf,
176        _max_running: usize,
177    ) -> Result<String, String> {
178        Err("background bash is not yet supported on Windows".to_string())
179    }
180
181    pub fn replay_session(&self, storage_dir: &Path, session_id: &str) -> Result<(), String> {
182        self.start_watchdog();
183        let dir = session_tasks_dir(storage_dir, session_id);
184        if !dir.exists() {
185            return Ok(());
186        }
187
188        let entries = fs::read_dir(&dir)
189            .map_err(|e| format!("failed to read background task dir {}: {e}", dir.display()))?;
190        for entry in entries.flatten() {
191            let path = entry.path();
192            if path.extension().and_then(|extension| extension.to_str()) != Some("json") {
193                continue;
194            }
195            let Ok(mut metadata) = read_task(&path) else {
196                continue;
197            };
198            if metadata.session_id != session_id {
199                continue;
200            }
201
202            let paths = task_paths(storage_dir, session_id, &metadata.task_id);
203            match metadata.status {
204                BgTaskStatus::Starting => {
205                    metadata.mark_terminal(
206                        BgTaskStatus::Failed,
207                        None,
208                        Some("spawn aborted".to_string()),
209                    );
210                    let _ = write_task(&paths.json, &metadata);
211                    self.enqueue_completion_if_needed(&metadata);
212                }
213                BgTaskStatus::Running => {
214                    if self.running_metadata_is_stale(&metadata) {
215                        metadata.mark_terminal(
216                            BgTaskStatus::Killed,
217                            None,
218                            Some("orphaned (>24h)".to_string()),
219                        );
220                        if !paths.exit.exists() {
221                            let _ = write_kill_marker_if_absent(&paths.exit);
222                        }
223                        let _ = write_task(&paths.json, &metadata);
224                        self.enqueue_completion_if_needed(&metadata);
225                    } else if let Ok(Some(marker)) = read_exit_marker(&paths.exit) {
226                        metadata = terminal_metadata_from_marker(metadata, marker, None);
227                        let _ = write_task(&paths.json, &metadata);
228                        self.enqueue_completion_if_needed(&metadata);
229                    } else {
230                        self.insert_rehydrated_task(metadata, paths, true)?;
231                    }
232                }
233                _ if metadata.status.is_terminal() => {
234                    self.insert_rehydrated_task(metadata.clone(), paths, true)?;
235                    self.enqueue_completion_if_needed(&metadata);
236                }
237                _ => {}
238            }
239        }
240
241        Ok(())
242    }
243
244    pub fn status(
245        &self,
246        task_id: &str,
247        session_id: &str,
248        preview_bytes: usize,
249    ) -> Option<BgTaskSnapshot> {
250        let task = self.task_for_session(task_id, session_id)?;
251        let _ = self.poll_task(&task);
252        Some(task.snapshot(preview_bytes))
253    }
254
255    pub fn list(&self, preview_bytes: usize) -> Vec<BgTaskSnapshot> {
256        let tasks = self
257            .inner
258            .tasks
259            .lock()
260            .map(|tasks| tasks.values().cloned().collect::<Vec<_>>())
261            .unwrap_or_default();
262        tasks
263            .into_iter()
264            .map(|task| {
265                let _ = self.poll_task(&task);
266                task.snapshot(preview_bytes)
267            })
268            .collect()
269    }
270
271    pub fn kill(&self, task_id: &str, session_id: &str) -> Result<BgTaskSnapshot, String> {
272        self.kill_with_status(task_id, session_id, BgTaskStatus::Killed)
273    }
274
275    pub(crate) fn kill_for_timeout(&self, task_id: &str, session_id: &str) -> Result<(), String> {
276        self.kill_with_status(task_id, session_id, BgTaskStatus::TimedOut)
277            .map(|_| ())
278    }
279
280    pub fn cleanup_finished(&self, _older_than: Duration) {}
281
282    pub fn drain_completions(&self) -> Vec<BgCompletion> {
283        self.drain_completions_for_session(None)
284    }
285
286    pub fn drain_completions_for_session(&self, session_id: Option<&str>) -> Vec<BgCompletion> {
287        let mut completions = match self.inner.completions.lock() {
288            Ok(completions) => completions,
289            Err(_) => return Vec::new(),
290        };
291
292        let drained = if let Some(session_id) = session_id {
293            let mut matched = Vec::new();
294            let mut retained = VecDeque::new();
295            while let Some(completion) = completions.pop_front() {
296                if completion.session_id == session_id {
297                    matched.push(completion);
298                } else {
299                    retained.push_back(completion);
300                }
301            }
302            *completions = retained;
303            matched
304        } else {
305            completions.drain(..).collect()
306        };
307        drop(completions);
308
309        for completion in &drained {
310            if let Some(task) = self.task_for_session(&completion.task_id, &completion.session_id) {
311                let _ = task.set_completion_delivered(true);
312            }
313        }
314
315        drained
316    }
317
318    pub fn pending_completions_for_session(&self, session_id: &str) -> Vec<BgCompletion> {
319        self.inner
320            .completions
321            .lock()
322            .map(|completions| {
323                completions
324                    .iter()
325                    .filter(|completion| completion.session_id == session_id)
326                    .cloned()
327                    .collect()
328            })
329            .unwrap_or_default()
330    }
331
332    pub fn detach(&self) {
333        self.inner.shutdown.store(true, Ordering::SeqCst);
334        if let Ok(mut tasks) = self.inner.tasks.lock() {
335            for task in tasks.values() {
336                if let Ok(mut state) = task.state.lock() {
337                    state.child = None;
338                    state.detached = true;
339                }
340            }
341            tasks.clear();
342        }
343    }
344
345    pub fn shutdown(&self) {
346        let tasks = self
347            .inner
348            .tasks
349            .lock()
350            .map(|tasks| {
351                tasks
352                    .values()
353                    .map(|task| (task.task_id.clone(), task.session_id.clone()))
354                    .collect::<Vec<_>>()
355            })
356            .unwrap_or_default();
357        for (task_id, session_id) in tasks {
358            let _ = self.kill(&task_id, &session_id);
359        }
360    }
361
362    pub(crate) fn poll_task(&self, task: &Arc<BgTask>) -> Result<(), String> {
363        let marker = match read_exit_marker(&task.paths.exit) {
364            Ok(Some(marker)) => marker,
365            Ok(None) => return Ok(()),
366            Err(error) => return Err(format!("failed to read exit marker: {error}")),
367        };
368        self.finalize_from_marker(task, marker, None)
369    }
370
371    pub(crate) fn reap_child(&self, task: &Arc<BgTask>) {
372        let Ok(mut state) = task.state.lock() else {
373            return;
374        };
375        if let Some(child) = state.child.as_mut() {
376            if matches!(child.try_wait(), Ok(Some(_))) {
377                state.child = None;
378                state.detached = true;
379            }
380        }
381    }
382
383    pub(crate) fn running_tasks(&self) -> Vec<Arc<BgTask>> {
384        self.inner
385            .tasks
386            .lock()
387            .map(|tasks| {
388                tasks
389                    .values()
390                    .filter(|task| task.is_running())
391                    .cloned()
392                    .collect()
393            })
394            .unwrap_or_default()
395    }
396
397    fn insert_rehydrated_task(
398        &self,
399        metadata: PersistedTask,
400        paths: TaskPaths,
401        detached: bool,
402    ) -> Result<(), String> {
403        let task_id = metadata.task_id.clone();
404        let session_id = metadata.session_id.clone();
405        let task = Arc::new(BgTask {
406            task_id: task_id.clone(),
407            session_id,
408            paths: paths.clone(),
409            started: Instant::now(),
410            state: Mutex::new(BgTaskState {
411                metadata,
412                child: None,
413                detached,
414                buffer: BgBuffer::new(paths.stdout.clone(), paths.stderr.clone()),
415            }),
416        });
417        self.inner
418            .tasks
419            .lock()
420            .map_err(|_| "background task registry lock poisoned".to_string())?
421            .insert(task_id, task);
422        Ok(())
423    }
424
425    fn kill_with_status(
426        &self,
427        task_id: &str,
428        session_id: &str,
429        terminal_status: BgTaskStatus,
430    ) -> Result<BgTaskSnapshot, String> {
431        let task = self
432            .task_for_session(task_id, session_id)
433            .ok_or_else(|| format!("background task not found: {task_id}"))?;
434
435        {
436            let mut state = task
437                .state
438                .lock()
439                .map_err(|_| "background task lock poisoned".to_string())?;
440            if state.metadata.status.is_terminal() {
441                return Ok(task.snapshot_locked(&state, 5 * 1024));
442            }
443
444            state.metadata.status = BgTaskStatus::Killing;
445            write_task(&task.paths.json, &state.metadata)
446                .map_err(|e| format!("failed to persist killing state: {e}"))?;
447
448            #[cfg(unix)]
449            if let Some(pgid) = state.metadata.pgid {
450                terminate_pgid(pgid, state.child.as_mut());
451            }
452            if let Some(child) = state.child.as_mut() {
453                let _ = child.wait();
454            }
455            state.child = None;
456            state.detached = true;
457
458            if !task.paths.exit.exists() {
459                write_kill_marker_if_absent(&task.paths.exit)
460                    .map_err(|e| format!("failed to write kill marker: {e}"))?;
461            }
462
463            let exit_code = if terminal_status == BgTaskStatus::TimedOut {
464                Some(124)
465            } else {
466                None
467            };
468            state
469                .metadata
470                .mark_terminal(terminal_status, exit_code, None);
471            write_task(&task.paths.json, &state.metadata)
472                .map_err(|e| format!("failed to persist killed state: {e}"))?;
473            state.buffer.enforce_terminal_cap();
474            self.enqueue_completion_locked(&state.metadata);
475        }
476
477        Ok(task.snapshot(5 * 1024))
478    }
479
480    fn finalize_from_marker(
481        &self,
482        task: &Arc<BgTask>,
483        marker: ExitMarker,
484        reason: Option<String>,
485    ) -> Result<(), String> {
486        let mut state = task
487            .state
488            .lock()
489            .map_err(|_| "background task lock poisoned".to_string())?;
490        if state.metadata.status.is_terminal() {
491            return Ok(());
492        }
493
494        let updated = update_task(&task.paths.json, |metadata| {
495            let new_metadata = terminal_metadata_from_marker(metadata.clone(), marker, reason);
496            *metadata = new_metadata;
497        })
498        .map_err(|e| format!("failed to persist terminal state: {e}"))?;
499        state.metadata = updated;
500        state.child = None;
501        state.detached = true;
502        state.buffer.enforce_terminal_cap();
503        self.enqueue_completion_locked(&state.metadata);
504        Ok(())
505    }
506
507    fn enqueue_completion_if_needed(&self, metadata: &PersistedTask) {
508        if metadata.status.is_terminal() && !metadata.completion_delivered {
509            self.enqueue_completion_locked(metadata);
510        }
511    }
512
513    fn enqueue_completion_locked(&self, metadata: &PersistedTask) {
514        if !metadata.status.is_terminal() || metadata.completion_delivered {
515            return;
516        }
517        if let Ok(mut completions) = self.inner.completions.lock() {
518            if completions
519                .iter()
520                .any(|completion| completion.task_id == metadata.task_id)
521            {
522                return;
523            }
524            completions.push_back(BgCompletion {
525                task_id: metadata.task_id.clone(),
526                session_id: metadata.session_id.clone(),
527                status: metadata.status.clone(),
528                exit_code: metadata.exit_code,
529                command: metadata.command.clone(),
530            });
531        }
532    }
533
534    fn task(&self, task_id: &str) -> Option<Arc<BgTask>> {
535        self.inner
536            .tasks
537            .lock()
538            .ok()
539            .and_then(|tasks| tasks.get(task_id).cloned())
540    }
541
542    fn task_for_session(&self, task_id: &str, session_id: &str) -> Option<Arc<BgTask>> {
543        self.task(task_id)
544            .filter(|task| task.session_id == session_id)
545    }
546
547    fn running_count(&self) -> usize {
548        self.inner
549            .tasks
550            .lock()
551            .map(|tasks| tasks.values().filter(|task| task.is_running()).count())
552            .unwrap_or(0)
553    }
554
555    fn start_watchdog(&self) {
556        if !self.inner.watchdog_started.swap(true, Ordering::SeqCst) {
557            super::watchdog::start(self.clone());
558        }
559    }
560
561    fn running_metadata_is_stale(&self, metadata: &PersistedTask) -> bool {
562        unix_millis().saturating_sub(metadata.started_at) > STALE_RUNNING_AFTER.as_millis() as u64
563    }
564
565    #[cfg(test)]
566    pub fn task_json_path(&self, task_id: &str, session_id: &str) -> Option<PathBuf> {
567        self.task_for_session(task_id, session_id)
568            .map(|task| task.paths.json.clone())
569    }
570
571    #[cfg(test)]
572    pub fn task_exit_path(&self, task_id: &str, session_id: &str) -> Option<PathBuf> {
573        self.task_for_session(task_id, session_id)
574            .map(|task| task.paths.exit.clone())
575    }
576
577    /// Generate a `bgb-{8hex}` slug that is unique against live tasks and queued completions.
578    fn generate_unique_task_id(&self) -> Result<String, String> {
579        for _ in 0..32 {
580            let candidate = random_slug();
581            let tasks = self
582                .inner
583                .tasks
584                .lock()
585                .map_err(|_| "background task registry lock poisoned".to_string())?;
586            if tasks.contains_key(&candidate) {
587                continue;
588            }
589            let completions = self
590                .inner
591                .completions
592                .lock()
593                .map_err(|_| "background completions lock poisoned".to_string())?;
594            if completions
595                .iter()
596                .any(|completion| completion.task_id == candidate)
597            {
598                continue;
599            }
600            return Ok(candidate);
601        }
602        Err("failed to allocate unique background task id after 32 attempts".to_string())
603    }
604}
605
606impl Default for BgTaskRegistry {
607    fn default() -> Self {
608        Self::new()
609    }
610}
611
612impl BgTask {
613    fn snapshot(&self, preview_bytes: usize) -> BgTaskSnapshot {
614        let state = self
615            .state
616            .lock()
617            .unwrap_or_else(|poison| poison.into_inner());
618        self.snapshot_locked(&state, preview_bytes)
619    }
620
621    fn snapshot_locked(&self, state: &BgTaskState, preview_bytes: usize) -> BgTaskSnapshot {
622        let metadata = &state.metadata;
623        let duration_ms = metadata.duration_ms.or_else(|| {
624            metadata
625                .status
626                .is_terminal()
627                .then(|| self.started.elapsed().as_millis() as u64)
628        });
629        let (output_preview, output_truncated) = state.buffer.read_tail(preview_bytes);
630        BgTaskSnapshot {
631            info: BgTaskInfo {
632                task_id: self.task_id.clone(),
633                status: metadata.status.clone(),
634                command: metadata.command.clone(),
635                started_at: metadata.started_at,
636                duration_ms,
637            },
638            exit_code: metadata.exit_code,
639            child_pid: metadata.child_pid,
640            workdir: metadata.workdir.display().to_string(),
641            output_preview,
642            output_truncated,
643            output_path: state
644                .buffer
645                .output_path()
646                .map(|path| path.display().to_string()),
647        }
648    }
649
650    pub(crate) fn is_running(&self) -> bool {
651        self.state
652            .lock()
653            .map(|state| state.metadata.status == BgTaskStatus::Running)
654            .unwrap_or(false)
655    }
656
657    fn set_completion_delivered(&self, delivered: bool) -> Result<(), String> {
658        let mut state = self
659            .state
660            .lock()
661            .map_err(|_| "background task lock poisoned".to_string())?;
662        let updated = update_task(&self.paths.json, |metadata| {
663            metadata.completion_delivered = delivered;
664        })
665        .map_err(|e| format!("failed to update completion delivery: {e}"))?;
666        state.metadata = updated;
667        Ok(())
668    }
669}
670
671fn terminal_metadata_from_marker(
672    mut metadata: PersistedTask,
673    marker: ExitMarker,
674    reason: Option<String>,
675) -> PersistedTask {
676    match marker {
677        ExitMarker::Code(code) => {
678            let status = if code == 0 {
679                BgTaskStatus::Completed
680            } else {
681                BgTaskStatus::Failed
682            };
683            metadata.mark_terminal(status, Some(code), reason);
684        }
685        ExitMarker::Killed => metadata.mark_terminal(BgTaskStatus::Killed, None, reason),
686    }
687    metadata
688}
689
690#[cfg(unix)]
691fn detached_shell_command(command: &str, exit_path: &Path) -> Command {
692    let mut cmd = Command::new("/bin/sh");
693    cmd.arg("-c")
694        .arg("\"$0\" -c \"$1\"; code=$?; printf \"%s\" \"$code\" > \"$2.tmp.$$\"; mv -f \"$2.tmp.$$\" \"$2\"")
695        .arg("/bin/sh")
696        .arg(command)
697        .arg(exit_path);
698    unsafe {
699        cmd.pre_exec(|| {
700            if libc::setsid() == -1 {
701                return Err(std::io::Error::last_os_error());
702            }
703            Ok(())
704        });
705    }
706    cmd
707}
708
709fn random_slug() -> String {
710    static COUNTER: AtomicU64 = AtomicU64::new(0);
711    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
712    let mixed = unix_millis_nanos()
713        ^ (std::process::id() as u128).wrapping_mul(0x9E3779B97F4A7C15)
714        ^ (counter as u128).wrapping_mul(0xBF58476D1CE4E5B9);
715    format!("bgb-{:08x}", (mixed as u32))
716}
717
718fn unix_millis_nanos() -> u128 {
719    SystemTime::now()
720        .duration_since(UNIX_EPOCH)
721        .map(|duration| duration.as_nanos())
722        .unwrap_or(0)
723}