Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9
10use super::BgTaskStatus;
11
12const SCHEMA_VERSION: u32 = 1;
13
14#[derive(Debug, Clone)]
15pub struct TaskPaths {
16    pub dir: PathBuf,
17    pub json: PathBuf,
18    pub stdout: PathBuf,
19    pub stderr: PathBuf,
20    pub exit: PathBuf,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PersistedTask {
25    pub schema_version: u32,
26    pub task_id: String,
27    pub session_id: String,
28    pub command: String,
29    pub workdir: PathBuf,
30    pub status: BgTaskStatus,
31    pub started_at: u64,
32    pub finished_at: Option<u64>,
33    pub duration_ms: Option<u64>,
34    pub timeout_ms: Option<u64>,
35    pub exit_code: Option<i32>,
36    pub child_pid: Option<u32>,
37    pub pgid: Option<i32>,
38    pub completion_delivered: bool,
39    #[serde(default = "default_notify_on_completion")]
40    pub notify_on_completion: bool,
41    pub status_reason: Option<String>,
42}
43
44fn default_notify_on_completion() -> bool {
45    true
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum ExitMarker {
50    Code(i32),
51    Killed,
52}
53
54impl PersistedTask {
55    pub fn starting(
56        task_id: String,
57        session_id: String,
58        command: String,
59        workdir: PathBuf,
60        timeout_ms: Option<u64>,
61        notify_on_completion: bool,
62    ) -> Self {
63        Self {
64            schema_version: SCHEMA_VERSION,
65            task_id,
66            session_id,
67            command,
68            workdir,
69            status: BgTaskStatus::Starting,
70            started_at: unix_millis(),
71            finished_at: None,
72            duration_ms: None,
73            timeout_ms,
74            exit_code: None,
75            child_pid: None,
76            pgid: None,
77            completion_delivered: !notify_on_completion,
78            notify_on_completion,
79            status_reason: None,
80        }
81    }
82
83    pub fn is_terminal(&self) -> bool {
84        self.status.is_terminal()
85    }
86
87    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
88        self.status = BgTaskStatus::Running;
89        self.child_pid = Some(child_pid);
90        self.pgid = Some(pgid);
91    }
92
93    pub fn mark_terminal(
94        &mut self,
95        status: BgTaskStatus,
96        exit_code: Option<i32>,
97        reason: Option<String>,
98    ) {
99        let finished_at = unix_millis();
100        self.status = status;
101        self.exit_code = exit_code;
102        self.finished_at = Some(finished_at);
103        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
104        self.child_pid = None;
105        self.status_reason = reason;
106        self.completion_delivered = !self.notify_on_completion;
107    }
108}
109
110pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
111    storage_dir
112        .join("bash-tasks")
113        .join(hash_session(session_id))
114}
115
116pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
117    let dir = session_tasks_dir(storage_dir, session_id);
118    TaskPaths {
119        json: dir.join(format!("{task_id}.json")),
120        stdout: dir.join(format!("{task_id}.stdout")),
121        stderr: dir.join(format!("{task_id}.stderr")),
122        exit: dir.join(format!("{task_id}.exit")),
123        dir,
124    }
125}
126
127pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
128    let content = fs::read_to_string(path)?;
129    serde_json::from_str(&content).map_err(io::Error::other)
130}
131
132pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
133    if let Some(parent) = path.parent() {
134        fs::create_dir_all(parent)?;
135    }
136    let content = serde_json::to_vec_pretty(task).map_err(io::Error::other)?;
137    atomic_write(path, &content, &task.task_id)
138}
139
140pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
141where
142    F: FnOnce(&mut PersistedTask),
143{
144    let mut task = read_task(path)?;
145    let original_terminal = task.is_terminal();
146    let original = task.clone();
147    update(&mut task);
148    if original_terminal {
149        let completion_delivered = task.completion_delivered;
150        task = original;
151        task.completion_delivered = completion_delivered;
152    }
153    write_task(path, &task)?;
154    Ok(task)
155}
156
157pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
158    if path.exists() {
159        return Ok(());
160    }
161    atomic_write(path, b"killed", "kill")
162}
163
164pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
165    let mut file = match File::open(path) {
166        Ok(file) => file,
167        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
168        Err(error) => return Err(error),
169    };
170    let mut content = String::new();
171    file.read_to_string(&mut content)?;
172    let content = content.trim();
173    if content.is_empty() {
174        return Ok(None);
175    }
176    if content == "killed" {
177        return Ok(Some(ExitMarker::Killed));
178    }
179    match content.parse::<i32>() {
180        Ok(code) => Ok(Some(ExitMarker::Code(code))),
181        Err(_) => Ok(None),
182    }
183}
184
185pub fn atomic_write(path: &Path, content: &[u8], task_id: &str) -> io::Result<()> {
186    let parent = path.parent().unwrap_or_else(|| Path::new("."));
187    fs::create_dir_all(parent)?;
188    let file_name = path
189        .file_name()
190        .and_then(|name| name.to_str())
191        .unwrap_or("task");
192    let tmp = parent.join(format!(
193        ".{file_name}.tmp.{}.{}",
194        std::process::id(),
195        sanitize_task_id(task_id)
196    ));
197    {
198        let mut file = OpenOptions::new()
199            .create(true)
200            .truncate(true)
201            .write(true)
202            .open(&tmp)?;
203        file.write_all(content)?;
204        file.sync_all()?;
205    }
206    fs::rename(&tmp, path)?;
207    Ok(())
208}
209
210fn sanitize_task_id(task_id: &str) -> String {
211    task_id
212        .chars()
213        .map(|ch| match ch {
214            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => ch,
215            _ => '_',
216        })
217        .collect()
218}
219
220pub fn create_capture_file(path: &Path) -> io::Result<File> {
221    if let Some(parent) = path.parent() {
222        fs::create_dir_all(parent)?;
223    }
224    File::create(path)
225}
226
227pub fn unix_millis() -> u64 {
228    SystemTime::now()
229        .duration_since(UNIX_EPOCH)
230        .map(|duration| duration.as_millis() as u64)
231        .unwrap_or(0)
232}
233
234#[cfg(test)]
235mod tests {
236    use std::thread;
237
238    use super::*;
239
240    #[test]
241    fn atomic_write_temp_names_include_task_id() {
242        let dir = tempfile::tempdir().expect("create temp dir");
243        let path = dir.path().join("task.json");
244
245        let left_path = path.clone();
246        let left = thread::spawn(move || atomic_write(&left_path, b"left", "task-left"));
247        let right_path = path.clone();
248        let right = thread::spawn(move || atomic_write(&right_path, b"right", "task-right"));
249
250        left.join().expect("join left").expect("write left");
251        right.join().expect("join right").expect("write right");
252
253        let content = fs::read_to_string(&path).expect("read final content");
254        assert!(content == "left" || content == "right");
255        assert!(!dir
256            .path()
257            .join(format!(".task.json.tmp.{}", std::process::id()))
258            .exists());
259    }
260}