1use std::collections::{BTreeMap, BTreeSet};
36use std::path::{Path, PathBuf};
37use std::time::{SystemTime, UNIX_EPOCH};
38
39use bincode::Options;
40use blake3::Hasher;
41use serde::{Deserialize, Serialize};
42
43use super::error::EmbedError;
44use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
45use crate::bincode_safe::deserialize_with_limit;
46
47pub const CHECKPOINT_VERSION: u32 = 1;
49
50#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
56pub struct CheckpointRepoId {
57 pub namespace: String,
59 pub name: String,
61 pub version: String,
63 pub branch: String,
65 pub commit: String,
67}
68
69impl From<&RepoIdentifier> for CheckpointRepoId {
70 fn from(repo: &RepoIdentifier) -> Self {
71 Self {
72 namespace: repo.namespace.clone(),
73 name: repo.name.clone(),
74 version: repo.version.clone().unwrap_or_default(),
75 branch: repo.branch.clone().unwrap_or_default(),
76 commit: repo.commit.clone().unwrap_or_default(),
77 }
78 }
79}
80
81impl From<RepoIdentifier> for CheckpointRepoId {
82 fn from(repo: RepoIdentifier) -> Self {
83 Self::from(&repo)
84 }
85}
86
87impl From<&CheckpointRepoId> for RepoIdentifier {
88 fn from(cp: &CheckpointRepoId) -> Self {
89 Self {
90 namespace: cp.namespace.clone(),
91 name: cp.name.clone(),
92 version: if cp.version.is_empty() {
93 None
94 } else {
95 Some(cp.version.clone())
96 },
97 branch: if cp.branch.is_empty() {
98 None
99 } else {
100 Some(cp.branch.clone())
101 },
102 commit: if cp.commit.is_empty() {
103 None
104 } else {
105 Some(cp.commit.clone())
106 },
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct EmbedCheckpoint {
114 pub version: u32,
116
117 pub created_at: u64,
119
120 pub updated_at: u64,
122
123 pub repo_id: CheckpointRepoId,
125
126 pub repo_path: String,
128
129 pub settings_hash: String,
131
132 pub processed_files: BTreeSet<String>,
134
135 pub remaining_files: Vec<String>,
137
138 pub chunks_by_file: BTreeMap<String, Vec<ChunkReference>>,
140
141 pub total_chunks: usize,
143
144 pub total_tokens: u64,
146
147 pub failed_files: BTreeMap<String, String>,
149
150 pub phase: CheckpointPhase,
152
153 pub integrity_hash: String,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ChunkReference {
160 pub id: String,
162
163 pub full_hash: String,
165
166 pub tokens: u32,
168
169 pub lines: (u32, u32),
171
172 pub symbol: String,
174}
175
176impl From<&EmbedChunk> for ChunkReference {
177 fn from(chunk: &EmbedChunk) -> Self {
178 Self {
179 id: chunk.id.clone(),
180 full_hash: chunk.full_hash.clone(),
181 tokens: chunk.tokens,
182 lines: chunk.source.lines,
183 symbol: chunk.source.symbol.clone(),
184 }
185 }
186}
187
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
190pub enum CheckpointPhase {
191 #[default]
193 Discovery,
194 Chunking,
196 CallGraph,
198 Hierarchy,
200 Sorting,
202 Complete,
204}
205
206impl EmbedCheckpoint {
207 pub fn new(repo_path: &Path, repo_id: RepoIdentifier, settings: &EmbedSettings) -> Self {
209 let now = SystemTime::now()
210 .duration_since(UNIX_EPOCH)
211 .unwrap_or_default()
212 .as_secs();
213
214 let settings_hash = compute_settings_hash(settings);
215
216 Self {
217 version: CHECKPOINT_VERSION,
218 created_at: now,
219 updated_at: now,
220 repo_id: CheckpointRepoId::from(repo_id),
221 repo_path: repo_path.to_string_lossy().to_string(),
222 settings_hash,
223 processed_files: BTreeSet::new(),
224 remaining_files: Vec::new(),
225 chunks_by_file: BTreeMap::new(),
226 total_chunks: 0,
227 total_tokens: 0,
228 failed_files: BTreeMap::new(),
229 phase: CheckpointPhase::Discovery,
230 integrity_hash: String::new(),
231 }
232 }
233
234 pub fn set_files(&mut self, files: Vec<String>) {
236 self.remaining_files = files;
237 self.phase = CheckpointPhase::Chunking;
238 self.update_timestamp();
239 }
240
241 pub fn mark_file_processed(&mut self, file: &str, chunks: &[EmbedChunk]) {
243 self.remaining_files.retain(|f| f != file);
245 self.processed_files.insert(file.to_owned());
246
247 let refs: Vec<ChunkReference> = chunks.iter().map(ChunkReference::from).collect();
249 let tokens: u64 = chunks.iter().map(|c| c.tokens as u64).sum();
250
251 self.total_chunks += chunks.len();
252 self.total_tokens += tokens;
253 self.chunks_by_file.insert(file.to_owned(), refs);
254
255 self.update_timestamp();
256 }
257
258 pub fn mark_file_failed(&mut self, file: &str, error: &str) {
260 self.remaining_files.retain(|f| f != file);
261 self.failed_files.insert(file.to_owned(), error.to_owned());
262 self.update_timestamp();
263 }
264
265 pub fn set_phase(&mut self, phase: CheckpointPhase) {
267 self.phase = phase;
268 self.update_timestamp();
269 }
270
271 pub fn is_chunking_complete(&self) -> bool {
273 self.remaining_files.is_empty()
274 && (self.phase == CheckpointPhase::Chunking
275 || self.phase == CheckpointPhase::CallGraph
276 || self.phase == CheckpointPhase::Hierarchy
277 || self.phase == CheckpointPhase::Sorting
278 || self.phase == CheckpointPhase::Complete)
279 }
280
281 pub fn progress_percent(&self) -> u32 {
283 let total =
284 self.processed_files.len() + self.remaining_files.len() + self.failed_files.len();
285 if total == 0 {
286 return 0;
287 }
288
289 let processed = self.processed_files.len() + self.failed_files.len();
290 ((processed * 100) / total) as u32
291 }
292
293 pub fn files_processed(&self) -> usize {
295 self.processed_files.len()
296 }
297
298 pub fn files_remaining(&self) -> usize {
300 self.remaining_files.len()
301 }
302
303 pub fn files_failed(&self) -> usize {
305 self.failed_files.len()
306 }
307
308 pub fn validate(
310 &self,
311 repo_path: &Path,
312 settings: &EmbedSettings,
313 ) -> Result<(), CheckpointError> {
314 if self.version > CHECKPOINT_VERSION {
316 return Err(CheckpointError::VersionMismatch {
317 checkpoint_version: self.version,
318 current_version: CHECKPOINT_VERSION,
319 });
320 }
321
322 let current_path = repo_path.to_string_lossy().to_string();
324 if self.repo_path != current_path {
325 return Err(CheckpointError::RepoMismatch {
326 checkpoint_repo: self.repo_path.clone(),
327 current_repo: current_path,
328 });
329 }
330
331 let current_hash = compute_settings_hash(settings);
333 if self.settings_hash != current_hash {
334 return Err(CheckpointError::SettingsMismatch {
335 checkpoint_hash: self.settings_hash.clone(),
336 current_hash,
337 });
338 }
339
340 Ok(())
341 }
342
343 fn update_timestamp(&mut self) {
345 self.updated_at = SystemTime::now()
346 .duration_since(UNIX_EPOCH)
347 .unwrap_or_default()
348 .as_secs();
349 }
350
351 pub fn compute_integrity(&mut self) {
353 let mut hasher = Hasher::new();
354
355 hasher.update(&self.version.to_le_bytes());
357 hasher.update(&self.created_at.to_le_bytes());
358 hasher.update(self.repo_path.as_bytes());
359 hasher.update(self.settings_hash.as_bytes());
360 hasher.update(&(self.processed_files.len() as u64).to_le_bytes());
361 hasher.update(&(self.total_chunks as u64).to_le_bytes());
362 hasher.update(&self.total_tokens.to_le_bytes());
363
364 for file in &self.processed_files {
366 hasher.update(file.as_bytes());
367 }
368
369 self.integrity_hash = hasher.finalize().to_hex().to_string();
370 }
371
372 pub fn verify_integrity(&self) -> bool {
374 let mut copy = self.clone();
375 copy.compute_integrity();
376 copy.integrity_hash == self.integrity_hash
377 }
378}
379
380#[derive(Debug, Clone)]
382pub enum CheckpointError {
383 VersionMismatch { checkpoint_version: u32, current_version: u32 },
385 RepoMismatch { checkpoint_repo: String, current_repo: String },
387 SettingsMismatch { checkpoint_hash: String, current_hash: String },
389 IntegrityFailed,
391 Corrupted(String),
393}
394
395impl std::fmt::Display for CheckpointError {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 match self {
398 Self::VersionMismatch { checkpoint_version, current_version } => {
399 write!(
400 f,
401 "Checkpoint version {} is newer than current version {}",
402 checkpoint_version, current_version
403 )
404 },
405 Self::RepoMismatch { checkpoint_repo, current_repo } => {
406 write!(
407 f,
408 "Checkpoint repo '{}' doesn't match current repo '{}'",
409 checkpoint_repo, current_repo
410 )
411 },
412 Self::SettingsMismatch { .. } => {
413 write!(f, "Checkpoint settings don't match current settings")
414 },
415 Self::IntegrityFailed => {
416 write!(f, "Checkpoint integrity verification failed")
417 },
418 Self::Corrupted(reason) => {
419 write!(f, "Checkpoint corrupted: {}", reason)
420 },
421 }
422 }
423}
424
425impl std::error::Error for CheckpointError {}
426
427pub struct CheckpointManager {
429 path: PathBuf,
431}
432
433impl CheckpointManager {
434 pub fn new(path: impl Into<PathBuf>) -> Self {
436 Self { path: path.into() }
437 }
438
439 pub fn path(&self) -> &Path {
441 &self.path
442 }
443
444 pub fn exists(&self) -> bool {
446 self.path.exists()
447 }
448
449 pub fn load(&self) -> Result<Option<EmbedCheckpoint>, EmbedError> {
451 if !self.path.exists() {
452 return Ok(None);
453 }
454
455 let bytes = std::fs::read(&self.path)
456 .map_err(|e| EmbedError::IoError { path: self.path.clone(), source: e })?;
457
458 let checkpoint: EmbedCheckpoint =
459 deserialize_with_limit(&bytes).map_err(|e| EmbedError::DeserializationError {
460 reason: format!("Failed to deserialize checkpoint: {}", e),
461 })?;
462
463 if !checkpoint.verify_integrity() {
465 return Err(EmbedError::ManifestCorrupted {
466 path: self.path.clone(),
467 expected: checkpoint.integrity_hash,
468 actual: "integrity check failed".to_owned(),
469 });
470 }
471
472 Ok(Some(checkpoint))
473 }
474
475 pub fn save(&self, checkpoint: &mut EmbedCheckpoint) -> Result<(), EmbedError> {
477 checkpoint.compute_integrity();
479
480 let bytes = bincode::options().serialize(checkpoint).map_err(|e| {
481 EmbedError::SerializationError {
482 reason: format!("Failed to serialize checkpoint: {}", e),
483 }
484 })?;
485
486 let temp_path = self.path.with_extension("tmp");
488
489 std::fs::write(&temp_path, &bytes)
490 .map_err(|e| EmbedError::IoError { path: temp_path.clone(), source: e })?;
491
492 std::fs::rename(&temp_path, &self.path)
493 .map_err(|e| EmbedError::IoError { path: self.path.clone(), source: e })?;
494
495 Ok(())
496 }
497
498 pub fn delete(&self) -> Result<(), EmbedError> {
500 if self.path.exists() {
501 std::fs::remove_file(&self.path)
502 .map_err(|e| EmbedError::IoError { path: self.path.clone(), source: e })?;
503 }
504 Ok(())
505 }
506
507 pub fn load_validated(
509 &self,
510 repo_path: &Path,
511 settings: &EmbedSettings,
512 ) -> Result<Option<EmbedCheckpoint>, EmbedError> {
513 let checkpoint = match self.load()? {
514 Some(cp) => cp,
515 None => return Ok(None),
516 };
517
518 match checkpoint.validate(repo_path, settings) {
520 Ok(()) => Ok(Some(checkpoint)),
521 Err(CheckpointError::SettingsMismatch { .. }) => {
522 Ok(None)
524 },
525 Err(CheckpointError::RepoMismatch { .. }) => {
526 Ok(None)
528 },
529 Err(e) => Err(EmbedError::DeserializationError { reason: e.to_string() }),
530 }
531 }
532}
533
534fn compute_settings_hash(settings: &EmbedSettings) -> String {
536 let mut hasher = Hasher::new();
537
538 hasher.update(&settings.max_tokens.to_le_bytes());
540 hasher.update(&settings.min_tokens.to_le_bytes());
541 hasher.update(&settings.overlap_tokens.to_le_bytes());
542 hasher.update(&settings.context_lines.to_le_bytes());
543 hasher.update(settings.token_model.as_bytes());
544 hasher.update(&[settings.include_imports as u8]);
545 hasher.update(&[settings.include_tests as u8]);
546 hasher.update(&[settings.include_top_level as u8]);
547 hasher.update(&[settings.scan_secrets as u8]);
548 hasher.update(&[settings.redact_secrets as u8]);
549 hasher.update(&[settings.fail_on_secrets as u8]);
550 hasher.update(&[settings.enable_hierarchy as u8]);
551 hasher.update(&settings.hierarchy_min_children.to_le_bytes());
552
553 for pattern in &settings.include_patterns {
555 hasher.update(pattern.as_bytes());
556 }
557 for pattern in &settings.exclude_patterns {
558 hasher.update(pattern.as_bytes());
559 }
560
561 hasher.finalize().to_hex().to_string()
562}
563
564#[derive(Debug, Clone)]
566pub struct CheckpointStats {
567 pub created_at: u64,
569 pub updated_at: u64,
571 pub phase: CheckpointPhase,
573 pub files_processed: usize,
575 pub files_remaining: usize,
577 pub files_failed: usize,
579 pub total_chunks: usize,
581 pub total_tokens: u64,
583 pub progress_percent: u32,
585}
586
587impl From<&EmbedCheckpoint> for CheckpointStats {
588 fn from(cp: &EmbedCheckpoint) -> Self {
589 Self {
590 created_at: cp.created_at,
591 updated_at: cp.updated_at,
592 phase: cp.phase,
593 files_processed: cp.files_processed(),
594 files_remaining: cp.files_remaining(),
595 files_failed: cp.files_failed(),
596 total_chunks: cp.total_chunks,
597 total_tokens: cp.total_tokens,
598 progress_percent: cp.progress_percent(),
599 }
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 fn test_settings() -> EmbedSettings {
608 EmbedSettings::default()
609 }
610
611 #[test]
612 fn test_checkpoint_creation() {
613 let settings = test_settings();
614 let cp =
615 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
616
617 assert_eq!(cp.version, CHECKPOINT_VERSION);
618 assert_eq!(cp.phase, CheckpointPhase::Discovery);
619 assert!(cp.processed_files.is_empty());
620 assert!(cp.remaining_files.is_empty());
621 assert_eq!(cp.total_chunks, 0);
622 }
623
624 #[test]
625 fn test_checkpoint_file_tracking() {
626 let settings = test_settings();
627 let mut cp =
628 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
629
630 cp.set_files(vec!["a.rs".to_owned(), "b.rs".to_owned(), "c.rs".to_owned()]);
632
633 assert_eq!(cp.files_remaining(), 3);
634 assert_eq!(cp.files_processed(), 0);
635 assert_eq!(cp.phase, CheckpointPhase::Chunking);
636
637 cp.mark_file_processed("a.rs", &[]);
639 assert_eq!(cp.files_remaining(), 2);
640 assert_eq!(cp.files_processed(), 1);
641
642 cp.mark_file_failed("b.rs", "Parse error");
644 assert_eq!(cp.files_remaining(), 1);
645 assert_eq!(cp.files_processed(), 1);
646 assert_eq!(cp.files_failed(), 1);
647
648 assert_eq!(cp.progress_percent(), 66);
650 }
651
652 #[test]
653 fn test_checkpoint_integrity() {
654 let settings = test_settings();
655 let mut cp =
656 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
657 cp.set_files(vec!["test.rs".to_owned()]);
658 cp.mark_file_processed("test.rs", &[]);
659
660 cp.compute_integrity();
662 assert!(!cp.integrity_hash.is_empty());
663 assert!(cp.verify_integrity());
664
665 cp.total_chunks = 999;
667 assert!(!cp.verify_integrity());
668 }
669
670 #[test]
671 fn test_checkpoint_validation() {
672 let settings = test_settings();
673 let cp =
674 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
675
676 assert!(cp.validate(Path::new("/test/repo"), &settings).is_ok());
678
679 assert!(cp.validate(Path::new("/other/repo"), &settings).is_err());
681
682 let mut different_settings = settings;
684 different_settings.max_tokens = 9999;
685 assert!(cp
686 .validate(Path::new("/test/repo"), &different_settings)
687 .is_err());
688 }
689
690 #[test]
691 fn test_checkpoint_save_load() {
692 let temp_dir = tempfile::TempDir::new().unwrap();
693 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
694
695 let settings = test_settings();
696 let mut cp =
697 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
698 cp.set_files(vec!["test.rs".to_owned()]);
699 cp.mark_file_processed("test.rs", &[]);
700
701 let manager = CheckpointManager::new(&checkpoint_path);
703 manager.save(&mut cp).unwrap();
704
705 let loaded = manager.load().unwrap().unwrap();
707 assert_eq!(loaded.files_processed(), 1);
708 assert!(loaded.verify_integrity());
709 }
710
711 #[test]
712 fn test_checkpoint_manager_validated() {
713 let temp_dir = tempfile::TempDir::new().unwrap();
714 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
715 let repo_path = Path::new("/test/repo");
716
717 let settings = test_settings();
718 let mut cp = EmbedCheckpoint::new(repo_path, RepoIdentifier::default(), &settings);
719 cp.set_files(vec!["test.rs".to_owned()]);
720
721 let manager = CheckpointManager::new(&checkpoint_path);
722 manager.save(&mut cp).unwrap();
723
724 let loaded = manager.load_validated(repo_path, &settings).unwrap();
726 assert!(loaded.is_some());
727
728 let mut different = settings;
730 different.max_tokens = 9999;
731 let loaded = manager.load_validated(repo_path, &different).unwrap();
732 assert!(loaded.is_none());
733 }
734
735 #[test]
736 fn test_checkpoint_phases() {
737 let settings = test_settings();
738 let mut cp =
739 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
740
741 assert_eq!(cp.phase, CheckpointPhase::Discovery);
742
743 cp.set_phase(CheckpointPhase::Chunking);
744 assert_eq!(cp.phase, CheckpointPhase::Chunking);
745
746 cp.set_phase(CheckpointPhase::CallGraph);
747 assert_eq!(cp.phase, CheckpointPhase::CallGraph);
748
749 cp.set_phase(CheckpointPhase::Complete);
750 assert_eq!(cp.phase, CheckpointPhase::Complete);
751 }
752
753 #[test]
754 fn test_checkpoint_stats() {
755 let settings = test_settings();
756 let mut cp =
757 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
758 cp.set_files(vec!["a.rs".to_owned(), "b.rs".to_owned()]);
759 cp.mark_file_processed("a.rs", &[]);
760
761 let stats = CheckpointStats::from(&cp);
762 assert_eq!(stats.files_processed, 1);
763 assert_eq!(stats.files_remaining, 1);
764 assert_eq!(stats.progress_percent, 50);
765 }
766
767 #[test]
768 fn test_settings_hash_determinism() {
769 let settings = test_settings();
770 let hash1 = compute_settings_hash(&settings);
771 let hash2 = compute_settings_hash(&settings);
772 assert_eq!(hash1, hash2);
773
774 let mut different = settings;
776 different.max_tokens = 2000; let hash3 = compute_settings_hash(&different);
778 assert_ne!(hash1, hash3);
779 }
780
781 #[test]
782 fn test_checkpoint_delete() {
783 let temp_dir = tempfile::TempDir::new().unwrap();
784 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
785
786 let settings = test_settings();
787 let mut cp =
788 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
789
790 let manager = CheckpointManager::new(&checkpoint_path);
791 manager.save(&mut cp).unwrap();
792 assert!(manager.exists());
793
794 manager.delete().unwrap();
795 assert!(!manager.exists());
796 }
797
798 #[test]
799 fn test_is_chunking_complete() {
800 let settings = test_settings();
801 let mut cp =
802 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
803
804 assert!(!cp.is_chunking_complete());
806
807 cp.set_files(vec!["a.rs".to_owned()]);
809 assert!(!cp.is_chunking_complete());
810
811 cp.mark_file_processed("a.rs", &[]);
813 assert!(cp.is_chunking_complete());
814 }
815
816 #[test]
817 fn test_chunk_reference_from_embed_chunk() {
818 use super::super::types::{ChunkContext, ChunkKind, ChunkSource};
819
820 let chunk = EmbedChunk {
821 id: "ec_abc123".to_owned(),
822 full_hash: "deadbeef".repeat(8),
823 content: "fn test() {}".to_owned(),
824 tokens: 10,
825 kind: ChunkKind::Function,
826 source: ChunkSource {
827 repo: RepoIdentifier::default(),
828 file: "test.rs".to_owned(),
829 lines: (1, 5),
830 symbol: "test".to_owned(),
831 fqn: None,
832 language: "Rust".to_owned(),
833 parent: None,
834 visibility: super::super::types::Visibility::Public,
835 is_test: false,
836 },
837 context: ChunkContext::default(),
838 part: None,
839 };
840
841 let reference = ChunkReference::from(&chunk);
842 assert_eq!(reference.id, "ec_abc123");
843 assert_eq!(reference.tokens, 10);
844 assert_eq!(reference.lines, (1, 5));
845 assert_eq!(reference.symbol, "test");
846 }
847}