Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9
10use super::BgTaskStatus;
11
12pub const SCHEMA_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct TaskPaths {
16    pub dir: PathBuf,
17    pub json: PathBuf,
18    pub stdout: PathBuf,
19    pub stderr: PathBuf,
20    pub exit: PathBuf,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PersistedTask {
25    pub schema_version: u32,
26    pub task_id: String,
27    pub session_id: String,
28    pub command: String,
29    pub workdir: PathBuf,
30    #[serde(default)]
31    pub project_root: Option<PathBuf>,
32    pub status: BgTaskStatus,
33    pub started_at: u64,
34    pub finished_at: Option<u64>,
35    pub duration_ms: Option<u64>,
36    pub timeout_ms: Option<u64>,
37    pub exit_code: Option<i32>,
38    pub child_pid: Option<u32>,
39    pub pgid: Option<i32>,
40    pub completion_delivered: bool,
41    #[serde(default = "default_notify_on_completion")]
42    pub notify_on_completion: bool,
43    /// Per-call output compression opt-in. Defaults to `true` so existing
44    /// behavior (compression when `experimental.bash.compress=true`) is
45    /// unchanged. Agents can pass `compressed: false` to disable compression
46    /// for a single bash call without flipping the global flag.
47    #[serde(default = "default_compressed")]
48    pub compressed: bool,
49    pub status_reason: Option<String>,
50}
51
52fn default_notify_on_completion() -> bool {
53    true
54}
55
56fn default_compressed() -> bool {
57    true
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum ExitMarker {
62    Code(i32),
63    Killed,
64}
65
66impl PersistedTask {
67    pub fn starting(
68        task_id: String,
69        session_id: String,
70        command: String,
71        workdir: PathBuf,
72        project_root: Option<PathBuf>,
73        timeout_ms: Option<u64>,
74        notify_on_completion: bool,
75        compressed: bool,
76    ) -> Self {
77        Self {
78            schema_version: SCHEMA_VERSION,
79            task_id,
80            session_id,
81            command,
82            workdir,
83            project_root,
84            status: BgTaskStatus::Starting,
85            started_at: unix_millis(),
86            finished_at: None,
87            duration_ms: None,
88            timeout_ms,
89            exit_code: None,
90            child_pid: None,
91            pgid: None,
92            completion_delivered: !notify_on_completion,
93            notify_on_completion,
94            compressed,
95            status_reason: None,
96        }
97    }
98
99    pub fn is_terminal(&self) -> bool {
100        self.status.is_terminal()
101    }
102
103    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
104        self.status = BgTaskStatus::Running;
105        self.child_pid = Some(child_pid);
106        self.pgid = Some(pgid);
107    }
108
109    pub fn mark_terminal(
110        &mut self,
111        status: BgTaskStatus,
112        exit_code: Option<i32>,
113        reason: Option<String>,
114    ) {
115        let finished_at = unix_millis();
116        self.status = status;
117        self.exit_code = exit_code;
118        self.finished_at = Some(finished_at);
119        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
120        self.child_pid = None;
121        self.status_reason = reason;
122        self.completion_delivered = !self.notify_on_completion;
123    }
124}
125
126pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
127    storage_dir
128        .join("bash-tasks")
129        .join(hash_session(session_id))
130}
131
132pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
133    let dir = session_tasks_dir(storage_dir, session_id);
134    TaskPaths {
135        json: dir.join(format!("{task_id}.json")),
136        stdout: dir.join(format!("{task_id}.stdout")),
137        stderr: dir.join(format!("{task_id}.stderr")),
138        exit: dir.join(format!("{task_id}.exit")),
139        dir,
140    }
141}
142
143pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
144    let content = fs::read_to_string(path)?;
145    let task: PersistedTask = serde_json::from_str(&content).map_err(io::Error::other)?;
146    if task.schema_version != SCHEMA_VERSION {
147        return Err(io::Error::new(
148            io::ErrorKind::InvalidData,
149            format!(
150                "unsupported background task schema_version {} (expected {SCHEMA_VERSION})",
151                task.schema_version
152            ),
153        ));
154    }
155    Ok(task)
156}
157
158pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
159    if let Some(parent) = path.parent() {
160        fs::create_dir_all(parent)?;
161    }
162    let content = serde_json::to_vec_pretty(task).map_err(io::Error::other)?;
163    atomic_write(path, &content, &task.task_id)
164}
165
166pub(super) fn delete_task_bundle(paths: &TaskPaths) -> io::Result<()> {
167    let mut first_error = None;
168    for path in task_bundle_files(paths) {
169        if let Err(error) = remove_file_if_present(&path) {
170            if first_error.is_none() {
171                first_error = Some(error);
172            }
173        }
174    }
175
176    if let Some(error) = first_error {
177        return Err(error);
178    }
179
180    match fs::remove_dir(&paths.dir) {
181        Ok(()) => Ok(()),
182        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
183        Err(error) if error.kind() == io::ErrorKind::DirectoryNotEmpty => Ok(()),
184        Err(error) => Err(error),
185    }
186}
187
188fn task_bundle_files(paths: &TaskPaths) -> Vec<PathBuf> {
189    let mut files = vec![
190        paths.json.clone(),
191        paths.stdout.clone(),
192        paths.stderr.clone(),
193        paths.exit.clone(),
194    ];
195    if let Some(stem) = paths.json.file_stem().and_then(|stem| stem.to_str()) {
196        // Windows background bash writes per-task wrapper scripts next to the
197        // capture files as `<task-id>.ps1`, `<task-id>.bat`, or `<task-id>.sh`
198        // depending on the shell selected in `detached_shell_command_for`.
199        for extension in ["ps1", "bat", "sh"] {
200            files.push(paths.dir.join(format!("{stem}.{extension}")));
201        }
202    }
203    files
204}
205
206fn remove_file_if_present(path: &Path) -> io::Result<()> {
207    match fs::remove_file(path) {
208        Ok(()) => Ok(()),
209        Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
210        Err(error) => Err(error),
211    }
212}
213
214pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
215where
216    F: FnOnce(&mut PersistedTask),
217{
218    let mut task = read_task(path)?;
219    let original_terminal = task.is_terminal();
220    let original = task.clone();
221    update(&mut task);
222    if original_terminal {
223        let completion_delivered = task.completion_delivered;
224        task = original;
225        task.completion_delivered = completion_delivered;
226    }
227    write_task(path, &task)?;
228    Ok(task)
229}
230
231pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
232    if path.exists() {
233        return Ok(());
234    }
235    atomic_write(path, b"killed", "kill")
236}
237
238pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
239    let mut file = match File::open(path) {
240        Ok(file) => file,
241        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
242        Err(error) => return Err(error),
243    };
244    let mut content = String::new();
245    file.read_to_string(&mut content)?;
246    let content = content.trim();
247    if content.is_empty() {
248        return Ok(None);
249    }
250    if content == "killed" {
251        return Ok(Some(ExitMarker::Killed));
252    }
253    match content.parse::<i32>() {
254        Ok(code) => Ok(Some(ExitMarker::Code(code))),
255        Err(_) => Ok(None),
256    }
257}
258
259pub fn atomic_write(path: &Path, content: &[u8], task_id: &str) -> io::Result<()> {
260    let parent = path.parent().unwrap_or_else(|| Path::new("."));
261    fs::create_dir_all(parent)?;
262    let file_name = path
263        .file_name()
264        .and_then(|name| name.to_str())
265        .unwrap_or("task");
266    let tmp = parent.join(format!(
267        ".{file_name}.tmp.{}.{}",
268        std::process::id(),
269        sanitize_task_id(task_id)
270    ));
271    {
272        let mut file = OpenOptions::new()
273            .create(true)
274            .truncate(true)
275            .write(true)
276            .open(&tmp)?;
277        file.write_all(content)?;
278        file.sync_all()?;
279    }
280    fs::rename(&tmp, path)?;
281    Ok(())
282}
283
284fn sanitize_task_id(task_id: &str) -> String {
285    task_id
286        .chars()
287        .map(|ch| match ch {
288            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
289            _ => '_',
290        })
291        .collect()
292}
293
294pub fn create_capture_file(path: &Path) -> io::Result<File> {
295    if let Some(parent) = path.parent() {
296        fs::create_dir_all(parent)?;
297    }
298    File::create(path)
299}
300
301pub fn unix_millis() -> u64 {
302    SystemTime::now()
303        .duration_since(UNIX_EPOCH)
304        .map(|duration| duration.as_millis() as u64)
305        .unwrap_or(0)
306}
307
308#[cfg(test)]
309mod tests {
310    use std::thread;
311
312    use super::*;
313
314    #[test]
315    fn atomic_write_temp_names_include_task_id() {
316        let dir = tempfile::tempdir().expect("create temp dir");
317        let path = dir.path().join("task.json");
318
319        let left_path = path.clone();
320        let left = thread::spawn(move || atomic_write(&left_path, b"left", "task-left"));
321        let right_path = path.clone();
322        let right = thread::spawn(move || atomic_write(&right_path, b"right", "task-right"));
323
324        left.join().expect("join left").expect("write left");
325        right.join().expect("join right").expect("write right");
326
327        let content = fs::read_to_string(&path).expect("read final content");
328        assert!(content == "left" || content == "right");
329        assert!(!dir
330            .path()
331            .join(format!(".task.json.tmp.{}", std::process::id()))
332            .exists());
333    }
334}