use crate::config::ProgressConfig;
use crate::deps::*;
#[allow(unused_imports)]
use crate::error::{ProgressError, ProgressResult};
use crate::event::ProgressEvent;
struct TrackerState {
message: String,
last_notify: Option<Instant>,
rate_samples: Vec<f64>,
}
pub struct ProgressTracker {
total: u64,
current: AtomicU64,
finished: AtomicBool,
started_at: Instant,
config: ProgressConfig,
state: RwLock<TrackerState>,
sender: broadcast::Sender<ProgressEvent>,
}
impl ProgressTracker {
pub fn new(total: u64) -> Self {
Self::with_config(total, ProgressConfig::default())
}
pub fn with_config(total: u64, config: ProgressConfig) -> Self {
assert!(total > 0, "total must be greater than 0");
let (sender, _) = broadcast::channel(config.channel_capacity);
Self {
total,
current: AtomicU64::new(0),
finished: AtomicBool::new(false),
started_at: Instant::now(),
config,
state: RwLock::new(TrackerState {
message: String::new(),
last_notify: None,
rate_samples: Vec::with_capacity(10),
}),
sender,
}
}
pub fn total(&self) -> u64 {
self.total
}
pub fn current(&self) -> u64 {
self.current.load(Ordering::Relaxed)
}
pub fn percentage(&self) -> f64 {
(self.current() as f64 / self.total as f64) * 100.0
}
pub fn is_finished(&self) -> bool {
self.finished.load(Ordering::Relaxed)
}
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
pub fn advance(&self, delta: u64) {
if self.is_finished() {
return;
}
let new_value = self.current.fetch_add(delta, Ordering::Relaxed) + delta;
let ratio = new_value as f64 / self.total as f64;
if ratio >= self.config.auto_finish_threshold {
self.finish();
} else {
self.maybe_notify();
}
}
pub fn set(&self, value: u64) {
if self.is_finished() {
return;
}
let value = value.min(self.total);
self.current.store(value, Ordering::Relaxed);
let ratio = value as f64 / self.total as f64;
if ratio >= self.config.auto_finish_threshold {
self.finish();
} else {
self.maybe_notify();
}
}
pub fn set_message(&self, message: impl Into<String>) {
let mut state = self.state.write();
state.message = message.into();
drop(state);
self.maybe_notify();
}
pub fn finish(&self) {
if self.finished.swap(true, Ordering::Relaxed) {
return; }
self.current.store(self.total, Ordering::Relaxed);
self.notify_now();
}
pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
self.sender.subscribe()
}
pub fn snapshot(&self) -> ProgressEvent {
let state = self.state.read();
let current = self.current();
let elapsed = self.elapsed();
ProgressEvent {
current,
total: self.total,
message: state.message.clone(),
elapsed,
eta: self.calculate_eta(current, elapsed),
finished: self.is_finished(),
}
}
fn maybe_notify(&self) {
let now = Instant::now();
let should_notify = {
let state = self.state.read();
match state.last_notify {
Some(last) => now.duration_since(last) >= self.config.debounce_interval,
None => true,
}
};
if should_notify {
self.notify_now();
}
}
fn notify_now(&self) {
let event = self.snapshot();
{
let mut state = self.state.write();
state.last_notify = Some(Instant::now());
let elapsed_secs = event.elapsed.as_secs_f64();
if elapsed_secs > 0.0 {
let rate = event.current as f64 / elapsed_secs;
state.rate_samples.push(rate);
if state.rate_samples.len() > 10 {
state.rate_samples.remove(0);
}
}
}
let _ = self.sender.send(event);
}
fn calculate_eta(&self, current: u64, elapsed: Duration) -> Option<Duration> {
if current == 0 {
return None;
}
let state = self.state.read();
if state.rate_samples.len() < self.config.eta_min_samples {
let rate = current as f64 / elapsed.as_secs_f64();
if rate > 0.0 {
let remaining = (self.total - current) as f64;
return Some(Duration::from_secs_f64(remaining / rate));
}
return None;
}
let smoothed_rate = state.rate_samples.iter().rev().fold(0.0, |acc, &rate| {
acc * (1.0 - self.config.eta_smoothing_factor) + rate * self.config.eta_smoothing_factor
});
if smoothed_rate > 0.0 {
let remaining = (self.total - current) as f64;
Some(Duration::from_secs_f64(remaining / smoothed_rate))
} else {
None
}
}
}
impl std::fmt::Debug for ProgressTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProgressTracker")
.field("total", &self.total)
.field("current", &self.current())
.field("finished", &self.is_finished())
.field("elapsed", &self.elapsed())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let tracker = ProgressTracker::new(100);
assert_eq!(tracker.total(), 100);
assert_eq!(tracker.current(), 0);
assert!(!tracker.is_finished());
}
#[test]
fn test_advance() {
let tracker = ProgressTracker::new(100);
tracker.advance(10);
assert_eq!(tracker.current(), 10);
tracker.advance(20);
assert_eq!(tracker.current(), 30);
}
#[test]
fn test_set() {
let tracker = ProgressTracker::new(100);
tracker.set(50);
assert_eq!(tracker.current(), 50);
}
#[test]
fn test_finish() {
let tracker = ProgressTracker::new(100);
tracker.advance(50);
tracker.finish();
assert!(tracker.is_finished());
assert_eq!(tracker.current(), 100);
}
#[test]
fn test_percentage() {
let tracker = ProgressTracker::new(100);
tracker.set(25);
assert!((tracker.percentage() - 25.0).abs() < 0.001);
}
#[test]
fn test_set_message() {
let tracker = ProgressTracker::new(100);
tracker.set_message("Processing...");
let snapshot = tracker.snapshot();
assert_eq!(snapshot.message, "Processing...");
}
#[test]
fn test_auto_finish() {
let tracker = ProgressTracker::new(100);
tracker.set(100);
assert!(tracker.is_finished());
}
#[tokio::test]
async fn test_subscribe() {
let tracker = ProgressTracker::with_config(
100,
ProgressConfig::default().no_debounce(),
);
let mut rx = tracker.subscribe();
tracker.advance(10);
let event = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout")
.expect("recv error");
assert_eq!(event.current, 10);
}
}