use std::fs::{self, File};
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
use std::time::Instant;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::cli::error::{CliError, CliResult};
const CHECKPOINT_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NgramCheckpoint {
pub version: u32,
pub config: NgramCheckpointConfig,
pub state: NgramTrainingState,
pub accumulator_path: PathBuf,
pub unique_ngrams: usize,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NgramCheckpointConfig {
pub order: usize,
pub min_count: u64,
pub corpus_path: String,
pub lowercase: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NgramTrainingState {
pub sentences_processed: u64,
pub tokens_processed: u64,
pub bytes_read: u64,
pub total_bytes: Option<u64>,
pub elapsed_secs: f64,
pub corpus_position: CorpusPosition,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CorpusPosition {
pub file_index: usize,
pub line_number: usize,
pub byte_offset: u64,
}
pub struct CheckpointManager {
checkpoint_dir: PathBuf,
max_checkpoints: usize,
}
impl CheckpointManager {
pub fn new(checkpoint_dir: &Path, max_checkpoints: usize) -> CliResult<Self> {
fs::create_dir_all(checkpoint_dir).map_err(|e| {
CliError::io(format!(
"Failed to create checkpoint directory {}: {}",
checkpoint_dir.display(),
e
))
})?;
Ok(Self {
checkpoint_dir: checkpoint_dir.to_path_buf(),
max_checkpoints,
})
}
pub fn accumulator_path(&self) -> PathBuf {
self.checkpoint_dir.join("ngram_hot.artrie")
}
pub fn save_ngram_checkpoint(&self, checkpoint: &NgramCheckpoint) -> CliResult<PathBuf> {
let name = format!("ngram_ckpt_{}", checkpoint.state.sentences_processed);
let temp_path = self.checkpoint_dir.join(format!("{}.tmp", name));
let final_path = self.checkpoint_dir.join(format!("{}.bin", name));
let file = File::create(&temp_path)
.map_err(|e| CliError::io(format!("Failed to create checkpoint file: {}", e)))?;
let writer = BufWriter::new(file);
let encoder = zstd::Encoder::new(writer, 3)
.map_err(|e| CliError::io(format!("Failed to create zstd encoder: {}", e)))?
.auto_finish();
bincode::serialize_into(encoder, checkpoint)
.map_err(|e| CliError::io(format!("Failed to serialize checkpoint: {}", e)))?;
fs::rename(&temp_path, &final_path)
.map_err(|e| CliError::io(format!("Failed to finalize checkpoint: {}", e)))?;
let latest = self.checkpoint_dir.join("latest.bin");
let _ = fs::remove_file(&latest);
#[cfg(unix)]
{
let _ = std::os::unix::fs::symlink(&final_path, &latest);
}
#[cfg(not(unix))]
{
let _ = fs::copy(&final_path, &latest);
}
self.prune_old_checkpoints()?;
Ok(final_path)
}
pub fn load_ngram_checkpoint(&self, name: &str) -> CliResult<NgramCheckpoint> {
let path = if name == "latest" {
self.checkpoint_dir.join("latest.bin")
} else if name.ends_with(".bin") {
PathBuf::from(name)
} else {
self.checkpoint_dir.join(format!("{}.bin", name))
};
if !path.exists() {
return Err(CliError::file_not_found(&path));
}
let file = File::open(&path)
.map_err(|e| CliError::io(format!("Failed to open checkpoint: {}", e)))?;
let reader = BufReader::new(file);
let decoder = zstd::Decoder::new(reader)
.map_err(|e| CliError::io(format!("Failed to create zstd decoder: {}", e)))?;
let checkpoint: NgramCheckpoint = bincode::deserialize_from(decoder)
.map_err(|e| CliError::io(format!("Failed to deserialize checkpoint: {}", e)))?;
if checkpoint.version != CHECKPOINT_VERSION {
return Err(CliError::unsupported(format!(
"Checkpoint version {} not supported (expected {})",
checkpoint.version, CHECKPOINT_VERSION
)));
}
Ok(checkpoint)
}
pub fn list_checkpoints(&self) -> CliResult<Vec<CheckpointInfo>> {
let mut checkpoints = Vec::new();
for entry in fs::read_dir(&self.checkpoint_dir)
.map_err(|e| CliError::io(format!("Failed to read checkpoint directory: {}", e)))?
{
let entry = entry.map_err(|e| CliError::io(format!("Directory read error: {}", e)))?;
let path = entry.path();
if path.extension().map_or(false, |ext| ext == "bin")
&& path
.file_stem()
.map_or(false, |s| s.to_string_lossy().starts_with("ngram_ckpt_"))
{
let metadata = fs::metadata(&path)
.map_err(|e| CliError::io(format!("Failed to read metadata: {}", e)))?;
checkpoints.push(CheckpointInfo {
path,
size: metadata.len(),
modified: metadata.modified().ok().map(|t| DateTime::<Utc>::from(t)),
});
}
}
checkpoints.sort_by(|a, b| b.modified.cmp(&a.modified));
Ok(checkpoints)
}
fn prune_old_checkpoints(&self) -> CliResult<()> {
let checkpoints = self.list_checkpoints()?;
for checkpoint in checkpoints.into_iter().skip(self.max_checkpoints) {
log::debug!("Pruning old checkpoint: {}", checkpoint.path.display());
let _ = fs::remove_file(&checkpoint.path);
}
Ok(())
}
}
#[derive(Debug)]
pub struct CheckpointInfo {
pub path: PathBuf,
pub size: u64,
pub modified: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingCheckpoint {
pub version: u32,
pub config: EmbeddingCheckpointConfig,
pub state: EmbeddingTrainingState,
pub model_path: PathBuf,
pub vocab_size: usize,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingCheckpointConfig {
pub dim: usize,
pub window: usize,
pub min_count: u64,
pub neg_samples: usize,
pub epochs: u32,
pub learning_rate: f64,
pub corpus_path: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EmbeddingTrainingState {
pub completed_epochs: u32,
pub words_processed: u64,
pub total_words: u64,
pub current_learning_rate: f64,
pub loss_history: Vec<f64>,
pub elapsed_secs: f64,
}
impl CheckpointManager {
pub fn save_embedding_checkpoint(
&self,
checkpoint: &EmbeddingCheckpoint,
) -> CliResult<PathBuf> {
let name = format!("embedding_epoch_{}", checkpoint.state.completed_epochs);
let temp_path = self.checkpoint_dir.join(format!("{}.tmp", name));
let final_path = self.checkpoint_dir.join(format!("{}.bin", name));
let file = File::create(&temp_path)
.map_err(|e| CliError::io(format!("Failed to create checkpoint file: {}", e)))?;
let writer = BufWriter::new(file);
let encoder = zstd::Encoder::new(writer, 3)
.map_err(|e| CliError::io(format!("Failed to create zstd encoder: {}", e)))?
.auto_finish();
bincode::serialize_into(encoder, checkpoint)
.map_err(|e| CliError::io(format!("Failed to serialize checkpoint: {}", e)))?;
fs::rename(&temp_path, &final_path)
.map_err(|e| CliError::io(format!("Failed to finalize checkpoint: {}", e)))?;
let latest = self.checkpoint_dir.join("embedding_latest.bin");
let _ = fs::remove_file(&latest);
#[cfg(unix)]
{
let _ = std::os::unix::fs::symlink(&final_path, &latest);
}
#[cfg(not(unix))]
{
let _ = fs::copy(&final_path, &latest);
}
Ok(final_path)
}
pub fn load_embedding_checkpoint(&self, name: &str) -> CliResult<EmbeddingCheckpoint> {
let path = if name == "latest" {
self.checkpoint_dir.join("embedding_latest.bin")
} else if name.ends_with(".bin") {
PathBuf::from(name)
} else {
self.checkpoint_dir.join(format!("{}.bin", name))
};
if !path.exists() {
return Err(CliError::file_not_found(&path));
}
let file = File::open(&path)
.map_err(|e| CliError::io(format!("Failed to open checkpoint: {}", e)))?;
let reader = BufReader::new(file);
let decoder = zstd::Decoder::new(reader)
.map_err(|e| CliError::io(format!("Failed to create zstd decoder: {}", e)))?;
let checkpoint: EmbeddingCheckpoint = bincode::deserialize_from(decoder)
.map_err(|e| CliError::io(format!("Failed to deserialize checkpoint: {}", e)))?;
if checkpoint.version != CHECKPOINT_VERSION {
return Err(CliError::unsupported(format!(
"Checkpoint version {} not supported (expected {})",
checkpoint.version, CHECKPOINT_VERSION
)));
}
Ok(checkpoint)
}
pub fn embedding_model_path(&self, epoch: u32) -> PathBuf {
self.checkpoint_dir
.join(format!("embedding_model_epoch_{}.bin", epoch))
}
}
pub struct TrainingTimer {
start: Instant,
elapsed_before_pause: f64,
paused_at: Option<Instant>,
}
impl TrainingTimer {
pub fn new() -> Self {
Self {
start: Instant::now(),
elapsed_before_pause: 0.0,
paused_at: None,
}
}
pub fn resume_from(elapsed_secs: f64) -> Self {
Self {
start: Instant::now(),
elapsed_before_pause: elapsed_secs,
paused_at: None,
}
}
pub fn elapsed_secs(&self) -> f64 {
if let Some(paused) = self.paused_at {
self.elapsed_before_pause + (paused - self.start).as_secs_f64()
} else {
self.elapsed_before_pause + self.start.elapsed().as_secs_f64()
}
}
pub fn pause(&mut self) {
if self.paused_at.is_none() {
self.paused_at = Some(Instant::now());
}
}
pub fn resume(&mut self) {
if let Some(paused) = self.paused_at.take() {
self.elapsed_before_pause += (paused - self.start).as_secs_f64();
self.start = Instant::now();
}
}
}
impl Default for TrainingTimer {
fn default() -> Self {
Self::new()
}
}