1use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::diff::DiffEngine;
10use super::storage::CheckpointStorage;
11use super::types::*;
12
13pub struct CheckpointSession {
15 pub id: String,
16 pub start_time: i64,
17 pub working_directory: String,
18 pub checkpoints: HashMap<String, Vec<FileCheckpoint>>,
19 pub current_index: HashMap<String, usize>,
20 pub edit_counts: HashMap<String, u32>,
21 pub auto_checkpoint_interval: u32,
22 pub metadata: Option<SessionMetadata>,
23}
24
25impl CheckpointSession {
26 pub fn new(
28 id: Option<String>,
29 working_directory: String,
30 auto_checkpoint_interval: u32,
31 ) -> Self {
32 let session_id = id.unwrap_or_else(generate_session_id);
33
34 Self {
35 id: session_id,
36 start_time: chrono::Utc::now().timestamp_millis(),
37 working_directory,
38 checkpoints: HashMap::new(),
39 current_index: HashMap::new(),
40 edit_counts: HashMap::new(),
41 auto_checkpoint_interval,
42 metadata: Some(SessionMetadata {
43 git_branch: get_git_branch(),
44 git_commit: get_git_commit(),
45 tags: None,
46 total_size: Some(0),
47 }),
48 }
49 }
50
51 pub fn get_checkpoints(&self, file_path: &str) -> Option<&Vec<FileCheckpoint>> {
53 self.checkpoints.get(file_path)
54 }
55
56 pub fn get_current_index(&self, file_path: &str) -> Option<usize> {
58 self.current_index.get(file_path).copied()
59 }
60}
61
62pub struct CheckpointManager {
64 session: Arc<RwLock<Option<CheckpointSession>>>,
65 storage: CheckpointStorage,
66 diff_engine: DiffEngine,
67}
68
69impl CheckpointManager {
70 pub fn new() -> Self {
72 Self {
73 session: Arc::new(RwLock::new(None)),
74 storage: CheckpointStorage::new(),
75 diff_engine: DiffEngine::new(),
76 }
77 }
78
79 pub async fn init(
81 &self,
82 session_id: Option<String>,
83 auto_checkpoint_interval: u32,
84 ) -> Result<(), String> {
85 self.storage.ensure_checkpoint_dir().await?;
86
87 let working_dir = std::env::current_dir()
88 .map(|p| p.to_string_lossy().to_string())
89 .unwrap_or_else(|_| ".".to_string());
90
91 let session =
92 CheckpointSession::new(session_id.clone(), working_dir, auto_checkpoint_interval);
93
94 if let Some(ref id) = session_id {
96 if let Ok(loaded) = self.storage.load_session(id).await {
97 *self.session.write().await = Some(loaded);
98 return Ok(());
99 }
100 }
101
102 *self.session.write().await = Some(session);
103
104 self.storage.cleanup_old_checkpoints().await;
106
107 Ok(())
108 }
109
110 pub async fn create_checkpoint(
112 &self,
113 file_path: &str,
114 options: Option<CreateCheckpointOptions>,
115 ) -> Option<FileCheckpoint> {
116 let mut session_guard = self.session.write().await;
117 let session = session_guard.as_mut()?;
118
119 let absolute_path = std::path::Path::new(file_path)
120 .canonicalize()
121 .ok()?
122 .to_string_lossy()
123 .to_string();
124
125 let content = tokio::fs::read_to_string(&absolute_path).await.ok()?;
127 let hash = get_content_hash(&content);
128
129 let existing = session.checkpoints.get(&absolute_path);
131 if let Some(checkpoints) = existing {
132 if let Some(last) = checkpoints.last() {
133 if last.hash == hash {
134 return Some(last.clone());
135 }
136 }
137 }
138
139 let opts = options.unwrap_or_default();
140 let edit_count = session
141 .edit_counts
142 .get(&absolute_path)
143 .copied()
144 .unwrap_or(0);
145
146 let use_full_content =
148 existing.is_none_or(|c| c.is_empty()) || opts.force_full_content.unwrap_or(false);
149
150 let (checkpoint_content, checkpoint_diff, compressed) = if use_full_content {
151 let (content_str, is_compressed) = if content.len() > COMPRESSION_THRESHOLD_BYTES {
152 (self.storage.compress_content(&content), true)
153 } else {
154 (content.clone(), false)
155 };
156 (Some(content_str), None, is_compressed)
157 } else {
158 let last_content = self.reconstruct_content_internal(session, &absolute_path, None)?;
159 let diff = self.diff_engine.calculate_diff(&last_content, &content);
160 (None, Some(diff), false)
161 };
162
163 let metadata = tokio::fs::metadata(&absolute_path)
164 .await
165 .ok()
166 .map(|m| FileMetadata {
167 mode: None,
168 uid: None,
169 gid: None,
170 size: Some(m.len()),
171 });
172
173 let checkpoint = FileCheckpoint {
174 path: absolute_path.clone(),
175 content: checkpoint_content,
176 diff: checkpoint_diff,
177 hash,
178 timestamp: chrono::Utc::now().timestamp_millis(),
179 name: opts.name,
180 description: opts.description,
181 git_commit: get_git_commit(),
182 edit_count: Some(edit_count),
183 compressed: Some(compressed),
184 metadata,
185 tags: opts.tags,
186 };
187
188 session
190 .checkpoints
191 .entry(absolute_path.clone())
192 .or_insert_with(Vec::new)
193 .push(checkpoint.clone());
194
195 if let Some(checkpoints) = session.checkpoints.get_mut(&absolute_path) {
197 if checkpoints.len() > MAX_CHECKPOINTS_PER_FILE {
198 let to_remove = checkpoints.len() - MAX_CHECKPOINTS_PER_FILE;
199 checkpoints.drain(1..=to_remove);
200 }
201 }
202
203 let len = session
205 .checkpoints
206 .get(&absolute_path)
207 .map_or(0, |c| c.len());
208 session
209 .current_index
210 .insert(absolute_path.clone(), len.saturating_sub(1));
211 session.edit_counts.insert(absolute_path, 0);
212
213 let _ = self.storage.save_checkpoint(&session.id, &checkpoint).await;
215
216 Some(checkpoint)
217 }
218
219 pub async fn track_file_edit(&self, file_path: &str) {
221 let should_checkpoint = {
222 let mut session_guard = self.session.write().await;
223 if let Some(session) = session_guard.as_mut() {
224 let absolute_path = std::path::Path::new(file_path)
225 .canonicalize()
226 .map(|p| p.to_string_lossy().to_string())
227 .unwrap_or_else(|_| file_path.to_string());
228
229 let edit_count = session
230 .edit_counts
231 .entry(absolute_path.clone())
232 .or_insert(0);
233 *edit_count += 1;
234
235 if *edit_count >= session.auto_checkpoint_interval {
237 Some((absolute_path, *edit_count))
238 } else {
239 None
240 }
241 } else {
242 None
243 }
244 };
245
246 if let Some((absolute_path, edit_count)) = should_checkpoint {
248 self.create_checkpoint(
249 &absolute_path,
250 Some(CreateCheckpointOptions {
251 name: Some(format!("Auto-checkpoint at {} edits", edit_count)),
252 ..Default::default()
253 }),
254 )
255 .await;
256 }
257 }
258
259 pub async fn restore_checkpoint(
261 &self,
262 file_path: &str,
263 index: Option<usize>,
264 options: Option<CheckpointRestoreOptions>,
265 ) -> CheckpointResult {
266 let absolute_path = std::path::Path::new(file_path)
267 .canonicalize()
268 .map(|p| p.to_string_lossy().to_string())
269 .unwrap_or_else(|_| file_path.to_string());
270
271 let opts = options.unwrap_or_default();
272
273 let (content, checkpoint_name, should_backup) = {
275 let session_guard = self.session.read().await;
276 let session = match session_guard.as_ref() {
277 Some(s) => s,
278 None => return CheckpointResult::err("No active checkpoint session"),
279 };
280
281 let checkpoints = match session.checkpoints.get(&absolute_path) {
282 Some(c) if !c.is_empty() => c,
283 _ => return CheckpointResult::err("No checkpoints found for this file"),
284 };
285
286 let target_index = index.unwrap_or_else(|| {
287 session
288 .current_index
289 .get(&absolute_path)
290 .copied()
291 .unwrap_or(checkpoints.len() - 1)
292 });
293
294 if target_index >= checkpoints.len() {
295 return CheckpointResult::err("Invalid checkpoint index");
296 }
297
298 let content = match self.reconstruct_content_internal(
299 session,
300 &absolute_path,
301 Some(target_index),
302 ) {
303 Some(c) => c,
304 None => return CheckpointResult::err("Failed to reconstruct content"),
305 };
306
307 if opts.dry_run.unwrap_or(false) {
309 return CheckpointResult::ok_with_content("Dry run successful", content);
310 }
311
312 let checkpoint = &checkpoints[target_index];
313 let name = checkpoint.name.clone().unwrap_or_else(|| {
314 format!(
315 "checkpoint from {}",
316 chrono::DateTime::from_timestamp_millis(checkpoint.timestamp)
317 .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
318 .unwrap_or_else(|| "unknown".to_string())
319 )
320 });
321
322 (content, name, opts.create_backup.unwrap_or(true))
323 };
324
325 if should_backup {
327 self.create_checkpoint(
328 &absolute_path,
329 Some(CreateCheckpointOptions {
330 name: Some("Pre-restore backup".to_string()),
331 ..Default::default()
332 }),
333 )
334 .await;
335 }
336
337 if let Err(e) = tokio::fs::write(&absolute_path, &content).await {
339 return CheckpointResult::err(format!("Failed to restore: {}", e));
340 }
341
342 CheckpointResult::ok(format!("Restored to: {}", checkpoint_name))
343 }
344
345 fn reconstruct_content_internal(
347 &self,
348 session: &CheckpointSession,
349 file_path: &str,
350 index: Option<usize>,
351 ) -> Option<String> {
352 let checkpoints = session.checkpoints.get(file_path)?;
353 let target_index = index.unwrap_or(checkpoints.len().saturating_sub(1));
354
355 if target_index >= checkpoints.len() {
356 return None;
357 }
358
359 let mut base_index = target_index;
361 while base_index > 0 && checkpoints[base_index].content.is_none() {
362 base_index -= 1;
363 }
364
365 let base_checkpoint = &checkpoints[base_index];
366 let mut content = base_checkpoint.content.clone()?;
367
368 if base_checkpoint.compressed.unwrap_or(false) {
370 content = self.storage.decompress_content(&content);
371 }
372
373 for checkpoint in checkpoints
375 .iter()
376 .take(target_index + 1)
377 .skip(base_index + 1)
378 {
379 if let Some(ref diff) = checkpoint.diff {
380 content = self.diff_engine.apply_diff(&content, diff);
381 } else if let Some(ref c) = checkpoint.content {
382 content = if checkpoint.compressed.unwrap_or(false) {
383 self.storage.decompress_content(c)
384 } else {
385 c.clone()
386 };
387 }
388 }
389
390 Some(content)
391 }
392
393 pub async fn undo(&self, file_path: &str) -> CheckpointResult {
395 let session_guard = self.session.read().await;
396 let session = match session_guard.as_ref() {
397 Some(s) => s,
398 None => return CheckpointResult::err("No active checkpoint session"),
399 };
400
401 let absolute_path = std::path::Path::new(file_path)
402 .canonicalize()
403 .map(|p| p.to_string_lossy().to_string())
404 .unwrap_or_else(|_| file_path.to_string());
405
406 let current_index = session
407 .current_index
408 .get(&absolute_path)
409 .copied()
410 .unwrap_or(0);
411 if current_index == 0 {
412 return CheckpointResult::err("Already at oldest checkpoint");
413 }
414
415 drop(session_guard);
416 self.restore_checkpoint(&absolute_path, Some(current_index - 1), None)
417 .await
418 }
419
420 pub async fn redo(&self, file_path: &str) -> CheckpointResult {
422 let session_guard = self.session.read().await;
423 let session = match session_guard.as_ref() {
424 Some(s) => s,
425 None => return CheckpointResult::err("No active checkpoint session"),
426 };
427
428 let absolute_path = std::path::Path::new(file_path)
429 .canonicalize()
430 .map(|p| p.to_string_lossy().to_string())
431 .unwrap_or_else(|_| file_path.to_string());
432
433 let checkpoints = match session.checkpoints.get(&absolute_path) {
434 Some(c) => c,
435 None => return CheckpointResult::err("No checkpoints available"),
436 };
437
438 let current_index = session
439 .current_index
440 .get(&absolute_path)
441 .copied()
442 .unwrap_or(0);
443 if current_index >= checkpoints.len() - 1 {
444 return CheckpointResult::err("Already at newest checkpoint");
445 }
446
447 drop(session_guard);
448 self.restore_checkpoint(&absolute_path, Some(current_index + 1), None)
449 .await
450 }
451
452 pub async fn get_checkpoint_history(&self, file_path: &str) -> CheckpointHistory {
454 let session_guard = self.session.read().await;
455 let session = match session_guard.as_ref() {
456 Some(s) => s,
457 None => {
458 return CheckpointHistory {
459 checkpoints: vec![],
460 current_index: -1,
461 }
462 }
463 };
464
465 let absolute_path = std::path::Path::new(file_path)
466 .canonicalize()
467 .map(|p| p.to_string_lossy().to_string())
468 .unwrap_or_else(|_| file_path.to_string());
469
470 let checkpoints = session.checkpoints.get(&absolute_path);
471 let current_index = session
472 .current_index
473 .get(&absolute_path)
474 .copied()
475 .unwrap_or(0);
476
477 let items = checkpoints.map_or(vec![], |cps| {
478 cps.iter()
479 .enumerate()
480 .map(|(idx, cp)| CheckpointHistoryItem {
481 index: idx,
482 timestamp: cp.timestamp,
483 hash: cp.hash.clone(),
484 name: cp.name.clone(),
485 description: cp.description.clone(),
486 git_commit: cp.git_commit.clone(),
487 tags: cp.tags.clone(),
488 size: cp.metadata.as_ref().and_then(|m| m.size),
489 compressed: cp.compressed,
490 current: idx == current_index,
491 })
492 .collect()
493 });
494
495 CheckpointHistory {
496 checkpoints: items,
497 current_index: current_index as i32,
498 }
499 }
500
501 pub async fn get_stats(&self) -> CheckpointStats {
503 let session_guard = self.session.read().await;
504 let session = match session_guard.as_ref() {
505 Some(s) => s,
506 None => {
507 return CheckpointStats {
508 total_checkpoints: 0,
509 total_files: 0,
510 total_size: 0,
511 oldest_checkpoint: None,
512 newest_checkpoint: None,
513 compression_ratio: None,
514 }
515 }
516 };
517
518 let mut total_checkpoints = 0;
519 let mut oldest: Option<i64> = None;
520 let mut newest: Option<i64> = None;
521
522 for checkpoints in session.checkpoints.values() {
523 total_checkpoints += checkpoints.len();
524 for cp in checkpoints {
525 oldest = Some(oldest.map_or(cp.timestamp, |o| o.min(cp.timestamp)));
526 newest = Some(newest.map_or(cp.timestamp, |n| n.max(cp.timestamp)));
527 }
528 }
529
530 CheckpointStats {
531 total_checkpoints,
532 total_files: session.checkpoints.len(),
533 total_size: session
534 .metadata
535 .as_ref()
536 .and_then(|m| m.total_size)
537 .unwrap_or(0),
538 oldest_checkpoint: oldest,
539 newest_checkpoint: newest,
540 compression_ratio: None,
541 }
542 }
543
544 pub async fn end_session(&self) {
546 *self.session.write().await = None;
547 }
548}
549
550#[derive(Debug, Clone, Default)]
552pub struct CreateCheckpointOptions {
553 pub name: Option<String>,
554 pub description: Option<String>,
555 pub tags: Option<Vec<String>>,
556 pub force_full_content: Option<bool>,
557}
558
559fn generate_session_id() -> String {
561 let uuid_str = uuid::Uuid::new_v4().to_string();
562 format!(
563 "{}-{}",
564 chrono::Utc::now().timestamp_millis(),
565 uuid_str.get(..8).unwrap_or(&uuid_str)
566 )
567}
568
569fn get_content_hash(content: &str) -> String {
571 use sha2::{Digest, Sha256};
572 let mut hasher = Sha256::new();
573 hasher.update(content.as_bytes());
574 let result = hasher.finalize();
575 hex::encode(&result[..8])
576}
577
578fn get_git_branch() -> Option<String> {
580 std::process::Command::new("git")
581 .args(["rev-parse", "--abbrev-ref", "HEAD"])
582 .output()
583 .ok()
584 .and_then(|o| {
585 if o.status.success() {
586 String::from_utf8(o.stdout)
587 .ok()
588 .map(|s| s.trim().to_string())
589 } else {
590 None
591 }
592 })
593}
594
595fn get_git_commit() -> Option<String> {
597 std::process::Command::new("git")
598 .args(["rev-parse", "HEAD"])
599 .output()
600 .ok()
601 .and_then(|o| {
602 if o.status.success() {
603 String::from_utf8(o.stdout)
604 .ok()
605 .map(|s| s.trim().to_string())
606 } else {
607 None
608 }
609 })
610}
611
612impl Default for CheckpointManager {
613 fn default() -> Self {
614 Self::new()
615 }
616}