Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9use crate::db::bash_tasks::BashTaskRow;
10
11use super::BgTaskStatus;
12
13pub const SCHEMA_VERSION: u32 = 4;
14
15#[derive(Debug, Clone)]
16pub struct TaskPaths {
17    pub dir: PathBuf,
18    pub json: PathBuf,
19    pub stdout: PathBuf,
20    pub stderr: PathBuf,
21    pub exit: PathBuf,
22    pub pty: PathBuf,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
26#[serde(rename_all = "lowercase")]
27pub enum BgMode {
28    #[default]
29    Pipes,
30    Pty,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PersistedTask {
35    pub schema_version: u32,
36    pub task_id: String,
37    pub session_id: String,
38    pub command: String,
39    #[serde(default)]
40    pub mode: BgMode,
41    pub workdir: PathBuf,
42    #[serde(default)]
43    pub project_root: Option<PathBuf>,
44    pub status: BgTaskStatus,
45    pub started_at: u64,
46    pub finished_at: Option<u64>,
47    pub duration_ms: Option<u64>,
48    pub timeout_ms: Option<u64>,
49    pub exit_code: Option<i32>,
50    pub child_pid: Option<u32>,
51    pub pgid: Option<i32>,
52    pub completion_delivered: bool,
53    #[serde(default = "default_notify_on_completion")]
54    pub notify_on_completion: bool,
55    /// Per-call output compression opt-in. Defaults to `true` so existing
56    /// behavior (compression when `experimental.bash.compress=true`) is
57    /// unchanged. Agents can pass `compressed: false` to disable compression
58    /// for a single bash call without flipping the global flag.
59    #[serde(default = "default_compressed")]
60    pub compressed: bool,
61    #[serde(default)]
62    pub pty_rows: Option<u16>,
63    #[serde(default)]
64    pub pty_cols: Option<u16>,
65    pub status_reason: Option<String>,
66}
67
68fn default_notify_on_completion() -> bool {
69    true
70}
71
72fn default_compressed() -> bool {
73    true
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum ExitMarker {
78    Code(i32),
79    Killed,
80}
81
82impl PersistedTask {
83    pub fn starting(
84        task_id: String,
85        session_id: String,
86        command: String,
87        workdir: PathBuf,
88        project_root: Option<PathBuf>,
89        timeout_ms: Option<u64>,
90        notify_on_completion: bool,
91        compressed: bool,
92    ) -> Self {
93        Self {
94            schema_version: SCHEMA_VERSION,
95            task_id,
96            session_id,
97            command,
98            mode: BgMode::Pipes,
99            workdir,
100            project_root,
101            status: BgTaskStatus::Starting,
102            started_at: unix_millis(),
103            finished_at: None,
104            duration_ms: None,
105            timeout_ms,
106            exit_code: None,
107            child_pid: None,
108            pgid: None,
109            completion_delivered: !notify_on_completion,
110            notify_on_completion,
111            compressed,
112            pty_rows: None,
113            pty_cols: None,
114            status_reason: None,
115        }
116    }
117
118    pub fn is_terminal(&self) -> bool {
119        self.status.is_terminal()
120    }
121
122    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
123        self.status = BgTaskStatus::Running;
124        self.child_pid = Some(child_pid);
125        self.pgid = Some(pgid);
126    }
127
128    pub fn mark_terminal(
129        &mut self,
130        status: BgTaskStatus,
131        exit_code: Option<i32>,
132        reason: Option<String>,
133    ) {
134        let finished_at = unix_millis();
135        self.status = status;
136        self.exit_code = exit_code;
137        self.finished_at = Some(finished_at);
138        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
139        self.child_pid = None;
140        self.status_reason = reason;
141        self.completion_delivered = !self.notify_on_completion;
142    }
143
144    pub fn to_bash_task_row(
145        &self,
146        harness: &str,
147        paths: &TaskPaths,
148    ) -> Result<BashTaskRow, serde_json::Error> {
149        let project_root = self.project_root.as_deref().unwrap_or(&self.workdir);
150        let output_bytes = capture_output_bytes(&self.mode, paths);
151        let stdout_path = match self.mode {
152            BgMode::Pipes => Some(paths.stdout.display().to_string()),
153            BgMode::Pty => Some(paths.pty.display().to_string()),
154        };
155        let stderr_path = match self.mode {
156            BgMode::Pipes => Some(paths.stderr.display().to_string()),
157            BgMode::Pty => None,
158        };
159        let mut metadata = self.clone();
160        metadata.schema_version = SCHEMA_VERSION;
161        Ok(BashTaskRow {
162            harness: harness.to_string(),
163            session_id: self.session_id.clone(),
164            task_id: self.task_id.clone(),
165            project_key: crate::search_index::project_cache_key(project_root),
166            command: self.command.clone(),
167            cwd: self.workdir.display().to_string(),
168            status: status_name(&self.status).to_string(),
169            exit_code: self.exit_code,
170            pid: self.child_pid.map(i64::from),
171            pgid: self.pgid.map(i64::from),
172            started_at: self.started_at as i64,
173            completed_at: self.finished_at.map(|value| value as i64),
174            stdout_path,
175            stderr_path,
176            compressed: self.compressed,
177            timeout_ms: self.timeout_ms.map(|value| value as i64),
178            completion_delivered: self.completion_delivered,
179            output_bytes,
180            metadata: serde_json::to_string(&metadata)?,
181        })
182    }
183}
184
185impl From<BashTaskRow> for PersistedTask {
186    fn from(row: BashTaskRow) -> Self {
187        if let Ok(task) = serde_json::from_str::<PersistedTask>(&row.metadata) {
188            return task;
189        }
190
191        let status = match row.status.as_str() {
192            "starting" => BgTaskStatus::Starting,
193            "running" => BgTaskStatus::Running,
194            "killing" => BgTaskStatus::Killing,
195            "completed" => BgTaskStatus::Completed,
196            "failed" => BgTaskStatus::Failed,
197            "killed" => BgTaskStatus::Killed,
198            "timed_out" => BgTaskStatus::TimedOut,
199            _ => BgTaskStatus::Failed,
200        };
201        let started_at = u64::try_from(row.started_at).unwrap_or_default();
202        let finished_at = row.completed_at.and_then(|value| u64::try_from(value).ok());
203
204        PersistedTask {
205            schema_version: SCHEMA_VERSION,
206            task_id: row.task_id,
207            session_id: row.session_id,
208            command: row.command,
209            mode: BgMode::Pipes,
210            workdir: PathBuf::from(row.cwd),
211            project_root: None,
212            status,
213            started_at,
214            finished_at,
215            duration_ms: finished_at.map(|finished_at| finished_at.saturating_sub(started_at)),
216            timeout_ms: row.timeout_ms.and_then(|value| u64::try_from(value).ok()),
217            exit_code: row.exit_code,
218            child_pid: row.pid.and_then(|value| u32::try_from(value).ok()),
219            pgid: row.pgid.and_then(|value| i32::try_from(value).ok()),
220            completion_delivered: row.completion_delivered,
221            notify_on_completion: !row.completion_delivered,
222            compressed: row.compressed,
223            pty_rows: None,
224            pty_cols: None,
225            status_reason: None,
226        }
227    }
228}
229
230fn status_name(status: &BgTaskStatus) -> &'static str {
231    match status {
232        BgTaskStatus::Starting => "starting",
233        BgTaskStatus::Running => "running",
234        BgTaskStatus::Killing => "killing",
235        BgTaskStatus::Completed => "completed",
236        BgTaskStatus::Failed => "failed",
237        BgTaskStatus::Killed => "killed",
238        BgTaskStatus::TimedOut => "timed_out",
239    }
240}
241
242fn capture_output_bytes(mode: &BgMode, paths: &TaskPaths) -> Option<i64> {
243    match mode {
244        BgMode::Pipes => {
245            let stdout = fs::metadata(&paths.stdout)
246                .ok()
247                .map(|metadata| metadata.len());
248            let stderr = fs::metadata(&paths.stderr)
249                .ok()
250                .map(|metadata| metadata.len());
251            match (stdout, stderr) {
252                (Some(stdout), Some(stderr)) => Some(stdout.saturating_add(stderr) as i64),
253                (Some(bytes), None) | (None, Some(bytes)) => Some(bytes as i64),
254                (None, None) => None,
255            }
256        }
257        BgMode::Pty => fs::metadata(&paths.pty)
258            .ok()
259            .map(|metadata| metadata.len() as i64),
260    }
261}
262
263pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
264    let session_hash = hash_session(session_id);
265    let direct = storage_dir.join("bash-tasks").join(&session_hash);
266    if direct.exists() {
267        return direct;
268    }
269
270    let mut harness_matches = ["opencode", "pi"]
271        .into_iter()
272        .map(|harness| {
273            storage_dir
274                .join(harness)
275                .join("bash-tasks")
276                .join(&session_hash)
277        })
278        .filter(|path| path.exists())
279        .collect::<Vec<_>>();
280    if harness_matches.len() == 1 {
281        return harness_matches.remove(0);
282    }
283
284    direct
285}
286
287pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
288    let dir = session_tasks_dir(storage_dir, session_id);
289    TaskPaths {
290        json: dir.join(format!("{task_id}.json")),
291        stdout: dir.join(format!("{task_id}.stdout")),
292        stderr: dir.join(format!("{task_id}.stderr")),
293        exit: dir.join(format!("{task_id}.exit")),
294        pty: dir.join(format!("{task_id}.pty")),
295        dir,
296    }
297}
298
299pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
300    let content = fs::read_to_string(path)?;
301    let task: PersistedTask = serde_json::from_str(&content).map_err(io::Error::other)?;
302    if !matches!(task.schema_version, 2 | 3 | SCHEMA_VERSION) {
303        return Err(io::Error::new(
304            io::ErrorKind::InvalidData,
305            format!(
306                "unsupported background task schema_version {} (expected 2, 3, or {SCHEMA_VERSION})",
307                task.schema_version
308            ),
309        ));
310    }
311    Ok(task)
312}
313
314pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
315    if let Some(parent) = path.parent() {
316        fs::create_dir_all(parent)?;
317    }
318    let mut upgraded = task.clone();
319    upgraded.schema_version = SCHEMA_VERSION;
320    let content = serde_json::to_vec_pretty(&upgraded).map_err(io::Error::other)?;
321    atomic_write(path, &content, &upgraded.task_id)
322}
323
324pub(super) fn delete_task_bundle(paths: &TaskPaths) -> io::Result<()> {
325    let mut first_error = None;
326    for path in task_bundle_files(paths) {
327        if let Err(error) = remove_file_if_present(&path) {
328            if first_error.is_none() {
329                first_error = Some(error);
330            }
331        }
332    }
333
334    if let Some(error) = first_error {
335        return Err(error);
336    }
337
338    match fs::remove_dir(&paths.dir) {
339        Ok(()) => Ok(()),
340        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
341        Err(error) if error.kind() == io::ErrorKind::DirectoryNotEmpty => Ok(()),
342        Err(error) => Err(error),
343    }
344}
345
346pub fn task_bundle_files(paths: &TaskPaths) -> Vec<PathBuf> {
347    let mut files = vec![
348        paths.json.clone(),
349        paths.stdout.clone(),
350        paths.stderr.clone(),
351        paths.exit.clone(),
352        paths.pty.clone(),
353    ];
354    if let Some(stem) = paths.json.file_stem().and_then(|stem| stem.to_str()) {
355        // Windows background bash writes per-task wrapper scripts next to the
356        // capture files as `<task-id>.ps1`, `<task-id>.bat`, or `<task-id>.sh`
357        // depending on the shell selected in `detached_shell_command_for`.
358        for extension in ["ps1", "bat", "sh"] {
359            files.push(paths.dir.join(format!("{stem}.{extension}")));
360        }
361    }
362    files
363}
364
365fn remove_file_if_present(path: &Path) -> io::Result<()> {
366    match fs::remove_file(path) {
367        Ok(()) => Ok(()),
368        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
369        Err(error) => Err(error),
370    }
371}
372
373pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
374where
375    F: FnOnce(&mut PersistedTask),
376{
377    let mut task = read_task(path)?;
378    let original_terminal = task.is_terminal();
379    let original = task.clone();
380    update(&mut task);
381    task.schema_version = SCHEMA_VERSION;
382    if original_terminal {
383        let completion_delivered = task.completion_delivered;
384        task = original;
385        task.completion_delivered = completion_delivered;
386        task.schema_version = SCHEMA_VERSION;
387    }
388    write_task(path, &task)?;
389    Ok(task)
390}
391
392pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
393    if path.exists() {
394        return Ok(());
395    }
396    atomic_write(path, b"killed", "kill")
397}
398
399pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
400    let mut file = match File::open(path) {
401        Ok(file) => file,
402        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
403        Err(error) => return Err(error),
404    };
405    let mut content = String::new();
406    file.read_to_string(&mut content)?;
407    let content = content.trim();
408    if content.is_empty() {
409        return Ok(None);
410    }
411    if content == "killed" {
412        return Ok(Some(ExitMarker::Killed));
413    }
414    match content.parse::<i32>() {
415        Ok(code) => Ok(Some(ExitMarker::Code(code))),
416        Err(_) => Ok(None),
417    }
418}
419
420pub fn atomic_write(path: &Path, content: &[u8], task_id: &str) -> io::Result<()> {
421    let parent = path.parent().unwrap_or_else(|| Path::new("."));
422    fs::create_dir_all(parent)?;
423    let file_name = path
424        .file_name()
425        .and_then(|name| name.to_str())
426        .unwrap_or("task");
427    let tmp = parent.join(format!(
428        ".{file_name}.tmp.{}.{}",
429        std::process::id(),
430        sanitize_task_id(task_id)
431    ));
432    {
433        let mut file = OpenOptions::new()
434            .create(true)
435            .truncate(true)
436            .write(true)
437            .open(&tmp)?;
438        file.write_all(content)?;
439        file.sync_all()?;
440    }
441    fs::rename(&tmp, path)?;
442    Ok(())
443}
444
445fn sanitize_task_id(task_id: &str) -> String {
446    task_id
447        .chars()
448        .map(|ch| match ch {
449            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
450            _ => '_',
451        })
452        .collect()
453}
454
455pub fn create_capture_file(path: &Path) -> io::Result<File> {
456    if let Some(parent) = path.parent() {
457        fs::create_dir_all(parent)?;
458    }
459    File::create(path)
460}
461
462pub fn unix_millis() -> u64 {
463    SystemTime::now()
464        .duration_since(UNIX_EPOCH)
465        .map(|duration| duration.as_millis() as u64)
466        .unwrap_or(0)
467}
468
469#[cfg(test)]
470mod tests {
471    use std::thread;
472
473    use super::*;
474
475    #[test]
476    fn atomic_write_temp_names_include_task_id() {
477        let dir = tempfile::tempdir().expect("create temp dir");
478        let path = dir.path().join("task.json");
479
480        let left_path = path.clone();
481        let left = thread::spawn(move || atomic_write(&left_path, b"left", "task-left"));
482        let right_path = path.clone();
483        let right = thread::spawn(move || atomic_write(&right_path, b"right", "task-right"));
484
485        left.join().expect("join left").expect("write left");
486        right.join().expect("join right").expect("write right");
487
488        let content = fs::read_to_string(&path).expect("read final content");
489        assert!(content == "left" || content == "right");
490        assert!(!dir
491            .path()
492            .join(format!(".task.json.tmp.{}", std::process::id()))
493            .exists());
494    }
495}