1use std::collections::{BTreeMap, BTreeSet};
36use std::path::{Path, PathBuf};
37use std::time::{SystemTime, UNIX_EPOCH};
38
39use bincode::Options;
40use blake3::Hasher;
41use serde::{Deserialize, Serialize};
42
43use super::error::EmbedError;
44use crate::bincode_safe::deserialize_with_limit;
45use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
46
47pub const CHECKPOINT_VERSION: u32 = 1;
49
50#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
56pub struct CheckpointRepoId {
57 pub namespace: String,
59 pub name: String,
61 pub version: String,
63 pub branch: String,
65 pub commit: String,
67}
68
69impl From<&RepoIdentifier> for CheckpointRepoId {
70 fn from(repo: &RepoIdentifier) -> Self {
71 Self {
72 namespace: repo.namespace.clone(),
73 name: repo.name.clone(),
74 version: repo.version.clone().unwrap_or_default(),
75 branch: repo.branch.clone().unwrap_or_default(),
76 commit: repo.commit.clone().unwrap_or_default(),
77 }
78 }
79}
80
81impl From<RepoIdentifier> for CheckpointRepoId {
82 fn from(repo: RepoIdentifier) -> Self {
83 Self::from(&repo)
84 }
85}
86
87impl From<&CheckpointRepoId> for RepoIdentifier {
88 fn from(cp: &CheckpointRepoId) -> Self {
89 Self {
90 namespace: cp.namespace.clone(),
91 name: cp.name.clone(),
92 version: if cp.version.is_empty() { None } else { Some(cp.version.clone()) },
93 branch: if cp.branch.is_empty() { None } else { Some(cp.branch.clone()) },
94 commit: if cp.commit.is_empty() { None } else { Some(cp.commit.clone()) },
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct EmbedCheckpoint {
102 pub version: u32,
104
105 pub created_at: u64,
107
108 pub updated_at: u64,
110
111 pub repo_id: CheckpointRepoId,
113
114 pub repo_path: String,
116
117 pub settings_hash: String,
119
120 pub processed_files: BTreeSet<String>,
122
123 pub remaining_files: Vec<String>,
125
126 pub chunks_by_file: BTreeMap<String, Vec<ChunkReference>>,
128
129 pub total_chunks: usize,
131
132 pub total_tokens: u64,
134
135 pub failed_files: BTreeMap<String, String>,
137
138 pub phase: CheckpointPhase,
140
141 pub integrity_hash: String,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct ChunkReference {
148 pub id: String,
150
151 pub full_hash: String,
153
154 pub tokens: u32,
156
157 pub lines: (u32, u32),
159
160 pub symbol: String,
162}
163
164impl From<&EmbedChunk> for ChunkReference {
165 fn from(chunk: &EmbedChunk) -> Self {
166 Self {
167 id: chunk.id.clone(),
168 full_hash: chunk.full_hash.clone(),
169 tokens: chunk.tokens,
170 lines: chunk.source.lines,
171 symbol: chunk.source.symbol.clone(),
172 }
173 }
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
178pub enum CheckpointPhase {
179 Discovery,
181 Chunking,
183 CallGraph,
185 Hierarchy,
187 Sorting,
189 Complete,
191}
192
193impl Default for CheckpointPhase {
194 fn default() -> Self {
195 Self::Discovery
196 }
197}
198
199impl EmbedCheckpoint {
200 pub fn new(repo_path: &Path, repo_id: RepoIdentifier, settings: &EmbedSettings) -> Self {
202 let now = SystemTime::now()
203 .duration_since(UNIX_EPOCH)
204 .unwrap_or_default()
205 .as_secs();
206
207 let settings_hash = compute_settings_hash(settings);
208
209 Self {
210 version: CHECKPOINT_VERSION,
211 created_at: now,
212 updated_at: now,
213 repo_id: CheckpointRepoId::from(repo_id),
214 repo_path: repo_path.to_string_lossy().to_string(),
215 settings_hash,
216 processed_files: BTreeSet::new(),
217 remaining_files: Vec::new(),
218 chunks_by_file: BTreeMap::new(),
219 total_chunks: 0,
220 total_tokens: 0,
221 failed_files: BTreeMap::new(),
222 phase: CheckpointPhase::Discovery,
223 integrity_hash: String::new(),
224 }
225 }
226
227 pub fn set_files(&mut self, files: Vec<String>) {
229 self.remaining_files = files;
230 self.phase = CheckpointPhase::Chunking;
231 self.update_timestamp();
232 }
233
234 pub fn mark_file_processed(&mut self, file: &str, chunks: &[EmbedChunk]) {
236 self.remaining_files.retain(|f| f != file);
238 self.processed_files.insert(file.to_string());
239
240 let refs: Vec<ChunkReference> = chunks.iter().map(ChunkReference::from).collect();
242 let tokens: u64 = chunks.iter().map(|c| c.tokens as u64).sum();
243
244 self.total_chunks += chunks.len();
245 self.total_tokens += tokens;
246 self.chunks_by_file.insert(file.to_string(), refs);
247
248 self.update_timestamp();
249 }
250
251 pub fn mark_file_failed(&mut self, file: &str, error: &str) {
253 self.remaining_files.retain(|f| f != file);
254 self.failed_files.insert(file.to_string(), error.to_string());
255 self.update_timestamp();
256 }
257
258 pub fn set_phase(&mut self, phase: CheckpointPhase) {
260 self.phase = phase;
261 self.update_timestamp();
262 }
263
264 pub fn is_chunking_complete(&self) -> bool {
266 self.remaining_files.is_empty()
267 && (self.phase == CheckpointPhase::Chunking
268 || self.phase == CheckpointPhase::CallGraph
269 || self.phase == CheckpointPhase::Hierarchy
270 || self.phase == CheckpointPhase::Sorting
271 || self.phase == CheckpointPhase::Complete)
272 }
273
274 pub fn progress_percent(&self) -> u32 {
276 let total = self.processed_files.len() + self.remaining_files.len() + self.failed_files.len();
277 if total == 0 {
278 return 0;
279 }
280
281 let processed = self.processed_files.len() + self.failed_files.len();
282 ((processed * 100) / total) as u32
283 }
284
285 pub fn files_processed(&self) -> usize {
287 self.processed_files.len()
288 }
289
290 pub fn files_remaining(&self) -> usize {
292 self.remaining_files.len()
293 }
294
295 pub fn files_failed(&self) -> usize {
297 self.failed_files.len()
298 }
299
300 pub fn validate(&self, repo_path: &Path, settings: &EmbedSettings) -> Result<(), CheckpointError> {
302 if self.version > CHECKPOINT_VERSION {
304 return Err(CheckpointError::VersionMismatch {
305 checkpoint_version: self.version,
306 current_version: CHECKPOINT_VERSION,
307 });
308 }
309
310 let current_path = repo_path.to_string_lossy().to_string();
312 if self.repo_path != current_path {
313 return Err(CheckpointError::RepoMismatch {
314 checkpoint_repo: self.repo_path.clone(),
315 current_repo: current_path,
316 });
317 }
318
319 let current_hash = compute_settings_hash(settings);
321 if self.settings_hash != current_hash {
322 return Err(CheckpointError::SettingsMismatch {
323 checkpoint_hash: self.settings_hash.clone(),
324 current_hash,
325 });
326 }
327
328 Ok(())
329 }
330
331 fn update_timestamp(&mut self) {
333 self.updated_at = SystemTime::now()
334 .duration_since(UNIX_EPOCH)
335 .unwrap_or_default()
336 .as_secs();
337 }
338
339 pub fn compute_integrity(&mut self) {
341 let mut hasher = Hasher::new();
342
343 hasher.update(&self.version.to_le_bytes());
345 hasher.update(&self.created_at.to_le_bytes());
346 hasher.update(self.repo_path.as_bytes());
347 hasher.update(self.settings_hash.as_bytes());
348 hasher.update(&(self.processed_files.len() as u64).to_le_bytes());
349 hasher.update(&(self.total_chunks as u64).to_le_bytes());
350 hasher.update(&self.total_tokens.to_le_bytes());
351
352 for file in &self.processed_files {
354 hasher.update(file.as_bytes());
355 }
356
357 self.integrity_hash = hasher.finalize().to_hex().to_string();
358 }
359
360 pub fn verify_integrity(&self) -> bool {
362 let mut copy = self.clone();
363 copy.compute_integrity();
364 copy.integrity_hash == self.integrity_hash
365 }
366}
367
368#[derive(Debug, Clone)]
370pub enum CheckpointError {
371 VersionMismatch {
373 checkpoint_version: u32,
374 current_version: u32,
375 },
376 RepoMismatch {
378 checkpoint_repo: String,
379 current_repo: String,
380 },
381 SettingsMismatch {
383 checkpoint_hash: String,
384 current_hash: String,
385 },
386 IntegrityFailed,
388 Corrupted(String),
390}
391
392impl std::fmt::Display for CheckpointError {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 match self {
395 Self::VersionMismatch { checkpoint_version, current_version } => {
396 write!(
397 f,
398 "Checkpoint version {} is newer than current version {}",
399 checkpoint_version, current_version
400 )
401 }
402 Self::RepoMismatch { checkpoint_repo, current_repo } => {
403 write!(
404 f,
405 "Checkpoint repo '{}' doesn't match current repo '{}'",
406 checkpoint_repo, current_repo
407 )
408 }
409 Self::SettingsMismatch { .. } => {
410 write!(f, "Checkpoint settings don't match current settings")
411 }
412 Self::IntegrityFailed => {
413 write!(f, "Checkpoint integrity verification failed")
414 }
415 Self::Corrupted(reason) => {
416 write!(f, "Checkpoint corrupted: {}", reason)
417 }
418 }
419 }
420}
421
422impl std::error::Error for CheckpointError {}
423
424pub struct CheckpointManager {
426 path: PathBuf,
428}
429
430impl CheckpointManager {
431 pub fn new(path: impl Into<PathBuf>) -> Self {
433 Self { path: path.into() }
434 }
435
436 pub fn path(&self) -> &Path {
438 &self.path
439 }
440
441 pub fn exists(&self) -> bool {
443 self.path.exists()
444 }
445
446 pub fn load(&self) -> Result<Option<EmbedCheckpoint>, EmbedError> {
448 if !self.path.exists() {
449 return Ok(None);
450 }
451
452 let bytes = std::fs::read(&self.path).map_err(|e| EmbedError::IoError {
453 path: self.path.clone(),
454 source: e,
455 })?;
456
457 let checkpoint: EmbedCheckpoint =
458 deserialize_with_limit(&bytes).map_err(|e| EmbedError::DeserializationError {
459 reason: format!("Failed to deserialize checkpoint: {}", e),
460 })?;
461
462 if !checkpoint.verify_integrity() {
464 return Err(EmbedError::ManifestCorrupted {
465 path: self.path.clone(),
466 expected: checkpoint.integrity_hash.clone(),
467 actual: "integrity check failed".to_string(),
468 });
469 }
470
471 Ok(Some(checkpoint))
472 }
473
474 pub fn save(&self, checkpoint: &mut EmbedCheckpoint) -> Result<(), EmbedError> {
476 checkpoint.compute_integrity();
478
479 let bytes = bincode::options()
480 .serialize(checkpoint)
481 .map_err(|e| EmbedError::SerializationError {
482 reason: format!("Failed to serialize checkpoint: {}", e),
483 })?;
484
485 let temp_path = self.path.with_extension("tmp");
487
488 std::fs::write(&temp_path, &bytes).map_err(|e| EmbedError::IoError {
489 path: temp_path.clone(),
490 source: e,
491 })?;
492
493 std::fs::rename(&temp_path, &self.path).map_err(|e| EmbedError::IoError {
494 path: self.path.clone(),
495 source: e,
496 })?;
497
498 Ok(())
499 }
500
501 pub fn delete(&self) -> Result<(), EmbedError> {
503 if self.path.exists() {
504 std::fs::remove_file(&self.path).map_err(|e| EmbedError::IoError {
505 path: self.path.clone(),
506 source: e,
507 })?;
508 }
509 Ok(())
510 }
511
512 pub fn load_validated(
514 &self,
515 repo_path: &Path,
516 settings: &EmbedSettings,
517 ) -> Result<Option<EmbedCheckpoint>, EmbedError> {
518 let checkpoint = match self.load()? {
519 Some(cp) => cp,
520 None => return Ok(None),
521 };
522
523 match checkpoint.validate(repo_path, settings) {
525 Ok(()) => Ok(Some(checkpoint)),
526 Err(CheckpointError::SettingsMismatch { .. }) => {
527 Ok(None)
529 }
530 Err(CheckpointError::RepoMismatch { .. }) => {
531 Ok(None)
533 }
534 Err(e) => Err(EmbedError::DeserializationError {
535 reason: e.to_string(),
536 }),
537 }
538 }
539}
540
541fn compute_settings_hash(settings: &EmbedSettings) -> String {
543 let mut hasher = Hasher::new();
544
545 hasher.update(&settings.max_tokens.to_le_bytes());
547 hasher.update(&settings.min_tokens.to_le_bytes());
548 hasher.update(&settings.overlap_tokens.to_le_bytes());
549 hasher.update(&settings.context_lines.to_le_bytes());
550 hasher.update(settings.token_model.as_bytes());
551 hasher.update(&[settings.include_imports as u8]);
552 hasher.update(&[settings.include_tests as u8]);
553 hasher.update(&[settings.include_top_level as u8]);
554 hasher.update(&[settings.scan_secrets as u8]);
555 hasher.update(&[settings.redact_secrets as u8]);
556 hasher.update(&[settings.fail_on_secrets as u8]);
557 hasher.update(&[settings.enable_hierarchy as u8]);
558 hasher.update(&settings.hierarchy_min_children.to_le_bytes());
559
560 for pattern in &settings.include_patterns {
562 hasher.update(pattern.as_bytes());
563 }
564 for pattern in &settings.exclude_patterns {
565 hasher.update(pattern.as_bytes());
566 }
567
568 hasher.finalize().to_hex().to_string()
569}
570
571#[derive(Debug, Clone)]
573pub struct CheckpointStats {
574 pub created_at: u64,
576 pub updated_at: u64,
578 pub phase: CheckpointPhase,
580 pub files_processed: usize,
582 pub files_remaining: usize,
584 pub files_failed: usize,
586 pub total_chunks: usize,
588 pub total_tokens: u64,
590 pub progress_percent: u32,
592}
593
594impl From<&EmbedCheckpoint> for CheckpointStats {
595 fn from(cp: &EmbedCheckpoint) -> Self {
596 Self {
597 created_at: cp.created_at,
598 updated_at: cp.updated_at,
599 phase: cp.phase,
600 files_processed: cp.files_processed(),
601 files_remaining: cp.files_remaining(),
602 files_failed: cp.files_failed(),
603 total_chunks: cp.total_chunks,
604 total_tokens: cp.total_tokens,
605 progress_percent: cp.progress_percent(),
606 }
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 fn test_settings() -> EmbedSettings {
615 EmbedSettings::default()
616 }
617
618 #[test]
619 fn test_checkpoint_creation() {
620 let settings = test_settings();
621 let cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
622
623 assert_eq!(cp.version, CHECKPOINT_VERSION);
624 assert_eq!(cp.phase, CheckpointPhase::Discovery);
625 assert!(cp.processed_files.is_empty());
626 assert!(cp.remaining_files.is_empty());
627 assert_eq!(cp.total_chunks, 0);
628 }
629
630 #[test]
631 fn test_checkpoint_file_tracking() {
632 let settings = test_settings();
633 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
634
635 cp.set_files(vec![
637 "a.rs".to_string(),
638 "b.rs".to_string(),
639 "c.rs".to_string(),
640 ]);
641
642 assert_eq!(cp.files_remaining(), 3);
643 assert_eq!(cp.files_processed(), 0);
644 assert_eq!(cp.phase, CheckpointPhase::Chunking);
645
646 cp.mark_file_processed("a.rs", &[]);
648 assert_eq!(cp.files_remaining(), 2);
649 assert_eq!(cp.files_processed(), 1);
650
651 cp.mark_file_failed("b.rs", "Parse error");
653 assert_eq!(cp.files_remaining(), 1);
654 assert_eq!(cp.files_processed(), 1);
655 assert_eq!(cp.files_failed(), 1);
656
657 assert_eq!(cp.progress_percent(), 66);
659 }
660
661 #[test]
662 fn test_checkpoint_integrity() {
663 let settings = test_settings();
664 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
665 cp.set_files(vec!["test.rs".to_string()]);
666 cp.mark_file_processed("test.rs", &[]);
667
668 cp.compute_integrity();
670 assert!(!cp.integrity_hash.is_empty());
671 assert!(cp.verify_integrity());
672
673 cp.total_chunks = 999;
675 assert!(!cp.verify_integrity());
676 }
677
678 #[test]
679 fn test_checkpoint_validation() {
680 let settings = test_settings();
681 let cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
682
683 assert!(cp.validate(Path::new("/test/repo"), &settings).is_ok());
685
686 assert!(cp.validate(Path::new("/other/repo"), &settings).is_err());
688
689 let mut different_settings = settings.clone();
691 different_settings.max_tokens = 9999;
692 assert!(cp.validate(Path::new("/test/repo"), &different_settings).is_err());
693 }
694
695 #[test]
696 fn test_checkpoint_save_load() {
697 let temp_dir = tempfile::TempDir::new().unwrap();
698 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
699
700 let settings = test_settings();
701 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
702 cp.set_files(vec!["test.rs".to_string()]);
703 cp.mark_file_processed("test.rs", &[]);
704
705 let manager = CheckpointManager::new(&checkpoint_path);
707 manager.save(&mut cp).unwrap();
708
709 let loaded = manager.load().unwrap().unwrap();
711 assert_eq!(loaded.files_processed(), 1);
712 assert!(loaded.verify_integrity());
713 }
714
715 #[test]
716 fn test_checkpoint_manager_validated() {
717 let temp_dir = tempfile::TempDir::new().unwrap();
718 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
719 let repo_path = Path::new("/test/repo");
720
721 let settings = test_settings();
722 let mut cp = EmbedCheckpoint::new(repo_path, RepoIdentifier::default(), &settings);
723 cp.set_files(vec!["test.rs".to_string()]);
724
725 let manager = CheckpointManager::new(&checkpoint_path);
726 manager.save(&mut cp).unwrap();
727
728 let loaded = manager.load_validated(repo_path, &settings).unwrap();
730 assert!(loaded.is_some());
731
732 let mut different = settings.clone();
734 different.max_tokens = 9999;
735 let loaded = manager.load_validated(repo_path, &different).unwrap();
736 assert!(loaded.is_none());
737 }
738
739 #[test]
740 fn test_checkpoint_phases() {
741 let settings = test_settings();
742 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
743
744 assert_eq!(cp.phase, CheckpointPhase::Discovery);
745
746 cp.set_phase(CheckpointPhase::Chunking);
747 assert_eq!(cp.phase, CheckpointPhase::Chunking);
748
749 cp.set_phase(CheckpointPhase::CallGraph);
750 assert_eq!(cp.phase, CheckpointPhase::CallGraph);
751
752 cp.set_phase(CheckpointPhase::Complete);
753 assert_eq!(cp.phase, CheckpointPhase::Complete);
754 }
755
756 #[test]
757 fn test_checkpoint_stats() {
758 let settings = test_settings();
759 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
760 cp.set_files(vec!["a.rs".to_string(), "b.rs".to_string()]);
761 cp.mark_file_processed("a.rs", &[]);
762
763 let stats = CheckpointStats::from(&cp);
764 assert_eq!(stats.files_processed, 1);
765 assert_eq!(stats.files_remaining, 1);
766 assert_eq!(stats.progress_percent, 50);
767 }
768
769 #[test]
770 fn test_settings_hash_determinism() {
771 let settings = test_settings();
772 let hash1 = compute_settings_hash(&settings);
773 let hash2 = compute_settings_hash(&settings);
774 assert_eq!(hash1, hash2);
775
776 let mut different = settings.clone();
778 different.max_tokens = 2000; let hash3 = compute_settings_hash(&different);
780 assert_ne!(hash1, hash3);
781 }
782
783 #[test]
784 fn test_checkpoint_delete() {
785 let temp_dir = tempfile::TempDir::new().unwrap();
786 let checkpoint_path = temp_dir.path().join("checkpoint.bin");
787
788 let settings = test_settings();
789 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
790
791 let manager = CheckpointManager::new(&checkpoint_path);
792 manager.save(&mut cp).unwrap();
793 assert!(manager.exists());
794
795 manager.delete().unwrap();
796 assert!(!manager.exists());
797 }
798
799 #[test]
800 fn test_is_chunking_complete() {
801 let settings = test_settings();
802 let mut cp = EmbedCheckpoint::new(Path::new("/test/repo"), RepoIdentifier::default(), &settings);
803
804 assert!(!cp.is_chunking_complete());
806
807 cp.set_files(vec!["a.rs".to_string()]);
809 assert!(!cp.is_chunking_complete());
810
811 cp.mark_file_processed("a.rs", &[]);
813 assert!(cp.is_chunking_complete());
814 }
815
816 #[test]
817 fn test_chunk_reference_from_embed_chunk() {
818 use super::super::types::{ChunkKind, ChunkSource, ChunkContext};
819
820 let chunk = EmbedChunk {
821 id: "ec_abc123".to_string(),
822 full_hash: "deadbeef".repeat(8),
823 content: "fn test() {}".to_string(),
824 tokens: 10,
825 kind: ChunkKind::Function,
826 source: ChunkSource {
827 repo: RepoIdentifier::default(),
828 file: "test.rs".to_string(),
829 lines: (1, 5),
830 symbol: "test".to_string(),
831 fqn: None,
832 language: "Rust".to_string(),
833 parent: None,
834 visibility: super::super::types::Visibility::Public,
835 is_test: false,
836 },
837 context: ChunkContext::default(),
838 part: None,
839 };
840
841 let reference = ChunkReference::from(&chunk);
842 assert_eq!(reference.id, "ec_abc123");
843 assert_eq!(reference.tokens, 10);
844 assert_eq!(reference.lines, (1, 5));
845 assert_eq!(reference.symbol, "test");
846 }
847}