use std::{
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
time::Duration,
};
use compact_str::CompactString;
use parking_lot::RwLock;
use web_time::Instant;
#[derive(Clone)]
pub struct Progress {
pub(crate) kind: ProgressType,
pub(crate) start: Option<Instant>,
pub(crate) cold: Arc<RwLock<Cold>>,
pub(crate) item: Arc<RwLock<CompactString>>,
pub(crate) position: Arc<AtomicU64>,
pub(crate) total: Arc<AtomicU64>,
pub(crate) finished: Arc<AtomicBool>,
}
pub struct Cold {
pub(crate) name: CompactString,
pub(crate) stopped: Option<Instant>,
pub(crate) error: Option<CompactString>,
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
pub enum ProgressType {
#[default]
Spinner,
Bar,
}
impl std::fmt::Debug for Progress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cold = self.cold.read();
f.debug_struct("Progress")
.field("kind", &self.kind)
.field("start", &self.start)
.field("name", &cold.name)
.field("item", &self.item.read())
.field("position", &self.position.load(Ordering::Relaxed))
.field("total", &self.total.load(Ordering::Relaxed))
.field("finished", &self.finished.load(Ordering::Relaxed))
.field("error", &cold.error)
.finish()
}
}
impl Progress {
pub fn new(kind: ProgressType, name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
Self {
kind,
start: None,
cold: Arc::new(RwLock::new(Cold {
name: name.into(),
stopped: None,
error: None,
})),
item: Arc::new(RwLock::new(CompactString::default())),
position: Arc::new(AtomicU64::new(0)),
total: Arc::new(AtomicU64::new(total.into())),
finished: Arc::new(AtomicBool::new(false)),
}
}
#[must_use]
pub fn new_pb(name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
Self::new(ProgressType::Bar, name, total)
}
#[must_use]
pub fn new_spinner(name: impl Into<CompactString>) -> Self {
Self::new(ProgressType::Spinner, name, 0u64)
}
#[must_use]
pub fn get_name(&self) -> CompactString {
self.cold.read().name.clone()
}
pub fn set_name(&self, name: impl Into<CompactString>) {
self.cold.write().name = name.into();
}
#[must_use]
pub fn get_item(&self) -> CompactString {
self.item.read().clone()
}
pub fn set_item(&self, item: impl Into<CompactString>) {
*self.item.write() = item.into();
}
#[must_use]
pub fn get_error(&self) -> Option<CompactString> {
self.cold.read().error.clone()
}
pub fn set_error(&self, error: Option<impl Into<CompactString>>) {
let error = error.map(Into::into);
self.cold.write().error = error;
}
pub fn inc(&self, amount: impl Into<u64>) {
self.position.fetch_add(amount.into(), Ordering::Relaxed);
}
pub fn bump(&self) {
self.inc(1u64);
}
#[must_use]
pub fn get_pos(&self) -> u64 {
self.position.load(Ordering::Relaxed)
}
pub fn set_pos(&self, pos: u64) {
self.position.store(pos, Ordering::Relaxed);
}
#[must_use]
pub fn get_total(&self) -> u64 {
self.total.load(Ordering::Relaxed)
}
pub fn set_total(&self, total: u64) {
self.total.store(total, Ordering::Relaxed);
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.finished.load(Ordering::Acquire)
}
pub fn set_finished(&self, finished: bool) {
self.finished.store(finished, Ordering::Release);
}
#[must_use]
pub fn get_elapsed(&self) -> Option<Duration> {
let start = self.start?;
let cold = self.cold.read();
Some(
cold.stopped
.map_or_else(|| start.elapsed(), |stopped| stopped.duration_since(start)),
)
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn get_percent(&self) -> f64 {
let pos = self.get_pos() as f64;
let total = self.get_total() as f64;
if total == 0.0 {
0.0
} else {
(pos / total) * 100.0
}
}
pub fn finish(&self) {
if self.start.is_some() {
self.cold.write().stopped.replace(Instant::now());
}
self.set_finished(true);
}
pub fn finish_with_item(&self, item: impl Into<CompactString>) {
self.set_item(item);
self.finish(); }
pub fn finish_with_error(&self, error: impl Into<CompactString>) {
self.set_error(Some(error));
self.finish();
}
#[must_use]
pub fn atomic_pos(&self) -> Arc<AtomicU64> {
self.position.clone()
}
#[must_use]
pub fn atomic_total(&self) -> Arc<AtomicU64> {
self.total.clone()
}
#[must_use]
pub fn snapshot(&self) -> ProgressSnapshot {
self.into()
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
pub struct ProgressSnapshot {
pub kind: ProgressType,
pub name: CompactString,
pub item: CompactString,
pub elapsed: Option<Duration>,
pub position: u64,
pub total: u64,
pub finished: bool,
pub error: Option<CompactString>,
}
impl From<&Progress> for ProgressSnapshot {
fn from(progress: &Progress) -> Self {
let cold = progress.cold.read();
let name = cold.name.clone();
let error = cold.error.clone();
drop(cold);
Self {
kind: progress.kind,
name,
item: progress.item.read().clone(),
elapsed: progress.get_elapsed(),
position: progress.position.load(Ordering::Relaxed),
total: progress.total.load(Ordering::Relaxed),
finished: progress.finished.load(Ordering::Relaxed),
error,
}
}
}
impl ProgressSnapshot {
#[must_use]
pub const fn kind(&self) -> ProgressType {
self.kind
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn item(&self) -> &str {
&self.item
}
#[must_use]
pub const fn elapsed(&self) -> Option<Duration> {
self.elapsed
}
#[must_use]
pub const fn position(&self) -> u64 {
self.position
}
#[must_use]
pub const fn total(&self) -> u64 {
self.total
}
#[must_use]
pub const fn finished(&self) -> bool {
self.finished
}
#[must_use]
pub fn error(&self) -> Option<&str> {
self.error.as_deref()
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn eta(&self) -> Option<Duration> {
if self.position == 0 || self.total == 0 || self.finished {
return None;
}
let elapsed = self.elapsed?;
let secs = elapsed.as_secs_f64();
if secs <= 1e-6 {
return None;
}
let rate = self.position as f64 / secs;
if rate <= 0.0 {
return None;
}
let remaining_items = self.total.saturating_sub(self.position);
let remaining_secs = remaining_items as f64 / rate;
Some(Duration::from_secs_f64(remaining_secs))
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn throughput(&self) -> f64 {
if let Some(elapsed) = self.elapsed {
let secs = elapsed.as_secs_f64();
if secs > 0.0 {
return self.position as f64 / secs;
}
}
0.0
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn throughput_since(&self, prev: &Self) -> f64 {
let pos_diff = self.position.saturating_sub(prev.position) as f64;
let time_diff = match (self.elapsed, prev.elapsed) {
(Some(curr), Some(old)) => curr.as_secs_f64() - old.as_secs_f64(),
_ => 0.0,
};
if time_diff > 0.0 {
pos_diff / time_diff
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use std::thread;
use super::Progress;
#[test]
#[allow(clippy::float_cmp)]
fn test_basic_lifecycle() {
let p = Progress::new_pb("test_job", 100u64);
assert_eq!(p.get_pos(), 0);
assert!(!p.is_finished());
assert_eq!(p.get_percent(), 0.0);
p.inc(50u64);
assert_eq!(p.get_pos(), 50);
assert_eq!(p.get_percent(), 50.0);
p.finish();
assert!(p.is_finished());
assert!(p.get_elapsed().is_none());
}
#[test]
fn test_concurrency_atomics() {
let p = Progress::new_spinner("concurrent_job");
let mut handles = vec![];
for _ in 0..10 {
let p_ref = p.clone();
handles.push(thread::spawn(move || {
for _ in 0..100 {
p_ref.inc(1u64);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(p.get_pos(), 1000, "Atomic updates should be lossless");
}
#[test]
fn test_snapshot_metadata() {
let p = Progress::new_pb("initial_name", 100u64);
p.set_name("updated_name");
p.set_item("file_a.txt");
p.set_error(Some("disk_full"));
let snap = p.snapshot();
assert_eq!(snap.name, "updated_name");
assert_eq!(snap.item, "file_a.txt");
assert_eq!(snap.error, Some("disk_full".into()));
}
#[allow(clippy::float_cmp)]
#[test]
fn test_math_safety() {
let p = Progress::new_pb("math_test", 100u64);
let snap = p.snapshot();
assert_eq!(snap.throughput(), 0.0);
assert!(snap.eta().is_none());
let p_zero = Progress::new_pb("zero_total", 0u64);
assert_eq!(p_zero.get_percent(), 0.0);
}
}