use crate::error::BatchError;
use serde::{Deserialize, Serialize};
use std::fs::{self, File};
use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub job_id: String,
pub total_count: usize,
pub generated_count: usize,
pub output_path: String,
pub last_key: Option<String>,
pub updated_at: u64,
pub mode: GenerationMode,
pub start_key: Option<String>,
pub current_position: u64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum GenerationMode {
Random,
Incremental,
}
impl Checkpoint {
pub fn new_random(job_id: &str, total_count: usize, output_path: &str) -> Self {
Self {
job_id: job_id.to_string(),
total_count,
generated_count: 0,
output_path: output_path.to_string(),
last_key: None,
updated_at: current_timestamp(),
mode: GenerationMode::Random,
start_key: None,
current_position: 0,
}
}
pub fn new_incremental(
job_id: &str,
total_count: usize,
output_path: &str,
start_key: &str,
) -> Self {
Self {
job_id: job_id.to_string(),
total_count,
generated_count: 0,
output_path: output_path.to_string(),
last_key: None,
updated_at: current_timestamp(),
mode: GenerationMode::Incremental,
start_key: Some(start_key.to_string()),
current_position: 0,
}
}
pub fn update(&mut self, generated: usize, last_key: Option<String>) {
self.generated_count = generated;
self.last_key = last_key;
self.updated_at = current_timestamp();
}
pub fn update_position(&mut self, position: u64) {
self.current_position = position;
self.updated_at = current_timestamp();
}
pub fn is_complete(&self) -> bool {
self.generated_count >= self.total_count
}
pub fn remaining(&self) -> usize {
self.total_count.saturating_sub(self.generated_count)
}
pub fn progress_percent(&self) -> f64 {
if self.total_count == 0 {
100.0
} else {
(self.generated_count as f64 / self.total_count as f64) * 100.0
}
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), BatchError> {
let file = File::create(path)
.map_err(|e| BatchError::io_error(format!("Failed to create checkpoint file: {}", e)))?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, self)
.map_err(|e| BatchError::io_error(format!("Failed to write checkpoint: {}", e)))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, BatchError> {
let file = File::open(path)
.map_err(|e| BatchError::io_error(format!("Failed to open checkpoint file: {}", e)))?;
let reader = BufReader::new(file);
let checkpoint: Self = serde_json::from_reader(reader)
.map_err(|e| BatchError::io_error(format!("Failed to parse checkpoint: {}", e)))?;
Ok(checkpoint)
}
pub fn exists<P: AsRef<Path>>(path: P) -> bool {
path.as_ref().exists()
}
pub fn delete<P: AsRef<Path>>(path: P) -> Result<(), BatchError> {
if path.as_ref().exists() {
fs::remove_file(path)
.map_err(|e| BatchError::io_error(format!("Failed to delete checkpoint: {}", e)))?;
}
Ok(())
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub struct ResumableBatchGenerator {
checkpoint: Checkpoint,
checkpoint_path: String,
chunk_size: usize,
checkpoint_interval: usize,
parallel: bool,
}
impl ResumableBatchGenerator {
pub fn new(job_id: &str, total_count: usize, output_path: &str, checkpoint_path: &str) -> Self {
let checkpoint = if Checkpoint::exists(checkpoint_path) {
Checkpoint::load(checkpoint_path).unwrap_or_else(|_| {
Checkpoint::new_random(job_id, total_count, output_path)
})
} else {
Checkpoint::new_random(job_id, total_count, output_path)
};
Self {
checkpoint,
checkpoint_path: checkpoint_path.to_string(),
chunk_size: 10_000,
checkpoint_interval: 100_000,
parallel: true,
}
}
pub fn new_incremental(
job_id: &str,
total_count: usize,
output_path: &str,
checkpoint_path: &str,
start_key: &str,
) -> Self {
let checkpoint = if Checkpoint::exists(checkpoint_path) {
Checkpoint::load(checkpoint_path).unwrap_or_else(|_| {
Checkpoint::new_incremental(job_id, total_count, output_path, start_key)
})
} else {
Checkpoint::new_incremental(job_id, total_count, output_path, start_key)
};
Self {
checkpoint,
checkpoint_path: checkpoint_path.to_string(),
chunk_size: 10_000,
checkpoint_interval: 100_000,
parallel: true,
}
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
pub fn checkpoint_interval(mut self, interval: usize) -> Self {
self.checkpoint_interval = interval;
self
}
pub fn parallel(mut self, enabled: bool) -> Self {
self.parallel = enabled;
self
}
pub fn progress(&self) -> &Checkpoint {
&self.checkpoint
}
pub fn generate_with_progress<F>(&mut self, mut progress_callback: F) -> Result<usize, BatchError>
where
F: FnMut(f64),
{
use crate::fast_gen::FastKeyGenerator;
use std::fs::OpenOptions;
use std::io::Write;
if self.checkpoint.is_complete() {
return Ok(self.checkpoint.generated_count);
}
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.checkpoint.output_path)
.map_err(|e| BatchError::io_error(format!("Failed to open output file: {}", e)))?;
let mut keys_since_checkpoint = 0;
while !self.checkpoint.is_complete() {
let chunk_count = self.checkpoint.remaining().min(self.chunk_size);
let keys = FastKeyGenerator::new(chunk_count)
.parallel(self.parallel)
.generate();
for key in &keys {
writeln!(file, "{}", key.to_hex())
.map_err(|e| BatchError::io_error(format!("Failed to write key: {}", e)))?;
}
let last_key = keys.last().map(|k| k.to_hex());
self.checkpoint.update(
self.checkpoint.generated_count + keys.len(),
last_key,
);
keys_since_checkpoint += keys.len();
if keys_since_checkpoint >= self.checkpoint_interval {
self.checkpoint.save(&self.checkpoint_path)?;
keys_since_checkpoint = 0;
}
progress_callback(self.checkpoint.progress_percent());
}
self.checkpoint.save(&self.checkpoint_path)?;
Ok(self.checkpoint.generated_count)
}
pub fn generate(&mut self) -> Result<usize, BatchError> {
self.generate_with_progress(|_| {})
}
pub fn cleanup(&self) -> Result<(), BatchError> {
Checkpoint::delete(&self.checkpoint_path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_checkpoint_save_load() {
let dir = tempdir().unwrap();
let path = dir.path().join("checkpoint.json");
let mut checkpoint = Checkpoint::new_random("test-job", 1000, "output.txt");
checkpoint.update(500, Some("abc123".to_string()));
checkpoint.save(&path).unwrap();
let loaded = Checkpoint::load(&path).unwrap();
assert_eq!(loaded.job_id, "test-job");
assert_eq!(loaded.total_count, 1000);
assert_eq!(loaded.generated_count, 500);
assert_eq!(loaded.last_key, Some("abc123".to_string()));
}
#[test]
fn test_checkpoint_progress() {
let mut checkpoint = Checkpoint::new_random("test", 1000, "out.txt");
assert_eq!(checkpoint.progress_percent(), 0.0);
assert_eq!(checkpoint.remaining(), 1000);
assert!(!checkpoint.is_complete());
checkpoint.update(500, None);
assert_eq!(checkpoint.progress_percent(), 50.0);
assert_eq!(checkpoint.remaining(), 500);
checkpoint.update(1000, None);
assert_eq!(checkpoint.progress_percent(), 100.0);
assert!(checkpoint.is_complete());
}
#[test]
fn test_resumable_generator() {
let dir = tempdir().unwrap();
let output_path = dir.path().join("keys.txt");
let checkpoint_path = dir.path().join("checkpoint.json");
let mut generator = ResumableBatchGenerator::new(
"test-job",
100,
output_path.to_str().unwrap(),
checkpoint_path.to_str().unwrap(),
)
.chunk_size(10)
.checkpoint_interval(50);
let count = generator.generate().unwrap();
assert_eq!(count, 100);
let content = std::fs::read_to_string(&output_path).unwrap();
let lines: Vec<_> = content.lines().collect();
assert_eq!(lines.len(), 100);
generator.cleanup().unwrap();
assert!(!checkpoint_path.exists());
}
#[test]
fn test_incremental_checkpoint() {
let checkpoint = Checkpoint::new_incremental(
"inc-job",
1000,
"output.txt",
"0000000000000000000000000000000000000000000000000000000000000001",
);
assert_eq!(checkpoint.mode, GenerationMode::Incremental);
assert!(checkpoint.start_key.is_some());
}
}