1use std::collections::{BTreeMap, BTreeSet};
36use std::path::{Path, PathBuf};
37use std::time::{SystemTime, UNIX_EPOCH};
38
39use blake3::Hasher;
40use serde::{Deserialize, Serialize};
41
42use super::error::EmbedError;
43use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
44use crate::bincode_safe::{deserialize_with_limit, serialize};
45
46pub const CHECKPOINT_VERSION: u32 = 2;
48
49#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
55pub struct CheckpointRepoId {
56 pub namespace: String,
58 pub name: String,
60 pub version: String,
62 pub branch: String,
64 pub commit: String,
66}
67
68impl From<&RepoIdentifier> for CheckpointRepoId {
69 fn from(repo: &RepoIdentifier) -> Self {
70 Self {
71 namespace: repo.namespace.clone().unwrap_or_default(),
72 name: repo.name.clone(),
73 version: repo.version.clone().unwrap_or_default(),
74 branch: repo.branch.clone().unwrap_or_default(),
75 commit: repo.commit.clone().unwrap_or_default(),
76 }
77 }
78}
79
80impl From<RepoIdentifier> for CheckpointRepoId {
81 fn from(repo: RepoIdentifier) -> Self {
82 Self::from(&repo)
83 }
84}
85
86impl From<&CheckpointRepoId> for RepoIdentifier {
87 fn from(cp: &CheckpointRepoId) -> Self {
88 Self {
89 namespace: if cp.namespace.is_empty() {
90 None
91 } else {
92 Some(cp.namespace.clone())
93 },
94 name: cp.name.clone(),
95 version: if cp.version.is_empty() {
96 None
97 } else {
98 Some(cp.version.clone())
99 },
100 branch: if cp.branch.is_empty() {
101 None
102 } else {
103 Some(cp.branch.clone())
104 },
105 commit: if cp.commit.is_empty() {
106 None
107 } else {
108 Some(cp.commit.clone())
109 },
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct EmbedCheckpoint {
117 pub version: u32,
119
120 pub created_at: u64,
122
123 pub updated_at: u64,
125
126 pub repo_id: CheckpointRepoId,
128
129 pub repo_path: String,
131
132 pub settings_hash: String,
134
135 pub processed_files: BTreeSet<String>,
137
138 pub remaining_files: Vec<String>,
140
141 pub chunks_by_file: BTreeMap<String, Vec<ChunkReference>>,
143
144 pub total_chunks: usize,
146
147 pub total_tokens: u64,
149
150 pub failed_files: BTreeMap<String, String>,
152
153 pub phase: CheckpointPhase,
155
156 pub integrity_hash: String,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ChunkReference {
163 pub id: String,
165
166 pub full_hash: String,
168
169 pub tokens: u32,
171
172 pub lines: (u32, u32),
174
175 pub symbol: String,
177}
178
179impl From<&EmbedChunk> for ChunkReference {
180 fn from(chunk: &EmbedChunk) -> Self {
181 Self {
182 id: chunk.id.clone(),
183 full_hash: chunk.full_hash.clone(),
184 tokens: chunk.tokens,
185 lines: chunk.source.lines,
186 symbol: chunk.source.symbol.clone(),
187 }
188 }
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
193pub enum CheckpointPhase {
194 #[default]
196 Discovery,
197 Chunking,
199 CallGraph,
201 Hierarchy,
203 Sorting,
205 Complete,
207}
208
209impl EmbedCheckpoint {
210 pub fn new(repo_path: &Path, repo_id: RepoIdentifier, settings: &EmbedSettings) -> Self {
212 let now = SystemTime::now()
213 .duration_since(UNIX_EPOCH)
214 .unwrap_or_default()
215 .as_secs();
216
217 let settings_hash = compute_settings_hash(settings);
218
219 Self {
220 version: CHECKPOINT_VERSION,
221 created_at: now,
222 updated_at: now,
223 repo_id: CheckpointRepoId::from(repo_id),
224 repo_path: repo_path.to_string_lossy().to_string(),
225 settings_hash,
226 processed_files: BTreeSet::new(),
227 remaining_files: Vec::new(),
228 chunks_by_file: BTreeMap::new(),
229 total_chunks: 0,
230 total_tokens: 0,
231 failed_files: BTreeMap::new(),
232 phase: CheckpointPhase::Discovery,
233 integrity_hash: String::new(),
234 }
235 }
236
237 pub fn set_files(&mut self, files: Vec<String>) {
239 self.remaining_files = files;
240 self.phase = CheckpointPhase::Chunking;
241 self.update_timestamp();
242 }
243
244 pub fn mark_file_processed(&mut self, file: &str, chunks: &[EmbedChunk]) {
246 self.remaining_files.retain(|f| f != file);
248 self.processed_files.insert(file.to_owned());
249
250 let refs: Vec<ChunkReference> = chunks.iter().map(ChunkReference::from).collect();
252 let tokens: u64 = chunks.iter().map(|c| c.tokens as u64).sum();
253
254 self.total_chunks += chunks.len();
255 self.total_tokens += tokens;
256 self.chunks_by_file.insert(file.to_owned(), refs);
257
258 self.update_timestamp();
259 }
260
261 pub fn mark_file_failed(&mut self, file: &str, error: &str) {
263 self.remaining_files.retain(|f| f != file);
264 self.failed_files.insert(file.to_owned(), error.to_owned());
265 self.update_timestamp();
266 }
267
268 pub fn set_phase(&mut self, phase: CheckpointPhase) {
270 self.phase = phase;
271 self.update_timestamp();
272 }
273
274 pub fn is_chunking_complete(&self) -> bool {
276 self.remaining_files.is_empty()
277 && (self.phase == CheckpointPhase::Chunking
278 || self.phase == CheckpointPhase::CallGraph
279 || self.phase == CheckpointPhase::Hierarchy
280 || self.phase == CheckpointPhase::Sorting
281 || self.phase == CheckpointPhase::Complete)
282 }
283
284 pub fn progress_percent(&self) -> u32 {
286 let total =
287 self.processed_files.len() + self.remaining_files.len() + self.failed_files.len();
288 if total == 0 {
289 return 0;
290 }
291
292 let processed = self.processed_files.len() + self.failed_files.len();
293 ((processed * 100) / total) as u32
294 }
295
296 pub fn files_processed(&self) -> usize {
298 self.processed_files.len()
299 }
300
301 pub fn files_remaining(&self) -> usize {
303 self.remaining_files.len()
304 }
305
306 pub fn files_failed(&self) -> usize {
308 self.failed_files.len()
309 }
310
311 pub fn validate(
313 &self,
314 repo_path: &Path,
315 settings: &EmbedSettings,
316 ) -> Result<(), CheckpointError> {
317 if self.version > CHECKPOINT_VERSION {
319 return Err(CheckpointError::VersionMismatch {
320 checkpoint_version: self.version,
321 current_version: CHECKPOINT_VERSION,
322 });
323 }
324
325 let current_path = repo_path.to_string_lossy().to_string();
327 if self.repo_path != current_path {
328 return Err(CheckpointError::RepoMismatch {
329 checkpoint_repo: self.repo_path.clone(),
330 current_repo: current_path,
331 });
332 }
333
334 let current_hash = compute_settings_hash(settings);
336 if self.settings_hash != current_hash {
337 return Err(CheckpointError::SettingsMismatch {
338 checkpoint_hash: self.settings_hash.clone(),
339 current_hash,
340 });
341 }
342
343 Ok(())
344 }
345
346 fn update_timestamp(&mut self) {
348 self.updated_at = SystemTime::now()
349 .duration_since(UNIX_EPOCH)
350 .unwrap_or_default()
351 .as_secs();
352 }
353
354 pub fn compute_integrity(&mut self) {
356 let mut hasher = Hasher::new();
357
358 hasher.update(&self.version.to_le_bytes());
360 hasher.update(&self.created_at.to_le_bytes());
361 hasher.update(self.repo_path.as_bytes());
362 hasher.update(self.settings_hash.as_bytes());
363 hasher.update(&(self.processed_files.len() as u64).to_le_bytes());
364 hasher.update(&(self.total_chunks as u64).to_le_bytes());
365 hasher.update(&self.total_tokens.to_le_bytes());
366
367 for file in &self.processed_files {
369 hasher.update(file.as_bytes());
370 }
371
372 self.integrity_hash = hasher.finalize().to_hex().to_string();
373 }
374
375 pub fn verify_integrity(&self) -> bool {
377 let mut copy = self.clone();
378 copy.compute_integrity();
379 copy.integrity_hash == self.integrity_hash
380 }
381}
382
383#[derive(Debug, Clone)]
385pub enum CheckpointError {
386 VersionMismatch { checkpoint_version: u32, current_version: u32 },
388 RepoMismatch { checkpoint_repo: String, current_repo: String },
390 SettingsMismatch { checkpoint_hash: String, current_hash: String },
392 IntegrityFailed,
394 Corrupted(String),
396}
397
398impl std::fmt::Display for CheckpointError {
399 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 match self {
401 Self::VersionMismatch { checkpoint_version, current_version } => {
402 write!(
403 f,
404 "Checkpoint version {} is newer than current version {}",
405 checkpoint_version, current_version
406 )
407 },
408 Self::RepoMismatch { checkpoint_repo, current_repo } => {
409 write!(
410 f,
411 "Checkpoint repo '{}' doesn't match current repo '{}'",
412 checkpoint_repo, current_repo
413 )
414 },
415 Self::SettingsMismatch { .. } => {
416 write!(f, "Checkpoint settings don't match current settings")
417 },
418 Self::IntegrityFailed => {
419 write!(f, "Checkpoint integrity verification failed")
420 },
421 Self::Corrupted(reason) => {
422 write!(f, "Checkpoint corrupted: {}", reason)
423 },
424 }
425 }
426}
427
428impl std::error::Error for CheckpointError {}
429
430pub struct CheckpointManager {
432 path: PathBuf,
434}
435
436impl CheckpointManager {
437 pub fn new(path: impl Into<PathBuf>) -> Self {
439 Self { path: path.into() }
440 }
441
442 pub fn path(&self) -> &Path {
444 &self.path
445 }
446
447 pub fn exists(&self) -> bool {
449 self.path.exists()
450 }
451
452 pub fn load(&self) -> Result<Option<EmbedCheckpoint>, EmbedError> {
454 if !self.path.exists() {
455 return Ok(None);
456 }
457
458 let bytes = std::fs::read(&self.path)
459 .map_err(|e| EmbedError::IoError { path: self.path.clone(), source: e })?;
460
461 let checkpoint: EmbedCheckpoint =
462 deserialize_with_limit(&bytes).map_err(|e| EmbedError::DeserializationError {
463 reason: format!("Failed to deserialize checkpoint (it may have been created by an older version of infiniloom; delete {:?} and re-run): {}", &self.path, e),
464 })?;
465
466 if !checkpoint.verify_integrity() {
468 return Err(EmbedError::ManifestCorrupted {
469 path: self.path.clone(),
470 expected: checkpoint.integrity_hash,
471 actual: "integrity check failed".to_owned(),
472 });
473 }
474
475 Ok(Some(checkpoint))
476 }
477
478 pub fn save(&self, checkpoint: &mut EmbedCheckpoint) -> Result<(), EmbedError> {
480 checkpoint.compute_integrity();
482
483 let bytes = serialize(checkpoint).map_err(|e| EmbedError::SerializationError {
484 reason: format!("Failed to serialize checkpoint: {}", e),
485 })?;
486
487 let temp_path = self.path.with_extension("tmp");
489
490 std::fs::write(&temp_path, &bytes)
491 .map_err(|e| EmbedError::IoError { path: temp_path.clone(), source: e })?;
492
493 std::fs::rename(&temp_path, &self.path)
494 .map_err(|e| EmbedError::IoError { path: self.path.clone(), source: e })?;
495
496 Ok(())
497 }
498
499 pub fn delete(&self) -> Result<(), EmbedError> {
501 if self.path.exists() {
502 std::fs::remove_file(&self.path)
503 .map_err(|e| EmbedError::IoError { path: self.path.clone(), source: e })?;
504 }
505 Ok(())
506 }
507
508 pub fn load_validated(
510 &self,
511 repo_path: &Path,
512 settings: &EmbedSettings,
513 ) -> Result<Option<EmbedCheckpoint>, EmbedError> {
514 let checkpoint = match self.load()? {
515 Some(cp) => cp,
516 None => return Ok(None),
517 };
518
519 match checkpoint.validate(repo_path, settings) {
521 Ok(()) => Ok(Some(checkpoint)),
522 Err(CheckpointError::SettingsMismatch { .. }) => {
523 Ok(None)
525 },
526 Err(CheckpointError::RepoMismatch { .. }) => {
527 Ok(None)
529 },
530 Err(e) => Err(EmbedError::DeserializationError { reason: e.to_string() }),
531 }
532 }
533}
534
535fn compute_settings_hash(settings: &EmbedSettings) -> String {
537 let mut hasher = Hasher::new();
538
539 hasher.update(&settings.max_tokens.to_le_bytes());
541 hasher.update(&settings.min_tokens.to_le_bytes());
542 hasher.update(&settings.overlap_tokens.to_le_bytes());
543 hasher.update(&settings.context_lines.to_le_bytes());
544 hasher.update(settings.token_model.as_bytes());
545 hasher.update(&[settings.include_imports as u8]);
546 hasher.update(&[settings.include_tests as u8]);
547 hasher.update(&[settings.include_top_level as u8]);
548 hasher.update(&[settings.scan_secrets as u8]);
549 hasher.update(&[settings.redact_secrets as u8]);
550 hasher.update(&[settings.fail_on_secrets as u8]);
551 hasher.update(&[settings.enable_hierarchy as u8]);
552 hasher.update(&settings.hierarchy_min_children.to_le_bytes());
553
554 for pattern in &settings.include_patterns {
556 hasher.update(pattern.as_bytes());
557 }
558 for pattern in &settings.exclude_patterns {
559 hasher.update(pattern.as_bytes());
560 }
561
562 hasher.finalize().to_hex().to_string()
563}
564
565#[derive(Debug, Clone)]
567pub struct CheckpointStats {
568 pub created_at: u64,
570 pub updated_at: u64,
572 pub phase: CheckpointPhase,
574 pub files_processed: usize,
576 pub files_remaining: usize,
578 pub files_failed: usize,
580 pub total_chunks: usize,
582 pub total_tokens: u64,
584 pub progress_percent: u32,
586}
587
588impl From<&EmbedCheckpoint> for CheckpointStats {
589 fn from(cp: &EmbedCheckpoint) -> Self {
590 Self {
591 created_at: cp.created_at,
592 updated_at: cp.updated_at,
593 phase: cp.phase,
594 files_processed: cp.files_processed(),
595 files_remaining: cp.files_remaining(),
596 files_failed: cp.files_failed(),
597 total_chunks: cp.total_chunks,
598 total_tokens: cp.total_tokens,
599 progress_percent: cp.progress_percent(),
600 }
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 fn test_settings() -> EmbedSettings {
609 EmbedSettings::default()
610 }
611
612 #[test]
613 fn test_checkpoint_creation() {
614 let settings = test_settings();
615 let cp =
616 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
617
618 assert_eq!(cp.version, CHECKPOINT_VERSION);
619 assert_eq!(cp.phase, CheckpointPhase::Discovery);
620 assert!(cp.processed_files.is_empty());
621 assert!(cp.remaining_files.is_empty());
622 assert_eq!(cp.total_chunks, 0);
623 }
624
625 #[test]
626 fn test_checkpoint_file_tracking() {
627 let settings = test_settings();
628 let mut cp =
629 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
630
631 cp.set_files(vec!["a.rs".to_owned(), "b.rs".to_owned(), "c.rs".to_owned()]);
633
634 assert_eq!(cp.files_remaining(), 3);
635 assert_eq!(cp.files_processed(), 0);
636 assert_eq!(cp.phase, CheckpointPhase::Chunking);
637
638 cp.mark_file_processed("a.rs", &[]);
640 assert_eq!(cp.files_remaining(), 2);
641 assert_eq!(cp.files_processed(), 1);
642
643 cp.mark_file_failed("b.rs", "Parse error");
645 assert_eq!(cp.files_remaining(), 1);
646 assert_eq!(cp.files_processed(), 1);
647 assert_eq!(cp.files_failed(), 1);
648
649 assert_eq!(cp.progress_percent(), 66);
651 }
652
653 #[test]
654 fn test_checkpoint_integrity() {
655 let settings = test_settings();
656 let mut cp =
657 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
658 cp.set_files(vec!["test.rs".to_owned()]);
659 cp.mark_file_processed("test.rs", &[]);
660
661 cp.compute_integrity();
663 assert!(!cp.integrity_hash.is_empty());
664 assert!(cp.verify_integrity());
665
666 cp.total_chunks = 999;
668 assert!(!cp.verify_integrity());
669 }
670
671 #[test]
672 fn test_checkpoint_validation() {
673 let settings = test_settings();
674 let cp =
675 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
676
677 assert!(cp.validate(Path::new("/test/repo"), &settings).is_ok());
679
680 assert!(cp.validate(Path::new("/other/repo"), &settings).is_err());
682
683 let mut different_settings = settings;
685 different_settings.max_tokens = 9999;
686 assert!(cp
687 .validate(Path::new("/test/repo"), &different_settings)
688 .is_err());
689 }
690
691 #[test]
692 fn test_checkpoint_save_load() {
693 let temp_dir = tempfile::TempDir::new().unwrap();
694 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
695
696 let settings = test_settings();
697 let mut cp =
698 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
699 cp.set_files(vec!["test.rs".to_owned()]);
700 cp.mark_file_processed("test.rs", &[]);
701
702 let manager = CheckpointManager::new(&checkpoint_path);
704 manager.save(&mut cp).unwrap();
705
706 let loaded = manager.load().unwrap().unwrap();
708 assert_eq!(loaded.files_processed(), 1);
709 assert!(loaded.verify_integrity());
710 }
711
712 #[test]
713 fn test_checkpoint_manager_validated() {
714 let temp_dir = tempfile::TempDir::new().unwrap();
715 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
716 let repo_path = Path::new("/test/repo");
717
718 let settings = test_settings();
719 let mut cp = EmbedCheckpoint::new(repo_path, RepoIdentifier::default(), &settings);
720 cp.set_files(vec!["test.rs".to_owned()]);
721
722 let manager = CheckpointManager::new(&checkpoint_path);
723 manager.save(&mut cp).unwrap();
724
725 let loaded = manager.load_validated(repo_path, &settings).unwrap();
727 assert!(loaded.is_some());
728
729 let mut different = settings;
731 different.max_tokens = 9999;
732 let loaded = manager.load_validated(repo_path, &different).unwrap();
733 assert!(loaded.is_none());
734 }
735
736 #[test]
737 fn test_checkpoint_phases() {
738 let settings = test_settings();
739 let mut cp =
740 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
741
742 assert_eq!(cp.phase, CheckpointPhase::Discovery);
743
744 cp.set_phase(CheckpointPhase::Chunking);
745 assert_eq!(cp.phase, CheckpointPhase::Chunking);
746
747 cp.set_phase(CheckpointPhase::CallGraph);
748 assert_eq!(cp.phase, CheckpointPhase::CallGraph);
749
750 cp.set_phase(CheckpointPhase::Complete);
751 assert_eq!(cp.phase, CheckpointPhase::Complete);
752 }
753
754 #[test]
755 fn test_checkpoint_stats() {
756 let settings = test_settings();
757 let mut cp =
758 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
759 cp.set_files(vec!["a.rs".to_owned(), "b.rs".to_owned()]);
760 cp.mark_file_processed("a.rs", &[]);
761
762 let stats = CheckpointStats::from(&cp);
763 assert_eq!(stats.files_processed, 1);
764 assert_eq!(stats.files_remaining, 1);
765 assert_eq!(stats.progress_percent, 50);
766 }
767
768 #[test]
769 fn test_settings_hash_determinism() {
770 let settings = test_settings();
771 let hash1 = compute_settings_hash(&settings);
772 let hash2 = compute_settings_hash(&settings);
773 assert_eq!(hash1, hash2);
774
775 let mut different = settings;
777 different.max_tokens = 2000; let hash3 = compute_settings_hash(&different);
779 assert_ne!(hash1, hash3);
780 }
781
782 #[test]
783 fn test_checkpoint_delete() {
784 let temp_dir = tempfile::TempDir::new().unwrap();
785 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
786
787 let settings = test_settings();
788 let mut cp =
789 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
790
791 let manager = CheckpointManager::new(&checkpoint_path);
792 manager.save(&mut cp).unwrap();
793 assert!(manager.exists());
794
795 manager.delete().unwrap();
796 assert!(!manager.exists());
797 }
798
799 #[test]
800 fn test_is_chunking_complete() {
801 let settings = test_settings();
802 let mut cp =
803 EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
804
805 assert!(!cp.is_chunking_complete());
807
808 cp.set_files(vec!["a.rs".to_owned()]);
810 assert!(!cp.is_chunking_complete());
811
812 cp.mark_file_processed("a.rs", &[]);
814 assert!(cp.is_chunking_complete());
815 }
816
817 #[test]
818 fn test_chunk_reference_from_embed_chunk() {
819 use super::super::types::{ChunkContext, ChunkKind, ChunkSource};
820
821 let chunk = EmbedChunk {
822 id: "ec_abc123".to_owned(),
823 full_hash: "deadbeef".repeat(8),
824 content: "fn test() {}".to_owned(),
825 tokens: 10,
826 kind: ChunkKind::Function,
827 source: ChunkSource {
828 repo: RepoIdentifier::default(),
829 file: "test.rs".to_owned(),
830 lines: (1, 5),
831 symbol: "test".to_owned(),
832 fqn: None,
833 language: "Rust".to_owned(),
834 parent: None,
835 visibility: super::super::types::Visibility::Public,
836 is_test: false,
837 module_path: None,
838 parent_chunk_id: None,
839 },
840 children_ids: Vec::new(),
841 context: ChunkContext::default(),
842 repr: "code".to_string(),
843 code_chunk_id: None,
844 part: None,
845 };
846
847 let reference = ChunkReference::from(&chunk);
848 assert_eq!(reference.id, "ec_abc123");
849 assert_eq!(reference.tokens, 10);
850 assert_eq!(reference.lines, (1, 5));
851 assert_eq!(reference.symbol, "test");
852 }
853}