atomic_progress/progress.rs
1//! Core primitives for tracking progress state.
2//!
3//! This module defines the [`Progress`] struct, which acts as the central handle for
4//! updates. It is designed around a "Hot/Cold" split to maximize performance in
5//! multi-threaded environments:
6//!
7//! * **Hot Data:** Position, Total, and Finished state are stored in `Atomic` primitives.
8//! This allows high-frequency updates (e.g., in tight loops) without locking contention.
9//! * **Cold Data:** Metadata like names, current items, and error states are guarded by
10//! an [`RwLock`](parking_lot::RwLock). These are accessed less frequently, typically
11//! only by the rendering thread or when significant state changes occur.
12//!
13//! # Snapshots
14//!
15//! To render progress safely, use [`Progress::snapshot`] to obtain a [`ProgressSnapshot`].
16//! This provides a consistent, immutable view of the progress state at a specific instant,
17//! calculating derived metrics like ETA and throughput automatically.
18
19use std::{
20 sync::{
21 Arc,
22 atomic::{AtomicBool, AtomicU64, Ordering},
23 },
24 time::Duration,
25};
26
27use compact_str::CompactString;
28use parking_lot::RwLock;
29use web_time::Instant;
30
31/// A thread-safe, cloneable handle to a progress indicator.
32///
33/// `Progress` separates "hot" data (position, total, finished status) which are stored in
34/// atomics for high-performance updates, from "cold" data (names, errors, timing) which are
35/// guarded by an [`RwLock`].
36///
37/// Cloning a `Progress` is cheap (Arc bump) and points to the same underlying state.
38#[derive(Clone)]
39pub struct Progress {
40 /// The type of progress indicator (Bar vs Spinner). Immutable after creation.
41 pub(crate) kind: ProgressType,
42
43 /// The instant the progress tracker was created/started.
44 pub(crate) start: Option<Instant>,
45
46 /// Infrequently accessed metadata (name, error state, stop time).
47 pub(crate) cold: Arc<RwLock<Cold>>,
48
49 /// The current "item" being processed (e.g., filename).
50 pub(crate) item: Arc<RwLock<CompactString>>,
51
52 // Atomic fields for wait-free updates on the hot path.
53 pub(crate) position: Arc<AtomicU64>,
54 pub(crate) total: Arc<AtomicU64>,
55 pub(crate) finished: Arc<AtomicBool>,
56}
57
58/// "Cold" storage for metadata that changes infrequently.
59pub struct Cold {
60 pub(crate) name: CompactString,
61 pub(crate) stopped: Option<Instant>,
62 pub(crate) error: Option<CompactString>,
63}
64
65/// Defines the behavior/visualization hint for the progress indicator.
66#[repr(u8)]
67#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
68#[cfg_attr(
69 feature = "rkyv",
70 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
71)]
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
74pub enum ProgressType {
75 /// A spinner, used when the total number of items is unknown.
76 #[default]
77 Spinner,
78 /// A progress bar, used when the total is known.
79 Bar,
80}
81
82impl Progress {
83 /// Creates a new `Progress` instance.
84 ///
85 /// # Parameters
86 ///
87 /// * `kind`: The type of indicator.
88 /// * `name`: A label for the task.
89 /// * `total`: The total expected count (use 0 for spinners).
90 pub fn new(kind: ProgressType, name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
91 Self {
92 kind,
93 start: None,
94 cold: Arc::new(RwLock::new(Cold {
95 name: name.into(),
96 stopped: None,
97 error: None,
98 })),
99 item: Arc::new(RwLock::new(CompactString::default())),
100 position: Arc::new(AtomicU64::new(0)),
101 total: Arc::new(AtomicU64::new(total.into())),
102 finished: Arc::new(AtomicBool::new(false)),
103 }
104 }
105
106 /// Creates a new generic progress bar with a known total.
107 #[must_use]
108 pub fn new_pb(name: impl Into<CompactString>, total: impl Into<u64>) -> Self {
109 Self::new(ProgressType::Bar, name, total)
110 }
111
112 /// Creates a new spinner (indeterminate progress).
113 #[must_use]
114 pub fn new_spinner(name: impl Into<CompactString>) -> Self {
115 Self::new(ProgressType::Spinner, name, 0u64)
116 }
117
118 // ========================================================================
119 // Metadata Accessors
120 // ========================================================================
121
122 /// Gets the current name/label of the progress task.
123 #[must_use]
124 pub fn get_name(&self) -> CompactString {
125 self.cold.read().name.clone()
126 }
127
128 /// Updates the name/label of the progress task.
129 pub fn set_name(&self, name: impl Into<CompactString>) {
130 self.cold.write().name = name.into();
131 }
132
133 /// Gets the current item description (e.g., currently processing file).
134 #[must_use]
135 pub fn get_item(&self) -> CompactString {
136 self.item.read().clone()
137 }
138
139 /// Updates the current item description.
140 pub fn set_item(&self, item: impl Into<CompactString>) {
141 *self.item.write() = item.into();
142 }
143
144 /// Returns the error message, if one occurred.
145 #[must_use]
146 pub fn get_error(&self) -> Option<CompactString> {
147 self.cold.read().error.clone()
148 }
149
150 /// Sets (or clears) an error message for this task.
151 pub fn set_error(&self, error: Option<impl Into<CompactString>>) {
152 let error = error.map(Into::into);
153 self.cold.write().error = error;
154 }
155
156 // ========================================================================
157 // State & Metrics (Hot Path)
158 // ========================================================================
159
160 /// Increments the progress position by the specified amount.
161 ///
162 /// This uses `Ordering::Relaxed` for maximum performance.
163 pub fn inc(&self, amount: impl Into<u64>) {
164 self.position.fetch_add(amount.into(), Ordering::Relaxed);
165 }
166
167 /// Increments the progress position by 1.
168 pub fn bump(&self) {
169 self.inc(1u64);
170 }
171
172 /// Gets the current position.
173 #[must_use]
174 pub fn get_pos(&self) -> u64 {
175 self.position.load(Ordering::Relaxed)
176 }
177
178 /// Sets the absolute position.
179 pub fn set_pos(&self, pos: u64) {
180 self.position.store(pos, Ordering::Relaxed);
181 }
182
183 /// Gets the total target count.
184 #[must_use]
185 pub fn get_total(&self) -> u64 {
186 self.total.load(Ordering::Relaxed)
187 }
188
189 /// Updates the total target count.
190 pub fn set_total(&self, total: u64) {
191 self.total.store(total, Ordering::Relaxed);
192 }
193
194 /// Checks if the task is marked as finished.
195 #[must_use]
196 pub fn is_finished(&self) -> bool {
197 // Acquire ensures we see any memory writes that happened before the finish flag was set.
198 self.finished.load(Ordering::Acquire)
199 }
200
201 /// Manually sets the finished state.
202 ///
203 /// Prefer using [`finish`](Self::finish), [`finish_with_item`](Self::finish_with_item),
204 /// or [`finish_with_error`](Self::finish_with_error) to ensure timestamps are recorded.
205 pub fn set_finished(&self, finished: bool) {
206 self.finished.store(finished, Ordering::Release);
207 }
208
209 // ========================================================================
210 // Timing & Calculations
211 // ========================================================================
212
213 /// Calculates the duration elapsed since creation.
214 ///
215 /// If the task is finished, this returns the duration between start and finish.
216 /// If never started (no start time recorded), returns `None`.
217 #[must_use]
218 pub fn get_elapsed(&self) -> Option<Duration> {
219 let start = self.start?;
220 let cold = self.cold.read();
221
222 Some(
223 cold.stopped
224 .map_or_else(|| start.elapsed(), |stopped| stopped.duration_since(start)),
225 )
226 }
227
228 /// Returns the current completion percentage (0.0 to 100.0).
229 ///
230 /// Returns `0.0` if `total` is zero.
231 #[allow(clippy::cast_precision_loss)]
232 #[must_use]
233 pub fn get_percent(&self) -> f64 {
234 let pos = self.get_pos() as f64;
235 let total = self.get_total() as f64;
236
237 if total == 0.0 {
238 0.0
239 } else {
240 (pos / total) * 100.0
241 }
242 }
243
244 // ========================================================================
245 // Lifecycle Management
246 // ========================================================================
247
248 /// Marks the task as finished and records the stop time.
249 pub fn finish(&self) {
250 if self.start.is_some() {
251 self.cold.write().stopped.replace(Instant::now());
252 }
253 self.set_finished(true);
254 }
255
256 /// Sets the current item and marks the task as finished.
257 pub fn finish_with_item(&self, item: impl Into<CompactString>) {
258 self.set_item(item);
259 self.finish(); // Calls set_finished(true) internally
260 }
261
262 /// Sets an error message and marks the task as finished.
263 pub fn finish_with_error(&self, error: impl Into<CompactString>) {
264 self.set_error(Some(error));
265 self.finish();
266 }
267
268 // ========================================================================
269 // Advanced / Internal
270 // ========================================================================
271
272 /// Returns a shared reference to the atomic position counter.
273 ///
274 /// Useful for sharing this specific counter with other systems.
275 #[must_use]
276 pub fn atomic_pos(&self) -> Arc<AtomicU64> {
277 self.position.clone()
278 }
279
280 /// Returns a shared reference to the atomic total counter.
281 #[must_use]
282 pub fn atomic_total(&self) -> Arc<AtomicU64> {
283 self.total.clone()
284 }
285
286 /// Creates a consistent snapshot of the current state.
287 ///
288 /// This involves acquiring a read lock on the "cold" data.
289 #[must_use]
290 pub fn snapshot(&self) -> ProgressSnapshot {
291 self.into()
292 }
293}
294
295/// A plain-data snapshot of a [`Progress`] state at a specific point in time.
296///
297/// This is typically used for rendering, as it holds owned data and requires no locking
298/// to access.
299#[derive(Clone, Debug, Default, Eq, PartialEq)]
300#[cfg_attr(
301 feature = "rkyv",
302 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
303)]
304#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
305#[cfg_attr(feature = "rkyv", rkyv(derive(Debug, Eq, PartialEq)))]
306pub struct ProgressSnapshot {
307 /// The type of progress indicator.
308 pub kind: ProgressType,
309
310 /// The name/label of the progress task.
311 pub name: CompactString,
312 /// The current item description.
313 pub item: CompactString,
314
315 /// The elapsed duration.
316 pub elapsed: Option<Duration>,
317
318 /// The current position.
319 pub position: u64,
320 /// The total target count.
321 pub total: u64,
322
323 /// Whether the task is finished.
324 pub finished: bool,
325
326 /// The associated error message, if any.
327 pub error: Option<CompactString>,
328}
329
330impl From<&Progress> for ProgressSnapshot {
331 fn from(progress: &Progress) -> Self {
332 // Lock cold data once
333 let cold = progress.cold.read();
334 let name = cold.name.clone();
335 let error = cold.error.clone();
336 drop(cold);
337
338 Self {
339 kind: progress.kind,
340 name,
341 item: progress.item.read().clone(),
342 elapsed: progress.get_elapsed(),
343 position: progress.position.load(Ordering::Relaxed),
344 total: progress.total.load(Ordering::Relaxed),
345 finished: progress.finished.load(Ordering::Relaxed),
346 error,
347 }
348 }
349}
350
351impl ProgressSnapshot {
352 /// Returns the type of progress indicator.
353 #[must_use]
354 pub const fn kind(&self) -> ProgressType {
355 self.kind
356 }
357
358 /// Returns the name/label of the progress task.
359 #[must_use]
360 pub fn name(&self) -> &str {
361 &self.name
362 }
363 /// Returns the current item description.
364 #[must_use]
365 pub fn item(&self) -> &str {
366 &self.item
367 }
368
369 /// Returns the elapsed duration.
370 #[must_use]
371 pub const fn elapsed(&self) -> Option<Duration> {
372 self.elapsed
373 }
374
375 /// Returns the current position.
376 #[must_use]
377 pub const fn position(&self) -> u64 {
378 self.position
379 }
380 /// Returns the total target count.
381 #[must_use]
382 pub const fn total(&self) -> u64 {
383 self.total
384 }
385
386 /// Returns whether the task is finished.
387 #[must_use]
388 pub const fn finished(&self) -> bool {
389 self.finished
390 }
391
392 /// Returns the error message, if any.
393 #[must_use]
394 pub fn error(&self) -> Option<&str> {
395 self.error.as_deref()
396 }
397
398 /// Estimates the time remaining (ETA) based on average speed since start.
399 ///
400 /// Returns `None` if:
401 /// * No progress has been made.
402 /// * Total is zero.
403 /// * Process is finished.
404 /// * Elapsed time is effectively zero.
405 #[allow(clippy::cast_precision_loss)]
406 #[must_use]
407 pub fn eta(&self) -> Option<Duration> {
408 if self.position == 0 || self.total == 0 || self.finished {
409 return None;
410 }
411
412 let elapsed = self.elapsed?;
413 let secs = elapsed.as_secs_f64();
414
415 // Avoid division by zero or extremely small intervals
416 if secs <= 1e-6 {
417 return None;
418 }
419
420 let rate = self.position as f64 / secs;
421 if rate <= 0.0 {
422 return None;
423 }
424
425 let remaining_items = self.total.saturating_sub(self.position);
426 let remaining_secs = remaining_items as f64 / rate;
427
428 Some(Duration::from_secs_f64(remaining_secs))
429 }
430
431 /// Calculates the average throughput (items per second) over the entire lifetime.
432 #[allow(clippy::cast_precision_loss)]
433 #[must_use]
434 pub fn throughput(&self) -> f64 {
435 if let Some(elapsed) = self.elapsed {
436 let secs = elapsed.as_secs_f64();
437 if secs > 0.0 {
438 return self.position as f64 / secs;
439 }
440 }
441 0.0
442 }
443
444 /// Calculates the instantaneous throughput relative to a previous snapshot.
445 ///
446 /// This is useful for calculating "current speed" (e.g., in the last second).
447 #[allow(clippy::cast_precision_loss)]
448 #[must_use]
449 pub fn throughput_since(&self, prev: &Self) -> f64 {
450 let pos_diff = self.position.saturating_sub(prev.position) as f64;
451
452 let time_diff = match (self.elapsed, prev.elapsed) {
453 (Some(curr), Some(old)) => curr.as_secs_f64() - old.as_secs_f64(),
454 _ => 0.0,
455 };
456
457 if time_diff > 0.0 {
458 pos_diff / time_diff
459 } else {
460 0.0
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use std::thread;
468
469 use super::Progress;
470
471 /// Basic Lifecycle
472 /// Verifies the fundamental state machine: New -> Inc -> Finish.
473 #[test]
474 #[allow(clippy::float_cmp)]
475 fn test_basic_lifecycle() {
476 let p = Progress::new_pb("test_job", 100u64);
477
478 assert_eq!(p.get_pos(), 0);
479 assert!(!p.is_finished());
480 assert_eq!(p.get_percent(), 0.0);
481
482 p.inc(50u64);
483 assert_eq!(p.get_pos(), 50);
484 assert_eq!(p.get_percent(), 50.0);
485
486 p.finish();
487 assert!(p.is_finished());
488
489 // Default constructor does not start the timer; elapsed should be None.
490 assert!(p.get_elapsed().is_none());
491 }
492
493 /// Concurrency & Atomics
494 /// Ensures that high-contention updates from multiple threads are lossless.
495 #[test]
496 fn test_concurrency_atomics() {
497 let p = Progress::new_spinner("concurrent_job");
498 let mut handles = vec![];
499
500 // Spawn 10 threads, each incrementing 100 times
501 for _ in 0..10 {
502 let p_ref = p.clone();
503 handles.push(thread::spawn(move || {
504 for _ in 0..100 {
505 p_ref.inc(1u64);
506 }
507 }));
508 }
509
510 for h in handles {
511 h.join().unwrap();
512 }
513
514 assert_eq!(p.get_pos(), 1000, "Atomic updates should be lossless");
515 }
516
517 /// Snapshot Metadata
518 /// Verifies that "Cold" data (names, errors) propagates to snapshots correctly.
519 #[test]
520 fn test_snapshot_metadata() {
521 let p = Progress::new_pb("initial_name", 100u64);
522
523 // Mutate cold state
524 p.set_name("updated_name");
525 p.set_item("file_a.txt");
526 p.set_error(Some("disk_full"));
527
528 let snap = p.snapshot();
529
530 assert_eq!(snap.name, "updated_name");
531 assert_eq!(snap.item, "file_a.txt");
532 assert_eq!(snap.error, Some("disk_full".into()));
533 }
534
535 /// Throughput & ETA Safety
536 /// Verifies mathematical correctness and edge-case safety (NaN/Inf checks).
537 #[allow(clippy::float_cmp)]
538 #[test]
539 fn test_math_safety() {
540 let p = Progress::new_pb("math_test", 100u64);
541 let snap = p.snapshot();
542
543 // Edge case: No time elapsed, no progress
544 assert_eq!(snap.throughput(), 0.0);
545 assert!(snap.eta().is_none());
546
547 // We can't easily mock time without dependency injection or sleeping.
548 // We settle for verifying that 0 total handles percentage gracefully.
549 let p_zero = Progress::new_pb("zero_total", 0u64);
550 assert_eq!(p_zero.get_percent(), 0.0);
551 }
552}