use super::routing::ShardKey;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fs::{self, File};
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
use std::time::SystemTime;
use thiserror::Error;
const CHECKPOINT_VERSION: u32 = 1;
#[derive(Error, Debug)]
pub enum CheckpointError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Checkpoint version mismatch: expected {expected}, found {found}")]
VersionMismatch {
expected: u32,
found: u32,
},
#[error("Checkpoint integrity error: {0}")]
Integrity(String),
#[error("Shard {shard_key} state inconsistency: {message}")]
ShardInconsistency {
shard_key: String,
message: String,
},
}
pub type CheckpointResult<T> = Result<T, CheckpointError>;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum ImportState {
NotStarted,
InProgress {
started_at: u64,
phase: ImportPhase,
},
Completed {
completed_at: u64,
total_ngrams: u64,
unique_ngrams: u64,
},
Failed {
failed_at: u64,
error: String,
},
RequiresRecovery {
last_checkpoint_at: u64,
in_progress_shards: Vec<String>,
},
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum ImportPhase {
Importing,
ComputingMkn,
Merging,
Finalizing,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ShardCheckpointRecord {
pub prefix: String,
pub order: Option<u8>,
pub path: PathBuf,
pub entry_count: u64,
pub completed_prefixes: HashSet<String>,
pub current_prefix: Option<String>,
pub ngrams_processed: u64,
pub last_lsn: u64,
pub last_checkpoint_time: u64,
}
impl ShardCheckpointRecord {
pub fn new(key: &ShardKey, path: impl Into<PathBuf>) -> Self {
Self {
prefix: key.prefix.clone(),
order: key.order,
path: path.into(),
entry_count: 0,
completed_prefixes: HashSet::new(),
current_prefix: None,
ngrams_processed: 0,
last_lsn: 0,
last_checkpoint_time: current_timestamp(),
}
}
pub fn to_shard_key(&self) -> ShardKey {
if let Some(order) = self.order {
ShardKey::with_order(&self.prefix, order)
} else {
ShardKey::new(&self.prefix)
}
}
pub fn is_in_progress(&self) -> bool {
self.current_prefix.is_some()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GlobalCheckpoint {
version: u32,
created_at: u64,
last_updated: u64,
pub import_state: ImportState,
pub shards: HashMap<String, ShardCheckpointRecord>,
pub metadata: HashMap<String, String>,
}
impl Default for GlobalCheckpoint {
fn default() -> Self {
Self::new()
}
}
impl GlobalCheckpoint {
pub fn new() -> Self {
let now = current_timestamp();
Self {
version: CHECKPOINT_VERSION,
created_at: now,
last_updated: now,
import_state: ImportState::NotStarted,
shards: HashMap::new(),
metadata: HashMap::new(),
}
}
pub fn load_or_create(path: impl AsRef<Path>) -> CheckpointResult<Self> {
let path = path.as_ref();
if path.exists() {
Self::load(path)
} else {
Ok(Self::new())
}
}
pub fn load(path: impl AsRef<Path>) -> CheckpointResult<Self> {
let path = path.as_ref();
let file = File::open(path)?;
let reader = BufReader::new(file);
let checkpoint: GlobalCheckpoint = serde_json::from_reader(reader)?;
if checkpoint.version != CHECKPOINT_VERSION {
return Err(CheckpointError::VersionMismatch {
expected: CHECKPOINT_VERSION,
found: checkpoint.version,
});
}
Ok(checkpoint)
}
pub fn save(&self, path: impl AsRef<Path>) -> CheckpointResult<()> {
let path = path.as_ref();
let temp_path = path.with_extension("json.tmp");
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
{
let file = File::create(&temp_path)?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, self)?;
}
fs::rename(&temp_path, path)?;
Ok(())
}
pub fn touch(&mut self) {
self.last_updated = current_timestamp();
}
pub fn start_import(&mut self) {
self.import_state = ImportState::InProgress {
started_at: current_timestamp(),
phase: ImportPhase::Importing,
};
self.touch();
}
pub fn set_phase(&mut self, phase: ImportPhase) {
if let ImportState::InProgress { started_at, .. } = &self.import_state {
self.import_state = ImportState::InProgress {
started_at: *started_at,
phase,
};
self.touch();
}
}
pub fn complete_import(&mut self, total_ngrams: u64, unique_ngrams: u64) {
self.import_state = ImportState::Completed {
completed_at: current_timestamp(),
total_ngrams,
unique_ngrams,
};
self.touch();
}
pub fn fail_import(&mut self, error: impl Into<String>) {
self.import_state = ImportState::Failed {
failed_at: current_timestamp(),
error: error.into(),
};
self.touch();
}
pub fn needs_recovery(&self) -> bool {
matches!(self.import_state, ImportState::RequiresRecovery { .. })
}
pub fn is_in_progress(&self) -> bool {
matches!(self.import_state, ImportState::InProgress { .. })
}
pub fn is_completed(&self) -> bool {
matches!(self.import_state, ImportState::Completed { .. })
}
pub fn get_or_create_shard(
&mut self,
key: &ShardKey,
path: impl Into<PathBuf>,
) -> &mut ShardCheckpointRecord {
let key_str = key.to_string();
self.shards
.entry(key_str)
.or_insert_with(|| ShardCheckpointRecord::new(key, path))
}
pub fn update_shard(&mut self, key: &ShardKey, entry_count: u64, ngrams_processed: u64) {
let key_str = key.to_string();
if let Some(record) = self.shards.get_mut(&key_str) {
record.entry_count = entry_count;
record.ngrams_processed = ngrams_processed;
record.last_checkpoint_time = current_timestamp();
}
self.touch();
}
pub fn complete_prefix(&mut self, key: &ShardKey, prefix: &str) {
let key_str = key.to_string();
if let Some(record) = self.shards.get_mut(&key_str) {
record.completed_prefixes.insert(prefix.to_string());
record.current_prefix = None;
}
self.touch();
}
pub fn set_current_prefix(&mut self, key: &ShardKey, prefix: Option<&str>) {
let key_str = key.to_string();
if let Some(record) = self.shards.get_mut(&key_str) {
record.current_prefix = prefix.map(String::from);
}
self.touch();
}
pub fn all_completed_prefixes(&self) -> HashSet<String> {
self.shards
.values()
.flat_map(|r| r.completed_prefixes.iter().cloned())
.collect()
}
pub fn completed_prefixes_for_order(&self, order: u8) -> HashSet<String> {
self.shards
.values()
.filter(|r| {
r.order.is_none() || r.order == Some(order)
})
.flat_map(|r| r.completed_prefixes.iter().cloned())
.collect()
}
pub fn in_progress_shards(&self) -> Vec<&ShardCheckpointRecord> {
self.shards
.values()
.filter(|r| r.is_in_progress())
.collect()
}
pub fn detect_recovery_needed(&mut self) {
if let ImportState::InProgress { .. } = &self.import_state {
let in_progress: Vec<String> = self
.shards
.iter()
.filter(|(_, r)| r.is_in_progress())
.map(|(k, _)| k.clone())
.collect();
self.import_state = ImportState::RequiresRecovery {
last_checkpoint_at: self.last_updated,
in_progress_shards: in_progress,
};
}
}
pub fn resume_import(&mut self) {
if let ImportState::RequiresRecovery { .. } = &self.import_state {
self.import_state = ImportState::InProgress {
started_at: current_timestamp(),
phase: ImportPhase::Importing,
};
self.touch();
}
}
pub fn total_ngrams(&self) -> u64 {
self.shards.values().map(|r| r.ngrams_processed).sum()
}
pub fn total_entries(&self) -> u64 {
self.shards.values().map(|r| r.entry_count).sum()
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
self.touch();
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
pub fn summary(&self) -> CheckpointSummary {
CheckpointSummary {
state: format!("{:?}", self.import_state),
shard_count: self.shards.len(),
total_entries: self.total_entries(),
total_ngrams: self.total_ngrams(),
completed_prefixes: self.all_completed_prefixes().len(),
last_updated: self.last_updated,
}
}
}
#[derive(Clone, Debug)]
pub struct CheckpointSummary {
pub state: String,
pub shard_count: usize,
pub total_entries: u64,
pub total_ngrams: u64,
pub completed_prefixes: usize,
pub last_updated: u64,
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub struct CheckpointManager {
checkpoint_path: PathBuf,
checkpoint: GlobalCheckpoint,
auto_save_interval_ms: u64,
last_save_time: std::time::Instant,
}
impl CheckpointManager {
pub fn new(
checkpoint_path: impl Into<PathBuf>,
auto_save_interval_ms: u64,
) -> CheckpointResult<Self> {
let checkpoint_path = checkpoint_path.into();
let checkpoint = GlobalCheckpoint::load_or_create(&checkpoint_path)?;
Ok(Self {
checkpoint_path,
checkpoint,
auto_save_interval_ms,
last_save_time: std::time::Instant::now(),
})
}
pub fn checkpoint(&self) -> &GlobalCheckpoint {
&self.checkpoint
}
pub fn checkpoint_mut(&mut self) -> &mut GlobalCheckpoint {
&mut self.checkpoint
}
pub fn save(&mut self) -> CheckpointResult<()> {
self.checkpoint.save(&self.checkpoint_path)?;
self.last_save_time = std::time::Instant::now();
Ok(())
}
pub fn maybe_save(&mut self) -> CheckpointResult<bool> {
let elapsed = self.last_save_time.elapsed().as_millis() as u64;
if elapsed >= self.auto_save_interval_ms {
self.save()?;
Ok(true)
} else {
Ok(false)
}
}
pub fn needs_recovery(&self) -> bool {
self.checkpoint.needs_recovery()
}
pub fn detect_recovery(&mut self) {
self.checkpoint.detect_recovery_needed();
}
pub fn resume(&mut self) {
self.checkpoint.resume_import();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_checkpoint_create_and_save() {
let dir = TempDir::new().expect("Failed to create temp dir");
let path = dir.path().join("checkpoint.json");
let mut checkpoint = GlobalCheckpoint::new();
checkpoint.start_import();
checkpoint.set_metadata("language", "en");
checkpoint.save(&path).expect("Failed to save");
let loaded = GlobalCheckpoint::load(&path).expect("Failed to load");
assert!(loaded.is_in_progress());
assert_eq!(loaded.get_metadata("language"), Some(&"en".to_string()));
}
#[test]
fn test_checkpoint_shard_tracking() {
let mut checkpoint = GlobalCheckpoint::new();
checkpoint.start_import();
let key = ShardKey::new("th");
checkpoint.get_or_create_shard(&key, "/tmp/shard_th.artrie");
checkpoint.set_current_prefix(&key, Some("th"));
assert_eq!(checkpoint.in_progress_shards().len(), 1);
checkpoint.complete_prefix(&key, "th");
assert!(checkpoint.all_completed_prefixes().contains("th"));
assert!(checkpoint.in_progress_shards().is_empty());
}
#[test]
fn test_recovery_detection() {
let mut checkpoint = GlobalCheckpoint::new();
checkpoint.start_import();
let key = ShardKey::new("th");
checkpoint.get_or_create_shard(&key, "/tmp/shard_th.artrie");
checkpoint.set_current_prefix(&key, Some("th"));
checkpoint.detect_recovery_needed();
assert!(checkpoint.needs_recovery());
if let ImportState::RequiresRecovery {
in_progress_shards, ..
} = &checkpoint.import_state
{
assert_eq!(in_progress_shards.len(), 1);
assert_eq!(in_progress_shards[0], "th");
} else {
panic!("Expected RequiresRecovery state");
}
checkpoint.resume_import();
assert!(checkpoint.is_in_progress());
}
#[test]
fn test_checkpoint_completion() {
let mut checkpoint = GlobalCheckpoint::new();
checkpoint.start_import();
checkpoint.complete_import(1000000, 500000);
assert!(checkpoint.is_completed());
if let ImportState::Completed {
total_ngrams,
unique_ngrams,
..
} = &checkpoint.import_state
{
assert_eq!(*total_ngrams, 1000000);
assert_eq!(*unique_ngrams, 500000);
} else {
panic!("Expected Completed state");
}
}
#[test]
fn test_checkpoint_atomic_save() {
let dir = TempDir::new().expect("Failed to create temp dir");
let path = dir.path().join("checkpoint.json");
let temp_path = path.with_extension("json.tmp");
let mut checkpoint = GlobalCheckpoint::new();
checkpoint.start_import();
checkpoint.save(&path).expect("Failed to save");
assert!(!temp_path.exists());
assert!(path.exists());
}
#[test]
fn test_checkpoint_manager() {
let dir = TempDir::new().expect("Failed to create temp dir");
let path = dir.path().join("checkpoint.json");
let mut manager = CheckpointManager::new(&path, 1000).expect("Failed to create manager");
manager.checkpoint_mut().start_import();
manager.save().expect("Failed to save");
let manager2 = CheckpointManager::new(&path, 1000).expect("Failed to create manager");
assert!(manager2.checkpoint().is_in_progress());
}
#[test]
fn test_shard_checkpoint_record() {
let key = ShardKey::with_order("th", 2);
let record = ShardCheckpointRecord::new(&key, "/tmp/shard.artrie");
assert_eq!(record.prefix, "th");
assert_eq!(record.order, Some(2));
assert!(!record.is_in_progress());
let restored_key = record.to_shard_key();
assert_eq!(restored_key.prefix, "th");
assert_eq!(restored_key.order, Some(2));
}
}