Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9
10use super::BgTaskStatus;
11
12const SCHEMA_VERSION: u32 = 1;
13
14#[derive(Debug, Clone)]
15pub struct TaskPaths {
16    pub dir: PathBuf,
17    pub json: PathBuf,
18    pub stdout: PathBuf,
19    pub stderr: PathBuf,
20    pub exit: PathBuf,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PersistedTask {
25    pub schema_version: u32,
26    pub task_id: String,
27    pub session_id: String,
28    pub command: String,
29    pub workdir: PathBuf,
30    pub status: BgTaskStatus,
31    pub started_at: u64,
32    pub finished_at: Option<u64>,
33    pub duration_ms: Option<u64>,
34    pub timeout_ms: Option<u64>,
35    pub exit_code: Option<i32>,
36    pub child_pid: Option<u32>,
37    pub pgid: Option<i32>,
38    pub completion_delivered: bool,
39    pub status_reason: Option<String>,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum ExitMarker {
44    Code(i32),
45    Killed,
46}
47
48impl PersistedTask {
49    pub fn starting(
50        task_id: String,
51        session_id: String,
52        command: String,
53        workdir: PathBuf,
54        timeout_ms: Option<u64>,
55    ) -> Self {
56        Self {
57            schema_version: SCHEMA_VERSION,
58            task_id,
59            session_id,
60            command,
61            workdir,
62            status: BgTaskStatus::Starting,
63            started_at: unix_millis(),
64            finished_at: None,
65            duration_ms: None,
66            timeout_ms,
67            exit_code: None,
68            child_pid: None,
69            pgid: None,
70            completion_delivered: true,
71            status_reason: None,
72        }
73    }
74
75    pub fn is_terminal(&self) -> bool {
76        self.status.is_terminal()
77    }
78
79    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
80        self.status = BgTaskStatus::Running;
81        self.child_pid = Some(child_pid);
82        self.pgid = Some(pgid);
83    }
84
85    pub fn mark_terminal(
86        &mut self,
87        status: BgTaskStatus,
88        exit_code: Option<i32>,
89        reason: Option<String>,
90    ) {
91        let finished_at = unix_millis();
92        self.status = status;
93        self.exit_code = exit_code;
94        self.finished_at = Some(finished_at);
95        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
96        self.child_pid = None;
97        self.status_reason = reason;
98        self.completion_delivered = false;
99    }
100}
101
102pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
103    storage_dir
104        .join("bash-tasks")
105        .join(hash_session(session_id))
106}
107
108pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
109    let dir = session_tasks_dir(storage_dir, session_id);
110    TaskPaths {
111        json: dir.join(format!("{task_id}.json")),
112        stdout: dir.join(format!("{task_id}.stdout")),
113        stderr: dir.join(format!("{task_id}.stderr")),
114        exit: dir.join(format!("{task_id}.exit")),
115        dir,
116    }
117}
118
119pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
120    let content = fs::read_to_string(path)?;
121    serde_json::from_str(&content).map_err(io::Error::other)
122}
123
124pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
125    if let Some(parent) = path.parent() {
126        fs::create_dir_all(parent)?;
127    }
128    let content = serde_json::to_vec_pretty(task).map_err(io::Error::other)?;
129    atomic_write(path, &content, &task.task_id)
130}
131
132pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
133where
134    F: FnOnce(&mut PersistedTask),
135{
136    let mut task = read_task(path)?;
137    let original_terminal = task.is_terminal();
138    let original = task.clone();
139    update(&mut task);
140    if original_terminal {
141        let completion_delivered = task.completion_delivered;
142        task = original;
143        task.completion_delivered = completion_delivered;
144    }
145    write_task(path, &task)?;
146    Ok(task)
147}
148
149pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
150    if path.exists() {
151        return Ok(());
152    }
153    atomic_write(path, b"killed", "kill")
154}
155
156pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
157    let mut file = match File::open(path) {
158        Ok(file) => file,
159        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
160        Err(error) => return Err(error),
161    };
162    let mut content = String::new();
163    file.read_to_string(&mut content)?;
164    let content = content.trim();
165    if content.is_empty() {
166        return Ok(None);
167    }
168    if content == "killed" {
169        return Ok(Some(ExitMarker::Killed));
170    }
171    match content.parse::<i32>() {
172        Ok(code) => Ok(Some(ExitMarker::Code(code))),
173        Err(_) => Ok(None),
174    }
175}
176
177pub fn atomic_write(path: &Path, content: &[u8], task_id: &str) -> io::Result<()> {
178    let parent = path.parent().unwrap_or_else(|| Path::new("."));
179    fs::create_dir_all(parent)?;
180    let file_name = path
181        .file_name()
182        .and_then(|name| name.to_str())
183        .unwrap_or("task");
184    let tmp = parent.join(format!(
185        ".{file_name}.tmp.{}.{}",
186        std::process::id(),
187        sanitize_task_id(task_id)
188    ));
189    {
190        let mut file = OpenOptions::new()
191            .create(true)
192            .truncate(true)
193            .write(true)
194            .open(&tmp)?;
195        file.write_all(content)?;
196        file.sync_all()?;
197    }
198    fs::rename(&tmp, path)?;
199    Ok(())
200}
201
202fn sanitize_task_id(task_id: &str) -> String {
203    task_id
204        .chars()
205        .map(|ch| match ch {
206            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
207            _ => '_',
208        })
209        .collect()
210}
211
212#[cfg(test)]
213mod tests {
214    use std::thread;
215
216    use super::*;
217
218    #[test]
219    fn atomic_write_temp_names_include_task_id() {
220        let dir = tempfile::tempdir().expect("create temp dir");
221        let path = dir.path().join("task.json");
222
223        let left_path = path.clone();
224        let left = thread::spawn(move || atomic_write(&left_path, b"left", "task-left"));
225        let right_path = path.clone();
226        let right = thread::spawn(move || atomic_write(&right_path, b"right", "task-right"));
227
228        left.join().expect("join left").expect("write left");
229        right.join().expect("join right").expect("write right");
230
231        let content = fs::read_to_string(&path).expect("read final content");
232        assert!(content == "left" || content == "right");
233        assert!(!dir
234            .path()
235            .join(format!(".task.json.tmp.{}", std::process::id()))
236            .exists());
237    }
238}
239
240pub fn create_capture_file(path: &Path) -> io::Result<File> {
241    if let Some(parent) = path.parent() {
242        fs::create_dir_all(parent)?;
243    }
244    File::create(path)
245}
246
247pub fn unix_millis() -> u64 {
248    SystemTime::now()
249        .duration_since(UNIX_EPOCH)
250        .map(|duration| duration.as_millis() as u64)
251        .unwrap_or(0)
252}