Skip to main content

aster/checkpoint/
storage.rs

1//! 检查点存储管理
2//!
3//! 负责检查点的磁盘存储、加载和清理
4
5use std::path::PathBuf;
6use tokio::fs;
7
8use super::session::CheckpointSession;
9use super::types::*;
10
11/// 检查点存储
12pub struct CheckpointStorage {
13    checkpoint_dir: PathBuf,
14}
15
16impl CheckpointStorage {
17    /// 创建新的存储管理器
18    pub fn new() -> Self {
19        let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
20        Self {
21            checkpoint_dir: home.join(".aster").join("checkpoints"),
22        }
23    }
24
25    /// 确保检查点目录存在
26    pub async fn ensure_checkpoint_dir(&self) -> Result<(), String> {
27        if !self.checkpoint_dir.exists() {
28            fs::create_dir_all(&self.checkpoint_dir)
29                .await
30                .map_err(|e| format!("Failed to create checkpoint directory: {}", e))?;
31        }
32        Ok(())
33    }
34
35    /// 获取会话目录
36    fn get_session_dir(&self, session_id: &str) -> PathBuf {
37        self.checkpoint_dir.join(session_id)
38    }
39
40    /// 保存检查点到磁盘
41    pub async fn save_checkpoint(
42        &self,
43        session_id: &str,
44        checkpoint: &FileCheckpoint,
45    ) -> Result<(), String> {
46        let session_dir = self.get_session_dir(session_id);
47        if !session_dir.exists() {
48            fs::create_dir_all(&session_dir)
49                .await
50                .map_err(|e| format!("Failed to create session directory: {}", e))?;
51        }
52
53        let file_hash = self.get_path_hash(&checkpoint.path);
54        let checkpoint_file =
55            session_dir.join(format!("{}-{}.json", file_hash, checkpoint.timestamp));
56
57        let data = serde_json::to_string_pretty(checkpoint)
58            .map_err(|e| format!("Failed to serialize checkpoint: {}", e))?;
59
60        fs::write(&checkpoint_file, data)
61            .await
62            .map_err(|e| format!("Failed to write checkpoint file: {}", e))?;
63
64        Ok(())
65    }
66
67    /// 加载会话
68    pub async fn load_session(&self, session_id: &str) -> Result<CheckpointSession, String> {
69        let session_dir = self.get_session_dir(session_id);
70        if !session_dir.exists() {
71            return Err("Session not found".to_string());
72        }
73
74        let mut session = CheckpointSession::new(
75            Some(session_id.to_string()),
76            ".".to_string(),
77            DEFAULT_AUTO_CHECKPOINT_INTERVAL,
78        );
79
80        let mut entries = fs::read_dir(&session_dir)
81            .await
82            .map_err(|e| format!("Failed to read session directory: {}", e))?;
83
84        while let Ok(Some(entry)) = entries.next_entry().await {
85            let path = entry.path();
86            if path.extension().is_some_and(|e| e == "json") {
87                if path.file_name().is_some_and(|n| n == "session.json") {
88                    continue;
89                }
90
91                if let Ok(data) = fs::read_to_string(&path).await {
92                    if let Ok(checkpoint) = serde_json::from_str::<FileCheckpoint>(&data) {
93                        session
94                            .checkpoints
95                            .entry(checkpoint.path.clone())
96                            .or_default()
97                            .push(checkpoint);
98                    }
99                }
100            }
101        }
102
103        // 按时间戳排序
104        for checkpoints in session.checkpoints.values_mut() {
105            checkpoints.sort_by_key(|c| c.timestamp);
106        }
107
108        // 更新索引
109        for (path, checkpoints) in &session.checkpoints {
110            session
111                .current_index
112                .insert(path.clone(), checkpoints.len().saturating_sub(1));
113        }
114
115        Ok(session)
116    }
117
118    /// 清理旧检查点
119    pub async fn cleanup_old_checkpoints(&self) {
120        let cutoff_time = chrono::Utc::now().timestamp_millis()
121            - (CHECKPOINT_RETENTION_DAYS as i64 * 24 * 60 * 60 * 1000);
122
123        if let Ok(mut entries) = fs::read_dir(&self.checkpoint_dir).await {
124            while let Ok(Some(entry)) = entries.next_entry().await {
125                let path = entry.path();
126                if path.is_dir() {
127                    if let Ok(metadata) = fs::metadata(&path).await {
128                        if let Ok(modified) = metadata.modified() {
129                            let modified_ms = modified
130                                .duration_since(std::time::UNIX_EPOCH)
131                                .map(|d| d.as_millis() as i64)
132                                .unwrap_or(0);
133
134                            if modified_ms < cutoff_time {
135                                let _ = fs::remove_dir_all(&path).await;
136                            }
137                        }
138                    }
139                }
140            }
141        }
142    }
143
144    /// 压缩内容(简化实现,使用 base64 编码)
145    pub fn compress_content(&self, content: &str) -> String {
146        use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
147        BASE64.encode(content.as_bytes())
148    }
149
150    /// 解压缩内容
151    pub fn decompress_content(&self, compressed: &str) -> String {
152        use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
153        if let Ok(data) = BASE64.decode(compressed) {
154            if let Ok(s) = String::from_utf8(data) {
155                return s;
156            }
157        }
158        compressed.to_string()
159    }
160
161    /// 获取路径哈希
162    fn get_path_hash(&self, path: &str) -> String {
163        use sha2::{Digest, Sha256};
164        let mut hasher = Sha256::new();
165        hasher.update(path.as_bytes());
166        let result = hasher.finalize();
167        hex::encode(&result[..8])
168    }
169}
170
171impl Default for CheckpointStorage {
172    fn default() -> Self {
173        Self::new()
174    }
175}