Skip to main content

aster/rewind/
file_history.rs

1//! 文件历史跟踪系统
2//!
3//! 提供文件修改跟踪、快照创建、状态恢复功能
4
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::collections::{HashMap, HashSet};
8use std::fs;
9use std::path::{Path, PathBuf};
10
11/// 文件备份信息
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct FileBackup {
14    /// 备份文件名
15    pub backup_file_name: Option<String>,
16    /// 原始文件的最后修改时间
17    pub mtime: u64,
18    /// 版本号
19    pub version: u32,
20    /// 文件哈希
21    pub hash: Option<String>,
22}
23
24/// 快照数据结构
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct FileSnapshot {
27    /// 关联的消息 ID
28    pub message_id: String,
29    /// 快照创建时间
30    pub timestamp: i64,
31    /// 被跟踪文件的备份信息
32    pub tracked_file_backups: HashMap<String, FileBackup>,
33}
34
35/// Rewind 结果
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct RewindResult {
38    pub success: bool,
39    pub files_changed: Vec<String>,
40    pub insertions: u32,
41    pub deletions: u32,
42    pub error: Option<String>,
43}
44
45impl RewindResult {
46    pub fn success(files_changed: Vec<String>, insertions: u32, deletions: u32) -> Self {
47        Self {
48            success: true,
49            files_changed,
50            insertions,
51            deletions,
52            error: None,
53        }
54    }
55
56    pub fn error(msg: impl Into<String>) -> Self {
57        Self {
58            success: false,
59            files_changed: vec![],
60            insertions: 0,
61            deletions: 0,
62            error: Some(msg.into()),
63        }
64    }
65}
66
67/// 文件历史管理器
68pub struct FileHistoryManager {
69    session_id: String,
70    tracked_files: HashSet<String>,
71    snapshots: Vec<FileSnapshot>,
72    backup_dir: PathBuf,
73    enabled: bool,
74}
75
76impl FileHistoryManager {
77    /// 创建新的文件历史管理器
78    pub fn new(session_id: impl Into<String>) -> Self {
79        let session_id = session_id.into();
80        let backup_dir = dirs::config_dir()
81            .unwrap_or_else(|| PathBuf::from("~/.config"))
82            .join("aster")
83            .join("file-history")
84            .join(&session_id);
85
86        // 确保备份目录存在
87        let _ = fs::create_dir_all(&backup_dir);
88
89        Self {
90            session_id,
91            tracked_files: HashSet::new(),
92            snapshots: Vec::new(),
93            backup_dir,
94            enabled: true,
95        }
96    }
97
98    /// 检查是否启用
99    pub fn is_enabled(&self) -> bool {
100        self.enabled
101    }
102
103    /// 启用/禁用文件历史
104    pub fn set_enabled(&mut self, enabled: bool) {
105        self.enabled = enabled;
106    }
107
108    /// 开始跟踪文件
109    pub fn track_file(&mut self, file_path: impl AsRef<Path>) {
110        if !self.enabled {
111            return;
112        }
113        let path = self.normalize_path(file_path.as_ref());
114        self.tracked_files.insert(path);
115    }
116
117    /// 检查文件是否被跟踪
118    pub fn is_tracked(&self, file_path: impl AsRef<Path>) -> bool {
119        let path = self.normalize_path(file_path.as_ref());
120        self.tracked_files.contains(&path)
121    }
122
123    /// 在文件修改前创建备份
124    pub fn backup_file_before_change(&mut self, file_path: impl AsRef<Path>) -> Option<FileBackup> {
125        if !self.enabled {
126            return None;
127        }
128
129        let path = file_path.as_ref();
130        let normalized = self.normalize_path(path);
131
132        // 如果文件不存在,返回空备份
133        if !path.exists() {
134            return Some(FileBackup {
135                backup_file_name: None,
136                mtime: 0,
137                version: 1,
138                hash: None,
139            });
140        }
141
142        // 读取文件内容并计算哈希
143        let content = fs::read(path).ok()?;
144        let hash = self.compute_hash(&content);
145        let mtime = fs::metadata(path)
146            .ok()?
147            .modified()
148            .ok()?
149            .duration_since(std::time::UNIX_EPOCH)
150            .ok()?
151            .as_secs();
152
153        // 生成备份文件名
154        let backup_file_name = self.generate_backup_file_name(path, &hash);
155        let backup_path = self.backup_dir.join(&backup_file_name);
156
157        // 如果备份不存在,创建它
158        if !backup_path.exists() {
159            let _ = fs::write(&backup_path, &content);
160        }
161
162        // 开始跟踪这个文件
163        self.tracked_files.insert(normalized);
164
165        Some(FileBackup {
166            backup_file_name: Some(backup_file_name),
167            mtime,
168            version: 1,
169            hash: Some(hash),
170        })
171    }
172
173    /// 创建快照
174    pub fn create_snapshot(&mut self, message_id: impl Into<String>) {
175        if !self.enabled {
176            return;
177        }
178
179        let mut tracked_file_backups = HashMap::new();
180
181        for file_path in self.tracked_files.clone() {
182            if let Some(backup) = self.backup_file_before_change(&file_path) {
183                tracked_file_backups.insert(file_path, backup);
184            }
185        }
186
187        self.snapshots.push(FileSnapshot {
188            message_id: message_id.into(),
189            timestamp: chrono::Utc::now().timestamp(),
190            tracked_file_backups,
191        });
192    }
193
194    /// 检查是否有指定消息的快照
195    pub fn has_snapshot(&self, message_id: &str) -> bool {
196        self.snapshots.iter().any(|s| s.message_id == message_id)
197    }
198
199    /// 获取快照列表
200    pub fn get_snapshots(&self) -> &[FileSnapshot] {
201        &self.snapshots
202    }
203
204    /// 回退到指定消息的状态
205    pub fn rewind_to_message(&self, message_id: &str, dry_run: bool) -> RewindResult {
206        if !self.enabled {
207            return RewindResult::error("文件历史已禁用");
208        }
209
210        // 查找快照
211        let snapshot = self
212            .snapshots
213            .iter()
214            .rev()
215            .find(|s| s.message_id == message_id);
216        let snapshot = match snapshot {
217            Some(s) => s,
218            None => return RewindResult::error(format!("未找到消息 {} 的快照", message_id)),
219        };
220
221        self.apply_snapshot(snapshot, dry_run)
222    }
223
224    /// 应用快照
225    fn apply_snapshot(&self, snapshot: &FileSnapshot, dry_run: bool) -> RewindResult {
226        let mut files_changed = Vec::new();
227        let mut insertions = 0u32;
228        let mut deletions = 0u32;
229
230        // 遍历快照中的所有文件备份
231        for (file_path, backup) in &snapshot.tracked_file_backups {
232            let path = Path::new(file_path);
233
234            if backup.backup_file_name.is_none() {
235                // 文件在快照时不存在,应该删除
236                if path.exists() {
237                    deletions += self.count_lines(path);
238                    if !dry_run {
239                        let _ = fs::remove_file(path);
240                    }
241                    files_changed.push(file_path.clone());
242                }
243            } else if let Some(ref backup_name) = backup.backup_file_name {
244                // 恢复文件内容
245                let backup_path = self.backup_dir.join(backup_name);
246                if !backup_path.exists() {
247                    continue;
248                }
249
250                // 检查文件是否需要恢复(通过哈希比较)
251                let current_hash = if path.exists() {
252                    fs::read(path).ok().map(|c| self.compute_hash(&c))
253                } else {
254                    None
255                };
256
257                let needs_restore = current_hash.as_ref() != backup.hash.as_ref();
258
259                if needs_restore {
260                    let (ins, del) = self.calculate_diff(path, &backup_path);
261                    insertions += ins;
262                    deletions += del;
263
264                    if !dry_run {
265                        if let Ok(content) = fs::read(&backup_path) {
266                            if let Some(parent) = path.parent() {
267                                let _ = fs::create_dir_all(parent);
268                            }
269                            let _ = fs::write(path, content);
270                        }
271                    }
272                    files_changed.push(file_path.clone());
273                }
274            }
275        }
276
277        RewindResult::success(files_changed, insertions, deletions)
278    }
279
280    /// 计算文件差异
281    fn calculate_diff(&self, current: &Path, backup: &Path) -> (u32, u32) {
282        let current_lines = self.count_lines(current);
283        let backup_lines = self.count_lines(backup);
284
285        let insertions = backup_lines.saturating_sub(current_lines);
286        let deletions = current_lines.saturating_sub(backup_lines);
287
288        (insertions, deletions)
289    }
290
291    /// 计算文件行数
292    fn count_lines(&self, path: &Path) -> u32 {
293        fs::read_to_string(path)
294            .map(|s| s.lines().count() as u32)
295            .unwrap_or(0)
296    }
297
298    /// 生成备份文件名
299    fn generate_backup_file_name(&self, file_path: &Path, hash: &str) -> String {
300        let _file_name = file_path
301            .file_name()
302            .and_then(|n| n.to_str())
303            .unwrap_or("file");
304        let ext = file_path.extension().and_then(|e| e.to_str()).unwrap_or("");
305        let name = file_path
306            .file_stem()
307            .and_then(|n| n.to_str())
308            .unwrap_or("file");
309
310        if ext.is_empty() {
311            format!("{}_{}", name, hash.get(..8).unwrap_or(hash))
312        } else {
313            format!("{}_{}.{}", name, hash.get(..8).unwrap_or(hash), ext)
314        }
315    }
316
317    /// 计算文件内容的哈希
318    fn compute_hash(&self, content: &[u8]) -> String {
319        let mut hasher = Sha256::new();
320        hasher.update(content);
321        format!("{:x}", hasher.finalize())
322    }
323
324    /// 规范化文件路径
325    fn normalize_path(&self, path: &Path) -> String {
326        if path.is_absolute() {
327            path.display().to_string()
328        } else {
329            std::env::current_dir()
330                .map(|cwd| cwd.join(path).display().to_string())
331                .unwrap_or_else(|_| path.display().to_string())
332        }
333    }
334
335    /// 清理备份文件
336    pub fn cleanup(&self) {
337        let _ = fs::remove_dir_all(&self.backup_dir);
338    }
339
340    /// 获取被跟踪的文件数量
341    pub fn get_tracked_files_count(&self) -> usize {
342        self.tracked_files.len()
343    }
344
345    /// 获取快照数量
346    pub fn get_snapshots_count(&self) -> usize {
347        self.snapshots.len()
348    }
349}
350
351// ============ 增强功能 ============
352
353impl FileHistoryManager {
354    /// 获取会话 ID
355    pub fn session_id(&self) -> &str {
356        &self.session_id
357    }
358
359    /// 获取备份目录
360    pub fn backup_dir(&self) -> &Path {
361        &self.backup_dir
362    }
363
364    /// 获取所有被跟踪的文件
365    pub fn get_tracked_files(&self) -> Vec<String> {
366        self.tracked_files.iter().cloned().collect()
367    }
368
369    /// 停止跟踪文件
370    pub fn untrack_file(&mut self, file_path: impl AsRef<Path>) {
371        let path = self.normalize_path(file_path.as_ref());
372        self.tracked_files.remove(&path);
373    }
374
375    /// 清除所有跟踪的文件
376    pub fn clear_tracked_files(&mut self) {
377        self.tracked_files.clear();
378    }
379
380    /// 获取指定消息的快照
381    pub fn get_snapshot(&self, message_id: &str) -> Option<&FileSnapshot> {
382        self.snapshots.iter().find(|s| s.message_id == message_id)
383    }
384
385    /// 获取最新的快照
386    pub fn get_latest_snapshot(&self) -> Option<&FileSnapshot> {
387        self.snapshots.last()
388    }
389
390    /// 删除指定消息之后的所有快照
391    pub fn remove_snapshots_after(&mut self, message_id: &str) -> usize {
392        let idx = self
393            .snapshots
394            .iter()
395            .position(|s| s.message_id == message_id);
396        match idx {
397            Some(i) if i + 1 < self.snapshots.len() => {
398                let removed = self.snapshots.len() - i - 1;
399                self.snapshots.truncate(i + 1);
400                removed
401            }
402            _ => 0,
403        }
404    }
405
406    /// 获取文件在指定快照时的内容
407    pub fn get_file_content_at_snapshot(
408        &self,
409        message_id: &str,
410        file_path: &str,
411    ) -> Option<Vec<u8>> {
412        let snapshot = self.get_snapshot(message_id)?;
413        let backup = snapshot.tracked_file_backups.get(file_path)?;
414        let backup_name = backup.backup_file_name.as_ref()?;
415        let backup_path = self.backup_dir.join(backup_name);
416        fs::read(&backup_path).ok()
417    }
418
419    /// 获取备份目录大小(字节)
420    pub fn get_backup_size(&self) -> u64 {
421        self.calculate_dir_size(&self.backup_dir)
422    }
423
424    fn calculate_dir_size(&self, path: &Path) -> u64 {
425        fs::read_dir(path)
426            .map(|entries| {
427                entries
428                    .filter_map(|e| e.ok())
429                    .map(|e| e.metadata().map(|m| m.len()).unwrap_or(0))
430                    .sum()
431            })
432            .unwrap_or(0)
433    }
434}
435
436// ============ 单元测试 ============
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use std::io::Write;
442    use tempfile::TempDir;
443
444    fn create_test_file(dir: &Path, name: &str, content: &str) -> PathBuf {
445        let path = dir.join(name);
446        let mut file = fs::File::create(&path).unwrap();
447        file.write_all(content.as_bytes()).unwrap();
448        path
449    }
450
451    #[test]
452    fn test_new_manager() {
453        let manager = FileHistoryManager::new("test-session");
454        assert_eq!(manager.session_id(), "test-session");
455        assert!(manager.is_enabled());
456        assert_eq!(manager.get_tracked_files_count(), 0);
457        assert_eq!(manager.get_snapshots_count(), 0);
458        manager.cleanup();
459    }
460
461    #[test]
462    fn test_track_file() {
463        let mut manager = FileHistoryManager::new("test-track");
464        manager.track_file("/tmp/test.rs");
465        assert!(manager.is_tracked("/tmp/test.rs"));
466        assert!(!manager.is_tracked("/tmp/other.rs"));
467        assert_eq!(manager.get_tracked_files_count(), 1);
468        manager.cleanup();
469    }
470
471    #[test]
472    fn test_untrack_file() {
473        let mut manager = FileHistoryManager::new("test-untrack");
474        manager.track_file("/tmp/test.rs");
475        assert!(manager.is_tracked("/tmp/test.rs"));
476        manager.untrack_file("/tmp/test.rs");
477        assert!(!manager.is_tracked("/tmp/test.rs"));
478        manager.cleanup();
479    }
480
481    #[test]
482    fn test_backup_and_snapshot() {
483        let temp_dir = TempDir::new().unwrap();
484        let test_file = create_test_file(temp_dir.path(), "test.txt", "hello world");
485
486        let mut manager = FileHistoryManager::new("test-backup");
487
488        // 备份文件
489        let backup = manager.backup_file_before_change(&test_file);
490        assert!(backup.is_some());
491        let backup = backup.unwrap();
492        assert!(backup.backup_file_name.is_some());
493        assert!(backup.hash.is_some());
494
495        // 创建快照
496        manager.create_snapshot("msg-1");
497        assert_eq!(manager.get_snapshots_count(), 1);
498        assert!(manager.has_snapshot("msg-1"));
499
500        manager.cleanup();
501    }
502
503    #[test]
504    fn test_rewind_to_message() {
505        let temp_dir = TempDir::new().unwrap();
506        let test_file = create_test_file(temp_dir.path(), "test.txt", "original content");
507
508        let mut manager = FileHistoryManager::new("test-rewind");
509
510        // 备份原始状态
511        manager.backup_file_before_change(&test_file);
512        manager.create_snapshot("msg-1");
513
514        // 修改文件
515        fs::write(&test_file, "modified content").unwrap();
516
517        // 预览回退
518        let preview = manager.rewind_to_message("msg-1", true);
519        assert!(preview.success);
520
521        // 文件应该还是修改后的内容(dry_run)
522        let content = fs::read_to_string(&test_file).unwrap();
523        assert_eq!(content, "modified content");
524
525        // 实际回退
526        let result = manager.rewind_to_message("msg-1", false);
527        assert!(result.success);
528
529        // 文件应该恢复为原始内容
530        let content = fs::read_to_string(&test_file).unwrap();
531        assert_eq!(content, "original content");
532
533        manager.cleanup();
534    }
535
536    #[test]
537    fn test_rewind_nonexistent_snapshot() {
538        let manager = FileHistoryManager::new("test-nonexistent");
539        let result = manager.rewind_to_message("nonexistent", false);
540        assert!(!result.success);
541        assert!(result.error.is_some());
542        manager.cleanup();
543    }
544
545    #[test]
546    fn test_disabled_manager() {
547        let mut manager = FileHistoryManager::new("test-disabled");
548        manager.set_enabled(false);
549        assert!(!manager.is_enabled());
550
551        manager.track_file("/tmp/test.rs");
552        assert_eq!(manager.get_tracked_files_count(), 0);
553
554        manager.create_snapshot("msg-1");
555        assert_eq!(manager.get_snapshots_count(), 0);
556
557        let result = manager.rewind_to_message("msg-1", false);
558        assert!(!result.success);
559
560        manager.cleanup();
561    }
562
563    #[test]
564    fn test_remove_snapshots_after() {
565        let mut manager = FileHistoryManager::new("test-remove");
566
567        manager.create_snapshot("msg-1");
568        manager.create_snapshot("msg-2");
569        manager.create_snapshot("msg-3");
570        assert_eq!(manager.get_snapshots_count(), 3);
571
572        let removed = manager.remove_snapshots_after("msg-1");
573        assert_eq!(removed, 2);
574        assert_eq!(manager.get_snapshots_count(), 1);
575        assert!(manager.has_snapshot("msg-1"));
576        assert!(!manager.has_snapshot("msg-2"));
577
578        manager.cleanup();
579    }
580
581    #[test]
582    fn test_compute_hash() {
583        let manager = FileHistoryManager::new("test-hash");
584        let hash1 = manager.compute_hash(b"hello");
585        let hash2 = manager.compute_hash(b"hello");
586        let hash3 = manager.compute_hash(b"world");
587
588        assert_eq!(hash1, hash2);
589        assert_ne!(hash1, hash3);
590        assert_eq!(hash1.len(), 64); // SHA256 hex
591
592        manager.cleanup();
593    }
594}