Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9use crate::db::bash_tasks::BashTaskRow;
10
11use super::BgTaskStatus;
12
13pub const SCHEMA_VERSION: u32 = 2;
14
15#[derive(Debug, Clone)]
16pub struct TaskPaths {
17    pub dir: PathBuf,
18    pub json: PathBuf,
19    pub stdout: PathBuf,
20    pub stderr: PathBuf,
21    pub exit: PathBuf,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PersistedTask {
26    pub schema_version: u32,
27    pub task_id: String,
28    pub session_id: String,
29    pub command: String,
30    pub workdir: PathBuf,
31    #[serde(default)]
32    pub project_root: Option<PathBuf>,
33    pub status: BgTaskStatus,
34    pub started_at: u64,
35    pub finished_at: Option<u64>,
36    pub duration_ms: Option<u64>,
37    pub timeout_ms: Option<u64>,
38    pub exit_code: Option<i32>,
39    pub child_pid: Option<u32>,
40    pub pgid: Option<i32>,
41    pub completion_delivered: bool,
42    #[serde(default = "default_notify_on_completion")]
43    pub notify_on_completion: bool,
44    /// Per-call output compression opt-in. Defaults to `true` so existing
45    /// behavior (compression when `experimental.bash.compress=true`) is
46    /// unchanged. Agents can pass `compressed: false` to disable compression
47    /// for a single bash call without flipping the global flag.
48    #[serde(default = "default_compressed")]
49    pub compressed: bool,
50    pub status_reason: Option<String>,
51}
52
53fn default_notify_on_completion() -> bool {
54    true
55}
56
57fn default_compressed() -> bool {
58    true
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum ExitMarker {
63    Code(i32),
64    Killed,
65}
66
67impl PersistedTask {
68    pub fn starting(
69        task_id: String,
70        session_id: String,
71        command: String,
72        workdir: PathBuf,
73        project_root: Option<PathBuf>,
74        timeout_ms: Option<u64>,
75        notify_on_completion: bool,
76        compressed: bool,
77    ) -> Self {
78        Self {
79            schema_version: SCHEMA_VERSION,
80            task_id,
81            session_id,
82            command,
83            workdir,
84            project_root,
85            status: BgTaskStatus::Starting,
86            started_at: unix_millis(),
87            finished_at: None,
88            duration_ms: None,
89            timeout_ms,
90            exit_code: None,
91            child_pid: None,
92            pgid: None,
93            completion_delivered: !notify_on_completion,
94            notify_on_completion,
95            compressed,
96            status_reason: None,
97        }
98    }
99
100    pub fn is_terminal(&self) -> bool {
101        self.status.is_terminal()
102    }
103
104    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
105        self.status = BgTaskStatus::Running;
106        self.child_pid = Some(child_pid);
107        self.pgid = Some(pgid);
108    }
109
110    pub fn mark_terminal(
111        &mut self,
112        status: BgTaskStatus,
113        exit_code: Option<i32>,
114        reason: Option<String>,
115    ) {
116        let finished_at = unix_millis();
117        self.status = status;
118        self.exit_code = exit_code;
119        self.finished_at = Some(finished_at);
120        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
121        self.child_pid = None;
122        self.status_reason = reason;
123        self.completion_delivered = !self.notify_on_completion;
124    }
125
126    pub fn to_bash_task_row(
127        &self,
128        harness: &str,
129        paths: &TaskPaths,
130    ) -> Result<BashTaskRow, serde_json::Error> {
131        let project_root = self.project_root.as_deref().unwrap_or(&self.workdir);
132        let output_bytes = capture_output_bytes(paths);
133        Ok(BashTaskRow {
134            harness: harness.to_string(),
135            session_id: self.session_id.clone(),
136            task_id: self.task_id.clone(),
137            project_key: crate::search_index::project_cache_key(project_root),
138            command: self.command.clone(),
139            cwd: self.workdir.display().to_string(),
140            status: status_name(&self.status).to_string(),
141            exit_code: self.exit_code,
142            pid: self.child_pid.map(i64::from),
143            pgid: self.pgid.map(i64::from),
144            started_at: self.started_at as i64,
145            completed_at: self.finished_at.map(|value| value as i64),
146            stdout_path: Some(paths.stdout.display().to_string()),
147            stderr_path: Some(paths.stderr.display().to_string()),
148            compressed: self.compressed,
149            timeout_ms: self.timeout_ms.map(|value| value as i64),
150            completion_delivered: self.completion_delivered,
151            output_bytes,
152            metadata: serde_json::to_string(self)?,
153        })
154    }
155}
156
157impl From<BashTaskRow> for PersistedTask {
158    fn from(row: BashTaskRow) -> Self {
159        if let Ok(task) = serde_json::from_str::<PersistedTask>(&row.metadata) {
160            return task;
161        }
162
163        let status = match row.status.as_str() {
164            "starting" => BgTaskStatus::Starting,
165            "running" => BgTaskStatus::Running,
166            "killing" => BgTaskStatus::Killing,
167            "completed" => BgTaskStatus::Completed,
168            "failed" => BgTaskStatus::Failed,
169            "killed" => BgTaskStatus::Killed,
170            "timed_out" => BgTaskStatus::TimedOut,
171            _ => BgTaskStatus::Failed,
172        };
173        let started_at = u64::try_from(row.started_at).unwrap_or_default();
174        let finished_at = row.completed_at.and_then(|value| u64::try_from(value).ok());
175
176        PersistedTask {
177            schema_version: SCHEMA_VERSION,
178            task_id: row.task_id,
179            session_id: row.session_id,
180            command: row.command,
181            workdir: PathBuf::from(row.cwd),
182            project_root: None,
183            status,
184            started_at,
185            finished_at,
186            duration_ms: finished_at.map(|finished_at| finished_at.saturating_sub(started_at)),
187            timeout_ms: row.timeout_ms.and_then(|value| u64::try_from(value).ok()),
188            exit_code: row.exit_code,
189            child_pid: row.pid.and_then(|value| u32::try_from(value).ok()),
190            pgid: row.pgid.and_then(|value| i32::try_from(value).ok()),
191            completion_delivered: row.completion_delivered,
192            notify_on_completion: !row.completion_delivered,
193            compressed: row.compressed,
194            status_reason: None,
195        }
196    }
197}
198
199fn status_name(status: &BgTaskStatus) -> &'static str {
200    match status {
201        BgTaskStatus::Starting => "starting",
202        BgTaskStatus::Running => "running",
203        BgTaskStatus::Killing => "killing",
204        BgTaskStatus::Completed => "completed",
205        BgTaskStatus::Failed => "failed",
206        BgTaskStatus::Killed => "killed",
207        BgTaskStatus::TimedOut => "timed_out",
208    }
209}
210
211fn capture_output_bytes(paths: &TaskPaths) -> Option<i64> {
212    let stdout = fs::metadata(&paths.stdout)
213        .ok()
214        .map(|metadata| metadata.len());
215    let stderr = fs::metadata(&paths.stderr)
216        .ok()
217        .map(|metadata| metadata.len());
218    match (stdout, stderr) {
219        (Some(stdout), Some(stderr)) => Some(stdout.saturating_add(stderr) as i64),
220        (Some(bytes), None) | (None, Some(bytes)) => Some(bytes as i64),
221        (None, None) => None,
222    }
223}
224
225pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
226    storage_dir
227        .join("bash-tasks")
228        .join(hash_session(session_id))
229}
230
231pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
232    let dir = session_tasks_dir(storage_dir, session_id);
233    TaskPaths {
234        json: dir.join(format!("{task_id}.json")),
235        stdout: dir.join(format!("{task_id}.stdout")),
236        stderr: dir.join(format!("{task_id}.stderr")),
237        exit: dir.join(format!("{task_id}.exit")),
238        dir,
239    }
240}
241
242pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
243    let content = fs::read_to_string(path)?;
244    let task: PersistedTask = serde_json::from_str(&content).map_err(io::Error::other)?;
245    if task.schema_version != SCHEMA_VERSION {
246        return Err(io::Error::new(
247            io::ErrorKind::InvalidData,
248            format!(
249                "unsupported background task schema_version {} (expected {SCHEMA_VERSION})",
250                task.schema_version
251            ),
252        ));
253    }
254    Ok(task)
255}
256
257pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
258    if let Some(parent) = path.parent() {
259        fs::create_dir_all(parent)?;
260    }
261    let content = serde_json::to_vec_pretty(task).map_err(io::Error::other)?;
262    atomic_write(path, &content, &task.task_id)
263}
264
265pub(super) fn delete_task_bundle(paths: &TaskPaths) -> io::Result<()> {
266    let mut first_error = None;
267    for path in task_bundle_files(paths) {
268        if let Err(error) = remove_file_if_present(&path) {
269            if first_error.is_none() {
270                first_error = Some(error);
271            }
272        }
273    }
274
275    if let Some(error) = first_error {
276        return Err(error);
277    }
278
279    match fs::remove_dir(&paths.dir) {
280        Ok(()) => Ok(()),
281        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
282        Err(error) if error.kind() == io::ErrorKind::DirectoryNotEmpty => Ok(()),
283        Err(error) => Err(error),
284    }
285}
286
287fn task_bundle_files(paths: &TaskPaths) -> Vec<PathBuf> {
288    let mut files = vec![
289        paths.json.clone(),
290        paths.stdout.clone(),
291        paths.stderr.clone(),
292        paths.exit.clone(),
293    ];
294    if let Some(stem) = paths.json.file_stem().and_then(|stem| stem.to_str()) {
295        // Windows background bash writes per-task wrapper scripts next to the
296        // capture files as `<task-id>.ps1`, `<task-id>.bat`, or `<task-id>.sh`
297        // depending on the shell selected in `detached_shell_command_for`.
298        for extension in ["ps1", "bat", "sh"] {
299            files.push(paths.dir.join(format!("{stem}.{extension}")));
300        }
301    }
302    files
303}
304
305fn remove_file_if_present(path: &Path) -> io::Result<()> {
306    match fs::remove_file(path) {
307        Ok(()) => Ok(()),
308        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
309        Err(error) => Err(error),
310    }
311}
312
313pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
314where
315    F: FnOnce(&mut PersistedTask),
316{
317    let mut task = read_task(path)?;
318    let original_terminal = task.is_terminal();
319    let original = task.clone();
320    update(&mut task);
321    if original_terminal {
322        let completion_delivered = task.completion_delivered;
323        task = original;
324        task.completion_delivered = completion_delivered;
325    }
326    write_task(path, &task)?;
327    Ok(task)
328}
329
330pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
331    if path.exists() {
332        return Ok(());
333    }
334    atomic_write(path, b"killed", "kill")
335}
336
337pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
338    let mut file = match File::open(path) {
339        Ok(file) => file,
340        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
341        Err(error) => return Err(error),
342    };
343    let mut content = String::new();
344    file.read_to_string(&mut content)?;
345    let content = content.trim();
346    if content.is_empty() {
347        return Ok(None);
348    }
349    if content == "killed" {
350        return Ok(Some(ExitMarker::Killed));
351    }
352    match content.parse::<i32>() {
353        Ok(code) => Ok(Some(ExitMarker::Code(code))),
354        Err(_) => Ok(None),
355    }
356}
357
358pub fn atomic_write(path: &Path, content: &[u8], task_id: &str) -> io::Result<()> {
359    let parent = path.parent().unwrap_or_else(|| Path::new("."));
360    fs::create_dir_all(parent)?;
361    let file_name = path
362        .file_name()
363        .and_then(|name| name.to_str())
364        .unwrap_or("task");
365    let tmp = parent.join(format!(
366        ".{file_name}.tmp.{}.{}",
367        std::process::id(),
368        sanitize_task_id(task_id)
369    ));
370    {
371        let mut file = OpenOptions::new()
372            .create(true)
373            .truncate(true)
374            .write(true)
375            .open(&tmp)?;
376        file.write_all(content)?;
377        file.sync_all()?;
378    }
379    fs::rename(&tmp, path)?;
380    Ok(())
381}
382
383fn sanitize_task_id(task_id: &str) -> String {
384    task_id
385        .chars()
386        .map(|ch| match ch {
387            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
388            _ => '_',
389        })
390        .collect()
391}
392
393pub fn create_capture_file(path: &Path) -> io::Result<File> {
394    if let Some(parent) = path.parent() {
395        fs::create_dir_all(parent)?;
396    }
397    File::create(path)
398}
399
400pub fn unix_millis() -> u64 {
401    SystemTime::now()
402        .duration_since(UNIX_EPOCH)
403        .map(|duration| duration.as_millis() as u64)
404        .unwrap_or(0)
405}
406
407#[cfg(test)]
408mod tests {
409    use std::thread;
410
411    use super::*;
412
413    #[test]
414    fn atomic_write_temp_names_include_task_id() {
415        let dir = tempfile::tempdir().expect("create temp dir");
416        let path = dir.path().join("task.json");
417
418        let left_path = path.clone();
419        let left = thread::spawn(move || atomic_write(&left_path, b"left", "task-left"));
420        let right_path = path.clone();
421        let right = thread::spawn(move || atomic_write(&right_path, b"right", "task-right"));
422
423        left.join().expect("join left").expect("write left");
424        right.join().expect("join right").expect("write right");
425
426        let content = fs::read_to_string(&path).expect("read final content");
427        assert!(content == "left" || content == "right");
428        assert!(!dir
429            .path()
430            .join(format!(".task.json.tmp.{}", std::process::id()))
431            .exists());
432    }
433}