Skip to main content

aster/checkpoint/
session.rs

1//! 检查点会话管理
2//!
3//! 管理检查点会话的创建、加载和保存
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::diff::DiffEngine;
10use super::storage::CheckpointStorage;
11use super::types::*;
12
13/// 检查点会话
14pub struct CheckpointSession {
15    pub id: String,
16    pub start_time: i64,
17    pub working_directory: String,
18    pub checkpoints: HashMap<String, Vec<FileCheckpoint>>,
19    pub current_index: HashMap<String, usize>,
20    pub edit_counts: HashMap<String, u32>,
21    pub auto_checkpoint_interval: u32,
22    pub metadata: Option<SessionMetadata>,
23}
24
25impl CheckpointSession {
26    /// 创建新会话
27    pub fn new(
28        id: Option<String>,
29        working_directory: String,
30        auto_checkpoint_interval: u32,
31    ) -> Self {
32        let session_id = id.unwrap_or_else(generate_session_id);
33
34        Self {
35            id: session_id,
36            start_time: chrono::Utc::now().timestamp_millis(),
37            working_directory,
38            checkpoints: HashMap::new(),
39            current_index: HashMap::new(),
40            edit_counts: HashMap::new(),
41            auto_checkpoint_interval,
42            metadata: Some(SessionMetadata {
43                git_branch: get_git_branch(),
44                git_commit: get_git_commit(),
45                tags: None,
46                total_size: Some(0),
47            }),
48        }
49    }
50
51    /// 获取文件的检查点列表
52    pub fn get_checkpoints(&self, file_path: &str) -> Option<&Vec<FileCheckpoint>> {
53        self.checkpoints.get(file_path)
54    }
55
56    /// 获取文件的当前检查点索引
57    pub fn get_current_index(&self, file_path: &str) -> Option<usize> {
58        self.current_index.get(file_path).copied()
59    }
60}
61
62/// 检查点管理器
63pub struct CheckpointManager {
64    session: Arc<RwLock<Option<CheckpointSession>>>,
65    storage: CheckpointStorage,
66    diff_engine: DiffEngine,
67}
68
69impl CheckpointManager {
70    /// 创建新的检查点管理器
71    pub fn new() -> Self {
72        Self {
73            session: Arc::new(RwLock::new(None)),
74            storage: CheckpointStorage::new(),
75            diff_engine: DiffEngine::new(),
76        }
77    }
78
79    /// 初始化检查点系统
80    pub async fn init(
81        &self,
82        session_id: Option<String>,
83        auto_checkpoint_interval: u32,
84    ) -> Result<(), String> {
85        self.storage.ensure_checkpoint_dir().await?;
86
87        let working_dir = std::env::current_dir()
88            .map(|p| p.to_string_lossy().to_string())
89            .unwrap_or_else(|_| ".".to_string());
90
91        let session =
92            CheckpointSession::new(session_id.clone(), working_dir, auto_checkpoint_interval);
93
94        // 如果有 session_id,尝试加载现有会话
95        if let Some(ref id) = session_id {
96            if let Ok(loaded) = self.storage.load_session(id).await {
97                *self.session.write().await = Some(loaded);
98                return Ok(());
99            }
100        }
101
102        *self.session.write().await = Some(session);
103
104        // 清理旧检查点
105        self.storage.cleanup_old_checkpoints().await;
106
107        Ok(())
108    }
109
110    /// 创建检查点
111    pub async fn create_checkpoint(
112        &self,
113        file_path: &str,
114        options: Option<CreateCheckpointOptions>,
115    ) -> Option<FileCheckpoint> {
116        let mut session_guard = self.session.write().await;
117        let session = session_guard.as_mut()?;
118
119        let absolute_path = std::path::Path::new(file_path)
120            .canonicalize()
121            .ok()?
122            .to_string_lossy()
123            .to_string();
124
125        // 读取文件内容
126        let content = tokio::fs::read_to_string(&absolute_path).await.ok()?;
127        let hash = get_content_hash(&content);
128
129        // 检查内容是否与上次检查点相同
130        let existing = session.checkpoints.get(&absolute_path);
131        if let Some(checkpoints) = existing {
132            if let Some(last) = checkpoints.last() {
133                if last.hash == hash {
134                    return Some(last.clone());
135                }
136            }
137        }
138
139        let opts = options.unwrap_or_default();
140        let edit_count = session
141            .edit_counts
142            .get(&absolute_path)
143            .copied()
144            .unwrap_or(0);
145
146        // 决定使用完整内容还是 diff
147        let use_full_content =
148            existing.is_none_or(|c| c.is_empty()) || opts.force_full_content.unwrap_or(false);
149
150        let (checkpoint_content, checkpoint_diff, compressed) = if use_full_content {
151            let (content_str, is_compressed) = if content.len() > COMPRESSION_THRESHOLD_BYTES {
152                (self.storage.compress_content(&content), true)
153            } else {
154                (content.clone(), false)
155            };
156            (Some(content_str), None, is_compressed)
157        } else {
158            let last_content = self.reconstruct_content_internal(session, &absolute_path, None)?;
159            let diff = self.diff_engine.calculate_diff(&last_content, &content);
160            (None, Some(diff), false)
161        };
162
163        let metadata = tokio::fs::metadata(&absolute_path)
164            .await
165            .ok()
166            .map(|m| FileMetadata {
167                mode: None,
168                uid: None,
169                gid: None,
170                size: Some(m.len()),
171            });
172
173        let checkpoint = FileCheckpoint {
174            path: absolute_path.clone(),
175            content: checkpoint_content,
176            diff: checkpoint_diff,
177            hash,
178            timestamp: chrono::Utc::now().timestamp_millis(),
179            name: opts.name,
180            description: opts.description,
181            git_commit: get_git_commit(),
182            edit_count: Some(edit_count),
183            compressed: Some(compressed),
184            metadata,
185            tags: opts.tags,
186        };
187
188        // 添加到会话
189        session
190            .checkpoints
191            .entry(absolute_path.clone())
192            .or_insert_with(Vec::new)
193            .push(checkpoint.clone());
194
195        // 限制检查点数量
196        if let Some(checkpoints) = session.checkpoints.get_mut(&absolute_path) {
197            if checkpoints.len() > MAX_CHECKPOINTS_PER_FILE {
198                let to_remove = checkpoints.len() - MAX_CHECKPOINTS_PER_FILE;
199                checkpoints.drain(1..=to_remove);
200            }
201        }
202
203        // 更新索引
204        let len = session
205            .checkpoints
206            .get(&absolute_path)
207            .map_or(0, |c| c.len());
208        session
209            .current_index
210            .insert(absolute_path.clone(), len.saturating_sub(1));
211        session.edit_counts.insert(absolute_path, 0);
212
213        // 保存到磁盘
214        let _ = self.storage.save_checkpoint(&session.id, &checkpoint).await;
215
216        Some(checkpoint)
217    }
218
219    /// 跟踪文件编辑
220    pub async fn track_file_edit(&self, file_path: &str) {
221        let should_checkpoint = {
222            let mut session_guard = self.session.write().await;
223            if let Some(session) = session_guard.as_mut() {
224                let absolute_path = std::path::Path::new(file_path)
225                    .canonicalize()
226                    .map(|p| p.to_string_lossy().to_string())
227                    .unwrap_or_else(|_| file_path.to_string());
228
229                let edit_count = session
230                    .edit_counts
231                    .entry(absolute_path.clone())
232                    .or_insert(0);
233                *edit_count += 1;
234
235                // 检查是否需要自动检查点
236                if *edit_count >= session.auto_checkpoint_interval {
237                    Some((absolute_path, *edit_count))
238                } else {
239                    None
240                }
241            } else {
242                None
243            }
244        };
245
246        // 在锁释放后创建检查点
247        if let Some((absolute_path, edit_count)) = should_checkpoint {
248            self.create_checkpoint(
249                &absolute_path,
250                Some(CreateCheckpointOptions {
251                    name: Some(format!("Auto-checkpoint at {} edits", edit_count)),
252                    ..Default::default()
253                }),
254            )
255            .await;
256        }
257    }
258
259    /// 恢复检查点
260    pub async fn restore_checkpoint(
261        &self,
262        file_path: &str,
263        index: Option<usize>,
264        options: Option<CheckpointRestoreOptions>,
265    ) -> CheckpointResult {
266        let absolute_path = std::path::Path::new(file_path)
267            .canonicalize()
268            .map(|p| p.to_string_lossy().to_string())
269            .unwrap_or_else(|_| file_path.to_string());
270
271        let opts = options.unwrap_or_default();
272
273        // 第一阶段:读取并重建内容
274        let (content, checkpoint_name, should_backup) = {
275            let session_guard = self.session.read().await;
276            let session = match session_guard.as_ref() {
277                Some(s) => s,
278                None => return CheckpointResult::err("No active checkpoint session"),
279            };
280
281            let checkpoints = match session.checkpoints.get(&absolute_path) {
282                Some(c) if !c.is_empty() => c,
283                _ => return CheckpointResult::err("No checkpoints found for this file"),
284            };
285
286            let target_index = index.unwrap_or_else(|| {
287                session
288                    .current_index
289                    .get(&absolute_path)
290                    .copied()
291                    .unwrap_or(checkpoints.len() - 1)
292            });
293
294            if target_index >= checkpoints.len() {
295                return CheckpointResult::err("Invalid checkpoint index");
296            }
297
298            let content = match self.reconstruct_content_internal(
299                session,
300                &absolute_path,
301                Some(target_index),
302            ) {
303                Some(c) => c,
304                None => return CheckpointResult::err("Failed to reconstruct content"),
305            };
306
307            // Dry run 模式
308            if opts.dry_run.unwrap_or(false) {
309                return CheckpointResult::ok_with_content("Dry run successful", content);
310            }
311
312            let checkpoint = &checkpoints[target_index];
313            let name = checkpoint.name.clone().unwrap_or_else(|| {
314                format!(
315                    "checkpoint from {}",
316                    chrono::DateTime::from_timestamp_millis(checkpoint.timestamp)
317                        .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
318                        .unwrap_or_else(|| "unknown".to_string())
319                )
320            });
321
322            (content, name, opts.create_backup.unwrap_or(true))
323        };
324
325        // 第二阶段:创建备份(锁已释放)
326        if should_backup {
327            self.create_checkpoint(
328                &absolute_path,
329                Some(CreateCheckpointOptions {
330                    name: Some("Pre-restore backup".to_string()),
331                    ..Default::default()
332                }),
333            )
334            .await;
335        }
336
337        // 第三阶段:恢复内容
338        if let Err(e) = tokio::fs::write(&absolute_path, &content).await {
339            return CheckpointResult::err(format!("Failed to restore: {}", e));
340        }
341
342        CheckpointResult::ok(format!("Restored to: {}", checkpoint_name))
343    }
344
345    /// 内部重建内容方法
346    fn reconstruct_content_internal(
347        &self,
348        session: &CheckpointSession,
349        file_path: &str,
350        index: Option<usize>,
351    ) -> Option<String> {
352        let checkpoints = session.checkpoints.get(file_path)?;
353        let target_index = index.unwrap_or(checkpoints.len().saturating_sub(1));
354
355        if target_index >= checkpoints.len() {
356            return None;
357        }
358
359        // 找到最近的完整内容检查点
360        let mut base_index = target_index;
361        while base_index > 0 && checkpoints[base_index].content.is_none() {
362            base_index -= 1;
363        }
364
365        let base_checkpoint = &checkpoints[base_index];
366        let mut content = base_checkpoint.content.clone()?;
367
368        // 解压缩
369        if base_checkpoint.compressed.unwrap_or(false) {
370            content = self.storage.decompress_content(&content);
371        }
372
373        // 应用 diff
374        for checkpoint in checkpoints
375            .iter()
376            .take(target_index + 1)
377            .skip(base_index + 1)
378        {
379            if let Some(ref diff) = checkpoint.diff {
380                content = self.diff_engine.apply_diff(&content, diff);
381            } else if let Some(ref c) = checkpoint.content {
382                content = if checkpoint.compressed.unwrap_or(false) {
383                    self.storage.decompress_content(c)
384                } else {
385                    c.clone()
386                };
387            }
388        }
389
390        Some(content)
391    }
392
393    /// Undo - 回到上一个检查点
394    pub async fn undo(&self, file_path: &str) -> CheckpointResult {
395        let session_guard = self.session.read().await;
396        let session = match session_guard.as_ref() {
397            Some(s) => s,
398            None => return CheckpointResult::err("No active checkpoint session"),
399        };
400
401        let absolute_path = std::path::Path::new(file_path)
402            .canonicalize()
403            .map(|p| p.to_string_lossy().to_string())
404            .unwrap_or_else(|_| file_path.to_string());
405
406        let current_index = session
407            .current_index
408            .get(&absolute_path)
409            .copied()
410            .unwrap_or(0);
411        if current_index == 0 {
412            return CheckpointResult::err("Already at oldest checkpoint");
413        }
414
415        drop(session_guard);
416        self.restore_checkpoint(&absolute_path, Some(current_index - 1), None)
417            .await
418    }
419
420    /// Redo - 前进到下一个检查点
421    pub async fn redo(&self, file_path: &str) -> CheckpointResult {
422        let session_guard = self.session.read().await;
423        let session = match session_guard.as_ref() {
424            Some(s) => s,
425            None => return CheckpointResult::err("No active checkpoint session"),
426        };
427
428        let absolute_path = std::path::Path::new(file_path)
429            .canonicalize()
430            .map(|p| p.to_string_lossy().to_string())
431            .unwrap_or_else(|_| file_path.to_string());
432
433        let checkpoints = match session.checkpoints.get(&absolute_path) {
434            Some(c) => c,
435            None => return CheckpointResult::err("No checkpoints available"),
436        };
437
438        let current_index = session
439            .current_index
440            .get(&absolute_path)
441            .copied()
442            .unwrap_or(0);
443        if current_index >= checkpoints.len() - 1 {
444            return CheckpointResult::err("Already at newest checkpoint");
445        }
446
447        drop(session_guard);
448        self.restore_checkpoint(&absolute_path, Some(current_index + 1), None)
449            .await
450    }
451
452    /// 获取检查点历史
453    pub async fn get_checkpoint_history(&self, file_path: &str) -> CheckpointHistory {
454        let session_guard = self.session.read().await;
455        let session = match session_guard.as_ref() {
456            Some(s) => s,
457            None => {
458                return CheckpointHistory {
459                    checkpoints: vec![],
460                    current_index: -1,
461                }
462            }
463        };
464
465        let absolute_path = std::path::Path::new(file_path)
466            .canonicalize()
467            .map(|p| p.to_string_lossy().to_string())
468            .unwrap_or_else(|_| file_path.to_string());
469
470        let checkpoints = session.checkpoints.get(&absolute_path);
471        let current_index = session
472            .current_index
473            .get(&absolute_path)
474            .copied()
475            .unwrap_or(0);
476
477        let items = checkpoints.map_or(vec![], |cps| {
478            cps.iter()
479                .enumerate()
480                .map(|(idx, cp)| CheckpointHistoryItem {
481                    index: idx,
482                    timestamp: cp.timestamp,
483                    hash: cp.hash.clone(),
484                    name: cp.name.clone(),
485                    description: cp.description.clone(),
486                    git_commit: cp.git_commit.clone(),
487                    tags: cp.tags.clone(),
488                    size: cp.metadata.as_ref().and_then(|m| m.size),
489                    compressed: cp.compressed,
490                    current: idx == current_index,
491                })
492                .collect()
493        });
494
495        CheckpointHistory {
496            checkpoints: items,
497            current_index: current_index as i32,
498        }
499    }
500
501    /// 获取统计信息
502    pub async fn get_stats(&self) -> CheckpointStats {
503        let session_guard = self.session.read().await;
504        let session = match session_guard.as_ref() {
505            Some(s) => s,
506            None => {
507                return CheckpointStats {
508                    total_checkpoints: 0,
509                    total_files: 0,
510                    total_size: 0,
511                    oldest_checkpoint: None,
512                    newest_checkpoint: None,
513                    compression_ratio: None,
514                }
515            }
516        };
517
518        let mut total_checkpoints = 0;
519        let mut oldest: Option<i64> = None;
520        let mut newest: Option<i64> = None;
521
522        for checkpoints in session.checkpoints.values() {
523            total_checkpoints += checkpoints.len();
524            for cp in checkpoints {
525                oldest = Some(oldest.map_or(cp.timestamp, |o| o.min(cp.timestamp)));
526                newest = Some(newest.map_or(cp.timestamp, |n| n.max(cp.timestamp)));
527            }
528        }
529
530        CheckpointStats {
531            total_checkpoints,
532            total_files: session.checkpoints.len(),
533            total_size: session
534                .metadata
535                .as_ref()
536                .and_then(|m| m.total_size)
537                .unwrap_or(0),
538            oldest_checkpoint: oldest,
539            newest_checkpoint: newest,
540            compression_ratio: None,
541        }
542    }
543
544    /// 结束会话
545    pub async fn end_session(&self) {
546        *self.session.write().await = None;
547    }
548}
549
550/// 创建检查点选项
551#[derive(Debug, Clone, Default)]
552pub struct CreateCheckpointOptions {
553    pub name: Option<String>,
554    pub description: Option<String>,
555    pub tags: Option<Vec<String>>,
556    pub force_full_content: Option<bool>,
557}
558
559/// 生成会话 ID
560fn generate_session_id() -> String {
561    let uuid_str = uuid::Uuid::new_v4().to_string();
562    format!(
563        "{}-{}",
564        chrono::Utc::now().timestamp_millis(),
565        uuid_str.get(..8).unwrap_or(&uuid_str)
566    )
567}
568
569/// 获取内容哈希
570fn get_content_hash(content: &str) -> String {
571    use sha2::{Digest, Sha256};
572    let mut hasher = Sha256::new();
573    hasher.update(content.as_bytes());
574    let result = hasher.finalize();
575    hex::encode(&result[..8])
576}
577
578/// 获取当前 git 分支
579fn get_git_branch() -> Option<String> {
580    std::process::Command::new("git")
581        .args(["rev-parse", "--abbrev-ref", "HEAD"])
582        .output()
583        .ok()
584        .and_then(|o| {
585            if o.status.success() {
586                String::from_utf8(o.stdout)
587                    .ok()
588                    .map(|s| s.trim().to_string())
589            } else {
590                None
591            }
592        })
593}
594
595/// 获取当前 git commit
596fn get_git_commit() -> Option<String> {
597    std::process::Command::new("git")
598        .args(["rev-parse", "HEAD"])
599        .output()
600        .ok()
601        .and_then(|o| {
602            if o.status.success() {
603                String::from_utf8(o.stdout)
604                    .ok()
605                    .map(|s| s.trim().to_string())
606            } else {
607                None
608            }
609        })
610}
611
612impl Default for CheckpointManager {
613    fn default() -> Self {
614        Self::new()
615    }
616}