memscope_rs/export/
progress_monitor.rs

1//! Progress monitoring and cancellation mechanism
2//!
3//! This module provides progress monitoring, cancellation mechanisms, and remaining time estimation for the export process.
4//! Supports callback interfaces, graceful interruption, and partial result saving.
5
6use crate::core::types::TrackingResult;
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11/// Export progress information
12#[derive(Debug, Clone)]
13pub struct ExportProgress {
14    /// Current stage
15    pub current_stage: ExportStage,
16    /// Current stage progress (0.0 - 1.0)
17    pub stage_progress: f64,
18    /// Overall progress (0.0 - 1.0)
19    pub overall_progress: f64,
20    /// Number of processed allocations
21    pub processed_allocations: usize,
22    /// Total number of allocations
23    pub total_allocations: usize,
24    /// Elapsed time
25    pub elapsed_time: Duration,
26    /// Estimated remaining time
27    pub estimated_remaining: Option<Duration>,
28    /// Current processing speed (allocations/second)
29    pub processing_speed: f64,
30    /// Stage details
31    pub stage_details: String,
32}
33
34/// Export stage
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ExportStage {
37    /// Initializing
38    Initializing,
39    /// Data localization
40    DataLocalization,
41    /// Parallel processing
42    ParallelProcessing,
43    /// High-speed writing
44    Writing,
45    /// Completed
46    Completed,
47    /// Cancelled
48    Cancelled,
49    /// Error
50    Error(String),
51}
52
53impl ExportStage {
54    /// Get stage weight (used for calculating overall progress)
55    pub fn weight(&self) -> f64 {
56        match self {
57            ExportStage::Initializing => 0.05,
58            ExportStage::DataLocalization => 0.15,
59            ExportStage::ParallelProcessing => 0.70,
60            ExportStage::Writing => 0.10,
61            ExportStage::Completed => 1.0,
62            ExportStage::Cancelled => 0.0,
63            ExportStage::Error(_) => 0.0,
64        }
65    }
66
67    /// Get stage description
68    pub fn description(&self) -> &str {
69        match self {
70            ExportStage::Initializing => "Initializing export environment",
71            ExportStage::DataLocalization => "Localizing data, reducing global state access",
72            ExportStage::ParallelProcessing => "Parallel shard processing",
73            ExportStage::Writing => "High-speed buffered writing",
74            ExportStage::Completed => "Export completed",
75            ExportStage::Cancelled => "Export cancelled",
76            ExportStage::Error(msg) => msg,
77        }
78    }
79}
80
81/// Progress callback function type
82pub type ProgressCallback = Box<dyn Fn(ExportProgress) + Send + Sync>;
83
84/// Cancellation token for interrupting export operations
85#[derive(Debug, Clone)]
86pub struct CancellationToken {
87    cancelled: Arc<AtomicBool>,
88}
89
90impl CancellationToken {
91    /// Create new cancellation token
92    pub fn new() -> Self {
93        Self {
94            cancelled: Arc::new(AtomicBool::new(false)),
95        }
96    }
97
98    /// Cancel operation
99    pub fn cancel(&self) {
100        self.cancelled.store(true, Ordering::SeqCst);
101    }
102
103    /// Check if cancelled
104    pub fn is_cancelled(&self) -> bool {
105        self.cancelled.load(Ordering::SeqCst)
106    }
107
108    /// Return error if cancelled
109    pub fn check_cancelled(&self) -> TrackingResult<()> {
110        if self.is_cancelled() {
111            Err(std::io::Error::new(
112                std::io::ErrorKind::Interrupted,
113                "Export operation was cancelled",
114            )
115            .into())
116        } else {
117            Ok(())
118        }
119    }
120}
121
122impl Default for CancellationToken {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128/// Progress monitor
129pub struct ProgressMonitor {
130    /// Start time
131    start_time: Instant,
132    /// Current stage
133    current_stage: ExportStage,
134    /// Total number of allocations
135    total_allocations: usize,
136    /// Number of processed allocations
137    processed_allocations: Arc<AtomicUsize>,
138    /// Progress callback
139    callback: Option<ProgressCallback>,
140    /// Cancellation token
141    cancellation_token: CancellationToken,
142    /// Last update time
143    last_update: Instant,
144    /// Update interval (to avoid too frequent callbacks)
145    update_interval: Duration,
146    /// Historical processing speed (for estimating remaining time)
147    speed_history: Vec<(Instant, usize)>,
148    /// Maximum history size
149    max_history_size: usize,
150}
151
152impl ProgressMonitor {
153    /// Create new progress monitor
154    pub fn new(total_allocations: usize) -> Self {
155        Self {
156            start_time: Instant::now(),
157            current_stage: ExportStage::Initializing,
158            total_allocations,
159            processed_allocations: Arc::new(AtomicUsize::new(0)),
160            callback: None,
161            cancellation_token: CancellationToken::new(),
162            last_update: Instant::now(),
163            update_interval: Duration::from_millis(100), // 100ms update interval
164            speed_history: Vec::new(),
165            max_history_size: 20,
166        }
167    }
168
169    /// Set progress callback
170    pub fn set_callback(&mut self, callback: ProgressCallback) {
171        self.callback = Some(callback);
172    }
173
174    /// Get cancellation token
175    pub fn cancellation_token(&self) -> CancellationToken {
176        self.cancellation_token.clone()
177    }
178
179    /// Set current stage
180    pub fn set_stage(&mut self, stage: ExportStage) {
181        self.current_stage = stage;
182        // Don't automatically call update_progress, let caller control progress
183    }
184
185    /// Update stage progress
186    pub fn update_progress(&mut self, stage_progress: f64, _details: Option<String>) {
187        let now = Instant::now();
188
189        // Check update interval to avoid too frequent callbacks
190        if now.duration_since(self.last_update) < self.update_interval {
191            return;
192        }
193
194        self.last_update = now;
195
196        let processed = self.processed_allocations.load(Ordering::SeqCst);
197
198        // Update speed history
199        self.speed_history.push((now, processed));
200        if self.speed_history.len() > self.max_history_size {
201            self.speed_history.remove(0);
202        }
203
204        let progress = self.calculate_progress(stage_progress, processed);
205
206        if let Some(ref callback) = self.callback {
207            callback(progress);
208        }
209    }
210
211    /// Add processed allocation count
212    pub fn add_processed(&self, count: usize) {
213        self.processed_allocations
214            .fetch_add(count, Ordering::SeqCst);
215    }
216
217    /// Set processed allocation count
218    pub fn set_processed(&self, count: usize) {
219        self.processed_allocations.store(count, Ordering::SeqCst);
220    }
221
222    /// Calculate progress information
223    fn calculate_progress(&self, stage_progress: f64, processed: usize) -> ExportProgress {
224        let elapsed = self.start_time.elapsed();
225
226        // Calculate overall progress
227        let stage_weights = [
228            (ExportStage::Initializing, 0.05),
229            (ExportStage::DataLocalization, 0.15),
230            (ExportStage::ParallelProcessing, 0.70),
231            (ExportStage::Writing, 0.10),
232        ];
233
234        let mut overall_progress = 0.0;
235        let mut found_current = false;
236
237        for (stage, weight) in &stage_weights {
238            if *stage == self.current_stage {
239                overall_progress += weight * stage_progress;
240                found_current = true;
241                break;
242            } else {
243                overall_progress += weight;
244            }
245        }
246
247        if !found_current {
248            overall_progress = match self.current_stage {
249                ExportStage::Completed => 1.0,
250                ExportStage::Cancelled => 0.0,
251                ExportStage::Error(_) => 0.0,
252                _ => overall_progress,
253            };
254        }
255
256        // Calculate processing speed
257        let processing_speed = if elapsed.as_secs() > 0 {
258            processed as f64 / elapsed.as_secs_f64()
259        } else {
260            0.0
261        };
262
263        // Estimate remaining time
264        let estimated_remaining = self.estimate_remaining_time(processed, processing_speed);
265
266        ExportProgress {
267            current_stage: self.current_stage.clone(),
268            stage_progress,
269            overall_progress,
270            processed_allocations: processed,
271            total_allocations: self.total_allocations,
272            elapsed_time: elapsed,
273            estimated_remaining,
274            processing_speed,
275            stage_details: self.current_stage.description().to_string(),
276        }
277    }
278
279    /// Estimate remaining time
280    fn estimate_remaining_time(&self, processed: usize, current_speed: f64) -> Option<Duration> {
281        if processed >= self.total_allocations || current_speed <= 0.0 {
282            return None;
283        }
284
285        // Use historical speed data for more accurate estimation
286        let avg_speed = if self.speed_history.len() >= 2 {
287            let recent_history = &self.speed_history[self.speed_history.len().saturating_sub(5)..];
288            if recent_history.len() >= 2 {
289                let first = &recent_history[0];
290                let last = &recent_history[recent_history.len() - 1];
291                let time_diff = last.0.duration_since(first.0).as_secs_f64();
292                let processed_diff = last.1.saturating_sub(first.1) as f64;
293
294                if time_diff > 0.0 {
295                    processed_diff / time_diff
296                } else {
297                    current_speed
298                }
299            } else {
300                current_speed
301            }
302        } else {
303            current_speed
304        };
305
306        if avg_speed > 0.0 {
307            let remaining_allocations = self.total_allocations.saturating_sub(processed) as f64;
308            let remaining_seconds = remaining_allocations / avg_speed;
309            Some(Duration::from_secs_f64(remaining_seconds))
310        } else {
311            None
312        }
313    }
314
315    /// Complete export
316    pub fn complete(&mut self) {
317        self.current_stage = ExportStage::Completed;
318        self.update_progress(1.0, Some("Export completed".to_string()));
319    }
320
321    /// Cancel export
322    pub fn cancel(&mut self) {
323        self.cancellation_token.cancel();
324        self.current_stage = ExportStage::Cancelled;
325        self.update_progress(0.0, Some("Export cancelled".to_string()));
326    }
327
328    /// Set error state
329    pub fn set_error(&mut self, error: String) {
330        self.current_stage = ExportStage::Error(error.clone());
331        self.update_progress(0.0, Some(error));
332    }
333
334    /// Check if should cancel
335    pub fn should_cancel(&self) -> bool {
336        self.cancellation_token.is_cancelled()
337    }
338
339    /// Get current progress snapshot
340    pub fn get_progress_snapshot(&self) -> ExportProgress {
341        let processed = self.processed_allocations.load(Ordering::SeqCst);
342        self.calculate_progress(0.0, processed)
343    }
344}
345
346/// Progress monitoring configuration
347#[derive(Debug, Clone)]
348pub struct ProgressConfig {
349    /// Whether to enable progress monitoring
350    pub enabled: bool,
351    /// Update interval
352    pub update_interval: Duration,
353    /// Whether to show details
354    pub show_details: bool,
355    /// Whether to show estimated time
356    pub show_estimated_time: bool,
357    /// Whether to allow cancellation
358    pub allow_cancellation: bool,
359}
360
361impl Default for ProgressConfig {
362    fn default() -> Self {
363        Self {
364            enabled: true,
365            update_interval: Duration::from_millis(100),
366            show_details: true,
367            show_estimated_time: true,
368            allow_cancellation: true,
369        }
370    }
371}
372
373/// Console progress display
374pub struct ConsoleProgressDisplay {
375    last_line_length: usize,
376}
377
378impl ConsoleProgressDisplay {
379    /// Create new console progress display
380    pub fn new() -> Self {
381        Self {
382            last_line_length: 0,
383        }
384    }
385
386    /// Display progress
387    pub fn display(&mut self, progress: &ExportProgress) {
388        // Clear previous line
389        if self.last_line_length > 0 {
390            print!("\r{}", " ".repeat(self.last_line_length));
391            print!("\r");
392        }
393
394        let progress_bar = self.create_progress_bar(progress.overall_progress);
395        let speed_info = if progress.processing_speed > 0.0 {
396            format!(" ({:.0} allocs/sec)", progress.processing_speed)
397        } else {
398            String::new()
399        };
400
401        let time_info = if let Some(remaining) = progress.estimated_remaining {
402            format!(" Remaining: {remaining:?}")
403        } else {
404            String::new()
405        };
406
407        let line = format!(
408            "{progress_bar} {:.1}% {} ({}/{}){speed_info}{time_info}",
409            progress.overall_progress * 100.0,
410            progress.current_stage.description(),
411            progress.processed_allocations,
412            progress.total_allocations,
413        );
414
415        print!("{line}");
416        std::io::Write::flush(&mut std::io::stdout()).ok();
417
418        self.last_line_length = line.len();
419    }
420
421    /// Create progress bar
422    fn create_progress_bar(&self, progress: f64) -> String {
423        let width = 20;
424        let filled = (progress * width as f64) as usize;
425        let empty = width - filled;
426
427        format!("[{}{}]", "█".repeat(filled), "░".repeat(empty))
428    }
429
430    /// Finish display (newline)
431    pub fn finish(&mut self) {
432        tracing::info!("");
433        self.last_line_length = 0;
434    }
435}
436
437impl Default for ConsoleProgressDisplay {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use std::sync::{Arc, Mutex};
447    use std::time::Duration;
448
449    #[test]
450    fn test_cancellation_token() {
451        let token = CancellationToken::new();
452        assert!(!token.is_cancelled());
453
454        token.cancel();
455        assert!(token.is_cancelled());
456        assert!(token.check_cancelled().is_err());
457    }
458
459    #[test]
460    fn test_progress_monitor_basic() {
461        let mut monitor = ProgressMonitor::new(1000);
462
463        // Test initial state
464        let progress = monitor.get_progress_snapshot();
465        assert_eq!(progress.current_stage, ExportStage::Initializing);
466        assert_eq!(progress.processed_allocations, 0);
467        assert_eq!(progress.total_allocations, 1000);
468
469        // Test stage switching
470        monitor.set_stage(ExportStage::DataLocalization);
471        let progress = monitor.get_progress_snapshot();
472        assert_eq!(progress.current_stage, ExportStage::DataLocalization);
473
474        // Test progress update
475        monitor.add_processed(100);
476        let progress = monitor.get_progress_snapshot();
477        assert_eq!(progress.processed_allocations, 100);
478    }
479
480    #[test]
481    fn test_progress_callback() {
482        use crate::core::safe_operations::SafeLock;
483
484        let callback_called = Arc::new(Mutex::new(false));
485        let callback_called_clone = callback_called.clone();
486
487        let mut monitor = ProgressMonitor::new(100);
488        // Set shorter update interval for testing
489        monitor.update_interval = Duration::from_millis(1);
490
491        monitor.set_callback(Box::new(move |_progress| {
492            *callback_called_clone
493                .safe_lock()
494                .expect("Failed to acquire lock on callback_called") = true;
495        }));
496
497        // Add small delay to ensure update interval passes
498        std::thread::sleep(std::time::Duration::from_millis(10));
499        monitor.update_progress(0.5, None);
500        assert!(*callback_called
501            .safe_lock()
502            .expect("Failed to acquire lock on callback_called"));
503    }
504
505    #[test]
506    fn test_progress_calculation() {
507        let mut monitor = ProgressMonitor::new(1000);
508        // Set shorter update interval for testing
509        monitor.update_interval = Duration::from_millis(1);
510
511        // Directly test calculation function
512        let progress = monitor.calculate_progress(1.0, 0);
513        assert_eq!(progress.current_stage, ExportStage::Initializing);
514
515        // Test initialization stage
516        monitor.set_stage(ExportStage::Initializing);
517        let progress = monitor.calculate_progress(1.0, 0);
518        assert!(
519            (progress.overall_progress - 0.05).abs() < 0.01,
520            "Expected ~0.05, got {}",
521            progress.overall_progress
522        );
523
524        // Test data localization stage
525        monitor.set_stage(ExportStage::DataLocalization);
526        let progress = monitor.calculate_progress(0.5, 0);
527        let expected = 0.05 + 0.15 * 0.5;
528        assert!(
529            (progress.overall_progress - expected).abs() < 0.01,
530            "Expected ~{}, got {}",
531            expected,
532            progress.overall_progress
533        );
534
535        // Test completion
536        monitor.set_stage(ExportStage::Completed);
537        let progress = monitor.calculate_progress(1.0, 0);
538        assert_eq!(progress.overall_progress, 1.0);
539        assert_eq!(progress.current_stage, ExportStage::Completed);
540    }
541
542    #[test]
543    fn test_speed_calculation() {
544        let monitor = ProgressMonitor::new(1000);
545
546        // Add small delay to ensure elapsed time > 0
547        std::thread::sleep(std::time::Duration::from_millis(10));
548
549        // Simulate processing some allocations
550        monitor.add_processed(100);
551
552        let progress = monitor.get_progress_snapshot();
553        // Processing speed should be >= 0
554        assert!(
555            progress.processing_speed >= 0.0,
556            "Processing speed should be >= 0, got {}",
557            progress.processing_speed
558        );
559
560        // Basic test: ensure speed calculation doesn't crash
561        assert!(
562            progress.elapsed_time.as_millis() > 0,
563            "Elapsed time should be > 0"
564        );
565
566        // If there are processed allocations and enough time, speed should be > 0
567        // But due to test environment uncertainty, we only check basic mathematical correctness
568        let expected_speed = if progress.elapsed_time.as_secs() > 0 {
569            100.0 / progress.elapsed_time.as_secs_f64()
570        } else {
571            0.0
572        };
573
574        // Allow certain error range
575        assert!(
576            (progress.processing_speed - expected_speed).abs() < 1.0,
577            "Speed calculation mismatch: expected ~{}, got {}",
578            expected_speed,
579            progress.processing_speed
580        );
581    }
582
583    #[test]
584    fn test_console_progress_display() {
585        let mut display = ConsoleProgressDisplay::new();
586
587        let progress = ExportProgress {
588            current_stage: ExportStage::ParallelProcessing,
589            stage_progress: 0.5,
590            overall_progress: 0.6,
591            processed_allocations: 600,
592            total_allocations: 1000,
593            elapsed_time: Duration::from_secs(10),
594            estimated_remaining: Some(Duration::from_secs(7)),
595            processing_speed: 60.0,
596            stage_details: "Parallel shard processing".to_string(),
597        };
598
599        // This test mainly ensures no panic
600        display.display(&progress);
601        display.finish();
602    }
603
604    #[test]
605    fn test_export_stage_weights() {
606        assert_eq!(ExportStage::Initializing.weight(), 0.05);
607        assert_eq!(ExportStage::DataLocalization.weight(), 0.15);
608        assert_eq!(ExportStage::ParallelProcessing.weight(), 0.70);
609        assert_eq!(ExportStage::Writing.weight(), 0.10);
610        assert_eq!(ExportStage::Completed.weight(), 1.0);
611    }
612}