Skip to main content

fgumi_lib/
progress.rs

1//! Progress tracking utilities
2//!
3//! This module provides a thread-safe progress tracker for logging progress at regular intervals.
4//! The tracker maintains an internal count and logs when interval boundaries are crossed.
5//! When a total is known, it also displays percentage complete and ETA using an exponential
6//! moving average (EMA) of the processing rate with bias correction (tqdm-style).
7
8use crate::logging::format_duration;
9use log::info;
10use std::sync::Mutex;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13use std::time::Instant;
14
15/// Smoothing constant for the EMA rate estimator.
16/// 0.3 balances responsiveness to rate changes with stability.
17/// Same default as tqdm.
18const EMA_ALPHA: f64 = 0.3;
19
20/// State for the exponential moving average rate estimator.
21struct EmaState {
22    /// Smoothed rate (records per second), pre-bias-correction.
23    smoothed_rate: f64,
24    /// Number of EMA updates (for bias correction).
25    calls: u32,
26    /// Count at last EMA update.
27    last_count: u64,
28    /// Time at last EMA update.
29    last_time: Instant,
30}
31
32impl EmaState {
33    fn new() -> Self {
34        Self { smoothed_rate: 0.0, calls: 0, last_count: 0, last_time: Instant::now() }
35    }
36
37    /// Update the EMA with a new observation and return the bias-corrected rate.
38    fn update(&mut self, current_count: u64) -> f64 {
39        if current_count <= self.last_count {
40            return self.corrected_rate();
41        }
42
43        let now = Instant::now();
44        let dt = now.duration_since(self.last_time).as_secs_f64();
45        if dt > 0.0 {
46            #[allow(clippy::cast_precision_loss)]
47            let dn = (current_count - self.last_count) as f64;
48            let instantaneous_rate = dn / dt;
49            self.smoothed_rate =
50                EMA_ALPHA * instantaneous_rate + (1.0 - EMA_ALPHA) * self.smoothed_rate;
51            self.calls += 1;
52            self.last_count = current_count;
53            self.last_time = now;
54        }
55        self.corrected_rate()
56    }
57
58    /// Return the bias-corrected rate estimate.
59    ///
60    /// Uses the correction factor `1 / (1 - (1-α)^n)` to compensate for
61    /// zero-initialization of the EMA, giving accurate estimates even with
62    /// only a few updates.
63    fn corrected_rate(&self) -> f64 {
64        if self.calls == 0 {
65            return 0.0;
66        }
67        let beta = 1.0 - EMA_ALPHA;
68        let correction = 1.0 - beta.powi(self.calls.cast_signed());
69        if correction <= 0.0 { 0.0 } else { self.smoothed_rate / correction }
70    }
71}
72
73/// Convert seconds (f64) to a formatted duration string via [`crate::logging::format_duration`].
74fn fmt_duration(secs: f64) -> String {
75    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
76    format_duration(Duration::from_secs(secs.round() as u64))
77}
78
79/// Thread-safe progress tracker for logging progress at regular intervals.
80///
81/// Maintains an internal count and logs progress messages when the count crosses
82/// interval boundaries. Safe to use from multiple threads.
83///
84/// When a total is set via [`with_total`](Self::with_total), progress messages include
85/// percentage complete and an ETA estimate using an exponential moving average of the
86/// processing rate with bias correction.
87///
88/// # Example
89/// ```
90/// use fgumi_lib::progress::ProgressTracker;
91///
92/// let tracker = ProgressTracker::new("Processed records")
93///     .with_interval(100);
94///
95/// // Add items and log at interval boundaries
96/// for _ in 0..250 {
97///     tracker.log_if_needed(1);  // Logs at 100, 200
98/// }
99/// tracker.log_final();  // Logs "Processed records 250 (complete)"
100/// ```
101///
102/// # Multi-threaded Example
103/// ```
104/// use fgumi_lib::progress::ProgressTracker;
105/// use std::sync::Arc;
106///
107/// let tracker = Arc::new(ProgressTracker::new("Processed records").with_interval(1000));
108///
109/// // Multiple threads can safely add to the same tracker
110/// let tracker_clone = Arc::clone(&tracker);
111/// std::thread::spawn(move || {
112///     tracker_clone.log_if_needed(500);
113/// });
114/// ```
115pub struct ProgressTracker {
116    /// The logging interval - progress is logged when count crosses multiples of this.
117    interval: u64,
118    /// Message prefix for log output.
119    message: String,
120    /// Internal count of items processed (thread-safe).
121    count: AtomicU64,
122    /// Optional total count for percentage and ETA display.
123    total: Option<u64>,
124    /// Time the tracker was created (for elapsed time in final message).
125    start_time: Instant,
126    /// EMA rate estimator state (only accessed during logging, so contention is negligible).
127    ema: Mutex<EmaState>,
128}
129
130impl ProgressTracker {
131    /// Create a new progress tracker with the specified message.
132    ///
133    /// The tracker starts with a count of 0 and a default interval of 10,000.
134    ///
135    /// # Arguments
136    /// * `message` - Message prefix for progress logs (e.g., "Processed records")
137    #[must_use]
138    pub fn new(message: impl Into<String>) -> Self {
139        Self {
140            interval: 10_000,
141            message: message.into(),
142            count: AtomicU64::new(0),
143            total: None,
144            start_time: Instant::now(),
145            ema: Mutex::new(EmaState::new()),
146        }
147    }
148
149    /// Set the logging interval.
150    ///
151    /// Progress will be logged each time the count crosses a multiple of this interval.
152    /// For example, with interval=1000, logs will occur at 1000, 2000, 3000, etc.
153    ///
154    /// # Arguments
155    /// * `interval` - The interval between progress logs
156    #[must_use]
157    pub fn with_interval(mut self, interval: u64) -> Self {
158        self.interval = interval;
159        self
160    }
161
162    /// Set the total expected count.
163    ///
164    /// When set, progress messages include percentage complete and an ETA estimate
165    /// using an exponential moving average of the processing rate.
166    ///
167    /// # Arguments
168    /// * `total` - The total expected count of items
169    #[must_use]
170    pub fn with_total(mut self, total: u64) -> Self {
171        self.total = (total > 0).then_some(total);
172        self
173    }
174
175    /// Add to the count and log if an interval boundary was crossed.
176    ///
177    /// This method is thread-safe and can be called from multiple threads.
178    /// It atomically adds `additional` to the internal count and logs progress
179    /// for each interval boundary crossed.
180    ///
181    /// When a total is set, log messages include percentage and ETA.
182    ///
183    /// # Arguments
184    /// * `additional` - Number of items to add to the count
185    ///
186    /// # Returns
187    /// `true` if the final count is exactly a multiple of the interval,
188    /// `false` otherwise. This is useful for `log_final()` to know if a
189    /// final message is needed.
190    #[allow(clippy::cast_precision_loss)]
191    pub fn log_if_needed(&self, additional: u64) -> bool {
192        if additional == 0 {
193            // No change, just check if current count is on interval
194            let count = self.count.load(Ordering::Relaxed);
195            return count > 0 && count.is_multiple_of(self.interval);
196        }
197
198        let prev = self.count.fetch_add(additional, Ordering::Relaxed);
199        let new_count = prev + additional;
200
201        // Calculate how many interval boundaries we crossed
202        let prev_intervals = prev / self.interval;
203        let new_intervals = new_count / self.interval;
204
205        if new_intervals > prev_intervals {
206            // We crossed at least one interval — update EMA and log.
207            // Compute rate once from the final new_count.
208            let rate = if self.total.is_some() {
209                if let Ok(mut ema) = self.ema.lock() { ema.update(new_count) } else { 0.0 }
210            } else {
211                0.0
212            };
213
214            for i in (prev_intervals + 1)..=new_intervals {
215                let milestone = i * self.interval;
216                if let Some(total) = self.total {
217                    let pct = (milestone as f64 / total as f64) * 100.0;
218                    // Derive remaining work from milestone, not new_count, so each
219                    // logged line shows the ETA appropriate for that milestone.
220                    let eta_suffix = if rate > 0.0 {
221                        let remaining = total.saturating_sub(milestone) as f64;
222                        format!(", ETA ~{}", fmt_duration(remaining / rate))
223                    } else {
224                        String::new()
225                    };
226                    info!("{} {} / {} ({:.1}%{})", self.message, milestone, total, pct, eta_suffix);
227                } else {
228                    info!("{} {}", self.message, milestone);
229                }
230            }
231        }
232
233        // Return true if we landed exactly on an interval
234        new_count.is_multiple_of(self.interval)
235    }
236
237    /// Log final progress.
238    ///
239    /// When a total is set, always logs a completion message with elapsed time.
240    /// Otherwise, logs only if the current count is not on an interval boundary.
241    pub fn log_final(&self) {
242        let count = self.count.load(Ordering::Relaxed);
243        if count == 0 && self.total.is_none() {
244            return;
245        }
246
247        if self.total.is_some() {
248            let elapsed = self.start_time.elapsed().as_secs_f64();
249            info!("{} {} (complete, {})", self.message, count, fmt_duration(elapsed));
250        } else if !self.log_if_needed(0) {
251            info!("{} {} (complete)", self.message, count);
252        }
253    }
254
255    /// Get the current count.
256    ///
257    /// # Returns
258    /// The current count of items processed.
259    #[must_use]
260    pub fn count(&self) -> u64 {
261        self.count.load(Ordering::Relaxed)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use rstest::rstest;
268
269    use super::*;
270
271    #[test]
272    fn test_progress_tracker_new() {
273        let tracker = ProgressTracker::new("Processing");
274        assert_eq!(tracker.interval, 10_000);
275        assert_eq!(tracker.message, "Processing");
276        assert_eq!(tracker.count(), 0);
277        assert!(tracker.total.is_none());
278    }
279
280    #[test]
281    fn test_progress_tracker_with_interval() {
282        let tracker = ProgressTracker::new("Processing").with_interval(100);
283        assert_eq!(tracker.interval, 100);
284    }
285
286    #[test]
287    fn test_progress_tracker_with_total() {
288        let tracker = ProgressTracker::new("Processing").with_total(1000);
289        assert_eq!(tracker.total, Some(1000));
290    }
291
292    #[test]
293    fn test_log_if_needed_returns_correctly() {
294        let tracker = ProgressTracker::new("Test").with_interval(10);
295
296        // Not on interval
297        assert!(!tracker.log_if_needed(5)); // count=5
298        assert!(!tracker.log_if_needed(3)); // count=8
299
300        // Crosses interval, lands on it
301        assert!(tracker.log_if_needed(2)); // count=10, exactly on interval
302
303        // Not on interval
304        assert!(!tracker.log_if_needed(5)); // count=15
305
306        // Crosses interval, doesn't land on it
307        assert!(!tracker.log_if_needed(10)); // count=25, crossed 20
308    }
309
310    #[test]
311    fn test_log_if_needed_zero() {
312        let tracker = ProgressTracker::new("Test").with_interval(10);
313
314        // Zero count, zero additional
315        assert!(!tracker.log_if_needed(0));
316
317        // Add to exactly on interval
318        tracker.log_if_needed(10);
319        assert!(tracker.log_if_needed(0)); // count=10, exactly on interval
320
321        // Add more, not on interval
322        tracker.log_if_needed(5);
323        assert!(!tracker.log_if_needed(0)); // count=15
324    }
325
326    #[test]
327    fn test_count() {
328        let tracker = ProgressTracker::new("Test").with_interval(100);
329
330        assert_eq!(tracker.count(), 0);
331        tracker.log_if_needed(50);
332        assert_eq!(tracker.count(), 50);
333        tracker.log_if_needed(75);
334        assert_eq!(tracker.count(), 125);
335    }
336
337    #[test]
338    fn test_crossing_multiple_intervals() {
339        let tracker = ProgressTracker::new("Test").with_interval(10);
340
341        // Cross multiple intervals at once (10, 20, 30)
342        assert!(!tracker.log_if_needed(35)); // count=35, crossed 10, 20, 30 but not on interval
343        assert_eq!(tracker.count(), 35);
344
345        // Cross to exactly on interval
346        assert!(tracker.log_if_needed(5)); // count=40
347    }
348
349    #[test]
350    fn test_thread_safety() {
351        use std::sync::Arc;
352        use std::thread;
353
354        let tracker = Arc::new(ProgressTracker::new("Test").with_interval(1000));
355        let mut handles = vec![];
356
357        // Spawn 10 threads, each adding 100 items
358        for _ in 0..10 {
359            let tracker_clone = Arc::clone(&tracker);
360            handles.push(thread::spawn(move || {
361                for _ in 0..100 {
362                    tracker_clone.log_if_needed(1);
363                }
364            }));
365        }
366
367        for handle in handles {
368            handle.join().expect("thread should join successfully");
369        }
370
371        // Total should be 1000
372        assert_eq!(tracker.count(), 1000);
373    }
374
375    #[test]
376    fn test_with_total_tracks_count() {
377        let tracker = ProgressTracker::new("Test").with_interval(10).with_total(100);
378
379        tracker.log_if_needed(25);
380        assert_eq!(tracker.count(), 25);
381        tracker.log_if_needed(75);
382        assert_eq!(tracker.count(), 100);
383    }
384
385    #[rstest]
386    #[case(0.0, "0s")]
387    #[case(59.0, "59s")]
388    #[case(59.5, "1m")]
389    #[case(90.0, "1m 30s")]
390    #[case(3600.0, "1h")]
391    #[case(5400.0, "1h 30m")]
392    fn test_fmt_duration(#[case] secs: f64, #[case] expected: &str) {
393        assert_eq!(fmt_duration(secs), expected);
394    }
395
396    #[test]
397    fn test_ema_bias_correction() {
398        let mut ema = EmaState::new();
399
400        // With zero calls, corrected rate should be 0
401        assert!(ema.corrected_rate().abs() < f64::EPSILON);
402
403        // After first update, corrected rate equals instantaneous rate
404        // (bias correction factor is 1/(1-0.7^1) = 1/0.3 = 3.33,
405        //  and smoothed_rate = 0.3 * rate, so corrected = rate)
406        std::thread::sleep(std::time::Duration::from_millis(10));
407        ema.last_count = 0;
408        let rate = ema.update(1000);
409        assert!(rate > 0.0, "rate should be positive after first update");
410    }
411}