fgumi_lib/progress.rs
1//! Progress tracking utilities
2//!
3//! This module provides a thread-safe progress tracker for logging progress at regular intervals.
4//! The tracker maintains an internal count and logs when interval boundaries are crossed.
5//! When a total is known, it also displays percentage complete and ETA using an exponential
6//! moving average (EMA) of the processing rate with bias correction (tqdm-style).
7
8use crate::logging::format_duration;
9use log::info;
10use std::sync::Mutex;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::Duration;
13use std::time::Instant;
14
15/// Smoothing constant for the EMA rate estimator.
16/// 0.3 balances responsiveness to rate changes with stability.
17/// Same default as tqdm.
18const EMA_ALPHA: f64 = 0.3;
19
20/// State for the exponential moving average rate estimator.
21struct EmaState {
22 /// Smoothed rate (records per second), pre-bias-correction.
23 smoothed_rate: f64,
24 /// Number of EMA updates (for bias correction).
25 calls: u32,
26 /// Count at last EMA update.
27 last_count: u64,
28 /// Time at last EMA update.
29 last_time: Instant,
30}
31
32impl EmaState {
33 fn new() -> Self {
34 Self { smoothed_rate: 0.0, calls: 0, last_count: 0, last_time: Instant::now() }
35 }
36
37 /// Update the EMA with a new observation and return the bias-corrected rate.
38 fn update(&mut self, current_count: u64) -> f64 {
39 if current_count <= self.last_count {
40 return self.corrected_rate();
41 }
42
43 let now = Instant::now();
44 let dt = now.duration_since(self.last_time).as_secs_f64();
45 if dt > 0.0 {
46 #[allow(clippy::cast_precision_loss)]
47 let dn = (current_count - self.last_count) as f64;
48 let instantaneous_rate = dn / dt;
49 self.smoothed_rate =
50 EMA_ALPHA * instantaneous_rate + (1.0 - EMA_ALPHA) * self.smoothed_rate;
51 self.calls += 1;
52 self.last_count = current_count;
53 self.last_time = now;
54 }
55 self.corrected_rate()
56 }
57
58 /// Return the bias-corrected rate estimate.
59 ///
60 /// Uses the correction factor `1 / (1 - (1-α)^n)` to compensate for
61 /// zero-initialization of the EMA, giving accurate estimates even with
62 /// only a few updates.
63 fn corrected_rate(&self) -> f64 {
64 if self.calls == 0 {
65 return 0.0;
66 }
67 let beta = 1.0 - EMA_ALPHA;
68 let correction = 1.0 - beta.powi(self.calls.cast_signed());
69 if correction <= 0.0 { 0.0 } else { self.smoothed_rate / correction }
70 }
71}
72
73/// Convert seconds (f64) to a formatted duration string via [`crate::logging::format_duration`].
74fn fmt_duration(secs: f64) -> String {
75 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
76 format_duration(Duration::from_secs(secs.round() as u64))
77}
78
79/// Thread-safe progress tracker for logging progress at regular intervals.
80///
81/// Maintains an internal count and logs progress messages when the count crosses
82/// interval boundaries. Safe to use from multiple threads.
83///
84/// When a total is set via [`with_total`](Self::with_total), progress messages include
85/// percentage complete and an ETA estimate using an exponential moving average of the
86/// processing rate with bias correction.
87///
88/// # Example
89/// ```
90/// use fgumi_lib::progress::ProgressTracker;
91///
92/// let tracker = ProgressTracker::new("Processed records")
93/// .with_interval(100);
94///
95/// // Add items and log at interval boundaries
96/// for _ in 0..250 {
97/// tracker.log_if_needed(1); // Logs at 100, 200
98/// }
99/// tracker.log_final(); // Logs "Processed records 250 (complete)"
100/// ```
101///
102/// # Multi-threaded Example
103/// ```
104/// use fgumi_lib::progress::ProgressTracker;
105/// use std::sync::Arc;
106///
107/// let tracker = Arc::new(ProgressTracker::new("Processed records").with_interval(1000));
108///
109/// // Multiple threads can safely add to the same tracker
110/// let tracker_clone = Arc::clone(&tracker);
111/// std::thread::spawn(move || {
112/// tracker_clone.log_if_needed(500);
113/// });
114/// ```
115pub struct ProgressTracker {
116 /// The logging interval - progress is logged when count crosses multiples of this.
117 interval: u64,
118 /// Message prefix for log output.
119 message: String,
120 /// Internal count of items processed (thread-safe).
121 count: AtomicU64,
122 /// Optional total count for percentage and ETA display.
123 total: Option<u64>,
124 /// Time the tracker was created (for elapsed time in final message).
125 start_time: Instant,
126 /// EMA rate estimator state (only accessed during logging, so contention is negligible).
127 ema: Mutex<EmaState>,
128}
129
130impl ProgressTracker {
131 /// Create a new progress tracker with the specified message.
132 ///
133 /// The tracker starts with a count of 0 and a default interval of 10,000.
134 ///
135 /// # Arguments
136 /// * `message` - Message prefix for progress logs (e.g., "Processed records")
137 #[must_use]
138 pub fn new(message: impl Into<String>) -> Self {
139 Self {
140 interval: 10_000,
141 message: message.into(),
142 count: AtomicU64::new(0),
143 total: None,
144 start_time: Instant::now(),
145 ema: Mutex::new(EmaState::new()),
146 }
147 }
148
149 /// Set the logging interval.
150 ///
151 /// Progress will be logged each time the count crosses a multiple of this interval.
152 /// For example, with interval=1000, logs will occur at 1000, 2000, 3000, etc.
153 ///
154 /// # Arguments
155 /// * `interval` - The interval between progress logs
156 #[must_use]
157 pub fn with_interval(mut self, interval: u64) -> Self {
158 self.interval = interval;
159 self
160 }
161
162 /// Set the total expected count.
163 ///
164 /// When set, progress messages include percentage complete and an ETA estimate
165 /// using an exponential moving average of the processing rate.
166 ///
167 /// # Arguments
168 /// * `total` - The total expected count of items
169 #[must_use]
170 pub fn with_total(mut self, total: u64) -> Self {
171 self.total = (total > 0).then_some(total);
172 self
173 }
174
175 /// Add to the count and log if an interval boundary was crossed.
176 ///
177 /// This method is thread-safe and can be called from multiple threads.
178 /// It atomically adds `additional` to the internal count and logs progress
179 /// for each interval boundary crossed.
180 ///
181 /// When a total is set, log messages include percentage and ETA.
182 ///
183 /// # Arguments
184 /// * `additional` - Number of items to add to the count
185 ///
186 /// # Returns
187 /// `true` if the final count is exactly a multiple of the interval,
188 /// `false` otherwise. This is useful for `log_final()` to know if a
189 /// final message is needed.
190 #[allow(clippy::cast_precision_loss)]
191 pub fn log_if_needed(&self, additional: u64) -> bool {
192 if additional == 0 {
193 // No change, just check if current count is on interval
194 let count = self.count.load(Ordering::Relaxed);
195 return count > 0 && count.is_multiple_of(self.interval);
196 }
197
198 let prev = self.count.fetch_add(additional, Ordering::Relaxed);
199 let new_count = prev + additional;
200
201 // Calculate how many interval boundaries we crossed
202 let prev_intervals = prev / self.interval;
203 let new_intervals = new_count / self.interval;
204
205 if new_intervals > prev_intervals {
206 // We crossed at least one interval — update EMA and log.
207 // Compute rate once from the final new_count.
208 let rate = if self.total.is_some() {
209 if let Ok(mut ema) = self.ema.lock() { ema.update(new_count) } else { 0.0 }
210 } else {
211 0.0
212 };
213
214 for i in (prev_intervals + 1)..=new_intervals {
215 let milestone = i * self.interval;
216 if let Some(total) = self.total {
217 let pct = (milestone as f64 / total as f64) * 100.0;
218 // Derive remaining work from milestone, not new_count, so each
219 // logged line shows the ETA appropriate for that milestone.
220 let eta_suffix = if rate > 0.0 {
221 let remaining = total.saturating_sub(milestone) as f64;
222 format!(", ETA ~{}", fmt_duration(remaining / rate))
223 } else {
224 String::new()
225 };
226 info!("{} {} / {} ({:.1}%{})", self.message, milestone, total, pct, eta_suffix);
227 } else {
228 info!("{} {}", self.message, milestone);
229 }
230 }
231 }
232
233 // Return true if we landed exactly on an interval
234 new_count.is_multiple_of(self.interval)
235 }
236
237 /// Log final progress.
238 ///
239 /// When a total is set, always logs a completion message with elapsed time.
240 /// Otherwise, logs only if the current count is not on an interval boundary.
241 pub fn log_final(&self) {
242 let count = self.count.load(Ordering::Relaxed);
243 if count == 0 && self.total.is_none() {
244 return;
245 }
246
247 if self.total.is_some() {
248 let elapsed = self.start_time.elapsed().as_secs_f64();
249 info!("{} {} (complete, {})", self.message, count, fmt_duration(elapsed));
250 } else if !self.log_if_needed(0) {
251 info!("{} {} (complete)", self.message, count);
252 }
253 }
254
255 /// Get the current count.
256 ///
257 /// # Returns
258 /// The current count of items processed.
259 #[must_use]
260 pub fn count(&self) -> u64 {
261 self.count.load(Ordering::Relaxed)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use rstest::rstest;
268
269 use super::*;
270
271 #[test]
272 fn test_progress_tracker_new() {
273 let tracker = ProgressTracker::new("Processing");
274 assert_eq!(tracker.interval, 10_000);
275 assert_eq!(tracker.message, "Processing");
276 assert_eq!(tracker.count(), 0);
277 assert!(tracker.total.is_none());
278 }
279
280 #[test]
281 fn test_progress_tracker_with_interval() {
282 let tracker = ProgressTracker::new("Processing").with_interval(100);
283 assert_eq!(tracker.interval, 100);
284 }
285
286 #[test]
287 fn test_progress_tracker_with_total() {
288 let tracker = ProgressTracker::new("Processing").with_total(1000);
289 assert_eq!(tracker.total, Some(1000));
290 }
291
292 #[test]
293 fn test_log_if_needed_returns_correctly() {
294 let tracker = ProgressTracker::new("Test").with_interval(10);
295
296 // Not on interval
297 assert!(!tracker.log_if_needed(5)); // count=5
298 assert!(!tracker.log_if_needed(3)); // count=8
299
300 // Crosses interval, lands on it
301 assert!(tracker.log_if_needed(2)); // count=10, exactly on interval
302
303 // Not on interval
304 assert!(!tracker.log_if_needed(5)); // count=15
305
306 // Crosses interval, doesn't land on it
307 assert!(!tracker.log_if_needed(10)); // count=25, crossed 20
308 }
309
310 #[test]
311 fn test_log_if_needed_zero() {
312 let tracker = ProgressTracker::new("Test").with_interval(10);
313
314 // Zero count, zero additional
315 assert!(!tracker.log_if_needed(0));
316
317 // Add to exactly on interval
318 tracker.log_if_needed(10);
319 assert!(tracker.log_if_needed(0)); // count=10, exactly on interval
320
321 // Add more, not on interval
322 tracker.log_if_needed(5);
323 assert!(!tracker.log_if_needed(0)); // count=15
324 }
325
326 #[test]
327 fn test_count() {
328 let tracker = ProgressTracker::new("Test").with_interval(100);
329
330 assert_eq!(tracker.count(), 0);
331 tracker.log_if_needed(50);
332 assert_eq!(tracker.count(), 50);
333 tracker.log_if_needed(75);
334 assert_eq!(tracker.count(), 125);
335 }
336
337 #[test]
338 fn test_crossing_multiple_intervals() {
339 let tracker = ProgressTracker::new("Test").with_interval(10);
340
341 // Cross multiple intervals at once (10, 20, 30)
342 assert!(!tracker.log_if_needed(35)); // count=35, crossed 10, 20, 30 but not on interval
343 assert_eq!(tracker.count(), 35);
344
345 // Cross to exactly on interval
346 assert!(tracker.log_if_needed(5)); // count=40
347 }
348
349 #[test]
350 fn test_thread_safety() {
351 use std::sync::Arc;
352 use std::thread;
353
354 let tracker = Arc::new(ProgressTracker::new("Test").with_interval(1000));
355 let mut handles = vec![];
356
357 // Spawn 10 threads, each adding 100 items
358 for _ in 0..10 {
359 let tracker_clone = Arc::clone(&tracker);
360 handles.push(thread::spawn(move || {
361 for _ in 0..100 {
362 tracker_clone.log_if_needed(1);
363 }
364 }));
365 }
366
367 for handle in handles {
368 handle.join().expect("thread should join successfully");
369 }
370
371 // Total should be 1000
372 assert_eq!(tracker.count(), 1000);
373 }
374
375 #[test]
376 fn test_with_total_tracks_count() {
377 let tracker = ProgressTracker::new("Test").with_interval(10).with_total(100);
378
379 tracker.log_if_needed(25);
380 assert_eq!(tracker.count(), 25);
381 tracker.log_if_needed(75);
382 assert_eq!(tracker.count(), 100);
383 }
384
385 #[rstest]
386 #[case(0.0, "0s")]
387 #[case(59.0, "59s")]
388 #[case(59.5, "1m")]
389 #[case(90.0, "1m 30s")]
390 #[case(3600.0, "1h")]
391 #[case(5400.0, "1h 30m")]
392 fn test_fmt_duration(#[case] secs: f64, #[case] expected: &str) {
393 assert_eq!(fmt_duration(secs), expected);
394 }
395
396 #[test]
397 fn test_ema_bias_correction() {
398 let mut ema = EmaState::new();
399
400 // With zero calls, corrected rate should be 0
401 assert!(ema.corrected_rate().abs() < f64::EPSILON);
402
403 // After first update, corrected rate equals instantaneous rate
404 // (bias correction factor is 1/(1-0.7^1) = 1/0.3 = 3.33,
405 // and smoothed_rate = 0.3 * rate, so corrected = rate)
406 std::thread::sleep(std::time::Duration::from_millis(10));
407 ema.last_count = 0;
408 let rate = ema.update(1000);
409 assert!(rate > 0.0, "rate should be positive after first update");
410 }
411}