1use crate::error::BatchError;
7use serde::{Deserialize, Serialize};
8use std::fs::{self, File};
9use std::io::{BufReader, BufWriter};
10use std::path::Path;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Checkpoint {
16 pub job_id: String,
18 pub total_count: usize,
20 pub generated_count: usize,
22 pub output_path: String,
24 pub last_key: Option<String>,
26 pub updated_at: u64,
28 pub mode: GenerationMode,
30 pub start_key: Option<String>,
32 pub current_position: u64,
34}
35
36#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
38pub enum GenerationMode {
39 Random,
41 Incremental,
43}
44
45impl Checkpoint {
46 pub fn new_random(job_id: &str, total_count: usize, output_path: &str) -> Self {
48 Self {
49 job_id: job_id.to_string(),
50 total_count,
51 generated_count: 0,
52 output_path: output_path.to_string(),
53 last_key: None,
54 updated_at: current_timestamp(),
55 mode: GenerationMode::Random,
56 start_key: None,
57 current_position: 0,
58 }
59 }
60
61 pub fn new_incremental(
63 job_id: &str,
64 total_count: usize,
65 output_path: &str,
66 start_key: &str,
67 ) -> Self {
68 Self {
69 job_id: job_id.to_string(),
70 total_count,
71 generated_count: 0,
72 output_path: output_path.to_string(),
73 last_key: None,
74 updated_at: current_timestamp(),
75 mode: GenerationMode::Incremental,
76 start_key: Some(start_key.to_string()),
77 current_position: 0,
78 }
79 }
80
81 pub fn update(&mut self, generated: usize, last_key: Option<String>) {
83 self.generated_count = generated;
84 self.last_key = last_key;
85 self.updated_at = current_timestamp();
86 }
87
88 pub fn update_position(&mut self, position: u64) {
90 self.current_position = position;
91 self.updated_at = current_timestamp();
92 }
93
94 pub fn is_complete(&self) -> bool {
96 self.generated_count >= self.total_count
97 }
98
99 pub fn remaining(&self) -> usize {
101 self.total_count.saturating_sub(self.generated_count)
102 }
103
104 pub fn progress_percent(&self) -> f64 {
106 if self.total_count == 0 {
107 100.0
108 } else {
109 (self.generated_count as f64 / self.total_count as f64) * 100.0
110 }
111 }
112
113 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), BatchError> {
115 let file = File::create(path)
116 .map_err(|e| BatchError::io_error(format!("Failed to create checkpoint file: {}", e)))?;
117 let writer = BufWriter::new(file);
118 serde_json::to_writer_pretty(writer, self)
119 .map_err(|e| BatchError::io_error(format!("Failed to write checkpoint: {}", e)))?;
120 Ok(())
121 }
122
123 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, BatchError> {
125 let file = File::open(path)
126 .map_err(|e| BatchError::io_error(format!("Failed to open checkpoint file: {}", e)))?;
127 let reader = BufReader::new(file);
128 let checkpoint: Self = serde_json::from_reader(reader)
129 .map_err(|e| BatchError::io_error(format!("Failed to parse checkpoint: {}", e)))?;
130 Ok(checkpoint)
131 }
132
133 pub fn exists<P: AsRef<Path>>(path: P) -> bool {
135 path.as_ref().exists()
136 }
137
138 pub fn delete<P: AsRef<Path>>(path: P) -> Result<(), BatchError> {
140 if path.as_ref().exists() {
141 fs::remove_file(path)
142 .map_err(|e| BatchError::io_error(format!("Failed to delete checkpoint: {}", e)))?;
143 }
144 Ok(())
145 }
146}
147
148fn current_timestamp() -> u64 {
150 SystemTime::now()
151 .duration_since(UNIX_EPOCH)
152 .map(|d| d.as_secs())
153 .unwrap_or(0)
154}
155
156pub struct ResumableBatchGenerator {
180 checkpoint: Checkpoint,
182 checkpoint_path: String,
184 chunk_size: usize,
186 checkpoint_interval: usize,
188 parallel: bool,
190}
191
192impl ResumableBatchGenerator {
193 pub fn new(job_id: &str, total_count: usize, output_path: &str, checkpoint_path: &str) -> Self {
197 let checkpoint = if Checkpoint::exists(checkpoint_path) {
198 Checkpoint::load(checkpoint_path).unwrap_or_else(|_| {
199 Checkpoint::new_random(job_id, total_count, output_path)
200 })
201 } else {
202 Checkpoint::new_random(job_id, total_count, output_path)
203 };
204
205 Self {
206 checkpoint,
207 checkpoint_path: checkpoint_path.to_string(),
208 chunk_size: 10_000,
209 checkpoint_interval: 100_000,
210 parallel: true,
211 }
212 }
213
214 pub fn new_incremental(
216 job_id: &str,
217 total_count: usize,
218 output_path: &str,
219 checkpoint_path: &str,
220 start_key: &str,
221 ) -> Self {
222 let checkpoint = if Checkpoint::exists(checkpoint_path) {
223 Checkpoint::load(checkpoint_path).unwrap_or_else(|_| {
224 Checkpoint::new_incremental(job_id, total_count, output_path, start_key)
225 })
226 } else {
227 Checkpoint::new_incremental(job_id, total_count, output_path, start_key)
228 };
229
230 Self {
231 checkpoint,
232 checkpoint_path: checkpoint_path.to_string(),
233 chunk_size: 10_000,
234 checkpoint_interval: 100_000,
235 parallel: true,
236 }
237 }
238
239 pub fn chunk_size(mut self, size: usize) -> Self {
241 self.chunk_size = size;
242 self
243 }
244
245 pub fn checkpoint_interval(mut self, interval: usize) -> Self {
247 self.checkpoint_interval = interval;
248 self
249 }
250
251 pub fn parallel(mut self, enabled: bool) -> Self {
253 self.parallel = enabled;
254 self
255 }
256
257 pub fn progress(&self) -> &Checkpoint {
259 &self.checkpoint
260 }
261
262 pub fn generate_with_progress<F>(&mut self, mut progress_callback: F) -> Result<usize, BatchError>
264 where
265 F: FnMut(f64),
266 {
267 use crate::fast_gen::FastKeyGenerator;
268 use std::fs::OpenOptions;
269 use std::io::Write;
270
271 if self.checkpoint.is_complete() {
272 return Ok(self.checkpoint.generated_count);
273 }
274
275 let mut file = OpenOptions::new()
277 .create(true)
278 .append(true)
279 .open(&self.checkpoint.output_path)
280 .map_err(|e| BatchError::io_error(format!("Failed to open output file: {}", e)))?;
281
282 let mut keys_since_checkpoint = 0;
283
284 while !self.checkpoint.is_complete() {
285 let chunk_count = self.checkpoint.remaining().min(self.chunk_size);
286
287 let keys = FastKeyGenerator::new(chunk_count)
288 .parallel(self.parallel)
289 .generate();
290
291 for key in &keys {
293 writeln!(file, "{}", key.to_hex())
294 .map_err(|e| BatchError::io_error(format!("Failed to write key: {}", e)))?;
295 }
296
297 let last_key = keys.last().map(|k| k.to_hex());
299 self.checkpoint.update(
300 self.checkpoint.generated_count + keys.len(),
301 last_key,
302 );
303
304 keys_since_checkpoint += keys.len();
305
306 if keys_since_checkpoint >= self.checkpoint_interval {
308 self.checkpoint.save(&self.checkpoint_path)?;
309 keys_since_checkpoint = 0;
310 }
311
312 progress_callback(self.checkpoint.progress_percent());
314 }
315
316 self.checkpoint.save(&self.checkpoint_path)?;
318
319 Ok(self.checkpoint.generated_count)
320 }
321
322 pub fn generate(&mut self) -> Result<usize, BatchError> {
324 self.generate_with_progress(|_| {})
325 }
326
327 pub fn cleanup(&self) -> Result<(), BatchError> {
329 Checkpoint::delete(&self.checkpoint_path)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use tempfile::tempdir;
337
338 #[test]
339 fn test_checkpoint_save_load() {
340 let dir = tempdir().unwrap();
341 let path = dir.path().join("checkpoint.json");
342
343 let mut checkpoint = Checkpoint::new_random("test-job", 1000, "output.txt");
344 checkpoint.update(500, Some("abc123".to_string()));
345 checkpoint.save(&path).unwrap();
346
347 let loaded = Checkpoint::load(&path).unwrap();
348 assert_eq!(loaded.job_id, "test-job");
349 assert_eq!(loaded.total_count, 1000);
350 assert_eq!(loaded.generated_count, 500);
351 assert_eq!(loaded.last_key, Some("abc123".to_string()));
352 }
353
354 #[test]
355 fn test_checkpoint_progress() {
356 let mut checkpoint = Checkpoint::new_random("test", 1000, "out.txt");
357 assert_eq!(checkpoint.progress_percent(), 0.0);
358 assert_eq!(checkpoint.remaining(), 1000);
359 assert!(!checkpoint.is_complete());
360
361 checkpoint.update(500, None);
362 assert_eq!(checkpoint.progress_percent(), 50.0);
363 assert_eq!(checkpoint.remaining(), 500);
364
365 checkpoint.update(1000, None);
366 assert_eq!(checkpoint.progress_percent(), 100.0);
367 assert!(checkpoint.is_complete());
368 }
369
370 #[test]
371 fn test_resumable_generator() {
372 let dir = tempdir().unwrap();
373 let output_path = dir.path().join("keys.txt");
374 let checkpoint_path = dir.path().join("checkpoint.json");
375
376 let mut generator = ResumableBatchGenerator::new(
377 "test-job",
378 100,
379 output_path.to_str().unwrap(),
380 checkpoint_path.to_str().unwrap(),
381 )
382 .chunk_size(10)
383 .checkpoint_interval(50);
384
385 let count = generator.generate().unwrap();
386 assert_eq!(count, 100);
387
388 let content = std::fs::read_to_string(&output_path).unwrap();
390 let lines: Vec<_> = content.lines().collect();
391 assert_eq!(lines.len(), 100);
392
393 generator.cleanup().unwrap();
395 assert!(!checkpoint_path.exists());
396 }
397
398 #[test]
399 fn test_incremental_checkpoint() {
400 let checkpoint = Checkpoint::new_incremental(
401 "inc-job",
402 1000,
403 "output.txt",
404 "0000000000000000000000000000000000000000000000000000000000000001",
405 );
406
407 assert_eq!(checkpoint.mode, GenerationMode::Incremental);
408 assert!(checkpoint.start_key.is_some());
409 }
410}