Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9
10use super::BgTaskStatus;
11
12const SCHEMA_VERSION: u32 = 1;
13
14#[derive(Debug, Clone)]
15pub struct TaskPaths {
16    pub dir: PathBuf,
17    pub json: PathBuf,
18    pub stdout: PathBuf,
19    pub stderr: PathBuf,
20    pub exit: PathBuf,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PersistedTask {
25    pub schema_version: u32,
26    pub task_id: String,
27    pub session_id: String,
28    pub command: String,
29    pub workdir: PathBuf,
30    pub status: BgTaskStatus,
31    pub started_at: u64,
32    pub finished_at: Option<u64>,
33    pub duration_ms: Option<u64>,
34    pub timeout_ms: Option<u64>,
35    pub exit_code: Option<i32>,
36    pub child_pid: Option<u32>,
37    pub pgid: Option<i32>,
38    pub completion_delivered: bool,
39    #[serde(default = "default_notify_on_completion")]
40    pub notify_on_completion: bool,
41    /// Per-call output compression opt-in. Defaults to `true` so existing
42    /// behavior (compression when `experimental.bash.compress=true`) is
43    /// unchanged. Agents can pass `compressed: false` to disable compression
44    /// for a single bash call without flipping the global flag.
45    #[serde(default = "default_compressed")]
46    pub compressed: bool,
47    pub status_reason: Option<String>,
48}
49
50fn default_notify_on_completion() -> bool {
51    true
52}
53
54fn default_compressed() -> bool {
55    true
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum ExitMarker {
60    Code(i32),
61    Killed,
62}
63
64impl PersistedTask {
65    pub fn starting(
66        task_id: String,
67        session_id: String,
68        command: String,
69        workdir: PathBuf,
70        timeout_ms: Option<u64>,
71        notify_on_completion: bool,
72        compressed: bool,
73    ) -> Self {
74        Self {
75            schema_version: SCHEMA_VERSION,
76            task_id,
77            session_id,
78            command,
79            workdir,
80            status: BgTaskStatus::Starting,
81            started_at: unix_millis(),
82            finished_at: None,
83            duration_ms: None,
84            timeout_ms,
85            exit_code: None,
86            child_pid: None,
87            pgid: None,
88            completion_delivered: !notify_on_completion,
89            notify_on_completion,
90            compressed,
91            status_reason: None,
92        }
93    }
94
95    pub fn is_terminal(&self) -> bool {
96        self.status.is_terminal()
97    }
98
99    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
100        self.status = BgTaskStatus::Running;
101        self.child_pid = Some(child_pid);
102        self.pgid = Some(pgid);
103    }
104
105    pub fn mark_terminal(
106        &mut self,
107        status: BgTaskStatus,
108        exit_code: Option<i32>,
109        reason: Option<String>,
110    ) {
111        let finished_at = unix_millis();
112        self.status = status;
113        self.exit_code = exit_code;
114        self.finished_at = Some(finished_at);
115        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
116        self.child_pid = None;
117        self.status_reason = reason;
118        self.completion_delivered = !self.notify_on_completion;
119    }
120}
121
122pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
123    storage_dir
124        .join("bash-tasks")
125        .join(hash_session(session_id))
126}
127
128pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
129    let dir = session_tasks_dir(storage_dir, session_id);
130    TaskPaths {
131        json: dir.join(format!("{task_id}.json")),
132        stdout: dir.join(format!("{task_id}.stdout")),
133        stderr: dir.join(format!("{task_id}.stderr")),
134        exit: dir.join(format!("{task_id}.exit")),
135        dir,
136    }
137}
138
139pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
140    let content = fs::read_to_string(path)?;
141    serde_json::from_str(&content).map_err(io::Error::other)
142}
143
144pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
145    if let Some(parent) = path.parent() {
146        fs::create_dir_all(parent)?;
147    }
148    let content = serde_json::to_vec_pretty(task).map_err(io::Error::other)?;
149    atomic_write(path, &content, &task.task_id)
150}
151
152pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
153where
154    F: FnOnce(&mut PersistedTask),
155{
156    let mut task = read_task(path)?;
157    let original_terminal = task.is_terminal();
158    let original = task.clone();
159    update(&mut task);
160    if original_terminal {
161        let completion_delivered = task.completion_delivered;
162        task = original;
163        task.completion_delivered = completion_delivered;
164    }
165    write_task(path, &task)?;
166    Ok(task)
167}
168
169pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
170    if path.exists() {
171        return Ok(());
172    }
173    atomic_write(path, b"killed", "kill")
174}
175
176pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
177    let mut file = match File::open(path) {
178        Ok(file) => file,
179        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
180        Err(error) => return Err(error),
181    };
182    let mut content = String::new();
183    file.read_to_string(&mut content)?;
184    let content = content.trim();
185    if content.is_empty() {
186        return Ok(None);
187    }
188    if content == "killed" {
189        return Ok(Some(ExitMarker::Killed));
190    }
191    match content.parse::<i32>() {
192        Ok(code) => Ok(Some(ExitMarker::Code(code))),
193        Err(_) => Ok(None),
194    }
195}
196
197pub fn atomic_write(path: &Path, content: &[u8], task_id: &str) -> io::Result<()> {
198    let parent = path.parent().unwrap_or_else(|| Path::new("."));
199    fs::create_dir_all(parent)?;
200    let file_name = path
201        .file_name()
202        .and_then(|name| name.to_str())
203        .unwrap_or("task");
204    let tmp = parent.join(format!(
205        ".{file_name}.tmp.{}.{}",
206        std::process::id(),
207        sanitize_task_id(task_id)
208    ));
209    {
210        let mut file = OpenOptions::new()
211            .create(true)
212            .truncate(true)
213            .write(true)
214            .open(&tmp)?;
215        file.write_all(content)?;
216        file.sync_all()?;
217    }
218    fs::rename(&tmp, path)?;
219    Ok(())
220}
221
222fn sanitize_task_id(task_id: &str) -> String {
223    task_id
224        .chars()
225        .map(|ch| match ch {
226            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
227            _ => '_',
228        })
229        .collect()
230}
231
232pub fn create_capture_file(path: &Path) -> io::Result<File> {
233    if let Some(parent) = path.parent() {
234        fs::create_dir_all(parent)?;
235    }
236    File::create(path)
237}
238
239pub fn unix_millis() -> u64 {
240    SystemTime::now()
241        .duration_since(UNIX_EPOCH)
242        .map(|duration| duration.as_millis() as u64)
243        .unwrap_or(0)
244}
245
246#[cfg(test)]
247mod tests {
248    use std::thread;
249
250    use super::*;
251
252    #[test]
253    fn atomic_write_temp_names_include_task_id() {
254        let dir = tempfile::tempdir().expect("create temp dir");
255        let path = dir.path().join("task.json");
256
257        let left_path = path.clone();
258        let left = thread::spawn(move || atomic_write(&left_path, b"left", "task-left"));
259        let right_path = path.clone();
260        let right = thread::spawn(move || atomic_write(&right_path, b"right", "task-right"));
261
262        left.join().expect("join left").expect("write left");
263        right.join().expect("join right").expect("write right");
264
265        let content = fs::read_to_string(&path).expect("read final content");
266        assert!(content == "left" || content == "right");
267        assert!(!dir
268            .path()
269            .join(format!(".task.json.tmp.{}", std::process::id()))
270            .exists());
271    }
272}