Skip to main content

research_master/utils/
progress.rs

1//! Progress tracking utilities for long-running operations.
2//!
3//! This module provides progress bar support for operations like
4//! batch downloads, searches, and paper processing.
5//!
6//! # Usage
7//!
8//! ```rust
9//! use research_master::utils::ProgressReporter;
10//!
11//! let reporter = ProgressReporter::new("Processing papers", 100);
12//! reporter.inc();
13//! assert_eq!(reporter.current(), 1);
14//! ```
15
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20/// Progress reporter with optional terminal output
21///
22/// Supports both quiet mode (no output) and verbose mode with progress bars.
23/// Uses atomic counters for thread-safe updates.
24#[derive(Debug, Clone)]
25pub struct ProgressReporter {
26    /// Name of the operation being tracked
27    name: String,
28
29    /// Total units of work (0 if unknown)
30    total: usize,
31
32    /// Current progress (atomic for thread safety)
33    current: Arc<AtomicUsize>,
34
35    /// Start time for calculating ETA
36    start_time: Instant,
37
38    /// Whether to show progress output
39    quiet: bool,
40}
41
42impl ProgressReporter {
43    /// Create a new progress reporter
44    ///
45    /// - `name`: Description of the operation
46    /// - `total`: Total number of units of work (0 for indeterminate)
47    pub fn new(name: &str, total: usize) -> Self {
48        Self {
49            name: name.to_string(),
50            total,
51            current: Arc::new(AtomicUsize::new(0)),
52            start_time: Instant::now(),
53            quiet: std::env::var("RESEARCH_MASTER_QUIET").is_ok(),
54        }
55    }
56
57    /// Create a quiet reporter that doesn't output anything
58    pub fn quiet(name: &str, total: usize) -> Self {
59        Self {
60            name: name.to_string(),
61            total,
62            current: Arc::new(AtomicUsize::new(0)),
63            start_time: Instant::now(),
64            quiet: true,
65        }
66    }
67
68    /// Increment progress by one unit
69    pub fn inc(&self) {
70        self.inc_by(1);
71    }
72
73    /// Increment progress by multiple units
74    pub fn inc_by(&self, delta: usize) {
75        let new_value = self.current.fetch_add(delta, Ordering::SeqCst) + delta;
76
77        if !self.quiet && new_value.is_multiple_of(10) {
78            self.print_progress(new_value);
79        }
80    }
81
82    /// Set the current progress to a specific value
83    pub fn set(&self, value: usize) {
84        self.current.store(value, Ordering::SeqCst);
85
86        if !self.quiet {
87            self.print_progress(value);
88        }
89    }
90
91    /// Print current progress
92    fn print_progress(&self, current: usize) {
93        let elapsed = self.start_time.elapsed();
94
95        if self.total > 0 {
96            // Deterministic progress
97            let percent = (current as f64 / self.total as f64 * 100.0).min(100.0);
98            let eta = self.estimate_eta(current);
99
100            print!(
101                "\r{}: [{:>3.0}%] {}/{} ({:?} elapsed, ETA: {:?})",
102                self.name,
103                percent,
104                current,
105                self.total,
106                Self::format_duration(elapsed),
107                eta
108            );
109        } else {
110            // Indeterminate progress
111            let dots = Self::loading_dots(current);
112            print!(
113                "\r{}: {} ({:?} elapsed)",
114                self.name,
115                dots,
116                Self::format_duration(elapsed)
117            );
118        }
119
120        if current >= self.total && self.total > 0 {
121            println!(); // New line on completion
122        } else {
123            use std::io::Write;
124            let _ = std::io::stdout().flush();
125        }
126    }
127
128    /// Estimate time remaining
129    fn estimate_eta(&self, current: usize) -> Duration {
130        if current == 0 {
131            return Duration::from_secs(u64::MAX);
132        }
133
134        let elapsed = self.start_time.elapsed();
135        let per_unit_secs = elapsed.as_secs_f64() / current as f64;
136        let remaining = self.total.saturating_sub(current);
137
138        Duration::from_secs((per_unit_secs * remaining as f64) as u64)
139    }
140
141    /// Format duration for display
142    fn format_duration(duration: Duration) -> String {
143        let secs = duration.as_secs();
144
145        if secs >= 60 {
146            format!("{}m {}s", secs / 60, secs % 60)
147        } else {
148            format!("{}s", secs)
149        }
150    }
151
152    /// Generate loading dots for indeterminate progress
153    fn loading_dots(count: usize) -> String {
154        let dots = count % 5;
155        format!("{}{}", ".".repeat(dots), " ".repeat(4 - dots))
156    }
157
158    /// Finish the progress and print final stats
159    pub fn finish(&self) {
160        let current = self.current.load(Ordering::SeqCst);
161        let elapsed = self.start_time.elapsed();
162
163        if !self.quiet {
164            if self.total > 0 {
165                println!(
166                    "{}: completed {}/{} in {:?} ({:.1} items/sec)",
167                    self.name,
168                    current,
169                    self.total,
170                    elapsed,
171                    current as f64 / elapsed.as_secs_f64().max(0.001)
172                );
173            } else {
174                println!(
175                    "{}: completed {} items in {:?}",
176                    self.name, current, elapsed
177                );
178            }
179        }
180    }
181
182    /// Get the current progress count
183    pub fn current(&self) -> usize {
184        self.current.load(Ordering::SeqCst)
185    }
186
187    /// Check if the operation is complete
188    pub fn is_done(&self) -> bool {
189        let current = self.current.load(Ordering::SeqCst);
190        self.total > 0 && current >= self.total
191    }
192}
193
194/// Thread-safe progress tracker that can be shared across threads
195#[derive(Clone)]
196pub struct SharedProgress {
197    /// Inner reporter
198    reporter: ProgressReporter,
199
200    /// Callback for progress updates (called from any thread)
201    #[allow(dead_code)]
202    callback: Option<Arc<dyn Fn(usize, usize) + Send + Sync>>,
203}
204
205impl SharedProgress {
206    /// Create a new shared progress tracker
207    pub fn new(name: &str, total: usize) -> Self {
208        Self {
209            reporter: ProgressReporter::new(name, total),
210            callback: None,
211        }
212
213        // let callback = Arc::new(callback);
214    }
215
216    /// Set a callback for progress updates
217    #[allow(dead_code)]
218    pub fn set_callback<F>(&mut self, callback: F)
219    where
220        F: Fn(usize, usize) + Send + Sync + 'static,
221    {
222        self.callback = Some(Arc::new(callback));
223    }
224
225    /// Increment progress
226    pub fn inc(&self) {
227        self.reporter.inc();
228    }
229
230    /// Increment by a delta
231    pub fn inc_by(&self, delta: usize) {
232        self.reporter.inc_by(delta);
233    }
234
235    /// Set progress to a specific value
236    pub fn set(&self, value: usize) {
237        self.reporter.set(value);
238    }
239
240    /// Finish the progress
241    pub fn finish(&self) {
242        self.reporter.finish();
243    }
244}
245
246impl std::fmt::Debug for SharedProgress {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        f.debug_struct("SharedProgress")
249            .field("reporter", &self.reporter)
250            .field(
251                "callback",
252                &self.callback.as_ref().map(|_| "Fn(usize, usize)"),
253            )
254            .finish()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_progress_reporter_creation() {
264        let reporter = ProgressReporter::quiet("test", 100);
265        assert_eq!(reporter.total, 100);
266        assert!(reporter.quiet);
267    }
268
269    #[test]
270    fn test_progress_reporter_increment() {
271        let reporter = ProgressReporter::quiet("test", 100);
272        reporter.inc();
273        assert_eq!(reporter.current(), 1);
274
275        reporter.inc_by(5);
276        assert_eq!(reporter.current(), 6);
277    }
278
279    #[test]
280    fn test_progress_reporter_set() {
281        let reporter = ProgressReporter::quiet("test", 100);
282        reporter.set(50);
283        assert_eq!(reporter.current(), 50);
284    }
285
286    #[test]
287    fn test_progress_reporter_is_done() {
288        let reporter = ProgressReporter::quiet("test", 10);
289        assert!(!reporter.is_done());
290
291        reporter.set(5);
292        assert!(!reporter.is_done());
293
294        reporter.set(10);
295        assert!(reporter.is_done());
296    }
297
298    #[test]
299    fn test_progress_reporter_zero_total() {
300        let reporter = ProgressReporter::quiet("test", 0);
301        assert!(!reporter.is_done());
302
303        reporter.inc();
304        assert!(!reporter.is_done()); // Never done when total is 0
305    }
306
307    #[test]
308    fn test_shared_progress() {
309        let progress = SharedProgress::new("test", 100);
310        progress.inc();
311        assert_eq!(progress.reporter.current(), 1);
312
313        progress.inc_by(10);
314        assert_eq!(progress.reporter.current(), 11);
315
316        progress.set(50);
317        assert_eq!(progress.reporter.current(), 50);
318    }
319}