use crate::logging::format_duration;
use log::info;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use std::time::Instant;
const EMA_ALPHA: f64 = 0.3;
struct EmaState {
smoothed_rate: f64,
calls: u32,
last_count: u64,
last_time: Instant,
}
impl EmaState {
fn new() -> Self {
Self { smoothed_rate: 0.0, calls: 0, last_count: 0, last_time: Instant::now() }
}
fn update(&mut self, current_count: u64) -> f64 {
if current_count <= self.last_count {
return self.corrected_rate();
}
let now = Instant::now();
let dt = now.duration_since(self.last_time).as_secs_f64();
if dt > 0.0 {
#[allow(clippy::cast_precision_loss)]
let dn = (current_count - self.last_count) as f64;
let instantaneous_rate = dn / dt;
self.smoothed_rate =
EMA_ALPHA * instantaneous_rate + (1.0 - EMA_ALPHA) * self.smoothed_rate;
self.calls += 1;
self.last_count = current_count;
self.last_time = now;
}
self.corrected_rate()
}
fn corrected_rate(&self) -> f64 {
if self.calls == 0 {
return 0.0;
}
let beta = 1.0 - EMA_ALPHA;
let correction = 1.0 - beta.powi(self.calls.cast_signed());
if correction <= 0.0 { 0.0 } else { self.smoothed_rate / correction }
}
}
fn fmt_duration(secs: f64) -> String {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
format_duration(Duration::from_secs(secs.round() as u64))
}
pub struct ProgressTracker {
interval: u64,
message: String,
count: AtomicU64,
total: Option<u64>,
start_time: Instant,
ema: Mutex<EmaState>,
}
impl ProgressTracker {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
interval: 10_000,
message: message.into(),
count: AtomicU64::new(0),
total: None,
start_time: Instant::now(),
ema: Mutex::new(EmaState::new()),
}
}
#[must_use]
pub fn with_interval(mut self, interval: u64) -> Self {
self.interval = interval;
self
}
#[must_use]
pub fn with_total(mut self, total: u64) -> Self {
self.total = (total > 0).then_some(total);
self
}
#[allow(clippy::cast_precision_loss)]
pub fn log_if_needed(&self, additional: u64) -> bool {
if additional == 0 {
let count = self.count.load(Ordering::Relaxed);
return count > 0 && count.is_multiple_of(self.interval);
}
let prev = self.count.fetch_add(additional, Ordering::Relaxed);
let new_count = prev + additional;
let prev_intervals = prev / self.interval;
let new_intervals = new_count / self.interval;
if new_intervals > prev_intervals {
let rate = if self.total.is_some() {
if let Ok(mut ema) = self.ema.lock() { ema.update(new_count) } else { 0.0 }
} else {
0.0
};
for i in (prev_intervals + 1)..=new_intervals {
let milestone = i * self.interval;
if let Some(total) = self.total {
let pct = (milestone as f64 / total as f64) * 100.0;
let eta_suffix = if rate > 0.0 {
let remaining = total.saturating_sub(milestone) as f64;
format!(", ETA ~{}", fmt_duration(remaining / rate))
} else {
String::new()
};
info!("{} {} / {} ({:.1}%{})", self.message, milestone, total, pct, eta_suffix);
} else {
info!("{} {}", self.message, milestone);
}
}
}
new_count.is_multiple_of(self.interval)
}
pub fn log_final(&self) {
let count = self.count.load(Ordering::Relaxed);
if count == 0 && self.total.is_none() {
return;
}
if self.total.is_some() {
let elapsed = self.start_time.elapsed().as_secs_f64();
info!("{} {} (complete, {})", self.message, count, fmt_duration(elapsed));
} else if !self.log_if_needed(0) {
info!("{} {} (complete)", self.message, count);
}
}
#[must_use]
pub fn count(&self) -> u64 {
self.count.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[test]
fn test_progress_tracker_new() {
let tracker = ProgressTracker::new("Processing");
assert_eq!(tracker.interval, 10_000);
assert_eq!(tracker.message, "Processing");
assert_eq!(tracker.count(), 0);
assert!(tracker.total.is_none());
}
#[test]
fn test_progress_tracker_with_interval() {
let tracker = ProgressTracker::new("Processing").with_interval(100);
assert_eq!(tracker.interval, 100);
}
#[test]
fn test_progress_tracker_with_total() {
let tracker = ProgressTracker::new("Processing").with_total(1000);
assert_eq!(tracker.total, Some(1000));
}
#[test]
fn test_log_if_needed_returns_correctly() {
let tracker = ProgressTracker::new("Test").with_interval(10);
assert!(!tracker.log_if_needed(5)); assert!(!tracker.log_if_needed(3));
assert!(tracker.log_if_needed(2));
assert!(!tracker.log_if_needed(5));
assert!(!tracker.log_if_needed(10)); }
#[test]
fn test_log_if_needed_zero() {
let tracker = ProgressTracker::new("Test").with_interval(10);
assert!(!tracker.log_if_needed(0));
tracker.log_if_needed(10);
assert!(tracker.log_if_needed(0));
tracker.log_if_needed(5);
assert!(!tracker.log_if_needed(0)); }
#[test]
fn test_count() {
let tracker = ProgressTracker::new("Test").with_interval(100);
assert_eq!(tracker.count(), 0);
tracker.log_if_needed(50);
assert_eq!(tracker.count(), 50);
tracker.log_if_needed(75);
assert_eq!(tracker.count(), 125);
}
#[test]
fn test_crossing_multiple_intervals() {
let tracker = ProgressTracker::new("Test").with_interval(10);
assert!(!tracker.log_if_needed(35)); assert_eq!(tracker.count(), 35);
assert!(tracker.log_if_needed(5)); }
#[test]
fn test_thread_safety() {
use std::sync::Arc;
use std::thread;
let tracker = Arc::new(ProgressTracker::new("Test").with_interval(1000));
let mut handles = vec![];
for _ in 0..10 {
let tracker_clone = Arc::clone(&tracker);
handles.push(thread::spawn(move || {
for _ in 0..100 {
tracker_clone.log_if_needed(1);
}
}));
}
for handle in handles {
handle.join().expect("thread should join successfully");
}
assert_eq!(tracker.count(), 1000);
}
#[test]
fn test_with_total_tracks_count() {
let tracker = ProgressTracker::new("Test").with_interval(10).with_total(100);
tracker.log_if_needed(25);
assert_eq!(tracker.count(), 25);
tracker.log_if_needed(75);
assert_eq!(tracker.count(), 100);
}
#[rstest]
#[case(0.0, "0s")]
#[case(59.0, "59s")]
#[case(59.5, "1m")]
#[case(90.0, "1m 30s")]
#[case(3600.0, "1h")]
#[case(5400.0, "1h 30m")]
fn test_fmt_duration(#[case] secs: f64, #[case] expected: &str) {
assert_eq!(fmt_duration(secs), expected);
}
#[test]
fn test_ema_bias_correction() {
let mut ema = EmaState::new();
assert!(ema.corrected_rate().abs() < f64::EPSILON);
std::thread::sleep(std::time::Duration::from_millis(10));
ema.last_count = 0;
let rate = ema.update(1000);
assert!(rate > 0.0, "rate should be positive after first update");
}
}