Skip to main content

aft/bash_background/
persistence.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::{self, Read, Write};
3use std::path::{Path, PathBuf};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7
8use crate::backup::hash_session;
9
10use super::BgTaskStatus;
11
12const SCHEMA_VERSION: u32 = 1;
13
14#[derive(Debug, Clone)]
15pub struct TaskPaths {
16    pub dir: PathBuf,
17    pub json: PathBuf,
18    pub stdout: PathBuf,
19    pub stderr: PathBuf,
20    pub exit: PathBuf,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PersistedTask {
25    pub schema_version: u32,
26    pub task_id: String,
27    pub session_id: String,
28    pub command: String,
29    pub workdir: PathBuf,
30    pub status: BgTaskStatus,
31    pub started_at: u64,
32    pub finished_at: Option<u64>,
33    pub duration_ms: Option<u64>,
34    pub timeout_ms: Option<u64>,
35    pub exit_code: Option<i32>,
36    pub child_pid: Option<u32>,
37    pub pgid: Option<i32>,
38    pub completion_delivered: bool,
39    pub status_reason: Option<String>,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum ExitMarker {
44    Code(i32),
45    Killed,
46}
47
48impl PersistedTask {
49    pub fn starting(
50        task_id: String,
51        session_id: String,
52        command: String,
53        workdir: PathBuf,
54        timeout_ms: Option<u64>,
55    ) -> Self {
56        Self {
57            schema_version: SCHEMA_VERSION,
58            task_id,
59            session_id,
60            command,
61            workdir,
62            status: BgTaskStatus::Starting,
63            started_at: unix_millis(),
64            finished_at: None,
65            duration_ms: None,
66            timeout_ms,
67            exit_code: None,
68            child_pid: None,
69            pgid: None,
70            completion_delivered: true,
71            status_reason: None,
72        }
73    }
74
75    pub fn is_terminal(&self) -> bool {
76        self.status.is_terminal()
77    }
78
79    pub fn mark_running(&mut self, child_pid: u32, pgid: i32) {
80        self.status = BgTaskStatus::Running;
81        self.child_pid = Some(child_pid);
82        self.pgid = Some(pgid);
83    }
84
85    pub fn mark_terminal(
86        &mut self,
87        status: BgTaskStatus,
88        exit_code: Option<i32>,
89        reason: Option<String>,
90    ) {
91        let finished_at = unix_millis();
92        self.status = status;
93        self.exit_code = exit_code;
94        self.finished_at = Some(finished_at);
95        self.duration_ms = Some(finished_at.saturating_sub(self.started_at));
96        self.child_pid = None;
97        self.status_reason = reason;
98        self.completion_delivered = false;
99    }
100}
101
102pub fn session_tasks_dir(storage_dir: &Path, session_id: &str) -> PathBuf {
103    storage_dir
104        .join("bash-tasks")
105        .join(hash_session(session_id))
106}
107
108pub fn task_paths(storage_dir: &Path, session_id: &str, task_id: &str) -> TaskPaths {
109    let dir = session_tasks_dir(storage_dir, session_id);
110    TaskPaths {
111        json: dir.join(format!("{task_id}.json")),
112        stdout: dir.join(format!("{task_id}.stdout")),
113        stderr: dir.join(format!("{task_id}.stderr")),
114        exit: dir.join(format!("{task_id}.exit")),
115        dir,
116    }
117}
118
119pub fn read_task(path: &Path) -> io::Result<PersistedTask> {
120    let content = fs::read_to_string(path)?;
121    serde_json::from_str(&content).map_err(io::Error::other)
122}
123
124pub fn write_task(path: &Path, task: &PersistedTask) -> io::Result<()> {
125    if let Some(parent) = path.parent() {
126        fs::create_dir_all(parent)?;
127    }
128    let content = serde_json::to_vec_pretty(task).map_err(io::Error::other)?;
129    atomic_write(path, &content)
130}
131
132pub fn update_task<F>(path: &Path, update: F) -> io::Result<PersistedTask>
133where
134    F: FnOnce(&mut PersistedTask),
135{
136    let mut task = read_task(path)?;
137    let original_terminal = task.is_terminal();
138    let original = task.clone();
139    update(&mut task);
140    if original_terminal {
141        let completion_delivered = task.completion_delivered;
142        task = original;
143        task.completion_delivered = completion_delivered;
144    }
145    write_task(path, &task)?;
146    Ok(task)
147}
148
149pub fn write_kill_marker_if_absent(path: &Path) -> io::Result<()> {
150    if path.exists() {
151        return Ok(());
152    }
153    atomic_write(path, b"killed")
154}
155
156pub fn read_exit_marker(path: &Path) -> io::Result<Option<ExitMarker>> {
157    let mut file = match File::open(path) {
158        Ok(file) => file,
159        Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
160        Err(error) => return Err(error),
161    };
162    let mut content = String::new();
163    file.read_to_string(&mut content)?;
164    let content = content.trim();
165    if content.is_empty() {
166        return Ok(None);
167    }
168    if content == "killed" {
169        return Ok(Some(ExitMarker::Killed));
170    }
171    match content.parse::<i32>() {
172        Ok(code) => Ok(Some(ExitMarker::Code(code))),
173        Err(_) => Ok(None),
174    }
175}
176
177pub fn atomic_write(path: &Path, content: &[u8]) -> io::Result<()> {
178    let parent = path.parent().unwrap_or_else(|| Path::new("."));
179    fs::create_dir_all(parent)?;
180    let file_name = path
181        .file_name()
182        .and_then(|name| name.to_str())
183        .unwrap_or("task");
184    let tmp = parent.join(format!(".{file_name}.tmp.{}", std::process::id()));
185    {
186        let mut file = OpenOptions::new()
187            .create(true)
188            .truncate(true)
189            .write(true)
190            .open(&tmp)?;
191        file.write_all(content)?;
192        file.sync_all()?;
193    }
194    fs::rename(&tmp, path)?;
195    Ok(())
196}
197
198pub fn create_capture_file(path: &Path) -> io::Result<File> {
199    if let Some(parent) = path.parent() {
200        fs::create_dir_all(parent)?;
201    }
202    File::create(path)
203}
204
205pub fn unix_millis() -> u64 {
206    SystemTime::now()
207        .duration_since(UNIX_EPOCH)
208        .map(|duration| duration.as_millis() as u64)
209        .unwrap_or(0)
210}