1use crate::core::types::TrackingResult;
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone)]
13pub struct ExportProgress {
14 pub current_stage: ExportStage,
16 pub stage_progress: f64,
18 pub overall_progress: f64,
20 pub processed_allocations: usize,
22 pub total_allocations: usize,
24 pub elapsed_time: Duration,
26 pub estimated_remaining: Option<Duration>,
28 pub processing_speed: f64,
30 pub stage_details: String,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ExportStage {
37 Initializing,
39 DataLocalization,
41 ParallelProcessing,
43 Writing,
45 Completed,
47 Cancelled,
49 Error(String),
51}
52
53impl ExportStage {
54 pub fn weight(&self) -> f64 {
56 match self {
57 ExportStage::Initializing => 0.05,
58 ExportStage::DataLocalization => 0.15,
59 ExportStage::ParallelProcessing => 0.70,
60 ExportStage::Writing => 0.10,
61 ExportStage::Completed => 1.0,
62 ExportStage::Cancelled => 0.0,
63 ExportStage::Error(_) => 0.0,
64 }
65 }
66
67 pub fn description(&self) -> &str {
69 match self {
70 ExportStage::Initializing => "Initializing export environment",
71 ExportStage::DataLocalization => "Localizing data, reducing global state access",
72 ExportStage::ParallelProcessing => "Parallel shard processing",
73 ExportStage::Writing => "High-speed buffered writing",
74 ExportStage::Completed => "Export completed",
75 ExportStage::Cancelled => "Export cancelled",
76 ExportStage::Error(msg) => msg,
77 }
78 }
79}
80
81pub type ProgressCallback = Box<dyn Fn(ExportProgress) + Send + Sync>;
83
84#[derive(Debug, Clone)]
86pub struct CancellationToken {
87 cancelled: Arc<AtomicBool>,
88}
89
90impl CancellationToken {
91 pub fn new() -> Self {
93 Self {
94 cancelled: Arc::new(AtomicBool::new(false)),
95 }
96 }
97
98 pub fn cancel(&self) {
100 self.cancelled.store(true, Ordering::SeqCst);
101 }
102
103 pub fn is_cancelled(&self) -> bool {
105 self.cancelled.load(Ordering::SeqCst)
106 }
107
108 pub fn check_cancelled(&self) -> TrackingResult<()> {
110 if self.is_cancelled() {
111 Err(std::io::Error::new(
112 std::io::ErrorKind::Interrupted,
113 "Export operation was cancelled",
114 )
115 .into())
116 } else {
117 Ok(())
118 }
119 }
120}
121
122impl Default for CancellationToken {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128pub struct ProgressMonitor {
130 start_time: Instant,
132 current_stage: ExportStage,
134 total_allocations: usize,
136 processed_allocations: Arc<AtomicUsize>,
138 callback: Option<ProgressCallback>,
140 cancellation_token: CancellationToken,
142 last_update: Instant,
144 update_interval: Duration,
146 speed_history: Vec<(Instant, usize)>,
148 max_history_size: usize,
150}
151
152impl ProgressMonitor {
153 pub fn new(total_allocations: usize) -> Self {
155 Self {
156 start_time: Instant::now(),
157 current_stage: ExportStage::Initializing,
158 total_allocations,
159 processed_allocations: Arc::new(AtomicUsize::new(0)),
160 callback: None,
161 cancellation_token: CancellationToken::new(),
162 last_update: Instant::now(),
163 update_interval: Duration::from_millis(100), speed_history: Vec::new(),
165 max_history_size: 20,
166 }
167 }
168
169 pub fn set_callback(&mut self, callback: ProgressCallback) {
171 self.callback = Some(callback);
172 }
173
174 pub fn cancellation_token(&self) -> CancellationToken {
176 self.cancellation_token.clone()
177 }
178
179 pub fn set_stage(&mut self, stage: ExportStage) {
181 self.current_stage = stage;
182 }
184
185 pub fn update_progress(&mut self, stage_progress: f64, _details: Option<String>) {
187 let now = Instant::now();
188
189 if now.duration_since(self.last_update) < self.update_interval {
191 return;
192 }
193
194 self.last_update = now;
195
196 let processed = self.processed_allocations.load(Ordering::SeqCst);
197
198 self.speed_history.push((now, processed));
200 if self.speed_history.len() > self.max_history_size {
201 self.speed_history.remove(0);
202 }
203
204 let progress = self.calculate_progress(stage_progress, processed);
205
206 if let Some(ref callback) = self.callback {
207 callback(progress);
208 }
209 }
210
211 pub fn add_processed(&self, count: usize) {
213 self.processed_allocations
214 .fetch_add(count, Ordering::SeqCst);
215 }
216
217 pub fn set_processed(&self, count: usize) {
219 self.processed_allocations.store(count, Ordering::SeqCst);
220 }
221
222 fn calculate_progress(&self, stage_progress: f64, processed: usize) -> ExportProgress {
224 let elapsed = self.start_time.elapsed();
225
226 let stage_weights = [
228 (ExportStage::Initializing, 0.05),
229 (ExportStage::DataLocalization, 0.15),
230 (ExportStage::ParallelProcessing, 0.70),
231 (ExportStage::Writing, 0.10),
232 ];
233
234 let mut overall_progress = 0.0;
235 let mut found_current = false;
236
237 for (stage, weight) in &stage_weights {
238 if *stage == self.current_stage {
239 overall_progress += weight * stage_progress;
240 found_current = true;
241 break;
242 } else {
243 overall_progress += weight;
244 }
245 }
246
247 if !found_current {
248 overall_progress = match self.current_stage {
249 ExportStage::Completed => 1.0,
250 ExportStage::Cancelled => 0.0,
251 ExportStage::Error(_) => 0.0,
252 _ => overall_progress,
253 };
254 }
255
256 let processing_speed = if elapsed.as_secs() > 0 {
258 processed as f64 / elapsed.as_secs_f64()
259 } else {
260 0.0
261 };
262
263 let estimated_remaining = self.estimate_remaining_time(processed, processing_speed);
265
266 ExportProgress {
267 current_stage: self.current_stage.clone(),
268 stage_progress,
269 overall_progress,
270 processed_allocations: processed,
271 total_allocations: self.total_allocations,
272 elapsed_time: elapsed,
273 estimated_remaining,
274 processing_speed,
275 stage_details: self.current_stage.description().to_string(),
276 }
277 }
278
279 fn estimate_remaining_time(&self, processed: usize, current_speed: f64) -> Option<Duration> {
281 if processed >= self.total_allocations || current_speed <= 0.0 {
282 return None;
283 }
284
285 let avg_speed = if self.speed_history.len() >= 2 {
287 let recent_history = &self.speed_history[self.speed_history.len().saturating_sub(5)..];
288 if recent_history.len() >= 2 {
289 let first = &recent_history[0];
290 let last = &recent_history[recent_history.len() - 1];
291 let time_diff = last.0.duration_since(first.0).as_secs_f64();
292 let processed_diff = last.1.saturating_sub(first.1) as f64;
293
294 if time_diff > 0.0 {
295 processed_diff / time_diff
296 } else {
297 current_speed
298 }
299 } else {
300 current_speed
301 }
302 } else {
303 current_speed
304 };
305
306 if avg_speed > 0.0 {
307 let remaining_allocations = self.total_allocations.saturating_sub(processed) as f64;
308 let remaining_seconds = remaining_allocations / avg_speed;
309 Some(Duration::from_secs_f64(remaining_seconds))
310 } else {
311 None
312 }
313 }
314
315 pub fn complete(&mut self) {
317 self.current_stage = ExportStage::Completed;
318 self.update_progress(1.0, Some("Export completed".to_string()));
319 }
320
321 pub fn cancel(&mut self) {
323 self.cancellation_token.cancel();
324 self.current_stage = ExportStage::Cancelled;
325 self.update_progress(0.0, Some("Export cancelled".to_string()));
326 }
327
328 pub fn set_error(&mut self, error: String) {
330 self.current_stage = ExportStage::Error(error.clone());
331 self.update_progress(0.0, Some(error));
332 }
333
334 pub fn should_cancel(&self) -> bool {
336 self.cancellation_token.is_cancelled()
337 }
338
339 pub fn get_progress_snapshot(&self) -> ExportProgress {
341 let processed = self.processed_allocations.load(Ordering::SeqCst);
342 self.calculate_progress(0.0, processed)
343 }
344}
345
346#[derive(Debug, Clone)]
348pub struct ProgressConfig {
349 pub enabled: bool,
351 pub update_interval: Duration,
353 pub show_details: bool,
355 pub show_estimated_time: bool,
357 pub allow_cancellation: bool,
359}
360
361impl Default for ProgressConfig {
362 fn default() -> Self {
363 Self {
364 enabled: true,
365 update_interval: Duration::from_millis(100),
366 show_details: true,
367 show_estimated_time: true,
368 allow_cancellation: true,
369 }
370 }
371}
372
373pub struct ConsoleProgressDisplay {
375 last_line_length: usize,
376}
377
378impl ConsoleProgressDisplay {
379 pub fn new() -> Self {
381 Self {
382 last_line_length: 0,
383 }
384 }
385
386 pub fn display(&mut self, progress: &ExportProgress) {
388 if self.last_line_length > 0 {
390 print!("\r{}", " ".repeat(self.last_line_length));
391 print!("\r");
392 }
393
394 let progress_bar = self.create_progress_bar(progress.overall_progress);
395 let speed_info = if progress.processing_speed > 0.0 {
396 format!(" ({:.0} allocs/sec)", progress.processing_speed)
397 } else {
398 String::new()
399 };
400
401 let time_info = if let Some(remaining) = progress.estimated_remaining {
402 format!(" Remaining: {remaining:?}")
403 } else {
404 String::new()
405 };
406
407 let line = format!(
408 "{progress_bar} {:.1}% {} ({}/{}){speed_info}{time_info}",
409 progress.overall_progress * 100.0,
410 progress.current_stage.description(),
411 progress.processed_allocations,
412 progress.total_allocations,
413 );
414
415 print!("{line}");
416 std::io::Write::flush(&mut std::io::stdout()).ok();
417
418 self.last_line_length = line.len();
419 }
420
421 fn create_progress_bar(&self, progress: f64) -> String {
423 let width = 20;
424 let filled = (progress * width as f64) as usize;
425 let empty = width - filled;
426
427 format!("[{}{}]", "█".repeat(filled), "░".repeat(empty))
428 }
429
430 pub fn finish(&mut self) {
432 tracing::info!("");
433 self.last_line_length = 0;
434 }
435}
436
437impl Default for ConsoleProgressDisplay {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use std::sync::{Arc, Mutex};
447 use std::time::Duration;
448
449 #[test]
450 fn test_cancellation_token() {
451 let token = CancellationToken::new();
452 assert!(!token.is_cancelled());
453
454 token.cancel();
455 assert!(token.is_cancelled());
456 assert!(token.check_cancelled().is_err());
457 }
458
459 #[test]
460 fn test_progress_monitor_basic() {
461 let mut monitor = ProgressMonitor::new(1000);
462
463 let progress = monitor.get_progress_snapshot();
465 assert_eq!(progress.current_stage, ExportStage::Initializing);
466 assert_eq!(progress.processed_allocations, 0);
467 assert_eq!(progress.total_allocations, 1000);
468
469 monitor.set_stage(ExportStage::DataLocalization);
471 let progress = monitor.get_progress_snapshot();
472 assert_eq!(progress.current_stage, ExportStage::DataLocalization);
473
474 monitor.add_processed(100);
476 let progress = monitor.get_progress_snapshot();
477 assert_eq!(progress.processed_allocations, 100);
478 }
479
480 #[test]
481 fn test_progress_callback() {
482 use crate::core::safe_operations::SafeLock;
483
484 let callback_called = Arc::new(Mutex::new(false));
485 let callback_called_clone = callback_called.clone();
486
487 let mut monitor = ProgressMonitor::new(100);
488 monitor.update_interval = Duration::from_millis(1);
490
491 monitor.set_callback(Box::new(move |_progress| {
492 *callback_called_clone
493 .safe_lock()
494 .expect("Failed to acquire lock on callback_called") = true;
495 }));
496
497 std::thread::sleep(std::time::Duration::from_millis(10));
499 monitor.update_progress(0.5, None);
500 assert!(*callback_called
501 .safe_lock()
502 .expect("Failed to acquire lock on callback_called"));
503 }
504
505 #[test]
506 fn test_progress_calculation() {
507 let mut monitor = ProgressMonitor::new(1000);
508 monitor.update_interval = Duration::from_millis(1);
510
511 let progress = monitor.calculate_progress(1.0, 0);
513 assert_eq!(progress.current_stage, ExportStage::Initializing);
514
515 monitor.set_stage(ExportStage::Initializing);
517 let progress = monitor.calculate_progress(1.0, 0);
518 assert!(
519 (progress.overall_progress - 0.05).abs() < 0.01,
520 "Expected ~0.05, got {}",
521 progress.overall_progress
522 );
523
524 monitor.set_stage(ExportStage::DataLocalization);
526 let progress = monitor.calculate_progress(0.5, 0);
527 let expected = 0.05 + 0.15 * 0.5;
528 assert!(
529 (progress.overall_progress - expected).abs() < 0.01,
530 "Expected ~{}, got {}",
531 expected,
532 progress.overall_progress
533 );
534
535 monitor.set_stage(ExportStage::Completed);
537 let progress = monitor.calculate_progress(1.0, 0);
538 assert_eq!(progress.overall_progress, 1.0);
539 assert_eq!(progress.current_stage, ExportStage::Completed);
540 }
541
542 #[test]
543 fn test_speed_calculation() {
544 let monitor = ProgressMonitor::new(1000);
545
546 std::thread::sleep(std::time::Duration::from_millis(10));
548
549 monitor.add_processed(100);
551
552 let progress = monitor.get_progress_snapshot();
553 assert!(
555 progress.processing_speed >= 0.0,
556 "Processing speed should be >= 0, got {}",
557 progress.processing_speed
558 );
559
560 assert!(
562 progress.elapsed_time.as_millis() > 0,
563 "Elapsed time should be > 0"
564 );
565
566 let expected_speed = if progress.elapsed_time.as_secs() > 0 {
569 100.0 / progress.elapsed_time.as_secs_f64()
570 } else {
571 0.0
572 };
573
574 assert!(
576 (progress.processing_speed - expected_speed).abs() < 1.0,
577 "Speed calculation mismatch: expected ~{}, got {}",
578 expected_speed,
579 progress.processing_speed
580 );
581 }
582
583 #[test]
584 fn test_console_progress_display() {
585 let mut display = ConsoleProgressDisplay::new();
586
587 let progress = ExportProgress {
588 current_stage: ExportStage::ParallelProcessing,
589 stage_progress: 0.5,
590 overall_progress: 0.6,
591 processed_allocations: 600,
592 total_allocations: 1000,
593 elapsed_time: Duration::from_secs(10),
594 estimated_remaining: Some(Duration::from_secs(7)),
595 processing_speed: 60.0,
596 stage_details: "Parallel shard processing".to_string(),
597 };
598
599 display.display(&progress);
601 display.finish();
602 }
603
604 #[test]
605 fn test_export_stage_weights() {
606 assert_eq!(ExportStage::Initializing.weight(), 0.05);
607 assert_eq!(ExportStage::DataLocalization.weight(), 0.15);
608 assert_eq!(ExportStage::ParallelProcessing.weight(), 0.70);
609 assert_eq!(ExportStage::Writing.weight(), 0.10);
610 assert_eq!(ExportStage::Completed.weight(), 1.0);
611 }
612}