Skip to main content

aster/background/
persistence.rs

1//! 后台任务持久化模块
2//!
3//! 负责保存和恢复后台任务状态
4//!
5//! # 功能
6//! - 任务状态持久化
7//! - Agent 状态持久化
8//! - 自动过期清理
9//! - 导入/导出功能
10
11use std::collections::HashMap;
12use std::path::PathBuf;
13use tokio::fs;
14
15use super::types::{
16    AgentStats, PersistedAgentState, PersistedTaskState, PersistenceStats, TaskStats, TaskType,
17};
18
19/// 持久化配置
20#[derive(Debug, Clone)]
21pub struct PersistenceOptions {
22    pub storage_dir: PathBuf,
23    pub auto_restore: bool,
24    pub expiry_time_ms: u64,
25    pub compress: bool,
26}
27
28impl Default for PersistenceOptions {
29    fn default() -> Self {
30        let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
31        Self {
32            storage_dir: home.join(".aster").join("background-tasks"),
33            auto_restore: true,
34            expiry_time_ms: 86_400_000, // 24 小时
35            compress: false,
36        }
37    }
38}
39
40/// 持久化管理器
41pub struct PersistenceManager {
42    storage_dir: PathBuf,
43    options: PersistenceOptions,
44}
45
46impl PersistenceManager {
47    /// 创建新的持久化管理器
48    pub async fn new(options: PersistenceOptions) -> Result<Self, String> {
49        let storage_dir = options.storage_dir.clone();
50
51        // 确保存储目录存在
52        if !storage_dir.exists() {
53            fs::create_dir_all(&storage_dir)
54                .await
55                .map_err(|e| format!("Failed to create storage directory: {}", e))?;
56        }
57
58        Ok(Self {
59            storage_dir,
60            options,
61        })
62    }
63
64    /// 获取任务文件路径
65    fn get_task_file_path(&self, id: &str, task_type: TaskType) -> PathBuf {
66        let prefix = match task_type {
67            TaskType::Bash => "bash",
68            TaskType::Agent => "agent",
69            TaskType::Generic => "generic",
70        };
71        self.storage_dir.join(format!("{}_{}.json", prefix, id))
72    }
73
74    /// 保存任务状态
75    pub async fn save_task(&self, task: &PersistedTaskState) -> Result<(), String> {
76        let file_path = self.get_task_file_path(&task.id, task.task_type);
77        let data = serde_json::to_string_pretty(task)
78            .map_err(|e| format!("Failed to serialize task: {}", e))?;
79
80        fs::write(&file_path, data)
81            .await
82            .map_err(|e| format!("Failed to write task file: {}", e))?;
83
84        Ok(())
85    }
86
87    /// 加载任务状态
88    pub async fn load_task(&self, id: &str, task_type: TaskType) -> Option<PersistedTaskState> {
89        let file_path = self.get_task_file_path(id, task_type);
90
91        if !file_path.exists() {
92            return None;
93        }
94
95        let data = fs::read_to_string(&file_path).await.ok()?;
96        let task: PersistedTaskState = serde_json::from_str(&data).ok()?;
97
98        // 检查是否过期
99        if self.is_expired(&task) {
100            let _ = self.delete_task(id, task_type).await;
101            return None;
102        }
103
104        Some(task)
105    }
106
107    /// 删除任务状态
108    pub async fn delete_task(&self, id: &str, task_type: TaskType) -> Result<(), String> {
109        let file_path = self.get_task_file_path(id, task_type);
110
111        if file_path.exists() {
112            fs::remove_file(&file_path)
113                .await
114                .map_err(|e| format!("Failed to delete task file: {}", e))?;
115        }
116
117        Ok(())
118    }
119
120    /// 检查任务是否过期
121    fn is_expired(&self, task: &PersistedTaskState) -> bool {
122        let now = chrono::Utc::now().timestamp_millis();
123        let age = (now - task.start_time) as u64;
124        age > self.options.expiry_time_ms
125    }
126
127    /// 保存 Agent 状态
128    pub async fn save_agent(&self, agent: &PersistedAgentState) -> Result<(), String> {
129        let agent_dir = self
130            .storage_dir
131            .parent()
132            .unwrap_or(&self.storage_dir)
133            .join("agents");
134
135        if !agent_dir.exists() {
136            fs::create_dir_all(&agent_dir)
137                .await
138                .map_err(|e| format!("Failed to create agent directory: {}", e))?;
139        }
140
141        let file_path = agent_dir.join(format!("{}.json", agent.id));
142        let data = serde_json::to_string_pretty(agent)
143            .map_err(|e| format!("Failed to serialize agent: {}", e))?;
144
145        fs::write(&file_path, data)
146            .await
147            .map_err(|e| format!("Failed to write agent file: {}", e))?;
148
149        Ok(())
150    }
151
152    /// 加载 Agent 状态
153    pub async fn load_agent(&self, id: &str) -> Option<PersistedAgentState> {
154        let agent_dir = self
155            .storage_dir
156            .parent()
157            .unwrap_or(&self.storage_dir)
158            .join("agents");
159        let file_path = agent_dir.join(format!("{}.json", id));
160
161        if !file_path.exists() {
162            return None;
163        }
164
165        let data = fs::read_to_string(&file_path).await.ok()?;
166        serde_json::from_str(&data).ok()
167    }
168
169    /// 列出所有保存的任务
170    pub async fn list_tasks(&self, task_type: Option<TaskType>) -> Vec<PersistedTaskState> {
171        let mut tasks = Vec::new();
172
173        let mut entries = match fs::read_dir(&self.storage_dir).await {
174            Ok(e) => e,
175            Err(_) => return tasks,
176        };
177
178        while let Ok(Some(entry)) = entries.next_entry().await {
179            let path = entry.path();
180            if path.extension().is_none_or(|e| e != "json") {
181                continue;
182            }
183
184            let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
185
186            let file_type = if file_name.starts_with("bash_") {
187                Some(TaskType::Bash)
188            } else if file_name.starts_with("agent_") {
189                Some(TaskType::Agent)
190            } else if file_name.starts_with("generic_") {
191                Some(TaskType::Generic)
192            } else {
193                None
194            };
195
196            if let Some(ft) = file_type {
197                if task_type.is_none() || task_type == Some(ft) {
198                    if let Ok(data) = fs::read_to_string(&path).await {
199                        if let Ok(task) = serde_json::from_str::<PersistedTaskState>(&data) {
200                            tasks.push(task);
201                        }
202                    }
203                }
204            }
205        }
206
207        tasks
208    }
209
210    /// 列出所有保存的 Agent
211    pub async fn list_agents(&self) -> Vec<PersistedAgentState> {
212        let mut agents = Vec::new();
213        let agent_dir = self
214            .storage_dir
215            .parent()
216            .unwrap_or(&self.storage_dir)
217            .join("agents");
218
219        if !agent_dir.exists() {
220            return agents;
221        }
222
223        let mut entries = match fs::read_dir(&agent_dir).await {
224            Ok(e) => e,
225            Err(_) => return agents,
226        };
227
228        while let Ok(Some(entry)) = entries.next_entry().await {
229            let path = entry.path();
230            if path.extension().is_some_and(|e| e == "json") {
231                if let Ok(data) = fs::read_to_string(&path).await {
232                    if let Ok(agent) = serde_json::from_str::<PersistedAgentState>(&data) {
233                        agents.push(agent);
234                    }
235                }
236            }
237        }
238
239        agents
240    }
241
242    /// 清理过期的任务
243    pub async fn cleanup_expired(&self) -> usize {
244        let tasks = self.list_tasks(None).await;
245        let mut cleaned = 0;
246
247        for task in tasks {
248            if self.is_expired(&task) && self.delete_task(&task.id, task.task_type).await.is_ok() {
249                cleaned += 1;
250            }
251        }
252
253        cleaned
254    }
255
256    /// 清理已完成的任务
257    pub async fn cleanup_completed(&self) -> usize {
258        let tasks = self.list_tasks(None).await;
259        let mut cleaned = 0;
260
261        for task in tasks {
262            if (task.status == "completed" || task.status == "failed")
263                && self.delete_task(&task.id, task.task_type).await.is_ok()
264            {
265                cleaned += 1;
266            }
267        }
268
269        cleaned
270    }
271
272    /// 清除所有任务
273    pub async fn clear_all(&self) -> usize {
274        let mut cleared = 0;
275
276        let mut entries = match fs::read_dir(&self.storage_dir).await {
277            Ok(e) => e,
278            Err(_) => return cleared,
279        };
280
281        while let Ok(Some(entry)) = entries.next_entry().await {
282            let path = entry.path();
283            if path.extension().is_some_and(|e| e == "json") && fs::remove_file(&path).await.is_ok()
284            {
285                cleared += 1;
286            }
287        }
288
289        cleared
290    }
291
292    /// 获取统计信息
293    pub async fn get_stats(&self) -> PersistenceStats {
294        let tasks = self.list_tasks(None).await;
295        let agents = self.list_agents().await;
296
297        let mut tasks_by_status: HashMap<String, usize> = HashMap::new();
298        for task in &tasks {
299            *tasks_by_status.entry(task.status.clone()).or_insert(0) += 1;
300        }
301
302        let mut agents_by_status: HashMap<String, usize> = HashMap::new();
303        for agent in &agents {
304            *agents_by_status.entry(agent.status.clone()).or_insert(0) += 1;
305        }
306
307        PersistenceStats {
308            tasks: TaskStats {
309                total: tasks.len(),
310                by_status: tasks_by_status,
311            },
312            agents: AgentStats {
313                total: agents.len(),
314                by_status: agents_by_status,
315            },
316            storage_dir: self.storage_dir.to_string_lossy().to_string(),
317            expiry_time_ms: self.options.expiry_time_ms,
318        }
319    }
320}